Run in executor for long blocking functions.

This commit is contained in:
Artificiangel 2024-04-28 08:24:45 -04:00
parent 5770e06c48
commit 3ffb09d465
2 changed files with 114 additions and 26 deletions

View File

@ -23,11 +23,12 @@ import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations
from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import _start_cloudflared
from extensions.openai.utils import _start_cloudflared, generate_in_executor, run_in_executor
from modules import shared
from modules.logging_colors import logger
from modules.models import unload_model
from modules.text_generation import stop_everything_event
import functools
from .typing import (
ChatCompletionRequest,
@ -59,8 +60,12 @@ params = {
'debug': 0
}
streaming_semaphore = asyncio.Semaphore(1)
# Allow some actions to run at the same time.
text_generation_semaphore = asyncio.Semaphore(1) # Use same lock for streaming and generations.
embedding_semaphore = asyncio.Semaphore(1)
stt_semaphore = asyncio.Semaphore(1)
io_semaphore = asyncio.Semaphore(1)
small_tasks_semaphore = asyncio.Semaphore(5)
def verify_api_key(authorization: str = Header(None)) -> None:
@ -101,9 +106,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
if request_data.stream:
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
async with text_generation_semaphore:
partial = functools.partial(OAIcompletions.stream_completions, to_dict(request_data), is_legacy=is_legacy)
async for resp in generate_in_executor(partial):
disconnected = await request.is_disconnected()
if disconnected:
break
@ -113,7 +119,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
return EventSourceResponse(generator()) # SSE streaming
else:
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
async with text_generation_semaphore:
partial = functools.partial(OAIcompletions.completions, to_dict(request_data), is_legacy=is_legacy)
response = await run_in_executor(partial)
return JSONResponse(response)
@ -124,9 +133,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
if request_data.stream:
async def generator():
async with streaming_semaphore:
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
for resp in response:
async with text_generation_semaphore:
partial = functools.partial(OAIcompletions.stream_chat_completions, to_dict(request_data), is_legacy=is_legacy)
async for resp in generate_in_executor(partial):
disconnected = await request.is_disconnected()
if disconnected:
break
@ -136,7 +146,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
return EventSourceResponse(generator()) # SSE streaming
else:
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
async with text_generation_semaphore:
partial = functools.partial(OAIcompletions.chat_completions, to_dict(request_data), is_legacy=is_legacy)
response = await run_in_executor(partial)
return JSONResponse(response)
@ -182,7 +195,10 @@ async def handle_audio_transcription(request: Request):
transcription = {"text": ""}
try:
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
async with stt_semaphore:
partial = functools.partial(r.recognize_whisper, audio_data, language=whisper_language, model=whisper_model)
transcription["text"] = await run_in_executor(partial)
except sr.UnknownValueError:
print("Whisper could not understand audio")
transcription["text"] = "Whisper could not understand audio UnknownValueError"
@ -205,7 +221,8 @@ async def handle_image_generation(request: Request):
response_format = body.get('response_format', 'url') # or b64_json
n = body.get('n', 1) # ignore the batch limits of max 10
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
response = await run_in_executor(partial)
return JSONResponse(response)
@ -218,7 +235,10 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
if type(input) is str:
input = [input]
response = OAIembeddings.embeddings(input, request_data.encoding_format)
async with embedding_semaphore:
partial = functools.partial(OAIembeddings.embeddings, input, request_data.encoding_format)
response = await run_in_executor(partial)
return JSONResponse(response)
@ -229,25 +249,37 @@ async def handle_moderations(request: Request):
if not input:
raise HTTPException(status_code=400, detail="Missing required argument input")
response = OAImoderations.moderations(input)
async with embedding_semaphore:
partial = functools.partial(OAImoderations.moderations, input)
response = await run_in_executor(partial)
return JSONResponse(response)
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
async def handle_token_encode(request_data: EncodeRequest):
response = token_encode(request_data.text)
async with small_tasks_semaphore:
partial = functools.partial(token_encode, request_data.text)
response = await run_in_executor(partial)
return JSONResponse(response)
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
async def handle_token_decode(request_data: DecodeRequest):
response = token_decode(request_data.tokens)
async with small_tasks_semaphore:
partial = functools.partial(token_decode, request_data.tokens)
response = await run_in_executor(partial)
return JSONResponse(response)
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
async def handle_token_count(request_data: EncodeRequest):
response = token_count(request_data.text)
async with small_tasks_semaphore:
partial = functools.partial(token_count, request_data.text)
response = await run_in_executor(partial)
return JSONResponse(response)
@ -257,7 +289,10 @@ async def handle_logits(request_data: LogitsRequest):
Given a prompt, returns the top 50 most likely logits as a dict.
The keys are the tokens, and the values are the probabilities.
'''
response = OAIlogits._get_next_logits(to_dict(request_data))
async with small_tasks_semaphore:
partial = functools.partial(OAIlogits._get_next_logits, to_dict(request_data))
response = await run_in_executor(partial)
return JSONResponse(response)
@ -265,8 +300,14 @@ async def handle_logits(request_data: LogitsRequest):
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
path = request.url.path
is_legacy = "/generate" in path
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
response = deque(generator, maxlen=1).pop()
async with small_tasks_semaphore:
# Run in executor as there are calls to get_encoded_length
# which might slow down at really long contexts.
partial = functools.partial(OAIcompletions.chat_completions_common, to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
generator = await run_in_executor(partial)
response = deque(generator, maxlen=1).pop()
return JSONResponse(response)
@ -318,7 +359,10 @@ async def handle_load_model(request_data: LoadModelRequest):
'''
try:
OAImodels._load_model(to_dict(request_data))
async with io_semaphore:
partial = functools.partial(OAImodels._load_model, to_dict(request_data))
await run_in_executor(partial)
return JSONResponse(content="OK")
except:
traceback.print_exc()
@ -327,7 +371,8 @@ async def handle_load_model(request_data: LoadModelRequest):
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
async def handle_unload_model():
unload_model()
async with io_semaphore:
await run_in_executor(unload_model)
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
@ -339,7 +384,10 @@ async def handle_list_loras():
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
async def handle_load_loras(request_data: LoadLorasRequest):
try:
OAImodels.load_loras(request_data.lora_names)
async with io_semaphore:
partial = functools.partial(OAImodels.load_loras, request_data.lora_names)
await run_in_executor(partial)
return JSONResponse(content="OK")
except:
traceback.print_exc()
@ -348,7 +396,9 @@ async def handle_load_loras(request_data: LoadLorasRequest):
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
async def handle_unload_loras():
OAImodels.unload_all_loras()
async with io_semaphore:
await run_in_executor(OAImodels.unload_all_loras)
return JSONResponse(content="OK")

View File

@ -2,9 +2,13 @@ import base64
import os
import time
import traceback
from typing import Callable, Optional
from typing import Callable, Optional, AsyncGenerator, Generator
import numpy as np
from modules import shared
from functools import partial
import asyncio
from asyncio import AbstractEventLoop, Future
def float_list_to_base64(float_array: np.ndarray) -> str:
@ -52,3 +56,37 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
time.sleep(3)
raise Exception('Could not start cloudflared.')
def get_next_generator_result(gen: Generator) -> tuple[any, bool]:
"""
Because StopIteration interacts badly with generators and cannot be raised into a Future
"""
try:
result = next(gen)
return result, False
except StopIteration:
return None, True
async def generate_in_executor(partial: partial, loop: AbstractEventLoop = None) -> AsyncGenerator[any, any]:
"""
Converts a blocking generator to an async one
"""
loop = loop or asyncio.get_running_loop()
gen = await loop.run_in_executor(None, partial)
while not shared.stop_everything:
result, is_done = await loop.run_in_executor(None, get_next_generator_result, gen)
if is_done:
break
yield result
async def run_in_executor(partial: partial, loop: AbstractEventLoop = None) -> Future:
"""
Runs a blocking function in a new thread so it can be awaited.
"""
loop = loop or asyncio.get_running_loop()
return await loop.run_in_executor(None, partial)