2023-07-09 04:21:20 +02:00
|
|
|
import asyncio
|
|
|
|
import functools
|
|
|
|
import threading
|
2023-04-23 20:52:43 +02:00
|
|
|
import time
|
2023-04-26 04:23:47 +02:00
|
|
|
import traceback
|
|
|
|
from threading import Thread
|
2023-04-23 20:52:43 +02:00
|
|
|
from typing import Callable, Optional
|
2023-04-26 04:23:47 +02:00
|
|
|
|
2023-05-20 23:42:17 +02:00
|
|
|
from modules import shared
|
2023-05-23 05:50:58 +02:00
|
|
|
from modules.chat import load_character_memoized
|
2023-06-14 01:34:35 +02:00
|
|
|
from modules.presets import load_preset_memoized
|
2023-04-23 20:52:43 +02:00
|
|
|
|
2023-07-09 04:21:20 +02:00
|
|
|
# We use a thread local to store the asyncio lock, so that each thread
|
|
|
|
# has its own lock. This isn't strictly necessary, but it makes it
|
|
|
|
# such that if we can support multiple worker threads in the future,
|
|
|
|
# thus handling multiple requests in parallel.
|
|
|
|
api_tls = threading.local()
|
|
|
|
|
|
|
|
|
2023-05-20 23:42:17 +02:00
|
|
|
def build_parameters(body, chat=False):
|
2023-04-23 20:52:43 +02:00
|
|
|
|
|
|
|
generate_params = {
|
|
|
|
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
2023-08-02 19:52:20 +02:00
|
|
|
'auto_max_new_tokens': bool(body.get('auto_max_new_tokens', False)),
|
2023-04-23 20:52:43 +02:00
|
|
|
'do_sample': bool(body.get('do_sample', True)),
|
|
|
|
'temperature': float(body.get('temperature', 0.5)),
|
|
|
|
'top_p': float(body.get('top_p', 1)),
|
|
|
|
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
2023-05-21 20:11:57 +02:00
|
|
|
'epsilon_cutoff': float(body.get('epsilon_cutoff', 0)),
|
|
|
|
'eta_cutoff': float(body.get('eta_cutoff', 0)),
|
2023-05-30 03:03:08 +02:00
|
|
|
'tfs': float(body.get('tfs', 1)),
|
|
|
|
'top_a': float(body.get('top_a', 0)),
|
2023-04-23 20:52:43 +02:00
|
|
|
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
2023-06-29 18:40:13 +02:00
|
|
|
'repetition_penalty_range': int(body.get('repetition_penalty_range', 0)),
|
2023-04-23 20:52:43 +02:00
|
|
|
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
|
|
|
'top_k': int(body.get('top_k', 0)),
|
|
|
|
'min_length': int(body.get('min_length', 0)),
|
|
|
|
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
|
|
|
'num_beams': int(body.get('num_beams', 1)),
|
|
|
|
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
|
|
|
'length_penalty': float(body.get('length_penalty', 1)),
|
|
|
|
'early_stopping': bool(body.get('early_stopping', False)),
|
2023-05-23 00:37:24 +02:00
|
|
|
'mirostat_mode': int(body.get('mirostat_mode', 0)),
|
|
|
|
'mirostat_tau': float(body.get('mirostat_tau', 5)),
|
|
|
|
'mirostat_eta': float(body.get('mirostat_eta', 0.1)),
|
2023-08-06 22:22:48 +02:00
|
|
|
'guidance_scale': float(body.get('guidance_scale', 1)),
|
|
|
|
'negative_prompt': str(body.get('negative_prompt', '')),
|
2023-04-23 20:52:43 +02:00
|
|
|
'seed': int(body.get('seed', -1)),
|
2023-05-04 00:27:20 +02:00
|
|
|
'add_bos_token': bool(body.get('add_bos_token', True)),
|
2023-05-20 23:42:17 +02:00
|
|
|
'truncation_length': int(body.get('truncation_length', body.get('max_context_length', 2048))),
|
2023-04-23 20:52:43 +02:00
|
|
|
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
|
|
|
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
|
|
|
'custom_stopping_strings': '', # leave this blank
|
|
|
|
'stopping_strings': body.get('stopping_strings', []),
|
|
|
|
}
|
|
|
|
|
2023-06-14 01:34:35 +02:00
|
|
|
preset_name = body.get('preset', 'None')
|
|
|
|
if preset_name not in ['None', None, '']:
|
|
|
|
preset = load_preset_memoized(preset_name)
|
|
|
|
generate_params.update(preset)
|
|
|
|
|
2023-05-20 23:42:17 +02:00
|
|
|
if chat:
|
|
|
|
character = body.get('character')
|
2023-07-12 05:01:03 +02:00
|
|
|
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
|
|
|
|
if str(instruction_template) == "None":
|
|
|
|
instruction_template = "Vicuna-v1.1"
|
|
|
|
|
2023-05-24 04:03:03 +02:00
|
|
|
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
|
2023-05-23 05:50:58 +02:00
|
|
|
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
2023-05-20 23:42:17 +02:00
|
|
|
generate_params.update({
|
|
|
|
'mode': str(body.get('mode', 'chat')),
|
2023-08-03 20:56:40 +02:00
|
|
|
'name1': str(body.get('name1', name1)),
|
|
|
|
'name2': str(body.get('name2', name2)),
|
|
|
|
'context': str(body.get('context', context)),
|
|
|
|
'greeting': str(body.get('greeting', greeting)),
|
|
|
|
'name1_instruct': str(body.get('name1_instruct', name1_instruct)),
|
|
|
|
'name2_instruct': str(body.get('name2_instruct', name2_instruct)),
|
|
|
|
'context_instruct': str(body.get('context_instruct', context_instruct)),
|
|
|
|
'turn_template': str(body.get('turn_template', turn_template)),
|
2023-08-07 04:46:25 +02:00
|
|
|
'chat-instruct_command': str(body.get('chat_instruct_command', body.get('chat-instruct_command', shared.settings['chat-instruct_command']))),
|
2023-07-05 02:36:47 +02:00
|
|
|
'history': body.get('history', {'internal': [], 'visible': []})
|
2023-05-20 23:42:17 +02:00
|
|
|
})
|
|
|
|
|
2023-04-23 20:52:43 +02:00
|
|
|
return generate_params
|
|
|
|
|
|
|
|
|
2023-08-09 03:20:27 +02:00
|
|
|
def try_start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
2023-04-23 20:52:43 +02:00
|
|
|
Thread(target=_start_cloudflared, args=[
|
2023-08-09 03:20:27 +02:00
|
|
|
port, tunnel_id, max_attempts, on_start], daemon=True).start()
|
2023-04-23 20:52:43 +02:00
|
|
|
|
|
|
|
|
2023-08-09 03:20:27 +02:00
|
|
|
def _start_cloudflared(port: int, tunnel_id: str, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
2023-04-23 20:52:43 +02:00
|
|
|
try:
|
|
|
|
from flask_cloudflared import _run_cloudflared
|
|
|
|
except ImportError:
|
|
|
|
print('You should install flask_cloudflared manually')
|
|
|
|
raise Exception(
|
|
|
|
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
|
|
|
|
|
|
|
for _ in range(max_attempts):
|
|
|
|
try:
|
2023-08-17 20:20:36 +02:00
|
|
|
if tunnel_id is not None:
|
|
|
|
public_url = _run_cloudflared(port, port + 1, tunnel_id=tunnel_id)
|
|
|
|
else:
|
|
|
|
public_url = _run_cloudflared(port, port + 1)
|
2023-04-23 20:52:43 +02:00
|
|
|
|
|
|
|
if on_start:
|
|
|
|
on_start(public_url)
|
|
|
|
|
|
|
|
return
|
|
|
|
except Exception:
|
2023-04-26 04:23:47 +02:00
|
|
|
traceback.print_exc()
|
2023-04-23 20:52:43 +02:00
|
|
|
time.sleep(3)
|
|
|
|
|
|
|
|
raise Exception('Could not start cloudflared.')
|
2023-07-09 04:21:20 +02:00
|
|
|
|
|
|
|
|
|
|
|
def _get_api_lock(tls) -> asyncio.Lock:
|
|
|
|
"""
|
|
|
|
The streaming and blocking API implementations each run on their own
|
|
|
|
thread, and multiplex requests using asyncio. If multiple outstanding
|
|
|
|
requests are received at once, we will try to acquire the shared lock
|
|
|
|
shared.generation_lock multiple times in succession in the same thread,
|
|
|
|
which will cause a deadlock.
|
|
|
|
|
|
|
|
To avoid this, we use this wrapper function to block on an asyncio
|
|
|
|
lock, and then try and grab the shared lock only while holding
|
|
|
|
the asyncio lock.
|
|
|
|
"""
|
|
|
|
if not hasattr(tls, "asyncio_lock"):
|
|
|
|
tls.asyncio_lock = asyncio.Lock()
|
|
|
|
|
|
|
|
return tls.asyncio_lock
|
|
|
|
|
|
|
|
|
|
|
|
def with_api_lock(func):
|
|
|
|
"""
|
|
|
|
This decorator should be added to all streaming API methods which
|
|
|
|
require access to the shared.generation_lock. It ensures that the
|
|
|
|
tls.asyncio_lock is acquired before the method is called, and
|
|
|
|
released afterwards.
|
|
|
|
"""
|
|
|
|
@functools.wraps(func)
|
|
|
|
async def api_wrapper(*args, **kwargs):
|
|
|
|
async with _get_api_lock(api_tls):
|
|
|
|
return await func(*args, **kwargs)
|
|
|
|
return api_wrapper
|