Always return only the new tokens in generation functions

This commit is contained in:
oobabooga 2023-05-11 17:07:20 -03:00
parent c4f0e6d740
commit 0d36c18f5d
4 changed files with 16 additions and 25 deletions

View File

@ -43,7 +43,7 @@ class Handler(BaseHTTPRequestHandler):
response = json.dumps({ response = json.dumps({
'results': [{ 'results': [{
'text': answer[len(prompt):] 'text': answer
}] }]
}) })
self.wfile.write(response.encode('utf-8')) self.wfile.write(response.encode('utf-8'))

View File

@ -29,7 +29,7 @@ async def _handle_connection(websocket, path):
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False) prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
# As we stream, only send the new bytes. # As we stream, only send the new bytes.
skip_index = len(prompt) skip_index = 0
message_num = 0 message_num = 0
for a in generator: for a in generator:

View File

@ -340,17 +340,14 @@ class Handler(BaseHTTPRequestHandler):
# generate reply ####################################### # generate reply #######################################
if debug: if debug:
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings}) print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=True) generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
answer = '' answer = ''
seen_content = '' seen_content = ''
longest_stop_len = max([len(x) for x in stopping_strings]) longest_stop_len = max([len(x) for x in stopping_strings])
for a in generator: for a in generator:
if isinstance(a, str): answer = a
answer = a
else:
answer = a[0]
stop_string_found = False stop_string_found = False
len_seen = len(seen_content) len_seen = len(seen_content)
@ -521,10 +518,7 @@ class Handler(BaseHTTPRequestHandler):
answer = '' answer = ''
for a in generator: for a in generator:
if isinstance(a, str): answer = a
answer = a
else:
answer = a[0]
completion_token_count = len(encode(answer)[0]) completion_token_count = len(encode(answer)[0])

View File

@ -104,18 +104,15 @@ def fix_galactica(s):
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False): def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
if shared.model_type == 'HF_seq2seq': if shared.model_type == 'HF_seq2seq':
reply = decode(output_ids, state['skip_special_tokens']) reply = decode(output_ids, state['skip_special_tokens'])
if not is_chat:
reply = apply_extensions('output', reply)
else: else:
new_tokens = len(output_ids) - len(input_ids[0]) new_tokens = len(output_ids) - len(input_ids[0])
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens']) reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
if type(shared.tokenizer) is transformers.LlamaTokenizer: if type(shared.tokenizer) is transformers.LlamaTokenizer:
if len(original_question) > 0 and original_question[-1] not in [' ', '\n']: if len(original_question) > 0 and original_question[-1] not in [' ', '\n']:
reply = ' ' + reply reply = ' ' + reply
if not is_chat: if not is_chat:
reply = original_question + apply_extensions('output', reply) reply = apply_extensions('output', reply)
return reply return reply
@ -149,6 +146,9 @@ def stop_everything_event():
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None): def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False): for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
if shared.model_type not in ['HF_seq2seq']:
reply = reply + question
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@ -236,7 +236,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
t0 = time.time() t0 = time.time()
try: try:
if not is_chat and shared.model_type != 'HF_seq2seq': if not is_chat and shared.model_type != 'HF_seq2seq':
yield original_question yield ''
# Generate the entire reply at once. # Generate the entire reply at once.
if not state['stream']: if not state['stream']:
@ -291,21 +291,18 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
t0 = time.time() t0 = time.time()
try: try:
if not is_chat: if not is_chat:
yield question yield ''
if not state['stream']: if not state['stream']:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply
if not is_chat: if not is_chat:
reply = original_question + apply_extensions('output', reply) reply = apply_extensions('output', reply)
yield reply yield reply
else: else:
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply
if not is_chat: if not is_chat:
reply = original_question + apply_extensions('output', reply) reply = apply_extensions('output', reply)
yield reply yield reply
@ -314,7 +311,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
finally: finally:
t1 = time.time() t1 = time.time()
original_tokens = len(encode(original_question)[0]) original_tokens = len(encode(original_question)[0])
new_tokens = len(encode(output)[0]) - original_tokens new_tokens = len(encode(original_question + reply)[0]) - original_tokens
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
@ -349,7 +346,7 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
t0 = time.time() t0 = time.time()
try: try:
if not is_chat: if not is_chat:
yield question yield ''
# Generate the entire reply at once. # Generate the entire reply at once.
if not state['stream']: if not state['stream']: