diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 03d99e8d..b90af6cd 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -119,6 +119,18 @@ async def openai_completions(request: Request, request_data: CompletionRequest): @app.post('/v1/chat/completions', response_model=ChatCompletionResponse, dependencies=check_key) async def openai_chat_completions(request: Request, request_data: ChatCompletionRequest): + requested_model = request_data.model + payload = OAImodels.get_current_model_info() + current_model = payload["model_name"] + if not current_model == requested_model: + requested_model_dict = {"model_name": requested_model} + try: + OAImodels._load_model(requested_model_dict) + except: + traceback.print_exc() + return HTTPException(status_code=400, detail="Failed to load the model.") + + path = request.url.path is_legacy = "/generate" in path