This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View File

@ -23,7 +23,7 @@ from tqdm.contrib.concurrent import thread_map
class ModelDownloader: class ModelDownloader:
def __init__(self, max_retries = 5): def __init__(self, max_retries=5):
self.s = requests.Session() self.s = requests.Session()
if max_retries: if max_retries:
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries)) self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))

View File

@ -22,6 +22,7 @@ options = {
'session_metadata': 'text-generation-webui', 'session_metadata': 'text-generation-webui',
} }
def ui(): def ui():
settings = shared.settings.get("ngrok") settings = shared.settings.get("ngrok")
if settings: if settings:
@ -33,4 +34,3 @@ def ui():
logging.info(f"Ingress established at: {tunnel.url()}") logging.info(f"Ingress established at: {tunnel.url()}")
except ModuleNotFoundError: except ModuleNotFoundError:
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`") logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")

View File

@ -41,7 +41,7 @@ class LogprobProcessor(LogitsProcessor):
# XXX hack. should find the selected token and include the prob of that # XXX hack. should find the selected token and include the prob of that
# ... but we just +1 here instead because we don't know it yet. # ... but we just +1 here instead because we don't know it yet.
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [ decode(tok) for tok in top_indices[0] ] top_tokens = [decode(tok) for tok in top_indices[0]]
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist())) self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
return logits return logits
@ -50,7 +50,7 @@ def convert_logprobs_to_tiktoken(model, logprobs):
try: try:
encoder = tiktoken.encoding_for_model(model) encoder = tiktoken.encoding_for_model(model)
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
return dict([ (encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items() ]) return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
except KeyError: except KeyError:
# assume native tokens if we can't find the tokenizer # assume native tokens if we can't find the tokenizer
return logprobs return logprobs
@ -220,16 +220,16 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']: if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}." err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
print(f"Warning: ${err_msg}") print(f"Warning: ${err_msg}")
#raise InvalidRequestError(message=err_msg) # raise InvalidRequestError(message=err_msg)
return prompt, token_count return prompt, token_count
def chat_completions(body: dict, is_legacy: bool=False) -> dict: def chat_completions(body: dict, is_legacy: bool = False) -> dict:
# Chat Completions # Chat Completions
object_type = 'chat.completions' object_type = 'chat.completions'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000)) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# common params # common params
@ -296,12 +296,12 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
# generator # generator
def stream_chat_completions(body: dict, is_legacy: bool=False): def stream_chat_completions(body: dict, is_legacy: bool = False):
# Chat Completions # Chat Completions
stream_object_type = 'chat.completions.chunk' stream_object_type = 'chat.completions.chunk'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000)) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# common params # common params
@ -342,7 +342,7 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
if logprob_proc: # not official for chat yet if logprob_proc: # not official for chat yet
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
#else: # else:
# chunk[resp_list][0]["logprobs"] = None # chunk[resp_list][0]["logprobs"] = None
return chunk return chunk
@ -380,7 +380,6 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
yield chunk yield chunk
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length" stop_reason = "length"
@ -396,12 +395,12 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
yield chunk yield chunk
def completions(body: dict, is_legacy: bool=False): def completions(body: dict, is_legacy: bool = False):
# Legacy # Legacy
# Text Completions # Text Completions
object_type = 'text_completion' object_type = 'text_completion'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time()*1000000000)) cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
@ -433,7 +432,7 @@ def completions(body: dict, is_legacy: bool=False):
if token_count + max_tokens > req_params['truncation_length']: if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
#print(f"Warning: ${err_msg}") # print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str) raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['echo'] = default(body, 'echo', req_params['echo'])
@ -486,13 +485,13 @@ def completions(body: dict, is_legacy: bool=False):
# generator # generator
def stream_completions(body: dict, is_legacy: bool=False): def stream_completions(body: dict, is_legacy: bool = False):
# Legacy # Legacy
# Text Completions # Text Completions
#object_type = 'text_completion' # object_type = 'text_completion'
stream_object_type = 'text_completion.chunk' stream_object_type = 'text_completion.chunk'
created_time = int(time.time()) created_time = int(time.time())
cmpl_id = "conv-%d" % (int(time.time()*1000000000)) cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
resp_list = 'data' if is_legacy else 'choices' resp_list = 'data' if is_legacy else 'choices'
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
@ -524,7 +523,7 @@ def stream_completions(body: dict, is_legacy: bool=False):
if token_count + max_tokens > req_params['truncation_length']: if token_count + max_tokens > req_params['truncation_length']:
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})." err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
#print(f"Warning: ${err_msg}") # print(f"Warning: ${err_msg}")
raise InvalidRequestError(message=err_msg, param=max_tokens_str) raise InvalidRequestError(message=err_msg, param=max_tokens_str)
req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['echo'] = default(body, 'echo', req_params['echo'])
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
completion_token_count += len(encode(new_content)[0]) completion_token_count += len(encode(new_content)[0])
yield chunk yield chunk
stop_reason = "stop" stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens: if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length" stop_reason = "length"

View File

@ -41,10 +41,13 @@ default_req_params = {
# 'requested_model' - temporarily used # 'requested_model' - temporarily used
} }
def get_default_req_params(): def get_default_req_params():
return copy.deepcopy(default_req_params) return copy.deepcopy(default_req_params)
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
if type(val) != type(default): if type(val) != type(default):
@ -59,6 +62,6 @@ def default(dic, key, default):
val = default val = default
return val return val
def clamp(value, minvalue, maxvalue): def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue)) return max(minvalue, min(value, maxvalue))

View File

@ -8,9 +8,9 @@ from extensions.openai.errors import *
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict: def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
created_time = int(time.time()*1000) created_time = int(time.time() * 1000)
# Request parameters # Request parameters
req_params = get_default_req_params() req_params = get_default_req_params()
@ -41,7 +41,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ') instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
if instruct['user']: if instruct['user']:
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ]) stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
except Exception as e: except Exception as e:
instruction_template = default_template instruction_template = default_template

View File

@ -6,26 +6,30 @@ from extensions.openai.errors import *
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None embeddings_model = None
def load_embedding_model(model): def load_embedding_model(model):
try: try:
emb_model = SentenceTransformer(model) emb_model = SentenceTransformer(model)
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}") print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
except Exception as e: except Exception as e:
print(f"\nError: Failed to load embedding model: {model}") print(f"\nError: Failed to load embedding model: {model}")
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message = repr(e)) raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
return emb_model return emb_model
def get_embeddings_model(): def get_embeddings_model():
global embeddings_model, st_model global embeddings_model, st_model
if st_model and not embeddings_model: if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model embeddings_model = load_embedding_model(st_model) # lazy load the model
return embeddings_model return embeddings_model
def get_embeddings_model_name(): def get_embeddings_model_name():
global st_model global st_model
return st_model return st_model
def embeddings(input: list, encoding_format: str): def embeddings(input: list, encoding_format: str):
embeddings = get_embeddings_model().encode(input).tolist() embeddings = get_embeddings_model().encode(input).tolist()

View File

@ -1,8 +1,9 @@
class OpenAIError(Exception): class OpenAIError(Exception):
def __init__(self, message = None, code = 500, internal_message = ''): def __init__(self, message=None, code=500, internal_message=''):
self.message = message self.message = message
self.code = code self.code = code
self.internal_message = internal_message self.internal_message = internal_message
def __repr__(self): def __repr__(self):
return "%s(message=%r, code=%d)" % ( return "%s(message=%r, code=%d)" % (
self.__class__.__name__, self.__class__.__name__,
@ -10,10 +11,12 @@ class OpenAIError(Exception):
self.code, self.code,
) )
class InvalidRequestError(OpenAIError): class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code = 400, error_type ='InvalidRequestError', internal_message = ''): def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message) super(OpenAIError, self).__init__(message, code, error_type, internal_message)
self.param = param self.param = param
def __repr__(self): def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % ( return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__, self.__class__.__name__,
@ -22,6 +25,7 @@ class InvalidRequestError(OpenAIError):
self.param, self.param,
) )
class ServiceUnavailableError(OpenAIError): class ServiceUnavailableError(OpenAIError):
def __init__(self, message = None, code = 500, error_type ='ServiceUnavailableError', internal_message = ''): def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message) super(OpenAIError, self).__init__(message, code, error_type, internal_message)

View File

@ -3,6 +3,7 @@ import time
import requests import requests
from extensions.openai.errors import * from extensions.openai.errors import *
def generations(prompt: str, size: str, response_format: str, n: int): def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img # Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E # Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
@ -15,7 +16,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
# require changing the form data handling to accept multipart form data, also to properly support # require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later! # url return types will require file management and a web serving files... Perhaps later!
width, height = [ int(x) for x in size.split('x') ] # ignore the restrictions on size width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
# to hack on better generation, edit default payload. # to hack on better generation, edit default payload.
payload = { payload = {
@ -37,7 +38,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
response = requests.post(url=sd_url, json=payload) response = requests.post(url=sd_url, json=payload)
r = response.json() r = response.json()
if response.status_code != 200 or 'images' not in r: if response.status_code != 200 or 'images' not in r:
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code = response.status_code) raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
# r['parameters']... # r['parameters']...
for b64_json in r['images']: for b64_json in r['images']:
if response_format == 'b64_json': if response_format == 'b64_json':

View File

@ -7,8 +7,10 @@ from modules.models_settings import (get_model_settings_from_yamls,
from extensions.openai.embeddings import get_embeddings_model_name from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import * from extensions.openai.errors import *
def get_current_model_list() -> list: def get_current_model_list() -> list:
return [ shared.model_name ] # The real chat/completions model, maybe "None" return [shared.model_name] # The real chat/completions model, maybe "None"
def get_pseudo_model_list() -> list: def get_pseudo_model_list() -> list:
return [ # these are expected by so much, so include some here as a dummy return [ # these are expected by so much, so include some here as a dummy
@ -16,6 +18,7 @@ def get_pseudo_model_list() -> list:
'text-embedding-ada-002', 'text-embedding-ada-002',
] ]
def load_model(model_name: str) -> dict: def load_model(model_name: str) -> dict:
resp = { resp = {
"id": model_name, "id": model_name,
@ -23,7 +26,7 @@ def load_model(model_name: str) -> dict:
"owner": "self", "owner": "self",
"ready": True, "ready": True,
} }
if model_name not in get_pseudo_model_list() + [ get_embeddings_model_name() ] + get_current_model_list(): # Real model only if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only
# No args. Maybe it works anyways! # No args. Maybe it works anyways!
# TODO: hack some heuristics into args for better results # TODO: hack some heuristics into args for better results
@ -48,16 +51,16 @@ def load_model(model_name: str) -> dict:
def list_models(is_legacy: bool = False) -> dict: def list_models(is_legacy: bool = False) -> dict:
# TODO: Lora's? # TODO: Lora's?
all_model_list = get_current_model_list() + [ get_embeddings_model_name() ] + get_pseudo_model_list() + get_available_models() all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()
models = {} models = {}
if is_legacy: if is_legacy:
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ] models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
if not shared.model: if not shared.model:
models[0]['ready'] = False models[0]['ready'] = False
else: else:
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ] models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]
resp = { resp = {
"object": "list", "object": "list",
@ -74,4 +77,3 @@ def model_info(model_name: str) -> dict:
"owned_by": "user", "owned_by": "user",
"permission": [] "permission": []
} }

View File

@ -7,7 +7,7 @@ from extensions.openai.embeddings import get_embeddings_model
moderations_disabled = False # return 0/false moderations_disabled = False # return 0/false
category_embeddings = None category_embeddings = None
antonym_embeddings = None antonym_embeddings = None
categories = [ "sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence" ] categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
flag_threshold = 0.5 flag_threshold = 0.5
@ -40,23 +40,22 @@ def moderations(input):
embeddings_model = get_embeddings_model() embeddings_model = get_embeddings_model()
if not embeddings_model or moderations_disabled: if not embeddings_model or moderations_disabled:
results['results'] = [{ results['results'] = [{
'categories': dict([ (C, False) for C in categories]), 'categories': dict([(C, False) for C in categories]),
'category_scores': dict([ (C, 0.0) for C in categories]), 'category_scores': dict([(C, 0.0) for C in categories]),
'flagged': False, 'flagged': False,
}] }]
return results return results
category_embeddings = get_category_embeddings() category_embeddings = get_category_embeddings()
# input, string or array # input, string or array
if isinstance(input, str): if isinstance(input, str):
input = [input] input = [input]
for in_str in input: for in_str in input:
for ine in embeddings_model.encode([in_str]).tolist(): for ine in embeddings_model.encode([in_str]).tolist():
category_scores = dict([ (C, mod_score(category_embeddings[C], ine)) for C in categories ]) category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
category_flags = dict([ (C, bool(category_scores[C] > flag_threshold)) for C in categories ]) category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
flagged = any(category_flags.values()) flagged = any(category_flags.values())
results['results'].extend([{ results['results'].extend([{

View File

@ -22,6 +22,7 @@ params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
} }
class Handler(BaseHTTPRequestHandler): class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self): def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*") self.send_header("Access-Control-Allow-Origin", "*")
@ -72,7 +73,7 @@ class Handler(BaseHTTPRequestHandler):
if not no_debug: if not no_debug:
debug_msg(r_utf8) debug_msg(r_utf8)
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''): def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
error_resp = { error_resp = {
'error': { 'error': {
@ -84,7 +85,7 @@ class Handler(BaseHTTPRequestHandler):
} }
if internal_message: if internal_message:
print(internal_message) print(internal_message)
#error_resp['internal_message'] = internal_message # error_resp['internal_message'] = internal_message
self.return_json(error_resp, code) self.return_json(error_resp, code)

View File

@ -1,6 +1,7 @@
from extensions.openai.utils import float_list_to_base64 from extensions.openai.utils import float_list_to_base64
from modules.text_generation import encode, decode from modules.text_generation import encode, decode
def token_count(prompt): def token_count(prompt):
tokens = encode(prompt)[0] tokens = encode(prompt)[0]
@ -11,8 +12,8 @@ def token_count(prompt):
} }
def token_encode(input, encoding_format = ''): def token_encode(input, encoding_format=''):
#if isinstance(input, list): # if isinstance(input, list):
tokens = encode(input)[0] tokens = encode(input)[0]
return { return {
@ -25,9 +26,9 @@ def token_encode(input, encoding_format = ''):
def token_decode(tokens, encoding_format): def token_decode(tokens, encoding_format):
#if isinstance(input, list): # if isinstance(input, list):
# if encoding_format == "base64": # if encoding_format == "base64":
# tokens = base64_to_float_list(tokens) # tokens = base64_to_float_list(tokens)
output = decode(tokens)[0] output = decode(tokens)[0]
return { return {

View File

@ -2,6 +2,7 @@ import os
import base64 import base64
import numpy as np import numpy as np
def float_list_to_base64(float_list): def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects # Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32") float_array = np.array(float_list, dtype="float32")
@ -16,11 +17,13 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii') ascii_string = encoded_bytes.decode('ascii')
return ascii_string return ascii_string
def end_line(s): def end_line(s):
if s and s[-1] != '\n': if s and s[-1] != '\n':
s = s + '\n' s = s + '\n'
return s return s
def debug_msg(*args, **kwargs): def debug_msg(*args, **kwargs):
if 'OPENEDAI_DEBUG' in os.environ: if 'OPENEDAI_DEBUG' in os.environ:
print(*args, **kwargs) print(*args, **kwargs)

View File

@ -126,6 +126,8 @@ def input_modifier(string):
return string return string
# Get and save the Stable Diffusion-generated picture # Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description, character): def get_SD_pictures(description, character):
global params global params
@ -186,6 +188,8 @@ def get_SD_pictures(description, character):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history) # TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging? # and replace it with 'text' for the purposes of logging?
def output_modifier(string, state): def output_modifier(string, state):
""" """
This function is applied to the model outputs. This function is applied to the model outputs.

View File

@ -113,7 +113,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(history['internal']) > params['chunk_count'] and user_input != '': if len(history['internal']) > params['chunk_count'] and user_input != '':
chunks = [] chunks = []
hist_size = len(history['internal']) hist_size = len(history['internal'])
for i in range(hist_size-1): for i in range(hist_size - 1):
chunks.append(make_single_exchange(i)) chunks.append(make_single_exchange(i))
add_chunks_to_collector(chunks, chat_collector) add_chunks_to_collector(chunks, chat_collector)

View File

@ -16,7 +16,7 @@ params = {
} }
def do_stt(audio,whipser_model,whipser_language): def do_stt(audio, whipser_model, whipser_language):
transcription = "" transcription = ""
r = sr.Recognizer() r = sr.Recognizer()
@ -33,10 +33,10 @@ def do_stt(audio,whipser_model,whipser_language):
return transcription return transcription
def auto_transcribe(audio, auto_submit,whipser_model,whipser_language): def auto_transcribe(audio, auto_submit, whipser_model, whipser_language):
if audio is None: if audio is None:
return "", "" return "", ""
transcription = do_stt(audio,whipser_model,whipser_language) transcription = do_stt(audio, whipser_model, whipser_language)
if auto_submit: if auto_submit:
input_hijack.update({"state": True, "value": [transcription, transcription]}) input_hijack.update({"state": True, "value": [transcription, transcription]})
@ -50,11 +50,11 @@ def ui():
with gr.Row(): with gr.Row():
with gr.Accordion("Settings", open=False): with gr.Accordion("Settings", open=False):
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit']) auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'],choices=["tiny.en","base.en", "small.en","medium.en","tiny","base","small","medium","large"]) whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'],choices=["chinese","german","spanish","russian","korean","french","japanese","portuguese","turkish","polish","catalan","dutch","arabic","swedish","italian","indonesian","hindi","finnish","vietnamese","hebrew","ukrainian","greek","malay","czech","romanian","danish","hungarian","tamil","norwegian","thai","urdu","croatian","bulgarian","lithuanian","latin","maori","malayalam","welsh","slovak","telugu","persian","latvian","bengali","serbian","azerbaijani","slovenian","kannada","estonian","macedonian","breton","basque","icelandic","armenian","nepali","mongolian","bosnian","kazakh","albanian","swahili","galician","marathi","punjabi","sinhala","khmer","shona","yoruba","somali","afrikaans","occitan","georgian","belarusian","tajik","sindhi","gujarati","amharic","yiddish","lao","uzbek","faroese","haitian creole","pashto","turkmen","nynorsk","maltese","sanskrit","luxembourgish","myanmar","tibetan","tagalog","malagasy","assamese","tatar","hawaiian","lingala","hausa","bashkir","javanese","sundanese"]) whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
audio.change( audio.change(
auto_transcribe, [audio, auto_submit,whipser_model,whipser_language], [shared.gradio['textbox'], audio]).then( auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}") None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None) whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None) whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)

View File

@ -54,14 +54,14 @@ loaders_and_params = {
'trust_remote_code', 'trust_remote_code',
'transformers_info' 'transformers_info'
], ],
'ExLlama' : [ 'ExLlama': [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'compress_pos_emb', 'compress_pos_emb',
'alpha_value', 'alpha_value',
'exllama_info', 'exllama_info',
], ],
'ExLlama_HF' : [ 'ExLlama_HF': [
'gpu_split', 'gpu_split',
'max_seq_len', 'max_seq_len',
'compress_pos_emb', 'compress_pos_emb',

View File

@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
''' '''
Copied from the transformers library Copied from the transformers library
''' '''
def __init__(self, penalty: float, _range: int): def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0): if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")

View File

@ -116,7 +116,7 @@ def get_available_loras():
def get_datasets(path: str, ext: str): def get_datasets(path: str, ext: str):
# include subdirectories for raw txt files to allow training from a subdirectory of txt files # include subdirectories for raw txt files to allow training from a subdirectory of txt files
if ext == "txt": if ext == "txt":
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys) return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys) return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)