mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 00:18:20 +01:00
commit
6447b2eea6
@ -23,7 +23,8 @@ async def run(user_input, history):
|
|||||||
'history': history,
|
'history': history,
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
'character': 'Example',
|
'character': 'Example',
|
||||||
'instruction_template': 'Vicuna-v1.1',
|
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
||||||
|
# 'context_instruct': '', # Optional
|
||||||
'your_name': 'You',
|
'your_name': 'You',
|
||||||
|
|
||||||
'regenerate': False,
|
'regenerate': False,
|
||||||
|
@ -17,7 +17,8 @@ def run(user_input, history):
|
|||||||
'history': history,
|
'history': history,
|
||||||
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
'mode': 'instruct', # Valid options: 'chat', 'chat-instruct', 'instruct'
|
||||||
'character': 'Example',
|
'character': 'Example',
|
||||||
'instruction_template': 'Vicuna-v1.1',
|
'instruction_template': 'Vicuna-v1.1', # Will get autodetected if unset
|
||||||
|
# 'context_instruct': '', # Optional
|
||||||
'your_name': 'You',
|
'your_name': 'You',
|
||||||
|
|
||||||
'regenerate': False,
|
'regenerate': False,
|
||||||
|
@ -4,8 +4,9 @@ import requests
|
|||||||
|
|
||||||
HOST = '0.0.0.0:5000'
|
HOST = '0.0.0.0:5000'
|
||||||
|
|
||||||
def generate(prompt, tokens = 200):
|
|
||||||
request = { 'prompt': prompt, 'max_new_tokens': tokens }
|
def generate(prompt, tokens=200):
|
||||||
|
request = {'prompt': prompt, 'max_new_tokens': tokens}
|
||||||
response = requests.post(f'http://{HOST}/api/v1/generate', json=request)
|
response = requests.post(f'http://{HOST}/api/v1/generate', json=request)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
@ -23,7 +24,7 @@ def print_basic_model_info(response):
|
|||||||
print("Model: ", response['result']['model_name'])
|
print("Model: ", response['result']['model_name'])
|
||||||
print("Lora(s): ", response['result']['lora_names'])
|
print("Lora(s): ", response['result']['lora_names'])
|
||||||
for setting in basic_settings:
|
for setting in basic_settings:
|
||||||
print(setting, "=", response['result']['shared.settings'][setting])
|
print(setting, "=", response['result']['shared.settings'][setting])
|
||||||
|
|
||||||
|
|
||||||
# model info
|
# model info
|
||||||
@ -54,7 +55,7 @@ def complex_model_load(model):
|
|||||||
'action': 'load',
|
'action': 'load',
|
||||||
'model_name': model,
|
'model_name': model,
|
||||||
'args': {
|
'args': {
|
||||||
'gptq_for_llama': False, # Use AutoGPTQ by default, set to True for gptq-for-llama
|
'loader': 'AutoGPTQ',
|
||||||
|
|
||||||
'bf16': False,
|
'bf16': False,
|
||||||
'load_in_8bit': False,
|
'load_in_8bit': False,
|
||||||
@ -75,17 +76,17 @@ def complex_model_load(model):
|
|||||||
'rwkv_cuda_on': False,
|
'rwkv_cuda_on': False,
|
||||||
|
|
||||||
# b&b 4-bit
|
# b&b 4-bit
|
||||||
#'load_in_4bit': False,
|
# 'load_in_4bit': False,
|
||||||
#'compute_dtype': 'float16',
|
# 'compute_dtype': 'float16',
|
||||||
#'quant_type': 'nf4',
|
# 'quant_type': 'nf4',
|
||||||
#'use_double_quant': False,
|
# 'use_double_quant': False,
|
||||||
|
|
||||||
#"cpu": false,
|
# "cpu": false,
|
||||||
#"auto_devices": false,
|
# "auto_devices": false,
|
||||||
#"gpu_memory": null,
|
# "gpu_memory": null,
|
||||||
#"cpu_memory": null,
|
# "cpu_memory": null,
|
||||||
#"disk": false,
|
# "disk": false,
|
||||||
#"disk_cache_dir": "cache",
|
# "disk_cache_dir": "cache",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -104,26 +105,25 @@ def complex_model_load(model):
|
|||||||
req['args']['load_in_8bit'] = True
|
req['args']['load_in_8bit'] = True
|
||||||
elif '-hf' in model or 'fp16' in model:
|
elif '-hf' in model or 'fp16' in model:
|
||||||
if '7b' in model:
|
if '7b' in model:
|
||||||
req['args']['bf16'] = True # for 24GB
|
req['args']['bf16'] = True # for 24GB
|
||||||
elif '13b' in model:
|
elif '13b' in model:
|
||||||
req['args']['load_in_8bit'] = True # for 24GB
|
req['args']['load_in_8bit'] = True # for 24GB
|
||||||
elif 'ggml' in model:
|
elif 'ggml' in model:
|
||||||
#req['args']['threads'] = 16
|
# req['args']['threads'] = 16
|
||||||
if '7b' in model:
|
if '7b' in model:
|
||||||
req['args']['n_gpu_layers'] = 100
|
req['args']['n_gpu_layers'] = 100
|
||||||
elif '13b' in model:
|
elif '13b' in model:
|
||||||
req['args']['n_gpu_layers'] = 100
|
req['args']['n_gpu_layers'] = 100
|
||||||
elif '30b' in model or '33b' in model:
|
elif '30b' in model or '33b' in model:
|
||||||
req['args']['n_gpu_layers'] = 59 # 24GB
|
req['args']['n_gpu_layers'] = 59 # 24GB
|
||||||
elif '65b' in model:
|
elif '65b' in model:
|
||||||
req['args']['n_gpu_layers'] = 42 # 24GB
|
req['args']['n_gpu_layers'] = 42 # 24GB
|
||||||
elif 'rwkv' in model:
|
elif 'rwkv' in model:
|
||||||
req['args']['rwkv_cuda_on'] = True
|
req['args']['rwkv_cuda_on'] = True
|
||||||
if '14b' in model:
|
if '14b' in model:
|
||||||
req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB
|
req['args']['rwkv_strategy'] = 'cuda f16i8' # 24GB
|
||||||
else:
|
else:
|
||||||
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
|
req['args']['rwkv_strategy'] = 'cuda f16' # 24GB
|
||||||
|
|
||||||
|
|
||||||
return model_api(req)
|
return model_api(req)
|
||||||
|
|
||||||
@ -134,7 +134,7 @@ if __name__ == '__main__':
|
|||||||
resp = complex_model_load(model)
|
resp = complex_model_load(model)
|
||||||
|
|
||||||
if 'error' in resp:
|
if 'error' in resp:
|
||||||
print (f"❌ {model} FAIL Error: {resp['error']['message']}")
|
print(f"❌ {model} FAIL Error: {resp['error']['message']}")
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
print_basic_model_info(resp)
|
print_basic_model_info(resp)
|
||||||
@ -142,12 +142,12 @@ if __name__ == '__main__':
|
|||||||
ans = generate("0,1,1,2,3,5,8,13,", tokens=2)
|
ans = generate("0,1,1,2,3,5,8,13,", tokens=2)
|
||||||
|
|
||||||
if '21' in ans:
|
if '21' in ans:
|
||||||
print (f"✅ {model} PASS ({ans})")
|
print(f"✅ {model} PASS ({ans})")
|
||||||
else:
|
else:
|
||||||
print (f"❌ {model} FAIL ({ans})")
|
print(f"❌ {model} FAIL ({ans})")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print (f"❌ {model} FAIL Exception: {repr(e)}")
|
print(f"❌ {model} FAIL Exception: {repr(e)}")
|
||||||
|
|
||||||
|
|
||||||
# 0,1,1,2,3,5,8,13, is the fibonacci sequence, the next number is 21.
|
# 0,1,1,2,3,5,8,13, is the fibonacci sequence, the next number is 21.
|
||||||
|
@ -5,13 +5,13 @@ services:
|
|||||||
context: .
|
context: .
|
||||||
args:
|
args:
|
||||||
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
# specify which cuda version your card supports: https://developer.nvidia.com/cuda-gpus
|
||||||
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST}
|
TORCH_CUDA_ARCH_LIST: ${TORCH_CUDA_ARCH_LIST:-7.5}
|
||||||
WEBUI_VERSION: ${WEBUI_VERSION}
|
WEBUI_VERSION: ${WEBUI_VERSION:-HEAD}
|
||||||
env_file: .env
|
env_file: .env
|
||||||
ports:
|
ports:
|
||||||
- "${HOST_PORT}:${CONTAINER_PORT}"
|
- "${HOST_PORT:-7860}:${CONTAINER_PORT:-7860}"
|
||||||
- "${HOST_API_PORT}:${CONTAINER_API_PORT}"
|
- "${HOST_API_PORT:-5000}:${CONTAINER_API_PORT:-5000}"
|
||||||
- "${HOST_API_STREAM_PORT}:${CONTAINER_API_STREAM_PORT}"
|
- "${HOST_API_STREAM_PORT:-5005}:${CONTAINER_API_STREAM_PORT:-5005}"
|
||||||
stdin_open: true
|
stdin_open: true
|
||||||
tty: true
|
tty: true
|
||||||
volumes:
|
volumes:
|
||||||
|
@ -23,13 +23,15 @@ from tqdm.contrib.concurrent import thread_map
|
|||||||
|
|
||||||
|
|
||||||
class ModelDownloader:
|
class ModelDownloader:
|
||||||
def __init__(self, max_retries = 5):
|
def __init__(self, max_retries=5):
|
||||||
self.s = requests.Session()
|
self.s = requests.Session()
|
||||||
if max_retries:
|
if max_retries:
|
||||||
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
|
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
|
||||||
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
|
self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries))
|
||||||
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
|
if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None:
|
||||||
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
|
self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS'))
|
||||||
|
if os.getenv('HF_TOKEN') is not None:
|
||||||
|
self.s.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'}
|
||||||
|
|
||||||
def sanitize_model_and_branch_names(self, model, branch):
|
def sanitize_model_and_branch_names(self, model, branch):
|
||||||
if model[-1] == '/':
|
if model[-1] == '/':
|
||||||
@ -77,7 +79,7 @@ class ModelDownloader:
|
|||||||
is_safetensors = re.match(".*\.safetensors", fname)
|
is_safetensors = re.match(".*\.safetensors", fname)
|
||||||
is_pt = re.match(".*\.pt", fname)
|
is_pt = re.match(".*\.pt", fname)
|
||||||
is_ggml = re.match(".*ggml.*\.bin", fname)
|
is_ggml = re.match(".*ggml.*\.bin", fname)
|
||||||
is_tokenizer = re.match("(tokenizer|ice).*\.model", fname)
|
is_tokenizer = re.match("(tokenizer|ice|spiece).*\.model", fname)
|
||||||
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
is_text = re.match(".*\.(txt|json|py|md)", fname) or is_tokenizer
|
||||||
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
if any((is_pytorch, is_safetensors, is_pt, is_ggml, is_tokenizer, is_text)):
|
||||||
if 'lfs' in dict[i]:
|
if 'lfs' in dict[i]:
|
||||||
|
@ -59,7 +59,10 @@ def build_parameters(body, chat=False):
|
|||||||
|
|
||||||
if chat:
|
if chat:
|
||||||
character = body.get('character')
|
character = body.get('character')
|
||||||
instruction_template = body.get('instruction_template')
|
instruction_template = body.get('instruction_template', shared.settings['instruction_template'])
|
||||||
|
if str(instruction_template) == "None":
|
||||||
|
instruction_template = "Vicuna-v1.1"
|
||||||
|
|
||||||
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
|
name1, name2, _, greeting, context, _ = load_character_memoized(character, str(body.get('your_name', shared.settings['name1'])), shared.settings['name2'], instruct=False)
|
||||||
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
name1_instruct, name2_instruct, _, _, context_instruct, turn_template = load_character_memoized(instruction_template, '', '', instruct=True)
|
||||||
generate_params.update({
|
generate_params.update({
|
||||||
@ -72,7 +75,7 @@ def build_parameters(body, chat=False):
|
|||||||
'greeting': greeting,
|
'greeting': greeting,
|
||||||
'name1_instruct': name1_instruct,
|
'name1_instruct': name1_instruct,
|
||||||
'name2_instruct': name2_instruct,
|
'name2_instruct': name2_instruct,
|
||||||
'context_instruct': context_instruct,
|
'context_instruct': body.get('context_instruct', context_instruct),
|
||||||
'turn_template': turn_template,
|
'turn_template': turn_template,
|
||||||
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||||
'history': body.get('history', {'internal': [], 'visible': []})
|
'history': body.get('history', {'internal': [], 'visible': []})
|
||||||
|
@ -6,6 +6,7 @@ import gradio as gr
|
|||||||
|
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.utils import gradio
|
from modules.utils import gradio
|
||||||
|
from modules.logging_colors import logger
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'activate': True,
|
'activate': True,
|
||||||
@ -13,10 +14,12 @@ params = {
|
|||||||
'selected_voice': 'None',
|
'selected_voice': 'None',
|
||||||
'autoplay': False,
|
'autoplay': False,
|
||||||
'show_text': True,
|
'show_text': True,
|
||||||
|
'model': 'eleven_monolingual_v1',
|
||||||
}
|
}
|
||||||
|
|
||||||
voices = None
|
voices = None
|
||||||
wav_idx = 0
|
wav_idx = 0
|
||||||
|
LANG_MODELS = ['eleven_monolingual_v1', 'eleven_multilingual_v1']
|
||||||
|
|
||||||
|
|
||||||
def update_api_key(key):
|
def update_api_key(key):
|
||||||
@ -108,7 +111,7 @@ def output_modifier(string):
|
|||||||
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
|
output_file = Path(f'extensions/elevenlabs_tts/outputs/{wav_idx:06d}.mp3'.format(wav_idx))
|
||||||
print(f'Outputting audio to {str(output_file)}')
|
print(f'Outputting audio to {str(output_file)}')
|
||||||
try:
|
try:
|
||||||
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model="eleven_monolingual_v1")
|
audio = elevenlabs.generate(text=string, voice=params['selected_voice'], model=params['model'])
|
||||||
elevenlabs.save(audio, str(output_file))
|
elevenlabs.save(audio, str(output_file))
|
||||||
|
|
||||||
autoplay = 'autoplay' if params['autoplay'] else ''
|
autoplay = 'autoplay' if params['autoplay'] else ''
|
||||||
@ -132,7 +135,12 @@ def ui():
|
|||||||
global voices
|
global voices
|
||||||
if not voices:
|
if not voices:
|
||||||
voices = refresh_voices()
|
voices = refresh_voices()
|
||||||
params['selected_voice'] = voices[0]
|
selected = params['selected_voice']
|
||||||
|
if selected == 'None':
|
||||||
|
params['selected_voice'] = voices[0]
|
||||||
|
elif selected not in voices:
|
||||||
|
logger.error(f'Selected voice {selected} not available, switching to {voices[0]}')
|
||||||
|
params['selected_voice'] = voices[0]
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
@ -145,7 +153,14 @@ def ui():
|
|||||||
refresh = gr.Button(value='Refresh')
|
refresh = gr.Button(value='Refresh')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
if params['api_key']:
|
||||||
|
api_key = gr.Textbox(value=params['api_key'], label='API Key')
|
||||||
|
update_api_key(params['api_key'])
|
||||||
|
else:
|
||||||
|
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||||
|
|
||||||
|
with gr.Row():
|
||||||
|
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
convert = gr.Button('Permanently replace audios with the message texts')
|
convert = gr.Button('Permanently replace audios with the message texts')
|
||||||
@ -175,6 +190,7 @@ def ui():
|
|||||||
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
activate.change(lambda x: params.update({'activate': x}), activate, None)
|
||||||
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
voice.change(lambda x: params.update({'selected_voice': x}), voice, None)
|
||||||
api_key.change(update_api_key, api_key, None)
|
api_key.change(update_api_key, api_key, None)
|
||||||
|
model.change(lambda x: params.update({'model': x}), model, None)
|
||||||
# connect.click(check_valid_api, [], connection_status)
|
# connect.click(check_valid_api, [], connection_status)
|
||||||
refresh.click(refresh_voices_dd, [], voice)
|
refresh.click(refresh_voices_dd, [], voice)
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
|
@ -38,6 +38,8 @@ As of now, the following multimodal pipelines are supported:
|
|||||||
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
|[LLaVA 7B](https://github.com/haotian-liu/LLaVA)|`llava-7b`|[LLaVA 7B](https://huggingface.co/wojtab/llava-7b-v0-4bit-128g)|GPTQ 4-bit quant, old CUDA|built-in|
|
||||||
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
|[MiniGPT-4 7B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-7b`|[Vicuna v0 7B](https://huggingface.co/TheBloke/vicuna-7B-GPTQ-4bit-128g)|GPTQ 4-bit quant, new format|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||||
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
|[MiniGPT-4 13B](https://github.com/Vision-CAIR/MiniGPT-4)|`minigpt4-13b`|[Vicuna v0 13B](https://huggingface.co/anon8231489123/vicuna-13b-GPTQ-4bit-128g)|GPTQ 4-bit quant, old CUDA|[Wojtab/minigpt-4-pipeline](https://github.com/Wojtab/minigpt-4-pipeline)|
|
||||||
|
|[InstructBLIP 7B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-7b`|[Vicuna v1.1 7B](https://huggingface.co/TheBloke/vicuna-7B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|
||||||
|
|[InstructBLIP 13B](https://github.com/salesforce/LAVIS/tree/main/projects/instructblip)|`instructblip-13b`|[Vicuna v1.1 13B](https://huggingface.co/TheBloke/vicuna-13B-1.1-GPTQ-4bit-128g)|GPTQ 4-bit quant|[kjerk/instructblip-pipeline](https://github.com/kjerk/instructblip-pipeline)|
|
||||||
|
|
||||||
Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration.
|
Some pipelines could support different LLMs but do note that while it might work, it isn't a supported configuration.
|
||||||
|
|
||||||
|
@ -22,6 +22,7 @@ options = {
|
|||||||
'session_metadata': 'text-generation-webui',
|
'session_metadata': 'text-generation-webui',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
settings = shared.settings.get("ngrok")
|
settings = shared.settings.get("ngrok")
|
||||||
if settings:
|
if settings:
|
||||||
@ -33,4 +34,3 @@ def ui():
|
|||||||
logging.info(f"Ingress established at: {tunnel.url()}")
|
logging.info(f"Ingress established at: {tunnel.url()}")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
||||||
|
|
||||||
|
@ -218,12 +218,11 @@ but there are some exceptions.
|
|||||||
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
| ✅❌ | langchain | https://github.com/hwchase17/langchain | OPENAI_API_BASE=http://127.0.0.1:5001/v1 even with a good 30B-4bit model the result is poor so far. It assumes zero shot python/json coding. Some model tailored prompt formatting improves results greatly. |
|
||||||
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
| ✅❌ | Auto-GPT | https://github.com/Significant-Gravitas/Auto-GPT | OPENAI_API_BASE=http://127.0.0.1:5001/v1 Same issues as langchain. Also assumes a 4k+ context |
|
||||||
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
| ✅❌ | babyagi | https://github.com/yoheinakajima/babyagi | OPENAI_API_BASE=http://127.0.0.1:5001/v1 |
|
||||||
|
| ❌ | guidance | https://github.com/microsoft/guidance | logit_bias and logprobs not yet supported |
|
||||||
|
|
||||||
## Future plans
|
## Future plans
|
||||||
* better error handling
|
|
||||||
* model changing, esp. something for swapping loras or embedding models
|
* model changing, esp. something for swapping loras or embedding models
|
||||||
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
* consider switching to FastAPI + starlette for SSE (openai SSE seems non-standard)
|
||||||
* do something about rate limiting or locking requests for completions, most systems will only be able handle a single request at a time before OOM
|
|
||||||
|
|
||||||
## Bugs? Feedback? Comments? Pull requests?
|
## Bugs? Feedback? Comments? Pull requests?
|
||||||
|
|
||||||
|
597
extensions/openai/completions.py
Normal file
597
extensions/openai/completions.py
Normal file
@ -0,0 +1,597 @@
|
|||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import tiktoken
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from transformers import LogitsProcessor, LogitsProcessorList
|
||||||
|
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import encode, decode, generate_reply
|
||||||
|
|
||||||
|
from extensions.openai.defaults import get_default_req_params, default, clamp
|
||||||
|
from extensions.openai.utils import end_line, debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
|
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||||
|
class LogitsBiasProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, logit_bias={}):
|
||||||
|
self.logit_bias = logit_bias
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.logit_bias:
|
||||||
|
keys = list([int(key) for key in self.logit_bias.keys()])
|
||||||
|
values = list([int(val) for val in self.logit_bias.values()])
|
||||||
|
logits[0, keys] += torch.tensor(values).cuda()
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
class LogprobProcessor(LogitsProcessor):
|
||||||
|
def __init__(self, logprobs=None):
|
||||||
|
self.logprobs = logprobs
|
||||||
|
self.token_alternatives = {}
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
|
if self.logprobs is not None: # 0-5
|
||||||
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
|
# XXX hack. should find the selected token and include the prob of that
|
||||||
|
# ... but we just +1 here instead because we don't know it yet.
|
||||||
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||||
|
top_tokens = [decode(tok) for tok in top_indices[0]]
|
||||||
|
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(model)
|
||||||
|
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||||
|
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||||
|
except KeyError:
|
||||||
|
# assume native tokens if we can't find the tokenizer
|
||||||
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
|
def marshal_common_params(body):
|
||||||
|
# Request Parameters
|
||||||
|
# Try to use openai defaults or map them to something with the same intent
|
||||||
|
|
||||||
|
req_params = get_default_req_params()
|
||||||
|
|
||||||
|
# Common request parameters
|
||||||
|
req_params['truncation_length'] = shared.settings['truncation_length']
|
||||||
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||||
|
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||||
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||||
|
|
||||||
|
# OpenAI API Parameters
|
||||||
|
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||||
|
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||||
|
|
||||||
|
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
|
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||||
|
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||||
|
n = default(body, 'n', 1)
|
||||||
|
if n != 1:
|
||||||
|
raise InvalidRequestError(message="Only n = 1 is supported.", param='n')
|
||||||
|
|
||||||
|
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||||
|
if isinstance(body['stop'], str):
|
||||||
|
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||||
|
elif isinstance(body['stop'], list):
|
||||||
|
req_params['stopping_strings'] = body['stop']
|
||||||
|
|
||||||
|
# presence_penalty - ignored
|
||||||
|
# frequency_penalty - ignored
|
||||||
|
# user - ignored
|
||||||
|
|
||||||
|
logits_processor = []
|
||||||
|
logit_bias = body.get('logit_bias', None)
|
||||||
|
if logit_bias: # {str: float, ...}
|
||||||
|
# XXX convert tokens from tiktoken based on requested model
|
||||||
|
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(req_params['requested_model'])
|
||||||
|
new_logit_bias = {}
|
||||||
|
for logit, bias in logit_bias.items():
|
||||||
|
for x in encode(encoder.decode([int(logit)]))[0]:
|
||||||
|
new_logit_bias[str(int(x))] = bias
|
||||||
|
print(logit_bias, '->', new_logit_bias)
|
||||||
|
logit_bias = new_logit_bias
|
||||||
|
except KeyError:
|
||||||
|
pass # assume native tokens if we can't find the tokenizer
|
||||||
|
|
||||||
|
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||||
|
|
||||||
|
logprobs = None # coming to chat eventually
|
||||||
|
if 'logprobs' in body:
|
||||||
|
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||||
|
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||||
|
logits_processor.extend([req_params['logprob_proc']])
|
||||||
|
else:
|
||||||
|
logprobs = None
|
||||||
|
|
||||||
|
if logits_processor: # requires logits_processor support
|
||||||
|
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||||
|
|
||||||
|
return req_params
|
||||||
|
|
||||||
|
|
||||||
|
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||||
|
# functions
|
||||||
|
if body.get('functions', []): # chat only
|
||||||
|
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||||
|
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||||
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||||
|
|
||||||
|
if not 'messages' in body:
|
||||||
|
raise InvalidRequestError(message="messages is required", param='messages')
|
||||||
|
|
||||||
|
messages = body['messages']
|
||||||
|
|
||||||
|
role_formats = {
|
||||||
|
'user': 'user: {message}\n',
|
||||||
|
'assistant': 'assistant: {message}\n',
|
||||||
|
'system': '{message}',
|
||||||
|
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
||||||
|
'prompt': 'assistant:',
|
||||||
|
}
|
||||||
|
|
||||||
|
if not 'stopping_strings' in req_params:
|
||||||
|
req_params['stopping_strings'] = []
|
||||||
|
|
||||||
|
# Instruct models can be much better
|
||||||
|
if shared.settings['instruction_template']:
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
system_message_template = "{message}"
|
||||||
|
system_message_default = instruct['context']
|
||||||
|
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||||
|
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||||
|
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||||
|
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||||
|
|
||||||
|
role_formats = {
|
||||||
|
'user': user_message_template,
|
||||||
|
'assistant': bot_message_template,
|
||||||
|
'system': system_message_template,
|
||||||
|
'context': system_message_default,
|
||||||
|
'prompt': bot_prompt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
|
req_params['stopping_strings'].extend(['\n###'])
|
||||||
|
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||||
|
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
|
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
req_params['stopping_strings'].extend(['\nuser:'])
|
||||||
|
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
req_params['stopping_strings'].extend(['\nuser:'])
|
||||||
|
print("Warning: Loaded default instruction-following template for model.")
|
||||||
|
|
||||||
|
system_msgs = []
|
||||||
|
chat_msgs = []
|
||||||
|
|
||||||
|
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
||||||
|
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
||||||
|
context_msg = end_line(context_msg)
|
||||||
|
|
||||||
|
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
||||||
|
if 'prompt' in body:
|
||||||
|
context_msg = end_line(role_formats['system'].format(message=body['prompt'])) + context_msg
|
||||||
|
|
||||||
|
for m in messages:
|
||||||
|
role = m['role']
|
||||||
|
content = m['content']
|
||||||
|
# name = m.get('name', None)
|
||||||
|
# function_call = m.get('function_call', None) # user name or function name with output in content
|
||||||
|
msg = role_formats[role].format(message=content)
|
||||||
|
if role == 'system':
|
||||||
|
system_msgs.extend([msg])
|
||||||
|
elif role == 'function':
|
||||||
|
raise InvalidRequestError(message="role: function is not supported.", param='messages')
|
||||||
|
else:
|
||||||
|
chat_msgs.extend([msg])
|
||||||
|
|
||||||
|
system_msg = '\n'.join(system_msgs)
|
||||||
|
system_msg = end_line(system_msg)
|
||||||
|
|
||||||
|
prompt = system_msg + context_msg + ''.join(chat_msgs) + role_formats['prompt']
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count >= req_params['truncation_length']:
|
||||||
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens."
|
||||||
|
raise InvalidRequestError(message=err_msg)
|
||||||
|
|
||||||
|
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||||
|
print(f"Warning: ${err_msg}")
|
||||||
|
# raise InvalidRequestError(message=err_msg)
|
||||||
|
|
||||||
|
return prompt, token_count
|
||||||
|
|
||||||
|
|
||||||
|
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||||
|
# Chat Completions
|
||||||
|
object_type = 'chat.completions'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = False
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||||
|
|
||||||
|
# chat default max_tokens is 'inf', but also flexible
|
||||||
|
max_tokens = 0
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
if max_tokens_str in body:
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
else:
|
||||||
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
|
# format the prompt from messages
|
||||||
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name, # TODO: add Lora info?
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"message": {"role": "assistant", "content": answer}
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# generator
|
||||||
|
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||||
|
|
||||||
|
# Chat Completions
|
||||||
|
stream_object_type = 'chat.completions.chunk'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = True
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||||
|
|
||||||
|
# chat default max_tokens is 'inf', but also flexible
|
||||||
|
max_tokens = 0
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
if max_tokens_str in body:
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['truncation_length'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
else:
|
||||||
|
req_params['max_new_tokens'] = req_params['truncation_length']
|
||||||
|
|
||||||
|
# format the prompt from messages
|
||||||
|
prompt, token_count = messages_to_prompt(body, req_params, max_tokens)
|
||||||
|
|
||||||
|
def chat_streaming_chunk(content):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": stream_object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
# So yeah... do both methods? delta and messages.
|
||||||
|
"message": {'role': 'assistant', 'content': content},
|
||||||
|
"delta": {'role': 'assistant', 'content': content},
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
|
||||||
|
if logprob_proc: # not official for chat yet
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
# else:
|
||||||
|
# chunk[resp_list][0]["logprobs"] = None
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield chat_streaming_chunk('')
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
chunk = chat_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
chunk = chat_streaming_chunk('')
|
||||||
|
chunk[resp_list][0]['finish_reason'] = stop_reason
|
||||||
|
chunk['usage'] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
|
def completions(body: dict, is_legacy: bool = False):
|
||||||
|
# Legacy
|
||||||
|
# Text Completions
|
||||||
|
object_type = 'text_completion'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
if not prompt_str in body:
|
||||||
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encode(encoder.decode(prompt))[0]
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = False
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||||
|
# print(f"Warning: ${err_msg}")
|
||||||
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name, # TODO: add Lora info?
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": stop_reason,
|
||||||
|
"text": answer,
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if logprob_proc:
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
else:
|
||||||
|
resp[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
# generator
|
||||||
|
def stream_completions(body: dict, is_legacy: bool = False):
|
||||||
|
# Legacy
|
||||||
|
# Text Completions
|
||||||
|
# object_type = 'text_completion'
|
||||||
|
stream_object_type = 'text_completion.chunk'
|
||||||
|
created_time = int(time.time())
|
||||||
|
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||||
|
resp_list = 'data' if is_legacy else 'choices'
|
||||||
|
|
||||||
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
|
if not prompt_str in body:
|
||||||
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
|
prompt = body[prompt_str]
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
if prompt and isinstance(prompt[0], int):
|
||||||
|
try:
|
||||||
|
encoder = tiktoken.encoding_for_model(requested_model)
|
||||||
|
prompt = encode(encoder.decode(prompt))[0]
|
||||||
|
except KeyError:
|
||||||
|
prompt = decode(prompt)[0]
|
||||||
|
else:
|
||||||
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
|
# common params
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
req_params['stream'] = True
|
||||||
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
|
||||||
|
token_count = len(encode(prompt)[0])
|
||||||
|
|
||||||
|
if token_count + max_tokens > req_params['truncation_length']:
|
||||||
|
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||||
|
# print(f"Warning: ${err_msg}")
|
||||||
|
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||||
|
|
||||||
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
|
def text_streaming_chunk(content):
|
||||||
|
# begin streaming
|
||||||
|
chunk = {
|
||||||
|
"id": cmpl_id,
|
||||||
|
"object": stream_object_type,
|
||||||
|
"created": created_time,
|
||||||
|
"model": shared.model_name,
|
||||||
|
resp_list: [{
|
||||||
|
"index": 0,
|
||||||
|
"finish_reason": None,
|
||||||
|
"text": content,
|
||||||
|
}],
|
||||||
|
}
|
||||||
|
if logprob_proc:
|
||||||
|
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||||
|
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||||
|
else:
|
||||||
|
chunk[resp_list][0]["logprobs"] = None
|
||||||
|
|
||||||
|
return chunk
|
||||||
|
|
||||||
|
yield text_streaming_chunk('')
|
||||||
|
|
||||||
|
# generate reply #######################################
|
||||||
|
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||||
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
answer = ''
|
||||||
|
seen_content = ''
|
||||||
|
completion_token_count = 0
|
||||||
|
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
len_seen = len(seen_content)
|
||||||
|
new_content = answer[len_seen:]
|
||||||
|
|
||||||
|
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_content = answer
|
||||||
|
|
||||||
|
# strip extra leading space off new generated content
|
||||||
|
if len_seen == 0 and new_content[0] == ' ':
|
||||||
|
new_content = new_content[1:]
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk(new_content)
|
||||||
|
|
||||||
|
completion_token_count += len(encode(new_content)[0])
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
stop_reason = "stop"
|
||||||
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
|
stop_reason = "length"
|
||||||
|
|
||||||
|
chunk = text_streaming_chunk('')
|
||||||
|
chunk[resp_list][0]["finish_reason"] = stop_reason
|
||||||
|
chunk["usage"] = {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
yield chunk
|
67
extensions/openai/defaults.py
Normal file
67
extensions/openai/defaults.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
import copy
|
||||||
|
|
||||||
|
# Slightly different defaults for OpenAI's API
|
||||||
|
# Data type is important, Ex. use 0.0 for a float 0
|
||||||
|
default_req_params = {
|
||||||
|
'max_new_tokens': 16, # 'Inf' for chat
|
||||||
|
'temperature': 1.0,
|
||||||
|
'top_p': 1.0,
|
||||||
|
'top_k': 1, # choose 20 for chat in absence of another default
|
||||||
|
'repetition_penalty': 1.18,
|
||||||
|
'repetition_penalty_range': 0,
|
||||||
|
'encoder_repetition_penalty': 1.0,
|
||||||
|
'suffix': None,
|
||||||
|
'stream': False,
|
||||||
|
'echo': False,
|
||||||
|
'seed': -1,
|
||||||
|
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||||
|
'truncation_length': 2048, # first use shared.settings value
|
||||||
|
'add_bos_token': True,
|
||||||
|
'do_sample': True,
|
||||||
|
'typical_p': 1.0,
|
||||||
|
'epsilon_cutoff': 0.0, # In units of 1e-4
|
||||||
|
'eta_cutoff': 0.0, # In units of 1e-4
|
||||||
|
'tfs': 1.0,
|
||||||
|
'top_a': 0.0,
|
||||||
|
'min_length': 0,
|
||||||
|
'no_repeat_ngram_size': 0,
|
||||||
|
'num_beams': 1,
|
||||||
|
'penalty_alpha': 0.0,
|
||||||
|
'length_penalty': 1.0,
|
||||||
|
'early_stopping': False,
|
||||||
|
'mirostat_mode': 0,
|
||||||
|
'mirostat_tau': 5.0,
|
||||||
|
'mirostat_eta': 0.1,
|
||||||
|
'ban_eos_token': False,
|
||||||
|
'skip_special_tokens': True,
|
||||||
|
'custom_stopping_strings': '',
|
||||||
|
# 'logits_processor' - conditionally passed
|
||||||
|
# 'stopping_strings' - temporarily used
|
||||||
|
# 'logprobs' - temporarily used
|
||||||
|
# 'requested_model' - temporarily used
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_req_params():
|
||||||
|
return copy.deepcopy(default_req_params)
|
||||||
|
|
||||||
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
|
|
||||||
|
|
||||||
|
def default(dic, key, default):
|
||||||
|
val = dic.get(key, default)
|
||||||
|
if type(val) != type(default):
|
||||||
|
# maybe it's just something like 1 instead of 1.0
|
||||||
|
try:
|
||||||
|
v = type(default)(val)
|
||||||
|
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
||||||
|
return v
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
val = default
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
|
def clamp(value, minvalue, maxvalue):
|
||||||
|
return max(minvalue, min(value, maxvalue))
|
102
extensions/openai/edits.py
Normal file
102
extensions/openai/edits.py
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
import time
|
||||||
|
import yaml
|
||||||
|
import os
|
||||||
|
from modules import shared
|
||||||
|
from extensions.openai.defaults import get_default_req_params
|
||||||
|
from extensions.openai.utils import debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
|
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
||||||
|
|
||||||
|
created_time = int(time.time() * 1000)
|
||||||
|
|
||||||
|
# Request parameters
|
||||||
|
req_params = get_default_req_params()
|
||||||
|
stopping_strings = []
|
||||||
|
|
||||||
|
# Alpaca is verbose so a good default prompt
|
||||||
|
default_template = (
|
||||||
|
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||||
|
"Write a response that appropriately completes the request.\n\n"
|
||||||
|
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
instruction_template = default_template
|
||||||
|
|
||||||
|
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||||
|
if shared.settings['instruction_template']:
|
||||||
|
if 'Alpaca' in shared.settings['instruction_template']:
|
||||||
|
stopping_strings.extend(['\n###'])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
||||||
|
|
||||||
|
template = instruct['turn_template']
|
||||||
|
template = template\
|
||||||
|
.replace('<|user|>', instruct.get('user', ''))\
|
||||||
|
.replace('<|bot|>', instruct.get('bot', ''))\
|
||||||
|
.replace('<|user-message|>', '{instruction}\n{input}')
|
||||||
|
|
||||||
|
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||||
|
if instruct['user']:
|
||||||
|
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
instruction_template = default_template
|
||||||
|
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
else:
|
||||||
|
stopping_strings.extend(['\n###'])
|
||||||
|
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
||||||
|
|
||||||
|
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||||
|
|
||||||
|
truncation_length = shared.settings['truncation_length']
|
||||||
|
|
||||||
|
token_count = len(encode(edit_task)[0])
|
||||||
|
max_tokens = truncation_length - token_count
|
||||||
|
|
||||||
|
if max_tokens < 1:
|
||||||
|
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||||
|
raise InvalidRequestError(err_msg, param='input')
|
||||||
|
|
||||||
|
req_params['max_new_tokens'] = max_tokens
|
||||||
|
req_params['truncation_length'] = truncation_length
|
||||||
|
req_params['temperature'] = temperature
|
||||||
|
req_params['top_p'] = top_p
|
||||||
|
req_params['seed'] = shared.settings.get('seed', req_params['seed'])
|
||||||
|
req_params['add_bos_token'] = shared.settings.get('add_bos_token', req_params['add_bos_token'])
|
||||||
|
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||||
|
|
||||||
|
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||||
|
|
||||||
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
|
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||||
|
answer = ''
|
||||||
|
for a in generator:
|
||||||
|
answer = a
|
||||||
|
|
||||||
|
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
||||||
|
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
||||||
|
answer = answer[1:]
|
||||||
|
|
||||||
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"object": "edit",
|
||||||
|
"created": created_time,
|
||||||
|
"choices": [{
|
||||||
|
"text": answer,
|
||||||
|
"index": 0,
|
||||||
|
}],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": token_count,
|
||||||
|
"completion_tokens": completion_token_count,
|
||||||
|
"total_tokens": token_count + completion_token_count
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
54
extensions/openai/embeddings.py
Normal file
54
extensions/openai/embeddings.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import os
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from extensions.openai.utils import float_list_to_base64, debug_msg
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
|
embeddings_model = None
|
||||||
|
|
||||||
|
|
||||||
|
def load_embedding_model(model):
|
||||||
|
try:
|
||||||
|
emb_model = SentenceTransformer(model)
|
||||||
|
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError: Failed to load embedding model: {model}")
|
||||||
|
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||||
|
|
||||||
|
return emb_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings_model():
|
||||||
|
global embeddings_model, st_model
|
||||||
|
if st_model and not embeddings_model:
|
||||||
|
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||||
|
return embeddings_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_embeddings_model_name():
|
||||||
|
global st_model
|
||||||
|
return st_model
|
||||||
|
|
||||||
|
|
||||||
|
def embeddings(input: list, encoding_format: str):
|
||||||
|
|
||||||
|
embeddings = get_embeddings_model().encode(input).tolist()
|
||||||
|
|
||||||
|
if encoding_format == "base64":
|
||||||
|
data = [{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
else:
|
||||||
|
data = [{"object": "embedding", "embedding": emb, "index": n} for n, emb in enumerate(embeddings)]
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"object": "list",
|
||||||
|
"data": data,
|
||||||
|
"model": st_model, # return the real model
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||||
|
|
||||||
|
return response
|
31
extensions/openai/errors.py
Normal file
31
extensions/openai/errors.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
class OpenAIError(Exception):
|
||||||
|
def __init__(self, message=None, code=500, internal_message=''):
|
||||||
|
self.message = message
|
||||||
|
self.code = code
|
||||||
|
self.internal_message = internal_message
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(message=%r, code=%d)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.message,
|
||||||
|
self.code,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRequestError(OpenAIError):
|
||||||
|
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
||||||
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||||
|
self.param = param
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "%s(message=%r, code=%d, param=%s)" % (
|
||||||
|
self.__class__.__name__,
|
||||||
|
self.message,
|
||||||
|
self.code,
|
||||||
|
self.param,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ServiceUnavailableError(OpenAIError):
|
||||||
|
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
||||||
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
49
extensions/openai/images.py
Normal file
49
extensions/openai/images.py
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
import os
|
||||||
|
import time
|
||||||
|
import requests
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
|
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||||
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
|
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
||||||
|
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
||||||
|
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
||||||
|
# Will probably work best with the stock SD models.
|
||||||
|
# SD configuration is beyond the scope of this API.
|
||||||
|
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||||
|
# require changing the form data handling to accept multipart form data, also to properly support
|
||||||
|
# url return types will require file management and a web serving files... Perhaps later!
|
||||||
|
|
||||||
|
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
||||||
|
|
||||||
|
# to hack on better generation, edit default payload.
|
||||||
|
payload = {
|
||||||
|
'prompt': prompt, # ignore prompt limit of 1000 characters
|
||||||
|
'width': width,
|
||||||
|
'height': height,
|
||||||
|
'batch_size': n,
|
||||||
|
'restore_faces': True, # slightly less horrible
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
'created': int(time.time()),
|
||||||
|
'data': []
|
||||||
|
}
|
||||||
|
|
||||||
|
# TODO: support SD_WEBUI_AUTH username:password pair.
|
||||||
|
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
||||||
|
|
||||||
|
response = requests.post(url=sd_url, json=payload)
|
||||||
|
r = response.json()
|
||||||
|
if response.status_code != 200 or 'images' not in r:
|
||||||
|
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
|
||||||
|
# r['parameters']...
|
||||||
|
for b64_json in r['images']:
|
||||||
|
if response_format == 'b64_json':
|
||||||
|
resp['data'].extend([{'b64_json': b64_json}])
|
||||||
|
else:
|
||||||
|
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||||
|
|
||||||
|
return resp
|
79
extensions/openai/models.py
Normal file
79
extensions/openai/models.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from modules import shared
|
||||||
|
from modules.utils import get_available_models
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
|
from modules.models_settings import (get_model_settings_from_yamls,
|
||||||
|
update_model_parameters)
|
||||||
|
|
||||||
|
from extensions.openai.embeddings import get_embeddings_model_name
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_model_list() -> list:
|
||||||
|
return [shared.model_name] # The real chat/completions model, maybe "None"
|
||||||
|
|
||||||
|
|
||||||
|
def get_pseudo_model_list() -> list:
|
||||||
|
return [ # these are expected by so much, so include some here as a dummy
|
||||||
|
'gpt-3.5-turbo',
|
||||||
|
'text-embedding-ada-002',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name: str) -> dict:
|
||||||
|
resp = {
|
||||||
|
"id": model_name,
|
||||||
|
"object": "engine",
|
||||||
|
"owner": "self",
|
||||||
|
"ready": True,
|
||||||
|
}
|
||||||
|
if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only
|
||||||
|
# No args. Maybe it works anyways!
|
||||||
|
# TODO: hack some heuristics into args for better results
|
||||||
|
|
||||||
|
shared.model_name = model_name
|
||||||
|
unload_model()
|
||||||
|
|
||||||
|
model_settings = get_model_settings_from_yamls(shared.model_name)
|
||||||
|
shared.settings.update(model_settings)
|
||||||
|
update_model_parameters(model_settings, initial=True)
|
||||||
|
|
||||||
|
if shared.settings['mode'] != 'instruct':
|
||||||
|
shared.settings['instruction_template'] = None
|
||||||
|
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
|
||||||
|
if not shared.model: # load failed.
|
||||||
|
shared.model_name = "None"
|
||||||
|
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def list_models(is_legacy: bool = False) -> dict:
|
||||||
|
# TODO: Lora's?
|
||||||
|
all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()
|
||||||
|
|
||||||
|
models = {}
|
||||||
|
|
||||||
|
if is_legacy:
|
||||||
|
models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
|
||||||
|
if not shared.model:
|
||||||
|
models[0]['ready'] = False
|
||||||
|
else:
|
||||||
|
models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]
|
||||||
|
|
||||||
|
resp = {
|
||||||
|
"object": "list",
|
||||||
|
"data": models,
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def model_info(model_name: str) -> dict:
|
||||||
|
return {
|
||||||
|
"id": model_name,
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": "user",
|
||||||
|
"permission": []
|
||||||
|
}
|
69
extensions/openai/moderations.py
Normal file
69
extensions/openai/moderations.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from numpy.linalg import norm
|
||||||
|
from extensions.openai.embeddings import get_embeddings_model
|
||||||
|
|
||||||
|
|
||||||
|
moderations_disabled = False # return 0/false
|
||||||
|
category_embeddings = None
|
||||||
|
antonym_embeddings = None
|
||||||
|
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
|
||||||
|
flag_threshold = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def get_category_embeddings():
|
||||||
|
global category_embeddings, categories
|
||||||
|
if category_embeddings is None:
|
||||||
|
embeddings = get_embeddings_model().encode(categories).tolist()
|
||||||
|
category_embeddings = dict(zip(categories, embeddings))
|
||||||
|
|
||||||
|
return category_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a, b):
|
||||||
|
return np.dot(a, b) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
|
||||||
|
# seems most openai like with all-mpnet-base-v2
|
||||||
|
def mod_score(a, b):
|
||||||
|
return 2.0 * np.dot(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def moderations(input):
|
||||||
|
global category_embeddings, categories, flag_threshold, moderations_disabled
|
||||||
|
results = {
|
||||||
|
"id": f"modr-{int(time.time()*1e9)}",
|
||||||
|
"model": "text-moderation-001",
|
||||||
|
"results": [],
|
||||||
|
}
|
||||||
|
|
||||||
|
embeddings_model = get_embeddings_model()
|
||||||
|
if not embeddings_model or moderations_disabled:
|
||||||
|
results['results'] = [{
|
||||||
|
'categories': dict([(C, False) for C in categories]),
|
||||||
|
'category_scores': dict([(C, 0.0) for C in categories]),
|
||||||
|
'flagged': False,
|
||||||
|
}]
|
||||||
|
return results
|
||||||
|
|
||||||
|
category_embeddings = get_category_embeddings()
|
||||||
|
|
||||||
|
# input, string or array
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
|
||||||
|
for in_str in input:
|
||||||
|
for ine in embeddings_model.encode([in_str]).tolist():
|
||||||
|
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
|
||||||
|
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
|
||||||
|
flagged = any(category_flags.values())
|
||||||
|
|
||||||
|
results['results'].extend([{
|
||||||
|
'flagged': flagged,
|
||||||
|
'categories': category_flags,
|
||||||
|
'category_scores': category_scores,
|
||||||
|
}])
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
return results
|
@ -1,2 +1,3 @@
|
|||||||
flask_cloudflared==0.0.12
|
flask_cloudflared==0.0.12
|
||||||
sentence-transformers
|
sentence-transformers
|
||||||
|
tiktoken
|
@ -1,107 +1,27 @@
|
|||||||
import base64
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import traceback
|
||||||
import requests
|
|
||||||
import yaml
|
|
||||||
import numpy as np
|
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from modules.utils import get_available_models
|
|
||||||
from modules.models import load_model, unload_model
|
|
||||||
from modules.models_settings import (get_model_settings_from_yamls,
|
|
||||||
update_model_parameters)
|
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
|
||||||
|
from extensions.openai.tokens import token_count, token_encode, token_decode
|
||||||
|
import extensions.openai.models as OAImodels
|
||||||
|
import extensions.openai.edits as OAIedits
|
||||||
|
import extensions.openai.embeddings as OAIembeddings
|
||||||
|
import extensions.openai.images as OAIimages
|
||||||
|
import extensions.openai.moderations as OAImoderations
|
||||||
|
import extensions.openai.completions as OAIcompletions
|
||||||
|
from extensions.openai.errors import *
|
||||||
|
from extensions.openai.utils import debug_msg
|
||||||
|
from extensions.openai.defaults import (get_default_req_params, default, clamp)
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||||
}
|
}
|
||||||
|
|
||||||
debug = True if 'OPENEDAI_DEBUG' in os.environ else False
|
|
||||||
|
|
||||||
# Slightly different defaults for OpenAI's API
|
|
||||||
# Data type is important, Ex. use 0.0 for a float 0
|
|
||||||
default_req_params = {
|
|
||||||
'max_new_tokens': 200,
|
|
||||||
'temperature': 1.0,
|
|
||||||
'top_p': 1.0,
|
|
||||||
'top_k': 1,
|
|
||||||
'repetition_penalty': 1.18,
|
|
||||||
'repetition_penalty_range': 0,
|
|
||||||
'encoder_repetition_penalty': 1.0,
|
|
||||||
'suffix': None,
|
|
||||||
'stream': False,
|
|
||||||
'echo': False,
|
|
||||||
'seed': -1,
|
|
||||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
|
||||||
'truncation_length': 2048,
|
|
||||||
'add_bos_token': True,
|
|
||||||
'do_sample': True,
|
|
||||||
'typical_p': 1.0,
|
|
||||||
'epsilon_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'eta_cutoff': 0.0, # In units of 1e-4
|
|
||||||
'tfs': 1.0,
|
|
||||||
'top_a': 0.0,
|
|
||||||
'min_length': 0,
|
|
||||||
'no_repeat_ngram_size': 0,
|
|
||||||
'num_beams': 1,
|
|
||||||
'penalty_alpha': 0.0,
|
|
||||||
'length_penalty': 1.0,
|
|
||||||
'early_stopping': False,
|
|
||||||
'mirostat_mode': 0,
|
|
||||||
'mirostat_tau': 5.0,
|
|
||||||
'mirostat_eta': 0.1,
|
|
||||||
'ban_eos_token': False,
|
|
||||||
'skip_special_tokens': True,
|
|
||||||
'custom_stopping_strings': '',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Optional, install the module and download the model to enable
|
|
||||||
# v1/embeddings
|
|
||||||
try:
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
|
||||||
embedding_model = None
|
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
|
||||||
def default(dic, key, default):
|
|
||||||
val = dic.get(key, default)
|
|
||||||
if type(val) != type(default):
|
|
||||||
# maybe it's just something like 1 instead of 1.0
|
|
||||||
try:
|
|
||||||
v = type(default)(val)
|
|
||||||
if type(val)(v) == val: # if it's the same value passed in, it's ok.
|
|
||||||
return v
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
val = default
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def clamp(value, minvalue, maxvalue):
|
|
||||||
return max(minvalue, min(value, maxvalue))
|
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
|
||||||
float_array = np.array(float_list, dtype="float32")
|
|
||||||
|
|
||||||
# Get raw bytes
|
|
||||||
bytes_array = float_array.tobytes()
|
|
||||||
|
|
||||||
# Encode bytes into base64
|
|
||||||
encoded_bytes = base64.b64encode(bytes_array)
|
|
||||||
|
|
||||||
# Turn raw base64 encoded bytes into ASCII
|
|
||||||
ascii_string = encoded_bytes.decode('ascii')
|
|
||||||
return ascii_string
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def send_access_control_headers(self):
|
def send_access_control_headers(self):
|
||||||
@ -118,11 +38,43 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
"Authorization"
|
"Authorization"
|
||||||
)
|
)
|
||||||
|
|
||||||
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
def do_OPTIONS(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'application/json')
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write("OK".encode('utf-8'))
|
||||||
|
|
||||||
|
def start_sse(self):
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_access_control_headers()
|
||||||
|
self.send_header('Content-Type', 'text/event-stream')
|
||||||
|
self.send_header('Cache-Control', 'no-cache')
|
||||||
|
# self.send_header('Connection', 'keep-alive')
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
def send_sse(self, chunk: dict):
|
||||||
|
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
||||||
|
debug_msg(response)
|
||||||
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
|
||||||
|
def end_sse(self):
|
||||||
|
self.wfile.write('data: [DONE]\r\n\r\n'.encode('utf-8'))
|
||||||
|
|
||||||
|
def return_json(self, ret: dict, code: int = 200, no_debug=False):
|
||||||
self.send_response(code)
|
self.send_response(code)
|
||||||
self.send_access_control_headers()
|
self.send_access_control_headers()
|
||||||
self.send_header('Content-Type', 'application/json')
|
self.send_header('Content-Type', 'application/json')
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
|
|
||||||
|
response = json.dumps(ret)
|
||||||
|
r_utf8 = response.encode('utf-8')
|
||||||
|
self.wfile.write(r_utf8)
|
||||||
|
if not no_debug:
|
||||||
|
debug_msg(r_utf8)
|
||||||
|
|
||||||
|
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
||||||
|
|
||||||
error_resp = {
|
error_resp = {
|
||||||
'error': {
|
'error': {
|
||||||
'message': message,
|
'message': message,
|
||||||
@ -132,121 +84,61 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if internal_message:
|
if internal_message:
|
||||||
error_resp['internal_message'] = internal_message
|
print(internal_message)
|
||||||
|
# error_resp['internal_message'] = internal_message
|
||||||
|
|
||||||
response = json.dumps(error_resp)
|
self.return_json(error_resp, code)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
def do_OPTIONS(self):
|
def openai_error_handler(func):
|
||||||
self.send_response(200)
|
def wrapper(self):
|
||||||
self.send_access_control_headers()
|
try:
|
||||||
self.send_header('Content-Type', 'application/json')
|
func(self)
|
||||||
self.end_headers()
|
except ServiceUnavailableError as e:
|
||||||
self.wfile.write("OK".encode('utf-8'))
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||||
|
except InvalidRequestError as e:
|
||||||
|
self.openai_error(e.message, e.code, e.error_type, e.param, internal_message=e.internal_message)
|
||||||
|
except OpenAIError as e:
|
||||||
|
self.openai_error(e.message, e.code, e.error_type, internal_message=e.internal_message)
|
||||||
|
except Exception as e:
|
||||||
|
self.openai_error(repr(e), 500, 'OpenAIError', internal_message=traceback.format_exc())
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
@openai_error_handler
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
debug_msg(self.requestline)
|
||||||
current_model_list = [ shared.model_name ] # The real chat/completions model, maybe "None"
|
debug_msg(self.headers)
|
||||||
embeddings_model_list = [ st_model ] if embedding_model else [] # The real sentence transformer embeddings model
|
|
||||||
pseudo_model_list = [ # these are expected by so much, so include some here as a dummy
|
|
||||||
'gpt-3.5-turbo', # /v1/chat/completions
|
|
||||||
'text-curie-001', # /v1/completions, 2k context
|
|
||||||
'text-davinci-002' # /v1/embeddings text-embedding-ada-002:1536, text-davinci-002:768
|
|
||||||
]
|
|
||||||
|
|
||||||
|
if self.path.startswith('/v1/engines') or self.path.startswith('/v1/models'):
|
||||||
is_legacy = 'engines' in self.path
|
is_legacy = 'engines' in self.path
|
||||||
is_list = self.path in ['/v1/engines', '/v1/models']
|
is_list = self.path in ['/v1/engines', '/v1/models']
|
||||||
|
if is_legacy and not is_list:
|
||||||
resp = ''
|
|
||||||
|
|
||||||
if is_legacy and not is_list: # load model
|
|
||||||
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
model_name = self.path[self.path.find('/v1/engines/') + len('/v1/engines/'):]
|
||||||
|
resp = OAImodels.load_model(model_name)
|
||||||
resp = {
|
|
||||||
"id": model_name,
|
|
||||||
"object": "engine",
|
|
||||||
"owner": "self",
|
|
||||||
"ready": True,
|
|
||||||
}
|
|
||||||
if model_name not in pseudo_model_list + embeddings_model_list + current_model_list: # Real model only
|
|
||||||
# No args. Maybe it works anyways!
|
|
||||||
# TODO: hack some heuristics into args for better results
|
|
||||||
|
|
||||||
shared.model_name = model_name
|
|
||||||
unload_model()
|
|
||||||
|
|
||||||
model_settings = get_model_settings_from_yamls(shared.model_name)
|
|
||||||
shared.settings.update(model_settings)
|
|
||||||
update_model_parameters(model_settings, initial=True)
|
|
||||||
|
|
||||||
if shared.settings['mode'] != 'instruct':
|
|
||||||
shared.settings['instruction_template'] = None
|
|
||||||
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
|
||||||
if not shared.model: # load failed.
|
|
||||||
shared.model_name = "None"
|
|
||||||
resp['id'] = "None"
|
|
||||||
resp['ready'] = False
|
|
||||||
|
|
||||||
elif is_list:
|
elif is_list:
|
||||||
# TODO: Lora's?
|
resp = OAImodels.list_models(is_legacy)
|
||||||
available_model_list = get_available_models()
|
|
||||||
all_model_list = current_model_list + embeddings_model_list + pseudo_model_list + available_model_list
|
|
||||||
|
|
||||||
models = {}
|
|
||||||
|
|
||||||
if is_legacy:
|
|
||||||
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
|
||||||
if not shared.model:
|
|
||||||
models[0]['ready'] = False
|
|
||||||
else:
|
|
||||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"object": "list",
|
|
||||||
"data": models,
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
the_model_name = self.path[len('/v1/models/'):]
|
model_name = self.path[len('/v1/models/'):]
|
||||||
resp = {
|
resp = OAImodels.model_info()
|
||||||
"id": the_model_name,
|
|
||||||
"object": "model",
|
|
||||||
"owned_by": "user",
|
|
||||||
"permission": []
|
|
||||||
}
|
|
||||||
|
|
||||||
self.send_response(200)
|
self.return_json(resp)
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/billing/usage' in self.path:
|
elif '/billing/usage' in self.path:
|
||||||
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
# Ex. /v1/dashboard/billing/usage?start_date=2023-05-01&end_date=2023-05-31
|
||||||
self.send_response(200)
|
self.return_json({"total_usage": 0}, no_debug=True)
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
"total_usage": 0,
|
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
@openai_error_handler
|
||||||
def do_POST(self):
|
def do_POST(self):
|
||||||
if debug:
|
debug_msg(self.requestline)
|
||||||
print(self.headers) # did you know... python-openai sends your linux kernel & python version?
|
debug_msg(self.headers)
|
||||||
|
|
||||||
content_length = int(self.headers['Content-Length'])
|
content_length = int(self.headers['Content-Length'])
|
||||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||||
|
|
||||||
if debug:
|
debug_msg(body)
|
||||||
print(body)
|
|
||||||
|
|
||||||
if '/completions' in self.path or '/generate' in self.path:
|
if '/completions' in self.path or '/generate' in self.path:
|
||||||
|
|
||||||
@ -255,621 +147,109 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
return
|
return
|
||||||
|
|
||||||
is_legacy = '/generate' in self.path
|
is_legacy = '/generate' in self.path
|
||||||
is_chat_request = 'chat' in self.path
|
is_streaming = body.get('stream', False)
|
||||||
resp_list = 'data' if is_legacy else 'choices'
|
|
||||||
|
|
||||||
# XXX model is ignored for now
|
|
||||||
# model = body.get('model', shared.model_name) # ignored, use existing for now
|
|
||||||
model = shared.model_name
|
|
||||||
created_time = int(time.time())
|
|
||||||
|
|
||||||
cmpl_id = "chatcmpl-%d" % (created_time) if is_chat_request else "conv-%d" % (created_time)
|
|
||||||
|
|
||||||
# Request Parameters
|
|
||||||
# Try to use openai defaults or map them to something with the same intent
|
|
||||||
req_params = default_req_params.copy()
|
|
||||||
stopping_strings = []
|
|
||||||
|
|
||||||
if 'stop' in body:
|
|
||||||
if isinstance(body['stop'], str):
|
|
||||||
stopping_strings.extend([body['stop']])
|
|
||||||
elif isinstance(body['stop'], list):
|
|
||||||
stopping_strings.extend(body['stop'])
|
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
||||||
truncation_length = clamp(default(body, 'truncation_length', truncation_length), 1, truncation_length)
|
|
||||||
|
|
||||||
default_max_tokens = truncation_length if is_chat_request else 16 # completions default, chat default is 'inf' so we need to cap it.
|
|
||||||
|
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
|
||||||
max_tokens = default(body, max_tokens_str, default(shared.settings, 'max_new_tokens', default_max_tokens))
|
|
||||||
# if the user assumes OpenAI, the max_tokens is way too large - try to ignore it unless it's small enough
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
req_params['truncation_length'] = truncation_length
|
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
|
||||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
|
||||||
req_params['top_k'] = default(body, 'best_of', default_req_params['top_k'])
|
|
||||||
req_params['suffix'] = default(body, 'suffix', default_req_params['suffix'])
|
|
||||||
req_params['stream'] = default(body, 'stream', default_req_params['stream'])
|
|
||||||
req_params['echo'] = default(body, 'echo', default_req_params['echo'])
|
|
||||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
|
||||||
|
|
||||||
is_streaming = req_params['stream']
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
if is_streaming:
|
|
||||||
self.send_header('Content-Type', 'text/event-stream')
|
|
||||||
self.send_header('Cache-Control', 'no-cache')
|
|
||||||
# self.send_header('Connection', 'keep-alive')
|
|
||||||
else:
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
token_count = 0
|
|
||||||
completion_token_count = 0
|
|
||||||
prompt = ''
|
|
||||||
stream_object_type = ''
|
|
||||||
object_type = ''
|
|
||||||
|
|
||||||
if is_chat_request:
|
|
||||||
# Chat Completions
|
|
||||||
stream_object_type = 'chat.completions.chunk'
|
|
||||||
object_type = 'chat.completions'
|
|
||||||
|
|
||||||
messages = body['messages']
|
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': 'user: {message}\n',
|
|
||||||
'assistant': 'assistant: {message}\n',
|
|
||||||
'system': '{message}',
|
|
||||||
'context': 'You are a helpful assistant. Answer as concisely as possible.',
|
|
||||||
'prompt': 'assistant:',
|
|
||||||
}
|
|
||||||
|
|
||||||
# Instruct models can be much better
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
system_message_template = "{message}"
|
|
||||||
system_message_default = instruct['context']
|
|
||||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
|
||||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
|
||||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
|
||||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
|
||||||
|
|
||||||
role_formats = {
|
|
||||||
'user': user_message_template,
|
|
||||||
'assistant': bot_message_template,
|
|
||||||
'system': system_message_template,
|
|
||||||
'context': system_message_default,
|
|
||||||
'prompt': bot_prompt,
|
|
||||||
}
|
|
||||||
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
|
||||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
stopping_strings.extend(['\nuser:'])
|
|
||||||
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
else:
|
|
||||||
stopping_strings.extend(['\nuser:'])
|
|
||||||
print("Warning: Loaded default instruction-following template for model.")
|
|
||||||
|
|
||||||
system_msgs = []
|
|
||||||
chat_msgs = []
|
|
||||||
|
|
||||||
# You are ChatGPT, a large language model trained by OpenAI. Answer as concisely as possible. Knowledge cutoff: {knowledge_cutoff} Current date: {current_date}
|
|
||||||
context_msg = role_formats['system'].format(message=role_formats['context']) if role_formats['context'] else ''
|
|
||||||
if context_msg:
|
|
||||||
system_msgs.extend([context_msg])
|
|
||||||
|
|
||||||
# Maybe they sent both? This is not documented in the API, but some clients seem to do this.
|
|
||||||
if 'prompt' in body:
|
|
||||||
prompt_msg = role_formats['system'].format(message=body['prompt'])
|
|
||||||
system_msgs.extend([prompt_msg])
|
|
||||||
|
|
||||||
for m in messages:
|
|
||||||
role = m['role']
|
|
||||||
content = m['content']
|
|
||||||
msg = role_formats[role].format(message=content)
|
|
||||||
if role == 'system':
|
|
||||||
system_msgs.extend([msg])
|
|
||||||
else:
|
|
||||||
chat_msgs.extend([msg])
|
|
||||||
|
|
||||||
# can't really truncate the system messages
|
|
||||||
system_msg = '\n'.join(system_msgs)
|
|
||||||
if system_msg and system_msg[-1] != '\n':
|
|
||||||
system_msg = system_msg + '\n'
|
|
||||||
|
|
||||||
system_token_count = len(encode(system_msg)[0])
|
|
||||||
remaining_tokens = truncation_length - system_token_count
|
|
||||||
chat_msg = ''
|
|
||||||
|
|
||||||
while chat_msgs:
|
|
||||||
new_msg = chat_msgs.pop()
|
|
||||||
new_size = len(encode(new_msg)[0])
|
|
||||||
if new_size <= remaining_tokens:
|
|
||||||
chat_msg = new_msg + chat_msg
|
|
||||||
remaining_tokens -= new_size
|
|
||||||
else:
|
|
||||||
print(f"Warning: too many messages for context size, dropping {len(chat_msgs) + 1} oldest message(s).")
|
|
||||||
break
|
|
||||||
|
|
||||||
prompt = system_msg + chat_msg + role_formats['prompt']
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Text Completions
|
|
||||||
stream_object_type = 'text_completion.chunk'
|
|
||||||
object_type = 'text_completion'
|
|
||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
|
||||||
if is_legacy:
|
|
||||||
prompt = body['context'] # Older engines.generate API
|
|
||||||
else:
|
|
||||||
prompt = body['prompt'] # XXX this can be different types
|
|
||||||
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
self.openai_error("API Batched generation not yet supported.")
|
|
||||||
return
|
|
||||||
|
|
||||||
token_count = len(encode(prompt)[0])
|
|
||||||
if token_count >= truncation_length:
|
|
||||||
new_len = int(len(prompt) * shared.settings['truncation_length'] / token_count)
|
|
||||||
prompt = prompt[-new_len:]
|
|
||||||
new_token_count = len(encode(prompt)[0])
|
|
||||||
print(f"Warning: truncating prompt to {new_len} characters, was {token_count} tokens. Now: {new_token_count} tokens.")
|
|
||||||
token_count = new_token_count
|
|
||||||
|
|
||||||
if truncation_length - token_count < req_params['max_new_tokens']:
|
|
||||||
print(f"Warning: Ignoring max_new_tokens ({req_params['max_new_tokens']}), too large for the remaining context. Remaining tokens: {truncation_length - token_count}")
|
|
||||||
req_params['max_new_tokens'] = truncation_length - token_count
|
|
||||||
print(f"Warning: Set max_new_tokens = {req_params['max_new_tokens']}")
|
|
||||||
|
|
||||||
if is_streaming:
|
if is_streaming:
|
||||||
# begin streaming
|
self.start_sse()
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name,
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": None,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
response = []
|
||||||
chunk[resp_list][0]["text"] = ""
|
if 'chat' in self.path:
|
||||||
|
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||||
else:
|
else:
|
||||||
# So yeah... do both methods? delta and messages.
|
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||||
chunk[resp_list][0]["message"] = {'role': 'assistant', 'content': ''}
|
|
||||||
chunk[resp_list][0]["delta"] = {'role': 'assistant', 'content': ''}
|
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
for resp in response:
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.send_sse(resp)
|
||||||
|
|
||||||
# generate reply #######################################
|
self.end_sse()
|
||||||
if debug:
|
|
||||||
print({'prompt': prompt, 'req_params': req_params})
|
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
answer = ''
|
|
||||||
seen_content = ''
|
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
||||||
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
stop_string_found = False
|
|
||||||
len_seen = len(seen_content)
|
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
idx = answer.find(string, search_start)
|
|
||||||
if idx != -1:
|
|
||||||
answer = answer[:idx] # clip it.
|
|
||||||
stop_string_found = True
|
|
||||||
|
|
||||||
if stop_string_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
# If something like "\nYo" is generated just before "\nYou:"
|
|
||||||
# is completed, buffer and generate more, don't send it
|
|
||||||
buffer_and_continue = False
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
for j in range(len(string) - 1, 0, -1):
|
|
||||||
if answer[-j:] == string[:j]:
|
|
||||||
buffer_and_continue = True
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
break
|
|
||||||
|
|
||||||
if buffer_and_continue:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
# Streaming
|
|
||||||
new_content = answer[len_seen:]
|
|
||||||
|
|
||||||
if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet.
|
|
||||||
continue
|
|
||||||
|
|
||||||
seen_content = answer
|
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": shared.model_name,
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": None,
|
|
||||||
}],
|
|
||||||
}
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if len_seen == 0 and new_content[0] == ' ':
|
|
||||||
new_content = new_content[1:]
|
|
||||||
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
|
||||||
chunk[resp_list][0]['text'] = new_content
|
|
||||||
else:
|
|
||||||
# So yeah... do both methods? delta and messages.
|
|
||||||
chunk[resp_list][0]['message'] = {'content': new_content}
|
|
||||||
chunk[resp_list][0]['delta'] = {'content': new_content}
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\n'
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
completion_token_count += len(encode(new_content)[0])
|
|
||||||
|
|
||||||
if is_streaming:
|
|
||||||
chunk = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": stream_object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model, # TODO: add Lora info?
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": "stop",
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if stream_object_type == 'text_completion.chunk':
|
|
||||||
chunk[resp_list][0]['text'] = ''
|
|
||||||
else:
|
|
||||||
# So yeah... do both methods? delta and messages.
|
|
||||||
chunk[resp_list][0]['message'] = {'content': ''}
|
|
||||||
chunk[resp_list][0]['delta'] = {'content': ''}
|
|
||||||
|
|
||||||
response = 'data: ' + json.dumps(chunk) + '\r\n\r\ndata: [DONE]\r\n\r\n'
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
# Finished if streaming.
|
|
||||||
if debug:
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
print({'answer': answer}, chunk)
|
|
||||||
return
|
|
||||||
|
|
||||||
# strip extra leading space off new generated content
|
|
||||||
if answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'response': answer})
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
stop_reason = "stop"
|
|
||||||
if token_count + completion_token_count >= truncation_length:
|
|
||||||
stop_reason = "length"
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"id": cmpl_id,
|
|
||||||
"object": object_type,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model, # TODO: add Lora info?
|
|
||||||
resp_list: [{
|
|
||||||
"index": 0,
|
|
||||||
"finish_reason": stop_reason,
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if is_chat_request:
|
|
||||||
resp[resp_list][0]["message"] = {"role": "assistant", "content": answer}
|
|
||||||
else:
|
else:
|
||||||
resp[resp_list][0]["text"] = answer
|
response = ''
|
||||||
|
if 'chat' in self.path:
|
||||||
|
response = OAIcompletions.chat_completions(body, is_legacy=is_legacy)
|
||||||
|
else:
|
||||||
|
response = OAIcompletions.completions(body, is_legacy=is_legacy)
|
||||||
|
|
||||||
response = json.dumps(resp)
|
self.return_json(response)
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/edits' in self.path:
|
elif '/edits' in self.path:
|
||||||
|
# deprecated
|
||||||
|
|
||||||
if not shared.model:
|
if not shared.model:
|
||||||
self.openai_error("No model loaded.")
|
self.openai_error("No model loaded.")
|
||||||
return
|
return
|
||||||
|
|
||||||
self.send_response(200)
|
req_params = get_default_req_params()
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
created_time = int(time.time())
|
|
||||||
|
|
||||||
# Using Alpaca format, this may work with other models too.
|
|
||||||
instruction = body['instruction']
|
instruction = body['instruction']
|
||||||
input = body.get('input', '')
|
input = body.get('input', '')
|
||||||
|
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||||
|
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||||
|
|
||||||
# Request parameters
|
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||||
req_params = default_req_params.copy()
|
|
||||||
stopping_strings = []
|
|
||||||
|
|
||||||
# Alpaca is verbose so a good default prompt
|
self.return_json(response)
|
||||||
default_template = (
|
|
||||||
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
|
||||||
"Write a response that appropriately completes the request.\n\n"
|
|
||||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
instruction_template = default_template
|
|
||||||
|
|
||||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
|
||||||
if shared.settings['instruction_template']:
|
|
||||||
if 'Alpaca' in shared.settings['instruction_template']:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
instruct = yaml.safe_load(open(f"characters/instruction-following/{shared.settings['instruction_template']}.yaml", 'r'))
|
|
||||||
|
|
||||||
template = instruct['turn_template']
|
|
||||||
template = template\
|
|
||||||
.replace('<|user|>', instruct.get('user', ''))\
|
|
||||||
.replace('<|bot|>', instruct.get('bot', ''))\
|
|
||||||
.replace('<|user-message|>', '{instruction}\n{input}')
|
|
||||||
|
|
||||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
|
||||||
if instruct['user']:
|
|
||||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
instruction_template = default_template
|
|
||||||
print(f"Exception: When loading characters/instruction-following/{shared.settings['instruction_template']}.yaml: {repr(e)}")
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
else:
|
|
||||||
stopping_strings.extend(['\n###'])
|
|
||||||
print("Warning: Loaded default instruction-following template (Alpaca) for model.")
|
|
||||||
|
|
||||||
|
|
||||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
|
||||||
|
|
||||||
truncation_length = default(shared.settings, 'truncation_length', 2048)
|
|
||||||
token_count = len(encode(edit_task)[0])
|
|
||||||
max_tokens = truncation_length - token_count
|
|
||||||
|
|
||||||
req_params['max_new_tokens'] = max_tokens
|
|
||||||
req_params['truncation_length'] = truncation_length
|
|
||||||
req_params['temperature'] = clamp(default(body, 'temperature', default_req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
|
||||||
req_params['top_p'] = clamp(default(body, 'top_p', default_req_params['top_p']), 0.001, 1.0)
|
|
||||||
req_params['seed'] = shared.settings.get('seed', default_req_params['seed'])
|
|
||||||
req_params['add_bos_token'] = shared.settings.get('add_bos_token', default_req_params['add_bos_token'])
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
|
||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
|
||||||
|
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
||||||
answer = ''
|
|
||||||
seen_content = ''
|
|
||||||
for a in generator:
|
|
||||||
answer = a
|
|
||||||
|
|
||||||
stop_string_found = False
|
|
||||||
len_seen = len(seen_content)
|
|
||||||
search_start = max(len_seen - longest_stop_len, 0)
|
|
||||||
|
|
||||||
for string in stopping_strings:
|
|
||||||
idx = answer.find(string, search_start)
|
|
||||||
if idx != -1:
|
|
||||||
answer = answer[:idx] # clip it.
|
|
||||||
stop_string_found = True
|
|
||||||
|
|
||||||
if stop_string_found:
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
# some reply's have an extra leading space to fit the instruction template, just clip it off from the reply.
|
|
||||||
if edit_task[-1] != '\n' and answer and answer[0] == ' ':
|
|
||||||
answer = answer[1:]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
|
||||||
|
|
||||||
resp = {
|
|
||||||
"object": "edit",
|
|
||||||
"created": created_time,
|
|
||||||
"choices": [{
|
|
||||||
"text": answer,
|
|
||||||
"index": 0,
|
|
||||||
}],
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": token_count,
|
|
||||||
"completion_tokens": completion_token_count,
|
|
||||||
"total_tokens": token_count + completion_token_count
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print({'answer': answer, 'completion_token_count': completion_token_count})
|
|
||||||
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
elif '/images/generations' in self.path and 'SD_WEBUI_URL' in os.environ:
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
prompt = body['prompt']
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
size = default(body, 'size', '1024x1024')
|
||||||
# the results will be limited and likely poor. SD has hundreds of models and dozens of settings.
|
|
||||||
# If you want high quality tailored results you should just use the Stable Diffusion API directly.
|
|
||||||
# it's too general an API to try and shape the result with specific tags like "masterpiece", etc,
|
|
||||||
# Will probably work best with the stock SD models.
|
|
||||||
# SD configuration is beyond the scope of this API.
|
|
||||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
|
||||||
# require changing the form data handling to accept multipart form data, also to properly support
|
|
||||||
# url return types will require file management and a web serving files... Perhaps later!
|
|
||||||
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
width, height = [ int(x) for x in default(body, 'size', '1024x1024').split('x') ] # ignore the restrictions on size
|
|
||||||
response_format = default(body, 'response_format', 'url') # or b64_json
|
response_format = default(body, 'response_format', 'url') # or b64_json
|
||||||
|
n = default(body, 'n', 1) # ignore the batch limits of max 10
|
||||||
|
|
||||||
payload = {
|
response = OAIimages.generations(prompt=prompt, size=size, response_format=response_format, n=n)
|
||||||
'prompt': body['prompt'], # ignore prompt limit of 1000 characters
|
|
||||||
'width': width,
|
|
||||||
'height': height,
|
|
||||||
'batch_size': default(body, 'n', 1) # ignore the batch limits of max 10
|
|
||||||
}
|
|
||||||
|
|
||||||
resp = {
|
self.return_json(response, no_debug=True)
|
||||||
'created': int(time.time()),
|
|
||||||
'data': []
|
|
||||||
}
|
|
||||||
|
|
||||||
# TODO: support SD_WEBUI_AUTH username:password pair.
|
elif '/embeddings' in self.path:
|
||||||
sd_url = f"{os.environ['SD_WEBUI_URL']}/sdapi/v1/txt2img"
|
encoding_format = body.get('encoding_format', '')
|
||||||
|
|
||||||
response = requests.post(url=sd_url, json=payload)
|
input = body.get('input', body.get('text', ''))
|
||||||
r = response.json()
|
if not input:
|
||||||
# r['parameters']...
|
raise InvalidRequestError("Missing required argument input", params='input')
|
||||||
for b64_json in r['images']:
|
|
||||||
if response_format == 'b64_json':
|
|
||||||
resp['data'].extend([{'b64_json': b64_json}])
|
|
||||||
else:
|
|
||||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
|
||||||
|
|
||||||
response = json.dumps(resp)
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/embeddings' in self.path and embedding_model is not None:
|
|
||||||
self.send_response(200)
|
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
input = body['input'] if 'input' in body else body['text']
|
|
||||||
if type(input) is str:
|
if type(input) is str:
|
||||||
input = [input]
|
input = [input]
|
||||||
|
|
||||||
embeddings = embedding_model.encode(input).tolist()
|
response = OAIembeddings.embeddings(input, encoding_format)
|
||||||
|
|
||||||
def enc_emb(emb):
|
self.return_json(response, no_debug=True)
|
||||||
# If base64 is specified, encode. Otherwise, do nothing.
|
|
||||||
if body.get("encoding_format", "") == "base64":
|
|
||||||
return float_list_to_base64(emb)
|
|
||||||
else:
|
|
||||||
return emb
|
|
||||||
data = [{"object": "embedding", "embedding": enc_emb(emb), "index": n} for n, emb in enumerate(embeddings)]
|
|
||||||
|
|
||||||
response = json.dumps({
|
|
||||||
"object": "list",
|
|
||||||
"data": data,
|
|
||||||
"model": st_model, # return the real model
|
|
||||||
"usage": {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
if debug:
|
|
||||||
print(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif '/moderations' in self.path:
|
elif '/moderations' in self.path:
|
||||||
# for now do nothing, just don't error.
|
input = body['input']
|
||||||
self.send_response(200)
|
if not input:
|
||||||
self.send_access_control_headers()
|
raise InvalidRequestError("Missing required argument input", params='input')
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
response = json.dumps({
|
response = OAImoderations.moderations(input)
|
||||||
"id": "modr-5MWoLO",
|
|
||||||
"model": "text-moderation-001",
|
self.return_json(response, no_debug=True)
|
||||||
"results": [{
|
|
||||||
"categories": {
|
|
||||||
"hate": False,
|
|
||||||
"hate/threatening": False,
|
|
||||||
"self-harm": False,
|
|
||||||
"sexual": False,
|
|
||||||
"sexual/minors": False,
|
|
||||||
"violence": False,
|
|
||||||
"violence/graphic": False
|
|
||||||
},
|
|
||||||
"category_scores": {
|
|
||||||
"hate": 0.0,
|
|
||||||
"hate/threatening": 0.0,
|
|
||||||
"self-harm": 0.0,
|
|
||||||
"sexual": 0.0,
|
|
||||||
"sexual/minors": 0.0,
|
|
||||||
"violence": 0.0,
|
|
||||||
"violence/graphic": 0.0
|
|
||||||
},
|
|
||||||
"flagged": False
|
|
||||||
}]
|
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
|
||||||
|
|
||||||
elif self.path == '/api/v1/token-count':
|
elif self.path == '/api/v1/token-count':
|
||||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||||
self.send_response(200)
|
response = token_count(body['prompt'])
|
||||||
self.send_access_control_headers()
|
|
||||||
self.send_header('Content-Type', 'application/json')
|
|
||||||
self.end_headers()
|
|
||||||
|
|
||||||
tokens = encode(body['prompt'])[0]
|
self.return_json(response, no_debug=True)
|
||||||
response = json.dumps({
|
|
||||||
'results': [{
|
elif self.path == '/api/v1/token/encode':
|
||||||
'tokens': len(tokens)
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||||
}]
|
encoding_format = body.get('encoding_format', '')
|
||||||
})
|
|
||||||
self.wfile.write(response.encode('utf-8'))
|
response = token_encode(body['input'], encoding_format)
|
||||||
|
|
||||||
|
self.return_json(response, no_debug=True)
|
||||||
|
|
||||||
|
elif self.path == '/api/v1/token/decode':
|
||||||
|
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||||
|
encoding_format = body.get('encoding_format', '')
|
||||||
|
|
||||||
|
response = token_decode(body['input'], encoding_format)
|
||||||
|
|
||||||
|
self.return_json(response, no_debug=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
print(self.path, self.headers)
|
|
||||||
self.send_error(404)
|
self.send_error(404)
|
||||||
|
|
||||||
|
|
||||||
def run_server():
|
def run_server():
|
||||||
global embedding_model
|
|
||||||
try:
|
|
||||||
embedding_model = SentenceTransformer(st_model)
|
|
||||||
print(f"\nLoaded embedding model: {st_model}, max sequence length: {embedding_model.max_seq_length}")
|
|
||||||
except:
|
|
||||||
print(f"\nFailed to load embedding model: {st_model}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||||
server = ThreadingHTTPServer(server_addr, Handler)
|
server = ThreadingHTTPServer(server_addr, Handler)
|
||||||
if shared.args.share:
|
if shared.args.share:
|
||||||
|
38
extensions/openai/tokens.py
Normal file
38
extensions/openai/tokens.py
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
from extensions.openai.utils import float_list_to_base64
|
||||||
|
from modules.text_generation import encode, decode
|
||||||
|
|
||||||
|
|
||||||
|
def token_count(prompt):
|
||||||
|
tokens = encode(prompt)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'tokens': len(tokens)
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def token_encode(input, encoding_format=''):
|
||||||
|
# if isinstance(input, list):
|
||||||
|
tokens = encode(input)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'encoding_format': encoding_format,
|
||||||
|
'tokens': float_list_to_base64(tokens) if encoding_format == "base64" else tokens,
|
||||||
|
'length': len(tokens),
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def token_decode(tokens, encoding_format):
|
||||||
|
# if isinstance(input, list):
|
||||||
|
# if encoding_format == "base64":
|
||||||
|
# tokens = base64_to_float_list(tokens)
|
||||||
|
output = decode(tokens)[0]
|
||||||
|
|
||||||
|
return {
|
||||||
|
'results': [{
|
||||||
|
'text': output
|
||||||
|
}]
|
||||||
|
}
|
29
extensions/openai/utils.py
Normal file
29
extensions/openai/utils.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
import os
|
||||||
|
import base64
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def float_list_to_base64(float_list):
|
||||||
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
|
float_array = np.array(float_list, dtype="float32")
|
||||||
|
|
||||||
|
# Get raw bytes
|
||||||
|
bytes_array = float_array.tobytes()
|
||||||
|
|
||||||
|
# Encode bytes into base64
|
||||||
|
encoded_bytes = base64.b64encode(bytes_array)
|
||||||
|
|
||||||
|
# Turn raw base64 encoded bytes into ASCII
|
||||||
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
|
def end_line(s):
|
||||||
|
if s and s[-1] != '\n':
|
||||||
|
s = s + '\n'
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
def debug_msg(*args, **kwargs):
|
||||||
|
if 'OPENEDAI_DEBUG' in os.environ:
|
||||||
|
print(*args, **kwargs)
|
@ -7,6 +7,7 @@ from transformers import BlipForConditionalGeneration, BlipProcessor
|
|||||||
|
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.ui import gather_interface_values
|
from modules.ui import gather_interface_values
|
||||||
|
from modules.utils import gradio
|
||||||
|
|
||||||
# If 'state' is True, will hijack the next chat generation with
|
# If 'state' is True, will hijack the next chat generation with
|
||||||
# custom input text given by 'value' in the format [text, visible_text]
|
# custom input text given by 'value' in the format [text, visible_text]
|
||||||
@ -42,6 +43,6 @@ def ui():
|
|||||||
# Prepare the input hijack, update the interface values, call the generation function, and clear the picture
|
# Prepare the input hijack, update the interface values, call the generation function, and clear the picture
|
||||||
picture_select.upload(
|
picture_select.upload(
|
||||||
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
|
lambda picture, name1, name2: input_hijack.update({"state": True, "value": generate_chat_picture(picture, name1, name2)}), [picture_select, shared.gradio['name1'], shared.gradio['name2']], None).then(
|
||||||
gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
gather_interface_values, gradio(shared.input_elements), gradio('interface_state')).then(
|
||||||
chat.generate_chat_reply_wrapper, shared.input_params, shared.gradio['display'], show_progress=False).then(
|
chat.generate_chat_reply_wrapper, shared.input_params, gradio('display', 'history'), show_progress=False).then(
|
||||||
lambda: None, None, picture_select, show_progress=False)
|
lambda: None, None, picture_select, show_progress=False)
|
||||||
|
@ -2,3 +2,4 @@ beautifulsoup4==4.12.2
|
|||||||
chromadb==0.3.18
|
chromadb==0.3.18
|
||||||
posthog==2.4.2
|
posthog==2.4.2
|
||||||
sentence_transformers==2.2.2
|
sentence_transformers==2.2.2
|
||||||
|
lxml
|
||||||
|
@ -69,7 +69,7 @@ def feed_url_into_collector(urls, chunk_len, chunk_sep, strong_cleanup, threads)
|
|||||||
cumulative += 'Processing the HTML sources...'
|
cumulative += 'Processing the HTML sources...'
|
||||||
yield cumulative
|
yield cumulative
|
||||||
for content in contents:
|
for content in contents:
|
||||||
soup = BeautifulSoup(content, features="html.parser")
|
soup = BeautifulSoup(content, features="lxml")
|
||||||
for script in soup(["script", "style"]):
|
for script in soup(["script", "style"]):
|
||||||
script.extract()
|
script.extract()
|
||||||
|
|
||||||
@ -113,7 +113,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
||||||
chunks = []
|
chunks = []
|
||||||
hist_size = len(history['internal'])
|
hist_size = len(history['internal'])
|
||||||
for i in range(hist_size-1):
|
for i in range(hist_size - 1):
|
||||||
chunks.append(make_single_exchange(i))
|
chunks.append(make_single_exchange(i))
|
||||||
|
|
||||||
add_chunks_to_collector(chunks, chat_collector)
|
add_chunks_to_collector(chunks, chat_collector)
|
||||||
|
@ -16,7 +16,7 @@ params = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def do_stt(audio,whipser_model,whipser_language):
|
def do_stt(audio, whipser_model, whipser_language):
|
||||||
transcription = ""
|
transcription = ""
|
||||||
r = sr.Recognizer()
|
r = sr.Recognizer()
|
||||||
|
|
||||||
@ -33,10 +33,10 @@ def do_stt(audio,whipser_model,whipser_language):
|
|||||||
return transcription
|
return transcription
|
||||||
|
|
||||||
|
|
||||||
def auto_transcribe(audio, auto_submit,whipser_model,whipser_language):
|
def auto_transcribe(audio, auto_submit, whipser_model, whipser_language):
|
||||||
if audio is None:
|
if audio is None:
|
||||||
return "", ""
|
return "", ""
|
||||||
transcription = do_stt(audio,whipser_model,whipser_language)
|
transcription = do_stt(audio, whipser_model, whipser_language)
|
||||||
if auto_submit:
|
if auto_submit:
|
||||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||||
|
|
||||||
@ -50,11 +50,11 @@ def ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Accordion("Settings", open=False):
|
with gr.Accordion("Settings", open=False):
|
||||||
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
|
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
|
||||||
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'],choices=["tiny.en","base.en", "small.en","medium.en","tiny","base","small","medium","large"])
|
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
|
||||||
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'],choices=["chinese","german","spanish","russian","korean","french","japanese","portuguese","turkish","polish","catalan","dutch","arabic","swedish","italian","indonesian","hindi","finnish","vietnamese","hebrew","ukrainian","greek","malay","czech","romanian","danish","hungarian","tamil","norwegian","thai","urdu","croatian","bulgarian","lithuanian","latin","maori","malayalam","welsh","slovak","telugu","persian","latvian","bengali","serbian","azerbaijani","slovenian","kannada","estonian","macedonian","breton","basque","icelandic","armenian","nepali","mongolian","bosnian","kazakh","albanian","swahili","galician","marathi","punjabi","sinhala","khmer","shona","yoruba","somali","afrikaans","occitan","georgian","belarusian","tajik","sindhi","gujarati","amharic","yiddish","lao","uzbek","faroese","haitian creole","pashto","turkmen","nynorsk","maltese","sanskrit","luxembourgish","myanmar","tibetan","tagalog","malagasy","assamese","tatar","hawaiian","lingala","hausa","bashkir","javanese","sundanese"])
|
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
|
||||||
|
|
||||||
audio.change(
|
audio.change(
|
||||||
auto_transcribe, [audio, auto_submit,whipser_model,whipser_language], [shared.gradio['textbox'], audio]).then(
|
auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
|
||||||
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
||||||
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
|
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
|
||||||
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
|
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
|
||||||
|
@ -97,6 +97,8 @@ llama-65b-gptq-3bit:
|
|||||||
.*raven:
|
.*raven:
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'RWKV-Raven'
|
instruction_template: 'RWKV-Raven'
|
||||||
|
.*ctx8192:
|
||||||
|
truncation_length: 8192
|
||||||
.*moss-moon.*sft:
|
.*moss-moon.*sft:
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'MOSS'
|
instruction_template: 'MOSS'
|
||||||
@ -143,6 +145,7 @@ llama-65b-gptq-3bit:
|
|||||||
.*wizard.*mega:
|
.*wizard.*mega:
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'Wizard-Mega'
|
instruction_template: 'Wizard-Mega'
|
||||||
|
custom_stopping_strings: '"</s>"'
|
||||||
.*ziya-:
|
.*ziya-:
|
||||||
mode: 'instruct'
|
mode: 'instruct'
|
||||||
instruction_template: 'Ziya'
|
instruction_template: 'Ziya'
|
||||||
@ -243,3 +246,26 @@ TheBloke_WizardLM-30B-GPTQ:
|
|||||||
.*xgen.*-inst:
|
.*xgen.*-inst:
|
||||||
truncation_length: 8192
|
truncation_length: 8192
|
||||||
instruction_template: 'Vicuna-v0'
|
instruction_template: 'Vicuna-v0'
|
||||||
|
.*(platypus|gplatty|superplatty):
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Alpaca'
|
||||||
|
.*longchat:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Vicuna-v1.1'
|
||||||
|
.*vicuna-33b:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Vicuna-v1.1'
|
||||||
|
.*redmond-hermes-coder:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Alpaca'
|
||||||
|
truncation_length: 8192
|
||||||
|
.*wizardcoder-15b:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Alpaca'
|
||||||
|
truncation_length: 8192
|
||||||
|
.*wizardlm-.*-v1.1:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Vicuna-v1.1'
|
||||||
|
.*godzilla:
|
||||||
|
mode: 'instruct'
|
||||||
|
instruction_template: 'Alpaca'
|
||||||
|
@ -3,6 +3,7 @@ import copy
|
|||||||
import functools
|
import functools
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@ -388,8 +389,25 @@ def load_history(file, history):
|
|||||||
return history
|
return history
|
||||||
|
|
||||||
|
|
||||||
|
def save_history_at_user_request(history, character, mode):
|
||||||
|
def make_timestamp_path(character=None):
|
||||||
|
return f"logs/{character or ''}{'_' if character else ''}{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||||
|
|
||||||
|
path = None
|
||||||
|
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None]:
|
||||||
|
path = make_timestamp_path(character)
|
||||||
|
else:
|
||||||
|
# Try to use mode as the file name, otherwise just use the timestamp
|
||||||
|
try:
|
||||||
|
path = make_timestamp_path(mode.capitalize())
|
||||||
|
except:
|
||||||
|
path = make_timestamp_path()
|
||||||
|
|
||||||
|
return save_history(history, path)
|
||||||
|
|
||||||
|
|
||||||
def save_persistent_history(history, character, mode):
|
def save_persistent_history(history, character, mode):
|
||||||
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
|
if mode in ['chat', 'chat-instruct'] and character not in ['', 'None', None] and not shared.args.multi_user:
|
||||||
save_history(history, path=Path(f'logs/{character}_persistent.json'))
|
save_history(history, path=Path(f'logs/{character}_persistent.json'))
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +49,7 @@ class LlamaCppModel:
|
|||||||
'n_batch': shared.args.n_batch,
|
'n_batch': shared.args.n_batch,
|
||||||
'use_mmap': not shared.args.no_mmap,
|
'use_mmap': not shared.args.no_mmap,
|
||||||
'use_mlock': shared.args.mlock,
|
'use_mlock': shared.args.mlock,
|
||||||
|
'low_vram': shared.args.low_vram,
|
||||||
'n_gpu_layers': shared.args.n_gpu_layers
|
'n_gpu_layers': shared.args.n_gpu_layers
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,6 +34,7 @@ loaders_and_params = {
|
|||||||
'n_batch',
|
'n_batch',
|
||||||
'threads',
|
'threads',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
|
'low_vram',
|
||||||
'mlock',
|
'mlock',
|
||||||
'llama_cpp_seed',
|
'llama_cpp_seed',
|
||||||
],
|
],
|
||||||
@ -53,14 +54,14 @@ loaders_and_params = {
|
|||||||
'trust_remote_code',
|
'trust_remote_code',
|
||||||
'transformers_info'
|
'transformers_info'
|
||||||
],
|
],
|
||||||
'ExLlama' : [
|
'ExLlama': [
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'max_seq_len',
|
'max_seq_len',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
'alpha_value',
|
'alpha_value',
|
||||||
'exllama_info',
|
'exllama_info',
|
||||||
],
|
],
|
||||||
'ExLlama_HF' : [
|
'ExLlama_HF': [
|
||||||
'gpu_split',
|
'gpu_split',
|
||||||
'max_seq_len',
|
'max_seq_len',
|
||||||
'compress_pos_emb',
|
'compress_pos_emb',
|
||||||
|
@ -106,7 +106,7 @@ def load_tokenizer(model_name, model):
|
|||||||
use_fast=False
|
use_fast=False
|
||||||
)
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
path_to_model,
|
path_to_model,
|
||||||
trust_remote_code=shared.args.trust_remote_code,
|
trust_remote_code=shared.args.trust_remote_code,
|
||||||
use_fast=True
|
use_fast=True
|
||||||
@ -339,6 +339,7 @@ def clear_torch_cache():
|
|||||||
def unload_model():
|
def unload_model():
|
||||||
shared.model = shared.tokenizer = None
|
shared.model = shared.tokenizer = None
|
||||||
shared.lora_names = []
|
shared.lora_names = []
|
||||||
|
shared.model_dirty_from_training = False
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@ -99,7 +99,10 @@ def apply_model_settings_to_state(model, state):
|
|||||||
|
|
||||||
for k in model_settings:
|
for k in model_settings:
|
||||||
if k in state:
|
if k in state:
|
||||||
state[k] = model_settings[k]
|
if k in ['wbits', 'groupsize']:
|
||||||
|
state[k] = str(model_settings[k])
|
||||||
|
else:
|
||||||
|
state[k] = model_settings[k]
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||||||
'''
|
'''
|
||||||
Copied from the transformers library
|
Copied from the transformers library
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, penalty: float, _range: int):
|
def __init__(self, penalty: float, _range: int):
|
||||||
if not isinstance(penalty, float) or not (penalty > 0):
|
if not isinstance(penalty, float) or not (penalty > 0):
|
||||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||||
|
@ -12,6 +12,7 @@ tokenizer = None
|
|||||||
is_seq2seq = False
|
is_seq2seq = False
|
||||||
model_name = "None"
|
model_name = "None"
|
||||||
lora_names = []
|
lora_names = []
|
||||||
|
model_dirty_from_training = False
|
||||||
|
|
||||||
# Chat variables
|
# Chat variables
|
||||||
stop_everything = False
|
stop_everything = False
|
||||||
@ -120,6 +121,7 @@ parser.add_argument('--use_double_quant', action='store_true', help='use_double_
|
|||||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use.')
|
||||||
parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
|
parser.add_argument('--n_batch', type=int, default=512, help='Maximum number of prompt tokens to batch together when calling llama_eval.')
|
||||||
parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
parser.add_argument('--no-mmap', action='store_true', help='Prevent mmap from being used.')
|
||||||
|
parser.add_argument('--low-vram', action='store_true', help='Low VRAM Mode')
|
||||||
parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
parser.add_argument('--mlock', action='store_true', help='Force the system to keep the model in RAM.')
|
||||||
parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
parser.add_argument('--cache-capacity', type=str, help='Maximum cache capacity. Examples: 2000MiB, 2GiB. When provided without units, bytes will be assumed.')
|
||||||
parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
|
parser.add_argument('--n-gpu-layers', type=int, default=0, help='Number of layers to offload to the GPU.')
|
||||||
@ -179,7 +181,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
|||||||
# API
|
# API
|
||||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
||||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||||
|
|
||||||
# Multimodal
|
# Multimodal
|
||||||
|
@ -1,18 +1,23 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["WANDB_MODE"] = "offline"
|
||||||
|
# os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import random
|
import random
|
||||||
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
|
from modules.models import load_model, unload_model
|
||||||
import shutil
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from datasets import Dataset, load_dataset
|
from datasets import Dataset, load_dataset
|
||||||
from peft import (
|
from peft import (
|
||||||
@ -29,6 +34,7 @@ from modules.evaluate import (
|
|||||||
save_past_evaluations
|
save_past_evaluations
|
||||||
)
|
)
|
||||||
from modules.logging_colors import logger
|
from modules.logging_colors import logger
|
||||||
|
from modules.utils import natural_keys
|
||||||
|
|
||||||
# This mapping is from a very recent commit, not yet released.
|
# This mapping is from a very recent commit, not yet released.
|
||||||
# If not available, default to a backup map for some common model types.
|
# If not available, default to a backup map for some common model types.
|
||||||
@ -56,7 +62,7 @@ train_log = {}
|
|||||||
train_template = {}
|
train_template = {}
|
||||||
|
|
||||||
WANT_INTERRUPT = False
|
WANT_INTERRUPT = False
|
||||||
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss"]
|
PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", "batch_size", "epochs", "learning_rate", "lr_scheduler_type", "lora_rank", "lora_alpha", "lora_dropout", "cutoff_len", "dataset", "eval_dataset", "format", "eval_steps", "raw_text_file", "overlap_len", "newline_favor_len", "higher_rank_limit", "warmup_steps", "optimizer", "hard_cut_string", "train_only_after", "stop_at_loss", "add_eos_token", "min_chars", "report_to"]
|
||||||
|
|
||||||
|
|
||||||
def create_train_interface():
|
def create_train_interface():
|
||||||
@ -104,6 +110,7 @@ def create_train_interface():
|
|||||||
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
raw_text_file = gr.Dropdown(choices=utils.get_datasets('training/datasets', 'txt'), value='None', label='Text file', info='The raw text file to use for training.')
|
||||||
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
ui.create_refresh_button(raw_text_file, lambda: None, lambda: {'choices': utils.get_datasets('training/datasets', 'txt')}, 'refresh-button')
|
||||||
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
hard_cut_string = gr.Textbox(label='Hard Cut String', value='\\n\\n\\n', info='String that indicates a hard cut between text parts. Helps prevent unwanted overlap.')
|
||||||
|
min_chars = gr.Number(label='Ignore small blocks', value=0, info='Ignore Hard Cut blocks that have less or equal characters than this number')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
overlap_len = gr.Slider(label='Overlap Length', minimum=0, maximum=512, value=128, step=16, info='Overlap length - ie how many tokens from the prior chunk of text to include into the next chunk. (The chunks themselves will be of a size determined by Cutoff Length below). Setting overlap to exactly half the cutoff length may be ideal.')
|
||||||
@ -115,9 +122,12 @@ def create_train_interface():
|
|||||||
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
optimizer = gr.Dropdown(label='Optimizer', value='adamw_torch', choices=['adamw_hf', 'adamw_torch', 'adamw_torch_fused', 'adamw_torch_xla', 'adamw_apex_fused', 'adafactor', 'adamw_bnb_8bit', 'adamw_anyprecision', 'sgd', 'adagrad'], info='Different optimizer implementation options, for advanced users. Effects of different options are not well documented yet.')
|
||||||
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
train_only_after = gr.Textbox(label='Train Only After', value='', info='Only consider text *after* this string in any given chunk for training. For Alpaca datasets, use "### Response:" to only train the response and ignore the input.')
|
||||||
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
stop_at_loss = gr.Slider(label='Stop at loss', minimum=0.0, maximum=3.0, step=0.1, value=0.00, info='The process will automatically stop once the desired loss value is reached. (reasonable numbers are 1.5-1.8)')
|
||||||
|
add_eos_token = gr.Checkbox(label='Add EOS token', value=False, info="Adds EOS token for each dataset item. In case of raw text, the EOS will be added at the Hard Cut")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
higher_rank_limit = gr.Checkbox(label='Enable higher ranks', value=False, info='If checked, changes Rank/Alpha slider above to go much higher. This will not work without a datacenter-class GPU.')
|
||||||
|
with gr.Row():
|
||||||
|
report_to = gr.Radio(label="Save detailed logs with", value="None", choices=["None", "wandb", "tensorboard"], interactive=True)
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_button = gr.Button("Start LoRA Training")
|
start_button = gr.Button("Start LoRA Training")
|
||||||
@ -148,7 +158,9 @@ def create_train_interface():
|
|||||||
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
refresh_table = gr.Button('Refresh the table', elem_classes="small-button")
|
||||||
|
|
||||||
# Training events
|
# Training events
|
||||||
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss]
|
|
||||||
|
all_params = [lora_name, always_override, save_steps, micro_batch_size, batch_size, epochs, learning_rate, lr_scheduler_type, lora_rank, lora_alpha, lora_dropout, cutoff_len, dataset, eval_dataset, format, eval_steps, raw_text_file, overlap_len, newline_favor_len, higher_rank_limit, warmup_steps, optimizer, hard_cut_string, train_only_after, stop_at_loss, add_eos_token, min_chars, report_to]
|
||||||
|
|
||||||
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
copy_from.change(do_copy_params, [copy_from] + all_params, all_params)
|
||||||
start_button.click(do_train, all_params, output)
|
start_button.click(do_train, all_params, output)
|
||||||
stop_button.click(do_interrupt, None, None, queue=False)
|
stop_button.click(do_interrupt, None, None, queue=False)
|
||||||
@ -240,6 +252,7 @@ def backup_adapter(input_folder):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("An error occurred in backup_adapter:", str(e))
|
print("An error occurred in backup_adapter:", str(e))
|
||||||
|
|
||||||
|
|
||||||
def calc_trainable_parameters(model):
|
def calc_trainable_parameters(model):
|
||||||
trainable_params = 0
|
trainable_params = 0
|
||||||
all_param = 0
|
all_param = 0
|
||||||
@ -253,10 +266,10 @@ def calc_trainable_parameters(model):
|
|||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
trainable_params += num_params
|
trainable_params += num_params
|
||||||
|
|
||||||
return trainable_params,all_param
|
return trainable_params, all_param
|
||||||
|
|
||||||
|
|
||||||
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float):
|
def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch_size: int, batch_size: int, epochs: int, learning_rate: str, lr_scheduler_type: str, lora_rank: int, lora_alpha: int, lora_dropout: float, cutoff_len: int, dataset: str, eval_dataset: str, format: str, eval_steps: int, raw_text_file: str, overlap_len: int, newline_favor_len: int, higher_rank_limit: bool, warmup_steps: int, optimizer: str, hard_cut_string: str, train_only_after: str, stop_at_loss: float, add_eos_token: bool, min_chars: int, report_to: str):
|
||||||
|
|
||||||
if shared.args.monkey_patch:
|
if shared.args.monkey_patch:
|
||||||
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
from monkeypatch.peft_tuners_lora_monkey_patch import (
|
||||||
@ -314,14 +327,22 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def encode(text, add_bos_token):
|
def encode(text, add_bos_token):
|
||||||
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
result = shared.tokenizer.encode(text, truncation=True, max_length=cutoff_len)
|
||||||
|
# Check if the first two tokens are BOS
|
||||||
|
if len(result) >= 2 and result[:2] == [shared.tokenizer.bos_token_id, shared.tokenizer.bos_token_id]:
|
||||||
|
result = result[1:]
|
||||||
|
|
||||||
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
if not add_bos_token and result[0] == shared.tokenizer.bos_token_id:
|
||||||
result = result[1:]
|
result = result[1:]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def tokenize(prompt):
|
def tokenize(prompt, append_eos_token=False):
|
||||||
|
|
||||||
if train_only_after == '' or train_only_after not in prompt:
|
if train_only_after == '' or train_only_after not in prompt:
|
||||||
input_ids = encode(prompt, True)
|
input_ids = encode(prompt, True)
|
||||||
|
|
||||||
|
if append_eos_token and input_ids[-1] != shared.tokenizer.eos_token_id and len(input_ids) < cutoff_len:
|
||||||
|
input_ids.append(shared.tokenizer.eos_token_id)
|
||||||
|
|
||||||
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
input_ids = [shared.tokenizer.pad_token_id] * (cutoff_len - len(input_ids)) + input_ids
|
||||||
labels = [1] * len(input_ids)
|
labels = [1] * len(input_ids)
|
||||||
|
|
||||||
@ -330,6 +351,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
before_tokens = encode(prompt[:ind], True)
|
before_tokens = encode(prompt[:ind], True)
|
||||||
after_tokens = encode(prompt[ind:], False)
|
after_tokens = encode(prompt[ind:], False)
|
||||||
|
|
||||||
|
if append_eos_token and after_tokens[-1] != shared.tokenizer.eos_token_id:
|
||||||
|
after_tokens.append(shared.tokenizer.eos_token_id)
|
||||||
|
|
||||||
full_length = len(after_tokens) + len(before_tokens)
|
full_length = len(after_tokens) + len(before_tokens)
|
||||||
if full_length > cutoff_len:
|
if full_length > cutoff_len:
|
||||||
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
after_tokens = after_tokens[:cutoff_len - len(before_tokens)]
|
||||||
@ -350,31 +374,46 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
logger.info("Loading raw text file dataset...")
|
|
||||||
|
|
||||||
train_template["template_type"] = "raw_text"
|
train_template["template_type"] = "raw_text"
|
||||||
|
logger.info("Loading raw text file dataset...")
|
||||||
|
fullpath = clean_path('training/datasets', f'{raw_text_file}')
|
||||||
|
fullpath = Path(fullpath)
|
||||||
|
if fullpath.is_dir():
|
||||||
|
logger.info('Training path directory {}'.format(raw_text_file))
|
||||||
|
raw_text = ""
|
||||||
|
file_paths = sorted(fullpath.glob('*.txt'), key=lambda path: natural_keys(path.name))
|
||||||
|
for file_path in file_paths:
|
||||||
|
if file_path.is_file():
|
||||||
|
with file_path.open('r', encoding='utf-8') as file:
|
||||||
|
raw_text += file.read().replace('\r', '')
|
||||||
|
|
||||||
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
logger.info(f"Loaded training file: {file_path.name}")
|
||||||
raw_text = file.read().replace('\r', '')
|
else:
|
||||||
|
with open(clean_path('training/datasets', f'{raw_text_file}.txt'), 'r', encoding='utf-8') as file:
|
||||||
|
raw_text = file.read().replace('\r', '')
|
||||||
|
|
||||||
cut_string = hard_cut_string.replace('\\n', '\n')
|
cut_string = hard_cut_string.replace('\\n', '\n')
|
||||||
|
eos_added = 0
|
||||||
out_tokens = []
|
out_tokens = []
|
||||||
for text_part in raw_text.split(cut_string):
|
for text_part in raw_text.split(cut_string):
|
||||||
if text_part.strip() == '':
|
|
||||||
|
if len(text_part.strip()) <= min_chars:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tokens = shared.tokenizer.encode(text_part)
|
tokens = shared.tokenizer.encode(text_part)
|
||||||
|
if add_eos_token:
|
||||||
|
tokens.append(shared.tokenizer.eos_token_id)
|
||||||
|
eos_added += 1
|
||||||
|
|
||||||
step = cutoff_len - overlap_len
|
step = cutoff_len - overlap_len
|
||||||
if step <= 0:
|
if step <= 0:
|
||||||
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
yield f"Error: overlap_len ({overlap_len}) cannot be greater than or equal to cutoff_len ({cutoff_len})"
|
||||||
return
|
return
|
||||||
|
|
||||||
tokens = list(split_chunks(tokens, step))
|
out_tokens.extend(split_chunks(tokens, cutoff_len, step))
|
||||||
for i in range(1, len(tokens)):
|
|
||||||
tokens[i] = tokens[i - 1][-overlap_len:] + tokens[i]
|
|
||||||
|
|
||||||
out_tokens.extend(tokens)
|
if eos_added > 0:
|
||||||
del tokens
|
print(f"EOS added to {eos_added} text blocks")
|
||||||
|
|
||||||
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
del raw_text # Note: could be a gig for a large dataset, so delete redundant data as we go to be safe on RAM
|
||||||
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
text_chunks = [shared.tokenizer.decode(x) for x in out_tokens]
|
||||||
@ -415,7 +454,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
def generate_and_tokenize_prompt(data_point):
|
def generate_and_tokenize_prompt(data_point):
|
||||||
prompt = generate_prompt(data_point)
|
prompt = generate_prompt(data_point)
|
||||||
return tokenize(prompt)
|
return tokenize(prompt, add_eos_token)
|
||||||
|
|
||||||
logger.info("Loading JSON datasets...")
|
logger.info("Loading JSON datasets...")
|
||||||
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
data = load_dataset("json", data_files=clean_path('training/datasets', f'{dataset}.json'))
|
||||||
@ -427,11 +466,33 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
eval_data = load_dataset("json", data_files=clean_path('training/datasets', f'{eval_dataset}.json'))
|
||||||
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
eval_data = eval_data['train'].map(generate_and_tokenize_prompt, new_fingerprint='%030x' % random.randrange(16**30))
|
||||||
|
|
||||||
|
# == We MUST reload model if it went through any previous training, even failed one ==
|
||||||
|
if shared.model_dirty_from_training:
|
||||||
|
selected_model = shared.model_name
|
||||||
|
if selected_model:
|
||||||
|
print("\033[1;31;1m(Model has been modified by previous training, it needs to be reloaded...)\033[0;37;0m")
|
||||||
|
try:
|
||||||
|
yield f"Reloading {selected_model}..."
|
||||||
|
unload_model()
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name, None)
|
||||||
|
if shared.model is not None:
|
||||||
|
print("Model reloaded OK, continue with training.")
|
||||||
|
else:
|
||||||
|
return f"Failed to load {selected_model}."
|
||||||
|
except:
|
||||||
|
exc = traceback.format_exc()
|
||||||
|
logger.error('Failed to reload the model.')
|
||||||
|
print(exc)
|
||||||
|
return exc
|
||||||
|
|
||||||
# == Start prepping the model itself ==
|
# == Start prepping the model itself ==
|
||||||
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
if not hasattr(shared.model, 'lm_head') or hasattr(shared.model.lm_head, 'weight'):
|
||||||
logger.info("Getting model ready...")
|
logger.info("Getting model ready...")
|
||||||
prepare_model_for_int8_training(shared.model)
|
prepare_model_for_int8_training(shared.model)
|
||||||
|
|
||||||
|
# base model is now frozen and should not be reused for any other LoRA training than this one
|
||||||
|
shared.model_dirty_from_training = True
|
||||||
|
|
||||||
logger.info("Prepping for training...")
|
logger.info("Prepping for training...")
|
||||||
config = LoraConfig(
|
config = LoraConfig(
|
||||||
r=lora_rank,
|
r=lora_rank,
|
||||||
@ -518,6 +579,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
train_dataset=train_data,
|
train_dataset=train_data,
|
||||||
eval_dataset=eval_data,
|
eval_dataset=eval_data,
|
||||||
args=transformers.TrainingArguments(
|
args=transformers.TrainingArguments(
|
||||||
|
report_to=report_to if report_to != "None" else None,
|
||||||
per_device_train_batch_size=micro_batch_size,
|
per_device_train_batch_size=micro_batch_size,
|
||||||
gradient_accumulation_steps=gradient_accumulation_steps,
|
gradient_accumulation_steps=gradient_accumulation_steps,
|
||||||
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
warmup_steps=math.ceil(warmup_steps / gradient_accumulation_steps),
|
||||||
@ -534,7 +596,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
load_best_model_at_end=eval_data is not None,
|
load_best_model_at_end=eval_data is not None,
|
||||||
# TODO: Enable multi-device support
|
# TODO: Enable multi-device support
|
||||||
ddp_find_unused_parameters=None,
|
ddp_find_unused_parameters=None,
|
||||||
no_cuda=shared.args.cpu
|
no_cuda=shared.args.cpu,
|
||||||
),
|
),
|
||||||
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
data_collator=transformers.DataCollatorForLanguageModeling(shared.tokenizer, mlm=False),
|
||||||
callbacks=list([Callbacks()])
|
callbacks=list([Callbacks()])
|
||||||
@ -560,14 +622,18 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
lora_trainable_param, lora_all_param = calc_trainable_parameters(lora_model)
|
||||||
|
|
||||||
if lora_all_param>0:
|
projections_string = ", ".join([projection.replace("_proj", "") for projection in model_to_lora_modules[model_id]])
|
||||||
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
|
||||||
|
|
||||||
|
print(f"Training '{model_id}' model using ({projections_string}) projections")
|
||||||
|
|
||||||
|
if lora_all_param > 0:
|
||||||
|
print(f"Trainable params: {lora_trainable_param:,d} ({100 * lora_trainable_param / lora_all_param:.4f} %), All params: {lora_all_param:,d} (Model: {model_all_params:,d})")
|
||||||
|
|
||||||
train_log.update({"base_model_name": shared.model_name})
|
train_log.update({"base_model_name": shared.model_name})
|
||||||
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
train_log.update({"base_model_class": shared.model.__class__.__name__})
|
||||||
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
train_log.update({"base_loaded_in_4bit": getattr(lora_model, "is_loaded_in_4bit", False)})
|
||||||
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
train_log.update({"base_loaded_in_8bit": getattr(lora_model, "is_loaded_in_8bit", False)})
|
||||||
|
train_log.update({"projections": projections_string})
|
||||||
|
|
||||||
if stop_at_loss > 0:
|
if stop_at_loss > 0:
|
||||||
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
print(f"Monitoring loss \033[1;31;1m(Auto-Stop at: {stop_at_loss})\033[0;37;0m")
|
||||||
@ -576,7 +642,26 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
yield "Interrupted before start."
|
yield "Interrupted before start."
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def log_train_dataset(trainer):
|
||||||
|
decoded_entries = []
|
||||||
|
# Try to decode the entries and write the log file
|
||||||
|
try:
|
||||||
|
# Iterate over the first 10 elements in the dataset (or fewer if there are less than 10)
|
||||||
|
for i in range(min(10, len(trainer.train_dataset))):
|
||||||
|
decoded_text = shared.tokenizer.decode(trainer.train_dataset[i]['input_ids'])
|
||||||
|
decoded_entries.append({"value": decoded_text})
|
||||||
|
|
||||||
|
# Write the log file
|
||||||
|
Path('logs').mkdir(exist_ok=True)
|
||||||
|
with open(Path('logs/train_dataset_sample.json'), 'w') as json_file:
|
||||||
|
json.dump(decoded_entries, json_file, indent=4)
|
||||||
|
|
||||||
|
logger.info("Log file 'train_dataset_sample.json' created in the 'logs' directory.")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to create log file due to error: {e}")
|
||||||
|
|
||||||
def threaded_run():
|
def threaded_run():
|
||||||
|
log_train_dataset(trainer)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path)
|
||||||
@ -625,9 +710,9 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
yield f"Done! LoRA saved to `{lora_file_path}`"
|
yield f"Done! LoRA saved to `{lora_file_path}`"
|
||||||
|
|
||||||
|
|
||||||
def split_chunks(arr, step):
|
def split_chunks(arr, size, step):
|
||||||
for i in range(0, len(arr), step):
|
for i in range(0, len(arr), step):
|
||||||
yield arr[i:i + step]
|
yield arr[i:i + size]
|
||||||
|
|
||||||
|
|
||||||
def cut_chunk_for_newline(chunk: str, max_length: int):
|
def cut_chunk_for_newline(chunk: str, max_length: int):
|
||||||
|
@ -57,6 +57,7 @@ def list_model_elements():
|
|||||||
'threads',
|
'threads',
|
||||||
'n_batch',
|
'n_batch',
|
||||||
'no_mmap',
|
'no_mmap',
|
||||||
|
'low_vram',
|
||||||
'mlock',
|
'mlock',
|
||||||
'n_gpu_layers',
|
'n_gpu_layers',
|
||||||
'n_ctx',
|
'n_ctx',
|
||||||
@ -159,7 +160,7 @@ def apply_interface_values(state, use_persistent=False):
|
|||||||
return [state[k] if k in state else gr.update() for k in elements]
|
return [state[k] if k in state else gr.update() for k in elements]
|
||||||
|
|
||||||
|
|
||||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
class ToolButton(gr.Button, gr.components.IOComponent):
|
||||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
@ -114,6 +114,10 @@ def get_available_loras():
|
|||||||
|
|
||||||
|
|
||||||
def get_datasets(path: str, ext: str):
|
def get_datasets(path: str, ext: str):
|
||||||
|
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
|
||||||
|
if ext == "txt":
|
||||||
|
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||||
|
|
||||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
|
@ -16,10 +16,12 @@ safetensors==0.3.1
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
tqdm
|
tqdm
|
||||||
scipy
|
scipy
|
||||||
|
tensorboard
|
||||||
|
wandb
|
||||||
transformers==4.30.2
|
transformers==4.30.2
|
||||||
git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524
|
git+https://github.com/huggingface/peft@03eb378eb914fbee709ff7c86ba5b1d033b89524
|
||||||
bitsandbytes==0.39.1; platform_system != "Windows"
|
bitsandbytes==0.40.0; platform_system != "Windows"
|
||||||
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.39.1-py3-none-win_amd64.whl; platform_system == "Windows"
|
https://github.com/jllllll/bitsandbytes-windows-webui/releases/download/wheels/bitsandbytes-0.40.0-py3-none-win_amd64.whl; platform_system == "Windows"
|
||||||
llama-cpp-python==0.1.70; platform_system != "Windows"
|
llama-cpp-python==0.1.70; platform_system != "Windows"
|
||||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.70/llama_cpp_python-0.1.70-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.70/llama_cpp_python-0.1.70-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
https://github.com/PanQiWei/AutoGPTQ/releases/download/v0.2.2/auto_gptq-0.2.2+cu117-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||||
|
17
server.py
17
server.py
@ -54,7 +54,7 @@ from modules.utils import gradio
|
|||||||
|
|
||||||
def load_model_wrapper(selected_model, loader, autoload=False):
|
def load_model_wrapper(selected_model, loader, autoload=False):
|
||||||
if not autoload:
|
if not autoload:
|
||||||
yield f"The settings for {selected_model} have been updated.\nClick on \"Load the model\" to load it."
|
yield f"The settings for {selected_model} have been updated.\nClick on \"Load\" to load it."
|
||||||
return
|
return
|
||||||
|
|
||||||
if selected_model == 'None':
|
if selected_model == 'None':
|
||||||
@ -145,7 +145,13 @@ def download_model_wrapper(repo_id, progress=gr.Progress()):
|
|||||||
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
links, sha256, is_lora = downloader.get_download_links_from_huggingface(model, branch, text_only=False)
|
||||||
|
|
||||||
yield ("Getting the output folder")
|
yield ("Getting the output folder")
|
||||||
output_folder = downloader.get_output_folder(model, branch, is_lora)
|
models_dir = Path(shared.args.model_dir)
|
||||||
|
|
||||||
|
# If the last part of the path is "models", remove it
|
||||||
|
if models_dir.name.lower() == 'models':
|
||||||
|
models_dir = models_dir.parent
|
||||||
|
|
||||||
|
output_folder = downloader.get_output_folder(model, branch, is_lora, base_folder=models_dir)
|
||||||
|
|
||||||
if check:
|
if check:
|
||||||
progress(0.5)
|
progress(0.5)
|
||||||
@ -218,8 +224,8 @@ def create_model_menus():
|
|||||||
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
shared.gradio['n_batch'] = gr.Slider(label="n_batch", minimum=1, maximum=2048, value=shared.args.n_batch)
|
||||||
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
shared.gradio['n_gpu_layers'] = gr.Slider(label="n-gpu-layers", minimum=0, maximum=128, value=shared.args.n_gpu_layers)
|
||||||
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
|
shared.gradio['n_ctx'] = gr.Slider(minimum=0, maximum=16384, step=256, label="n_ctx", value=shared.args.n_ctx)
|
||||||
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=shared.args.wbits if shared.args.wbits > 0 else "None")
|
shared.gradio['wbits'] = gr.Dropdown(label="wbits", choices=["None", 1, 2, 3, 4, 8], value=str(shared.args.wbits) if shared.args.wbits > 0 else "None")
|
||||||
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=shared.args.groupsize if shared.args.groupsize > 0 else "None")
|
shared.gradio['groupsize'] = gr.Dropdown(label="groupsize", choices=["None", 32, 64, 128, 1024], value=str(shared.args.groupsize) if shared.args.groupsize > 0 else "None")
|
||||||
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
shared.gradio['model_type'] = gr.Dropdown(label="model_type", choices=["None", "llama", "opt", "gptj"], value=shared.args.model_type or "None")
|
||||||
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
shared.gradio['pre_layer'] = gr.Slider(label="pre_layer", minimum=0, maximum=100, value=shared.args.pre_layer[0] if shared.args.pre_layer is not None else 0)
|
||||||
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
|
shared.gradio['autogptq_info'] = gr.Markdown('On some systems, AutoGPTQ can be 2x slower than GPTQ-for-LLaMa. You can manually select the GPTQ-for-LLaMa loader above.')
|
||||||
@ -242,6 +248,7 @@ def create_model_menus():
|
|||||||
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
shared.gradio['load_in_4bit'] = gr.Checkbox(label="load-in-4bit", value=shared.args.load_in_4bit)
|
||||||
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
shared.gradio['use_double_quant'] = gr.Checkbox(label="use_double_quant", value=shared.args.use_double_quant)
|
||||||
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
shared.gradio['no_mmap'] = gr.Checkbox(label="no-mmap", value=shared.args.no_mmap)
|
||||||
|
shared.gradio['low_vram'] = gr.Checkbox(label="low-vram", value=shared.args.low_vram)
|
||||||
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
shared.gradio['mlock'] = gr.Checkbox(label="mlock", value=shared.args.mlock)
|
||||||
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
shared.gradio['llama_cpp_seed'] = gr.Number(label='Seed (0 for random)', value=shared.args.llama_cpp_seed)
|
||||||
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
shared.gradio['trust_remote_code'] = gr.Checkbox(label="trust-remote-code", value=shared.args.trust_remote_code, info='Make sure to inspect the .py files inside the model folder before loading it with this option enabled.')
|
||||||
@ -976,7 +983,7 @@ def create_interface():
|
|||||||
lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
|
lambda: 'characters/instruction-following/', None, gradio('delete_root')).then(
|
||||||
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
lambda: gr.update(visible=True), None, gradio('file_deleter'))
|
||||||
|
|
||||||
shared.gradio['download_button'].click(chat.save_history, gradio('history'), gradio('download'))
|
shared.gradio['download_button'].click(chat.save_history_at_user_request, gradio('history', 'character_menu', 'mode'), gradio('download'))
|
||||||
shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'))
|
shared.gradio['Submit character'].click(chat.upload_character, gradio('upload_json', 'upload_img_bot'), gradio('character_menu'))
|
||||||
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
|
shared.gradio['upload_json'].upload(lambda: gr.update(interactive=True), None, gradio('Submit character'))
|
||||||
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
|
shared.gradio['upload_json'].clear(lambda: gr.update(interactive=False), None, gradio('Submit character'))
|
||||||
|
@ -0,0 +1 @@
|
|||||||
|
to load multiple raw text files create a subdirectory and put them all there
|
Loading…
Reference in New Issue
Block a user