mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 17:50:22 +01:00
Add nice HTML output for all models
This commit is contained in:
parent
18ae08ef91
commit
d5e01c80e3
@ -94,7 +94,7 @@ Optionally, you can use the following command-line flags:
|
|||||||
--cpu Use the CPU to generate text.
|
--cpu Use the CPU to generate text.
|
||||||
--auto-devices Automatically split the model across the available GPU(s) and CPU.
|
--auto-devices Automatically split the model across the available GPU(s) and CPU.
|
||||||
--load-in-8bit Load the model with 8-bit precision.
|
--load-in-8bit Load the model with 8-bit precision.
|
||||||
--listen Make the webui reachable from your local network.
|
--no-listen Make the webui unreachable from your local network.
|
||||||
```
|
```
|
||||||
|
|
||||||
## Presets
|
## Presets
|
||||||
|
@ -20,7 +20,7 @@ def process_post(post, c):
|
|||||||
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
src = f'<span class="name">Anonymous </span> <span class="number">No.{number}</span>\n{src}'
|
||||||
return src
|
return src
|
||||||
|
|
||||||
def generate_html(f):
|
def generate_4chan_html(f):
|
||||||
css = """
|
css = """
|
||||||
#container {
|
#container {
|
||||||
background-color: #eef2ff;
|
background-color: #eef2ff;
|
||||||
|
23
server.py
23
server.py
@ -18,7 +18,7 @@ parser.add_argument('--chat', action='store_true', help='Launch the webui in cha
|
|||||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||||
parser.add_argument('--listen', action='store_true', help='Make the webui reachable from your local network.')
|
parser.add_argument('--no-listen', action='store_true', help='Make the webui unreachable from your local network.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
loaded_preset = None
|
loaded_preset = None
|
||||||
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
available_models = sorted(set(map(lambda x : str(x.name).replace('.pt', ''), list(Path('models/').glob('*'))+list(Path('torch-dumps/').glob('*')))))
|
||||||
@ -63,7 +63,7 @@ def load_model(model_name):
|
|||||||
model = eval(command)
|
model = eval(command)
|
||||||
|
|
||||||
# Loading the tokenizer
|
# Loading the tokenizer
|
||||||
if model_name.lower().startswith('gpt4chan') and Path(f"models/gpt-j-6B/").exists():
|
if model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"models/gpt-j-6B/").exists():
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/"))
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{model_name}/"))
|
||||||
@ -79,6 +79,7 @@ def fix_gpt4chan(s):
|
|||||||
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
s = re.sub("--- [0-9]*\n\n\n---", "---", s)
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
# Fix the LaTeX equations in GALACTICA
|
||||||
def fix_galactica(s):
|
def fix_galactica(s):
|
||||||
s = s.replace(r'\[', r'$')
|
s = s.replace(r'\[', r'$')
|
||||||
s = s.replace(r'\]', r'$')
|
s = s.replace(r'\]', r'$')
|
||||||
@ -87,6 +88,11 @@ def fix_galactica(s):
|
|||||||
s = s.replace(r'$$', r'$')
|
s = s.replace(r'$$', r'$')
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
def generate_html(s):
|
||||||
|
s = '\n'.join([f'<p style="margin-bottom: 20px">{line}</p>' for line in s.split('\n')])
|
||||||
|
s = f'<div style="max-width: 600px; margin-left: auto; margin-right: auto; background-color:#eef2ff; color:#0b0f19; padding:3em; font-size:1.2em;">{s}</div>'
|
||||||
|
return s
|
||||||
|
|
||||||
def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
|
def generate_reply(question, temperature, max_length, inference_settings, selected_model, eos_token=None):
|
||||||
global model, tokenizer, model_name, loaded_preset, preset
|
global model, tokenizer, model_name, loaded_preset, preset
|
||||||
|
|
||||||
@ -117,14 +123,15 @@ def generate_reply(question, temperature, max_length, inference_settings, select
|
|||||||
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
output = eval(f"model.generate(input_ids, eos_token_id={n}, {preset}){cuda}")
|
||||||
|
|
||||||
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
reply = tokenizer.decode(output[0], skip_special_tokens=True)
|
||||||
|
reply = reply.replace(r'<|endoftext|>', '')
|
||||||
if model_name.lower().startswith('galactica'):
|
if model_name.lower().startswith('galactica'):
|
||||||
reply = fix_galactica(reply)
|
reply = fix_galactica(reply)
|
||||||
return reply, reply, 'Only applicable for gpt4chan.'
|
return reply, reply, generate_html(reply)
|
||||||
elif model_name.lower().startswith('gpt4chan'):
|
elif model_name.lower().startswith('gpt4chan'):
|
||||||
reply = fix_gpt4chan(reply)
|
reply = fix_gpt4chan(reply)
|
||||||
return reply, 'Only applicable for galactica models.', generate_html(reply)
|
return reply, 'Only applicable for galactica models.', generate_4chan_html(reply)
|
||||||
else:
|
else:
|
||||||
return reply, 'Only applicable for galactica models.', 'Only applicable for gpt4chan.'
|
return reply, 'Only applicable for galactica models.', generate_html(reply)
|
||||||
|
|
||||||
# Choosing the default model
|
# Choosing the default model
|
||||||
if args.model is not None:
|
if args.model is not None:
|
||||||
@ -248,7 +255,7 @@ else:
|
|||||||
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
btn.click(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
||||||
textbox.submit(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
textbox.submit(generate_reply, [textbox, temp_slider, length_slider, preset_menu, model_menu], [output_textbox, markdown, html], show_progress=True)
|
||||||
|
|
||||||
if args.listen:
|
if args.no_listen:
|
||||||
interface.launch(share=False, server_name="0.0.0.0")
|
|
||||||
else:
|
|
||||||
interface.launch(share=False)
|
interface.launch(share=False)
|
||||||
|
else:
|
||||||
|
interface.launch(share=False, server_name="0.0.0.0")
|
||||||
|
Loading…
Reference in New Issue
Block a user