mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
Clean things up
This commit is contained in:
parent
3a99b2b030
commit
6456777b09
@ -133,7 +133,7 @@ Optionally, you can use the following command-line flags:
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--max-gpu-memory MAX_GPU_MEMORY` | Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number. |
|
||||
| `--no-listen` | Make the web UI unreachable from your local network.|
|
||||
| `--settings-file SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
|
||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example.|
|
||||
|
||||
## Presets
|
||||
|
||||
|
@ -17,7 +17,5 @@ model_name = path.name
|
||||
|
||||
print(f"Loading {model_name}...")
|
||||
model = AutoModelForCausalLM.from_pretrained(path, low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda()
|
||||
print("Model loaded.")
|
||||
|
||||
print(f"Saving to torch-dumps/{model_name}.pt")
|
||||
print(f"Model loaded.\nSaving to torch-dumps/{model_name}.pt")
|
||||
torch.save(model, Path(f"torch-dumps/{model_name}.pt"))
|
||||
|
@ -28,7 +28,6 @@ def get_file(args):
|
||||
t.close()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
model = argv[1]
|
||||
if model[-1] == '/':
|
||||
model = model[:-1]
|
||||
|
@ -1,6 +1,6 @@
|
||||
'''
|
||||
|
||||
This is a library for formatting gpt4chan outputs as nice HTML.
|
||||
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
|
||||
|
||||
'''
|
||||
|
||||
@ -267,7 +267,5 @@ def generate_chat_html(history, name1, name2):
|
||||
</div>
|
||||
"""
|
||||
|
||||
output += '<script>document.getElementById("chat").scrollTo(0, document.getElementById("chat").scrollHeight);</script>'
|
||||
output += "</div>"
|
||||
|
||||
return output
|
||||
|
28
server.py
28
server.py
@ -25,14 +25,12 @@ parser.add_argument('--auto-devices', action='store_true', help='Automatically s
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--max-gpu-memory', type=int, help='Maximum memory in GiB to allocate to the GPU when loading the model. This is useful if you get out of memory errors while trying to generate text. Must be an integer number.')
|
||||
parser.add_argument('--no-listen', action='store_true', help='Make the web UI unreachable from your local network.')
|
||||
parser.add_argument('--settings-file', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example.')
|
||||
args = parser.parse_args()
|
||||
|
||||
loaded_preset = None
|
||||
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
||||
available_models = [item for item in available_models if not item.endswith('.txt')]
|
||||
available_models = sorted(available_models, key=str.lower)
|
||||
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], list(Path('presets').glob('*.txt')))))
|
||||
available_models = sorted(set([item.replace('.pt', '') for item in map(lambda x : str(x.name), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*'))) if not item.endswith('.txt')]), key=str.lower)
|
||||
available_presets = sorted(set(map(lambda x : str(x.name).split('.')[0], Path('presets').glob('*.txt'))), key=str.lower)
|
||||
|
||||
settings = {
|
||||
'max_new_tokens': 200,
|
||||
@ -50,12 +48,12 @@ settings = {
|
||||
'stop_at_newline': True,
|
||||
}
|
||||
|
||||
if args.settings_file is not None and Path(args.settings_file).exists():
|
||||
with open(Path(args.settings_file), 'r') as f:
|
||||
if args.settings is not None and Path(args.settings).exists():
|
||||
with open(Path(args.settings), 'r') as f:
|
||||
new_settings = json.load(f)
|
||||
for i in new_settings:
|
||||
if i in settings:
|
||||
settings[i] = new_settings[i]
|
||||
for item in new_settings:
|
||||
if item in settings:
|
||||
settings[item] = new_settings[item]
|
||||
|
||||
def load_model(model_name):
|
||||
print(f"Loading {model_name}...")
|
||||
@ -87,7 +85,7 @@ def load_model(model_name):
|
||||
else:
|
||||
settings.append("torch_dtype=torch.float16")
|
||||
|
||||
settings = ', '.join(list(set(settings)))
|
||||
settings = ', '.join(set(settings))
|
||||
command = f"{command}(Path(f'models/{model_name}'), {settings})"
|
||||
model = eval(command)
|
||||
|
||||
@ -109,7 +107,7 @@ def fix_gpt4chan(s):
|
||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||
return s
|
||||
|
||||
# Fix the LaTeX equations in GALACTICA
|
||||
# Fix the LaTeX equations in galactica
|
||||
def fix_galactica(s):
|
||||
s = s.replace(r'\[', r'$')
|
||||
s = s.replace(r'\]', r'$')
|
||||
@ -154,9 +152,9 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok
|
||||
return reply, reply, generate_basic_html(reply)
|
||||
elif model_name.lower().startswith('gpt4chan'):
|
||||
reply = fix_gpt4chan(reply)
|
||||
return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_4chan_html(reply)
|
||||
else:
|
||||
return reply, 'Only applicable for galactica models.', generate_basic_html(reply)
|
||||
return reply, 'Only applicable for GALACTICA models.', generate_basic_html(reply)
|
||||
|
||||
# Choosing the default model
|
||||
if args.model is not None:
|
||||
@ -219,7 +217,7 @@ elif args.chat or args.cai_chat:
|
||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check):
|
||||
text = chat_response_cleaner(text)
|
||||
|
||||
question = context+'\n\n'
|
||||
question = f"{context}\n\n"
|
||||
for i in range(len(history)):
|
||||
if args.cai_chat:
|
||||
question += f"{name1}: {history[i][0].strip()}\n"
|
||||
|
Loading…
Reference in New Issue
Block a user