mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-29 10:59:32 +01:00
Add docs for image generation
This commit is contained in:
parent
ee7d2c7406
commit
d9fdb3db71
@ -40,6 +40,8 @@ from .typing import (
|
||||
DecodeResponse,
|
||||
TranscriptionsRequest,
|
||||
TranscriptionsResponse,
|
||||
ImageGenerationRequest,
|
||||
ImageGenerationResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
@ -215,17 +217,16 @@ async def handle_audio_transcription(request: Request, request_data: Transcripti
|
||||
return JSONResponse(content=transcription)
|
||||
|
||||
|
||||
@app.post('/v1/images/generations', dependencies=check_key)
|
||||
async def handle_image_generation(request: Request):
|
||||
@app.post('/v1/images/generations', response_model=ImageGenerationResponse, dependencies=check_key)
|
||||
async def handle_image_generation(request: Request, request_data: ImageGenerationRequest):
|
||||
|
||||
if not os.environ.get('SD_WEBUI_URL', params.get('sd_webui_url', '')):
|
||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||
|
||||
body = await request.json()
|
||||
prompt = body['prompt']
|
||||
size = body.get('size', '1024x1024')
|
||||
response_format = body.get('response_format', 'url') # or b64_json
|
||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||
prompt = request_data.prompt
|
||||
size = request_data.size
|
||||
response_format = request_data.response_format # or b64_json
|
||||
n = request_data.n # ignore the batch limits of max 10
|
||||
|
||||
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
response = await run_in_executor(partial)
|
||||
|
@ -148,6 +148,18 @@ class TranscriptionsResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class ImageGenerationRequest(BaseModel):
|
||||
prompt: str
|
||||
size: str = Field(default='1024x1024')
|
||||
response_format: str = Field(default='url')
|
||||
n: int = Field(default=1)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
created: int
|
||||
data: list[dict]
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str] | List[int] | List[List[int]]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
|
Loading…
Reference in New Issue
Block a user