Improve usage of stopping_criteria

This commit is contained in:
oobabooga 2023-03-08 12:13:40 -03:00
parent add9330e5e
commit 59b5f7a4b7

View File

@ -119,18 +119,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
output = input_ids[0] output = input_ids[0]
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
stopping_criteria_list = transformers.StoppingCriteriaList()
if stopping_string is not None: if stopping_string is not None:
# The stopping_criteria code below was copied from # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
t = encode(stopping_string, 0, add_special_tokens=False) t = encode(stopping_string, 0, add_special_tokens=False)
stopping_criteria_list = transformers.StoppingCriteriaList([ stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
_SentinelTokenStoppingCriteria(
sentinel_token_ids=t,
starting_idx=len(input_ids[0])
)
])
else:
stopping_criteria_list = []
if not shared.args.flexgen: if not shared.args.flexgen:
generate_params = [ generate_params = [
@ -184,17 +177,17 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
elif not shared.args.flexgen: elif not shared.args.flexgen:
def generate_with_callback(callback=None, **kwargs): def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' not in kwargs:
kwargs['stopping_criteria'] = []
kwargs['stopping_criteria'].append(Stream(callback_func=callback)) kwargs['stopping_criteria'].append(Stream(callback_func=callback))
clear_torch_cache() clear_torch_cache()
shared.model.generate(**kwargs) with torch.no_grad():
shared.model.generate(**kwargs)
def generate_with_streaming(**kwargs): def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None) return Iteratorize(generate_with_callback, kwargs, callback=None)
yield formatted_outputs(original_question, shared.model_name) yield formatted_outputs(original_question, shared.model_name)
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"): for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
if shared.soft_prompt: if shared.soft_prompt:
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
reply = decode(output) reply = decode(output)