mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-23 16:38:21 +01:00
commit
b28020a9e4
@ -22,17 +22,12 @@
|
|||||||
.message-body p, .message-body li {
|
.message-body p, .message-body li {
|
||||||
font-size: 15px !important;
|
font-size: 15px !important;
|
||||||
line-height: 24px !important;
|
line-height: 24px !important;
|
||||||
list-style-position: outside;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
.message-body p, .chat .message-body ul, .chat .message-body ol {
|
||||||
margin-bottom: 16px !important;
|
margin-bottom: 16px !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
.chat .message-body ul, .chat .message-body ol {
|
|
||||||
padding-inline-start: 2em;
|
|
||||||
}
|
|
||||||
|
|
||||||
.message-body p:last-child, .chat .message-body ul:last-child, .chat .message-body ol:last-child {
|
.message-body p:last-child, .chat .message-body ul:last-child, .chat .message-body ol:last-child {
|
||||||
margin-bottom: 0 !important;
|
margin-bottom: 0 !important;
|
||||||
}
|
}
|
||||||
|
@ -364,6 +364,14 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
|||||||
padding-bottom: 0 !important;
|
padding-bottom: 0 !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.message-body li {
|
||||||
|
list-style-position: outside;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat .message-body ul, .chat .message-body ol {
|
||||||
|
padding-inline-start: 2em;
|
||||||
|
}
|
||||||
|
|
||||||
.message-body li:not(:last-child) {
|
.message-body li:not(:last-child) {
|
||||||
margin-top: 0 !important;
|
margin-top: 0 !important;
|
||||||
margin-bottom: 2px !important;
|
margin-bottom: 2px !important;
|
||||||
|
@ -51,59 +51,9 @@ from modules.logging_colors import logger
|
|||||||
from modules.models import reload_model
|
from modules.models import reload_model
|
||||||
from modules.utils import natural_keys
|
from modules.utils import natural_keys
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.filterwarnings(action = "ignore", message="torch.utils.checkpoint:")
|
||||||
## just temporary to avoid warning
|
warnings.filterwarnings(action = "ignore", message="`do_sample` is set to `False`")
|
||||||
|
|
||||||
import inspect
|
|
||||||
|
|
||||||
from typing import Callable, Optional, Tuple, ContextManager
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
|
||||||
def my_checkpoint(
|
|
||||||
function,
|
|
||||||
*args,
|
|
||||||
use_reentrant: Optional[bool] = None,
|
|
||||||
context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = torch.utils.checkpoint.noop_context_fn,
|
|
||||||
determinism_check: str = torch.utils.checkpoint._DEFAULT_DETERMINISM_MODE,
|
|
||||||
debug: bool = False,
|
|
||||||
**kwargs
|
|
||||||
):
|
|
||||||
|
|
||||||
if use_reentrant is None:
|
|
||||||
#print ("reentran = NONE")
|
|
||||||
use_reentrant = True
|
|
||||||
# Hack to mix *args with **kwargs in a python 2.7-compliant way
|
|
||||||
preserve = kwargs.pop("preserve_rng_state", True)
|
|
||||||
if kwargs and use_reentrant:
|
|
||||||
raise ValueError(
|
|
||||||
"Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_reentrant:
|
|
||||||
if context_fn is not torch.utils.checkpoint.noop_context_fn or debug is not False:
|
|
||||||
raise ValueError(
|
|
||||||
"Passing `context_fn` or `debug` is only supported when "
|
|
||||||
"use_reentrant=False."
|
|
||||||
)
|
|
||||||
return torch.utils.checkpoint.CheckpointFunction.apply(function, preserve, *args)
|
|
||||||
else:
|
|
||||||
|
|
||||||
print ("reentran = FALSE")
|
|
||||||
gen = torch.utils.checkpoint._checkpoint_without_reentrant_generator(
|
|
||||||
function, preserve, context_fn, determinism_check, debug, *args, **kwargs
|
|
||||||
)
|
|
||||||
# Runs pre-forward logic
|
|
||||||
next(gen)
|
|
||||||
ret = function(*args, **kwargs)
|
|
||||||
# Runs post-forward logic
|
|
||||||
try:
|
|
||||||
next(gen)
|
|
||||||
except StopIteration:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"display_name": "Training PRO",
|
"display_name": "Training PRO",
|
||||||
@ -121,6 +71,7 @@ non_serialized_params = {
|
|||||||
"save_epochs": 0,
|
"save_epochs": 0,
|
||||||
"checkpoint_offset": 0,
|
"checkpoint_offset": 0,
|
||||||
"epoch_offset":0,
|
"epoch_offset":0,
|
||||||
|
"safe_serialization": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
MODEL_CLASSES = {v[1]: v[0] for v in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.items()}
|
||||||
@ -150,7 +101,7 @@ def ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
# YY.MM.DD
|
# YY.MM.DD
|
||||||
gr.Markdown("`Ver: 23.10.20` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
|
gr.Markdown("`Ver: 23.10.20 (REV2)` This is enhanced version of QLora Training. [Maintained by FP](https://github.com/FartyPants/Training_PRO/tree/main)")
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column(scale=5):
|
with gr.Column(scale=5):
|
||||||
@ -290,7 +241,7 @@ def ui():
|
|||||||
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
stride_length = gr.Slider(label='Stride', minimum=1, maximum=2048, value=512, step=1, info='Used to make the evaluation faster at the cost of accuracy. 1 = slowest but most accurate. 512 is a common value.')
|
||||||
|
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
max_length = gr.Slider(label='max_length', minimum=0, maximum=8096, value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
max_length = gr.Slider(label='max_length', minimum=0, maximum=shared.settings['truncation_length_max'], value=0, step=1, info='The context for each evaluation. If set to 0, the maximum context length for the model will be used.')
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
start_current_evaluation = gr.Button("Evaluate loaded model")
|
start_current_evaluation = gr.Button("Evaluate loaded model")
|
||||||
@ -713,7 +664,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
train_template.clear()
|
train_template.clear()
|
||||||
|
|
||||||
|
|
||||||
#reset stuff
|
#reset stuff
|
||||||
print(f"*** LoRA: {lora_name} ***")
|
print(f"*** LoRA: {lora_name} ***")
|
||||||
non_serialized_params.update({"stop_at_loss": stop_at_loss})
|
non_serialized_params.update({"stop_at_loss": stop_at_loss})
|
||||||
@ -726,24 +676,6 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
non_serialized_params.update({"epoch_offset": 0})
|
non_serialized_params.update({"epoch_offset": 0})
|
||||||
train_log_graph.clear()
|
train_log_graph.clear()
|
||||||
|
|
||||||
# === once fixed, this can be removed ==============================
|
|
||||||
if hasattr(torch.utils.checkpoint, 'noop_context_fn'):
|
|
||||||
print("Testing Pytorch...")
|
|
||||||
old_checkpoint_signature = inspect.signature(torch.utils.checkpoint.checkpoint)
|
|
||||||
|
|
||||||
# Get the signature of your new checkpoint function
|
|
||||||
my_checkpoint_signature = inspect.signature(my_checkpoint)
|
|
||||||
|
|
||||||
# Check if the signatures match
|
|
||||||
if old_checkpoint_signature.parameters == my_checkpoint_signature.parameters:
|
|
||||||
print(F"{RED}Overriding Torch checkpoint function to avoid repeated 'use_reentrant not explicitly set' warnings{RESET}")
|
|
||||||
#print(" - Note: Transformers need to pass use_reentrant in llama.modeling_llama in def forward, layer_outputs = torch.utils.checkpoint.checkpoint")
|
|
||||||
#print(" Once they do, this function can be removed")
|
|
||||||
torch.utils.checkpoint.checkpoint = my_checkpoint
|
|
||||||
|
|
||||||
|
|
||||||
# END OF FPHAM SENTENCE SPLIT functions ===================
|
|
||||||
|
|
||||||
# == Prep the dataset, format, etc ==
|
# == Prep the dataset, format, etc ==
|
||||||
if raw_text_file not in ['None', '']:
|
if raw_text_file not in ['None', '']:
|
||||||
train_template["template_type"] = "raw_text"
|
train_template["template_type"] = "raw_text"
|
||||||
@ -1025,7 +957,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
force_save = True
|
force_save = True
|
||||||
|
|
||||||
if force_save:
|
if force_save:
|
||||||
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/")
|
lora_model.save_pretrained(f"{lora_file_path}/{folder_save}/", safe_serialization = non_serialized_params['safe_serialization'])
|
||||||
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
|
print(f"\033[1;30;40mStep: {tracked.current_steps:6} \033[0;37;0m Saved: [{folder_save}]")
|
||||||
# Save log
|
# Save log
|
||||||
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/{folder_save}/training_log.json", 'w', encoding='utf-8') as file:
|
||||||
@ -1252,7 +1184,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
log_train_dataset(trainer)
|
log_train_dataset(trainer)
|
||||||
trainer.train()
|
trainer.train()
|
||||||
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
# Note: save in the thread in case the gradio thread breaks (eg browser closed)
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])
|
||||||
logger.info("LoRA training run is completed and saved.")
|
logger.info("LoRA training run is completed and saved.")
|
||||||
# Save log
|
# Save log
|
||||||
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
with open(f"{lora_file_path}/training_log.json", 'w', encoding='utf-8') as file:
|
||||||
@ -1353,7 +1285,7 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
|||||||
|
|
||||||
if not tracked.did_save:
|
if not tracked.did_save:
|
||||||
logger.info("Training complete, saving...")
|
logger.info("Training complete, saving...")
|
||||||
lora_model.save_pretrained(lora_file_path)
|
lora_model.save_pretrained(lora_file_path, safe_serialization = non_serialized_params['safe_serialization'])
|
||||||
|
|
||||||
if WANT_INTERRUPT:
|
if WANT_INTERRUPT:
|
||||||
logger.info("Training interrupted.")
|
logger.info("Training interrupted.")
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
instruction_template: |-
|
|
||||||
{%- set found_item = false -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{%- set found_item = true -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if not found_item -%}
|
|
||||||
{{- '' + 'A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human\'s questions.' + '\n\n' -}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages %}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{{- '' + message['content'] + '\n\n' -}}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if message['role'] == 'user' -%}
|
|
||||||
{{-'### Human: ' + message['content'] + '\n'-}}
|
|
||||||
{%- else -%}
|
|
||||||
{{-'### Assistant: ' + message['content'] + '\n' -}}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{-'### Assistant:'-}}
|
|
||||||
{%- endif -%}
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
|||||||
instruction_template: |-
|
|
||||||
{%- set found_item = false -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{%- set found_item = true -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if not found_item -%}
|
|
||||||
{{- '' + '' + '' -}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages %}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{{- '' + message['content'] + '' -}}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if message['role'] == 'user' -%}
|
|
||||||
{{-'<human>: ' + message['content'] + '\n'-}}
|
|
||||||
{%- else -%}
|
|
||||||
{{-'<bot>:' + message['content'] + '\n' -}}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{-'<bot>:'-}}
|
|
||||||
{%- endif -%}
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
|||||||
instruction_template: |-
|
|
||||||
{%- set found_item = false -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{%- set found_item = true -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if not found_item -%}
|
|
||||||
{{- '' + 'A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user\'s questions.' + '\n\n' -}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages %}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{{- '' + message['content'] + '\n\n' -}}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if message['role'] == 'user' -%}
|
|
||||||
{{-'USER: ' + message['content'] + '\n'-}}
|
|
||||||
{%- else -%}
|
|
||||||
{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{-'ASSISTANT:'-}}
|
|
||||||
{%- endif -%}
|
|
||||||
|
|
@ -6,20 +6,19 @@ instruction_template: |-
|
|||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{%- endfor -%}
|
{%- endfor -%}
|
||||||
{%- if not found_item -%}
|
{%- if not found_item -%}
|
||||||
{{- '' + '' + '' -}}
|
{{-'SYSTEM: ' + '' + '\n' -}}
|
||||||
{%- endif %}
|
{%- endif %}
|
||||||
{%- for message in messages %}
|
{%- for message in messages %}
|
||||||
{%- if message['role'] == 'system' -%}
|
{%- if message['role'] == 'system' -%}
|
||||||
{{- '' + message['content'] + '' -}}
|
{{-'SYSTEM: ' + message['content'] + '\n' -}}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{%- if message['role'] == 'user' -%}
|
{%- if message['role'] == 'user' -%}
|
||||||
{{-'USER: ' + message['content'] + '\n'-}}
|
{{-'USER: ' + message['content'] + '\n'-}}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{{-'ASSISTANT:' + message['content'] + '\n' -}}
|
{{-'ASSISTANT: ' + message['content'] + '</s>\n' -}}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{%- endfor -%}
|
{%- endfor -%}
|
||||||
{%- if add_generation_prompt -%}
|
{%- if add_generation_prompt -%}
|
||||||
{{-'ASSISTANT:'-}}
|
{{-'ASSISTANT:'-}}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
|
|
@ -13,9 +13,9 @@ instruction_template: |-
|
|||||||
{{- '' + message['content'] + '\n' -}}
|
{{- '' + message['content'] + '\n' -}}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{%- if message['role'] == 'user' -%}
|
{%- if message['role'] == 'user' -%}
|
||||||
{{-'\n<|USER|>: ' + message['content'] + '\n'-}}
|
{{-'<|USER|>: ' + message['content'] + '\n'-}}
|
||||||
{%- else -%}
|
{%- else -%}
|
||||||
{{-'<|ASSISTANT|>: ' + message['content'] + '' -}}
|
{{-'<|ASSISTANT|>: ' + message['content'] + '\n' -}}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{%- endif -%}
|
{%- endif -%}
|
||||||
{%- endfor -%}
|
{%- endfor -%}
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
instruction_template: |-
|
|
||||||
{%- set found_item = false -%}
|
|
||||||
{%- for message in messages -%}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{%- set found_item = true -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if not found_item -%}
|
|
||||||
{{- '' + 'Below is an instruction that describes a task. Write a response that appropriately completes the request.' + '\n\n' -}}
|
|
||||||
{%- endif %}
|
|
||||||
{%- for message in messages %}
|
|
||||||
{%- if message['role'] == 'system' -%}
|
|
||||||
{{- '' + message['content'] + '\n\n' -}}
|
|
||||||
{%- else -%}
|
|
||||||
{%- if message['role'] == 'user' -%}
|
|
||||||
{{-'### Instruction:\n' + message['content'] + '\n\n'-}}
|
|
||||||
{%- else -%}
|
|
||||||
{{-'### Response:\n' + message['content'] + '\n\n' -}}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endif -%}
|
|
||||||
{%- endfor -%}
|
|
||||||
{%- if add_generation_prompt -%}
|
|
||||||
{{-'### Response:\n'-}}
|
|
||||||
{%- endif -%}
|
|
||||||
|
|
@ -38,7 +38,7 @@
|
|||||||
instruction_template: 'LLaVA'
|
instruction_template: 'LLaVA'
|
||||||
custom_stopping_strings: '"\n###"'
|
custom_stopping_strings: '"\n###"'
|
||||||
.*llava.*1.5:
|
.*llava.*1.5:
|
||||||
instruction_template: 'LLaVA-v1'
|
instruction_template: 'Vicuna-v1.1'
|
||||||
.*wizard.*mega:
|
.*wizard.*mega:
|
||||||
instruction_template: 'Wizard-Mega'
|
instruction_template: 'Wizard-Mega'
|
||||||
custom_stopping_strings: '"</s>"'
|
custom_stopping_strings: '"</s>"'
|
||||||
@ -108,7 +108,7 @@
|
|||||||
.*bactrian:
|
.*bactrian:
|
||||||
instruction_template: 'Bactrian'
|
instruction_template: 'Bactrian'
|
||||||
.*(h2ogpt-oig-|h2ogpt-oasst1-|h2ogpt-research-oasst1-):
|
.*(h2ogpt-oig-|h2ogpt-oasst1-|h2ogpt-research-oasst1-):
|
||||||
instruction_template: 'H2O-human_bot'
|
instruction_template: 'INCITE-Chat'
|
||||||
.*h2ogpt-gm-:
|
.*h2ogpt-gm-:
|
||||||
instruction_template: 'H2O-prompt_answer'
|
instruction_template: 'H2O-prompt_answer'
|
||||||
.*manticore:
|
.*manticore:
|
||||||
@ -128,7 +128,7 @@
|
|||||||
.*lazarus:
|
.*lazarus:
|
||||||
instruction_template: 'Alpaca'
|
instruction_template: 'Alpaca'
|
||||||
.*guanaco-.*(7|13|33|65)b:
|
.*guanaco-.*(7|13|33|65)b:
|
||||||
instruction_template: 'Guanaco'
|
instruction_template: 'Vicuna-v0'
|
||||||
.*hypermantis:
|
.*hypermantis:
|
||||||
instruction_template: 'Alpaca'
|
instruction_template: 'Alpaca'
|
||||||
.*open-llama-.*-open-instruct:
|
.*open-llama-.*-open-instruct:
|
||||||
@ -144,7 +144,7 @@
|
|||||||
.*wizardcoder:
|
.*wizardcoder:
|
||||||
instruction_template: 'Alpaca'
|
instruction_template: 'Alpaca'
|
||||||
.*minotaur:
|
.*minotaur:
|
||||||
instruction_template: 'Minotaur'
|
instruction_template: 'Manticore Chat'
|
||||||
.*orca_mini:
|
.*orca_mini:
|
||||||
instruction_template: 'Orca Mini'
|
instruction_template: 'Orca Mini'
|
||||||
.*(platypus|gplatty|superplatty):
|
.*(platypus|gplatty|superplatty):
|
||||||
@ -186,3 +186,5 @@
|
|||||||
instruction_template: 'ChatML'
|
instruction_template: 'ChatML'
|
||||||
.*Yi-34B-Chat:
|
.*Yi-34B-Chat:
|
||||||
instruction_template: 'ChatML'
|
instruction_template: 'ChatML'
|
||||||
|
(dolphin).*:
|
||||||
|
instruction_template: 'ChatML'
|
||||||
|
@ -112,6 +112,13 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
if user_input and not impersonate and not _continue:
|
if user_input and not impersonate and not _continue:
|
||||||
messages.append({"role": "user", "content": user_input})
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
def remove_extra_bos(prompt):
|
||||||
|
for bos_token in ['<s>', '<|startoftext|>']:
|
||||||
|
while prompt.startswith(bos_token):
|
||||||
|
prompt = prompt[len(bos_token):]
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
def make_prompt(messages):
|
def make_prompt(messages):
|
||||||
if state['mode'] == 'chat-instruct' and _continue:
|
if state['mode'] == 'chat-instruct' and _continue:
|
||||||
prompt = renderer(messages=messages[:-1])
|
prompt = renderer(messages=messages[:-1])
|
||||||
@ -123,6 +130,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
if state['custom_system_message'].strip() != '':
|
if state['custom_system_message'].strip() != '':
|
||||||
outer_messages.append({"role": "system", "content": state['custom_system_message']})
|
outer_messages.append({"role": "system", "content": state['custom_system_message']})
|
||||||
|
|
||||||
|
prompt = remove_extra_bos(prompt)
|
||||||
command = state['chat-instruct_command']
|
command = state['chat-instruct_command']
|
||||||
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
command = command.replace('<|character|>', state['name2'] if not impersonate else state['name1'])
|
||||||
command = command.replace('<|prompt|>', prompt)
|
command = command.replace('<|prompt|>', prompt)
|
||||||
@ -153,6 +161,7 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
|||||||
|
|
||||||
prompt += prefix
|
prompt += prefix
|
||||||
|
|
||||||
|
prompt = remove_extra_bos(prompt)
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
prompt = make_prompt(messages)
|
prompt = make_prompt(messages)
|
||||||
|
@ -82,8 +82,9 @@ def load_metadata(fname):
|
|||||||
if value_type == GGUFValueType.ARRAY:
|
if value_type == GGUFValueType.ARRAY:
|
||||||
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
|
ltype = GGUFValueType(struct.unpack("<I", file.read(4))[0])
|
||||||
length = struct.unpack("<Q", file.read(8))[0]
|
length = struct.unpack("<Q", file.read(8))[0]
|
||||||
for j in range(length):
|
|
||||||
_ = get_single(ltype, file)
|
arr = [get_single(ltype, file) for _ in range(length)]
|
||||||
|
metadata[key.decode()] = arr
|
||||||
else:
|
else:
|
||||||
value = get_single(value_type, file)
|
value = get_single(value_type, file)
|
||||||
metadata[key.decode()] = value
|
metadata[key.decode()] = value
|
||||||
|
@ -64,6 +64,16 @@ def get_model_metadata(model):
|
|||||||
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear']
|
model_settings['compress_pos_emb'] = metadata['llama.rope.scale_linear']
|
||||||
if 'llama.rope.freq_base' in metadata:
|
if 'llama.rope.freq_base' in metadata:
|
||||||
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
model_settings['rope_freq_base'] = metadata['llama.rope.freq_base']
|
||||||
|
if 'tokenizer.chat_template' in metadata:
|
||||||
|
template = metadata['tokenizer.chat_template']
|
||||||
|
eos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.eos_token_id']]
|
||||||
|
bos_token = metadata['tokenizer.ggml.tokens'][metadata['tokenizer.ggml.bos_token_id']]
|
||||||
|
template = template.replace('eos_token', "'{}'".format(eos_token))
|
||||||
|
template = template.replace('bos_token', "'{}'".format(bos_token))
|
||||||
|
|
||||||
|
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||||
|
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
|
||||||
|
model_settings['instruction_template_str'] = template
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Transformers metadata
|
# Transformers metadata
|
||||||
@ -114,7 +124,6 @@ def get_model_metadata(model):
|
|||||||
template = template.replace(k, "'{}'".format(value))
|
template = template.replace(k, "'{}'".format(value))
|
||||||
|
|
||||||
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
template = re.sub(r'raise_exception\([^)]*\)', "''", template)
|
||||||
|
|
||||||
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
|
model_settings['instruction_template'] = 'Custom (obtained from model metadata)'
|
||||||
model_settings['instruction_template_str'] = template
|
model_settings['instruction_template_str'] = template
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ def create_chat_settings_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', value='Custom', info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes='slim-dropdown')
|
shared.gradio['instruction_template'] = gr.Dropdown(choices=utils.get_available_instruction_templates(), label='Saved instruction templates', value='Select template to load...', elem_classes='slim-dropdown')
|
||||||
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
ui.create_refresh_button(shared.gradio['instruction_template'], lambda: None, lambda: {'choices': utils.get_available_instruction_templates()}, 'refresh-button', interactive=not mu)
|
||||||
shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button')
|
shared.gradio['load_template'] = gr.Button("Load", elem_classes='refresh-button')
|
||||||
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
|
shared.gradio['save_template'] = gr.Button('💾', elem_classes='refresh-button', interactive=not mu)
|
||||||
@ -119,7 +119,7 @@ def create_chat_settings_ui():
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
with gr.Column():
|
with gr.Column():
|
||||||
shared.gradio['custom_system_message'] = gr.Textbox(value=shared.settings['custom_system_message'], lines=2, label='Custom system message', info='If not empty, will be used instead of the default one.', elem_classes=['add_scrollbar'])
|
shared.gradio['custom_system_message'] = gr.Textbox(value=shared.settings['custom_system_message'], lines=2, label='Custom system message', info='If not empty, will be used instead of the default one.', elem_classes=['add_scrollbar'])
|
||||||
shared.gradio['instruction_template_str'] = gr.Textbox(value='', label='Instruction template', lines=24, elem_classes=['add_scrollbar', 'monospace'])
|
shared.gradio['instruction_template_str'] = gr.Textbox(value='', label='Instruction template', lines=24, info='Change this according to the model/LoRA that you are using. Used in instruct and chat-instruct modes.', elem_classes=['add_scrollbar', 'monospace'])
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button'])
|
shared.gradio['send_instruction_to_default'] = gr.Button('Send to default', elem_classes=['small-button'])
|
||||||
shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button'])
|
shared.gradio['send_instruction_to_notebook'] = gr.Button('Send to notebook', elem_classes=['small-button'])
|
||||||
@ -299,7 +299,10 @@ def create_event_handlers():
|
|||||||
|
|
||||||
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
|
shared.gradio['delete_character'].click(lambda: gr.update(visible=True), None, gradio('character_deleter'))
|
||||||
|
|
||||||
shared.gradio['load_template'].click(chat.load_instruction_template, gradio('instruction_template'), gradio('instruction_template_str'))
|
shared.gradio['load_template'].click(
|
||||||
|
chat.load_instruction_template, gradio('instruction_template'), gradio('instruction_template_str')).then(
|
||||||
|
lambda: "Select template to load...", None, gradio('instruction_template'))
|
||||||
|
|
||||||
shared.gradio['save_template'].click(
|
shared.gradio['save_template'].click(
|
||||||
lambda: 'My Template.yaml', None, gradio('save_filename')).then(
|
lambda: 'My Template.yaml', None, gradio('save_filename')).then(
|
||||||
lambda: 'instruction-templates/', None, gradio('save_root')).then(
|
lambda: 'instruction-templates/', None, gradio('save_root')).then(
|
||||||
|
@ -105,7 +105,7 @@ def get_available_instruction_templates():
|
|||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
paths = (x for x in Path(path).iterdir() if x.suffix in ('.json', '.yaml', '.yml'))
|
||||||
|
|
||||||
return ['Custom'] + sorted(set((k.stem for k in paths)), key=natural_keys)
|
return ['Select template to load...'] + sorted(set((k.stem for k in paths)), key=natural_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_available_extensions():
|
def get_available_extensions():
|
||||||
|
Loading…
Reference in New Issue
Block a user