diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 57a7bdb4..94e4160e 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,3 +1,4 @@ +import asyncio import json import os import traceback @@ -46,6 +47,9 @@ params = { } +streaming_semaphore = asyncio.Semaphore(1) + + def verify_api_key(authorization: str = Header(None)) -> None: expected_api_key = shared.args.api_key if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): @@ -84,9 +88,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest): if request_data.stream: async def generator(): - response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) - for resp in response: - yield {"data": json.dumps(resp)} + async with streaming_semaphore: + response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) + for resp in response: + yield {"data": json.dumps(resp)} return EventSourceResponse(generator()) # SSE streaming @@ -102,9 +107,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion if request_data.stream: async def generator(): - response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) - for resp in response: - yield {"data": json.dumps(resp)} + async with streaming_semaphore: + response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) + for resp in response: + yield {"data": json.dumps(resp)} return EventSourceResponse(generator()) # SSE streaming