mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Fix transcriptions endpoint
This commit is contained in:
parent
f9b2ff1616
commit
432b070bde
@ -38,6 +38,8 @@ from .typing import (
|
||||
CompletionResponse,
|
||||
DecodeRequest,
|
||||
DecodeResponse,
|
||||
TranscriptionsRequest,
|
||||
TranscriptionsResponse,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
EncodeRequest,
|
||||
@ -53,6 +55,8 @@ from .typing import (
|
||||
to_dict
|
||||
)
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
params = {
|
||||
'embedding_device': 'cpu',
|
||||
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||
@ -176,12 +180,13 @@ def handle_billing_usage():
|
||||
return JSONResponse(content={"total_usage": 0})
|
||||
|
||||
|
||||
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
||||
async def handle_audio_transcription(request: Request):
|
||||
@app.post('/v1/audio/transcriptions', response_model=TranscriptionsResponse, dependencies=check_key)
|
||||
async def handle_audio_transcription(request: Request, request_data: TranscriptionsRequest = Depends(TranscriptionsRequest.as_form)):
|
||||
r = sr.Recognizer()
|
||||
|
||||
form = await request.form()
|
||||
audio_file = await form["file"].read()
|
||||
file = request_data.file
|
||||
audio_file = await file.read()
|
||||
audio_file = BytesIO(audio_file)
|
||||
audio_data = AudioSegment.from_file(audio_file)
|
||||
|
||||
# Convert AudioSegment to raw data
|
||||
@ -189,8 +194,8 @@ async def handle_audio_transcription(request: Request):
|
||||
|
||||
# Create AudioData object
|
||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||
whisper_language = form.getvalue('language', None)
|
||||
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
||||
whisper_language = request_data.language
|
||||
whisper_model = request_data.model # Use the model from the form data if it exists, otherwise default to tiny
|
||||
|
||||
transcription = {"text": ""}
|
||||
|
||||
@ -200,10 +205,11 @@ async def handle_audio_transcription(request: Request):
|
||||
transcription["text"] = await run_in_executor(partial)
|
||||
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
logger.warning("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
logger.warning("Could not request results from Whisper", e)
|
||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||
|
||||
return JSONResponse(content=transcription)
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from fastapi import UploadFile, Form
|
||||
|
||||
|
||||
class GenerationOptions(BaseModel):
|
||||
@ -128,6 +129,25 @@ class ChatPromptResponse(BaseModel):
|
||||
prompt: str
|
||||
|
||||
|
||||
class TranscriptionsRequest(BaseModel):
|
||||
file: UploadFile
|
||||
language: str | None = Field(default=None)
|
||||
model: str = Field(default='tiny')
|
||||
|
||||
@classmethod
|
||||
def as_form(
|
||||
cls,
|
||||
file: UploadFile = UploadFile(...),
|
||||
language: str | None = Form(None),
|
||||
model: str = Form('tiny'),
|
||||
) -> 'TranscriptionsRequest':
|
||||
return cls(file=file, language=language, model=model)
|
||||
|
||||
|
||||
class TranscriptionsResponse(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class EmbeddingsRequest(BaseModel):
|
||||
input: str | List[str] | List[int] | List[List[int]]
|
||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||
|
Loading…
Reference in New Issue
Block a user