This commit is contained in:
oobabooga 2024-06-12 19:00:21 -07:00
parent 2d196ed2fe
commit a36fa73071
2 changed files with 3 additions and 5 deletions

View File

@ -1,5 +1,4 @@
import gc
import logging
import os
import pprint
import re
@ -29,7 +28,6 @@ import modules.shared as shared
from modules import RoPE, sampler_hijack
from modules.logging_colors import logger
from modules.models_settings import get_model_metadata
from modules.relative_imports import RelativeImport
transformers.logging.set_verbosity_error()
@ -267,7 +265,7 @@ def llamacpp_loader(model_name):
if path.is_file():
model_file = path
else:
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
model_file = sorted(Path(f'{shared.args.model_dir}/{model_name}').glob('*.gguf'))[0]
logger.info(f"llama.cpp weights detected: \"{model_file}\"")
model, tokenizer = LlamaCppModel.from_pretrained(model_file)

View File

@ -218,7 +218,7 @@ class DRYLogitsProcessor(LogitsProcessor):
match_lengths = {}
for i in match_indices:
next_token = input_ids_row[i+1].item()
next_token = input_ids_row[i + 1].item()
if next_token in self.sequence_breakers:
continue
@ -234,7 +234,7 @@ class DRYLogitsProcessor(LogitsProcessor):
# Start of input reached.
break
previous_token = input_ids_row[-(match_length+1)].item()
previous_token = input_ids_row[-(match_length + 1)].item()
if input_ids_row[j] != previous_token:
# Start of match reached.
break