mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Run in executor for long blocking functions.
This commit is contained in:
parent
5770e06c48
commit
3ffb09d465
@ -23,11 +23,12 @@ import extensions.openai.models as OAImodels
|
||||
import extensions.openai.moderations as OAImoderations
|
||||
from extensions.openai.errors import ServiceUnavailableError
|
||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||
from extensions.openai.utils import _start_cloudflared
|
||||
from extensions.openai.utils import _start_cloudflared, generate_in_executor, run_in_executor
|
||||
from modules import shared
|
||||
from modules.logging_colors import logger
|
||||
from modules.models import unload_model
|
||||
from modules.text_generation import stop_everything_event
|
||||
import functools
|
||||
|
||||
from .typing import (
|
||||
ChatCompletionRequest,
|
||||
@ -59,8 +60,12 @@ params = {
|
||||
'debug': 0
|
||||
}
|
||||
|
||||
|
||||
streaming_semaphore = asyncio.Semaphore(1)
|
||||
# Allow some actions to run at the same time.
|
||||
text_generation_semaphore = asyncio.Semaphore(1) # Use same lock for streaming and generations.
|
||||
embedding_semaphore = asyncio.Semaphore(1)
|
||||
stt_semaphore = asyncio.Semaphore(1)
|
||||
io_semaphore = asyncio.Semaphore(1)
|
||||
small_tasks_semaphore = asyncio.Semaphore(5)
|
||||
|
||||
|
||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||
@ -101,9 +106,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
async with streaming_semaphore:
|
||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
async with text_generation_semaphore:
|
||||
partial = functools.partial(OAIcompletions.stream_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||
|
||||
async for resp in generate_in_executor(partial):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
@ -113,7 +119,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
else:
|
||||
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
async with text_generation_semaphore:
|
||||
partial = functools.partial(OAIcompletions.completions, to_dict(request_data), is_legacy=is_legacy)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -124,9 +133,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||
|
||||
if request_data.stream:
|
||||
async def generator():
|
||||
async with streaming_semaphore:
|
||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
for resp in response:
|
||||
async with text_generation_semaphore:
|
||||
partial = functools.partial(OAIcompletions.stream_chat_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||
|
||||
async for resp in generate_in_executor(partial):
|
||||
disconnected = await request.is_disconnected()
|
||||
if disconnected:
|
||||
break
|
||||
@ -136,7 +146,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
||||
return EventSourceResponse(generator()) # SSE streaming
|
||||
|
||||
else:
|
||||
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
||||
async with text_generation_semaphore:
|
||||
partial = functools.partial(OAIcompletions.chat_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -182,7 +195,10 @@ async def handle_audio_transcription(request: Request):
|
||||
transcription = {"text": ""}
|
||||
|
||||
try:
|
||||
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
|
||||
async with stt_semaphore:
|
||||
partial = functools.partial(r.recognize_whisper, audio_data, language=whisper_language, model=whisper_model)
|
||||
transcription["text"] = await run_in_executor(partial)
|
||||
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||
@ -205,7 +221,8 @@ async def handle_image_generation(request: Request):
|
||||
response_format = body.get('response_format', 'url') # or b64_json
|
||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||
|
||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
|
||||
response = await run_in_executor(partial)
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -218,7 +235,10 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
||||
async with embedding_semaphore:
|
||||
partial = functools.partial(OAIembeddings.embeddings, input, request_data.encoding_format)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -229,25 +249,37 @@ async def handle_moderations(request: Request):
|
||||
if not input:
|
||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||
|
||||
response = OAImoderations.moderations(input)
|
||||
async with embedding_semaphore:
|
||||
partial = functools.partial(OAImoderations.moderations, input)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
||||
async def handle_token_encode(request_data: EncodeRequest):
|
||||
response = token_encode(request_data.text)
|
||||
async with small_tasks_semaphore:
|
||||
partial = functools.partial(token_encode, request_data.text)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
||||
async def handle_token_decode(request_data: DecodeRequest):
|
||||
response = token_decode(request_data.tokens)
|
||||
async with small_tasks_semaphore:
|
||||
partial = functools.partial(token_decode, request_data.tokens)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
||||
async def handle_token_count(request_data: EncodeRequest):
|
||||
response = token_count(request_data.text)
|
||||
async with small_tasks_semaphore:
|
||||
partial = functools.partial(token_count, request_data.text)
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -257,7 +289,10 @@ async def handle_logits(request_data: LogitsRequest):
|
||||
Given a prompt, returns the top 50 most likely logits as a dict.
|
||||
The keys are the tokens, and the values are the probabilities.
|
||||
'''
|
||||
response = OAIlogits._get_next_logits(to_dict(request_data))
|
||||
async with small_tasks_semaphore:
|
||||
partial = functools.partial(OAIlogits._get_next_logits, to_dict(request_data))
|
||||
response = await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -265,8 +300,14 @@ async def handle_logits(request_data: LogitsRequest):
|
||||
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
|
||||
path = request.url.path
|
||||
is_legacy = "/generate" in path
|
||||
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
|
||||
async with small_tasks_semaphore:
|
||||
# Run in executor as there are calls to get_encoded_length
|
||||
# which might slow down at really long contexts.
|
||||
partial = functools.partial(OAIcompletions.chat_completions_common, to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
|
||||
generator = await run_in_executor(partial)
|
||||
|
||||
response = deque(generator, maxlen=1).pop()
|
||||
|
||||
return JSONResponse(response)
|
||||
|
||||
|
||||
@ -318,7 +359,10 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||
'''
|
||||
|
||||
try:
|
||||
OAImodels._load_model(to_dict(request_data))
|
||||
async with io_semaphore:
|
||||
partial = functools.partial(OAImodels._load_model, to_dict(request_data))
|
||||
await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(content="OK")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
@ -327,7 +371,8 @@ async def handle_load_model(request_data: LoadModelRequest):
|
||||
|
||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_model():
|
||||
unload_model()
|
||||
async with io_semaphore:
|
||||
await run_in_executor(unload_model)
|
||||
|
||||
|
||||
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||
@ -339,7 +384,10 @@ async def handle_list_loras():
|
||||
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
||||
async def handle_load_loras(request_data: LoadLorasRequest):
|
||||
try:
|
||||
OAImodels.load_loras(request_data.lora_names)
|
||||
async with io_semaphore:
|
||||
partial = functools.partial(OAImodels.load_loras, request_data.lora_names)
|
||||
await run_in_executor(partial)
|
||||
|
||||
return JSONResponse(content="OK")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
@ -348,7 +396,9 @@ async def handle_load_loras(request_data: LoadLorasRequest):
|
||||
|
||||
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||
async def handle_unload_loras():
|
||||
OAImodels.unload_all_loras()
|
||||
async with io_semaphore:
|
||||
await run_in_executor(OAImodels.unload_all_loras)
|
||||
|
||||
return JSONResponse(content="OK")
|
||||
|
||||
|
||||
|
@ -2,9 +2,13 @@ import base64
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, Optional, AsyncGenerator, Generator
|
||||
|
||||
import numpy as np
|
||||
from modules import shared
|
||||
from functools import partial
|
||||
import asyncio
|
||||
from asyncio import AbstractEventLoop, Future
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
@ -52,3 +56,37 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
||||
|
||||
|
||||
def get_next_generator_result(gen: Generator) -> tuple[any, bool]:
|
||||
"""
|
||||
Because StopIteration interacts badly with generators and cannot be raised into a Future
|
||||
"""
|
||||
try:
|
||||
result = next(gen)
|
||||
return result, False
|
||||
except StopIteration:
|
||||
return None, True
|
||||
|
||||
|
||||
async def generate_in_executor(partial: partial, loop: AbstractEventLoop = None) -> AsyncGenerator[any, any]:
|
||||
"""
|
||||
Converts a blocking generator to an async one
|
||||
"""
|
||||
loop = loop or asyncio.get_running_loop()
|
||||
gen = await loop.run_in_executor(None, partial)
|
||||
|
||||
while not shared.stop_everything:
|
||||
result, is_done = await loop.run_in_executor(None, get_next_generator_result, gen)
|
||||
if is_done:
|
||||
break
|
||||
|
||||
yield result
|
||||
|
||||
|
||||
async def run_in_executor(partial: partial, loop: AbstractEventLoop = None) -> Future:
|
||||
"""
|
||||
Runs a blocking function in a new thread so it can be awaited.
|
||||
"""
|
||||
loop = loop or asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, partial)
|
||||
|
Loading…
Reference in New Issue
Block a user