mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Run in executor for long blocking functions.
This commit is contained in:
parent
5770e06c48
commit
3ffb09d465
@ -23,11 +23,12 @@ import extensions.openai.models as OAImodels
|
|||||||
import extensions.openai.moderations as OAImoderations
|
import extensions.openai.moderations as OAImoderations
|
||||||
from extensions.openai.errors import ServiceUnavailableError
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
from extensions.openai.tokens import token_count, token_decode, token_encode
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
from extensions.openai.utils import _start_cloudflared
|
from extensions.openai.utils import _start_cloudflared, generate_in_executor, run_in_executor
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
from modules.models import unload_model
|
from modules.models import unload_model
|
||||||
from modules.text_generation import stop_everything_event
|
from modules.text_generation import stop_everything_event
|
||||||
|
import functools
|
||||||
|
|
||||||
from .typing import (
|
from .typing import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
@ -59,8 +60,12 @@ params = {
|
|||||||
'debug': 0
|
'debug': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Allow some actions to run at the same time.
|
||||||
streaming_semaphore = asyncio.Semaphore(1)
|
text_generation_semaphore = asyncio.Semaphore(1) # Use same lock for streaming and generations.
|
||||||
|
embedding_semaphore = asyncio.Semaphore(1)
|
||||||
|
stt_semaphore = asyncio.Semaphore(1)
|
||||||
|
io_semaphore = asyncio.Semaphore(1)
|
||||||
|
small_tasks_semaphore = asyncio.Semaphore(5)
|
||||||
|
|
||||||
|
|
||||||
def verify_api_key(authorization: str = Header(None)) -> None:
|
def verify_api_key(authorization: str = Header(None)) -> None:
|
||||||
@ -101,9 +106,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||||||
|
|
||||||
if request_data.stream:
|
if request_data.stream:
|
||||||
async def generator():
|
async def generator():
|
||||||
async with streaming_semaphore:
|
async with text_generation_semaphore:
|
||||||
response = OAIcompletions.stream_completions(to_dict(request_data), is_legacy=is_legacy)
|
partial = functools.partial(OAIcompletions.stream_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||||
for resp in response:
|
|
||||||
|
async for resp in generate_in_executor(partial):
|
||||||
disconnected = await request.is_disconnected()
|
disconnected = await request.is_disconnected()
|
||||||
if disconnected:
|
if disconnected:
|
||||||
break
|
break
|
||||||
@ -113,7 +119,10 @@ async def openai_completions(request: Request, request_data: CompletionRequest):
|
|||||||
return EventSourceResponse(generator()) # SSE streaming
|
return EventSourceResponse(generator()) # SSE streaming
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = OAIcompletions.completions(to_dict(request_data), is_legacy=is_legacy)
|
async with text_generation_semaphore:
|
||||||
|
partial = functools.partial(OAIcompletions.completions, to_dict(request_data), is_legacy=is_legacy)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -124,9 +133,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||||||
|
|
||||||
if request_data.stream:
|
if request_data.stream:
|
||||||
async def generator():
|
async def generator():
|
||||||
async with streaming_semaphore:
|
async with text_generation_semaphore:
|
||||||
response = OAIcompletions.stream_chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
partial = functools.partial(OAIcompletions.stream_chat_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||||
for resp in response:
|
|
||||||
|
async for resp in generate_in_executor(partial):
|
||||||
disconnected = await request.is_disconnected()
|
disconnected = await request.is_disconnected()
|
||||||
if disconnected:
|
if disconnected:
|
||||||
break
|
break
|
||||||
@ -136,7 +146,10 @@ async def openai_chat_completions(request: Request, request_data: ChatCompletion
|
|||||||
return EventSourceResponse(generator()) # SSE streaming
|
return EventSourceResponse(generator()) # SSE streaming
|
||||||
|
|
||||||
else:
|
else:
|
||||||
response = OAIcompletions.chat_completions(to_dict(request_data), is_legacy=is_legacy)
|
async with text_generation_semaphore:
|
||||||
|
partial = functools.partial(OAIcompletions.chat_completions, to_dict(request_data), is_legacy=is_legacy)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -182,7 +195,10 @@ async def handle_audio_transcription(request: Request):
|
|||||||
transcription = {"text": ""}
|
transcription = {"text": ""}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
transcription["text"] = r.recognize_whisper(audio_data, language=whisper_language, model=whisper_model)
|
async with stt_semaphore:
|
||||||
|
partial = functools.partial(r.recognize_whisper, audio_data, language=whisper_language, model=whisper_model)
|
||||||
|
transcription["text"] = await run_in_executor(partial)
|
||||||
|
|
||||||
except sr.UnknownValueError:
|
except sr.UnknownValueError:
|
||||||
print("Whisper could not understand audio")
|
print("Whisper could not understand audio")
|
||||||
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
transcription["text"] = "Whisper could not understand audio UnknownValueError"
|
||||||
@ -205,7 +221,8 @@ async def handle_image_generation(request: Request):
|
|||||||
response_format = body.get('response_format', 'url') # or b64_json
|
response_format = body.get('response_format', 'url') # or b64_json
|
||||||
n = body.get('n', 1) # ignore the batch limits of max 10
|
n = body.get('n', 1) # ignore the batch limits of max 10
|
||||||
|
|
||||||
response = await OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
partial = functools.partial(OAIimages.generations, prompt=prompt, size=size, response_format=response_format, n=n)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -218,7 +235,10 @@ async def handle_embeddings(request: Request, request_data: EmbeddingsRequest):
|
|||||||
if type(input) is str:
|
if type(input) is str:
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
response = OAIembeddings.embeddings(input, request_data.encoding_format)
|
async with embedding_semaphore:
|
||||||
|
partial = functools.partial(OAIembeddings.embeddings, input, request_data.encoding_format)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -229,25 +249,37 @@ async def handle_moderations(request: Request):
|
|||||||
if not input:
|
if not input:
|
||||||
raise HTTPException(status_code=400, detail="Missing required argument input")
|
raise HTTPException(status_code=400, detail="Missing required argument input")
|
||||||
|
|
||||||
response = OAImoderations.moderations(input)
|
async with embedding_semaphore:
|
||||||
|
partial = functools.partial(OAImoderations.moderations, input)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
@app.post("/v1/internal/encode", response_model=EncodeResponse, dependencies=check_key)
|
||||||
async def handle_token_encode(request_data: EncodeRequest):
|
async def handle_token_encode(request_data: EncodeRequest):
|
||||||
response = token_encode(request_data.text)
|
async with small_tasks_semaphore:
|
||||||
|
partial = functools.partial(token_encode, request_data.text)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
@app.post("/v1/internal/decode", response_model=DecodeResponse, dependencies=check_key)
|
||||||
async def handle_token_decode(request_data: DecodeRequest):
|
async def handle_token_decode(request_data: DecodeRequest):
|
||||||
response = token_decode(request_data.tokens)
|
async with small_tasks_semaphore:
|
||||||
|
partial = functools.partial(token_decode, request_data.tokens)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
@app.post("/v1/internal/token-count", response_model=TokenCountResponse, dependencies=check_key)
|
||||||
async def handle_token_count(request_data: EncodeRequest):
|
async def handle_token_count(request_data: EncodeRequest):
|
||||||
response = token_count(request_data.text)
|
async with small_tasks_semaphore:
|
||||||
|
partial = functools.partial(token_count, request_data.text)
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -257,7 +289,10 @@ async def handle_logits(request_data: LogitsRequest):
|
|||||||
Given a prompt, returns the top 50 most likely logits as a dict.
|
Given a prompt, returns the top 50 most likely logits as a dict.
|
||||||
The keys are the tokens, and the values are the probabilities.
|
The keys are the tokens, and the values are the probabilities.
|
||||||
'''
|
'''
|
||||||
response = OAIlogits._get_next_logits(to_dict(request_data))
|
async with small_tasks_semaphore:
|
||||||
|
partial = functools.partial(OAIlogits._get_next_logits, to_dict(request_data))
|
||||||
|
response = await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -265,8 +300,14 @@ async def handle_logits(request_data: LogitsRequest):
|
|||||||
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
|
async def handle_chat_prompt(request: Request, request_data: ChatCompletionRequest):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
is_legacy = "/generate" in path
|
is_legacy = "/generate" in path
|
||||||
generator = OAIcompletions.chat_completions_common(to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
|
async with small_tasks_semaphore:
|
||||||
|
# Run in executor as there are calls to get_encoded_length
|
||||||
|
# which might slow down at really long contexts.
|
||||||
|
partial = functools.partial(OAIcompletions.chat_completions_common, to_dict(request_data), is_legacy=is_legacy, prompt_only=True)
|
||||||
|
generator = await run_in_executor(partial)
|
||||||
|
|
||||||
response = deque(generator, maxlen=1).pop()
|
response = deque(generator, maxlen=1).pop()
|
||||||
|
|
||||||
return JSONResponse(response)
|
return JSONResponse(response)
|
||||||
|
|
||||||
|
|
||||||
@ -318,7 +359,10 @@ async def handle_load_model(request_data: LoadModelRequest):
|
|||||||
'''
|
'''
|
||||||
|
|
||||||
try:
|
try:
|
||||||
OAImodels._load_model(to_dict(request_data))
|
async with io_semaphore:
|
||||||
|
partial = functools.partial(OAImodels._load_model, to_dict(request_data))
|
||||||
|
await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -327,7 +371,8 @@ async def handle_load_model(request_data: LoadModelRequest):
|
|||||||
|
|
||||||
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/model/unload", dependencies=check_admin_key)
|
||||||
async def handle_unload_model():
|
async def handle_unload_model():
|
||||||
unload_model()
|
async with io_semaphore:
|
||||||
|
await run_in_executor(unload_model)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
@app.get("/v1/internal/lora/list", response_model=LoraListResponse, dependencies=check_admin_key)
|
||||||
@ -339,7 +384,10 @@ async def handle_list_loras():
|
|||||||
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
@app.post("/v1/internal/lora/load", dependencies=check_admin_key)
|
||||||
async def handle_load_loras(request_data: LoadLorasRequest):
|
async def handle_load_loras(request_data: LoadLorasRequest):
|
||||||
try:
|
try:
|
||||||
OAImodels.load_loras(request_data.lora_names)
|
async with io_semaphore:
|
||||||
|
partial = functools.partial(OAImodels.load_loras, request_data.lora_names)
|
||||||
|
await run_in_executor(partial)
|
||||||
|
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
except:
|
except:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@ -348,7 +396,9 @@ async def handle_load_loras(request_data: LoadLorasRequest):
|
|||||||
|
|
||||||
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
@app.post("/v1/internal/lora/unload", dependencies=check_admin_key)
|
||||||
async def handle_unload_loras():
|
async def handle_unload_loras():
|
||||||
OAImodels.unload_all_loras()
|
async with io_semaphore:
|
||||||
|
await run_in_executor(OAImodels.unload_all_loras)
|
||||||
|
|
||||||
return JSONResponse(content="OK")
|
return JSONResponse(content="OK")
|
||||||
|
|
||||||
|
|
||||||
|
@ -2,9 +2,13 @@ import base64
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import Callable, Optional
|
from typing import Callable, Optional, AsyncGenerator, Generator
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from modules import shared
|
||||||
|
from functools import partial
|
||||||
|
import asyncio
|
||||||
|
from asyncio import AbstractEventLoop, Future
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||||
@ -52,3 +56,37 @@ def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_star
|
|||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
|
|
||||||
raise Exception('Could not start cloudflared.')
|
raise Exception('Could not start cloudflared.')
|
||||||
|
|
||||||
|
|
||||||
|
def get_next_generator_result(gen: Generator) -> tuple[any, bool]:
|
||||||
|
"""
|
||||||
|
Because StopIteration interacts badly with generators and cannot be raised into a Future
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
result = next(gen)
|
||||||
|
return result, False
|
||||||
|
except StopIteration:
|
||||||
|
return None, True
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_in_executor(partial: partial, loop: AbstractEventLoop = None) -> AsyncGenerator[any, any]:
|
||||||
|
"""
|
||||||
|
Converts a blocking generator to an async one
|
||||||
|
"""
|
||||||
|
loop = loop or asyncio.get_running_loop()
|
||||||
|
gen = await loop.run_in_executor(None, partial)
|
||||||
|
|
||||||
|
while not shared.stop_everything:
|
||||||
|
result, is_done = await loop.run_in_executor(None, get_next_generator_result, gen)
|
||||||
|
if is_done:
|
||||||
|
break
|
||||||
|
|
||||||
|
yield result
|
||||||
|
|
||||||
|
|
||||||
|
async def run_in_executor(partial: partial, loop: AbstractEventLoop = None) -> Future:
|
||||||
|
"""
|
||||||
|
Runs a blocking function in a new thread so it can be awaited.
|
||||||
|
"""
|
||||||
|
loop = loop or asyncio.get_running_loop()
|
||||||
|
return await loop.run_in_executor(None, partial)
|
||||||
|
Loading…
Reference in New Issue
Block a user