mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2025-01-23 18:19:22 +01:00
Improve example dialogue handling
This commit is contained in:
parent
f9dbe7e08e
commit
aadf4e899a
@ -161,7 +161,7 @@ def generate_4chan_html(f):
|
||||
|
||||
return output
|
||||
|
||||
def generate_chat_html(_history, name1, name2, character):
|
||||
def generate_chat_html(history, name1, name2, character):
|
||||
css = """
|
||||
.chat {
|
||||
margin-left: auto;
|
||||
@ -234,13 +234,6 @@ def generate_chat_html(_history, name1, name2, character):
|
||||
img = f'<img src="file/{i}">'
|
||||
break
|
||||
|
||||
history = copy.deepcopy(_history)
|
||||
for i in range(len(history)):
|
||||
if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
|
||||
history[i][0] = history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
||||
history = history[i:]
|
||||
break
|
||||
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
row = _row.copy()
|
||||
row[0] = re.sub(r"[\\]*\*", r"*", row[0])
|
||||
|
29
server.py
29
server.py
@ -98,7 +98,7 @@ def load_model(model_name):
|
||||
settings.append(f"max_memory={{0: '{args.gpu_memory}GiB', 'cpu': '99GiB'}}")
|
||||
if args.disk:
|
||||
if args.disk_cache_dir is not None:
|
||||
settings.append("offload_folder='"+args.disk_cache_dir+"'")
|
||||
settings.append(f"offload_folder='{args.disk_cache_dir}'")
|
||||
else:
|
||||
settings.append("offload_folder='cache'")
|
||||
if args.load_in_8bit:
|
||||
@ -265,6 +265,15 @@ if args.chat or args.cai_chat:
|
||||
question = question.replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
||||
return question
|
||||
|
||||
def remove_example_dialogue_from_history(history):
|
||||
_history = copy.deepcopy(history)
|
||||
for i in range(len(_history)):
|
||||
if '<|BEGIN-VISIBLE-CHAT|>' in _history[i][0]:
|
||||
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
||||
_history = _history[i:]
|
||||
break
|
||||
return _history
|
||||
|
||||
def chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
question = generate_chat_prompt(text, tokens, name1, name2, context, history_size)
|
||||
history.append(['', ''])
|
||||
@ -300,9 +309,9 @@ if args.chat or args.cai_chat:
|
||||
next_character_substring_found = True
|
||||
|
||||
if not next_character_substring_found:
|
||||
yield history
|
||||
yield remove_example_dialogue_from_history(history)
|
||||
|
||||
yield history
|
||||
yield remove_example_dialogue_from_history(history)
|
||||
|
||||
def cai_chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
for history in chatbot_wrapper(text, tokens, inference_settings, selected_model, name1, name2, context, check, history_size):
|
||||
@ -327,17 +336,10 @@ if args.chat or args.cai_chat:
|
||||
return generate_chat_html(history, name1, name2, character)
|
||||
|
||||
def save_history():
|
||||
_history = copy.deepcopy(history)
|
||||
for i in range(len(_history)):
|
||||
if '<|BEGIN-VISIBLE-CHAT|>' in history[i][0]:
|
||||
_history[i][0] = _history[i][0].replace('<|BEGIN-VISIBLE-CHAT|>', '')
|
||||
_history = _history[i:]
|
||||
break
|
||||
|
||||
if not Path('logs').exists():
|
||||
Path('logs').mkdir()
|
||||
with open(Path('logs/conversation.json'), 'w') as f:
|
||||
f.write(json.dumps({'data': _history}))
|
||||
f.write(json.dumps({'data': history}))
|
||||
return Path('logs/conversation.json')
|
||||
|
||||
def load_history(file):
|
||||
@ -389,10 +391,11 @@ if args.chat or args.cai_chat:
|
||||
context = settings['context_pygmalion']
|
||||
name2 = settings['name2_pygmalion']
|
||||
|
||||
_history = remove_example_dialogue_from_history(history)
|
||||
if args.cai_chat:
|
||||
return name2, context, generate_chat_html(history, name1, name2, character)
|
||||
return name2, context, generate_chat_html(_history, name1, name2, character)
|
||||
else:
|
||||
return name2, context, history
|
||||
return name2, context, _history
|
||||
|
||||
suffix = '_pygmalion' if 'pygmalion' in model_name.lower() else ''
|
||||
with gr.Blocks(css=css+".h-\[40vh\] {height: 66.67vh} .gradio-container {max-width: 800px; margin-left: auto; margin-right: auto}", analytics_enabled=False) as interface:
|
||||
|
Loading…
Reference in New Issue
Block a user