Merge pull request #4532 from oobabooga/dev

Merge dev branch
This commit is contained in:
oobabooga 2023-11-09 09:33:55 -03:00 committed by GitHub
commit f7534b2f4b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import traceback import traceback
@ -46,6 +47,9 @@ params = {
} }
streaming_semaphore = asyncio.Semaphore(1)
def verify_api_key(authorization: str = Header(None)) -> None: def verify_api_key(authorization: str = Header(None)) -> None:
expected_api_key = shared.args.api_key expected_api_key = shared.args.api_key
if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"): if expected_api_key and (authorization is None or authorization != f"Bearer {expected_api_key}"):
@ -84,9 +88,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy) async with streaming_semaphore:
for resp in response: response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
yield {"data": json.dumps(resp)} for resp in response:
yield {"data": json.dumps(resp)}
return EventSourceResponse(generator()) # SSE streaming return EventSourceResponse(generator()) # SSE streaming
@ -102,9 +107,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
if request_data.stream: if request_data.stream:
async def generator(): async def generator():
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy) async with streaming_semaphore:
for resp in response: response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
yield {"data": json.dumps(resp)} for resp in response:
yield {"data": json.dumps(resp)}
return EventSourceResponse(generator()) # SSE streaming return EventSourceResponse(generator()) # SSE streaming