diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 6ae2e2af..a6710163 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -40,6 +40,8 @@ from .typing import ( DecodeResponse, TranscriptionsRequest, TranscriptionsResponse, + ImageGenerationRequest, + ImageGenerationResponse, EmbeddingsRequest, EmbeddingsResponse, EncodeRequest, @@ -215,17 +217,16 @@ async def handle_audio_transcription(request: Request, request_data: Transcripti return JSONResponse(content=transcription) -@app.post('/v1/images/generations', dependencies=check_key) -async def handle_image_generation(request: Request): +@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key) +async def handle_image_generation(request: Request, request_data: ImageGenerationRequest): if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") - body = await request.json() - prompt = body['prompt'] - size = body.get('size', '1024x1024') - response_format = body.get('response_format', 'url') # or b64_json - n = body.get('n', 1) # ignore the batch limits of max 10 + prompt = request_data.prompt + size = request_data.size + response_format = request_data.response_format # or b64_json + n = request_data.n # ignore the batch limits of max 10 partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n) response = await run_in_executor(partial) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index a558fa32..9cd1c80e 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -148,6 +148,18 @@ class TranscriptionsResponse(BaseModel): text: str +class ImageGenerationRequest(BaseModel): + prompt: str + size: str = Field(default='1024x1024') + response_format: str = Field(default='url') + n: int = Field(default=1) + + +class ImageGenerationResponse(BaseModel): + created: int + data: list[dict] + + class EmbeddingsRequest(BaseModel): input: str | List[str] | List[int] | List[List[int]] model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")