Merge branch 'main' into Brawlence-main

This commit is contained in:
oobabooga 2023-03-19 13:09:59 -03:00
commit eab8de0d4a
34 changed files with 1038 additions and 645 deletions

4
.gitignore vendored
View File

@ -4,12 +4,15 @@ extensions/silero_tts/outputs/*
extensions/elevenlabs_tts/outputs/*
extensions/sd_api_pictures/outputs/*
logs/*
loras/*
models/*
softprompts/*
torch-dumps/*
*pycache*
*/*pycache*
*/*/pycache*
venv/
.venv/
settings.json
img_bot*
@ -17,6 +20,7 @@ img_me*
!characters/Example.json
!characters/Example.png
!loras/place-your-loras-here.txt
!models/place-your-models-here.txt
!softprompts/place-your-softprompts-here.txt
!torch-dumps/place-your-pt-models-here.txt

152
README.md
View File

@ -19,52 +19,76 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
* Generate Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX support.
* Support for [Pygmalion](https://huggingface.co/models?search=pygmalionai/pygmalion) and custom characters in JSON or TavernAI Character Card formats ([FAQ](https://github.com/oobabooga/text-generation-webui/wiki/Pygmalion-chat-model-FAQ)).
* Advanced chat features (send images, get audio responses with TTS).
* Stream the text output in real time.
* Stream the text output in real time very efficiently.
* Load parameter presets from text files.
* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows).
* Load large models in 8-bit mode.
* Split large models across your GPU(s), CPU, and disk.
* CPU mode.
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs).
* Supports softprompts.
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
* [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab).
## Installation option 1: conda
## Installation
Open a terminal and copy and paste these commands one at a time ([install conda](https://docs.conda.io/en/latest/miniconda.html) first if you don't have it already):
The recommended installation methods are the following:
* Linux and MacOS: using conda natively.
* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)).
Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html
On Linux or WSL, it can be automatically installed with these two commands:
```
conda create -n textgen
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
bash Miniconda3.sh
```
Source: https://educe-ubc.github.io/conda.html
#### 1. Create a new conda environment
```
conda create -n textgen python=3.10.9
conda activate textgen
conda install torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia
```
#### 2. Install Pytorch
| System | GPU | Command |
|--------|---------|---------|
| Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` |
| Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` |
| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` |
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
#### 3. Install the web UI
```
git clone https://github.com/oobabooga/text-generation-webui
cd text-generation-webui
pip install -r requirements.txt
```
The third line assumes that you have an NVIDIA GPU.
* If you have an AMD GPU, replace the third command with this one:
```
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
```
* If you are running it in CPU mode, replace the third command with this one:
```
conda install pytorch torchvision torchaudio git -c pytorch
```
> **Note**
> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better.
> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
>
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
## Installation option 2: one-click installers
### Alternative: native Windows installation
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
### Alternative: one-click installers
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
@ -75,19 +99,25 @@ Just download the zip above, extract it, and double click on "install". The web
* To download a model, double click on "download-model"
* To start the web UI, double click on "start-webui"
Source codes: https://github.com/oobabooga/one-click-installers
This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652
### Alternative: Docker
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
## Downloading models
Models should be placed under `models/model-name`. For instance, `models/gpt-j-6B` for [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main).
#### Hugging Face
Models should be placed inside the `models` folder.
[Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some noteworthy examples:
* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
* [Pythia](https://huggingface.co/models?search=eleutherai/pythia)
* [OPT](https://huggingface.co/models?search=facebook/opt)
* [GALACTICA](https://huggingface.co/models?search=facebook/galactica)
* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
* [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW)
* [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW)
@ -101,7 +131,7 @@ For instance:
If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
#### GPT-4chan
### GPT-4chan
[GPT-4chan](https://huggingface.co/ykilcher/gpt-4chan) has been shut down from Hugging Face, so you need to download it elsewhere. You have two options:
@ -123,6 +153,7 @@ python download-model.py EleutherAI/gpt-j-6B --text-only
## Starting the web UI
conda activate textgen
cd text-generation-webui
python server.py
Then browse to
@ -133,41 +164,42 @@ Then browse to
Optionally, you can use the following command-line flags:
| Flag | Description |
|-------------|-------------|
| `-h`, `--help` | show this help message and exit |
| `--model MODEL` | Name of the model to load by default. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| Flag | Description |
|------------------|-------------|
| `-h`, `--help` | show this help message and exit |
| `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode.|
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
| `--cpu` | Use the CPU to generate text.|
| `--load-in-8bit` | Load the model with 8-bit precision.|
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
| `--flexgen` | Enable the use of FlexGen offloading. |
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. |
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
| `--no-stream` | Don't stream the text output in real time. |
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
| `--listen` | Make the web UI reachable from your local network.|
| `--listen` | Make the web UI reachable from your local network.|
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--verbose` | Print the prompts to the terminal. |
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
| `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--verbose` | Print the prompts to the terminal. |
Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
@ -192,7 +224,7 @@ Before reporting a bug, make sure that you have:
## Credits
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui
- Verbose preset: Anonymous 4chan user.
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
- Pygmalion preset, code for early stopping in chat mode, code for some of the sliders, --chat mode colors: https://github.com/PygmalionAI/gradio-ui/

View File

@ -26,6 +26,7 @@ async def run(context):
'top_p': 0.9,
'typical_p': 1,
'repetition_penalty': 1.05,
'encoder_repetition_penalty': 1.0,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
@ -43,14 +44,14 @@ async def run(context):
case "send_hash":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": 7
"fn_index": 9
}))
case "estimation":
pass
case "send_data":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": 7,
"fn_index": 9,
"data": [
context,
params['max_new_tokens'],
@ -59,6 +60,7 @@ async def run(context):
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],

View File

@ -24,6 +24,7 @@ params = {
'top_p': 0.9,
'typical_p': 1,
'repetition_penalty': 1.05,
'encoder_repetition_penalty': 1.0,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
params['top_p'],
params['typical_p'],
params['repetition_penalty'],
params['encoder_repetition_penalty'],
params['top_k'],
params['min_length'],
params['no_repeat_ngram_size'],

25
css/chat.css Normal file
View File

@ -0,0 +1,25 @@
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
height: 66.67vh
}
.gradio-container {
margin-left: auto !important;
margin-right: auto !important;
}
.w-screen {
width: unset
}
div.svelte-362y77>*, div.svelte-362y77>.form>* {
flex-wrap: nowrap
}
/* fixes the API documentation in chat mode */
.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
display: grid;
}
.pending.svelte-1ed2p3z {
opacity: 1;
}

4
css/chat.js Normal file
View File

@ -0,0 +1,4 @@
document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto";
document.getElementById("extensions").style.setProperty("max-width", "800px");
document.getElementById("extensions").style.setProperty("margin-left", "auto");
document.getElementById("extensions").style.setProperty("margin-right", "auto");

103
css/html_4chan_style.css Normal file
View File

@ -0,0 +1,103 @@
#parent #container {
background-color: #eef2ff;
padding: 17px;
}
#parent #container .reply {
background-color: rgb(214, 218, 240);
border-bottom-color: rgb(183, 197, 217);
border-bottom-style: solid;
border-bottom-width: 1px;
border-image-outset: 0;
border-image-repeat: stretch;
border-image-slice: 100%;
border-image-source: none;
border-image-width: 1;
border-left-color: rgb(0, 0, 0);
border-left-style: none;
border-left-width: 0px;
border-right-color: rgb(183, 197, 217);
border-right-style: solid;
border-right-width: 1px;
border-top-color: rgb(0, 0, 0);
border-top-style: none;
border-top-width: 0px;
color: rgb(0, 0, 0);
display: table;
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 4px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
padding-bottom: 4px;
padding-left: 2px;
padding-right: 2px;
padding-top: 4px;
}
#parent #container .number {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
width: 342.65px;
margin-right: 7px;
}
#parent #container .op {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 8px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
}
#parent #container .op blockquote {
margin-left: 0px !important;
}
#parent #container .name {
color: rgb(17, 119, 67);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
font-weight: 700;
margin-left: 7px;
}
#parent #container .quote {
color: rgb(221, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
text-decoration-color: rgb(221, 0, 0);
text-decoration-line: underline;
text-decoration-style: solid;
text-decoration-thickness: auto;
}
#parent #container .greentext {
color: rgb(120, 153, 34);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
}
#parent #container blockquote {
margin: 0px !important;
margin-block-start: 1em;
margin-block-end: 1em;
margin-inline-start: 40px;
margin-inline-end: 40px;
margin-top: 13.33px !important;
margin-bottom: 13.33px !important;
margin-left: 40px !important;
margin-right: 40px !important;
}
#parent #container .message {
color: black;
border: none;
}

73
css/html_cai_style.css Normal file
View File

@ -0,0 +1,73 @@
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: 66.67vh;
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
}
.message {
display: grid;
grid-template-columns: 60px 1fr;
padding-bottom: 25px;
font-size: 15px;
font-family: Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.circle-you {
width: 50px;
height: 50px;
background-color: rgb(238, 78, 59);
border-radius: 50%;
}
.circle-bot {
width: 50px;
height: 50px;
background-color: rgb(59, 78, 244);
border-radius: 50%;
}
.circle-bot img,
.circle-you img {
border-radius: 50%;
width: 100%;
height: 100%;
object-fit: cover;
}
.text {}
.text p {
margin-top: 5px;
}
.username {
font-weight: bold;
}
.message-body {}
.message-body img {
max-width: 300px;
max-height: 300px;
border-radius: 20px;
}
.message-body p {
margin-bottom: 0 !important;
font-size: 15px !important;
line-height: 1.428571429 !important;
}
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
.message-body p em {
color: rgb(110, 110, 110) !important;
}

View File

@ -0,0 +1,14 @@
.container {
max-width: 600px;
margin-left: auto;
margin-right: auto;
background-color: rgb(31, 41, 55);
padding:3em;
}
.container p {
font-size: 16px !important;
color: white !important;
margin-bottom: 22px;
line-height: 1.4 !important;
}

52
css/main.css Normal file
View File

@ -0,0 +1,52 @@
.tabs.svelte-710i53 {
margin-top: 0
}
.py-6 {
padding-top: 2.5rem
}
.dark #refresh-button {
background-color: #ffffff1f;
}
#refresh-button {
flex: none;
margin: 0;
padding: 0;
min-width: 50px;
border: none;
box-shadow: none;
border-radius: 10px;
background-color: #0000000d;
}
#download-label, #upload-label {
min-height: 0
}
#accordion {
}
.dark svg {
fill: white;
}
.dark a {
color: white !important;
text-decoration: none !important;
}
svg {
display: unset !important;
vertical-align: middle !important;
margin: 5px;
}
ol li p, ul li p {
display: inline-block;
}
#main, #parameters, #chat-settings, #interface-mode, #lora {
border: 0;
}

18
css/main.js Normal file
View File

@ -0,0 +1,18 @@
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
// Get references to the elements
let main = document.getElementById('main');
let main_parent = main.parentNode;
let extensions = document.getElementById('extensions');
// Add an event listener to the main element
main_parent.addEventListener('click', function(e) {
// Check if the main element is visible
if (main.offsetHeight > 0 && main.offsetWidth > 0) {
extensions.style.display = 'block';
} else {
extensions.style.display = 'none';
}
});

View File

@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
classifications = []
has_pytorch = False
has_safetensors = False
is_lora = False
while True:
content = requests.get(f"{base}{page}{cursor.decode()}").content
@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
for i in range(len(dict)):
fname = dict[i]['path']
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
is_lora = True
is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname)
is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True
classifications.append('pytorch')
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
if classifications[i] == 'pytorch':
links.pop(i)
return links
return links, is_lora
if __name__ == '__main__':
model = args.MODEL
@ -159,15 +163,16 @@ if __name__ == '__main__':
except ValueError as err_branch:
print(f"Error: {err_branch}")
sys.exit()
links, is_lora = get_download_links_from_huggingface(model, branch)
base_folder = 'models' if not is_lora else 'loras'
if branch != 'main':
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
else:
output_folder = Path("models") / model.split('/')[-1]
output_folder = Path(base_folder) / model.split('/')[-1]
if not output_folder.exists():
output_folder.mkdir()
links = get_download_links_from_huggingface(model, branch)
# Downloading the files
print(f"Downloading the model to {output_folder}")
pool = multiprocessing.Pool(processes=args.threads)

View File

@ -0,0 +1 @@
flask_cloudflared==0.0.12

90
extensions/api/script.py Normal file
View File

@ -0,0 +1,90 @@
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules import shared
from modules.text_generation import generate_reply, encode
import json
params = {
'port': 5000,
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
prompt_lines = [l.strip() for l in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generator = generate_reply(
question = prompt,
max_new_tokens = body.get('max_length', 200),
do_sample=True,
temperature=body.get('temperature', 0.5),
top_p=body.get('top_p', 1),
typical_p=body.get('typical', 1),
repetition_penalty=body.get('rep_pen', 1.1),
encoder_repetition_penalty=1,
top_k=body.get('top_k', 0),
min_length=0,
no_repeat_ngram_size=0,
num_beams=1,
penalty_alpha=0,
length_penalty=1,
early_stopping=False,
)
answer = ''
for a in generator:
answer = a[0]
response = json.dumps({
'results': [{
'text': answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting KoboldAI compatible api at {public_url}/api')
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever()
def ui():
Thread(target=run_server, daemon=True).start()

View File

@ -1,8 +1,8 @@
from pathlib import Path
import gradio as gr
from elevenlabslib import *
from elevenlabslib.helpers import *
from elevenlabslib import ElevenLabsUser
from elevenlabslib.helpers import save_bytes_to_path
params = {
'activate': True,

View File

@ -76,7 +76,7 @@ def generate_html():
return container_html
def ui():
with gr.Accordion("Character gallery"):
with gr.Accordion("Character gallery", open=False):
update = gr.Button("Refresh")
gallery = gr.HTML(value=generate_html())
update.click(generate_html, [], gallery)

View File

@ -0,0 +1,4 @@
git+https://github.com/Uberi/speech_recognition.git@010382b
openai-whisper
soundfile
ffmpeg

View File

@ -0,0 +1,54 @@
import gradio as gr
import speech_recognition as sr
input_hijack = {
'state': False,
'value': ["", ""]
}
def do_stt(audio, text_state=""):
transcription = ""
r = sr.Recognizer()
# Convert to AudioData
audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4)
try:
transcription = r.recognize_whisper(audio_data, language="english", model="base.en")
except sr.UnknownValueError:
print("Whisper could not understand audio")
except sr.RequestError as e:
print("Could not request results from Whisper", e)
input_hijack.update({"state": True, "value": [transcription, transcription]})
text_state += transcription + " "
return text_state, text_state
def update_hijack(val):
input_hijack.update({"state": True, "value": [val, val]})
return val
def auto_transcribe(audio, audio_auto, text_state=""):
if audio is None:
return "", ""
if audio_auto:
return do_stt(audio, text_state)
return "", ""
def ui():
tr_state = gr.State(value="")
output_transcription = gr.Textbox(label="STT-Input",
placeholder="Speech Preview. Click \"Generate\" to send",
interactive=True)
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
with gr.Row():
audio = gr.Audio(source="microphone")
audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
transcribe_button = gr.Button(value="Transcribe")
transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])

View File

View File

@ -61,7 +61,7 @@ def load_quantized(model_name):
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map)
# Single GPU

22
modules/LoRA.py Normal file
View File

@ -0,0 +1,22 @@
from pathlib import Path
import modules.shared as shared
from modules.models import load_model
def add_lora_to_model(lora_name):
from peft import PeftModel
# Is there a more efficient way of returning to the base model?
if lora_name == "None":
print("Reloading the model to remove the LoRA...")
shared.model, shared.tokenizer = load_model(shared.model_name)
else:
# Why doesn't this work in 16-bit mode?
print(f"Adding the LoRA {lora_name} to the model...")
params = {}
params['device_map'] = {'': 0}
#params['dtype'] = shared.model.dtype
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)

View File

@ -7,6 +7,7 @@ import transformers
import modules.shared as shared
# Copied from https://github.com/PygmalionAI/gradio-ui/
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):

View File

@ -11,17 +11,11 @@ from PIL import Image
import modules.extensions as extensions_module
import modules.shared as shared
from modules.extensions import apply_extensions
from modules.html_generator import generate_chat_html
from modules.text_generation import encode, generate_reply, get_max_prompt_length
from modules.html_generator import fix_newlines, generate_chat_html
from modules.text_generation import (encode, generate_reply,
get_max_prompt_length)
# This gets the new line characters right.
def clean_chat_message(text):
text = text.replace('\n', '\n\n')
text = re.sub(r"\n{3,}", "\n\n", text)
text = text.strip()
return text
def generate_chat_output(history, name1, name2, character):
if shared.args.cai_chat:
return generate_chat_html(history, name1, name2, character)
@ -29,7 +23,7 @@ def generate_chat_output(history, name1, name2, character):
return history
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
user_input = clean_chat_message(user_input)
user_input = fix_newlines(user_input)
rows = [f"{context.strip()}\n"]
if shared.soft_prompt:
@ -82,7 +76,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
if idx != -1:
reply = reply[:idx]
next_character_found = True
reply = clean_chat_message(reply)
reply = fix_newlines(reply)
# If something like "\nYo" is generated just before "\nYou:"
# is completed, trim it
@ -97,7 +91,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
def stop_everything_event():
shared.stop_everything = True
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
shared.stop_everything = False
just_started = True
eos_token = '\n' if check else None
@ -133,7 +127,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
# Generate
reply = ''
for i in range(chat_generation_attempts):
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
# Extracting the reply
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
@ -160,7 +154,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
yield shared.history['visible']
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
eos_token = '\n' if check else None
if 'pygmalion' in shared.model_name.lower():
@ -172,18 +166,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
# Yield *Is typing...*
yield shared.processing_message
for i in range(chat_generation_attempts):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
yield reply
if next_character_found:
break
yield reply
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
yield generate_chat_html(_history, name1, name2, shared.character)
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
else:
@ -191,7 +185,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
last_internal = shared.history['internal'].pop()
# Yield '*Is typing...*'
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
if shared.args.cai_chat:
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
else:

View File

@ -1,3 +1,5 @@
import gradio as gr
import extensions
import modules.shared as shared
@ -9,9 +11,12 @@ def load_extensions():
for i, name in enumerate(shared.args.extensions):
if name in available_extensions:
print(f'Loading the extension "{name}"... ', end='')
exec(f"import extensions.{name}.script")
state[name] = [True, i]
print('Ok.')
try:
exec(f"import extensions.{name}.script")
state[name] = [True, i]
print('Ok.')
except:
print('Fail.')
# This iterator returns the extensions in the order specified in the command-line
def iterator():
@ -40,6 +45,9 @@ def create_extensions_block():
extension.params[param] = shared.settings[_id]
# Creating the extension ui elements
for extension, name in iterator():
if hasattr(extension, "ui"):
extension.ui()
if len(state) > 0:
with gr.Box(elem_id="extensions"):
gr.Markdown("Extensions")
for extension, name in iterator():
if hasattr(extension, "ui"):
extension.ui()

View File

@ -1,6 +1,6 @@
'''
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
This is a library for formatting text outputs as nice HTML.
'''
@ -8,30 +8,39 @@ import os
import re
from pathlib import Path
import markdown
from PIL import Image
# This is to store the paths to the thumbnails of the profile pictures
image_cache = {}
def generate_basic_html(s):
css = """
.container {
max-width: 600px;
margin-left: auto;
margin-right: auto;
background-color: rgb(31, 41, 55);
padding:3em;
}
.container p {
font-size: 16px !important;
color: white !important;
margin-bottom: 22px;
line-height: 1.4 !important;
}
"""
s = '\n'.join([f'<p>{line}</p>' for line in s.split('\n')])
s = f'<style>{css}</style><div class="container">{s}</div>'
return s
with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f:
readable_css = f.read()
with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f:
_4chan_css = css_f.read()
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
cai_css = f.read()
def fix_newlines(string):
string = string.replace('\n', '\n\n')
string = re.sub(r"\n{3,}", "\n\n", string)
string = string.strip()
return string
# This could probably be generalized and improved
def convert_to_markdown(string):
string = string.replace('\\begin{code}', '```')
string = string.replace('\\end{code}', '```')
string = string.replace('\\begin{blockquote}', '> ')
string = string.replace('\\end{blockquote}', '')
string = re.sub(r"(.)```", r"\1\n```", string)
# string = fix_newlines(string)
return markdown.markdown(string, extensions=['fenced_code'])
def generate_basic_html(string):
string = convert_to_markdown(string)
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
return string
def process_post(post, c):
t = post.split('\n')
@ -48,113 +57,6 @@ def process_post(post, c):
return src
def generate_4chan_html(f):
css = """
#parent #container {
background-color: #eef2ff;
padding: 17px;
}
#parent #container .reply {
background-color: rgb(214, 218, 240);
border-bottom-color: rgb(183, 197, 217);
border-bottom-style: solid;
border-bottom-width: 1px;
border-image-outset: 0;
border-image-repeat: stretch;
border-image-slice: 100%;
border-image-source: none;
border-image-width: 1;
border-left-color: rgb(0, 0, 0);
border-left-style: none;
border-left-width: 0px;
border-right-color: rgb(183, 197, 217);
border-right-style: solid;
border-right-width: 1px;
border-top-color: rgb(0, 0, 0);
border-top-style: none;
border-top-width: 0px;
color: rgb(0, 0, 0);
display: table;
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 4px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
padding-bottom: 4px;
padding-left: 2px;
padding-right: 2px;
padding-top: 4px;
}
#parent #container .number {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
width: 342.65px;
margin-right: 7px;
}
#parent #container .op {
color: rgb(0, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
margin-bottom: 8px;
margin-left: 0px;
margin-right: 0px;
margin-top: 4px;
overflow-x: hidden;
overflow-y: hidden;
}
#parent #container .op blockquote {
margin-left: 0px !important;
}
#parent #container .name {
color: rgb(17, 119, 67);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
font-weight: 700;
margin-left: 7px;
}
#parent #container .quote {
color: rgb(221, 0, 0);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
text-decoration-color: rgb(221, 0, 0);
text-decoration-line: underline;
text-decoration-style: solid;
text-decoration-thickness: auto;
}
#parent #container .greentext {
color: rgb(120, 153, 34);
font-family: arial, helvetica, sans-serif;
font-size: 13.3333px;
}
#parent #container blockquote {
margin: 0px !important;
margin-block-start: 1em;
margin-block-end: 1em;
margin-inline-start: 40px;
margin-inline-end: 40px;
margin-top: 13.33px !important;
margin-bottom: 13.33px !important;
margin-left: 40px !important;
margin-right: 40px !important;
}
#parent #container .message {
color: black;
border: none;
}
"""
posts = []
post = ''
c = -2
@ -181,7 +83,7 @@ def generate_4chan_html(f):
posts[i] = f'<div class="reply">{posts[i]}</div>\n'
output = ''
output += f'<style>{css}</style><div id="parent"><div id="container">'
output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
for post in posts:
output += post
output += '</div></div>'
@ -208,135 +110,39 @@ def get_image_cache(path):
return image_cache[path][1]
def load_html_image(paths):
for str_path in paths:
path = Path(str_path)
if path.exists():
return f'<img src="file/{get_image_cache(path)}">'
return ''
def generate_chat_html(history, name1, name2, character):
css = """
.chat {
margin-left: auto;
margin-right: auto;
max-width: 800px;
height: 66.67vh;
overflow-y: auto;
padding-right: 20px;
display: flex;
flex-direction: column-reverse;
}
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
.message {
display: grid;
grid-template-columns: 60px 1fr;
padding-bottom: 25px;
font-size: 15px;
font-family: Helvetica, Arial, sans-serif;
line-height: 1.428571429;
}
.circle-you {
width: 50px;
height: 50px;
background-color: rgb(238, 78, 59);
border-radius: 50%;
}
.circle-bot {
width: 50px;
height: 50px;
background-color: rgb(59, 78, 244);
border-radius: 50%;
}
.circle-bot img, .circle-you img {
border-radius: 50%;
width: 100%;
height: 100%;
object-fit: cover;
}
.text {
}
.text p {
margin-top: 5px;
}
.username {
font-weight: bold;
}
.message-body {
}
.message-body img {
max-width: 300px;
max-height: 300px;
border-radius: 20px;
}
.message-body p {
margin-bottom: 0 !important;
font-size: 15px !important;
line-height: 1.428571429 !important;
}
.dark .message-body p em {
color: rgb(138, 138, 138) !important;
}
.message-body p em {
color: rgb(110, 110, 110) !important;
}
"""
output = ''
output += f'<style>{css}</style><div class="chat" id="chat">'
img = ''
for i in [
f"characters/{character}.png",
f"characters/{character}.jpg",
f"characters/{character}.jpeg",
"img_bot.png",
"img_bot.jpg",
"img_bot.jpeg"
]:
path = Path(i)
if path.exists():
img = f'<img src="file/{get_image_cache(path)}">'
break
img_me = ''
for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]:
path = Path(i)
if path.exists():
img_me = f'<img src="file/{get_image_cache(path)}">'
break
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
for i,_row in enumerate(history[::-1]):
row = _row.copy()
row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[0])
row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[1])
row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[0])
row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[1])
p = '\n'.join([f"<p>{x}</p>" for x in row[1].split('\n')])
row = [convert_to_markdown(entry) for entry in _row]
output += f"""
<div class="message">
<div class="circle-bot">
{img}
{img_bot}
</div>
<div class="text">
<div class="username">
{name2}
</div>
<div class="message-body">
{p}
{row[1]}
</div>
</div>
</div>
"""
if not (i == len(history)-1 and len(row[0]) == 0):
p = '\n'.join([f"<p>{x}</p>" for x in row[0].split('\n')])
output += f"""
<div class="message">
<div class="circle-you">
@ -347,7 +153,7 @@ def generate_chat_html(history, name1, name2, character):
{name1}
</div>
<div class="message-body">
{p}
{row[0]}
</div>
</div>
</div>

View File

@ -7,7 +7,9 @@ from pathlib import Path
import numpy as np
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig)
import modules.shared as shared
@ -16,8 +18,7 @@ transformers.logging.set_verbosity_error()
local_rank = None
if shared.args.flexgen:
from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
Policy, str2bool)
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
if shared.args.deepspeed:
import deepspeed
@ -46,7 +47,12 @@ def load_model(model_name):
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
else:
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
if torch.has_mps:
device = torch.device('mps')
model = model.to(device)
else:
model = model.cuda()
# FlexGen
elif shared.args.flexgen:
@ -95,39 +101,60 @@ def load_model(model_name):
# Custom
else:
command = "AutoModelForCausalLM.from_pretrained"
params = ["low_cpu_mem_usage=True"]
if not shared.args.cpu and not torch.cuda.is_available():
print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
params = {"low_cpu_mem_usage": True}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
shared.args.cpu = True
if shared.args.cpu:
params.append("low_cpu_mem_usage=True")
params.append("torch_dtype=torch.float32")
params["torch_dtype"] = torch.float32
else:
params.append("device_map='auto'")
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
params["device_map"] = 'auto'
if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
elif shared.args.load_in_8bit:
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
elif shared.args.bf16:
params["torch_dtype"] = torch.bfloat16
else:
params["torch_dtype"] = torch.float16
if shared.args.gpu_memory:
memory_map = shared.args.gpu_memory
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
for i in range(1, len(memory_map)):
max_memory += (f", {i}: '{memory_map[i]}GiB'")
max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
params.append(max_memory)
elif not shared.args.load_in_8bit:
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
suggestion = round((total_mem-1000)/1000)*1000
if total_mem-suggestion < 800:
max_memory = {}
for i in range(len(memory_map)):
max_memory[i] = f'{memory_map[i]}GiB'
max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
params['max_memory'] = max_memory
elif shared.args.auto_devices:
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
suggestion = round((total_mem-1000) / 1000) * 1000
if total_mem - suggestion < 800:
suggestion -= 1000
suggestion = int(round(suggestion/1000))
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
if shared.args.disk:
params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
model = eval(command)
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
params['max_memory'] = max_memory
if shared.args.disk:
params["offload_folder"] = shared.args.disk_cache_dir
checkpoint = Path(f'models/{shared.model_name}')
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
config = AutoConfig.from_pretrained(checkpoint)
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config)
model.tie_weights()
params['device_map'] = infer_auto_device_map(
model,
dtype=torch.int8,
max_memory=params['max_memory'],
no_split_module_classes = model._no_split_modules
)
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
# Loading the tokenizer
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():

View File

@ -2,7 +2,8 @@ import argparse
model = None
tokenizer = None
model_name = ""
model_name = "None"
lora_name = "None"
soft_prompt_tensor = None
soft_prompt = False
is_RWKV = False
@ -19,6 +20,9 @@ gradio = {}
# Generation input parameters
input_params = []
# For restarting the interface
need_restart = False
settings = {
'max_new_tokens': 200,
'max_new_tokens_min': 1,
@ -26,7 +30,7 @@ settings = {
'name1': 'Person 1',
'name2': 'Person 2',
'context': 'This is a conversation between two people.',
'stop_at_newline': True,
'stop_at_newline': False,
'chat_prompt_size': 2048,
'chat_prompt_size_min': 0,
'chat_prompt_size_max': 2048,
@ -49,6 +53,10 @@ settings = {
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
},
'lora_prompts': {
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
'(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}
}
@ -64,6 +72,7 @@ def str2bool(v):
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')

View File

@ -33,12 +33,15 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
return input_ids.numpy()
elif shared.args.deepspeed:
return input_ids.to(device=local_rank)
elif torch.has_mps:
device = torch.device('mps')
return input_ids.to(device)
else:
return input_ids.cuda()
def decode(output_ids):
# Open Assistant relies on special tokens like <|endoftext|>
if re.match('oasst-*', shared.model_name.lower()):
if re.match('(oasst|galactica)-*', shared.model_name.lower()):
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
else:
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
@ -89,7 +92,7 @@ def clear_torch_cache():
if not shared.args.cpu:
torch.cuda.empty_cache()
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
clear_torch_cache()
t0 = time.time()
@ -101,7 +104,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
yield formatted_outputs(reply, shared.model_name)
else:
yield formatted_outputs(question, shared.model_name)
if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
@ -143,6 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
"top_p": top_p,
"typical_p": typical_p,
"repetition_penalty": repetition_penalty,
"encoder_repetition_penalty": encoder_repetition_penalty,
"top_k": top_k,
"min_length": min_length if shared.args.no_stream else 0,
"no_repeat_ngram_size": no_repeat_ngram_size,
@ -196,7 +201,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
def generate_with_streaming(**kwargs):
return Iteratorize(generate_with_callback, kwargs, callback=None)
yield formatted_outputs(original_question, shared.model_name)
if not (shared.args.chat or shared.args.cai_chat):
yield formatted_outputs(original_question, shared.model_name)
with generate_with_streaming(**generate_params) as generator:
for output in generator:
if shared.soft_prompt:

View File

@ -1,68 +1,17 @@
from pathlib import Path
import gradio as gr
refresh_symbol = '\U0001f504' # 🔄
css = """
.tabs.svelte-710i53 {
margin-top: 0
}
.py-6 {
padding-top: 2.5rem
}
.dark #refresh-button {
background-color: #ffffff1f;
}
#refresh-button {
flex: none;
margin: 0;
padding: 0;
min-width: 50px;
border: none;
box-shadow: none;
border-radius: 10px;
background-color: #0000000d;
}
#download-label, #upload-label {
min-height: 0
}
#accordion {
}
.dark svg {
fill: white;
}
svg {
display: unset !important;
vertical-align: middle !important;
margin: 5px;
}
ol li p, ul li p {
display: inline-block;
}
"""
chat_css = """
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
height: 66.67vh
}
.gradio-container {
max-width: 800px !important;
margin-left: auto !important;
margin-right: auto !important;
}
.w-screen {
width: unset
}
div.svelte-362y77>*, div.svelte-362y77>.form>* {
flex-wrap: nowrap
}
/* fixes the API documentation in chat mode */
.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
display: grid;
}
.pending.svelte-1ed2p3z {
opacity: 1;
}
"""
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
css = f.read()
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
chat_css = f.read()
with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
main_js = f.read()
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
chat_js = f.read()
class ToolButton(gr.Button, gr.components.FormComponent):
"""Small button with single emoji as text, fits inside gradio forms"""

View File

@ -1,12 +1,7 @@
do_sample=True
temperature=1
top_p=1
typical_p=1
repetition_penalty=1
top_k=50
num_beams=1
penalty_alpha=0
min_length=0
length_penalty=1
no_repeat_ngram_size=0
top_p=0.5
top_k=40
temperature=0.7
repetition_penalty=1.2
typical_p=1.0
early_stopping=False

View File

@ -1,6 +0,0 @@
do_sample=True
top_p=0.9
top_k=50
temperature=1.39
repetition_penalty=1.08
typical_p=0.2

View File

@ -2,10 +2,12 @@ accelerate==0.17.1
bitsandbytes==0.37.1
flexgen==0.1.7
gradio==3.18.0
markdown
numpy
peft==0.2.0
requests
rwkv==0.4.2
rwkv==0.7.0
safetensors==0.3.0
sentencepiece
tqdm
git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
git+https://github.com/huggingface/transformers

482
server.py
View File

@ -15,6 +15,7 @@ import modules.extensions as extensions_module
import modules.shared as shared
import modules.ui as ui
from modules.html_generator import generate_chat_html
from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt
from modules.text_generation import generate_reply
@ -34,7 +35,7 @@ def get_available_models():
if shared.args.flexgen:
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
else:
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def get_available_presets():
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
@ -48,6 +49,9 @@ def get_available_extensions():
def get_available_softprompts():
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
def get_available_loras():
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
def load_model_wrapper(selected_model):
if selected_model != shared.model_name:
shared.model_name = selected_model
@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):
return selected_model
def load_lora_wrapper(selected_lora):
shared.lora_name = selected_lora
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if not shared.args.cpu:
gc.collect()
torch.cuda.empty_cache()
add_lora_to_model(selected_lora)
return selected_lora, default_text
def load_preset_values(preset_menu, return_dict=False):
generate_params = {
'do_sample': True,
@ -66,6 +81,7 @@ def load_preset_values(preset_menu, return_dict=False):
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1,
'encoder_repetition_penalty': 1,
'top_k': 50,
'num_beams': 1,
'penalty_alpha': 0,
@ -86,7 +102,7 @@ def load_preset_values(preset_menu, return_dict=False):
if return_dict:
return generate_params
else:
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
def upload_soft_prompt(file):
with zipfile.ZipFile(io.BytesIO(file)) as zf:
@ -100,9 +116,7 @@ def upload_soft_prompt(file):
return name
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
def create_model_and_preset_menus():
with gr.Row():
with gr.Column():
with gr.Row():
@ -113,31 +127,48 @@ def create_settings_menus(default_preset):
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
with gr.Row():
with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
with gr.Column():
def create_settings_menus(default_preset):
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
with gr.Row():
shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
with gr.Row():
with gr.Column():
with gr.Box():
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
with gr.Row():
with gr.Column():
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
with gr.Column():
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
with gr.Column():
with gr.Box():
gr.Markdown('Contrastive search')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
gr.Markdown('Contrastive search:')
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
with gr.Box():
gr.Markdown('Beam search (uses a lot of VRAM)')
with gr.Row():
with gr.Column():
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
with gr.Column():
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
gr.Markdown('Beam search (uses a lot of VRAM):')
with gr.Row():
with gr.Column():
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
with gr.Column():
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
with gr.Row():
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
with gr.Accordion('Soft prompt', open=False):
with gr.Row():
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
@ -147,14 +178,35 @@ def create_settings_menus(default_preset):
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
def set_interface_arguments(interface_mode, extensions, cmd_active):
modes = ["default", "notebook", "chat", "cai_chat"]
cmd_list = vars(shared.args)
cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
shared.args.extensions = extensions
for k in modes[1:]:
exec(f"shared.args.{k} = False")
if interface_mode != "default":
exec(f"shared.args.{interface_mode} = True")
for k in cmd_list:
exec(f"shared.args.{k} = False")
for k in cmd_active:
exec(f"shared.args.{k} = True")
shared.need_restart = True
available_models = get_available_models()
available_presets = get_available_presets()
available_characters = get_available_characters()
available_softprompts = get_available_softprompts()
available_loras = get_available_loras()
# Default extensions
extensions_module.available_extensions = get_available_extensions()
@ -168,8 +220,6 @@ else:
shared.args.extensions = shared.args.extensions or []
if extension not in shared.args.extensions:
shared.args.extensions.append(extension)
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
# Default model
if shared.args.model is not None:
@ -189,191 +239,235 @@ else:
print()
shared.model_name = available_models[i]
shared.model, shared.tokenizer = load_model(shared.model_name)
if shared.args.lora:
print(shared.args.lora)
shared.lora_name = shared.args.lora
add_lora_to_model(shared.lora_name)
# Default UI settings
gen_events = []
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
if default_text == '':
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
title ='Text generation web UI'
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
if shared.args.chat or shared.args.cai_chat:
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate')
with gr.Row():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
shared.gradio['Remove last'] = gr.Button('Remove last')
def create_interface():
shared.gradio['Clear history'] = gr.Button('Clear history')
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
with gr.Tab('Chat settings'):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
gen_events = []
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
extensions_module.load_extensions()
with gr.Row():
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
with gr.Row():
with gr.Tab('Chat history'):
with gr.Row():
with gr.Column():
gr.Markdown('Upload')
shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
with gr.Column():
gr.Markdown('Download')
shared.gradio['download'] = gr.File()
shared.gradio['download_button'] = gr.Button(value='Click me')
with gr.Tab('Upload character'):
with gr.Row():
with gr.Column():
gr.Markdown('1. Select the JSON file')
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
with gr.Column():
gr.Markdown('2. Select your character\'s profile picture (optional)')
shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
shared.gradio['Upload character'] = gr.Button(value='Submit')
with gr.Tab('Upload your profile picture'):
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Upload TavernAI Character Card'):
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Generation settings'):
with gr.Row():
with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
with gr.Column():
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
if shared.args.extensions is not None:
with gr.Tab('Extensions'):
extensions_module.create_extensions_block()
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
# Clearing stuff and saving the history
for i in ['Generate', 'Regenerate', 'Replace last reply']:
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
gr.Markdown(description)
with gr.Tab('Raw'):
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
create_settings_menus(default_preset)
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
else:
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
gr.Markdown(description)
with gr.Row():
with gr.Column():
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
if shared.args.chat or shared.args.cai_chat:
with gr.Tab("Text generation", elem_id="main"):
if shared.args.cai_chat:
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
else:
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
shared.gradio['textbox'] = gr.Textbox(label='Input')
with gr.Row():
with gr.Column():
shared.gradio['Continue'] = gr.Button('Continue')
with gr.Column():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Row():
shared.gradio['Impersonate'] = gr.Button('Impersonate')
shared.gradio['Regenerate'] = gr.Button('Regenerate')
with gr.Row():
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
shared.gradio['Remove last'] = gr.Button('Remove last')
shared.gradio['Clear history'] = gr.Button('Clear history')
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
create_model_and_preset_menus()
with gr.Tab("Character", elem_id="chat-settings"):
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
with gr.Row():
with gr.Tab('Chat history'):
with gr.Row():
with gr.Column():
gr.Markdown('Upload')
shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
with gr.Column():
gr.Markdown('Download')
shared.gradio['download'] = gr.File()
shared.gradio['download_button'] = gr.Button(value='Click me')
with gr.Tab('Upload character'):
with gr.Row():
with gr.Column():
gr.Markdown('1. Select the JSON file')
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
with gr.Column():
gr.Markdown('2. Select your character\'s profile picture (optional)')
shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
shared.gradio['Upload character'] = gr.Button(value='Submit')
with gr.Tab('Upload your profile picture'):
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
with gr.Tab('Upload TavernAI Character Card'):
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
with gr.Tab("Parameters", elem_id="parameters"):
with gr.Box():
gr.Markdown("Chat parameters")
with gr.Row():
with gr.Column():
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
with gr.Column():
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
create_settings_menus(default_preset)
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
with gr.Column():
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
# Clear history with confirmation
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
# Clearing stuff and saving the history
for i in ['Generate', 'Regenerate', 'Replace last reply']:
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
elif shared.args.notebook:
with gr.Tab("Text generation", elem_id="main"):
with gr.Tab('Raw'):
shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
with gr.Row():
shared.gradio['Stop'] = gr.Button('Stop')
shared.gradio['Generate'] = gr.Button('Generate')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
create_model_and_preset_menus()
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
else:
with gr.Tab("Text generation", elem_id="main"):
with gr.Row():
with gr.Column():
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
shared.gradio['Generate'] = gr.Button('Generate')
with gr.Row():
with gr.Column():
shared.gradio['Continue'] = gr.Button('Continue')
with gr.Column():
shared.gradio['Stop'] = gr.Button('Stop')
create_model_and_preset_menus()
with gr.Column():
with gr.Tab('Raw'):
shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
with gr.Tab('Markdown'):
shared.gradio['markdown'] = gr.Markdown()
with gr.Tab('HTML'):
shared.gradio['html'] = gr.HTML()
with gr.Tab("Parameters", elem_id="parameters"):
create_settings_menus(default_preset)
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
with gr.Tab("Interface mode", elem_id="interface-mode"):
modes = ["default", "notebook", "chat", "cai_chat"]
current_mode = "default"
for mode in modes[1:]:
if eval(f"shared.args.{mode}"):
current_mode = mode
break
cmd_list = vars(shared.args)
cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]]
gr.Markdown("*Experimental*")
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags")
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}')
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
# Launch the interface
shared.gradio['interface'].queue()
if shared.args.listen:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
else:
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
create_interface()
# I think that I will need this later
while True:
time.sleep(0.5)
if shared.need_restart:
shared.need_restart = False
shared.gradio['interface'].close()
create_interface()

View File

@ -5,7 +5,7 @@
"name1": "Person 1",
"name2": "Person 2",
"context": "This is a conversation between two people.",
"stop_at_newline": true,
"stop_at_newline": false,
"chat_prompt_size": 2048,
"chat_prompt_size_min": 0,
"chat_prompt_size_max": 2048,
@ -23,13 +23,16 @@
"presets": {
"default": "NovelAI-Sphinx Moth",
"pygmalion-*": "Pygmalion",
"RWKV-*": "Naive",
"(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
"RWKV-*": "Naive"
},
"prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
},
"lora_prompts": {
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
}
}