mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Fix transcriptions endpoint
This commit is contained in:
parent
f9b2ff1616
commit
432b070bde
@ -38,6 +38,8 @@ from .typing import (
|
|||||||
CompletionResponse,
|
CompletionResponse,
|
||||||
DecodeRequest,
|
DecodeRequest,
|
||||||
DecodeResponse,
|
DecodeResponse,
|
||||||
|
TranscriptionsRequest,
|
||||||
|
TranscriptionsResponse,
|
||||||
EmbeddingsRequest,
|
EmbeddingsRequest,
|
||||||
EmbeddingsResponse,
|
EmbeddingsResponse,
|
||||||
EncodeRequest,
|
EncodeRequest,
|
||||||
@ -53,6 +55,8 @@ from .typing import (
|
|||||||
to_dict
|
to_dict
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'embedding_device': 'cpu',
|
'embedding_device': 'cpu',
|
||||||
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
|
||||||
@ -176,12 +180,13 @@ def handle_billing_usage():
|
|||||||
return JSONResponse(content={"total_usage": 0})
|
return JSONResponse(content={"total_usage": 0})
|
||||||
|
|
||||||
|
|
||||||
@app.post('/v1/audio/transcriptions', dependencies=check_key)
|
@app.post('/v1/audio/transcriptions', response_model=TranscriptionsResponse, dependencies=check_key)
|
||||||
async def handle_audio_transcription(request: Request):
|
async def handle_audio_transcription(request: Request, request_data: TranscriptionsRequest = Depends(TranscriptionsRequest.as_form)):
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
form = await request.form()
|
file = request_data.file
|
||||||
audio_file = await form["file"].read()
|
audio_file = await file.read()
|
||||||
|
audio_file = BytesIO(audio_file)
|
||||||
audio_data = AudioSegment.from_file(audio_file)
|
audio_data = AudioSegment.from_file(audio_file)
|
||||||
|
|
||||||
# Convert AudioSegment to raw data
|
# Convert AudioSegment to raw data
|
||||||
@ -189,8 +194,8 @@ async def handle_audio_transcription(request: Request):
|
|||||||
|
|
||||||
# Create AudioData object
|
# Create AudioData object
|
||||||
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
|
||||||
whisper_language = form.getvalue('language', None)
|
whisper_language = request_data.language
|
||||||
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny
|
whisper_model = request_data.model # Use the model from the form data if it exists, otherwise default to tiny
|
||||||
|
|
||||||
transcription = {"text": ""}
|
transcription = {"text": ""}
|
||||||
|
|
||||||
@ -200,10 +205,11 @@ async def handle_audio_transcription(request: Request):
|
|||||||
transcription["text"] = await run_in_executor(partial)
|
transcription["text"] = await run_in_executor(partial)
|
||||||
|
|
||||||
except sr.UnknownValueError:
|
except sr.UnknownValueError:
|
||||||
print("Whisper could not understand audio")
|
logger.warning("Whisper could not understand audio")
|
||||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||||
|
|
||||||
except sr.RequestError as e:
|
except sr.RequestError as e:
|
||||||
print("Could not request results from Whisper", e)
|
logger.warning("Could not request results from Whisper", e)
|
||||||
transcription["text"] = "Whisper could not understand audio RequestError"
|
transcription["text"] = "Whisper could not understand audio RequestError"
|
||||||
|
|
||||||
return JSONResponse(content=transcription)
|
return JSONResponse(content=transcription)
|
||||||
|
@ -3,6 +3,7 @@ import time
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from fastapi import UploadFile, Form
|
||||||
|
|
||||||
|
|
||||||
class GenerationOptions(BaseModel):
|
class GenerationOptions(BaseModel):
|
||||||
@ -128,6 +129,25 @@ class ChatPromptResponse(BaseModel):
|
|||||||
prompt: str
|
prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionsRequest(BaseModel):
|
||||||
|
file: UploadFile
|
||||||
|
language: str | None = Field(default=None)
|
||||||
|
model: str = Field(default='tiny')
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def as_form(
|
||||||
|
cls,
|
||||||
|
file: UploadFile = UploadFile(...),
|
||||||
|
language: str | None = Form(None),
|
||||||
|
model: str = Form('tiny'),
|
||||||
|
) -> 'TranscriptionsRequest':
|
||||||
|
return cls(file=file, language=language, model=model)
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionsResponse(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingsRequest(BaseModel):
|
class EmbeddingsRequest(BaseModel):
|
||||||
input: str | List[str] | List[int] | List[List[int]]
|
input: str | List[str] | List[int] | List[List[int]]
|
||||||
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")
|
||||||
|
Loading…
Reference in New Issue
Block a user