mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
@ -23,7 +23,7 @@ from tqdm.contrib.concurrent import thread_map
|
||||
|
||||
|
||||
class ModelDownloader:
|
||||
def __init__(self, max_retries = 5):
|
||||
def __init__(self, max_retries=5):
|
||||
self.s = requests.Session()
|
||||
if max_retries:
|
||||
self.s.mount('https://cdn-lfs.huggingface.co', HTTPAdapter(max_retries=max_retries))
|
||||
|
@ -75,7 +75,7 @@ def build_parameters(body, chat=False):
|
||||
'greeting': greeting,
|
||||
'name1_instruct': name1_instruct,
|
||||
'name2_instruct': name2_instruct,
|
||||
'context_instruct': body.get('context_instruct', context_instruct),
|
||||
'context_instruct': body.get('context_instruct', context_instruct),
|
||||
'turn_template': turn_template,
|
||||
'chat-instruct_command': str(body.get('chat-instruct_command', shared.settings['chat-instruct_command'])),
|
||||
'history': body.get('history', {'internal': [], 'visible': []})
|
||||
|
@ -160,7 +160,7 @@ def ui():
|
||||
api_key = gr.Textbox(placeholder="Enter your API key.", label='API Key')
|
||||
|
||||
with gr.Row():
|
||||
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||
model = gr.Dropdown(value=params['model'], choices=LANG_MODELS, label='Language model')
|
||||
|
||||
with gr.Row():
|
||||
convert = gr.Button('Permanently replace audios with the message texts')
|
||||
|
@ -1,8 +1,8 @@
|
||||
# Adds ngrok ingress, to use add `--extension ngrok` to the command line options
|
||||
#
|
||||
# Parameters can be customized in settings.json of webui, e.g.:
|
||||
# Parameters can be customized in settings.json of webui, e.g.:
|
||||
# {"ngrok": {"basic_auth":"user:password"} }
|
||||
# or
|
||||
# or
|
||||
# {"ngrok": {"oauth_provider":"google", "oauth_allow_emails":["asdf@asdf.com"]} }
|
||||
#
|
||||
# See this example for full list of options: https://github.com/ngrok/ngrok-py/blob/main/examples/ngrok-connect-full.py
|
||||
@ -22,6 +22,7 @@ options = {
|
||||
'session_metadata': 'text-generation-webui',
|
||||
}
|
||||
|
||||
|
||||
def ui():
|
||||
settings = shared.settings.get("ngrok")
|
||||
if settings:
|
||||
@ -33,4 +34,3 @@ def ui():
|
||||
logging.info(f"Ingress established at: {tunnel.url()}")
|
||||
except ModuleNotFoundError:
|
||||
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
||||
|
||||
|
@ -36,12 +36,12 @@ class LogprobProcessor(LogitsProcessor):
|
||||
super().__init__()
|
||||
|
||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if self.logprobs is not None: # 0-5
|
||||
if self.logprobs is not None: # 0-5
|
||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||
# XXX hack. should find the selected token and include the prob of that
|
||||
# ... but we just +1 here instead because we don't know it yet.
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||
top_tokens = [decode(tok) for tok in top_indices[0]]
|
||||
self.token_alternatives = dict(zip(top_tokens, top_values[0].tolist()))
|
||||
return logits
|
||||
|
||||
@ -50,9 +50,9 @@ def convert_logprobs_to_tiktoken(model, logprobs):
|
||||
try:
|
||||
encoder = tiktoken.encoding_for_model(model)
|
||||
# just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||
return dict([ (encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items() ])
|
||||
return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||
except KeyError:
|
||||
# assume native tokens if we can't find the tokenizer
|
||||
# assume native tokens if we can't find the tokenizer
|
||||
return logprobs
|
||||
|
||||
|
||||
@ -71,9 +71,9 @@ def marshal_common_params(body):
|
||||
# OpenAI API Parameters
|
||||
# model - ignored for now, TODO: When we can reliably load a model or lora from a name only change this
|
||||
req_params['requested_model'] = body.get('model', shared.model_name)
|
||||
|
||||
|
||||
req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||
req_params['temperature'] = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0/2.0
|
||||
req_params['top_p'] = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
n = default(body, 'n', 1)
|
||||
if n != 1:
|
||||
@ -81,7 +81,7 @@ def marshal_common_params(body):
|
||||
|
||||
if 'stop' in body: # str or array, max len 4 (ignored)
|
||||
if isinstance(body['stop'], str):
|
||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||
req_params['stopping_strings'] = [body['stop']] # non-standard parameter
|
||||
elif isinstance(body['stop'], list):
|
||||
req_params['stopping_strings'] = body['stop']
|
||||
|
||||
@ -91,7 +91,7 @@ def marshal_common_params(body):
|
||||
|
||||
logits_processor = []
|
||||
logit_bias = body.get('logit_bias', None)
|
||||
if logit_bias: # {str: float, ...}
|
||||
if logit_bias: # {str: float, ...}
|
||||
# XXX convert tokens from tiktoken based on requested model
|
||||
# Ex.: 'logit_bias': {'1129': 100, '11442': 100, '16243': 100}
|
||||
try:
|
||||
@ -103,19 +103,19 @@ def marshal_common_params(body):
|
||||
print(logit_bias, '->', new_logit_bias)
|
||||
logit_bias = new_logit_bias
|
||||
except KeyError:
|
||||
pass # assume native tokens if we can't find the tokenizer
|
||||
pass # assume native tokens if we can't find the tokenizer
|
||||
|
||||
logits_processor = [LogitsBiasProcessor(logit_bias)]
|
||||
|
||||
logprobs = None # coming to chat eventually
|
||||
logprobs = None # coming to chat eventually
|
||||
if 'logprobs' in body:
|
||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
logprobs = default(body, 'logprobs', 0) # maybe cap at topk? don't clamp 0-5.
|
||||
req_params['logprob_proc'] = LogprobProcessor(logprobs)
|
||||
logits_processor.extend([req_params['logprob_proc']])
|
||||
else:
|
||||
logprobs = None
|
||||
|
||||
if logits_processor: # requires logits_processor support
|
||||
if logits_processor: # requires logits_processor support
|
||||
req_params['logits_processor'] = LogitsProcessorList(logits_processor)
|
||||
|
||||
return req_params
|
||||
@ -123,14 +123,14 @@ def marshal_common_params(body):
|
||||
|
||||
def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
# functions
|
||||
if body.get('functions', []): # chat only
|
||||
if body.get('functions', []): # chat only
|
||||
raise InvalidRequestError(message="functions is not supported.", param='functions')
|
||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||
|
||||
if not 'messages' in body:
|
||||
raise InvalidRequestError(message="messages is required", param='messages')
|
||||
|
||||
|
||||
messages = body['messages']
|
||||
|
||||
role_formats = {
|
||||
@ -152,11 +152,11 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
template = instruct['turn_template']
|
||||
system_message_template = "{message}"
|
||||
system_message_default = instruct['context']
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
|
||||
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct['user'])
|
||||
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct['bot'])
|
||||
bot_prompt = bot_message_template[:bot_message_template.find('{message}')].rstrip(' ')
|
||||
|
||||
|
||||
role_formats = {
|
||||
'user': user_message_template,
|
||||
'assistant': bot_message_template,
|
||||
@ -167,7 +167,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
req_params['stopping_strings'].extend(['\n###'])
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
elif instruct['user']: # WizardLM and some others have no user prompt.
|
||||
req_params['stopping_strings'].extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
debug_msg(f"Loaded instruction role format: {shared.settings['instruction_template']}")
|
||||
@ -220,16 +220,16 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
||||
if max_tokens > 0 and token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"This model maximum context length is {req_params['truncation_length']} tokens. However, your messages resulted in over {token_count} tokens and max_tokens is {max_tokens}."
|
||||
print(f"Warning: ${err_msg}")
|
||||
#raise InvalidRequestError(message=err_msg)
|
||||
# raise InvalidRequestError(message=err_msg)
|
||||
|
||||
return prompt, token_count
|
||||
|
||||
|
||||
def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
def chat_completions(body: dict, is_legacy: bool = False) -> dict:
|
||||
# Chat Completions
|
||||
object_type = 'chat.completions'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
@ -237,7 +237,7 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
req_params['stream'] = False
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
@ -254,7 +254,7 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
answer = ''
|
||||
@ -286,9 +286,9 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
"total_tokens": token_count + completion_token_count
|
||||
}
|
||||
}
|
||||
if logprob_proc: # not official for chat yet
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
@ -296,12 +296,12 @@ def chat_completions(body: dict, is_legacy: bool=False) -> dict:
|
||||
|
||||
|
||||
# generator
|
||||
def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
def stream_chat_completions(body: dict, is_legacy: bool = False):
|
||||
|
||||
# Chat Completions
|
||||
stream_object_type = 'chat.completions.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time()*1000000000))
|
||||
cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# common params
|
||||
@ -309,7 +309,7 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
req_params['stream'] = True
|
||||
requested_model = req_params.pop('requested_model')
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
req_params['top_k'] = 20 # There is no best_of/top_k param for chat, but it is much improved with a higher top_k.
|
||||
|
||||
# chat default max_tokens is 'inf', but also flexible
|
||||
max_tokens = 0
|
||||
@ -339,10 +339,10 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
}],
|
||||
}
|
||||
|
||||
if logprob_proc: # not official for chat yet
|
||||
if logprob_proc: # not official for chat yet
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
#else:
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
# else:
|
||||
# chunk[resp_list][0]["logprobs"] = None
|
||||
return chunk
|
||||
|
||||
@ -350,7 +350,7 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
# generate reply #######################################
|
||||
debug_msg({'prompt': prompt, 'req_params': req_params})
|
||||
|
||||
|
||||
stopping_strings = req_params.pop('stopping_strings', [])
|
||||
logprob_proc = req_params.pop('logprob_proc', None)
|
||||
|
||||
@ -377,9 +377,8 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
chunk = chat_streaming_chunk(new_content)
|
||||
|
||||
yield chunk
|
||||
|
||||
yield chunk
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
@ -396,12 +395,12 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
yield chunk
|
||||
|
||||
|
||||
def completions(body: dict, is_legacy: bool=False):
|
||||
def completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
object_type = 'text_completion'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
@ -433,7 +432,7 @@ def completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
#print(f"Warning: ${err_msg}")
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
@ -478,21 +477,21 @@ def completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
resp[resp_list][0]["logprobs"] = None
|
||||
resp[resp_list][0]["logprobs"] = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
# generator
|
||||
def stream_completions(body: dict, is_legacy: bool=False):
|
||||
def stream_completions(body: dict, is_legacy: bool = False):
|
||||
# Legacy
|
||||
# Text Completions
|
||||
#object_type = 'text_completion'
|
||||
# object_type = 'text_completion'
|
||||
stream_object_type = 'text_completion.chunk'
|
||||
created_time = int(time.time())
|
||||
cmpl_id = "conv-%d" % (int(time.time()*1000000000))
|
||||
cmpl_id = "conv-%d" % (int(time.time() * 1000000000))
|
||||
resp_list = 'data' if is_legacy else 'choices'
|
||||
|
||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||
@ -524,7 +523,7 @@ def stream_completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
if token_count + max_tokens > req_params['truncation_length']:
|
||||
err_msg = f"The token count of your prompt ({token_count}) plus max_tokens ({max_tokens}) cannot exceed the model's context length ({req_params['truncation_length']})."
|
||||
#print(f"Warning: ${err_msg}")
|
||||
# print(f"Warning: ${err_msg}")
|
||||
raise InvalidRequestError(message=err_msg, param=max_tokens_str)
|
||||
|
||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||
@ -545,9 +544,9 @@ def stream_completions(body: dict, is_legacy: bool=False):
|
||||
}
|
||||
if logprob_proc:
|
||||
top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives)
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]}
|
||||
else:
|
||||
chunk[resp_list][0]["logprobs"] = None
|
||||
chunk[resp_list][0]["logprobs"] = None
|
||||
|
||||
return chunk
|
||||
|
||||
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
yield chunk
|
||||
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
@ -3,10 +3,10 @@ import copy
|
||||
# Slightly different defaults for OpenAI's API
|
||||
# Data type is important, Ex. use 0.0 for a float 0
|
||||
default_req_params = {
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'max_new_tokens': 16, # 'Inf' for chat
|
||||
'temperature': 1.0,
|
||||
'top_p': 1.0,
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
'top_k': 1, # choose 20 for chat in absence of another default
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
@ -15,7 +15,7 @@ default_req_params = {
|
||||
'echo': False,
|
||||
'seed': -1,
|
||||
# 'n' : default(body, 'n', 1), # 'n' doesn't have a direct map
|
||||
'truncation_length': 2048, # first use shared.settings value
|
||||
'truncation_length': 2048, # first use shared.settings value
|
||||
'add_bos_token': True,
|
||||
'do_sample': True,
|
||||
'typical_p': 1.0,
|
||||
@ -41,10 +41,13 @@ default_req_params = {
|
||||
# 'requested_model' - temporarily used
|
||||
}
|
||||
|
||||
|
||||
def get_default_req_params():
|
||||
return copy.deepcopy(default_req_params)
|
||||
|
||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
|
||||
|
||||
def default(dic, key, default):
|
||||
val = dic.get(key, default)
|
||||
if type(val) != type(default):
|
||||
@ -59,6 +62,6 @@ def default(dic, key, default):
|
||||
val = default
|
||||
return val
|
||||
|
||||
|
||||
def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
||||
|
||||
|
@ -8,9 +8,9 @@ from extensions.openai.errors import *
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
|
||||
def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
||||
|
||||
created_time = int(time.time()*1000)
|
||||
created_time = int(time.time() * 1000)
|
||||
|
||||
# Request parameters
|
||||
req_params = get_default_req_params()
|
||||
@ -24,7 +24,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
)
|
||||
|
||||
instruction_template = default_template
|
||||
|
||||
|
||||
# Use the special instruction/input/response template for anything trained like Alpaca
|
||||
if shared.settings['instruction_template']:
|
||||
if 'Alpaca' in shared.settings['instruction_template']:
|
||||
@ -41,7 +41,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
|
||||
instruction_template = instruct.get('context', '') + template[:template.find('<|bot-message|>')].rstrip(' ')
|
||||
if instruct['user']:
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user'] ])
|
||||
stopping_strings.extend(['\n' + instruct['user'], instruct['user']])
|
||||
|
||||
except Exception as e:
|
||||
instruction_template = default_template
|
||||
@ -54,14 +54,14 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
edit_task = instruction_template.format(instruction=instruction, input=input)
|
||||
|
||||
truncation_length = shared.settings['truncation_length']
|
||||
|
||||
|
||||
token_count = len(encode(edit_task)[0])
|
||||
max_tokens = truncation_length - token_count
|
||||
|
||||
if max_tokens < 1:
|
||||
err_msg = f"This model maximum context length is {truncation_length} tokens. However, your messages resulted in over {truncation_length - max_tokens} tokens."
|
||||
raise InvalidRequestError(err_msg, param='input')
|
||||
|
||||
|
||||
req_params['max_new_tokens'] = max_tokens
|
||||
req_params['truncation_length'] = truncation_length
|
||||
req_params['temperature'] = temperature
|
||||
@ -71,7 +71,7 @@ def edits(instruction: str, input: str, temperature = 1.0, top_p = 1.0) -> dict:
|
||||
req_params['custom_stopping_strings'] = shared.settings['custom_stopping_strings']
|
||||
|
||||
debug_msg({'edit_template': edit_task, 'req_params': req_params, 'token_count': token_count})
|
||||
|
||||
|
||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||
|
||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
||||
|
@ -6,26 +6,30 @@ from extensions.openai.errors import *
|
||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model):
|
||||
try:
|
||||
emb_model = SentenceTransformer(model)
|
||||
print(f"\nLoaded embedding model: {model}, max sequence length: {emb_model.max_seq_length}")
|
||||
except Exception as e:
|
||||
print(f"\nError: Failed to load embedding model: {model}")
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message = repr(e))
|
||||
|
||||
raise ServiceUnavailableError(f"Error: Failed to load embedding model: {model}", internal_message=repr(e))
|
||||
|
||||
return emb_model
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
return embeddings_model
|
||||
|
||||
|
||||
def get_embeddings_model_name():
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str):
|
||||
|
||||
embeddings = get_embeddings_model().encode(input).tolist()
|
||||
@ -47,4 +51,4 @@ def embeddings(input: list, encoding_format: str):
|
||||
|
||||
debug_msg(f"Embeddings return size: {len(embeddings[0])}, number: {len(embeddings)}")
|
||||
|
||||
return response
|
||||
return response
|
||||
|
@ -1,8 +1,9 @@
|
||||
class OpenAIError(Exception):
|
||||
def __init__(self, message = None, code = 500, internal_message = ''):
|
||||
def __init__(self, message=None, code=500, internal_message=''):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.internal_message = internal_message
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d)" % (
|
||||
self.__class__.__name__,
|
||||
@ -10,10 +11,12 @@ class OpenAIError(Exception):
|
||||
self.code,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
def __init__(self, message, param, code = 400, error_type ='InvalidRequestError', internal_message = ''):
|
||||
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
self.param = param
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d, param=%s)" % (
|
||||
self.__class__.__name__,
|
||||
@ -22,6 +25,7 @@ class InvalidRequestError(OpenAIError):
|
||||
self.param,
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(OpenAIError):
|
||||
def __init__(self, message = None, code = 500, error_type ='ServiceUnavailableError', internal_message = ''):
|
||||
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
import requests
|
||||
from extensions.openai.errors import *
|
||||
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# Stable Diffusion callout wrapper for txt2img
|
||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||
@ -15,7 +16,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# require changing the form data handling to accept multipart form data, also to properly support
|
||||
# url return types will require file management and a web serving files... Perhaps later!
|
||||
|
||||
width, height = [ int(x) for x in size.split('x') ] # ignore the restrictions on size
|
||||
width, height = [int(x) for x in size.split('x')] # ignore the restrictions on size
|
||||
|
||||
# to hack on better generation, edit default payload.
|
||||
payload = {
|
||||
@ -23,7 +24,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
'width': width,
|
||||
'height': height,
|
||||
'batch_size': n,
|
||||
'restore_faces': True, # slightly less horrible
|
||||
'restore_faces': True, # slightly less horrible
|
||||
}
|
||||
|
||||
resp = {
|
||||
@ -37,7 +38,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
response = requests.post(url=sd_url, json=payload)
|
||||
r = response.json()
|
||||
if response.status_code != 200 or 'images' not in r:
|
||||
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code = response.status_code)
|
||||
raise ServiceUnavailableError(r.get('detail', [{'msg': 'Unknown error calling Stable Diffusion'}])[0]['msg'], code=response.status_code)
|
||||
# r['parameters']...
|
||||
for b64_json in r['images']:
|
||||
if response_format == 'b64_json':
|
||||
@ -45,4 +46,4 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
else:
|
||||
resp['data'].extend([{'url': f'data:image/png;base64,{b64_json}'}]) # yeah it's lazy. requests.get() will not work with this
|
||||
|
||||
return resp
|
||||
return resp
|
||||
|
@ -7,15 +7,18 @@ from modules.models_settings import (get_model_settings_from_yamls,
|
||||
from extensions.openai.embeddings import get_embeddings_model_name
|
||||
from extensions.openai.errors import *
|
||||
|
||||
|
||||
def get_current_model_list() -> list:
|
||||
return [ shared.model_name ] # The real chat/completions model, maybe "None"
|
||||
return [shared.model_name] # The real chat/completions model, maybe "None"
|
||||
|
||||
|
||||
def get_pseudo_model_list() -> list:
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
def load_model(model_name: str) -> dict:
|
||||
resp = {
|
||||
"id": model_name,
|
||||
@ -23,7 +26,7 @@ def load_model(model_name: str) -> dict:
|
||||
"owner": "self",
|
||||
"ready": True,
|
||||
}
|
||||
if model_name not in get_pseudo_model_list() + [ get_embeddings_model_name() ] + get_current_model_list(): # Real model only
|
||||
if model_name not in get_pseudo_model_list() + [get_embeddings_model_name()] + get_current_model_list(): # Real model only
|
||||
# No args. Maybe it works anyways!
|
||||
# TODO: hack some heuristics into args for better results
|
||||
|
||||
@ -39,7 +42,7 @@ def load_model(model_name: str) -> dict:
|
||||
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
|
||||
if not shared.model: # load failed.
|
||||
if not shared.model: # load failed.
|
||||
shared.model_name = "None"
|
||||
raise OpenAIError(f"Model load failed for: {shared.model_name}")
|
||||
|
||||
@ -48,16 +51,16 @@ def load_model(model_name: str) -> dict:
|
||||
|
||||
def list_models(is_legacy: bool = False) -> dict:
|
||||
# TODO: Lora's?
|
||||
all_model_list = get_current_model_list() + [ get_embeddings_model_name() ] + get_pseudo_model_list() + get_available_models()
|
||||
all_model_list = get_current_model_list() + [get_embeddings_model_name()] + get_pseudo_model_list() + get_available_models()
|
||||
|
||||
models = {}
|
||||
|
||||
if is_legacy:
|
||||
models = [{ "id": id, "object": "engine", "owner": "user", "ready": True } for id in all_model_list ]
|
||||
models = [{"id": id, "object": "engine", "owner": "user", "ready": True} for id in all_model_list]
|
||||
if not shared.model:
|
||||
models[0]['ready'] = False
|
||||
else:
|
||||
models = [{ "id": id, "object": "model", "owned_by": "user", "permission": [] } for id in all_model_list ]
|
||||
models = [{"id": id, "object": "model", "owned_by": "user", "permission": []} for id in all_model_list]
|
||||
|
||||
resp = {
|
||||
"object": "list",
|
||||
@ -74,4 +77,3 @@ def model_info(model_name: str) -> dict:
|
||||
"owned_by": "user",
|
||||
"permission": []
|
||||
}
|
||||
|
||||
|
@ -4,10 +4,10 @@ from numpy.linalg import norm
|
||||
from extensions.openai.embeddings import get_embeddings_model
|
||||
|
||||
|
||||
moderations_disabled = False # return 0/false
|
||||
moderations_disabled = False # return 0/false
|
||||
category_embeddings = None
|
||||
antonym_embeddings = None
|
||||
categories = [ "sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence" ]
|
||||
categories = ["sexual", "hate", "harassment", "self-harm", "sexual/minors", "hate/threatening", "violence/graphic", "self-harm/intent", "self-harm/instructions", "harassment/threatening", "violence"]
|
||||
flag_threshold = 0.5
|
||||
|
||||
|
||||
@ -40,23 +40,22 @@ def moderations(input):
|
||||
embeddings_model = get_embeddings_model()
|
||||
if not embeddings_model or moderations_disabled:
|
||||
results['results'] = [{
|
||||
'categories': dict([ (C, False) for C in categories]),
|
||||
'category_scores': dict([ (C, 0.0) for C in categories]),
|
||||
'categories': dict([(C, False) for C in categories]),
|
||||
'category_scores': dict([(C, 0.0) for C in categories]),
|
||||
'flagged': False,
|
||||
}]
|
||||
return results
|
||||
|
||||
category_embeddings = get_category_embeddings()
|
||||
|
||||
|
||||
# input, string or array
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
for in_str in input:
|
||||
for ine in embeddings_model.encode([in_str]).tolist():
|
||||
category_scores = dict([ (C, mod_score(category_embeddings[C], ine)) for C in categories ])
|
||||
category_flags = dict([ (C, bool(category_scores[C] > flag_threshold)) for C in categories ])
|
||||
category_scores = dict([(C, mod_score(category_embeddings[C], ine)) for C in categories])
|
||||
category_flags = dict([(C, bool(category_scores[C] > flag_threshold)) for C in categories])
|
||||
flagged = any(category_flags.values())
|
||||
|
||||
results['results'].extend([{
|
||||
@ -67,4 +66,4 @@ def moderations(input):
|
||||
|
||||
print(results)
|
||||
|
||||
return results
|
||||
return results
|
||||
|
@ -22,6 +22,7 @@ params = {
|
||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_access_control_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
@ -72,8 +73,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
if not no_debug:
|
||||
debug_msg(r_utf8)
|
||||
|
||||
def openai_error(self, message, code = 500, error_type = 'APIError', param = '', internal_message = ''):
|
||||
|
||||
def openai_error(self, message, code=500, error_type='APIError', param='', internal_message=''):
|
||||
|
||||
error_resp = {
|
||||
'error': {
|
||||
'message': message,
|
||||
@ -84,10 +85,10 @@ class Handler(BaseHTTPRequestHandler):
|
||||
}
|
||||
if internal_message:
|
||||
print(internal_message)
|
||||
#error_resp['internal_message'] = internal_message
|
||||
# error_resp['internal_message'] = internal_message
|
||||
|
||||
self.return_json(error_resp, code)
|
||||
|
||||
|
||||
def openai_error_handler(func):
|
||||
def wrapper(self):
|
||||
try:
|
||||
@ -156,7 +157,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
response = OAIcompletions.stream_chat_completions(body, is_legacy=is_legacy)
|
||||
else:
|
||||
response = OAIcompletions.stream_completions(body, is_legacy=is_legacy)
|
||||
|
||||
|
||||
for resp in response:
|
||||
self.send_sse(resp)
|
||||
|
||||
@ -182,7 +183,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
|
||||
instruction = body['instruction']
|
||||
input = body.get('input', '')
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
temperature = clamp(default(body, 'temperature', req_params['temperature']), 0.001, 1.999) # fixup absolute 0.0
|
||||
top_p = clamp(default(body, 'top_p', req_params['top_p']), 0.001, 1.0)
|
||||
|
||||
response = OAIedits.edits(instruction, input, temperature, top_p)
|
||||
@ -205,7 +206,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
input = body.get('input', body.get('text', ''))
|
||||
if not input:
|
||||
raise InvalidRequestError("Missing required argument input", params='input')
|
||||
|
||||
|
||||
if type(input) is str:
|
||||
input = [input]
|
||||
|
||||
@ -225,15 +226,15 @@ class Handler(BaseHTTPRequestHandler):
|
||||
elif self.path == '/api/v1/token-count':
|
||||
# NOT STANDARD. lifted from the api extension, but it's still very useful to calculate tokenized length client side.
|
||||
response = token_count(body['prompt'])
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/encode':
|
||||
# NOT STANDARD. needed to support logit_bias, logprobs and token arrays for native models
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
|
||||
response = token_encode(body['input'], encoding_format)
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
elif self.path == '/api/v1/token/decode':
|
||||
@ -241,7 +242,7 @@ class Handler(BaseHTTPRequestHandler):
|
||||
encoding_format = body.get('encoding_format', '')
|
||||
|
||||
response = token_decode(body['input'], encoding_format)
|
||||
|
||||
|
||||
self.return_json(response, no_debug=True)
|
||||
|
||||
else:
|
||||
|
@ -1,6 +1,7 @@
|
||||
from extensions.openai.utils import float_list_to_base64
|
||||
from modules.text_generation import encode, decode
|
||||
|
||||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
|
||||
@ -11,8 +12,8 @@ def token_count(prompt):
|
||||
}
|
||||
|
||||
|
||||
def token_encode(input, encoding_format = ''):
|
||||
#if isinstance(input, list):
|
||||
def token_encode(input, encoding_format=''):
|
||||
# if isinstance(input, list):
|
||||
tokens = encode(input)[0]
|
||||
|
||||
return {
|
||||
@ -25,9 +26,9 @@ def token_encode(input, encoding_format = ''):
|
||||
|
||||
|
||||
def token_decode(tokens, encoding_format):
|
||||
#if isinstance(input, list):
|
||||
# if encoding_format == "base64":
|
||||
# tokens = base64_to_float_list(tokens)
|
||||
# if isinstance(input, list):
|
||||
# if encoding_format == "base64":
|
||||
# tokens = base64_to_float_list(tokens)
|
||||
output = decode(tokens)[0]
|
||||
|
||||
return {
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
|
||||
def float_list_to_base64(float_list):
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
float_array = np.array(float_list, dtype="float32")
|
||||
@ -16,11 +17,13 @@ def float_list_to_base64(float_list):
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
def end_line(s):
|
||||
if s and s[-1] != '\n':
|
||||
s = s + '\n'
|
||||
return s
|
||||
|
||||
|
||||
def debug_msg(*args, **kwargs):
|
||||
if 'OPENEDAI_DEBUG' in os.environ:
|
||||
print(*args, **kwargs)
|
||||
print(*args, **kwargs)
|
||||
|
@ -126,6 +126,8 @@ def input_modifier(string):
|
||||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
|
||||
|
||||
def get_SD_pictures(description, character):
|
||||
|
||||
global params
|
||||
@ -186,6 +188,8 @@ def get_SD_pictures(description, character):
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
|
||||
|
||||
def output_modifier(string, state):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
|
@ -113,7 +113,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
if len(history['internal']) > params['chunk_count'] and user_input != '':
|
||||
chunks = []
|
||||
hist_size = len(history['internal'])
|
||||
for i in range(hist_size-1):
|
||||
for i in range(hist_size - 1):
|
||||
chunks.append(make_single_exchange(i))
|
||||
|
||||
add_chunks_to_collector(chunks, chat_collector)
|
||||
|
@ -16,7 +16,7 @@ params = {
|
||||
}
|
||||
|
||||
|
||||
def do_stt(audio,whipser_model,whipser_language):
|
||||
def do_stt(audio, whipser_model, whipser_language):
|
||||
transcription = ""
|
||||
r = sr.Recognizer()
|
||||
|
||||
@ -33,10 +33,10 @@ def do_stt(audio,whipser_model,whipser_language):
|
||||
return transcription
|
||||
|
||||
|
||||
def auto_transcribe(audio, auto_submit,whipser_model,whipser_language):
|
||||
def auto_transcribe(audio, auto_submit, whipser_model, whipser_language):
|
||||
if audio is None:
|
||||
return "", ""
|
||||
transcription = do_stt(audio,whipser_model,whipser_language)
|
||||
transcription = do_stt(audio, whipser_model, whipser_language)
|
||||
if auto_submit:
|
||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||
|
||||
@ -50,11 +50,11 @@ def ui():
|
||||
with gr.Row():
|
||||
with gr.Accordion("Settings", open=False):
|
||||
auto_submit = gr.Checkbox(label='Submit the transcribed audio automatically', value=params['auto_submit'])
|
||||
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'],choices=["tiny.en","base.en", "small.en","medium.en","tiny","base","small","medium","large"])
|
||||
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'],choices=["chinese","german","spanish","russian","korean","french","japanese","portuguese","turkish","polish","catalan","dutch","arabic","swedish","italian","indonesian","hindi","finnish","vietnamese","hebrew","ukrainian","greek","malay","czech","romanian","danish","hungarian","tamil","norwegian","thai","urdu","croatian","bulgarian","lithuanian","latin","maori","malayalam","welsh","slovak","telugu","persian","latvian","bengali","serbian","azerbaijani","slovenian","kannada","estonian","macedonian","breton","basque","icelandic","armenian","nepali","mongolian","bosnian","kazakh","albanian","swahili","galician","marathi","punjabi","sinhala","khmer","shona","yoruba","somali","afrikaans","occitan","georgian","belarusian","tajik","sindhi","gujarati","amharic","yiddish","lao","uzbek","faroese","haitian creole","pashto","turkmen","nynorsk","maltese","sanskrit","luxembourgish","myanmar","tibetan","tagalog","malagasy","assamese","tatar","hawaiian","lingala","hausa","bashkir","javanese","sundanese"])
|
||||
whipser_model = gr.Dropdown(label='Whisper Model', value=params['whipser_model'], choices=["tiny.en", "base.en", "small.en", "medium.en", "tiny", "base", "small", "medium", "large"])
|
||||
whipser_language = gr.Dropdown(label='Whisper Language', value=params['whipser_language'], choices=["chinese", "german", "spanish", "russian", "korean", "french", "japanese", "portuguese", "turkish", "polish", "catalan", "dutch", "arabic", "swedish", "italian", "indonesian", "hindi", "finnish", "vietnamese", "hebrew", "ukrainian", "greek", "malay", "czech", "romanian", "danish", "hungarian", "tamil", "norwegian", "thai", "urdu", "croatian", "bulgarian", "lithuanian", "latin", "maori", "malayalam", "welsh", "slovak", "telugu", "persian", "latvian", "bengali", "serbian", "azerbaijani", "slovenian", "kannada", "estonian", "macedonian", "breton", "basque", "icelandic", "armenian", "nepali", "mongolian", "bosnian", "kazakh", "albanian", "swahili", "galician", "marathi", "punjabi", "sinhala", "khmer", "shona", "yoruba", "somali", "afrikaans", "occitan", "georgian", "belarusian", "tajik", "sindhi", "gujarati", "amharic", "yiddish", "lao", "uzbek", "faroese", "haitian creole", "pashto", "turkmen", "nynorsk", "maltese", "sanskrit", "luxembourgish", "myanmar", "tibetan", "tagalog", "malagasy", "assamese", "tatar", "hawaiian", "lingala", "hausa", "bashkir", "javanese", "sundanese"])
|
||||
|
||||
audio.change(
|
||||
auto_transcribe, [audio, auto_submit,whipser_model,whipser_language], [shared.gradio['textbox'], audio]).then(
|
||||
auto_transcribe, [audio, auto_submit, whipser_model, whipser_language], [shared.gradio['textbox'], audio]).then(
|
||||
None, auto_submit, None, _js="(check) => {if (check) { document.getElementById('Generate').click() }}")
|
||||
whipser_model.change(lambda x: params.update({"whipser_model": x}), whipser_model, None)
|
||||
whipser_language.change(lambda x: params.update({"whipser_language": x}), whipser_language, None)
|
||||
|
@ -66,7 +66,7 @@ def add_lora_autogptq(lora_names):
|
||||
logger.error("This version of AutoGPTQ does not support LoRA. You need to install from source or wait for a new release.")
|
||||
return
|
||||
|
||||
if len(lora_names) == 0:
|
||||
if len(lora_names) == 0:
|
||||
reload_model()
|
||||
|
||||
shared.lora_names = []
|
||||
@ -108,14 +108,14 @@ def add_lora_transformers(lora_names):
|
||||
# If any LoRA needs to be removed, start over
|
||||
if len(removed_set) > 0:
|
||||
# shared.model may no longer be PeftModel
|
||||
if hasattr(shared.model, 'disable_adapter'):
|
||||
shared.model.disable_adapter()
|
||||
if hasattr(shared.model, 'disable_adapter'):
|
||||
shared.model.disable_adapter()
|
||||
shared.model = shared.model.base_model.model
|
||||
|
||||
if len(lora_names) > 0:
|
||||
params = {}
|
||||
if not shared.args.cpu:
|
||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||
if shared.args.load_in_4bit or shared.args.load_in_8bit:
|
||||
params['peft_type'] = shared.model.dtype
|
||||
else:
|
||||
params['dtype'] = shared.model.dtype
|
||||
|
@ -54,14 +54,14 @@ loaders_and_params = {
|
||||
'trust_remote_code',
|
||||
'transformers_info'
|
||||
],
|
||||
'ExLlama' : [
|
||||
'ExLlama': [
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'compress_pos_emb',
|
||||
'alpha_value',
|
||||
'exllama_info',
|
||||
],
|
||||
'ExLlama_HF' : [
|
||||
'ExLlama_HF': [
|
||||
'gpu_split',
|
||||
'max_seq_len',
|
||||
'compress_pos_emb',
|
||||
|
@ -106,11 +106,11 @@ def load_tokenizer(model_name, model):
|
||||
use_fast=False
|
||||
)
|
||||
except ValueError:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
path_to_model,
|
||||
trust_remote_code=shared.args.trust_remote_code,
|
||||
use_fast=True
|
||||
)
|
||||
)
|
||||
|
||||
if tokenizer.__class__.__name__ == 'LlamaTokenizer':
|
||||
pairs = [
|
||||
|
@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
@ -181,7 +181,7 @@ parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authent
|
||||
# API
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--api-blocking-port', type=int, default=5000, help='The listening port for the blocking API.')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||
parser.add_argument('--api-streaming-port', type=int, default=5005, help='The listening port for the streaming API.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
# Multimodal
|
||||
|
@ -116,7 +116,7 @@ def get_available_loras():
|
||||
def get_datasets(path: str, ext: str):
|
||||
# include subdirectories for raw txt files to allow training from a subdirectory of txt files
|
||||
if ext == "txt":
|
||||
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt'))+list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
return ['None'] + sorted(set([k.stem for k in list(Path(path).glob('txt')) + list(Path(path).glob('*/')) if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
return ['None'] + sorted(set([k.stem for k in Path(path).glob(f'*.{ext}') if k.stem != 'put-trainer-datasets-here']), key=natural_keys)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user