From f4005164f4318ce8ba728d0ed7de7b7d40315bf3 Mon Sep 17 00:00:00 2001 From: Pete <33569918+jparmstr@users.noreply.github.com> Date: Thu, 3 Aug 2023 19:01:15 -0400 Subject: [PATCH] Fix llama.cpp truncation (#3400) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/llamacpp_model.py | 7 +++++++ modules/text_generation.py | 1 - 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/modules/llamacpp_model.py b/modules/llamacpp_model.py index 53177f4f..e5401378 100644 --- a/modules/llamacpp_model.py +++ b/modules/llamacpp_model.py @@ -6,6 +6,7 @@ import torch from modules import shared from modules.callbacks import Iteratorize from modules.logging_colors import logger +from modules.text_generation import get_max_prompt_length import llama_cpp @@ -91,6 +92,12 @@ class LlamaCppModel: LogitsProcessorList = llama_cpp_lib().LogitsProcessorList prompt = prompt if type(prompt) is str else prompt.decode() + + # Handle truncation + prompt = self.encode(prompt) + prompt = prompt[-get_max_prompt_length(state):] + prompt = self.decode(prompt).decode('utf-8') + completion_chunks = self.model.create_completion( prompt=prompt, max_tokens=state['max_new_tokens'], diff --git a/modules/text_generation.py b/modules/text_generation.py index f6f71990..7507a731 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -39,7 +39,6 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt if shared.model.__class__.__name__ in ['LlamaCppModel', 'RWKVModel']: input_ids = shared.tokenizer.encode(str(prompt)) input_ids = np.array(input_ids).reshape(1, len(input_ids)) - return input_ids else: input_ids = shared.tokenizer.encode(str(prompt), return_tensors='pt', add_special_tokens=add_special_tokens)