mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
Merge branch 'main' into HideLord-main
This commit is contained in:
commit
693b53d957
53
.github/ISSUE_TEMPLATE/bug_report_template.yml
vendored
Normal file
53
.github/ISSUE_TEMPLATE/bug_report_template.yml
vendored
Normal file
@ -0,0 +1,53 @@
|
||||
name: "Bug report"
|
||||
description: Report a bug
|
||||
labels: [ "bug" ]
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: |
|
||||
Thanks for taking the time to fill out this bug report!
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: Describe the bug
|
||||
description: A clear and concise description of what the bug is.
|
||||
placeholder: Bug description
|
||||
validations:
|
||||
required: true
|
||||
- type: checkboxes
|
||||
attributes:
|
||||
label: Is there an existing issue for this?
|
||||
description: Please search to see if an issue already exists for the issue you encountered.
|
||||
options:
|
||||
- label: I have searched the existing issues
|
||||
required: true
|
||||
- type: textarea
|
||||
id: reproduction
|
||||
attributes:
|
||||
label: Reproduction
|
||||
description: Please provide the steps necessary to reproduce your issue.
|
||||
placeholder: Reproduction
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: screenshot
|
||||
attributes:
|
||||
label: Screenshot
|
||||
description: "If possible, please include screenshot(s) so that we can understand what the issue is."
|
||||
- type: textarea
|
||||
id: logs
|
||||
attributes:
|
||||
label: Logs
|
||||
description: "Please include the full stacktrace of the errors you get in the command-line (if any)."
|
||||
render: shell
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: "Please share your system info with us: operating system, GPU brand, and GPU model. If you are using a Google Colab notebook, mention that instead."
|
||||
render: shell
|
||||
placeholder:
|
||||
validations:
|
||||
required: true
|
16
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
16
.github/ISSUE_TEMPLATE/feature_request.md
vendored
Normal file
@ -0,0 +1,16 @@
|
||||
---
|
||||
name: Feature request
|
||||
about: Suggest an improvement or new feature for the web UI
|
||||
title: ''
|
||||
labels: 'enhancement'
|
||||
assignees: ''
|
||||
|
||||
---
|
||||
|
||||
**Description**
|
||||
|
||||
A clear and concise description of what you want to be implemented.
|
||||
|
||||
**Additional Context**
|
||||
|
||||
If applicable, please provide any extra information, external links, or screenshots that could be useful.
|
11
.github/dependabot.yml
vendored
Normal file
11
.github/dependabot.yml
vendored
Normal file
@ -0,0 +1,11 @@
|
||||
# To get started with Dependabot version updates, you'll need to specify which
|
||||
# package ecosystems to update and where the package manifests are located.
|
||||
# Please see the documentation for all configuration options:
|
||||
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
|
||||
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: "pip" # See documentation for possible values
|
||||
directory: "/" # Location of package manifests
|
||||
schedule:
|
||||
interval: "weekly"
|
22
.github/workflows/stale.yml
vendored
Normal file
22
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
name: Close inactive issues
|
||||
on:
|
||||
schedule:
|
||||
- cron: "10 23 * * *"
|
||||
|
||||
jobs:
|
||||
close-issues:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/stale@v5
|
||||
with:
|
||||
stale-issue-message: ""
|
||||
close-issue-message: "This issue has been closed due to inactivity for 30 days. If you believe it is still relevant, you can reopen it (if you are the author) or leave a comment below."
|
||||
days-before-issue-stale: 30
|
||||
days-before-issue-close: 0
|
||||
stale-issue-label: "stale"
|
||||
days-before-pr-stale: -1
|
||||
days-before-pr-close: -1
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
@ -60,7 +60,9 @@ pip3 install torch torchvision torchaudio --extra-index-url https://download.pyt
|
||||
conda install pytorch torchvision torchaudio git -c pytorch
|
||||
```
|
||||
|
||||
See also: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
|
||||
> **Note**
|
||||
> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better.
|
||||
> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
|
||||
|
||||
## Installation option 2: one-click installers
|
||||
|
||||
@ -140,8 +142,9 @@ Optionally, you can use the following command-line flags:
|
||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
||||
| `--cpu` | Use the CPU to generate text.|
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--load-in-4bit` | Load the model with 4-bit precision. Currently only works with LLaMA.|
|
||||
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA. |
|
||||
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
|
||||
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
|
||||
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
|
@ -26,6 +26,7 @@ async def run(context):
|
||||
'top_p': 0.9,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.05,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
@ -59,6 +60,7 @@ async def run(context):
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
|
@ -24,6 +24,7 @@ params = {
|
||||
'top_p': 0.9,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.05,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
|
@ -76,7 +76,7 @@ def generate_html():
|
||||
return container_html
|
||||
|
||||
def ui():
|
||||
with gr.Accordion("Character gallery"):
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gallery = gr.HTML(value=generate_html())
|
||||
update.click(generate_html, [], gallery)
|
||||
|
@ -81,6 +81,7 @@ def input_modifier(string):
|
||||
if (shared.args.chat or shared.args.cai_chat) and len(shared.history['internal']) > 0:
|
||||
shared.history['visible'][-1] = [shared.history['visible'][-1][0], shared.history['visible'][-1][1].replace('controls autoplay>','controls>')]
|
||||
|
||||
shared.processing_message = "*Is recording a voice message...*"
|
||||
return string
|
||||
|
||||
def output_modifier(string):
|
||||
@ -119,6 +120,7 @@ def output_modifier(string):
|
||||
if params['show_text']:
|
||||
string += f'\n\n{original_string}'
|
||||
|
||||
shared.processing_message = "*Is typing...*"
|
||||
return string
|
||||
|
||||
def bot_prefix_modifier(string):
|
||||
|
@ -7,28 +7,40 @@ import torch
|
||||
import modules.shared as shared
|
||||
|
||||
sys.path.insert(0, str(Path("repositories/GPTQ-for-LLaMa")))
|
||||
from llama import load_quant
|
||||
import llama
|
||||
import opt
|
||||
|
||||
|
||||
# 4-bit LLaMA
|
||||
def load_quantized_LLaMA(model_name):
|
||||
if shared.args.load_in_4bit:
|
||||
bits = 4
|
||||
def load_quantized(model_name):
|
||||
if not shared.args.gptq_model_type:
|
||||
# Try to determine model type from model name
|
||||
model_type = model_name.split('-')[0].lower()
|
||||
if model_type not in ('llama', 'opt'):
|
||||
print("Can't determine model type from model name. Please specify it manually using --gptq-model-type "
|
||||
"argument")
|
||||
exit()
|
||||
else:
|
||||
bits = shared.args.gptq_bits
|
||||
model_type = shared.args.gptq_model_type.lower()
|
||||
|
||||
if model_type == 'llama':
|
||||
load_quant = llama.load_quant
|
||||
elif model_type == 'opt':
|
||||
load_quant = opt.load_quant
|
||||
else:
|
||||
print("Unknown pre-quantized model type specified. Only 'llama' and 'opt' are supported")
|
||||
exit()
|
||||
|
||||
path_to_model = Path(f'models/{model_name}')
|
||||
pt_model = ''
|
||||
if path_to_model.name.lower().startswith('llama-7b'):
|
||||
pt_model = f'llama-7b-{bits}bit.pt'
|
||||
pt_model = f'llama-7b-{shared.args.gptq_bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-13b'):
|
||||
pt_model = f'llama-13b-{bits}bit.pt'
|
||||
pt_model = f'llama-13b-{shared.args.gptq_bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-30b'):
|
||||
pt_model = f'llama-30b-{bits}bit.pt'
|
||||
pt_model = f'llama-30b-{shared.args.gptq_bits}bit.pt'
|
||||
elif path_to_model.name.lower().startswith('llama-65b'):
|
||||
pt_model = f'llama-65b-{bits}bit.pt'
|
||||
pt_model = f'llama-65b-{shared.args.gptq_bits}bit.pt'
|
||||
else:
|
||||
pt_model = f'{model_name}-{bits}bit.pt'
|
||||
pt_model = f'{model_name}-{shared.args.gptq_bits}bit.pt'
|
||||
|
||||
# Try to find the .pt both in models/ and in the subfolder
|
||||
pt_path = None
|
||||
@ -40,7 +52,7 @@ def load_quantized_LLaMA(model_name):
|
||||
print(f"Could not find {pt_model}, exiting...")
|
||||
exit()
|
||||
|
||||
model = load_quant(str(path_to_model), str(pt_path), bits)
|
||||
model = load_quant(str(path_to_model), str(pt_path), shared.args.gptq_bits)
|
||||
|
||||
# Multiple GPUs or GPU+CPU
|
||||
if shared.args.gpu_memory:
|
@ -97,7 +97,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
shared.stop_everything = False
|
||||
just_started = True
|
||||
eos_token = '\n' if check else None
|
||||
@ -126,13 +126,14 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not regenerate:
|
||||
yield shared.history['visible']+[[visible_text, '*Is typing...*']]
|
||||
yield shared.history['visible']+[[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
reply = ''
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
|
||||
@ -159,7 +160,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
|
||||
yield shared.history['visible']
|
||||
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
eos_token = '\n' if check else None
|
||||
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
@ -168,28 +169,29 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
||||
prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True)
|
||||
|
||||
reply = ''
|
||||
yield '*Is typing...*'
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
|
||||
yield reply
|
||||
if next_character_found:
|
||||
break
|
||||
yield reply
|
||||
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
||||
yield generate_chat_html(_history, name1, name2, shared.character)
|
||||
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], '*Is typing...*']], name1, name2, shared.character)
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
# Yield '*Is typing...*'
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||
else:
|
||||
|
@ -1,6 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
@ -35,6 +34,7 @@ if shared.args.deepspeed:
|
||||
ds_config = generate_ds_config(shared.args.bf16, 1 * world_size, shared.args.nvme_offload_dir)
|
||||
dschf = HfDeepSpeedConfig(ds_config) # Keep this object alive for the Transformers integration
|
||||
|
||||
|
||||
def load_model(model_name):
|
||||
print(f"Loading {model_name}...")
|
||||
t0 = time.time()
|
||||
@ -42,7 +42,7 @@ def load_model(model_name):
|
||||
shared.is_RWKV = model_name.lower().startswith('rwkv-')
|
||||
|
||||
# Default settings
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.load_in_4bit, shared.args.gptq_bits > 0, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
||||
if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
else:
|
||||
@ -87,11 +87,11 @@ def load_model(model_name):
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
# 4-bit LLaMA
|
||||
elif shared.args.gptq_bits > 0 or shared.args.load_in_4bit:
|
||||
from modules.quantized_LLaMA import load_quantized_LLaMA
|
||||
# Quantized model
|
||||
elif shared.args.gptq_bits > 0:
|
||||
from modules.GPTQ_loader import load_quantized
|
||||
|
||||
model = load_quantized_LLaMA(model_name)
|
||||
model = load_quantized(model_name)
|
||||
|
||||
# Custom
|
||||
else:
|
||||
|
@ -11,6 +11,7 @@ is_RWKV = False
|
||||
history = {'internal': [], 'visible': []}
|
||||
character = 'None'
|
||||
stop_everything = False
|
||||
processing_message = '*Is typing...*'
|
||||
|
||||
# UI elements (buttons, sliders, HTML, etc)
|
||||
gradio = {}
|
||||
@ -68,8 +69,9 @@ parser.add_argument('--chat', action='store_true', help='Launch the web UI in ch
|
||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
||||
parser.add_argument('--cpu', action='store_true', help='Use the CPU to generate text.')
|
||||
parser.add_argument('--load-in-8bit', action='store_true', help='Load the model with 8-bit precision.')
|
||||
parser.add_argument('--load-in-4bit', action='store_true', help='Load the model with 4-bit precision. Currently only works with LLaMA.')
|
||||
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA.')
|
||||
parser.add_argument('--load-in-4bit', action='store_true', help='DEPRECATED: use --gptq-bits 4 instead.')
|
||||
parser.add_argument('--gptq-bits', type=int, default=0, help='Load a pre-quantized model with specified precision. 2, 3, 4 and 8bit are supported. Currently only works with LLaMA and OPT.')
|
||||
parser.add_argument('--gptq-model-type', type=str, help='Model type of pre-quantized model. Currently only LLaMa and OPT are supported.')
|
||||
parser.add_argument('--bf16', action='store_true', help='Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU.')
|
||||
parser.add_argument('--auto-devices', action='store_true', help='Automatically split the model across the available GPU(s) and CPU.')
|
||||
parser.add_argument('--disk', action='store_true', help='If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk.')
|
||||
@ -94,3 +96,8 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
|
||||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
|
||||
parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# Provisional, this will be deleted later
|
||||
if args.load_in_4bit:
|
||||
print("Warning: --load-in-4bit is deprecated and will be removed. Use --gptq-bits 4 instead.\n")
|
||||
args.gptq_bits = 4
|
||||
|
@ -89,7 +89,7 @@ def clear_torch_cache():
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
clear_torch_cache()
|
||||
t0 = time.time()
|
||||
|
||||
@ -122,7 +122,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
input_ids = encode(question, max_new_tokens)
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()"
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
@ -132,45 +132,49 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
t = encode(stopping_string, 0, add_special_tokens=False)
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0])))
|
||||
|
||||
generate_params = {}
|
||||
if not shared.args.flexgen:
|
||||
generate_params = [
|
||||
f"max_new_tokens=max_new_tokens",
|
||||
f"eos_token_id={eos_token_ids}",
|
||||
f"stopping_criteria=stopping_criteria_list",
|
||||
f"do_sample={do_sample}",
|
||||
f"temperature={temperature}",
|
||||
f"top_p={top_p}",
|
||||
f"typical_p={typical_p}",
|
||||
f"repetition_penalty={repetition_penalty}",
|
||||
f"top_k={top_k}",
|
||||
f"min_length={min_length if shared.args.no_stream else 0}",
|
||||
f"no_repeat_ngram_size={no_repeat_ngram_size}",
|
||||
f"num_beams={num_beams}",
|
||||
f"penalty_alpha={penalty_alpha}",
|
||||
f"length_penalty={length_penalty}",
|
||||
f"early_stopping={early_stopping}",
|
||||
]
|
||||
generate_params.update({
|
||||
"max_new_tokens": max_new_tokens,
|
||||
"eos_token_id": eos_token_ids,
|
||||
"stopping_criteria": stopping_criteria_list,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"typical_p": typical_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"encoder_repetition_penalty": encoder_repetition_penalty,
|
||||
"top_k": top_k,
|
||||
"min_length": min_length if shared.args.no_stream else 0,
|
||||
"no_repeat_ngram_size": no_repeat_ngram_size,
|
||||
"num_beams": num_beams,
|
||||
"penalty_alpha": penalty_alpha,
|
||||
"length_penalty": length_penalty,
|
||||
"early_stopping": early_stopping,
|
||||
})
|
||||
else:
|
||||
generate_params = [
|
||||
f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}",
|
||||
f"do_sample={do_sample}",
|
||||
f"temperature={temperature}",
|
||||
f"stop={eos_token_ids[-1]}",
|
||||
]
|
||||
generate_params.update({
|
||||
"max_new_tokens": max_new_tokens if shared.args.no_stream else 8,
|
||||
"do_sample": do_sample,
|
||||
"temperature": temperature,
|
||||
"stop": eos_token_ids[-1],
|
||||
})
|
||||
if shared.args.deepspeed:
|
||||
generate_params.append("synced_gpus=True")
|
||||
generate_params.update({"synced_gpus": True})
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
generate_params.insert(0, "inputs_embeds=inputs_embeds")
|
||||
generate_params.insert(0, "inputs=filler_input_ids")
|
||||
generate_params.update({"inputs_embeds": inputs_embeds})
|
||||
generate_params.update({"inputs": filler_input_ids})
|
||||
else:
|
||||
generate_params.insert(0, "inputs=input_ids")
|
||||
generate_params.update({"inputs": input_ids})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
if shared.args.no_stream:
|
||||
with torch.no_grad():
|
||||
output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0]
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if cuda:
|
||||
output = output.cuda()
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
|
||||
@ -194,7 +198,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator:
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
@ -214,7 +218,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
for i in range(max_new_tokens//8+1):
|
||||
clear_torch_cache()
|
||||
with torch.no_grad():
|
||||
output = eval(f"shared.model.generate({', '.join(generate_params)})")[0]
|
||||
output = shared.model.generate(**generate_params)[0]
|
||||
if shared.soft_prompt:
|
||||
output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:]))
|
||||
reply = decode(output)
|
||||
|
@ -38,6 +38,9 @@ svg {
|
||||
ol li p, ul li p {
|
||||
display: inline-block;
|
||||
}
|
||||
#main, #settings, #extensions, #chat-settings {
|
||||
border: 0;
|
||||
}
|
||||
"""
|
||||
|
||||
chat_css = """
|
||||
@ -64,6 +67,12 @@ div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
||||
}
|
||||
"""
|
||||
|
||||
page_js = """
|
||||
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px"
|
||||
document.getElementById("main").parentNode.style = "padding: 0; margin: 0"
|
||||
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0"
|
||||
"""
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
||||
|
@ -1,10 +1,10 @@
|
||||
accelerate==0.17.0
|
||||
bitsandbytes==0.37.0
|
||||
accelerate==0.17.1
|
||||
bitsandbytes==0.37.1
|
||||
flexgen==0.1.7
|
||||
gradio==3.18.0
|
||||
numpy
|
||||
requests
|
||||
rwkv==0.3.1
|
||||
rwkv==0.4.2
|
||||
safetensors==0.3.0
|
||||
sentencepiece
|
||||
tqdm
|
||||
|
89
server.py
89
server.py
@ -66,6 +66,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 50,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
@ -86,7 +87,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
@ -100,9 +101,7 @@ def upload_soft_prompt(file):
|
||||
|
||||
return name
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
def create_model_and_preset_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
@ -113,22 +112,29 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown('Custom generation parameters')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
||||
with gr.Column():
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown('Contrastive search:')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
with gr.Box():
|
||||
gr.Markdown('Beam search (uses a lot of VRAM):')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -137,7 +143,8 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
|
||||
with gr.Box():
|
||||
gr.Markdown('Soft prompt')
|
||||
with gr.Row():
|
||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
@ -147,7 +154,7 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['encoder_repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
|
||||
|
||||
@ -200,6 +207,7 @@ suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||
|
||||
if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
||||
else:
|
||||
@ -219,7 +227,21 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
with gr.Tab('Chat settings'):
|
||||
|
||||
create_model_and_preset_menus()
|
||||
|
||||
with gr.Box():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
with gr.Tab("Chat settings", elem_id="chat-settings"):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
|
||||
@ -253,23 +275,13 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Tab('Upload TavernAI Character Card'):
|
||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||
|
||||
with gr.Tab('Generation settings'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
with gr.Tab("Settings", elem_id="settings"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
||||
if shared.args.extensions is not None:
|
||||
with gr.Tab('Extensions'):
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
@ -308,36 +320,42 @@ if shared.args.chat or shared.args.cai_chat:
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
gr.Markdown(description)
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
create_model_and_preset_menus()
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
with gr.Tab("Settings", elem_id="settings"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
|
||||
|
||||
else:
|
||||
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
gr.Markdown(description)
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
|
||||
@ -349,24 +367,27 @@ else:
|
||||
with gr.Column():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
create_model_and_preset_menus()
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
|
||||
shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
with gr.Tab("Settings", elem_id="settings"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.page_js}}}")
|
||||
|
||||
shared.gradio['interface'].queue()
|
||||
if shared.args.listen:
|
||||
|
Loading…
Reference in New Issue
Block a user