diff --git a/server.py b/server.py index 37e45cc6..bf9ff219 100644 --- a/server.py +++ b/server.py @@ -87,10 +87,11 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(Path(f"models/{model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.float16).cuda() # Custom else: - settings = ["low_cpu_mem_usage=True"] command = "AutoModelForCausalLM.from_pretrained" + settings = [] if args.cpu: + settings.append("low_cpu_mem_usage=True") settings.append("torch_dtype=torch.float32") else: settings.append("device_map='auto'") @@ -374,7 +375,7 @@ if args.chat or args.cai_chat: reply = reply[idx + 1 + len(apply_extensions(f"{current}:", "bot_prefix")):] else: reply = reply[idx + 1 + len(f"{current}:"):] - + if check: reply = reply.split('\n')[0].strip() else: