Merge branch 'oobabooga:main' into searx_integration_bs4

This commit is contained in:
catalpaaa 2023-04-25 05:41:08 -07:00 committed by GitHub
commit a325e13857
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 1124 additions and 472 deletions

View File

@ -188,6 +188,7 @@ Optionally, you can use the following command-line flags:
| `-h`, `--help` | Show this help message and exit. | | `-h`, `--help` | Show this help message and exit. |
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
| `--chat` | Launch the web UI in chat mode. | | `--chat` | Launch the web UI in chat mode. |
| `--character CHARACTER` | The name of the character to load in chat mode by default. |
| `--model MODEL` | Name of the model to load by default. | | `--model MODEL` | Name of the model to load by default. |
| `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. |
| `--model-dir MODEL_DIR` | Path to directory with all the models. | | `--model-dir MODEL_DIR` | Path to directory with all the models. |
@ -220,6 +221,7 @@ Optionally, you can use the following command-line flags:
| Flag | Description | | Flag | Description |
|-------------|-------------| |-------------|-------------|
| `--threads` | Number of threads to use in llama.cpp. | | `--threads` | Number of threads to use in llama.cpp. |
| `--n_batch` | Processing batch size for llama.cpp. |
#### GPTQ #### GPTQ
@ -269,6 +271,13 @@ Optionally, you can use the following command-line flags:
| `--auto-launch` | Open the web UI in the default browser upon launch. | | `--auto-launch` | Open the web UI in the default browser upon launch. |
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
#### API
| Flag | Description |
|---------------------------------------|-------------|
| `--api` | Enable the API extension. |
| `--public-api` | Create a public URL for the API using Cloudfare. |
Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md). Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md).
## Presets ## Presets

View File

@ -1,39 +1,30 @@
'''
Contributed by SagsMug. Thank you SagsMug.
https://github.com/oobabooga/text-generation-webui/pull/175
'''
import asyncio import asyncio
import json import json
import random import sys
import string
import websockets try:
import websockets
except ImportError:
print("Websockets package not found. Make sure it's installed.")
# Gradio changes this index from time to time. To rediscover it, set VISIBLE = False in # For local streaming, the websockets are hosted without ssl - ws://
# modules/api.py and use the dev tools to inspect the request made after clicking on the HOST = 'localhost:5005'
# button called "Run" at the bottom of the UI URI = f'ws://{HOST}/api/v1/stream'
GRADIO_FN = 34
def random_hash():
letters = string.ascii_lowercase + string.digits
return ''.join(random.choice(letters) for i in range(9))
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context): async def run(context):
server = "127.0.0.1" # Note: the selected defaults change from time to time.
params = { request = {
'max_new_tokens': 200, 'prompt': context,
'max_new_tokens': 250,
'do_sample': True, 'do_sample': True,
'temperature': 0.72, 'temperature': 1.3,
'top_p': 0.73, 'top_p': 0.1,
'typical_p': 1, 'typical_p': 1,
'repetition_penalty': 1.1, 'repetition_penalty': 1.18,
'encoder_repetition_penalty': 1.0, 'top_k': 40,
'top_k': 0,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,
'num_beams': 1, 'num_beams': 1,
@ -45,48 +36,31 @@ async def run(context):
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [], 'stopping_strings': []
} }
payload = json.dumps([context, params])
session = random_hash()
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: async with websockets.connect(URI) as websocket:
while content := json.loads(await websocket.recv()): await websocket.send(json.dumps(request))
# Python3.10 syntax, replace with if elif on older
match content["msg"]:
case "send_hash":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": GRADIO_FN
}))
case "estimation":
pass
case "send_data":
await websocket.send(json.dumps({
"session_hash": session,
"fn_index": GRADIO_FN,
"data": [
payload
]
}))
case "process_starts":
pass
case "process_generating" | "process_completed":
yield content["output"]["data"][0]
# You can search for your desired end indicator and
# stop generation by closing the websocket here
if (content["msg"] == "process_completed"):
break
prompt = "What I would like to say is the following: " yield context # Remove this if you just want to see the reply
while True:
incoming_data = await websocket.recv()
incoming_data = json.loads(incoming_data)
match incoming_data['event']:
case 'text_stream':
yield incoming_data['text']
case 'stream_end':
return
async def get_result(): async def print_response_stream(prompt):
async for response in run(prompt): async for response in run(prompt):
# Print intermediate steps print(response, end='')
print(response) sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
# Print final result
print(response)
asyncio.run(get_result()) if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
asyncio.run(print_response_stream(prompt))

View File

@ -1,33 +1,22 @@
'''
This is an example on how to use the API for oobabooga/text-generation-webui.
Make sure to start the web UI with the following flags:
python server.py --model MODEL --listen --no-stream
Optionally, you can also add the --share flag to generate a public gradio URL,
allowing you to use the API remotely.
'''
import json
import requests import requests
# Server address # For local streaming, the websockets are hosted without ssl - http://
server = "127.0.0.1" HOST = 'localhost:5000'
URI = f'http://{HOST}/api/v1/generate'
# Generation parameters # For reverse-proxied streaming, the remote will likely host with ssl - https://
# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig # URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
params = {
'max_new_tokens': 200, def run(context):
request = {
'prompt': prompt,
'max_new_tokens': 250,
'do_sample': True, 'do_sample': True,
'temperature': 0.72, 'temperature': 1.3,
'top_p': 0.73, 'top_p': 0.1,
'typical_p': 1, 'typical_p': 1,
'repetition_penalty': 1.1, 'repetition_penalty': 1.18,
'encoder_repetition_penalty': 1.0, 'top_k': 40,
'top_k': 0,
'min_length': 0, 'min_length': 0,
'no_repeat_ngram_size': 0, 'no_repeat_ngram_size': 0,
'num_beams': 1, 'num_beams': 1,
@ -39,19 +28,15 @@ params = {
'truncation_length': 2048, 'truncation_length': 2048,
'ban_eos_token': False, 'ban_eos_token': False,
'skip_special_tokens': True, 'skip_special_tokens': True,
'stopping_strings': [], 'stopping_strings': []
} }
# Input prompt response = requests.post(URI, json=request)
prompt = "What I would like to say is the following: "
payload = json.dumps([prompt, params]) if response.status_code == 200:
result = response.json()['results'][0]['text']
print(prompt + result)
response = requests.post(f"http://{server}:7860/run/textgen", json={ if __name__ == '__main__':
"data": [ prompt = "In order to make homemade bread, follow these steps:\n1)"
payload run(prompt)
]
}).json()
reply = response["data"][0]
print(reply)

View File

@ -0,0 +1,3 @@
name: "### Assistant"
your_name: "### Human"
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"

View File

@ -16,19 +16,24 @@ command-line flag.
The link above contains a directory of user extensions for text-generation-webui. The link above contains a directory of user extensions for text-generation-webui.
If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above.
## Built-in extensions ## Built-in extensions
Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a).
|Extension|Description| |Extension|Description|
|---------|-----------| |---------|-----------|
|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` por 5000. This is the main API for this web UI. |
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.| |[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.|
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | |[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. | |[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. |
|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. Author: [@MetaIX](https://github.com/MetaIX). | |[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. |
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. Author: [@SillyLossy](https://github.com/sillylossy).| |[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API similar to the one provided by KoboldAI. Works with TavernAI: start the web UI with `python server.py --no-stream --extensions api` and set the API URL to `http://127.0.0.1:5000/api`. Author: [@mayaeary](https://github.com/mayaeary).| |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. Author: [@EliasVincent](https://github.com/EliasVincent).| |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). Author: [@Brawlence](https://github.com/Brawlence).| |[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. |
## How to write an extension ## How to write an extension
@ -41,6 +46,7 @@ The link above contains a directory of user extensions for text-generation-webui
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). | | `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
Additionally, the script may define two special global variables: Additionally, the script may define two special global variables:
@ -66,7 +72,9 @@ input_hijack = {
'value': ["", ""] 'value': ["", ""]
} }
``` ```
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the vales inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example. This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example.
## The `bot_prefix_modifier` ## The `bot_prefix_modifier`

167
docs/Training-LoRAs.md Normal file
View File

@ -0,0 +1,167 @@
## Training Your Own LoRAs
The WebUI seeks to make training your own LoRAs as easy as possible. It comes down to just a few simple steps:
### **Step 1**: Make a plan.
- What base model do you want to use? The LoRA you make has to be matched up to a single architecture (eg LLaMA-13B) and cannot be transferred to others (eg LLaMA-7B, StableLM, etc. would all be different). Derivatives of the same model (eg Alpaca finetune of LLaMA-13B) might be transferrable, but even then it's best to train exactly on what you plan to use.
- What model format do you want? At time of writing, 8-bit models are most stable, and 4-bit are supported but experimental. In the near future it is likely that 4-bit will be the best option for most users.
- What are you training it on? Do you want it to learn real information, a simple format, ...?
### **Step 2**: Gather a dataset.
- If you use a dataset similar to the [Alpaca](https://github.com/gururise/AlpacaDataCleaned/blob/main/alpaca_data_cleaned.json) format, that is natively supported by the `Formatted Dataset` input in the WebUI, with premade formatter options.
- If you use a dataset that isn't matched to Alpaca's format, but uses the same basic JSON structure, you can make your own format file by copying `training/formats/alpaca-format.json` to a new file and [editing its content](#format-files).
- If you can get the dataset into a simple text file, that works too! You can train using the `Raw text file` input option.
- This means you can for example just copy/paste a chatlog/documentation page/whatever you want, shove it in a plain text file, and train on it.
- If you use a structured dataset not in this format, you may have to find an external way to convert it - or open an issue to request native support.
### **Step 3**: Do the training.
- **3.1**: Load the WebUI, and your model.
- Make sure you don't have any LoRAs already loaded (unless you want to train for multi-LoRA usage).
- **3.2**: Open the `Training` tab at the top, `Train LoRA` sub-tab.
- **3.3**: Fill in the name lof the LoRA, select your dataset in the dataset options.
- **3.4**: Select other parameters to your preference. See [parameters below](#parameters).
- **3.5**: click `Start LoRA Training`, and wait.
- It can take a few hours for a large dataset, or just a few minute if doing a small run.
- You may want to monitor your [loss value](#loss) while it goes.
### **Step 4**: Evaluate your results.
- Load the LoRA under the Models Tab.
- You can go test-drive it on the `Text generation` tab, or you can use the `Perplexity evaluation` sub-tab of the `Training` tab.
- If you used the `Save every n steps` option, you can grab prior copies of the model from sub-folders within the LoRA model's folder and try them instead.
### **Step 5**: Re-run if you're unhappy.
- Make sure to unload the LoRA before training it.
- You can simply resume a prior run - use `Copy parameters from` to select your LoRA, and edit parameters. Note that you cannot change the `Rank` of an already created LoRA.
- If you want to resume from a checkpoint saved along the way, simply copy the contents of the checkpoint folder into the LoRA's folder.
- (Note: `adapter_model.bin` is the important file that holds the actual LoRA content).
- This will start Learning Rate and Steps back to the start. If you want to resume as if you were midway through, you can adjust your Learning Rate to the last reported LR in logs and reduce your epochs.
- Or, you can start over entirely if you prefer.
- If your model is producing corrupted outputs, you probably need to start over and use a lower Learning Rate.
- If your model isn't learning detailed information but you want it to, you might need to just run more epochs, or you might need a higher Rank.
- If your model is enforcing a format you didn't want, you may need to tweak your dataset, or start over and not train as far.
## Format Files
If using JSON formatted datasets, they are presumed to be in the following approximate format:
```json
[
{
"somekey": "somevalue",
"key2": "value2"
},
{
// etc
}
]
```
Where the keys (eg `somekey`, `key2` above) are standardized, and relatively consistent across the dataset, and the values (eg `somevalue`, `value2`) contain the content actually intended to be trained.
For Alpaca, the keys are `instruction`, `input`, and `output`, wherein `input` is sometimes blank.
A simple format file for Alpaca to be used as a chat bot is:
```json
{
"instruction,output": "User: %instruction%\nAssistant: %output%",
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
}
```
Note that the keys (eg `instruction,output`) are a comma-separated list of dataset keys, and the values are a simple string that use those keys with `%%`.
So for example if a dataset has `"instruction": "answer my question"`, then the format file's `User: %instruction%\n` will be automatically filled in as `User: answer my question\n`.
If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs.
## Parameters
The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options.
That said, here's a guide to the most important parameter choices you should consider:
### VRAM
- First, you must consider your VRAM availability.
- Generally, under default settings, VRAM usage for training with default parameters is very close to when generating text (with 1000+ tokens of context) (ie, if you can generate text, you can train LoRAs).
- Note: worse by default in the 4-bit monkeypatch currently. Reduce `Micro Batch Size` to `1` to restore this to expectations.
- If you have VRAM to spare, setting higher batch sizes will use more VRAM and get you better quality training in exchange.
- If you have large data, setting a higher cutoff length may be beneficial, but will cost significant VRAM. If you can spare some, set your batch size to `1` and see how high you can push your cutoff length.
- If you're low on VRAM, reducing batch size or cutoff length will of course improve that.
- Don't be afraid to just try it and see what happens. If it's too much, it will just error out, and you can lower settings and try again.
### Rank
- Second, you want to consider the amount of learning you want.
- For example, you may wish to just learn a dialogue format (as in the case of Alpaca) in which case setting a low `Rank` value (32 or lower) works great.
- Or, you might be training on project documentation you want the bot to understand and be able to understand questions about, in which case the higher the rank, the better.
- Generally, higher Rank = more precise learning = more total content learned = more VRAM usage while training.
### Learning Rate and Epochs
- Third, how carefully you want it to be learned.
- In other words, how okay or not you are with the model losing unrelated understandings.
- You can control this with 3 key settings: the Learning Rate, its scheduler, and your total epochs.
- The learning rate controls how much change is made to the model by each token it sees.
- It's in scientific notation normally, so for example `3e-4` means `3 * 10^-4` which is `0.0003`. The number after `e-` controls how many `0`s are in the number.
- Higher values let training run faster, but also are more likely to corrupt prior data in the model.
- You essentially have two variables to balance: the LR, and Epochs.
- If you make LR higher, you can set Epochs equally lower to match. High LR + low epochs = very fast, low quality training.
- If you make LR low, set epochs high. Low LR + high epochs = slow but high-quality training.
- The scheduler controls change-over-time as you train - it starts high, and then goes low. This helps balance getting data in, and having decent quality, at the same time.
- You can see graphs of the different scheduler options [in the HuggingFace docs here](https://moon-ci-docs.huggingface.co/docs/transformers/pr_1/en/main_classes/optimizer_schedules#transformers.SchedulerType)
## Loss
When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes.
"Loss" in the world of AI training theoretically means "how close is the model to perfect", with `0` meaning "absolutely perfect". This is calculated by measuring the difference between the model outputting exactly the text you're training it to output, and what it actually outputs.
In practice, a good LLM should have a very complex variable range of ideas running in its artificial head, so a loss of `0` would indicate that the model has broken and forgotten to how think about anything other than what you trained it.
So, in effect, Loss is a balancing game: you want to get it low enough that it understands your data, but high enough that it isn't forgetting everything else. Generally, if it goes below `1.0`, it's going to start forgetting its prior memories, and you should stop training. In some cases you may prefer to take it as low as `0.5` (if you want it to be very very predictable). Different goals have different needs, so don't be afraid to experiment and see what works best for you.
Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption).
## Note: 4-Bit Monkeypatch
The [4-bit LoRA monkeypatch](GPTQ-models-(4-bit-mode).md#using-loras-in-4-bit-mode) works for training, but has side effects:
- VRAM usage is higher currently. You can reduce the `Micro Batch Size` to `1` to compensate.
- Models do funky things. LoRAs apply themselves, or refuse to apply, or spontaneously error out, or etc. It can be helpful to reload base model or restart the WebUI between training/usage to minimize chances of anything going haywire.
- Loading or working with multiple LoRAs at the same time doesn't currently work.
- Generally, recognize and treat the monkeypatch as the dirty temporary hack it is - it works, but isn't very stable. It will get better in time when everything is merged upstream for full official support.
## Legacy notes
LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570).
### Using the original alpaca-lora code
Kept here for reference. The Training tab has much more features than this method.
```
conda activate textgen
git clone https://github.com/tloen/alpaca-lora
```
Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda:
```
model = LlamaForCausalLM.from_pretrained(
"models/llama-7b",
load_in_8bit=True,
device_map="auto",
)
tokenizer = LlamaTokenizer.from_pretrained(
"models/llama-7b", add_eos_token=True
)
```
Run the script with:
```
python finetune.py
```
It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode).

View File

@ -52,38 +52,4 @@ print(f"Predicted {len(output)} tokens for '{sentence}':\n{output}")
## Training a LoRA ## Training a LoRA
The Training tab in the interface can be used to train a LoRA. The parameters are self-documenting and good defaults are included. You can train your own LoRAs from the `Training` tab. See [Training LoRAs](Training-LoRAs.md) for details.
You can interrupt and resume LoRA training in this tab. If the name and rank are the same, training will resume using the `adapter_model.bin` in your LoRA folder. You can resume from a past checkpoint by replacing this file using the contents of one of the checkpoint folders. Note that the learning rate and steps will be reset, and you may want to set the learning rate to the last reported rate in the console output.
LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570).
#### Using the original alpaca-lora code
Kept here for reference. The Training tab has much more features than this method.
```
conda activate textgen
git clone https://github.com/tloen/alpaca-lora
```
Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda:
```
model = LlamaForCausalLM.from_pretrained(
"models/llama-7b",
load_in_8bit=True,
device_map="auto",
)
tokenizer = LlamaTokenizer.from_pretrained(
"models/llama-7b", add_eos_token=True
)
```
Run the script with:
```
python finetune.py
```
It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode).

View File

@ -0,0 +1,90 @@
import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from modules import shared
from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
generate_params = build_parameters(body)
stopping_strings = generate_params.pop('stopping_strings')
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
response = json.dumps({
'results': [{
'text': answer if shared.is_chat() else answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def _run_server(port: int, share: bool=False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
server = ThreadingHTTPServer((address, port), Handler)
def on_start(public_url: str):
print(f'Starting non-streaming server at public url {public_url}/api')
if share:
try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
except Exception:
pass
else:
print(
f'Starting API at http://{address}:{port}/api')
server.serve_forever()
def start_server(port: int, share: bool = False):
Thread(target=_run_server, args=[port, share], daemon=True).start()

View File

@ -1 +1,2 @@
flask_cloudflared==0.0.12 flask_cloudflared==0.0.12
websockets==11.0.2

View File

@ -1,115 +1,10 @@
import json import extensions.api.blocking_api as blocking_api
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer import extensions.api.streaming_api as streaming_api
from threading import Thread
from modules import shared from modules import shared
from modules.text_generation import encode, generate_reply
params = {
'port': 5000,
}
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path == '/api/v1/model':
self.send_response(200)
self.end_headers()
response = json.dumps({
'result': shared.model_name
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def do_POST(self):
content_length = int(self.headers['Content-Length'])
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
if self.path == '/api/v1/generate':
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_length', 200)),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical', 1)),
'repetition_penalty': float(body.get('rep_pen', 1.1)),
'encoder_repetition_penalty': 1,
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
'add_bos_token': int(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
stopping_strings = generate_params.pop('stopping_strings')
generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings)
answer = ''
for a in generator:
if isinstance(a, str):
answer = a
else:
answer = a[0]
response = json.dumps({
'results': [{
'text': answer if shared.is_chat() else answer[len(prompt):]
}]
})
self.wfile.write(response.encode('utf-8'))
elif self.path == '/api/v1/token-count':
# Not compatible with KoboldAI api
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.end_headers()
tokens = encode(body['prompt'])[0]
response = json.dumps({
'results': [{
'tokens': len(tokens)
}]
})
self.wfile.write(response.encode('utf-8'))
else:
self.send_error(404)
def run_server():
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
server = ThreadingHTTPServer(server_addr, Handler)
if shared.args.share:
try:
from flask_cloudflared import _run_cloudflared
public_url = _run_cloudflared(params['port'], params['port'] + 1)
print(f'Starting KoboldAI compatible api at {public_url}/api')
except ImportError:
print('You should install flask_cloudflared manually')
else:
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
server.serve_forever()
BLOCKING_PORT = 5000
STREAMING_PORT = 5005
def setup(): def setup():
Thread(target=run_server, daemon=True).start() blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)

View File

@ -0,0 +1,82 @@
import json
import asyncio
from websockets.server import serve
from threading import Thread
from modules import shared
from modules.text_generation import generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
PATH = '/api/v1/stream'
async def _handle_connection(websocket, path):
if path != PATH:
print(f'Streaming api: unknown path: {path}')
return
async for message in websocket:
message = json.loads(message)
prompt = message['prompt']
generate_params = build_parameters(message)
stopping_strings = generate_params.pop('stopping_strings')
generator = generate_reply(
prompt, generate_params, stopping_strings=stopping_strings)
# As we stream, only send the new bytes.
skip_index = len(prompt) if not shared.is_chat() else 0
message_num = 0
for a in generator:
to_send = ''
if isinstance(a, str):
to_send = a[skip_index:]
else:
to_send = a[0][skip_index:]
await websocket.send(json.dumps({
'event': 'text_stream',
'message_num': message_num,
'text': to_send
}))
await asyncio.sleep(0)
skip_index += len(to_send)
message_num += 1
await websocket.send(json.dumps({
'event': 'stream_end',
'message_num': message_num
}))
async def _run(host: str, port: int):
async with serve(_handle_connection, host, port):
await asyncio.Future() # run forever
def _run_server(port: int, share: bool = False):
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
def on_start(public_url: str):
public_url = public_url.replace('https://', 'wss://')
print(f'Starting streaming server at public url {public_url}{PATH}')
if share:
try:
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
except Exception as e:
print(e)
else:
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
asyncio.run(_run(host=address, port=port))
def start_server(port: int, share: bool = False):
Thread(target=_run_server, args=[port, share], daemon=True).start()

69
extensions/api/util.py Normal file
View File

@ -0,0 +1,69 @@
from threading import Thread
import time
from typing import Callable, Optional
from modules.text_generation import encode
def build_parameters(body):
prompt = body['prompt']
prompt_lines = [k.strip() for k in prompt.split('\n')]
max_context = body.get('max_context_length', 2048)
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
prompt_lines.pop(0)
prompt = '\n'.join(prompt_lines)
generate_params = {
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
'do_sample': bool(body.get('do_sample', True)),
'temperature': float(body.get('temperature', 0.5)),
'top_p': float(body.get('top_p', 1)),
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
'top_k': int(body.get('top_k', 0)),
'min_length': int(body.get('min_length', 0)),
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
'num_beams': int(body.get('num_beams', 1)),
'penalty_alpha': float(body.get('penalty_alpha', 0)),
'length_penalty': float(body.get('length_penalty', 1)),
'early_stopping': bool(body.get('early_stopping', False)),
'seed': int(body.get('seed', -1)),
'add_bos_token': int(body.get('add_bos_token', True)),
'truncation_length': int(body.get('truncation_length', 2048)),
'ban_eos_token': bool(body.get('ban_eos_token', False)),
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
'custom_stopping_strings': '', # leave this blank
'stopping_strings': body.get('stopping_strings', []),
}
return generate_params
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
Thread(target=_start_cloudflared, args=[
port, max_attempts, on_start], daemon=True).start()
def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
try:
from flask_cloudflared import _run_cloudflared
except ImportError:
print('You should install flask_cloudflared manually')
raise Exception(
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
for _ in range(max_attempts):
try:
public_url = _run_cloudflared(port, port + 1)
if on_start:
on_start(public_url)
return
except Exception:
time.sleep(3)
raise Exception('Could not start cloudflared.')

View File

@ -0,0 +1,49 @@
# LLaVA
## Description
Adds [LLaVA](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
## Usage
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g), and then start server.py with `--extensions llava` argument.
When in ui, go to instruct mode, and select LLaVA template, you also should add `"\n###"` to "Custom stopping strings" in parameters tab.
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
## Extension config
This extension uses following parameters (from settings.json):
|Parameter|Description|
|---------|-----------|
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|`llava-projector_bits`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
## Technical description
### Original LLaVA
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
```
images --> CLIP --> projector --> input embeddings for images --> |
| --> LLaMA
prompt -------------------------> input embeddings for text ----> |
```
The images are represented in the prompt by the following token IDs:
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
- 32001 - `<im_start>` - token marking start of an image
- 32002 - `<im_end>` - token marking end of an image
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
The concatenated prompt then gets fed to fine-tuned LLaMA.
### In this extension
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).

267
extensions/llava/script.py Normal file
View File

@ -0,0 +1,267 @@
import base64
import re
import time
from dataclasses import dataclass
from functools import partial
from io import BytesIO
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModel
from modules import shared
from modules.extensions import apply_extensions
from modules.text_generation import encode, get_max_prompt_length
params = {
"add_all_images_to_prompt": False,
# device to run CLIP on
"clip_device": None,
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
"clip_bits": 32,
# device to run projector on
"projector_device": None,
# projector bits, either 32 or 16
"projector_bits": 32
}
# If 'state' is True, will hijack the next chat generation
input_hijack = {
'state': False,
'value': ["", ""]
}
# initialized in ui, so that params are loaded from settings
llava_embedder = None
@dataclass
class Token:
token: str
id: int
class LLaVAEmbedder:
IM_PATCH = Token("<im_patch>", 32000)
IM_START = Token("<im_start>", 32001)
IM_END = Token("<im_end>", 32002)
CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14'
PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0'
PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin'
def __init__(self):
self.clip_device = self._get_device("clip_device")
self.clip_dtype = self._get_dtype("clip_bits")
self.projector_device = self._get_device("projector_device")
self.projector_dtype = self._get_dtype("projector_bits")
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
def _get_device(self, setting_name):
if params[setting_name] is None:
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
return torch.device(params[setting_name])
def _get_dtype(self, setting_name):
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
def _load_models(self):
start_ts = time.time()
print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...")
image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype)
vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device)
print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...")
projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE)
mm_projector = torch.nn.Linear(1024, 5120)
projector_data = torch.load(projector_path)
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
mm_projector = mm_projector.to(self.projector_device)
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
return image_processor, vision_tower, mm_projector
def _update_prompt(self, prompt, images):
for _ in images:
# replace the image token with the image patch token in the prompt (each occurrence)
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
return prompt
def _extract_image_features(self, images):
images = self.image_processor(images, return_tensors='pt')['pixel_values']
images = images.to(self.clip_device, dtype=self.clip_dtype)
with torch.no_grad():
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
select_hidden_state_layer = -2
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
image_features = self.mm_projector(image_features)
return image_features
def forward(self, prompt, images, state):
prompt = self._update_prompt(prompt, images)
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
# prompt got truncated in the middle of an image, remove the image data
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
input_ids = input_ids[im_end+1:]
input_embeds = input_embeds[im_end+1:]
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
images = images[-leftover_images:]
if len(images) == 0:
return prompt, input_ids, input_embeds, 0
total_embedded = 0
image_features = self._extract_image_features(images).to(self.projector_device)
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
# multimodal LLM, but the current prompt is not multimodal/truncated
return prompt, input_ids, input_embeds, total_embedded
cur_image_idx = 0
if not params['add_all_images_to_prompt']:
image_start_tokens = [image_start_tokens[-1]]
cur_image_idx = -1
for image_start_token_pos in image_start_tokens:
cur_image_features = image_features[cur_image_idx]
num_patches = cur_image_features.shape[0]
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
cur_image_idx += 1
total_embedded += 1
return prompt, input_ids, input_embeds, total_embedded
@staticmethod
def len_in_tokens(text):
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
image_tokens = 0
for _ in images:
image_tokens += 258
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
def add_chat_picture(picture, text, visible_text):
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
max_hw, min_hw = max(picture.size), min(picture.size)
aspect_ratio = max_hw / min_hw
shortest_edge = int(max(300 / aspect_ratio, 224))
longest_edge = int(shortest_edge * aspect_ratio)
w = shortest_edge if picture.width < picture.height else longest_edge
h = shortest_edge if picture.width >= picture.height else longest_edge
picture = picture.resize((w,h))
buffer = BytesIO()
picture.save(buffer, format="JPEG")
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
image = f'<img src="data:image/jpeg;base64,{img_str}">'
if '<image>' in text:
text = text.replace('<image>', image)
else:
text = text + '\n' + image
if visible_text == '' or visible_text is None:
visible_text = text
elif '<image>' in visible_text:
visible_text = visible_text.replace('<image>', image)
else:
visible_text = visible_text + '\n' + image
return text, visible_text
def custom_generate_chat_prompt(user_input, state, **kwargs):
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
_continue = kwargs['_continue'] if '_continue' in kwargs else False
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
rows = [f"{state['context'].strip()}\n"]
min_rows = 3
# Finding the maximum prompt size
chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
prefix1 = f"{state['name1']}: "
prefix2 = f"{state['name2']}: "
i = len(shared.history['internal']) - 1
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
if _continue and i == len(shared.history['internal']) - 1:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
else:
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
string = shared.history['internal'][i][0]
if string != '':
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
i -= 1
if impersonate:
min_rows = 2
rows.append(f"{prefix1}")
elif not _continue:
# Adding the user message
if len(user_input) > 0:
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
rows.pop(1)
prompt = ''.join(rows)
if also_return_rows:
return prompt, rows
else:
return prompt
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
global params
start_ts = time.time()
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
if len(images) == 0:
return prompt, input_ids, input_embeds
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device)
def ui():
global llava_embedder
llava_embedder = LLaVAEmbedder()
with gr.Column():
picture_select = gr.Image(label='Send a picture', type='pil')
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
# Prepare the input hijack
picture_select.upload(
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
[picture_select],
None
)
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
shared.gradio['Generate'].click(lambda: None, None, picture_select)
shared.gradio['textbox'].submit(lambda: None, None, picture_select)

View File

@ -48,3 +48,7 @@ llama-[0-9]*b-4bit$:
.*chatglm: .*chatglm:
mode: 'instruct' mode: 'instruct'
instruction_template: 'ChatGLM' instruction_template: 'ChatGLM'
.*llava:
mode: 'instruct'
model_type: 'llama'
instruction_template: 'LLaVA'

View File

@ -135,7 +135,7 @@ def load_quantized(model_name):
# Find the model type # Find the model type
if not shared.args.model_type: if not shared.args.model_type:
name = model_name.lower() name = model_name.lower()
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])): if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])):
model_type = 'llama' model_type = 'llama'
elif any((k in name for k in ['opt-', 'galactica'])): elif any((k in name for k in ['opt-', 'galactica'])):
model_type = 'opt' model_type = 'opt'

View File

@ -1,52 +0,0 @@
import json
import gradio as gr
from modules import shared
from modules.text_generation import generate_reply
# set this to True to rediscover the fn_index using the browser DevTools
VISIBLE = False
def generate_reply_wrapper(string):
# Provide defaults so as to not break the API on the client side when new parameters are added
generate_params = {
'max_new_tokens': 200,
'do_sample': True,
'temperature': 0.5,
'top_p': 1,
'typical_p': 1,
'repetition_penalty': 1.1,
'encoder_repetition_penalty': 1,
'top_k': 0,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'seed': -1,
'add_bos_token': True,
'custom_stopping_strings': '',
'truncation_length': 2048,
'ban_eos_token': False,
'skip_special_tokens': True,
'stopping_strings': [],
}
params = json.loads(string)
generate_params.update(params[1])
stopping_strings = generate_params.pop('stopping_strings')
for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings):
yield i
def create_apis():
t1 = gr.Textbox(visible=VISIBLE)
t2 = gr.Textbox(visible=VISIBLE)
dummy = gr.Button(visible=VISIBLE)
input_params = [t1]
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')

View File

@ -10,7 +10,6 @@ from pathlib import Path
import yaml import yaml
from PIL import Image from PIL import Image
import modules.extensions as extensions_module
import modules.shared as shared import modules.shared as shared
from modules.extensions import apply_extensions from modules.extensions import apply_extensions
from modules.html_generator import chat_html_wrapper, make_thumbnail from modules.html_generator import chat_html_wrapper, make_thumbnail
@ -30,8 +29,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
chat_prompt_size = state['chat_prompt_size'] chat_prompt_size = state['chat_prompt_size']
if shared.soft_prompt: if shared.soft_prompt:
chat_prompt_size -= shared.soft_prompt_tensor.shape[1] chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
max_length = min(get_max_prompt_length(state), chat_prompt_size)
max_length = min(get_max_prompt_length(state), chat_prompt_size)
if is_instruct: if is_instruct:
prefix1 = f"{state['name1']}\n" prefix1 = f"{state['name1']}\n"
prefix2 = f"{state['name2']}\n" prefix2 = f"{state['name2']}\n"
@ -57,19 +56,18 @@ def generate_chat_prompt(user_input, state, **kwargs):
min_rows = 2 min_rows = 2
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
elif not _continue: elif not _continue:
# Adding the user message # Adding the user message
if len(user_input) > 0: if len(user_input) > 0:
this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n") rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
# Adding the Character prefix # Adding the Character prefix
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
rows.pop(1) rows.pop(1)
prompt = ''.join(rows)
prompt = ''.join(rows)
if also_return_rows: if also_return_rows:
return prompt, rows return prompt, rows
else: else:
@ -81,6 +79,7 @@ def get_stopping_strings(state):
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"] stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
else: else:
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"] stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]") stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]")
return stopping_strings return stopping_strings
@ -111,13 +110,13 @@ def extract_message_from_reply(reply, state):
break break
else: else:
continue continue
break break
return reply, next_character_found return reply, next_character_found
def chatbot_wrapper(text, state, regenerate=False, _continue=False): def chatbot_wrapper(text, state, regenerate=False, _continue=False):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") print("No model is loaded! Select one in the Model tab.")
yield shared.history['visible'] yield shared.history['visible']
@ -125,35 +124,36 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# Defining some variables # Defining some variables
cumulative_reply = '' cumulative_reply = ''
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
just_started = True just_started = True
visible_text = custom_generate_chat_prompt = None visible_text = None
eos_token = '\n' if state['stop_at_newline'] else None eos_token = '\n' if state['stop_at_newline'] else None
stopping_strings = get_stopping_strings(state) stopping_strings = get_stopping_strings(state)
# Check if any extension wants to hijack this function call # Preparing the input
for extension, _ in extensions_module.iterator(): if not any((regenerate, _continue)):
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: text, visible_text = apply_extensions('input_hijack', text, visible_text)
extension.input_hijack['state'] = False
text, visible_text = extension.input_hijack['value']
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if visible_text is None: if visible_text is None:
visible_text = text visible_text = text
if not _continue:
text = apply_extensions(text, "input") text = apply_extensions('input', text)
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
else:
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
if regenerate:
shared.history['visible'].pop()
shared.history['internal'].pop()
# *Is typing...*
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
elif _continue:
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
# Generating the prompt # Generating the prompt
kwargs = {'_continue': _continue} kwargs = {'_continue': _continue}
if custom_generate_chat_prompt is None: prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
if prompt is None:
prompt = generate_chat_prompt(text, state, **kwargs) prompt = generate_chat_prompt(text, state, **kwargs)
else:
prompt = custom_generate_chat_prompt(text, state, **kwargs)
# Yield *Is typing...*
if not any((regenerate, _continue)):
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
# Generate # Generate
for i in range(state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
@ -164,26 +164,26 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
# Extracting the reply # Extracting the reply
reply, next_character_found = extract_message_from_reply(reply, state) reply, next_character_found = extract_message_from_reply(reply, state)
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply) visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
visible_reply = apply_extensions(visible_reply, "output") visible_reply = apply_extensions("output", visible_reply)
if _continue:
sep = ' ' if last_reply[0][-1] not in [' ', '\n'] else ''
reply = last_reply[0] + sep + reply
sep = ' ' if last_reply[1][-1] not in [' ', '\n'] else ''
visible_reply = last_reply[1] + sep + visible_reply
# We need this global variable to handle the Stop event, # We need this global variable to handle the Stop event,
# otherwise gradio gets confused # otherwise gradio gets confused
if shared.stop_everything: if shared.stop_everything:
return shared.history['visible'] return shared.history['visible']
if just_started: if just_started:
just_started = False just_started = False
if not _continue: if not _continue:
shared.history['internal'].append(['', '']) shared.history['internal'].append(['', ''])
shared.history['visible'].append(['', '']) shared.history['visible'].append(['', ''])
if _continue:
sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply))
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
else:
shared.history['internal'][-1] = [text, reply] shared.history['internal'][-1] = [text, reply]
shared.history['visible'][-1] = [visible_text, visible_reply] shared.history['visible'][-1] = [visible_text, visible_reply]
if not shared.args.no_stream:
yield shared.history['visible'] yield shared.history['visible']
if next_character_found: if next_character_found:
break break
@ -195,7 +195,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
def impersonate_wrapper(text, state): def impersonate_wrapper(text, state):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") print("No model is loaded! Select one in the Model tab.")
yield '' yield ''
@ -209,7 +208,6 @@ def impersonate_wrapper(text, state):
# Yield *Is typing...* # Yield *Is typing...*
yield shared.processing_message yield shared.processing_message
for i in range(state['chat_generation_attempts']): for i in range(state['chat_generation_attempts']):
reply = None reply = None
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings): for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
@ -234,23 +232,16 @@ def regenerate_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else: else:
last_visible = shared.history['visible'].pop() for history in chatbot_wrapper('', state, regenerate=True):
last_internal = shared.history['internal'].pop() yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
# Yield '*Is typing...*'
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def continue_wrapper(text, state): def continue_wrapper(text, state):
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
else: else:
# Yield ' ...' for history in chatbot_wrapper('', state, _continue=True):
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode']) yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
def remove_last_message(name1, name2, mode): def remove_last_message(name1, name2, mode):
@ -273,14 +264,14 @@ def send_last_reply_to_input():
def replace_last_reply(text, name1, name2, mode): def replace_last_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0: if len(shared.history['visible']) > 0:
shared.history['visible'][-1][1] = text shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input") shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def send_dummy_message(text, name1, name2, mode): def send_dummy_message(text, name1, name2, mode):
shared.history['visible'].append([text, '']) shared.history['visible'].append([text, ''])
shared.history['internal'].append([apply_extensions(text, "input"), '']) shared.history['internal'].append([apply_extensions("input", text), ''])
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -288,8 +279,9 @@ def send_dummy_reply(text, name1, name2, mode):
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '': if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '':
shared.history['visible'].append(['', '']) shared.history['visible'].append(['', ''])
shared.history['internal'].append(['', '']) shared.history['internal'].append(['', ''])
shared.history['visible'][-1][1] = text shared.history['visible'][-1][1] = text
shared.history['internal'][-1][1] = apply_extensions(text, "input") shared.history['internal'][-1][1] = apply_extensions("input", text)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -303,11 +295,10 @@ def clear_chat_log(name1, name2, greeting, mode):
if greeting != '': if greeting != '':
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Save cleared logs # Save cleared logs
save_history(mode) save_history(mode)
return chat_html_wrapper(shared.history['visible'], name1, name2, mode) return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
@ -328,8 +319,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
for i in range(len(idx) - 1): for i in range(len(idx) - 1):
messages.append(dialogue[idx[i]:idx[i + 1]].strip()) messages.append(dialogue[idx[i]:idx[i + 1]].strip())
messages.append(dialogue[idx[-1]:].strip())
messages.append(dialogue[idx[-1]:].strip())
entry = ['', ''] entry = ['', '']
for i in messages: for i in messages:
if i.startswith(f'{name1}:'): if i.startswith(f'{name1}:'):
@ -338,6 +329,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
entry[1] = i[len(f'{name2}:'):].strip() entry[1] = i[len(f'{name2}:'):].strip()
if not (len(entry[0]) == 0 and len(entry[1]) == 0): if not (len(entry[0]) == 0 and len(entry[1]) == 0):
history.append(entry) history.append(entry)
entry = ['', ''] entry = ['', '']
print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
@ -346,6 +338,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
print("\n") print("\n")
for line in column.strip().split('\n'): for line in column.strip().split('\n'):
print("| " + line + "\n") print("| " + line + "\n")
print("|\n") print("|\n")
print("------------------------------") print("------------------------------")
@ -358,14 +351,17 @@ def save_history(mode, timestamp=False):
if mode == 'instruct': if mode == 'instruct':
if not timestamp: if not timestamp:
return return
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
if timestamp: if timestamp:
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
else: else:
fname = f"{shared.character}_persistent.json" fname = f"{shared.character}_persistent.json"
if not Path('logs').exists(): if not Path('logs').exists():
Path('logs').mkdir() Path('logs').mkdir()
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
@ -396,8 +392,10 @@ def build_pygmalion_style_context(data):
context = "" context = ""
if 'char_persona' in data and data['char_persona'] != '': if 'char_persona' in data and data['char_persona'] != '':
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
if 'world_scenario' in data and data['world_scenario'] != '': if 'world_scenario' in data and data['world_scenario'] != '':
context += f"Scenario: {data['world_scenario']}\n" context += f"Scenario: {data['world_scenario']}\n"
context = f"{context.strip()}\n<START>\n" context = f"{context.strip()}\n<START>\n"
return context return context
@ -412,6 +410,7 @@ def generate_pfp_cache(character):
img = make_thumbnail(Image.open(path)) img = make_thumbnail(Image.open(path))
img.save(Path('cache/pfp_character.png'), format='PNG') img.save(Path('cache/pfp_character.png'), format='PNG')
return img return img
return None return None
@ -475,7 +474,7 @@ def load_character(character, name1, name2, mode):
# Insert greeting if it exists # Insert greeting if it exists
if greeting != "": if greeting != "":
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
shared.history['visible'] += [['', apply_extensions(greeting, "output")]] shared.history['visible'] += [['', apply_extensions("output", greeting)]]
# Create .json log files since they don't already exist # Create .json log files since they don't already exist
save_history(mode) save_history(mode)
@ -483,10 +482,6 @@ def load_character(character, name1, name2, mode):
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode) return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
def load_default_history(name1, name2):
load_character("None", name1, name2, "chat")
def upload_character(json_file, img, tavern=False): def upload_character(json_file, img, tavern=False):
json_file = json_file if type(json_file) == str else json_file.decode('utf-8') json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
data = json.loads(json_file) data = json.loads(json_file)
@ -495,13 +490,17 @@ def upload_character(json_file, img, tavern=False):
while Path(f'characters/{outfile_name}.json').exists(): while Path(f'characters/{outfile_name}.json').exists():
outfile_name = f'{data["char_name"]}_{i:03d}' outfile_name = f'{data["char_name"]}_{i:03d}'
i += 1 i += 1
if tavern: if tavern:
outfile_name = f'TavernAI-{outfile_name}' outfile_name = f'TavernAI-{outfile_name}'
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f: with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
f.write(json_file) f.write(json_file)
if img is not None: if img is not None:
img = Image.open(io.BytesIO(img)) img = Image.open(io.BytesIO(img))
img.save(Path(f'characters/{outfile_name}.png')) img.save(Path(f'characters/{outfile_name}.png'))
print(f'New character saved to "characters/{outfile_name}.json".') print(f'New character saved to "characters/{outfile_name}.json".')
return outfile_name return outfile_name

View File

@ -1,4 +1,5 @@
import traceback import traceback
from functools import partial
import gradio as gr import gradio as gr
@ -10,21 +11,39 @@ available_extensions = []
setup_called = set() setup_called = set()
def apply_settings(extension, name):
if not hasattr(extension, 'params'):
return
for param in extension.params:
_id = f"{name}-{param}"
if _id not in shared.settings:
continue
extension.params[param] = shared.settings[_id]
def load_extensions(): def load_extensions():
global state, setup_called global state, setup_called
for i, name in enumerate(shared.args.extensions): for i, name in enumerate(shared.args.extensions):
if name in available_extensions: if name in available_extensions:
if name != 'api':
print(f'Loading the extension "{name}"... ', end='') print(f'Loading the extension "{name}"... ', end='')
try: try:
exec(f"import extensions.{name}.script") exec(f"import extensions.{name}.script")
extension = getattr(extensions, name).script extension = getattr(extensions, name).script
apply_settings(extension, name)
if extension not in setup_called and hasattr(extension, "setup"): if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension) setup_called.add(extension)
extension.setup() extension.setup()
state[name] = [True, i] state[name] = [True, i]
if name != 'api':
print('Ok.') print('Ok.')
except: except:
if name != 'api':
print('Fail.') print('Fail.')
traceback.print_exc() traceback.print_exc()
@ -36,32 +55,74 @@ def iterator():
# Extension functions that map string -> string # Extension functions that map string -> string
def apply_extensions(text, typ): def _apply_string_extensions(function_name, text):
for extension, _ in iterator(): for extension, _ in iterator():
if typ == "input" and hasattr(extension, "input_modifier"): if hasattr(extension, function_name):
text = extension.input_modifier(text) text = getattr(extension, function_name)(text)
elif typ == "output" and hasattr(extension, "output_modifier"):
text = extension.output_modifier(text)
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
text = extension.bot_prefix_modifier(text)
return text return text
# Input hijack of extensions
def _apply_input_hijack(text, visible_text):
for extension, _ in iterator():
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
extension.input_hijack['state'] = False
if callable(extension.input_hijack['value']):
text, visible_text = extension.input_hijack['value'](text, visible_text)
else:
text, visible_text = extension.input_hijack['value']
return text, visible_text
# custom_generate_chat_prompt handling
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
custom_generate_chat_prompt = None
for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs)
return None
# Extension functions that override the default tokenizer output
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
for extension, _ in iterator():
if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds
EXTENSION_MAP = {
"input": partial(_apply_string_extensions, "input_modifier"),
"output": partial(_apply_string_extensions, "output_modifier"),
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
"input_hijack": _apply_input_hijack,
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
}
def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs)
def create_extensions_block(): def create_extensions_block():
global setup_called global setup_called
# Updating the default values
for extension, name in iterator():
if hasattr(extension, 'params'):
for param in extension.params:
_id = f"{name}-{param}"
if _id in shared.settings:
extension.params[param] = shared.settings[_id]
should_display_ui = False should_display_ui = False
for extension, name in iterator(): for extension, name in iterator():
if hasattr(extension, "ui"): if hasattr(extension, "ui"):
should_display_ui = True should_display_ui = True
break
# Creating the extension ui elements # Creating the extension ui elements
if should_display_ui: if should_display_ui:

View File

@ -24,7 +24,8 @@ class LlamaCppModel:
'model_path': str(path), 'model_path': str(path),
'n_ctx': 2048, 'n_ctx': 2048,
'seed': 0, 'seed': 0,
'n_threads': shared.args.threads or None 'n_threads': shared.args.threads or None,
'n_batch': shared.args.n_batch
} }
self.model = Llama(**params) self.model = Llama(**params)
self.model.set_cache(LlamaCache) self.model.set_cache(LlamaCache)

View File

@ -50,6 +50,8 @@ def find_model_type(model_name):
return 'chatglm' return 'chatglm'
elif 'galactica' in model_name: elif 'galactica' in model_name:
return 'galactica' return 'galactica'
elif 'llava' in model_name:
return 'llava'
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])): elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
return 'gpt4chan' return 'gpt4chan'
else: else:
@ -217,6 +219,7 @@ def load_model(model_name):
tokenizer = None tokenizer = None
# Try to load an universal LLaMA tokenizer # Try to load an universal LLaMA tokenizer
if shared.model_type != 'llava':
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
if p.exists(): if p.exists():
print(f"Loading the universal LLaMA tokenizer from {p}...") print(f"Loading the universal LLaMA tokenizer from {p}...")

View File

@ -20,6 +20,9 @@ processing_message = '*Is typing...*'
# UI elements (buttons, sliders, HTML, etc) # UI elements (buttons, sliders, HTML, etc)
gradio = {} gradio = {}
# For keeping the values of UI elements on page reload
persistent_interface_state = {}
# Generation input parameters # Generation input parameters
input_params = [] input_params = []
@ -31,6 +34,7 @@ settings = {
'max_new_tokens_min': 1, 'max_new_tokens_min': 1,
'max_new_tokens_max': 2000, 'max_new_tokens_max': 2000,
'seed': -1, 'seed': -1,
'character': 'None',
'name1': 'You', 'name1': 'You',
'name2': 'Assistant', 'name2': 'Assistant',
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
@ -56,7 +60,7 @@ settings = {
'chat_default_extensions': ["gallery"], 'chat_default_extensions': ["gallery"],
'presets': { 'presets': {
'default': 'Default', 'default': 'Default',
'.*(alpaca|llama)': "LLaMA-Precise", '.*(alpaca|llama|llava)': "LLaMA-Precise",
'.*pygmalion': 'NovelAI-Storywriter', '.*pygmalion': 'NovelAI-Storywriter',
'.*RWKV': 'Naive', '.*RWKV': 'Naive',
}, },
@ -90,6 +94,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.') parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.')
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
@ -116,6 +121,7 @@ parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_
# llama.cpp # llama.cpp
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
parser.add_argument('--n_batch', type=int, default=8, help='Processing batch size for llama.cpp.')
# GPTQ # GPTQ
parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
@ -150,6 +156,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
# API
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
args = parser.parse_args() args = parser.parse_args()
args_defaults = parser.parse_args([]) args_defaults = parser.parse_args([])
@ -171,6 +182,13 @@ if args.trust_remote_code:
if args.share: if args.share:
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n") print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
# Activating the API extension
if args.api or args.public_api:
if args.extensions is None:
args.extensions = ['api']
elif 'api' not in args.extensions:
args.extensions.append('api')
def is_chat(): def is_chat():
return args.chat return args.chat

View File

@ -113,9 +113,11 @@ def set_manual_seed(seed):
seed = int(seed) seed = int(seed)
if seed == -1: if seed == -1:
seed = random.randint(1, 2**31) seed = random.randint(1, 2**31)
torch.manual_seed(seed) torch.manual_seed(seed)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed_all(seed)
return seed return seed
@ -123,8 +125,41 @@ def stop_everything_event():
shared.stop_everything = True shared.stop_everything = True
def generate_reply(question, state, eos_token=None, stopping_strings=[]): def get_generate_params(state):
generate_params = {}
# Models that are not on transformers
if shared.model_type in ['rwkv', 'llamacpp']:
generate_params['token_count'] = state['max_new_tokens']
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
else:
# FlexGen
if shared.args.flexgen:
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8
# transformers
else:
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
generate_params[k] = state[k]
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
if shared.args.no_cache:
generate_params.update({'use_cache': False})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
return generate_params
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
if shared.model_name == 'None' or shared.model is None: if shared.model_name == 'None' or shared.model is None:
print("No model is loaded! Select one in the Model tab.") print("No model is loaded! Select one in the Model tab.")
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
@ -133,40 +168,37 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
clear_torch_cache() clear_torch_cache()
seed = set_manual_seed(state['seed']) seed = set_manual_seed(state['seed'])
shared.stop_everything = False shared.stop_everything = False
generate_params = {} generate_params = get_generate_params(state)
t0 = time.time() t0 = time.time()
# Preparing the input
original_question = question original_question = question
if not shared.is_chat(): if not shared.is_chat():
question = apply_extensions(question, 'input') question = apply_extensions('input', question)
# These models are not part of Hugging Face, so we handle them # If the model is not on transformers, handle it separately and end this
# separately and terminate the function call earlier # function call earlier.
if shared.model_type in ['rwkv', 'llamacpp']: if shared.model_type in ['rwkv', 'llamacpp']:
if shared.args.verbose: if shared.args.verbose:
print(f'\n\n{question}\n--------------------\n') print(f'\n\n{question}\n--------------------\n')
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
generate_params[k] = state[k]
generate_params['token_count'] = state['max_new_tokens']
try: try:
if shared.args.no_stream: if shared.args.no_stream:
reply = shared.model.generate(context=question, **generate_params) reply = shared.model.generate(context=question, **generate_params)
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
else: else:
if not shared.is_chat(): if not shared.is_chat():
yield formatted_outputs(question, shared.model_name) yield formatted_outputs(question, shared.model_name)
# RWKV has proper streaming, which is very nice.
# No need to generate 8 tokens at a time.
for reply in shared.model.generate_with_streaming(context=question, **generate_params): for reply in shared.model.generate_with_streaming(context=question, **generate_params):
output = original_question + reply output = original_question + reply
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
except Exception: except Exception:
@ -178,19 +210,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
return return
# Encode the input
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
original_input_ids = input_ids
output = input_ids[0] output = input_ids[0]
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
if shared.args.verbose: if shared.args.verbose:
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n') print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) # Find the eos tokens
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
if eos_token is not None: if eos_token is not None:
eos_token_ids.append(int(encode(eos_token)[0][-1])) eos_token_ids.append(int(encode(eos_token)[0][-1]))
# Handling the stopping strings # Create the StoppingCriteriaList with the stopping strings
stopping_criteria_list = transformers.StoppingCriteriaList() stopping_criteria_list = transformers.StoppingCriteriaList()
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
if type(st) is list and len(st) > 0: if type(st) is list and len(st) > 0:
@ -198,30 +230,26 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
break break
if not shared.args.flexgen: # Update generate_params with the eos token and the stopping strings
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: if shared.args.flexgen:
generate_params[k] = state[k] generate_params['stop'] = eos_token_ids[-1]
else:
generate_params['eos_token_id'] = eos_token_ids generate_params['eos_token_id'] = eos_token_ids
generate_params['stopping_criteria'] = stopping_criteria_list generate_params['stopping_criteria'] = stopping_criteria_list
if state['ban_eos_token']:
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
else:
for k in ['max_new_tokens', 'do_sample', 'temperature']:
generate_params[k] = state[k]
generate_params['stop'] = eos_token_ids[-1]
if not shared.args.no_stream:
generate_params['max_new_tokens'] = 8
if shared.args.no_cache: # Add the encoded tokens to generate_params
generate_params.update({'use_cache': False})
if shared.args.deepspeed:
generate_params.update({'synced_gpus': True})
if shared.soft_prompt: if shared.soft_prompt:
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
original_input_ids = input_ids
generate_params.update({'inputs_embeds': inputs_embeds}) generate_params.update({'inputs_embeds': inputs_embeds})
generate_params.update({'inputs': filler_input_ids}) generate_params.update({'inputs': filler_input_ids})
else: else:
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
original_input_ids = input_ids
generate_params.update({'inputs': input_ids}) generate_params.update({'inputs': input_ids})
if inputs_embeds is not None:
generate_params.update({'inputs_embeds': inputs_embeds})
try: try:
# Generate the entire reply at once. # Generate the entire reply at once.
@ -237,7 +265,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0]) new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens']) reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions('output', reply)
yield formatted_outputs(reply, shared.model_name) yield formatted_outputs(reply, shared.model_name)
@ -265,7 +293,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(input_ids[0]) new_tokens = len(output) - len(input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens']) reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions('output', reply)
if output[-1] in eos_token_ids: if output[-1] in eos_token_ids:
break break
@ -285,7 +313,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
new_tokens = len(output) - len(original_input_ids[0]) new_tokens = len(output) - len(original_input_ids[0])
reply = decode(output[-new_tokens:], state['skip_special_tokens']) reply = decode(output[-new_tokens:], state['skip_special_tokens'])
if not shared.is_chat(): if not shared.is_chat():
reply = original_question + apply_extensions(reply, 'output') reply = original_question + apply_extensions('output', reply)
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
break break

View File

@ -18,14 +18,14 @@ from modules.evaluate import calculate_perplexity, generate_markdown_table, save
from server import get_available_loras, get_available_models from server import get_available_loras, get_available_models
# This mapping is from a very recent commit, not yet released. # This mapping is from a very recent commit, not yet released.
# If not available, default to a backup map for the 3 safe model types. # If not available, default to a backup map for some common model types.
try: try:
from peft.utils.other import \ from peft.utils.other import \
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
model_to_lora_modules model_to_lora_modules
except: except:
standard_modules = ["q_proj", "v_proj"] standard_modules = ["q_proj", "v_proj"]
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules} model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
WANT_INTERRUPT = False WANT_INTERRUPT = False
@ -35,7 +35,8 @@ PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size",
MODEL_CLASSES = { MODEL_CLASSES = {
"LlamaForCausalLM": "llama", "LlamaForCausalLM": "llama",
"OPTForCausalLM": "opt", "OPTForCausalLM": "opt",
"GPTJForCausalLM": "gptj" "GPTJForCausalLM": "gptj",
"GPTNeoXForCausalLM": "gpt_neox"
} }
@ -45,6 +46,8 @@ def get_datasets(path: str, ext: str):
def create_train_interface(): def create_train_interface():
with gr.Tab('Train LoRA', elem_id='lora-train-tab'): with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
with gr.Row(): with gr.Row():
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).') always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
@ -215,11 +218,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
else: else:
model_id = "llama" model_id = "llama"
if model_type == "PeftModelForCausalLM": if model_type == "PeftModelForCausalLM":
if len(shared.args.lora_names) > 0:
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
else: else:
yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})") print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
else:
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
time.sleep(5) time.sleep(5)
if shared.args.wbits > 0 and not shared.args.monkey_patch: if shared.args.wbits > 0 and not shared.args.monkey_patch:

View File

@ -33,9 +33,10 @@ def list_model_elements():
def list_interface_input_elements(chat=False): def list_interface_input_elements(chat=False):
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens'] elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
if chat: if chat:
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template'] elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
elements += list_model_elements() elements += list_model_elements()
return elements return elements
@ -44,11 +45,26 @@ def gather_interface_values(*args):
output = {} output = {}
for i, element in enumerate(shared.input_elements): for i, element in enumerate(shared.input_elements):
output[element] = args[i] output[element] = args[i]
shared.persistent_interface_state = output
return output return output
def apply_interface_values(state): def apply_interface_values(state, use_persistent=False):
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())] if use_persistent:
state = shared.persistent_interface_state
elements = list_interface_input_elements(chat=shared.is_chat())
if len(state) == 0:
return [gr.update() for k in elements] # Dummy, do nothing
else:
if use_persistent and 'mode' in state:
if state['mode'] == 'instruct':
return [state[k] if (k not in ['character_menu'] and k in state) else gr.update() for k in elements]
else:
return [state[k] if (k not in ['instruction_template'] and k in state) else gr.update() for k in elements]
else:
return [state[k] if k in state else gr.update() for k in elements]
class ToolButton(gr.Button, gr.components.FormComponent): class ToolButton(gr.Button, gr.components.FormComponent):

View File

@ -16,5 +16,5 @@ tqdm
git+https://github.com/huggingface/peft git+https://github.com/huggingface/peft
transformers==4.28.1 transformers==4.28.1
bitsandbytes==0.38.1; platform_system != "Windows" bitsandbytes==0.38.1; platform_system != "Windows"
llama-cpp-python==0.1.34; platform_system != "Windows" llama-cpp-python==0.1.36; platform_system != "Windows"
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.36/llama_cpp_python-0.1.36-cp310-cp310-win_amd64.whl; platform_system == "Windows"

View File

@ -32,6 +32,7 @@ import time
import traceback import traceback
import zipfile import zipfile
from datetime import datetime from datetime import datetime
from functools import partial
from pathlib import Path from pathlib import Path
import psutil import psutil
@ -40,7 +41,7 @@ import yaml
from PIL import Image from PIL import Image
import modules.extensions as extensions_module import modules.extensions as extensions_module
from modules import api, chat, shared, training, ui from modules import chat, shared, training, ui
from modules.html_generator import chat_html_wrapper from modules.html_generator import chat_html_wrapper
from modules.LoRA import add_lora_to_model from modules.LoRA import add_lora_to_model
from modules.models import load_model, load_soft_prompt, unload_model from modules.models import load_model, load_soft_prompt, unload_model
@ -543,7 +544,7 @@ def create_interface():
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode') shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode')
shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value=shared.settings['instruction_template'], visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.') shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value='None', visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.')
with gr.Tab('Character', elem_id='chat-settings'): with gr.Tab('Character', elem_id='chat-settings'):
with gr.Row(): with gr.Row():
@ -559,7 +560,7 @@ def create_interface():
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None) shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
with gr.Row(): with gr.Row():
shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), value='None', label='Character', elem_id='character-menu') shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), label='Character', elem_id='character-menu')
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
with gr.Row(): with gr.Row():
@ -710,14 +711,6 @@ def create_interface():
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then( set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}') lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
# Extensions block
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
# Create the invisible elements that define the API
if not shared.is_chat():
api.create_apis()
# chat mode event handlers # chat mode event handlers
if shared.is_chat(): if shared.is_chat():
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
@ -801,16 +794,11 @@ def create_interface():
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
# notebook/default modes event handlers # notebook/default modes event handlers
else: else:
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
if shared.args.notebook: if shared.args.notebook:
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
else: else:
@ -851,6 +839,11 @@ def create_interface():
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
# Extensions block
if shared.args.extensions is not None:
extensions_module.create_extensions_block()
# Launch the interface # Launch the interface
shared.gradio['interface'].queue() shared.gradio['interface'].queue()
if shared.args.listen: if shared.args.listen:
@ -860,7 +853,6 @@ def create_interface():
if __name__ == "__main__": if __name__ == "__main__":
# Loading custom settings # Loading custom settings
settings_file = None settings_file = None
if shared.args.settings is not None and Path(shared.args.settings).exists(): if shared.args.settings is not None and Path(shared.args.settings).exists():
@ -905,14 +897,15 @@ if __name__ == "__main__":
print('The following models are available:\n') print('The following models are available:\n')
for i, model in enumerate(available_models): for i, model in enumerate(available_models):
print(f'{i+1}. {model}') print(f'{i+1}. {model}')
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
i = int(input()) - 1 i = int(input()) - 1
print() print()
shared.model_name = available_models[i] shared.model_name = available_models[i]
# If any model has been selected, load it # If any model has been selected, load it
if shared.model_name != 'None': if shared.model_name != 'None':
model_settings = get_model_specific_settings(shared.model_name) model_settings = get_model_specific_settings(shared.model_name)
shared.settings.update(model_settings) # hijacking the interface defaults shared.settings.update(model_settings) # hijacking the interface defaults
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
@ -922,6 +915,14 @@ if __name__ == "__main__":
if shared.args.lora: if shared.args.lora:
add_lora_to_model([shared.args.lora]) add_lora_to_model([shared.args.lora])
# Force a character to be loaded
if shared.is_chat():
shared.persistent_interface_state.update({
'mode': shared.settings['mode'],
'character_menu': shared.args.character or shared.settings['character'],
'instruction_template': shared.settings['instruction_template']
})
# Launch the web UI # Launch the web UI
create_interface() create_interface()
while True: while True:

View File

@ -3,6 +3,7 @@
"max_new_tokens_min": 1, "max_new_tokens_min": 1,
"max_new_tokens_max": 2000, "max_new_tokens_max": 2000,
"seed": -1, "seed": -1,
"character": "None",
"name1": "You", "name1": "You",
"name2": "Assistant", "name2": "Assistant",
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
@ -30,7 +31,7 @@
], ],
"presets": { "presets": {
"default": "Default", "default": "Default",
".*(alpaca|llama)": "LLaMA-Precise", ".*(alpaca|llama|llava)": "LLaMA-Precise",
".*pygmalion": "NovelAI-Storywriter", ".*pygmalion": "NovelAI-Storywriter",
".*RWKV": "Naive" ".*RWKV": "Naive"
}, },