Merge remote-tracking branch 'refs/remotes/origin/dev' into dev

This commit is contained in:
oobabooga 2024-09-02 21:16:39 -07:00
commit 68d52c60f3
3 changed files with 9 additions and 3 deletions

View File

@ -241,7 +241,7 @@ def ui():
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.') stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
with gr.Column(): with gr.Column():
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.') max_length = gr.Number(label='max_length', precision=0, step=256, value=0, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
with gr.Row(): with gr.Row():
start_current_evaluation = gr.Button("Evaluate loaded model") start_current_evaluation = gr.Button("Evaluate loaded model")

View File

@ -154,8 +154,9 @@ def convert_history(history):
elif item['type'] == 'text' and isinstance(item['text'], str): elif item['type'] == 'text' and isinstance(item['text'], str):
content = item['text'] content = item['text']
if image_url and content: if image_url:
new_history.append({"image_url": image_url, "role": "user"}) new_history.append({"image_url": image_url, "role": "user"})
if content:
new_history.append({"content": content, "role": "user"}) new_history.append({"content": content, "role": "user"})
else: else:
new_history.append(entry) new_history.append(entry)

View File

@ -274,7 +274,12 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '): if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):
first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from])) first_token = shared.tokenizer.convert_ids_to_tokens(int(output_ids[starting_from]))
if isinstance(first_token, (bytes,)): if isinstance(first_token, (bytes,)):
first_token = first_token.decode('utf8') #try to decode the bytes to a string
try:
first_token = first_token.decode('utf8')
#if it fails, which means it's not a string in this turn, just ignore it
except UnicodeDecodeError:
first_token = ''
if first_token.startswith(''): if first_token.startswith(''):
reply = ' ' + reply reply = ' ' + reply