diff --git a/server.py b/server.py index e74a5c3c..4991ba50 100644 --- a/server.py +++ b/server.py @@ -141,7 +141,7 @@ def generate_reply(question, tokens, inference_settings, selected_model, eos_tok input_ids = encode(question, 1) preset = preset.replace('max_new_tokens=tokens', 'max_new_tokens=1') - cuda = ".cuda()" if not args.cpu else "" + cuda = "" if args.cpu else ".cuda()" for i in range(tokens): if eos_token is None: output = eval(f"model.generate(input_ids, {preset}){cuda}")