From 92cdb4f22b86459f92246a5108f859b351f93ce2 Mon Sep 17 00:00:00 2001 From: Vincent Brouwers Date: Wed, 26 Apr 2023 03:39:04 +0200 Subject: [PATCH] Seq2Seq support (including FLAN-T5) (#1535) --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- modules/models.py | 13 ++++++++++-- modules/text_generation.py | 41 ++++++++++++++++++-------------------- 2 files changed, 30 insertions(+), 24 deletions(-) diff --git a/modules/models.py b/modules/models.py index a17fba4b..4ebb6597 100644 --- a/modules/models.py +++ b/modules/models.py @@ -11,7 +11,8 @@ import torch import transformers from accelerate import infer_auto_device_map, init_empty_weights from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, - AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer) + AutoModelForSeq2SeqLM, AutoTokenizer, + BitsAndBytesConfig, LlamaTokenizer) import modules.shared as shared from modules import llama_attn_hijack @@ -55,7 +56,12 @@ def find_model_type(model_name): elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])): return 'gpt4chan' else: - return 'HF_generic' + config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}") + # Not a "catch all", but fairly accurate + if config.to_dict().get("is_encoder_decoder", False): + return 'HF_seq2seq' + else: + return 'HF_generic' def load_model(model_name): @@ -66,6 +72,9 @@ def load_model(model_name): if shared.model_type == 'chatglm': LoaderClass = AutoModel trust_remote_code = shared.args.trust_remote_code + elif shared.model_type == 'HF_seq2seq': + LoaderClass = AutoModelForSeq2SeqLM + trust_remote_code = False else: LoaderClass = AutoModelForCausalLM trust_remote_code = False diff --git a/modules/text_generation.py b/modules/text_generation.py index 936ec647..ba915482 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -58,12 +58,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt def decode(output_ids, skip_special_tokens=True): - if skip_special_tokens: - reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True) - reply = reply.replace(r'<|endoftext|>', '') - return reply + return shared.tokenizer.decode(output_ids, skip_special_tokens) + + +def get_reply_from_output_ids(output_ids, input_ids, original_question, state): + if shared.model_type == 'HF_seq2seq': + reply = decode(output_ids, state['skip_special_tokens']) + if not shared.is_chat(): + reply = apply_extensions('output', reply) else: - return shared.tokenizer.decode(output_ids, skip_special_tokens=False) + new_tokens = len(output_ids) - len(input_ids[0]) + reply = decode(output_ids[-new_tokens:], state['skip_special_tokens']) + if not shared.is_chat(): + reply = original_question + apply_extensions('output', reply) + + return reply def generate_softprompt_input_tensors(input_ids): @@ -262,11 +271,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - new_tokens = len(output) - len(input_ids[0]) - reply = decode(output[-new_tokens:], state['skip_special_tokens']) - if not shared.is_chat(): - reply = original_question + apply_extensions('output', reply) - + reply = get_reply_from_output_ids(output, input_ids, original_question, state) yield formatted_outputs(reply, shared.model_name) # Stream the reply 1 token at a time. @@ -282,7 +287,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): def generate_with_streaming(**kwargs): return Iteratorize(generate_with_callback, kwargs, callback=None) - if not shared.is_chat(): + if not shared.is_chat() and shared.model_type != 'HF_seq2seq': yield formatted_outputs(original_question, shared.model_name) with generate_with_streaming(**generate_params) as generator: @@ -290,11 +295,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - new_tokens = len(output) - len(input_ids[0]) - reply = decode(output[-new_tokens:], state['skip_special_tokens']) - if not shared.is_chat(): - reply = original_question + apply_extensions('output', reply) - + reply = get_reply_from_output_ids(output, input_ids, original_question, state) if output[-1] in eos_token_ids: break @@ -310,11 +311,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.soft_prompt: output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - new_tokens = len(output) - len(original_input_ids[0]) - reply = decode(output[-new_tokens:], state['skip_special_tokens']) - if not shared.is_chat(): - reply = original_question + apply_extensions('output', reply) - + reply = get_reply_from_output_ids(output, input_ids, original_question, state) if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): break @@ -334,6 +331,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): finally: t1 = time.time() original_tokens = len(original_input_ids[0]) - new_tokens = len(output) - original_tokens + new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0) print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return