Clean things up

This commit is contained in:
oobabooga 2023-01-16 16:35:45 -03:00
parent 3a99b2b030
commit 6456777b09
5 changed files with 16 additions and 23 deletions

View File

@ -133,7 +133,7 @@ Optionally, you can use the following command-line flags:
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. |
| `--no-listen` | Make the web UI unreachable from your local network.|
| `--settings-file SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
## Presets

View File

@ -17,7 +17,5 @@ model_name = path.name
print(f"Loading {model_name}...")
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
print("Model loaded.")
print(f"Saving to torch-dumps/{model_name}.pt")
print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt")
torch.save(model, Path(f"torch-dumps/{model_name}.pt"))

View File

@ -28,7 +28,6 @@ def get_file(args):
t.close()
if __name__ == '__main__':
model = argv[1]
if model[-1] == '/':
model = model[:-1]

View File

@ -1,6 +1,6 @@
'''
This is a library for formatting gpt4chan outputs as nice HTML.
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
'''
@ -267,7 +267,5 @@ def generate_chat_html(history, name1, name2):
</div>
"""
output += '<script>document.getElementById("chat").scrollTo(0, document.getElementById("chat").scrollHeight);</script>'
output += "</div>"
return output

View File

@ -25,14 +25,12 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
parser.add_argument('--settings-file', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
args = parser.parse_args()
loaded_preset = None
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
available_models = [item for item in available_models if not item.endswith('.txt')]
available_models = sorted(available_models, key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower)
settings = {
'max_new_tokens': 200,
@ -50,12 +48,12 @@ settings = {
'stop_at_newline': True,
}
if args.settings_file is not None and Path(args.settings_file).exists():
with open(Path(args.settings_file), 'r') as f:
if args.settings is not None and Path(args.settings).exists():
with open(Path(args.settings), 'r') as f:
new_settings = json.load(f)
for i in new_settings:
if i in settings:
settings[i] = new_settings[i]
for item in new_settings:
if item in settings:
settings[item] = new_settings[item]
def load_model(model_name):
print(f"Loading {model_name}...")
@ -87,7 +85,7 @@ def load_model(model_name):
else:
settings.append("torch_dtype=torch.float16")
settings = ', '.join(list(set(settings)))
settings = ', '.join(set(settings))
command = f"{command}(Path(f'models/{model_name}'), {settings})"
model = eval(command)
@ -109,7 +107,7 @@ def fix_gpt4chan(s):
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
return s
# Fix the LaTeX equations in GALACTICA
# Fix the LaTeX equations in galactica
def fix_galactica(s):
s = s.replace(r'\[', r'$')
s = s.replace(r'\]', r'$')
@ -154,9 +152,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
return reply, reply, generate_basic_html(reply)
elif model_name.lower().startswith('gpt4chan'):
reply = fix_gpt4chan(reply)
return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
else:
return reply, 'Only applicable for galactica models.', generate_basic_html(reply)
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
# Choosing the default model
if args.model is not None:
@ -219,7 +217,7 @@ elif args.chat or args.cai_chat:
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
text = chat_response_cleaner(text)
question = context+'\n\n'
question = f"{context}\n\n"
for i in range(len(history)):
if args.cai_chat:
question += f"{name1}: {history[i][0].strip()}\n"