Add docs for image generation

This commit is contained in:
Artificiangel 2024-05-23 08:44:15 -04:00
parent ee7d2c7406
commit d9fdb3db71
2 changed files with 20 additions and 7 deletions

View File

@ -40,6 +40,8 @@ from .typing import (
DecodeResponse,
TranscriptionsRequest,
TranscriptionsResponse,
ImageGenerationRequest,
ImageGenerationResponse,
EmbeddingsRequest,
EmbeddingsResponse,
EncodeRequest,
@ -215,17 +217,16 @@ async def handle_audio_transcription(request: Request, request_data: Transcripti
return JSONResponse(content=transcription)
@app.post('/v1/images/generations', dependencies=check_key)
async def handle_image_generation(request: Request):
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request: Request, request_data: ImageGenerationRequest):
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
body = await request.json()
prompt = body['prompt']
size = body.get('size', '1024x1024')
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
prompt = request_data.prompt
size = request_data.size
response_format = request_data.response_format # or b64_json
n = request_data.n # ignore the batch limits of max 10
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
response = await run_in_executor(partial)

View File

@ -148,6 +148,18 @@ class TranscriptionsResponse(BaseModel):
text: str
class ImageGenerationRequest(BaseModel):
prompt: str
size: str = Field(default='1024x1024')
response_format: str = Field(default='url')
n: int = Field(default=1)
class ImageGenerationResponse(BaseModel):
created: int
data: list[dict]
class EmbeddingsRequest(BaseModel):
input: str | List[str] | List[int] | List[List[int]]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")