diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 746a4390..6ae2e2af 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -38,6 +38,8 @@ from .typing import ( CompletionResponse, DecodeRequest, DecodeResponse, + TranscriptionsRequest, + TranscriptionsResponse, EmbeddingsRequest, EmbeddingsResponse, EncodeRequest, @@ -53,6 +55,8 @@ from .typing import ( to_dict ) +from io import BytesIO + params = { 'embedding_device': 'cpu', 'embedding_model': 'sentence-transformers/all-mpnet-base-v2', @@ -176,12 +180,13 @@ def handle_billing_usage(): return JSONResponse(content={"total_usage": 0}) -@app.post('/v1/audio/transcriptions', dependencies=check_key) -async def handle_audio_transcription(request: Request): +@app.post('/v1/audio/transcriptions', response_model=TranscriptionsResponse, dependencies=check_key) +async def handle_audio_transcription(request: Request, request_data: TranscriptionsRequest = Depends(TranscriptionsRequest.as_form)): r = sr.Recognizer() - form = await request.form() - audio_file = await form["file"].read() + file = request_data.file + audio_file = await file.read() + audio_file = BytesIO(audio_file) audio_data = AudioSegment.from_file(audio_file) # Convert AudioSegment to raw data @@ -189,8 +194,8 @@ async def handle_audio_transcription(request: Request): # Create AudioData object audio_data = sr.AudioData(raw_data, audio_data.frame_rate, audio_data.sample_width) - whisper_language = form.getvalue('language', None) - whisper_model = form.getvalue('model', 'tiny') # Use the model from the form data if it exists, otherwise default to tiny + whisper_language = request_data.language + whisper_model = request_data.model # Use the model from the form data if it exists, otherwise default to tiny transcription = {"text": ""} @@ -200,10 +205,11 @@ async def handle_audio_transcription(request: Request): transcription["text"] = await run_in_executor(partial) except sr.UnknownValueError: - print("Whisper could not understand audio") + logger.warning("Whisper could not understand audio") transcription["text"] = "Whisper could not understand audio UnknownValueError" + except sr.RequestError as e: - print("Could not request results from Whisper", e) + logger.warning("Could not request results from Whisper", e) transcription["text"] = "Whisper could not understand audio RequestError" return JSONResponse(content=transcription) diff --git a/extensions/openai/typing.py b/extensions/openai/typing.py index 2b30ebf2..a558fa32 100644 --- a/extensions/openai/typing.py +++ b/extensions/openai/typing.py @@ -3,6 +3,7 @@ import time from typing import Dict, List from pydantic import BaseModel, Field +from fastapi import UploadFile, Form class GenerationOptions(BaseModel): @@ -128,6 +129,25 @@ class ChatPromptResponse(BaseModel): prompt: str +class TranscriptionsRequest(BaseModel): + file: UploadFile + language: str | None = Field(default=None) + model: str = Field(default='tiny') + + @classmethod + def as_form( + cls, + file: UploadFile = UploadFile(...), + language: str | None = Form(None), + model: str = Form('tiny'), + ) -> 'TranscriptionsRequest': + return cls(file=file, language=language, model=model) + + +class TranscriptionsResponse(BaseModel): + text: str + + class EmbeddingsRequest(BaseModel): input: str | List[str] | List[int] | List[List[int]] model: str | None = Field(default=None, description="Unused parameter. To change the model, set the OPENEDAI_EMBEDDING_MODEL and OPENEDAI_EMBEDDING_DEVICE environment variables before starting the server.")