mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 16:17:57 +01:00
Improve usage of stopping_criteria
This commit is contained in:
parent
add9330e5e
commit
59b5f7a4b7
@ -119,18 +119,11 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
output = input_ids[0]
|
output = input_ids[0]
|
||||||
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
||||||
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
|
n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1])
|
||||||
|
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||||
if stopping_string is not None:
|
if stopping_string is not None:
|
||||||
# The stopping_criteria code below was copied from
|
# Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
||||||
# https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py
|
|
||||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||||
stopping_criteria_list = transformers.StoppingCriteriaList([
|
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||||
_SentinelTokenStoppingCriteria(
|
|
||||||
sentinel_token_ids=t,
|
|
||||||
starting_idx=len(input_ids[0])
|
|
||||||
)
|
|
||||||
])
|
|
||||||
else:
|
|
||||||
stopping_criteria_list = []
|
|
||||||
|
|
||||||
if not shared.args.flexgen:
|
if not shared.args.flexgen:
|
||||||
generate_params = [
|
generate_params = [
|
||||||
@ -184,10 +177,9 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
elif not shared.args.flexgen:
|
elif not shared.args.flexgen:
|
||||||
|
|
||||||
def generate_with_callback(callback=None, **kwargs):
|
def generate_with_callback(callback=None, **kwargs):
|
||||||
if 'stopping_criteria' not in kwargs:
|
|
||||||
kwargs['stopping_criteria'] = []
|
|
||||||
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
kwargs['stopping_criteria'].append(Stream(callback_func=callback))
|
||||||
clear_torch_cache()
|
clear_torch_cache()
|
||||||
|
with torch.no_grad():
|
||||||
shared.model.generate(**kwargs)
|
shared.model.generate(**kwargs)
|
||||||
|
|
||||||
def generate_with_streaming(**kwargs):
|
def generate_with_streaming(**kwargs):
|
||||||
@ -195,6 +187,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
|||||||
|
|
||||||
yield formatted_outputs(original_question, shared.model_name)
|
yield formatted_outputs(original_question, shared.model_name)
|
||||||
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
|
for output in eval(f"generate_with_streaming({', '.join(generate_params)})"):
|
||||||
|
print(print('Used vram in gib:', torch.cuda.memory_allocated() / 1024**3))
|
||||||
if shared.soft_prompt:
|
if shared.soft_prompt:
|
||||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||||
reply = decode(output)
|
reply = decode(output)
|
||||||
|
Loading…
Reference in New Issue
Block a user