diff --git a/modules/text_generation.py b/modules/text_generation.py index e1ee5294..caa77df9 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -194,7 +194,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi yield formatted_outputs(reply, shared.model_name) if not shared.args.flexgen: - if output[-1] == n: + if int(output[-1]) == int(n): break input_ids = torch.reshape(output, (1, output.shape[0])) else: