mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-23 21:18:00 +01:00
Merge branch 'main' into Brawlence-main
This commit is contained in:
commit
eab8de0d4a
4
.gitignore
vendored
4
.gitignore
vendored
@ -4,12 +4,15 @@ extensions/silero_tts/outputs/*
|
||||
extensions/elevenlabs_tts/outputs/*
|
||||
extensions/sd_api_pictures/outputs/*
|
||||
logs/*
|
||||
loras/*
|
||||
models/*
|
||||
softprompts/*
|
||||
torch-dumps/*
|
||||
*pycache*
|
||||
*/*pycache*
|
||||
*/*/pycache*
|
||||
venv/
|
||||
.venv/
|
||||
|
||||
settings.json
|
||||
img_bot*
|
||||
@ -17,6 +20,7 @@ img_me*
|
||||
|
||||
!characters/Example.json
|
||||
!characters/Example.png
|
||||
!loras/place-your-loras-here.txt
|
||||
!models/place-your-models-here.txt
|
||||
!softprompts/place-your-softprompts-here.txt
|
||||
!torch-dumps/place-your-pt-models-here.txt
|
||||
|
152
README.md
152
README.md
@ -19,52 +19,76 @@ Its goal is to become the [AUTOMATIC1111/stable-diffusion-webui](https://github.
|
||||
* Generate Markdown output for [GALACTICA](https://github.com/paperswithcode/galai), including LaTeX support.
|
||||
* Support for [Pygmalion](https://huggingface.co/models?search=pygmalionai/pygmalion) and custom characters in JSON or TavernAI Character Card formats ([FAQ](https://github.com/oobabooga/text-generation-webui/wiki/Pygmalion-chat-model-FAQ)).
|
||||
* Advanced chat features (send images, get audio responses with TTS).
|
||||
* Stream the text output in real time.
|
||||
* Stream the text output in real time very efficiently.
|
||||
* Load parameter presets from text files.
|
||||
* Load large models in 8-bit mode (see [here](https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134), [here](https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652) and [here](https://www.reddit.com/r/PygmalionAI/comments/1115gom/running_pygmalion_6b_with_8gb_of_vram/) if you are on Windows).
|
||||
* Load large models in 8-bit mode.
|
||||
* Split large models across your GPU(s), CPU, and disk.
|
||||
* CPU mode.
|
||||
* [FlexGen offload](https://github.com/oobabooga/text-generation-webui/wiki/FlexGen).
|
||||
* [DeepSpeed ZeRO-3 offload](https://github.com/oobabooga/text-generation-webui/wiki/DeepSpeed).
|
||||
* Get responses via API, [with](https://github.com/oobabooga/text-generation-webui/blob/main/api-example-streaming.py) or [without](https://github.com/oobabooga/text-generation-webui/blob/main/api-example.py) streaming.
|
||||
* [Supports the LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
|
||||
* [Supports the RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
|
||||
* [LLaMA model, including 4-bit mode](https://github.com/oobabooga/text-generation-webui/wiki/LLaMA-model).
|
||||
* [RWKV model](https://github.com/oobabooga/text-generation-webui/wiki/RWKV-model).
|
||||
* [Supports LoRAs](https://github.com/oobabooga/text-generation-webui/wiki/Using-LoRAs).
|
||||
* Supports softprompts.
|
||||
* [Supports extensions](https://github.com/oobabooga/text-generation-webui/wiki/Extensions).
|
||||
* [Works on Google Colab](https://github.com/oobabooga/text-generation-webui/wiki/Running-on-Colab).
|
||||
|
||||
## Installation option 1: conda
|
||||
## Installation
|
||||
|
||||
Open a terminal and copy and paste these commands one at a time ([install conda](https://docs.conda.io/en/latest/miniconda.html) first if you don't have it already):
|
||||
The recommended installation methods are the following:
|
||||
|
||||
* Linux and MacOS: using conda natively.
|
||||
* Windows: using conda on WSL ([WSL installation guide](https://github.com/oobabooga/text-generation-webui/wiki/Windows-Subsystem-for-Linux-(Ubuntu)-Installation-Guide)).
|
||||
|
||||
Conda can be downloaded here: https://docs.conda.io/en/latest/miniconda.html
|
||||
|
||||
On Linux or WSL, it can be automatically installed with these two commands:
|
||||
|
||||
```
|
||||
conda create -n textgen
|
||||
curl -sL "https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh" > "Miniconda3.sh"
|
||||
bash Miniconda3.sh
|
||||
```
|
||||
|
||||
Source: https://educe-ubc.github.io/conda.html
|
||||
|
||||
#### 1. Create a new conda environment
|
||||
|
||||
```
|
||||
conda create -n textgen python=3.10.9
|
||||
conda activate textgen
|
||||
conda install torchvision torchaudio pytorch-cuda=11.7 git -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
#### 2. Install Pytorch
|
||||
|
||||
| System | GPU | Command |
|
||||
|--------|---------|---------|
|
||||
| Linux/WSL | NVIDIA | `pip3 install torch torchvision torchaudio` |
|
||||
| Linux | AMD | `pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2` |
|
||||
| MacOS + MPS (untested) | Any | `pip3 install torch torchvision torchaudio` |
|
||||
|
||||
The up to date commands can be found here: https://pytorch.org/get-started/locally/.
|
||||
|
||||
MacOS users, refer to the comments here: https://github.com/oobabooga/text-generation-webui/pull/393
|
||||
|
||||
|
||||
#### 3. Install the web UI
|
||||
|
||||
```
|
||||
git clone https://github.com/oobabooga/text-generation-webui
|
||||
cd text-generation-webui
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
The third line assumes that you have an NVIDIA GPU.
|
||||
|
||||
* If you have an AMD GPU, replace the third command with this one:
|
||||
|
||||
```
|
||||
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
|
||||
```
|
||||
|
||||
* If you are running it in CPU mode, replace the third command with this one:
|
||||
|
||||
```
|
||||
conda install pytorch torchvision torchaudio git -c pytorch
|
||||
```
|
||||
|
||||
> **Note**
|
||||
> 1. If you are on Windows, it may be easier to run the commands above in a WSL environment. The performance may also be better.
|
||||
> 2. For a more detailed, user-contributed guide, see: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
|
||||
>
|
||||
> For bitsandbytes and `--load-in-8bit` to work on Linux/WSL, this dirty fix is currently necessary: https://github.com/oobabooga/text-generation-webui/issues/400#issuecomment-1474876859
|
||||
|
||||
## Installation option 2: one-click installers
|
||||
### Alternative: native Windows installation
|
||||
|
||||
As an alternative to the recommended WSL method, you can install the web UI natively on Windows using this guide. It will be a lot harder and the performance may be slower: [Installation instructions for human beings](https://github.com/oobabooga/text-generation-webui/wiki/Installation-instructions-for-human-beings).
|
||||
|
||||
### Alternative: one-click installers
|
||||
|
||||
[oobabooga-windows.zip](https://github.com/oobabooga/one-click-installers/archive/refs/heads/oobabooga-windows.zip)
|
||||
|
||||
@ -75,19 +99,25 @@ Just download the zip above, extract it, and double click on "install". The web
|
||||
* To download a model, double click on "download-model"
|
||||
* To start the web UI, double click on "start-webui"
|
||||
|
||||
Source codes: https://github.com/oobabooga/one-click-installers
|
||||
|
||||
This method lags behind the newest developments and does not support 8-bit mode on Windows without additional set up: https://github.com/oobabooga/text-generation-webui/issues/147#issuecomment-1456040134, https://github.com/oobabooga/text-generation-webui/issues/20#issuecomment-1411650652
|
||||
|
||||
### Alternative: Docker
|
||||
|
||||
https://github.com/oobabooga/text-generation-webui/issues/174, https://github.com/oobabooga/text-generation-webui/issues/87
|
||||
|
||||
## Downloading models
|
||||
|
||||
Models should be placed under `models/model-name`. For instance, `models/gpt-j-6B` for [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main).
|
||||
|
||||
#### Hugging Face
|
||||
Models should be placed inside the `models` folder.
|
||||
|
||||
[Hugging Face](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads) is the main place to download models. These are some noteworthy examples:
|
||||
|
||||
* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
|
||||
* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
|
||||
* [Pythia](https://huggingface.co/models?search=eleutherai/pythia)
|
||||
* [OPT](https://huggingface.co/models?search=facebook/opt)
|
||||
* [GALACTICA](https://huggingface.co/models?search=facebook/galactica)
|
||||
* [GPT-J 6B](https://huggingface.co/EleutherAI/gpt-j-6B/tree/main)
|
||||
* [GPT-Neo](https://huggingface.co/models?pipeline_tag=text-generation&sort=downloads&search=eleutherai+%2F+gpt-neo)
|
||||
* [\*-Erebus](https://huggingface.co/models?search=erebus) (NSFW)
|
||||
* [Pygmalion](https://huggingface.co/models?search=pygmalion) (NSFW)
|
||||
|
||||
@ -101,7 +131,7 @@ For instance:
|
||||
|
||||
If you want to download a model manually, note that all you need are the json, txt, and pytorch\*.bin (or model*.safetensors) files. The remaining files are not necessary.
|
||||
|
||||
#### GPT-4chan
|
||||
### GPT-4chan
|
||||
|
||||
[GPT-4chan](https://huggingface.co/ykilcher/gpt-4chan) has been shut down from Hugging Face, so you need to download it elsewhere. You have two options:
|
||||
|
||||
@ -123,6 +153,7 @@ python download-model.py EleutherAI/gpt-j-6B --text-only
|
||||
## Starting the web UI
|
||||
|
||||
conda activate textgen
|
||||
cd text-generation-webui
|
||||
python server.py
|
||||
|
||||
Then browse to
|
||||
@ -133,41 +164,42 @@ Then browse to
|
||||
|
||||
Optionally, you can use the following command-line flags:
|
||||
|
||||
| Flag | Description |
|
||||
|-------------|-------------|
|
||||
| `-h`, `--help` | show this help message and exit |
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode.|
|
||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
||||
| `--cpu` | Use the CPU to generate text.|
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
|
||||
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
|
||||
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| Flag | Description |
|
||||
|------------------|-------------|
|
||||
| `-h`, `--help` | show this help message and exit |
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode.|
|
||||
| `--cai-chat` | Launch the web UI in chat mode with a style similar to Character.AI's. If the file `img_bot.png` or `img_bot.jpg` exists in the same folder as server.py, this image will be used as the bot's profile picture. Similarly, `img_me.png` or `img_me.jpg` will be used as your profile picture. |
|
||||
| `--cpu` | Use the CPU to generate text.|
|
||||
| `--load-in-8bit` | Load the model with 8-bit precision.|
|
||||
| `--load-in-4bit` | DEPRECATED: use `--gptq-bits 4` instead. |
|
||||
| `--gptq-bits GPTQ_BITS` | Load a pre-quantized model with specified precision. 2, 3, 4 and 8 (bit) are supported. Currently only works with LLaMA and OPT. |
|
||||
| `--gptq-model-type MODEL_TYPE` | Model type of pre-quantized model. Currently only LLaMa and OPT are supported. |
|
||||
| `--bf16` | Load the model with bfloat16 precision. Requires NVIDIA Ampere GPU. |
|
||||
| `--auto-devices` | Automatically split the model across the available GPU(s) and CPU.|
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
| `--disk` | If the model is too large for your GPU(s) and CPU combined, send the remaining layers to the disk. |
|
||||
| `--disk-cache-dir DISK_CACHE_DIR` | Directory to save the disk cache to. Defaults to `cache/`. |
|
||||
| `--gpu-memory GPU_MEMORY [GPU_MEMORY ...]` | Maxmimum GPU memory in GiB to be allocated per GPU. Example: `--gpu-memory 10` for a single GPU, `--gpu-memory 10 5` for two GPUs. |
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
|
||||
| `--flexgen` | Enable the use of FlexGen offloading. |
|
||||
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
||||
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
||||
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
||||
| `--cpu-memory CPU_MEMORY` | Maximum CPU memory in GiB to allocate for offloaded weights. Must be an integer number. Defaults to 99.|
|
||||
| `--flexgen` | Enable the use of FlexGen offloading. |
|
||||
| `--percent PERCENT [PERCENT ...]` | FlexGen: allocation percentages. Must be 6 numbers separated by spaces (default: 0, 100, 100, 0, 100, 0). |
|
||||
| `--compress-weight` | FlexGen: Whether to compress weight (default: False).|
|
||||
| `--pin-weight [PIN_WEIGHT]` | FlexGen: whether to pin weights (setting this to False reduces CPU memory by 20%). |
|
||||
| `--deepspeed` | Enable the use of DeepSpeed ZeRO-3 for inference via the Transformers integration. |
|
||||
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||
| `--no-stream` | Don't stream the text output in real time. |
|
||||
| `--nvme-offload-dir NVME_OFFLOAD_DIR` | DeepSpeed: Directory to use for ZeRO-3 NVME offloading. |
|
||||
| `--local_rank LOCAL_RANK` | DeepSpeed: Optional argument for distributed setups. |
|
||||
| `--rwkv-strategy RWKV_STRATEGY` | RWKV: The strategy to use while loading the model. Examples: "cpu fp32", "cuda fp16", "cuda fp16i8". |
|
||||
| `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. |
|
||||
| `--no-stream` | Don't stream the text output in real time. |
|
||||
| `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.|
|
||||
| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. |
|
||||
| `--listen` | Make the web UI reachable from your local network.|
|
||||
| `--listen` | Make the web UI reachable from your local network.|
|
||||
| `--listen-port LISTEN_PORT` | The listening port that the server will use. |
|
||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||
| `--auto-launch` | Open the web UI in the default browser upon launch. |
|
||||
| `--verbose` | Print the prompts to the terminal. |
|
||||
| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. |
|
||||
| `--auto-launch` | Open the web UI in the default browser upon launch. |
|
||||
| `--verbose` | Print the prompts to the terminal. |
|
||||
|
||||
Out of memory errors? [Check this guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide).
|
||||
|
||||
@ -192,7 +224,7 @@ Before reporting a bug, make sure that you have:
|
||||
|
||||
## Credits
|
||||
|
||||
- Gradio dropdown menu refresh button: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
- Gradio dropdown menu refresh button, code for reloading the interface: https://github.com/AUTOMATIC1111/stable-diffusion-webui
|
||||
- Verbose preset: Anonymous 4chan user.
|
||||
- NovelAI and KoboldAI presets: https://github.com/KoboldAI/KoboldAI-Client/wiki/Settings-Presets
|
||||
- Pygmalion preset, code for early stopping in chat mode, code for some of the sliders, --chat mode colors: https://github.com/PygmalionAI/gradio-ui/
|
||||
|
@ -26,6 +26,7 @@ async def run(context):
|
||||
'top_p': 0.9,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.05,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
@ -43,14 +44,14 @@ async def run(context):
|
||||
case "send_hash":
|
||||
await websocket.send(json.dumps({
|
||||
"session_hash": session,
|
||||
"fn_index": 7
|
||||
"fn_index": 9
|
||||
}))
|
||||
case "estimation":
|
||||
pass
|
||||
case "send_data":
|
||||
await websocket.send(json.dumps({
|
||||
"session_hash": session,
|
||||
"fn_index": 7,
|
||||
"fn_index": 9,
|
||||
"data": [
|
||||
context,
|
||||
params['max_new_tokens'],
|
||||
@ -59,6 +60,7 @@ async def run(context):
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
|
@ -24,6 +24,7 @@ params = {
|
||||
'top_p': 0.9,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.05,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
@ -45,6 +46,7 @@ response = requests.post(f"http://{server}:7860/run/textgen", json={
|
||||
params['top_p'],
|
||||
params['typical_p'],
|
||||
params['repetition_penalty'],
|
||||
params['encoder_repetition_penalty'],
|
||||
params['top_k'],
|
||||
params['min_length'],
|
||||
params['no_repeat_ngram_size'],
|
||||
|
25
css/chat.css
Normal file
25
css/chat.css
Normal file
@ -0,0 +1,25 @@
|
||||
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
|
||||
height: 66.67vh
|
||||
}
|
||||
|
||||
.gradio-container {
|
||||
margin-left: auto !important;
|
||||
margin-right: auto !important;
|
||||
}
|
||||
|
||||
.w-screen {
|
||||
width: unset
|
||||
}
|
||||
|
||||
div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
||||
flex-wrap: nowrap
|
||||
}
|
||||
|
||||
/* fixes the API documentation in chat mode */
|
||||
.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
|
||||
display: grid;
|
||||
}
|
||||
|
||||
.pending.svelte-1ed2p3z {
|
||||
opacity: 1;
|
||||
}
|
4
css/chat.js
Normal file
4
css/chat.js
Normal file
@ -0,0 +1,4 @@
|
||||
document.getElementById("main").childNodes[0].style = "max-width: 800px; margin-left: auto; margin-right: auto";
|
||||
document.getElementById("extensions").style.setProperty("max-width", "800px");
|
||||
document.getElementById("extensions").style.setProperty("margin-left", "auto");
|
||||
document.getElementById("extensions").style.setProperty("margin-right", "auto");
|
103
css/html_4chan_style.css
Normal file
103
css/html_4chan_style.css
Normal file
@ -0,0 +1,103 @@
|
||||
#parent #container {
|
||||
background-color: #eef2ff;
|
||||
padding: 17px;
|
||||
}
|
||||
#parent #container .reply {
|
||||
background-color: rgb(214, 218, 240);
|
||||
border-bottom-color: rgb(183, 197, 217);
|
||||
border-bottom-style: solid;
|
||||
border-bottom-width: 1px;
|
||||
border-image-outset: 0;
|
||||
border-image-repeat: stretch;
|
||||
border-image-slice: 100%;
|
||||
border-image-source: none;
|
||||
border-image-width: 1;
|
||||
border-left-color: rgb(0, 0, 0);
|
||||
border-left-style: none;
|
||||
border-left-width: 0px;
|
||||
border-right-color: rgb(183, 197, 217);
|
||||
border-right-style: solid;
|
||||
border-right-width: 1px;
|
||||
border-top-color: rgb(0, 0, 0);
|
||||
border-top-style: none;
|
||||
border-top-width: 0px;
|
||||
color: rgb(0, 0, 0);
|
||||
display: table;
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
margin-bottom: 4px;
|
||||
margin-left: 0px;
|
||||
margin-right: 0px;
|
||||
margin-top: 4px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: hidden;
|
||||
padding-bottom: 4px;
|
||||
padding-left: 2px;
|
||||
padding-right: 2px;
|
||||
padding-top: 4px;
|
||||
}
|
||||
|
||||
#parent #container .number {
|
||||
color: rgb(0, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
width: 342.65px;
|
||||
margin-right: 7px;
|
||||
}
|
||||
|
||||
#parent #container .op {
|
||||
color: rgb(0, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
margin-bottom: 8px;
|
||||
margin-left: 0px;
|
||||
margin-right: 0px;
|
||||
margin-top: 4px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
#parent #container .op blockquote {
|
||||
margin-left: 0px !important;
|
||||
}
|
||||
|
||||
#parent #container .name {
|
||||
color: rgb(17, 119, 67);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
font-weight: 700;
|
||||
margin-left: 7px;
|
||||
}
|
||||
|
||||
#parent #container .quote {
|
||||
color: rgb(221, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
text-decoration-color: rgb(221, 0, 0);
|
||||
text-decoration-line: underline;
|
||||
text-decoration-style: solid;
|
||||
text-decoration-thickness: auto;
|
||||
}
|
||||
|
||||
#parent #container .greentext {
|
||||
color: rgb(120, 153, 34);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
}
|
||||
|
||||
#parent #container blockquote {
|
||||
margin: 0px !important;
|
||||
margin-block-start: 1em;
|
||||
margin-block-end: 1em;
|
||||
margin-inline-start: 40px;
|
||||
margin-inline-end: 40px;
|
||||
margin-top: 13.33px !important;
|
||||
margin-bottom: 13.33px !important;
|
||||
margin-left: 40px !important;
|
||||
margin-right: 40px !important;
|
||||
}
|
||||
|
||||
#parent #container .message {
|
||||
color: black;
|
||||
border: none;
|
||||
}
|
73
css/html_cai_style.css
Normal file
73
css/html_cai_style.css
Normal file
@ -0,0 +1,73 @@
|
||||
.chat {
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 800px;
|
||||
height: 66.67vh;
|
||||
overflow-y: auto;
|
||||
padding-right: 20px;
|
||||
display: flex;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
|
||||
.message {
|
||||
display: grid;
|
||||
grid-template-columns: 60px 1fr;
|
||||
padding-bottom: 25px;
|
||||
font-size: 15px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
line-height: 1.428571429;
|
||||
}
|
||||
|
||||
.circle-you {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
background-color: rgb(238, 78, 59);
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.circle-bot {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
background-color: rgb(59, 78, 244);
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.circle-bot img,
|
||||
.circle-you img {
|
||||
border-radius: 50%;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.text {}
|
||||
|
||||
.text p {
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
.username {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.message-body {}
|
||||
|
||||
.message-body img {
|
||||
max-width: 300px;
|
||||
max-height: 300px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
font-size: 15px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
color: rgb(138, 138, 138) !important;
|
||||
}
|
||||
|
||||
.message-body p em {
|
||||
color: rgb(110, 110, 110) !important;
|
||||
}
|
14
css/html_readable_style.css
Normal file
14
css/html_readable_style.css
Normal file
@ -0,0 +1,14 @@
|
||||
.container {
|
||||
max-width: 600px;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
background-color: rgb(31, 41, 55);
|
||||
padding:3em;
|
||||
}
|
||||
|
||||
.container p {
|
||||
font-size: 16px !important;
|
||||
color: white !important;
|
||||
margin-bottom: 22px;
|
||||
line-height: 1.4 !important;
|
||||
}
|
52
css/main.css
Normal file
52
css/main.css
Normal file
@ -0,0 +1,52 @@
|
||||
.tabs.svelte-710i53 {
|
||||
margin-top: 0
|
||||
}
|
||||
|
||||
.py-6 {
|
||||
padding-top: 2.5rem
|
||||
}
|
||||
|
||||
.dark #refresh-button {
|
||||
background-color: #ffffff1f;
|
||||
}
|
||||
|
||||
#refresh-button {
|
||||
flex: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-width: 50px;
|
||||
border: none;
|
||||
box-shadow: none;
|
||||
border-radius: 10px;
|
||||
background-color: #0000000d;
|
||||
}
|
||||
|
||||
#download-label, #upload-label {
|
||||
min-height: 0
|
||||
}
|
||||
|
||||
#accordion {
|
||||
}
|
||||
|
||||
.dark svg {
|
||||
fill: white;
|
||||
}
|
||||
|
||||
.dark a {
|
||||
color: white !important;
|
||||
text-decoration: none !important;
|
||||
}
|
||||
|
||||
svg {
|
||||
display: unset !important;
|
||||
vertical-align: middle !important;
|
||||
margin: 5px;
|
||||
}
|
||||
|
||||
ol li p, ul li p {
|
||||
display: inline-block;
|
||||
}
|
||||
|
||||
#main, #parameters, #chat-settings, #interface-mode, #lora {
|
||||
border: 0;
|
||||
}
|
18
css/main.js
Normal file
18
css/main.js
Normal file
@ -0,0 +1,18 @@
|
||||
document.getElementById("main").parentNode.childNodes[0].style = "border: none; background-color: #8080802b; margin-bottom: 40px";
|
||||
document.getElementById("main").parentNode.style = "padding: 0; margin: 0";
|
||||
document.getElementById("main").parentNode.parentNode.parentNode.style = "padding: 0";
|
||||
|
||||
// Get references to the elements
|
||||
let main = document.getElementById('main');
|
||||
let main_parent = main.parentNode;
|
||||
let extensions = document.getElementById('extensions');
|
||||
|
||||
// Add an event listener to the main element
|
||||
main_parent.addEventListener('click', function(e) {
|
||||
// Check if the main element is visible
|
||||
if (main.offsetHeight > 0 && main.offsetWidth > 0) {
|
||||
extensions.style.display = 'block';
|
||||
} else {
|
||||
extensions.style.display = 'none';
|
||||
}
|
||||
});
|
@ -101,6 +101,7 @@ def get_download_links_from_huggingface(model, branch):
|
||||
classifications = []
|
||||
has_pytorch = False
|
||||
has_safetensors = False
|
||||
is_lora = False
|
||||
while True:
|
||||
content = requests.get(f"{base}{page}{cursor.decode()}").content
|
||||
|
||||
@ -110,8 +111,10 @@ def get_download_links_from_huggingface(model, branch):
|
||||
|
||||
for i in range(len(dict)):
|
||||
fname = dict[i]['path']
|
||||
if not is_lora and fname.endswith(('adapter_config.json', 'adapter_model.bin')):
|
||||
is_lora = True
|
||||
|
||||
is_pytorch = re.match("pytorch_model.*\.bin", fname)
|
||||
is_pytorch = re.match("(pytorch|adapter)_model.*\.bin", fname)
|
||||
is_safetensors = re.match("model.*\.safetensors", fname)
|
||||
is_tokenizer = re.match("tokenizer.*\.model", fname)
|
||||
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
|
||||
@ -130,6 +133,7 @@ def get_download_links_from_huggingface(model, branch):
|
||||
has_pytorch = True
|
||||
classifications.append('pytorch')
|
||||
|
||||
|
||||
cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
|
||||
cursor = base64.b64encode(cursor)
|
||||
cursor = cursor.replace(b'=', b'%3D')
|
||||
@ -140,7 +144,7 @@ def get_download_links_from_huggingface(model, branch):
|
||||
if classifications[i] == 'pytorch':
|
||||
links.pop(i)
|
||||
|
||||
return links
|
||||
return links, is_lora
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = args.MODEL
|
||||
@ -159,15 +163,16 @@ if __name__ == '__main__':
|
||||
except ValueError as err_branch:
|
||||
print(f"Error: {err_branch}")
|
||||
sys.exit()
|
||||
|
||||
links, is_lora = get_download_links_from_huggingface(model, branch)
|
||||
base_folder = 'models' if not is_lora else 'loras'
|
||||
if branch != 'main':
|
||||
output_folder = Path("models") / (model.split('/')[-1] + f'_{branch}')
|
||||
output_folder = Path(base_folder) / (model.split('/')[-1] + f'_{branch}')
|
||||
else:
|
||||
output_folder = Path("models") / model.split('/')[-1]
|
||||
output_folder = Path(base_folder) / model.split('/')[-1]
|
||||
if not output_folder.exists():
|
||||
output_folder.mkdir()
|
||||
|
||||
links = get_download_links_from_huggingface(model, branch)
|
||||
|
||||
# Downloading the files
|
||||
print(f"Downloading the model to {output_folder}")
|
||||
pool = multiprocessing.Pool(processes=args.threads)
|
||||
|
1
extensions/api/requirements.txt
Normal file
1
extensions/api/requirements.txt
Normal file
@ -0,0 +1 @@
|
||||
flask_cloudflared==0.0.12
|
90
extensions/api/script.py
Normal file
90
extensions/api/script.py
Normal file
@ -0,0 +1,90 @@
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply, encode
|
||||
import json
|
||||
|
||||
params = {
|
||||
'port': 5000,
|
||||
}
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
response = json.dumps({
|
||||
'result': shared.model_name
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_POST(self):
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||
|
||||
if self.path == '/api/v1/generate':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [l.strip() for l in prompt.split('\n')]
|
||||
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
|
||||
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
|
||||
generator = generate_reply(
|
||||
question = prompt,
|
||||
max_new_tokens = body.get('max_length', 200),
|
||||
do_sample=True,
|
||||
temperature=body.get('temperature', 0.5),
|
||||
top_p=body.get('top_p', 1),
|
||||
typical_p=body.get('typical', 1),
|
||||
repetition_penalty=body.get('rep_pen', 1.1),
|
||||
encoder_repetition_penalty=1,
|
||||
top_k=body.get('top_k', 0),
|
||||
min_length=0,
|
||||
no_repeat_ngram_size=0,
|
||||
num_beams=1,
|
||||
penalty_alpha=0,
|
||||
length_penalty=1,
|
||||
early_stopping=False,
|
||||
)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
answer = a[0]
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'text': answer[len(prompt):]
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
|
||||
def run_server():
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting KoboldAI compatible api at {public_url}/api')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
else:
|
||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
def ui():
|
||||
Thread(target=run_server, daemon=True).start()
|
@ -1,8 +1,8 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
from elevenlabslib import *
|
||||
from elevenlabslib.helpers import *
|
||||
from elevenlabslib import ElevenLabsUser
|
||||
from elevenlabslib.helpers import save_bytes_to_path
|
||||
|
||||
params = {
|
||||
'activate': True,
|
||||
|
@ -76,7 +76,7 @@ def generate_html():
|
||||
return container_html
|
||||
|
||||
def ui():
|
||||
with gr.Accordion("Character gallery"):
|
||||
with gr.Accordion("Character gallery", open=False):
|
||||
update = gr.Button("Refresh")
|
||||
gallery = gr.HTML(value=generate_html())
|
||||
update.click(generate_html, [], gallery)
|
||||
|
4
extensions/whisper_stt/requirements.txt
Normal file
4
extensions/whisper_stt/requirements.txt
Normal file
@ -0,0 +1,4 @@
|
||||
git+https://github.com/Uberi/speech_recognition.git@010382b
|
||||
openai-whisper
|
||||
soundfile
|
||||
ffmpeg
|
54
extensions/whisper_stt/script.py
Normal file
54
extensions/whisper_stt/script.py
Normal file
@ -0,0 +1,54 @@
|
||||
import gradio as gr
|
||||
import speech_recognition as sr
|
||||
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
def do_stt(audio, text_state=""):
|
||||
transcription = ""
|
||||
r = sr.Recognizer()
|
||||
|
||||
# Convert to AudioData
|
||||
audio_data = sr.AudioData(sample_rate=audio[0], frame_data=audio[1], sample_width=4)
|
||||
|
||||
try:
|
||||
transcription = r.recognize_whisper(audio_data, language="english", model="base.en")
|
||||
except sr.UnknownValueError:
|
||||
print("Whisper could not understand audio")
|
||||
except sr.RequestError as e:
|
||||
print("Could not request results from Whisper", e)
|
||||
|
||||
input_hijack.update({"state": True, "value": [transcription, transcription]})
|
||||
|
||||
text_state += transcription + " "
|
||||
return text_state, text_state
|
||||
|
||||
|
||||
def update_hijack(val):
|
||||
input_hijack.update({"state": True, "value": [val, val]})
|
||||
return val
|
||||
|
||||
|
||||
def auto_transcribe(audio, audio_auto, text_state=""):
|
||||
if audio is None:
|
||||
return "", ""
|
||||
if audio_auto:
|
||||
return do_stt(audio, text_state)
|
||||
return "", ""
|
||||
|
||||
|
||||
def ui():
|
||||
tr_state = gr.State(value="")
|
||||
output_transcription = gr.Textbox(label="STT-Input",
|
||||
placeholder="Speech Preview. Click \"Generate\" to send",
|
||||
interactive=True)
|
||||
output_transcription.change(fn=update_hijack, inputs=[output_transcription], outputs=[tr_state])
|
||||
audio_auto = gr.Checkbox(label="Auto-Transcribe", value=True)
|
||||
with gr.Row():
|
||||
audio = gr.Audio(source="microphone")
|
||||
audio.change(fn=auto_transcribe, inputs=[audio, audio_auto, tr_state], outputs=[output_transcription, tr_state])
|
||||
transcribe_button = gr.Button(value="Transcribe")
|
||||
transcribe_button.click(do_stt, inputs=[audio, tr_state], outputs=[output_transcription, tr_state])
|
0
loras/place-your-loras-here.txt
Normal file
0
loras/place-your-loras-here.txt
Normal file
@ -61,7 +61,7 @@ def load_quantized(model_name):
|
||||
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
|
||||
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
|
||||
|
||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
|
||||
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LlamaDecoderLayer"])
|
||||
model = accelerate.dispatch_model(model, device_map=device_map)
|
||||
|
||||
# Single GPU
|
||||
|
22
modules/LoRA.py
Normal file
22
modules/LoRA.py
Normal file
@ -0,0 +1,22 @@
|
||||
from pathlib import Path
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import load_model
|
||||
|
||||
|
||||
def add_lora_to_model(lora_name):
|
||||
|
||||
from peft import PeftModel
|
||||
|
||||
# Is there a more efficient way of returning to the base model?
|
||||
if lora_name == "None":
|
||||
print("Reloading the model to remove the LoRA...")
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
else:
|
||||
# Why doesn't this work in 16-bit mode?
|
||||
print(f"Adding the LoRA {lora_name} to the model...")
|
||||
|
||||
params = {}
|
||||
params['device_map'] = {'': 0}
|
||||
#params['dtype'] = shared.model.dtype
|
||||
shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params)
|
@ -7,6 +7,7 @@ import transformers
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
# Copied from https://github.com/PygmalionAI/gradio-ui/
|
||||
class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria):
|
||||
|
||||
|
@ -11,17 +11,11 @@ from PIL import Image
|
||||
import modules.extensions as extensions_module
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.text_generation import encode, generate_reply, get_max_prompt_length
|
||||
from modules.html_generator import fix_newlines, generate_chat_html
|
||||
from modules.text_generation import (encode, generate_reply,
|
||||
get_max_prompt_length)
|
||||
|
||||
|
||||
# This gets the new line characters right.
|
||||
def clean_chat_message(text):
|
||||
text = text.replace('\n', '\n\n')
|
||||
text = re.sub(r"\n{3,}", "\n\n", text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
def generate_chat_output(history, name1, name2, character):
|
||||
if shared.args.cai_chat:
|
||||
return generate_chat_html(history, name1, name2, character)
|
||||
@ -29,7 +23,7 @@ def generate_chat_output(history, name1, name2, character):
|
||||
return history
|
||||
|
||||
def generate_chat_prompt(user_input, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=False):
|
||||
user_input = clean_chat_message(user_input)
|
||||
user_input = fix_newlines(user_input)
|
||||
rows = [f"{context.strip()}\n"]
|
||||
|
||||
if shared.soft_prompt:
|
||||
@ -82,7 +76,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
||||
if idx != -1:
|
||||
reply = reply[:idx]
|
||||
next_character_found = True
|
||||
reply = clean_chat_message(reply)
|
||||
reply = fix_newlines(reply)
|
||||
|
||||
# If something like "\nYo" is generated just before "\nYou:"
|
||||
# is completed, trim it
|
||||
@ -97,7 +91,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate
|
||||
def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False):
|
||||
shared.stop_everything = False
|
||||
just_started = True
|
||||
eos_token = '\n' if check else None
|
||||
@ -133,7 +127,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
# Generate
|
||||
reply = ''
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(reply) > 0 else ''}{reply}", max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name1}:"):
|
||||
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check)
|
||||
@ -160,7 +154,7 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical
|
||||
|
||||
yield shared.history['visible']
|
||||
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
eos_token = '\n' if check else None
|
||||
|
||||
if 'pygmalion' in shared.model_name.lower():
|
||||
@ -172,18 +166,18 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
for i in range(chat_generation_attempts):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"):
|
||||
reply, next_character_found = extract_message_from_reply(prompt, reply, name1, name2, check, impersonate=True)
|
||||
yield reply
|
||||
if next_character_found:
|
||||
break
|
||||
yield reply
|
||||
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
||||
def cai_chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
for _history in chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts):
|
||||
yield generate_chat_html(_history, name1, name2, shared.character)
|
||||
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1):
|
||||
if (shared.character != 'None' and len(shared.history['visible']) == 1) or len(shared.history['internal']) == 0:
|
||||
yield generate_chat_output(shared.history['visible'], name1, name2, shared.character)
|
||||
else:
|
||||
@ -191,7 +185,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield generate_chat_output(shared.history['visible']+[[last_visible[0], shared.processing_message]], name1, name2, shared.character)
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True):
|
||||
if shared.args.cai_chat:
|
||||
shared.history['visible'][-1] = [last_visible[0], _history[-1][1]]
|
||||
else:
|
||||
|
@ -1,3 +1,5 @@
|
||||
import gradio as gr
|
||||
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
@ -9,9 +11,12 @@ def load_extensions():
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
if name in available_extensions:
|
||||
print(f'Loading the extension "{name}"... ', end='')
|
||||
exec(f"import extensions.{name}.script")
|
||||
state[name] = [True, i]
|
||||
print('Ok.')
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
state[name] = [True, i]
|
||||
print('Ok.')
|
||||
except:
|
||||
print('Fail.')
|
||||
|
||||
# This iterator returns the extensions in the order specified in the command-line
|
||||
def iterator():
|
||||
@ -40,6 +45,9 @@ def create_extensions_block():
|
||||
extension.params[param] = shared.settings[_id]
|
||||
|
||||
# Creating the extension ui elements
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui"):
|
||||
extension.ui()
|
||||
if len(state) > 0:
|
||||
with gr.Box(elem_id="extensions"):
|
||||
gr.Markdown("Extensions")
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui"):
|
||||
extension.ui()
|
||||
|
@ -1,6 +1,6 @@
|
||||
'''
|
||||
|
||||
This is a library for formatting GPT-4chan and chat outputs as nice HTML.
|
||||
This is a library for formatting text outputs as nice HTML.
|
||||
|
||||
'''
|
||||
|
||||
@ -8,30 +8,39 @@ import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import markdown
|
||||
from PIL import Image
|
||||
|
||||
# This is to store the paths to the thumbnails of the profile pictures
|
||||
image_cache = {}
|
||||
|
||||
def generate_basic_html(s):
|
||||
css = """
|
||||
.container {
|
||||
max-width: 600px;
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
background-color: rgb(31, 41, 55);
|
||||
padding:3em;
|
||||
}
|
||||
.container p {
|
||||
font-size: 16px !important;
|
||||
color: white !important;
|
||||
margin-bottom: 22px;
|
||||
line-height: 1.4 !important;
|
||||
}
|
||||
"""
|
||||
s = '\n'.join([f'<p>{line}</p>' for line in s.split('\n')])
|
||||
s = f'<style>{css}</style><div class="container">{s}</div>'
|
||||
return s
|
||||
with open(Path(__file__).resolve().parent / '../css/html_readable_style.css', 'r') as f:
|
||||
readable_css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/html_4chan_style.css', 'r') as css_f:
|
||||
_4chan_css = css_f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/html_cai_style.css', 'r') as f:
|
||||
cai_css = f.read()
|
||||
|
||||
def fix_newlines(string):
|
||||
string = string.replace('\n', '\n\n')
|
||||
string = re.sub(r"\n{3,}", "\n\n", string)
|
||||
string = string.strip()
|
||||
return string
|
||||
|
||||
# This could probably be generalized and improved
|
||||
def convert_to_markdown(string):
|
||||
string = string.replace('\\begin{code}', '```')
|
||||
string = string.replace('\\end{code}', '```')
|
||||
string = string.replace('\\begin{blockquote}', '> ')
|
||||
string = string.replace('\\end{blockquote}', '')
|
||||
string = re.sub(r"(.)```", r"\1\n```", string)
|
||||
# string = fix_newlines(string)
|
||||
return markdown.markdown(string, extensions=['fenced_code'])
|
||||
|
||||
def generate_basic_html(string):
|
||||
string = convert_to_markdown(string)
|
||||
string = f'<style>{readable_css}</style><div class="container">{string}</div>'
|
||||
return string
|
||||
|
||||
def process_post(post, c):
|
||||
t = post.split('\n')
|
||||
@ -48,113 +57,6 @@ def process_post(post, c):
|
||||
return src
|
||||
|
||||
def generate_4chan_html(f):
|
||||
css = """
|
||||
|
||||
#parent #container {
|
||||
background-color: #eef2ff;
|
||||
padding: 17px;
|
||||
}
|
||||
#parent #container .reply {
|
||||
background-color: rgb(214, 218, 240);
|
||||
border-bottom-color: rgb(183, 197, 217);
|
||||
border-bottom-style: solid;
|
||||
border-bottom-width: 1px;
|
||||
border-image-outset: 0;
|
||||
border-image-repeat: stretch;
|
||||
border-image-slice: 100%;
|
||||
border-image-source: none;
|
||||
border-image-width: 1;
|
||||
border-left-color: rgb(0, 0, 0);
|
||||
border-left-style: none;
|
||||
border-left-width: 0px;
|
||||
border-right-color: rgb(183, 197, 217);
|
||||
border-right-style: solid;
|
||||
border-right-width: 1px;
|
||||
border-top-color: rgb(0, 0, 0);
|
||||
border-top-style: none;
|
||||
border-top-width: 0px;
|
||||
color: rgb(0, 0, 0);
|
||||
display: table;
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
margin-bottom: 4px;
|
||||
margin-left: 0px;
|
||||
margin-right: 0px;
|
||||
margin-top: 4px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: hidden;
|
||||
padding-bottom: 4px;
|
||||
padding-left: 2px;
|
||||
padding-right: 2px;
|
||||
padding-top: 4px;
|
||||
}
|
||||
|
||||
#parent #container .number {
|
||||
color: rgb(0, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
width: 342.65px;
|
||||
margin-right: 7px;
|
||||
}
|
||||
|
||||
#parent #container .op {
|
||||
color: rgb(0, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
margin-bottom: 8px;
|
||||
margin-left: 0px;
|
||||
margin-right: 0px;
|
||||
margin-top: 4px;
|
||||
overflow-x: hidden;
|
||||
overflow-y: hidden;
|
||||
}
|
||||
|
||||
#parent #container .op blockquote {
|
||||
margin-left: 0px !important;
|
||||
}
|
||||
|
||||
#parent #container .name {
|
||||
color: rgb(17, 119, 67);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
font-weight: 700;
|
||||
margin-left: 7px;
|
||||
}
|
||||
|
||||
#parent #container .quote {
|
||||
color: rgb(221, 0, 0);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
text-decoration-color: rgb(221, 0, 0);
|
||||
text-decoration-line: underline;
|
||||
text-decoration-style: solid;
|
||||
text-decoration-thickness: auto;
|
||||
}
|
||||
|
||||
#parent #container .greentext {
|
||||
color: rgb(120, 153, 34);
|
||||
font-family: arial, helvetica, sans-serif;
|
||||
font-size: 13.3333px;
|
||||
}
|
||||
|
||||
#parent #container blockquote {
|
||||
margin: 0px !important;
|
||||
margin-block-start: 1em;
|
||||
margin-block-end: 1em;
|
||||
margin-inline-start: 40px;
|
||||
margin-inline-end: 40px;
|
||||
margin-top: 13.33px !important;
|
||||
margin-bottom: 13.33px !important;
|
||||
margin-left: 40px !important;
|
||||
margin-right: 40px !important;
|
||||
}
|
||||
|
||||
#parent #container .message {
|
||||
color: black;
|
||||
border: none;
|
||||
}
|
||||
"""
|
||||
|
||||
posts = []
|
||||
post = ''
|
||||
c = -2
|
||||
@ -181,7 +83,7 @@ def generate_4chan_html(f):
|
||||
posts[i] = f'<div class="reply">{posts[i]}</div>\n'
|
||||
|
||||
output = ''
|
||||
output += f'<style>{css}</style><div id="parent"><div id="container">'
|
||||
output += f'<style>{_4chan_css}</style><div id="parent"><div id="container">'
|
||||
for post in posts:
|
||||
output += post
|
||||
output += '</div></div>'
|
||||
@ -208,135 +110,39 @@ def get_image_cache(path):
|
||||
|
||||
return image_cache[path][1]
|
||||
|
||||
def load_html_image(paths):
|
||||
for str_path in paths:
|
||||
path = Path(str_path)
|
||||
if path.exists():
|
||||
return f'<img src="file/{get_image_cache(path)}">'
|
||||
return ''
|
||||
|
||||
def generate_chat_html(history, name1, name2, character):
|
||||
css = """
|
||||
.chat {
|
||||
margin-left: auto;
|
||||
margin-right: auto;
|
||||
max-width: 800px;
|
||||
height: 66.67vh;
|
||||
overflow-y: auto;
|
||||
padding-right: 20px;
|
||||
display: flex;
|
||||
flex-direction: column-reverse;
|
||||
}
|
||||
|
||||
.message {
|
||||
display: grid;
|
||||
grid-template-columns: 60px 1fr;
|
||||
padding-bottom: 25px;
|
||||
font-size: 15px;
|
||||
font-family: Helvetica, Arial, sans-serif;
|
||||
line-height: 1.428571429;
|
||||
}
|
||||
|
||||
.circle-you {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
background-color: rgb(238, 78, 59);
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.circle-bot {
|
||||
width: 50px;
|
||||
height: 50px;
|
||||
background-color: rgb(59, 78, 244);
|
||||
border-radius: 50%;
|
||||
}
|
||||
|
||||
.circle-bot img, .circle-you img {
|
||||
border-radius: 50%;
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
object-fit: cover;
|
||||
}
|
||||
|
||||
.text {
|
||||
}
|
||||
|
||||
.text p {
|
||||
margin-top: 5px;
|
||||
}
|
||||
|
||||
.username {
|
||||
font-weight: bold;
|
||||
}
|
||||
|
||||
.message-body {
|
||||
}
|
||||
|
||||
.message-body img {
|
||||
max-width: 300px;
|
||||
max-height: 300px;
|
||||
border-radius: 20px;
|
||||
}
|
||||
|
||||
.message-body p {
|
||||
margin-bottom: 0 !important;
|
||||
font-size: 15px !important;
|
||||
line-height: 1.428571429 !important;
|
||||
}
|
||||
|
||||
.dark .message-body p em {
|
||||
color: rgb(138, 138, 138) !important;
|
||||
}
|
||||
|
||||
.message-body p em {
|
||||
color: rgb(110, 110, 110) !important;
|
||||
}
|
||||
|
||||
"""
|
||||
|
||||
output = ''
|
||||
output += f'<style>{css}</style><div class="chat" id="chat">'
|
||||
img = ''
|
||||
|
||||
for i in [
|
||||
f"characters/{character}.png",
|
||||
f"characters/{character}.jpg",
|
||||
f"characters/{character}.jpeg",
|
||||
"img_bot.png",
|
||||
"img_bot.jpg",
|
||||
"img_bot.jpeg"
|
||||
]:
|
||||
|
||||
path = Path(i)
|
||||
if path.exists():
|
||||
img = f'<img src="file/{get_image_cache(path)}">'
|
||||
break
|
||||
|
||||
img_me = ''
|
||||
for i in ["img_me.png", "img_me.jpg", "img_me.jpeg"]:
|
||||
path = Path(i)
|
||||
if path.exists():
|
||||
img_me = f'<img src="file/{get_image_cache(path)}">'
|
||||
break
|
||||
output = f'<style>{cai_css}</style><div class="chat" id="chat">'
|
||||
|
||||
img_bot = load_html_image([f"characters/{character}.{ext}" for ext in ['png', 'jpg', 'jpeg']] + ["img_bot.png","img_bot.jpg","img_bot.jpeg"])
|
||||
img_me = load_html_image(["img_me.png", "img_me.jpg", "img_me.jpeg"])
|
||||
|
||||
for i,_row in enumerate(history[::-1]):
|
||||
row = _row.copy()
|
||||
row[0] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[0])
|
||||
row[1] = re.sub(r"(\*\*)([^\*\n]*)(\*\*)", r"<b>\2</b>", row[1])
|
||||
row[0] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[0])
|
||||
row[1] = re.sub(r"(\*)([^\*\n]*)(\*)", r"<em>\2</em>", row[1])
|
||||
p = '\n'.join([f"<p>{x}</p>" for x in row[1].split('\n')])
|
||||
row = [convert_to_markdown(entry) for entry in _row]
|
||||
|
||||
output += f"""
|
||||
<div class="message">
|
||||
<div class="circle-bot">
|
||||
{img}
|
||||
{img_bot}
|
||||
</div>
|
||||
<div class="text">
|
||||
<div class="username">
|
||||
{name2}
|
||||
</div>
|
||||
<div class="message-body">
|
||||
{p}
|
||||
{row[1]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
"""
|
||||
|
||||
if not (i == len(history)-1 and len(row[0]) == 0):
|
||||
p = '\n'.join([f"<p>{x}</p>" for x in row[0].split('\n')])
|
||||
output += f"""
|
||||
<div class="message">
|
||||
<div class="circle-you">
|
||||
@ -347,7 +153,7 @@ def generate_chat_html(history, name1, name2, character):
|
||||
{name1}
|
||||
</div>
|
||||
<div class="message-body">
|
||||
{p}
|
||||
{row[0]}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -7,7 +7,9 @@ from pathlib import Path
|
||||
import numpy as np
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from accelerate import infer_auto_device_map, init_empty_weights
|
||||
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
BitsAndBytesConfig)
|
||||
|
||||
import modules.shared as shared
|
||||
|
||||
@ -16,8 +18,7 @@ transformers.logging.set_verbosity_error()
|
||||
local_rank = None
|
||||
|
||||
if shared.args.flexgen:
|
||||
from flexgen.flex_opt import (CompressionConfig, ExecutionEnv, OptLM,
|
||||
Policy, str2bool)
|
||||
from flexgen.flex_opt import CompressionConfig, ExecutionEnv, OptLM, Policy
|
||||
|
||||
if shared.args.deepspeed:
|
||||
import deepspeed
|
||||
@ -46,7 +47,12 @@ def load_model(model_name):
|
||||
if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')):
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16).cuda()
|
||||
model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16)
|
||||
if torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
model = model.to(device)
|
||||
else:
|
||||
model = model.cuda()
|
||||
|
||||
# FlexGen
|
||||
elif shared.args.flexgen:
|
||||
@ -95,39 +101,60 @@ def load_model(model_name):
|
||||
|
||||
# Custom
|
||||
else:
|
||||
command = "AutoModelForCausalLM.from_pretrained"
|
||||
params = ["low_cpu_mem_usage=True"]
|
||||
if not shared.args.cpu and not torch.cuda.is_available():
|
||||
print("Warning: no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
params = {"low_cpu_mem_usage": True}
|
||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||
print("Warning: torch.cuda.is_available() returned False.\nThis means that no GPU has been detected.\nFalling back to CPU mode.\n")
|
||||
shared.args.cpu = True
|
||||
|
||||
if shared.args.cpu:
|
||||
params.append("low_cpu_mem_usage=True")
|
||||
params.append("torch_dtype=torch.float32")
|
||||
params["torch_dtype"] = torch.float32
|
||||
else:
|
||||
params.append("device_map='auto'")
|
||||
params.append("load_in_8bit=True" if shared.args.load_in_8bit else "torch_dtype=torch.bfloat16" if shared.args.bf16 else "torch_dtype=torch.float16")
|
||||
params["device_map"] = 'auto'
|
||||
if shared.args.load_in_8bit and any((shared.args.auto_devices, shared.args.gpu_memory)):
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True, llm_int8_enable_fp32_cpu_offload=True)
|
||||
elif shared.args.load_in_8bit:
|
||||
params['quantization_config'] = BitsAndBytesConfig(load_in_8bit=True)
|
||||
elif shared.args.bf16:
|
||||
params["torch_dtype"] = torch.bfloat16
|
||||
else:
|
||||
params["torch_dtype"] = torch.float16
|
||||
|
||||
if shared.args.gpu_memory:
|
||||
memory_map = shared.args.gpu_memory
|
||||
max_memory = f"max_memory={{0: '{memory_map[0]}GiB'"
|
||||
for i in range(1, len(memory_map)):
|
||||
max_memory += (f", {i}: '{memory_map[i]}GiB'")
|
||||
max_memory += (f", 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
|
||||
params.append(max_memory)
|
||||
elif not shared.args.load_in_8bit:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory/(1024*1024))
|
||||
suggestion = round((total_mem-1000)/1000)*1000
|
||||
if total_mem-suggestion < 800:
|
||||
max_memory = {}
|
||||
for i in range(len(memory_map)):
|
||||
max_memory[i] = f'{memory_map[i]}GiB'
|
||||
max_memory['cpu'] = f'{shared.args.cpu_memory or 99}GiB'
|
||||
params['max_memory'] = max_memory
|
||||
elif shared.args.auto_devices:
|
||||
total_mem = (torch.cuda.get_device_properties(0).total_memory / (1024*1024))
|
||||
suggestion = round((total_mem-1000) / 1000) * 1000
|
||||
if total_mem - suggestion < 800:
|
||||
suggestion -= 1000
|
||||
suggestion = int(round(suggestion/1000))
|
||||
print(f"\033[1;32;1mAuto-assiging --gpu-memory {suggestion} for your GPU to try to prevent out-of-memory errors.\nYou can manually set other values.\033[0;37;0m")
|
||||
params.append(f"max_memory={{0: '{suggestion}GiB', 'cpu': '{shared.args.cpu_memory or '99'}GiB'}}")
|
||||
if shared.args.disk:
|
||||
params.append(f"offload_folder='{shared.args.disk_cache_dir}'")
|
||||
|
||||
max_memory = {0: f'{suggestion}GiB', 'cpu': f'{shared.args.cpu_memory or 99}GiB'}
|
||||
params['max_memory'] = max_memory
|
||||
|
||||
command = f"{command}(Path(f'models/{shared.model_name}'), {', '.join(set(params))})"
|
||||
model = eval(command)
|
||||
if shared.args.disk:
|
||||
params["offload_folder"] = shared.args.disk_cache_dir
|
||||
|
||||
checkpoint = Path(f'models/{shared.model_name}')
|
||||
|
||||
if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto':
|
||||
config = AutoConfig.from_pretrained(checkpoint)
|
||||
with init_empty_weights():
|
||||
model = AutoModelForCausalLM.from_config(config)
|
||||
model.tie_weights()
|
||||
params['device_map'] = infer_auto_device_map(
|
||||
model,
|
||||
dtype=torch.int8,
|
||||
max_memory=params['max_memory'],
|
||||
no_split_module_classes = model._no_split_modules
|
||||
)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(checkpoint, **params)
|
||||
|
||||
# Loading the tokenizer
|
||||
if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists():
|
||||
|
@ -2,7 +2,8 @@ import argparse
|
||||
|
||||
model = None
|
||||
tokenizer = None
|
||||
model_name = ""
|
||||
model_name = "None"
|
||||
lora_name = "None"
|
||||
soft_prompt_tensor = None
|
||||
soft_prompt = False
|
||||
is_RWKV = False
|
||||
@ -19,6 +20,9 @@ gradio = {}
|
||||
# Generation input parameters
|
||||
input_params = []
|
||||
|
||||
# For restarting the interface
|
||||
need_restart = False
|
||||
|
||||
settings = {
|
||||
'max_new_tokens': 200,
|
||||
'max_new_tokens_min': 1,
|
||||
@ -26,7 +30,7 @@ settings = {
|
||||
'name1': 'Person 1',
|
||||
'name2': 'Person 2',
|
||||
'context': 'This is a conversation between two people.',
|
||||
'stop_at_newline': True,
|
||||
'stop_at_newline': False,
|
||||
'chat_prompt_size': 2048,
|
||||
'chat_prompt_size_min': 0,
|
||||
'chat_prompt_size_max': 2048,
|
||||
@ -49,6 +53,10 @@ settings = {
|
||||
'^(gpt4chan|gpt-4chan|4chan)': '-----\n--- 865467536\nInput text\n--- 865467537\n',
|
||||
'(rosey|chip|joi)_.*_instruct.*': 'User: \n',
|
||||
'oasst-*': '<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>'
|
||||
},
|
||||
'lora_prompts': {
|
||||
'default': 'Common sense questions and answers\n\nQuestion: \nFactual answer:',
|
||||
'(alpaca-lora-7b|alpaca-lora-13b|alpaca-lora-30b)': "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,6 +72,7 @@ def str2bool(v):
|
||||
|
||||
parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpFormatter(prog,max_help_position=54))
|
||||
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode.')
|
||||
parser.add_argument('--cai-chat', action='store_true', help='Launch the web UI in chat mode with a style similar to Character.AI\'s. If the file img_bot.png or img_bot.jpg exists in the same folder as server.py, this image will be used as the bot\'s profile picture. Similarly, img_me.png or img_me.jpg will be used as your profile picture.')
|
||||
|
@ -33,12 +33,15 @@ def encode(prompt, tokens_to_generate=0, add_special_tokens=True):
|
||||
return input_ids.numpy()
|
||||
elif shared.args.deepspeed:
|
||||
return input_ids.to(device=local_rank)
|
||||
elif torch.has_mps:
|
||||
device = torch.device('mps')
|
||||
return input_ids.to(device)
|
||||
else:
|
||||
return input_ids.cuda()
|
||||
|
||||
def decode(output_ids):
|
||||
# Open Assistant relies on special tokens like <|endoftext|>
|
||||
if re.match('oasst-*', shared.model_name.lower()):
|
||||
if re.match('(oasst|galactica)-*', shared.model_name.lower()):
|
||||
return shared.tokenizer.decode(output_ids, skip_special_tokens=False)
|
||||
else:
|
||||
reply = shared.tokenizer.decode(output_ids, skip_special_tokens=True)
|
||||
@ -89,7 +92,7 @@ def clear_torch_cache():
|
||||
if not shared.args.cpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, encoder_repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=None, stopping_string=None):
|
||||
clear_torch_cache()
|
||||
t0 = time.time()
|
||||
|
||||
@ -101,7 +104,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k)
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k):
|
||||
@ -143,6 +147,7 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
"top_p": top_p,
|
||||
"typical_p": typical_p,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"encoder_repetition_penalty": encoder_repetition_penalty,
|
||||
"top_k": top_k,
|
||||
"min_length": min_length if shared.args.no_stream else 0,
|
||||
"no_repeat_ngram_size": no_repeat_ngram_size,
|
||||
@ -196,7 +201,8 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi
|
||||
def generate_with_streaming(**kwargs):
|
||||
return Iteratorize(generate_with_callback, kwargs, callback=None)
|
||||
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
if not (shared.args.chat or shared.args.cai_chat):
|
||||
yield formatted_outputs(original_question, shared.model_name)
|
||||
with generate_with_streaming(**generate_params) as generator:
|
||||
for output in generator:
|
||||
if shared.soft_prompt:
|
||||
|
@ -1,68 +1,17 @@
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
|
||||
refresh_symbol = '\U0001f504' # 🔄
|
||||
|
||||
css = """
|
||||
.tabs.svelte-710i53 {
|
||||
margin-top: 0
|
||||
}
|
||||
.py-6 {
|
||||
padding-top: 2.5rem
|
||||
}
|
||||
.dark #refresh-button {
|
||||
background-color: #ffffff1f;
|
||||
}
|
||||
#refresh-button {
|
||||
flex: none;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
min-width: 50px;
|
||||
border: none;
|
||||
box-shadow: none;
|
||||
border-radius: 10px;
|
||||
background-color: #0000000d;
|
||||
}
|
||||
#download-label, #upload-label {
|
||||
min-height: 0
|
||||
}
|
||||
#accordion {
|
||||
}
|
||||
.dark svg {
|
||||
fill: white;
|
||||
}
|
||||
svg {
|
||||
display: unset !important;
|
||||
vertical-align: middle !important;
|
||||
margin: 5px;
|
||||
}
|
||||
ol li p, ul li p {
|
||||
display: inline-block;
|
||||
}
|
||||
"""
|
||||
|
||||
chat_css = """
|
||||
.h-\[40vh\], .wrap.svelte-byatnx.svelte-byatnx.svelte-byatnx {
|
||||
height: 66.67vh
|
||||
}
|
||||
.gradio-container {
|
||||
max-width: 800px !important;
|
||||
margin-left: auto !important;
|
||||
margin-right: auto !important;
|
||||
}
|
||||
.w-screen {
|
||||
width: unset
|
||||
}
|
||||
div.svelte-362y77>*, div.svelte-362y77>.form>* {
|
||||
flex-wrap: nowrap
|
||||
}
|
||||
/* fixes the API documentation in chat mode */
|
||||
.api-docs.svelte-1iguv9h.svelte-1iguv9h.svelte-1iguv9h {
|
||||
display: grid;
|
||||
}
|
||||
.pending.svelte-1ed2p3z {
|
||||
opacity: 1;
|
||||
}
|
||||
"""
|
||||
with open(Path(__file__).resolve().parent / '../css/main.css', 'r') as f:
|
||||
css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/chat.css', 'r') as f:
|
||||
chat_css = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/main.js', 'r') as f:
|
||||
main_js = f.read()
|
||||
with open(Path(__file__).resolve().parent / '../css/chat.js', 'r') as f:
|
||||
chat_js = f.read()
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
"""Small button with single emoji as text, fits inside gradio forms"""
|
||||
|
@ -1,12 +1,7 @@
|
||||
do_sample=True
|
||||
temperature=1
|
||||
top_p=1
|
||||
typical_p=1
|
||||
repetition_penalty=1
|
||||
top_k=50
|
||||
num_beams=1
|
||||
penalty_alpha=0
|
||||
min_length=0
|
||||
length_penalty=1
|
||||
no_repeat_ngram_size=0
|
||||
top_p=0.5
|
||||
top_k=40
|
||||
temperature=0.7
|
||||
repetition_penalty=1.2
|
||||
typical_p=1.0
|
||||
early_stopping=False
|
||||
|
@ -1,6 +0,0 @@
|
||||
do_sample=True
|
||||
top_p=0.9
|
||||
top_k=50
|
||||
temperature=1.39
|
||||
repetition_penalty=1.08
|
||||
typical_p=0.2
|
@ -2,10 +2,12 @@ accelerate==0.17.1
|
||||
bitsandbytes==0.37.1
|
||||
flexgen==0.1.7
|
||||
gradio==3.18.0
|
||||
markdown
|
||||
numpy
|
||||
peft==0.2.0
|
||||
requests
|
||||
rwkv==0.4.2
|
||||
rwkv==0.7.0
|
||||
safetensors==0.3.0
|
||||
sentencepiece
|
||||
tqdm
|
||||
git+https://github.com/zphang/transformers.git@68d640f7c368bcaaaecfc678f11908ebbd3d6176
|
||||
git+https://github.com/huggingface/transformers
|
||||
|
482
server.py
482
server.py
@ -15,6 +15,7 @@ import modules.extensions as extensions_module
|
||||
import modules.shared as shared
|
||||
import modules.ui as ui
|
||||
from modules.html_generator import generate_chat_html
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
@ -34,7 +35,7 @@ def get_available_models():
|
||||
if shared.args.flexgen:
|
||||
return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower)
|
||||
else:
|
||||
return sorted([item.name for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt'))], key=str.lower)
|
||||
return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
def get_available_presets():
|
||||
return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower)
|
||||
@ -48,6 +49,9 @@ def get_available_extensions():
|
||||
def get_available_softprompts():
|
||||
return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower)
|
||||
|
||||
def get_available_loras():
|
||||
return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower)
|
||||
|
||||
def load_model_wrapper(selected_model):
|
||||
if selected_model != shared.model_name:
|
||||
shared.model_name = selected_model
|
||||
@ -59,6 +63,17 @@ def load_model_wrapper(selected_model):
|
||||
|
||||
return selected_model
|
||||
|
||||
def load_lora_wrapper(selected_lora):
|
||||
shared.lora_name = selected_lora
|
||||
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
||||
|
||||
if not shared.args.cpu:
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
add_lora_to_model(selected_lora)
|
||||
|
||||
return selected_lora, default_text
|
||||
|
||||
def load_preset_values(preset_menu, return_dict=False):
|
||||
generate_params = {
|
||||
'do_sample': True,
|
||||
@ -66,6 +81,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 50,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
@ -86,7 +102,7 @@ def load_preset_values(preset_menu, return_dict=False):
|
||||
if return_dict:
|
||||
return generate_params
|
||||
else:
|
||||
return generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
return preset_menu, generate_params['do_sample'], generate_params['temperature'], generate_params['top_p'], generate_params['typical_p'], generate_params['repetition_penalty'], generate_params['encoder_repetition_penalty'], generate_params['top_k'], generate_params['min_length'], generate_params['no_repeat_ngram_size'], generate_params['num_beams'], generate_params['penalty_alpha'], generate_params['length_penalty'], generate_params['early_stopping']
|
||||
|
||||
def upload_soft_prompt(file):
|
||||
with zipfile.ZipFile(io.BytesIO(file)) as zf:
|
||||
@ -100,9 +116,7 @@ def upload_soft_prompt(file):
|
||||
|
||||
return name
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
def create_model_and_preset_menus():
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
@ -113,31 +127,48 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['preset_menu'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
with gr.Accordion('Custom generation parameters', open=False, elem_id='accordion'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 2.99, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
||||
with gr.Column():
|
||||
def create_settings_menus(default_preset):
|
||||
generate_params = load_preset_values(default_preset if not shared.args.flexgen else 'Naive', return_dict=True)
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['preset_menu_mirror'] = gr.Dropdown(choices=available_presets, value=default_preset if not shared.args.flexgen else 'Naive', label='Generation parameters preset')
|
||||
ui.create_refresh_button(shared.gradio['preset_menu_mirror'], lambda : None, lambda : {'choices': get_available_presets()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
gr.Markdown('Custom generation parameters ([reference](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig))')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['temperature'] = gr.Slider(0.01, 1.99, value=generate_params['temperature'], step=0.01, label='temperature')
|
||||
shared.gradio['top_p'] = gr.Slider(0.0,1.0,value=generate_params['top_p'],step=0.01,label='top_p')
|
||||
shared.gradio['top_k'] = gr.Slider(0,200,value=generate_params['top_k'],step=1,label='top_k')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
||||
with gr.Column():
|
||||
shared.gradio['repetition_penalty'] = gr.Slider(1.0, 1.5, value=generate_params['repetition_penalty'],step=0.01,label='repetition_penalty')
|
||||
shared.gradio['encoder_repetition_penalty'] = gr.Slider(0.8, 1.5, value=generate_params['encoder_repetition_penalty'],step=0.01,label='encoder_repetition_penalty')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
shared.gradio['do_sample'] = gr.Checkbox(value=generate_params['do_sample'], label='do_sample')
|
||||
shared.gradio['typical_p'] = gr.Slider(0.0,1.0,value=generate_params['typical_p'],step=0.01,label='typical_p')
|
||||
shared.gradio['no_repeat_ngram_size'] = gr.Slider(0, 20, step=1, value=generate_params['no_repeat_ngram_size'], label='no_repeat_ngram_size')
|
||||
shared.gradio['min_length'] = gr.Slider(0, 2000, step=1, value=generate_params['min_length'] if shared.args.no_stream else 0, label='min_length', interactive=shared.args.no_stream)
|
||||
with gr.Column():
|
||||
with gr.Box():
|
||||
gr.Markdown('Contrastive search')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
|
||||
gr.Markdown('Contrastive search:')
|
||||
shared.gradio['penalty_alpha'] = gr.Slider(0, 5, value=generate_params['penalty_alpha'], label='penalty_alpha')
|
||||
with gr.Box():
|
||||
gr.Markdown('Beam search (uses a lot of VRAM)')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
|
||||
with gr.Column():
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
|
||||
gr.Markdown('Beam search (uses a lot of VRAM):')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['num_beams'] = gr.Slider(1, 20, step=1, value=generate_params['num_beams'], label='num_beams')
|
||||
with gr.Column():
|
||||
shared.gradio['length_penalty'] = gr.Slider(-5, 5, value=generate_params['length_penalty'], label='length_penalty')
|
||||
shared.gradio['early_stopping'] = gr.Checkbox(value=generate_params['early_stopping'], label='early_stopping')
|
||||
with gr.Row():
|
||||
shared.gradio['lora_menu'] = gr.Dropdown(choices=available_loras, value=shared.lora_name, label='LoRA')
|
||||
ui.create_refresh_button(shared.gradio['lora_menu'], lambda : None, lambda : {'choices': get_available_loras()}, 'refresh-button')
|
||||
|
||||
with gr.Accordion('Soft prompt', open=False, elem_id='accordion'):
|
||||
with gr.Accordion('Soft prompt', open=False):
|
||||
with gr.Row():
|
||||
shared.gradio['softprompts_menu'] = gr.Dropdown(choices=available_softprompts, value='None', label='Soft prompt')
|
||||
ui.create_refresh_button(shared.gradio['softprompts_menu'], lambda : None, lambda : {'choices': get_available_softprompts()}, 'refresh-button')
|
||||
@ -147,14 +178,35 @@ def create_settings_menus(default_preset):
|
||||
shared.gradio['upload_softprompt'] = gr.File(type='binary', file_types=['.zip'])
|
||||
|
||||
shared.gradio['model_menu'].change(load_model_wrapper, [shared.gradio['model_menu']], [shared.gradio['model_menu']], show_progress=True)
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio['do_sample'], shared.gradio['temperature'], shared.gradio['top_p'], shared.gradio['typical_p'], shared.gradio['repetition_penalty'], shared.gradio['top_k'], shared.gradio['min_length'], shared.gradio['no_repeat_ngram_size'], shared.gradio['num_beams'], shared.gradio['penalty_alpha'], shared.gradio['length_penalty'], shared.gradio['early_stopping']])
|
||||
shared.gradio['preset_menu'].change(load_preset_values, [shared.gradio['preset_menu']], [shared.gradio[k] for k in ['preset_menu_mirror', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['preset_menu_mirror'].change(load_preset_values, [shared.gradio['preset_menu_mirror']], [shared.gradio[k] for k in ['preset_menu', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']])
|
||||
shared.gradio['lora_menu'].change(load_lora_wrapper, [shared.gradio['lora_menu']], [shared.gradio['lora_menu'], shared.gradio['textbox']], show_progress=True)
|
||||
shared.gradio['softprompts_menu'].change(load_soft_prompt, [shared.gradio['softprompts_menu']], [shared.gradio['softprompts_menu']], show_progress=True)
|
||||
shared.gradio['upload_softprompt'].upload(upload_soft_prompt, [shared.gradio['upload_softprompt']], [shared.gradio['softprompts_menu']])
|
||||
|
||||
def set_interface_arguments(interface_mode, extensions, cmd_active):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
cmd_list = vars(shared.args)
|
||||
cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||
|
||||
shared.args.extensions = extensions
|
||||
for k in modes[1:]:
|
||||
exec(f"shared.args.{k} = False")
|
||||
if interface_mode != "default":
|
||||
exec(f"shared.args.{interface_mode} = True")
|
||||
|
||||
for k in cmd_list:
|
||||
exec(f"shared.args.{k} = False")
|
||||
for k in cmd_active:
|
||||
exec(f"shared.args.{k} = True")
|
||||
|
||||
shared.need_restart = True
|
||||
|
||||
available_models = get_available_models()
|
||||
available_presets = get_available_presets()
|
||||
available_characters = get_available_characters()
|
||||
available_softprompts = get_available_softprompts()
|
||||
available_loras = get_available_loras()
|
||||
|
||||
# Default extensions
|
||||
extensions_module.available_extensions = get_available_extensions()
|
||||
@ -168,8 +220,6 @@ else:
|
||||
shared.args.extensions = shared.args.extensions or []
|
||||
if extension not in shared.args.extensions:
|
||||
shared.args.extensions.append(extension)
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
extensions_module.load_extensions()
|
||||
|
||||
# Default model
|
||||
if shared.args.model is not None:
|
||||
@ -189,191 +239,235 @@ else:
|
||||
print()
|
||||
shared.model_name = available_models[i]
|
||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||
if shared.args.lora:
|
||||
print(shared.args.lora)
|
||||
shared.lora_name = shared.args.lora
|
||||
add_lora_to_model(shared.lora_name)
|
||||
|
||||
# Default UI settings
|
||||
gen_events = []
|
||||
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
default_text = shared.settings['lora_prompts'][next((k for k in shared.settings['lora_prompts'] if re.match(k.lower(), shared.lora_name.lower())), 'default')]
|
||||
if default_text == '':
|
||||
default_text = shared.settings['prompts'][next((k for k in shared.settings['prompts'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||
title ='Text generation web UI'
|
||||
description = '\n\n# Text generation lab\nGenerate text using Large Language Models.\n'
|
||||
suffix = '_pygmalion' if 'pygmalion' in shared.model_name.lower() else ''
|
||||
|
||||
if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Blocks(css=ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
||||
else:
|
||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
with gr.Row():
|
||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
shared.gradio['Remove last'] = gr.Button('Remove last')
|
||||
def create_interface():
|
||||
|
||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
with gr.Tab('Chat settings'):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||
gen_events = []
|
||||
if shared.args.extensions is not None and len(shared.args.extensions) > 0:
|
||||
extensions_module.load_extensions()
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
||||
with gr.Row():
|
||||
with gr.Tab('Chat history'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown('Upload')
|
||||
shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
|
||||
with gr.Column():
|
||||
gr.Markdown('Download')
|
||||
shared.gradio['download'] = gr.File()
|
||||
shared.gradio['download_button'] = gr.Button(value='Click me')
|
||||
with gr.Tab('Upload character'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown('1. Select the JSON file')
|
||||
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
|
||||
with gr.Column():
|
||||
gr.Markdown('2. Select your character\'s profile picture (optional)')
|
||||
shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
|
||||
shared.gradio['Upload character'] = gr.Button(value='Submit')
|
||||
with gr.Tab('Upload your profile picture'):
|
||||
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
|
||||
with gr.Tab('Upload TavernAI Character Card'):
|
||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||
|
||||
with gr.Tab('Generation settings'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
||||
if shared.args.extensions is not None:
|
||||
with gr.Tab('Extensions'):
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear history with confirmation
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
|
||||
# Clearing stuff and saving the history
|
||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
|
||||
|
||||
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
|
||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
gr.Markdown(description)
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=23)
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
|
||||
else:
|
||||
with gr.Blocks(css=ui.css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
gr.Markdown(description)
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
with gr.Blocks(css=ui.css if not any((shared.args.chat, shared.args.cai_chat)) else ui.css+ui.chat_css, analytics_enabled=False, title=title) as shared.gradio['interface']:
|
||||
if shared.args.chat or shared.args.cai_chat:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
if shared.args.cai_chat:
|
||||
shared.gradio['display'] = gr.HTML(value=generate_chat_html(shared.history['visible'], shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}'], shared.character))
|
||||
else:
|
||||
shared.gradio['display'] = gr.Chatbot(value=shared.history['visible']).style(color_map=("#326efd", "#212528"))
|
||||
shared.gradio['textbox'] = gr.Textbox(label='Input')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
with gr.Column():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['Stop'] = gr.Button('Stop', elem_id="stop")
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
with gr.Row():
|
||||
shared.gradio['Impersonate'] = gr.Button('Impersonate')
|
||||
shared.gradio['Regenerate'] = gr.Button('Regenerate')
|
||||
with gr.Row():
|
||||
shared.gradio['Copy last reply'] = gr.Button('Copy last reply')
|
||||
shared.gradio['Replace last reply'] = gr.Button('Replace last reply')
|
||||
shared.gradio['Remove last'] = gr.Button('Remove last')
|
||||
|
||||
shared.gradio['Clear history'] = gr.Button('Clear history')
|
||||
shared.gradio['Clear history-confirm'] = gr.Button('Confirm', variant="stop", visible=False)
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
create_model_and_preset_menus()
|
||||
|
||||
with gr.Tab("Character", elem_id="chat-settings"):
|
||||
shared.gradio['name1'] = gr.Textbox(value=shared.settings[f'name1{suffix}'], lines=1, label='Your name')
|
||||
shared.gradio['name2'] = gr.Textbox(value=shared.settings[f'name2{suffix}'], lines=1, label='Bot\'s name')
|
||||
shared.gradio['context'] = gr.Textbox(value=shared.settings[f'context{suffix}'], lines=5, label='Context')
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=available_characters, value='None', label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda : None, lambda : {'choices': get_available_characters()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Tab('Chat history'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown('Upload')
|
||||
shared.gradio['upload_chat_history'] = gr.File(type='binary', file_types=['.json', '.txt'])
|
||||
with gr.Column():
|
||||
gr.Markdown('Download')
|
||||
shared.gradio['download'] = gr.File()
|
||||
shared.gradio['download_button'] = gr.Button(value='Click me')
|
||||
with gr.Tab('Upload character'):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown('1. Select the JSON file')
|
||||
shared.gradio['upload_json'] = gr.File(type='binary', file_types=['.json'])
|
||||
with gr.Column():
|
||||
gr.Markdown('2. Select your character\'s profile picture (optional)')
|
||||
shared.gradio['upload_img_bot'] = gr.File(type='binary', file_types=['image'])
|
||||
shared.gradio['Upload character'] = gr.Button(value='Submit')
|
||||
with gr.Tab('Upload your profile picture'):
|
||||
shared.gradio['upload_img_me'] = gr.File(type='binary', file_types=['image'])
|
||||
with gr.Tab('Upload TavernAI Character Card'):
|
||||
shared.gradio['upload_img_tavern'] = gr.File(type='binary', file_types=['image'])
|
||||
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
with gr.Box():
|
||||
gr.Markdown("Chat parameters")
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
shared.gradio['chat_prompt_size_slider'] = gr.Slider(minimum=shared.settings['chat_prompt_size_min'], maximum=shared.settings['chat_prompt_size_max'], step=1, label='Maximum prompt size in tokens', value=shared.settings['chat_prompt_size'])
|
||||
with gr.Column():
|
||||
shared.gradio['chat_generation_attempts'] = gr.Slider(minimum=shared.settings['chat_generation_attempts_min'], maximum=shared.settings['chat_generation_attempts_max'], value=shared.settings['chat_generation_attempts'], step=1, label='Generation attempts (for longer replies)')
|
||||
shared.gradio['check'] = gr.Checkbox(value=shared.settings[f'stop_at_newline{suffix}'], label='Stop generating at new line character?')
|
||||
|
||||
create_settings_menus(default_preset)
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
with gr.Column():
|
||||
function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper'
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping', 'name1', 'name2', 'context', 'check', 'chat_prompt_size_slider', 'chat_generation_attempts']]
|
||||
|
||||
gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events)
|
||||
|
||||
shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream)
|
||||
shared.gradio['Replace last reply'].click(chat.replace_last_reply, [shared.gradio['textbox'], shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'], show_progress=shared.args.no_stream)
|
||||
|
||||
# Clear history with confirmation
|
||||
clear_arr = [shared.gradio[k] for k in ['Clear history-confirm', 'Clear history', 'Clear history-cancel']]
|
||||
shared.gradio['Clear history'].click(lambda :[gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
shared.gradio['Clear history-confirm'].click(chat.clear_chat_log, [shared.gradio['name1'], shared.gradio['name2']], shared.gradio['display'])
|
||||
shared.gradio['Clear history-cancel'].click(lambda :[gr.update(visible=False), gr.update(visible=True), gr.update(visible=False)], None, clear_arr)
|
||||
|
||||
shared.gradio['Remove last'].click(chat.remove_last_message, [shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['display'], shared.gradio['textbox']], show_progress=False)
|
||||
shared.gradio['download_button'].click(chat.save_history, inputs=[], outputs=[shared.gradio['download']])
|
||||
shared.gradio['Upload character'].click(chat.upload_character, [shared.gradio['upload_json'], shared.gradio['upload_img_bot']], [shared.gradio['character_menu']])
|
||||
|
||||
# Clearing stuff and saving the history
|
||||
for i in ['Generate', 'Regenerate', 'Replace last reply']:
|
||||
shared.gradio[i].click(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio[i].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['Clear history-confirm'].click(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda x: '', shared.gradio['textbox'], shared.gradio['textbox'], show_progress=False)
|
||||
shared.gradio['textbox'].submit(lambda : chat.save_history(timestamp=False), [], [], show_progress=False)
|
||||
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio['character_menu'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['name2'], shared.gradio['context'], shared.gradio['display']])
|
||||
shared.gradio['upload_chat_history'].upload(chat.load_history, [shared.gradio['upload_chat_history'], shared.gradio['name1'], shared.gradio['name2']], [])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['upload_img_me'].upload(chat.upload_your_profile_picture, [shared.gradio['upload_img_me']], [])
|
||||
|
||||
reload_func = chat.redraw_html if shared.args.cai_chat else lambda : shared.history['visible']
|
||||
reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else []
|
||||
shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None)
|
||||
shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True)
|
||||
|
||||
elif shared.args.notebook:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['output_textbox'] = gr.Textbox(lines=15, label='Output')
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=25)
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
with gr.Row():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
|
||||
shared.gradio['interface'].queue()
|
||||
if shared.args.listen:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
else:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
create_model_and_preset_menus()
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
else:
|
||||
with gr.Tab("Text generation", elem_id="main"):
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['textbox'] = gr.Textbox(value=default_text, lines=15, label='Input')
|
||||
shared.gradio['max_new_tokens'] = gr.Slider(minimum=shared.settings['max_new_tokens_min'], maximum=shared.settings['max_new_tokens_max'], step=1, label='max_new_tokens', value=shared.settings['max_new_tokens'])
|
||||
shared.gradio['Generate'] = gr.Button('Generate')
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
shared.gradio['Continue'] = gr.Button('Continue')
|
||||
with gr.Column():
|
||||
shared.gradio['Stop'] = gr.Button('Stop')
|
||||
|
||||
create_model_and_preset_menus()
|
||||
|
||||
with gr.Column():
|
||||
with gr.Tab('Raw'):
|
||||
shared.gradio['output_textbox'] = gr.Textbox(lines=25, label='Output')
|
||||
with gr.Tab('Markdown'):
|
||||
shared.gradio['markdown'] = gr.Markdown()
|
||||
with gr.Tab('HTML'):
|
||||
shared.gradio['html'] = gr.HTML()
|
||||
with gr.Tab("Parameters", elem_id="parameters"):
|
||||
create_settings_menus(default_preset)
|
||||
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']]
|
||||
output_params = [shared.gradio[k] for k in ['output_textbox', 'markdown', 'html']]
|
||||
gen_events.append(shared.gradio['Generate'].click(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream, api_name='textgen'))
|
||||
gen_events.append(shared.gradio['textbox'].submit(generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream))
|
||||
gen_events.append(shared.gradio['Continue'].click(generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream))
|
||||
shared.gradio['Stop'].click(None, None, None, cancels=gen_events)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
with gr.Tab("Interface mode", elem_id="interface-mode"):
|
||||
modes = ["default", "notebook", "chat", "cai_chat"]
|
||||
current_mode = "default"
|
||||
for mode in modes[1:]:
|
||||
if eval(f"shared.args.{mode}"):
|
||||
current_mode = mode
|
||||
break
|
||||
cmd_list = vars(shared.args)
|
||||
cmd_list = [k for k in cmd_list if type(cmd_list[k]) is bool and k not in modes]
|
||||
active_cmd_list = [k for k in cmd_list if vars(shared.args)[k]]
|
||||
|
||||
gr.Markdown("*Experimental*")
|
||||
shared.gradio['interface_modes_menu'] = gr.Dropdown(choices=modes, value=current_mode, label="Mode")
|
||||
shared.gradio['extensions_menu'] = gr.CheckboxGroup(choices=get_available_extensions(), value=shared.args.extensions, label="Available extensions")
|
||||
shared.gradio['cmd_arguments_menu'] = gr.CheckboxGroup(choices=cmd_list, value=active_cmd_list, label="Boolean command-line flags")
|
||||
shared.gradio['reset_interface'] = gr.Button("Apply and restart the interface", type="primary")
|
||||
|
||||
shared.gradio['reset_interface'].click(set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'cmd_arguments_menu']], None)
|
||||
shared.gradio['reset_interface'].click(lambda : None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500)}')
|
||||
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
# Launch the interface
|
||||
shared.gradio['interface'].queue()
|
||||
if shared.args.listen:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
else:
|
||||
shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch)
|
||||
|
||||
create_interface()
|
||||
|
||||
# I think that I will need this later
|
||||
while True:
|
||||
time.sleep(0.5)
|
||||
if shared.need_restart:
|
||||
shared.need_restart = False
|
||||
shared.gradio['interface'].close()
|
||||
create_interface()
|
||||
|
@ -5,7 +5,7 @@
|
||||
"name1": "Person 1",
|
||||
"name2": "Person 2",
|
||||
"context": "This is a conversation between two people.",
|
||||
"stop_at_newline": true,
|
||||
"stop_at_newline": false,
|
||||
"chat_prompt_size": 2048,
|
||||
"chat_prompt_size_min": 0,
|
||||
"chat_prompt_size_max": 2048,
|
||||
@ -23,13 +23,16 @@
|
||||
"presets": {
|
||||
"default": "NovelAI-Sphinx Moth",
|
||||
"pygmalion-*": "Pygmalion",
|
||||
"RWKV-*": "Naive",
|
||||
"(rosey|chip|joi)_.*_instruct.*": "Instruct Joi (Contrastive Search)"
|
||||
"RWKV-*": "Naive"
|
||||
},
|
||||
"prompts": {
|
||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||
"^(gpt4chan|gpt-4chan|4chan)": "-----\n--- 865467536\nInput text\n--- 865467537\n",
|
||||
"(rosey|chip|joi)_.*_instruct.*": "User: \n",
|
||||
"oasst-*": "<|prompter|>Write a story about future of AI development<|endoftext|><|assistant|>"
|
||||
},
|
||||
"lora_prompts": {
|
||||
"default": "Common sense questions and answers\n\nQuestion: \nFactual answer:",
|
||||
"alpaca-lora-7b": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n### Instruction:\nWrite a poem about the transformers Python library. \nMention the word \"large language models\" in that poem.\n### Response:\n"
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user