From 97a6a50d98c96554ef0ece3b18b68501a78aca01 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Thu, 4 May 2023 15:55:39 -0300 Subject: [PATCH] Use oasst tokenizer instead of universal tokenizer --- models/config.yaml | 1 - modules/models.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/models/config.yaml b/models/config.yaml index 5e5c7cf8..9ddf0e5a 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -24,7 +24,6 @@ llama-[0-9]*b-4bit$: .*(oasst-sft-1-pythia-12b|oasst-sft-6-llama-30b): mode: 'instruct' instruction_template: 'Open Assistant' - custom_stopping_strings: '"<|endoftext|>"' .*vicuna: mode: 'instruct' instruction_template: 'Vicuna-v0' diff --git a/modules/models.py b/modules/models.py index 8151c5e2..01201eee 100644 --- a/modules/models.py +++ b/modules/models.py @@ -54,6 +54,8 @@ def find_model_type(model_name): return 'galactica' elif 'llava' in model_name_lower: return 'llava' + elif 'oasst' in model_name_lower: + return 'oasst' elif any((k in model_name_lower for k in ['gpt4chan', 'gpt-4chan'])): return 'gpt4chan' else: @@ -227,7 +229,7 @@ def load_model(model_name): tokenizer = None # Try to load an universal LLaMA tokenizer - if shared.model_type != 'llava': + if shared.model_type not in ['llava', 'oasst']: for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: if p.exists(): logging.info(f"Loading the universal LLaMA tokenizer from {p}...")