Fix transcriptions endpoint

This commit is contained in:
Artificiangel 2024-05-23 08:07:51 -04:00
parent f9b2ff1616
commit 432b070bde
2 changed files with 34 additions and 8 deletions

View File

@ -38,6 +38,8 @@ from .typing import (
CompletionResponse, CompletionResponse,
DecodeRequest, DecodeRequest,
DecodeResponse, DecodeResponse,
TranscriptionsRequest,
TranscriptionsResponse,
EmbeddingsRequest, EmbeddingsRequest,
EmbeddingsResponse, EmbeddingsResponse,
EncodeRequest, EncodeRequest,
@ -53,6 +55,8 @@ from .typing import (
to_dict to_dict
) )
from io import BytesIO
params = { params = {
'embedding_device': 'cpu', 'embedding_device': 'cpu',
'embedding_model': 'sentence-transformers/all-mpnet-base-v2', 'embedding_model': 'sentence-transformers/all-mpnet-base-v2',
@ -176,12 +180,13 @@ def handle_billing_usage():
return JSONResponse(content={"total_usage": 0}) return JSONResponse(content={"total_usage": 0})
@app.post('/v1/audio/transcriptions', dependencies=check_key) @app.post('/v1/audio/transcriptions', response_model=TranscriptionsResponse, dependencies=check_key)
async def handle_audio_transcription(request: Request): async def handle_audio_transcription(request: Request, request_data: TranscriptionsRequest = Depends(TranscriptionsRequest.as_form)):
r = sr.Recognizer() r = sr.Recognizer()
form = await request.form() file = request_data.file
audio_file = await form["file"].read() audio_file = await file.read()
audio_file = BytesIO(audio_file)
audio_data = AudioSegment.from_file(audio_file) audio_data = AudioSegment.from_file(audio_file)
# Convert AudioSegment to raw data # Convert AudioSegment to raw data
@ -189,8 +194,8 @@ async def handle_audio_transcription(request: Request):
# Create AudioData object # Create AudioData object
audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width)
whisper_language = form.getvalue('language', None) whisper_language = request_data.language
whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny whisper_model = request_data.model # Use the model from the form data if it exists, otherwise default to tiny
transcription = {"text": ""} transcription = {"text": ""}
@ -200,10 +205,11 @@ async def handle_audio_transcription(request: Request):
transcription["text"] = await run_in_executor(partial) transcription["text"] = await run_in_executor(partial)
except sr.UnknownValueError: except sr.UnknownValueError:
print("Whisper could not understand audio") logger.warning("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError" transcription["text"] = "Whisper could not understand audio UnknownValueError"
except sr.RequestError as e: except sr.RequestError as e:
print("Could not request results from Whisper", e) logger.warning("Could not request results from Whisper", e)
transcription["text"] = "Whisper could not understand audio RequestError" transcription["text"] = "Whisper could not understand audio RequestError"
return JSONResponse(content=transcription) return JSONResponse(content=transcription)

View File

@ -3,6 +3,7 @@ import time
from typing import Dict, List from typing import Dict, List
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from fastapi import UploadFile, Form
class GenerationOptions(BaseModel): class GenerationOptions(BaseModel):
@ -128,6 +129,25 @@ class ChatPromptResponse(BaseModel):
prompt: str prompt: str
class TranscriptionsRequest(BaseModel):
file: UploadFile
language: str | None = Field(default=None)
model: str = Field(default='tiny')
@classmethod
def as_form(
cls,
file: UploadFile = UploadFile(...),
language: str | None = Form(None),
model: str = Form('tiny'),
) -> 'TranscriptionsRequest':
return cls(file=file, language=language, model=model)
class TranscriptionsResponse(BaseModel):
text: str
class EmbeddingsRequest(BaseModel): class EmbeddingsRequest(BaseModel):
input: str | List[str] | List[int] | List[List[int]] input: str | List[str] | List[int] | List[List[int]]
model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.") model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")