Add docs for image generation

This commit is contained in:
Artificiangel 2024-05-23 08:44:15 -04:00
parent ee7d2c7406
commit d9fdb3db71
2 changed files with 20 additions and 7 deletions

View File

@ -40,6 +40,8 @@ from .typing import (
DecodeResponse, DecodeResponse,
TranscriptionsRequest, TranscriptionsRequest,
TranscriptionsResponse, TranscriptionsResponse,
ImageGenerationRequest,
ImageGenerationResponse,
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
EncodeRequest, EncodeRequest,
@ -215,17 +217,16 @@ async def handle_audio_transcription(request: Request, request_data: Transcripti
return JSONResponse(content=transcription) return JSONResponse(content=transcription)
@app.post('/v1/images/generations', dependencies=check_key) @app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
async def handle_image_generation(request: Request): async def handle_image_generation(request: Request, request_data: ImageGenerationRequest):
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')): if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
body = await request.json() prompt = request_data.prompt
prompt = body['prompt'] size = request_data.size
size = body.get('size', '1024x1024') response_format = request_data.response_format # or b64_json
response_format = body.get('response_format', 'url') # or b64_json n = request_data.n # ignore the batch limits of max 10
n = body.get('n', 1) # ignore the batch limits of max 10
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n) partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
response = await run_in_executor(partial) response = await run_in_executor(partial)

View File

@ -148,6 +148,18 @@ class TranscriptionsResponse(BaseModel):
text: str text: str
class ImageGenerationRequest(BaseModel):
prompt: str
size: str = Field(default='1024x1024')
response_format: str = Field(default='url')
n: int = Field(default=1)
class ImageGenerationResponse(BaseModel):
created: int
data: list[dict]
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: str | List[str] | List[int] | List[List[int]] input: str | List[str] | List[int] | List[List[int]]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")