Seq2Seq support (including FLAN-T5) (#1535)

---------

Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com>
This commit is contained in:
Vincent Brouwers 2023-04-26 03:39:04 +02:00 committed by GitHub
parent 95aa43b9c2
commit 92cdb4f22b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 30 additions and 24 deletions

View File

@ -11,7 +11,8 @@ import torch
import transformers import transformers
from accelerate import infer_auto_device_map, init_empty_weights from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM, from transformers import (AutoConfig, AutoModel, AutoModelForCausalLM,
AutoTokenizer, BitsAndBytesConfig, LlamaTokenizer) AutoModelForSeq2SeqLM, AutoTokenizer,
BitsAndBytesConfig, LlamaTokenizer)
import modules.shared as shared import modules.shared as shared
from modules import llama_attn_hijack from modules import llama_attn_hijack
@ -54,6 +55,11 @@ def find_model_type(model_name):
return 'llava' return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])): elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan' return 'gpt4chan'
else:
config = AutoConfig.from_pretrained(f"{shared.args.model_dir}/{model_name}")
# Not a "catch all", but fairly accurate
if config.to_dict().get("is_encoder_decoder", False):
return 'HF_seq2seq'
else: else:
return 'HF_generic' return 'HF_generic'
@ -66,6 +72,9 @@ def load_model(model_name):
if shared.model_type == 'chatglm': if shared.model_type == 'chatglm':
LoaderClass = AutoModel LoaderClass = AutoModel
trust_remote_code = shared.args.trust_remote_code trust_remote_code = shared.args.trust_remote_code
elif shared.model_type == 'HF_seq2seq':
LoaderClass = AutoModelForSeq2SeqLM
trust_remote_code = False
else: else:
LoaderClass = AutoModelForCausalLM LoaderClass = AutoModelForCausalLM
trust_remote_code = False trust_remote_code = False

View File

@ -58,12 +58,21 @@ def encode(prompt, add_special_tokens=True, add_bos_token=True, truncation_lengt
def decode(output_ids, skip_special_tokens=True): def decode(output_ids, skip_special_tokens=True):
if skip_special_tokens: return shared.tokenizer.decode(output_ids, skip_special_tokens)
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
reply = reply.replace(r'<|endoftext|>', '')
return reply def get_reply_from_output_ids(output_ids, input_ids, original_question, state):
if shared.model_type == 'HF_seq2seq':
reply = decode(output_ids, state['skip_special_tokens'])
if not shared.is_chat():
reply = apply_extensions('output', reply)
else: else:
return shared.tokenizer.decode(output_ids, skip_special_tokens=False) new_tokens = len(output_ids) - len(input_ids[0])
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
return reply
def generate_softprompt_input_tensors(input_ids): def generate_softprompt_input_tensors(input_ids):
@ -262,11 +271,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0]) reply = get_reply_from_output_ids(output, input_ids, original_question, state)
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
# Stream the reply 1 token at a time. # Stream the reply 1 token at a time.
@ -282,7 +287,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
def generate_with_streaming(**kwargs): def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None) return Iteratorize(generate_with_callback, kwargs, callback=None)
if not shared.is_chat(): if not shared.is_chat() and shared.model_type != 'HF_seq2seq':
yield formatted_outputs(original_question, shared.model_name) yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator: with generate_with_streaming(**generate_params) as generator:
@ -290,11 +295,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(input_ids[0]) reply = get_reply_from_output_ids(output, input_ids, original_question, state)
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break
@ -310,11 +311,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
new_tokens = len(output) - len(original_input_ids[0]) reply = get_reply_from_output_ids(output, input_ids, original_question, state)
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat():
reply = original_question + apply_extensions('output', reply)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break break
@ -334,6 +331,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
finally: finally:
t1 = time.time() t1 = time.time()
original_tokens = len(original_input_ids[0]) original_tokens = len(original_input_ids[0])
new_tokens = len(output) - original_tokens new_tokens = len(output) - (original_tokens if shared.model_type != 'HF_seq2seq' else 0)
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return