mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 09:19:23 +01:00
Always return only the new tokens in generation functions
This commit is contained in:
parent
c4f0e6d740
commit
0d36c18f5d
@ -43,7 +43,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
response = json.dumps({
|
response = json.dumps({
|
||||||
'results': [{
|
'results': [{
|
||||||
'text': answer[len(prompt):]
|
'text': answer
|
||||||
}]
|
}]
|
||||||
})
|
})
|
||||||
self.wfile.write(response.encode('utf-8'))
|
self.wfile.write(response.encode('utf-8'))
|
||||||
|
@ -29,7 +29,7 @@ async def _handle_connection(websocket, path):
|
|||||||
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
prompt, generate_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
# As we stream, only send the new bytes.
|
# As we stream, only send the new bytes.
|
||||||
skip_index = len(prompt)
|
skip_index = 0
|
||||||
message_num = 0
|
message_num = 0
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
|
@ -340,17 +340,14 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
# generate reply #######################################
|
# generate reply #######################################
|
||||||
if debug:
|
if debug:
|
||||||
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
print({'prompt': prompt, 'req_params': req_params, 'stopping_strings': stopping_strings})
|
||||||
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=True)
|
generator = generate_reply(prompt, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
seen_content = ''
|
seen_content = ''
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings])
|
longest_stop_len = max([len(x) for x in stopping_strings])
|
||||||
|
|
||||||
for a in generator:
|
for a in generator:
|
||||||
if isinstance(a, str):
|
|
||||||
answer = a
|
answer = a
|
||||||
else:
|
|
||||||
answer = a[0]
|
|
||||||
|
|
||||||
stop_string_found = False
|
stop_string_found = False
|
||||||
len_seen = len(seen_content)
|
len_seen = len(seen_content)
|
||||||
@ -521,10 +518,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
|
|
||||||
answer = ''
|
answer = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
if isinstance(a, str):
|
|
||||||
answer = a
|
answer = a
|
||||||
else:
|
|
||||||
answer = a[0]
|
|
||||||
|
|
||||||
completion_token_count = len(encode(answer)[0])
|
completion_token_count = len(encode(answer)[0])
|
||||||
|
|
||||||
|
@ -104,18 +104,15 @@ def fix_galactica(s):
|
|||||||
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
|
def get_reply_from_output_ids(output_ids, input_ids, original_question, state, is_chat=False):
|
||||||
if shared.model_type == 'HF_seq2seq':
|
if shared.model_type == 'HF_seq2seq':
|
||||||
reply = decode(output_ids, state['skip_special_tokens'])
|
reply = decode(output_ids, state['skip_special_tokens'])
|
||||||
if not is_chat:
|
|
||||||
reply = apply_extensions('output', reply)
|
|
||||||
else:
|
else:
|
||||||
new_tokens = len(output_ids) - len(input_ids[0])
|
new_tokens = len(output_ids) - len(input_ids[0])
|
||||||
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
|
reply = decode(output_ids[-new_tokens:], state['skip_special_tokens'])
|
||||||
|
|
||||||
if type(shared.tokenizer) is transformers.LlamaTokenizer:
|
if type(shared.tokenizer) is transformers.LlamaTokenizer:
|
||||||
if len(original_question) > 0 and original_question[-1] not in [' ', '\n']:
|
if len(original_question) > 0 and original_question[-1] not in [' ', '\n']:
|
||||||
reply = ' ' + reply
|
reply = ' ' + reply
|
||||||
|
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
reply = original_question + apply_extensions('output', reply)
|
reply = apply_extensions('output', reply)
|
||||||
|
|
||||||
return reply
|
return reply
|
||||||
|
|
||||||
@ -149,6 +146,9 @@ def stop_everything_event():
|
|||||||
|
|
||||||
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
def generate_reply_wrapper(question, state, eos_token=None, stopping_strings=None):
|
||||||
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
for reply in generate_reply(question, state, eos_token, stopping_strings, is_chat=False):
|
||||||
|
if shared.model_type not in ['HF_seq2seq']:
|
||||||
|
reply = reply + question
|
||||||
|
|
||||||
yield formatted_outputs(reply, shared.model_name)
|
yield formatted_outputs(reply, shared.model_name)
|
||||||
|
|
||||||
|
|
||||||
@ -236,7 +236,7 @@ def generate_reply_HF(question, original_question, seed, state, eos_token=None,
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat and shared.model_type != 'HF_seq2seq':
|
if not is_chat and shared.model_type != 'HF_seq2seq':
|
||||||
yield original_question
|
yield ''
|
||||||
|
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
if not state['stream']:
|
if not state['stream']:
|
||||||
@ -291,21 +291,18 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
yield question
|
yield ''
|
||||||
|
|
||||||
if not state['stream']:
|
if not state['stream']:
|
||||||
reply = shared.model.generate(context=question, **generate_params)
|
reply = shared.model.generate(context=question, **generate_params)
|
||||||
output = original_question + reply
|
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
reply = original_question + apply_extensions('output', reply)
|
reply = apply_extensions('output', reply)
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
else:
|
else:
|
||||||
|
|
||||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||||
output = original_question + reply
|
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
reply = original_question + apply_extensions('output', reply)
|
reply = apply_extensions('output', reply)
|
||||||
|
|
||||||
yield reply
|
yield reply
|
||||||
|
|
||||||
@ -314,7 +311,7 @@ def generate_reply_custom(question, original_question, seed, state, eos_token=No
|
|||||||
finally:
|
finally:
|
||||||
t1 = time.time()
|
t1 = time.time()
|
||||||
original_tokens = len(encode(original_question)[0])
|
original_tokens = len(encode(original_question)[0])
|
||||||
new_tokens = len(encode(output)[0]) - original_tokens
|
new_tokens = len(encode(original_question + reply)[0]) - original_tokens
|
||||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -349,7 +346,7 @@ def generate_reply_flexgen(question, original_question, seed, state, eos_token=N
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
try:
|
try:
|
||||||
if not is_chat:
|
if not is_chat:
|
||||||
yield question
|
yield ''
|
||||||
|
|
||||||
# Generate the entire reply at once.
|
# Generate the entire reply at once.
|
||||||
if not state['stream']:
|
if not state['stream']:
|
||||||
|
Loading…
Reference in New Issue
Block a user