mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
Merge branch 'oobabooga:main' into searx_integration_bs4
This commit is contained in:
commit
a325e13857
@ -188,6 +188,7 @@ Optionally, you can use the following command-line flags:
|
||||
| `-h`, `--help` | Show this help message and exit. |
|
||||
| `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. |
|
||||
| `--chat` | Launch the web UI in chat mode. |
|
||||
| `--character CHARACTER` | The name of the character to load in chat mode by default. |
|
||||
| `--model MODEL` | Name of the model to load by default. |
|
||||
| `--lora LORA` | Name of the LoRA to apply to the model by default. |
|
||||
| `--model-dir MODEL_DIR` | Path to directory with all the models. |
|
||||
@ -220,6 +221,7 @@ Optionally, you can use the following command-line flags:
|
||||
| Flag | Description |
|
||||
|-------------|-------------|
|
||||
| `--threads` | Number of threads to use in llama.cpp. |
|
||||
| `--n_batch` | Processing batch size for llama.cpp. |
|
||||
|
||||
#### GPTQ
|
||||
|
||||
@ -269,6 +271,13 @@ Optionally, you can use the following command-line flags:
|
||||
| `--auto-launch` | Open the web UI in the default browser upon launch. |
|
||||
| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" |
|
||||
|
||||
#### API
|
||||
|
||||
| Flag | Description |
|
||||
|---------------------------------------|-------------|
|
||||
| `--api` | Enable the API extension. |
|
||||
| `--public-api` | Create a public URL for the API using Cloudfare. |
|
||||
|
||||
Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md).
|
||||
|
||||
## Presets
|
||||
|
@ -1,39 +1,30 @@
|
||||
'''
|
||||
|
||||
Contributed by SagsMug. Thank you SagsMug.
|
||||
https://github.com/oobabooga/text-generation-webui/pull/175
|
||||
|
||||
'''
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import string
|
||||
import sys
|
||||
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# Gradio changes this index from time to time. To rediscover it, set VISIBLE = False in
|
||||
# modules/api.py and use the dev tools to inspect the request made after clicking on the
|
||||
# button called "Run" at the bottom of the UI
|
||||
GRADIO_FN = 34
|
||||
|
||||
|
||||
def random_hash():
|
||||
letters = string.ascii_lowercase + string.digits
|
||||
return ''.join(random.choice(letters) for i in range(9))
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5005'
|
||||
URI = f'ws://{HOST}/api/v1/stream'
|
||||
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
async def run(context):
|
||||
server = "127.0.0.1"
|
||||
params = {
|
||||
'max_new_tokens': 200,
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
'prompt': context,
|
||||
'max_new_tokens': 250,
|
||||
'do_sample': True,
|
||||
'temperature': 0.72,
|
||||
'top_p': 0.73,
|
||||
'temperature': 1.3,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.1,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
@ -45,48 +36,31 @@ async def run(context):
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': [],
|
||||
'stopping_strings': []
|
||||
}
|
||||
payload = json.dumps([context, params])
|
||||
session = random_hash()
|
||||
|
||||
async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket:
|
||||
while content := json.loads(await websocket.recv()):
|
||||
# Python3.10 syntax, replace with if elif on older
|
||||
match content["msg"]:
|
||||
case "send_hash":
|
||||
await websocket.send(json.dumps({
|
||||
"session_hash": session,
|
||||
"fn_index": GRADIO_FN
|
||||
}))
|
||||
case "estimation":
|
||||
pass
|
||||
case "send_data":
|
||||
await websocket.send(json.dumps({
|
||||
"session_hash": session,
|
||||
"fn_index": GRADIO_FN,
|
||||
"data": [
|
||||
payload
|
||||
]
|
||||
}))
|
||||
case "process_starts":
|
||||
pass
|
||||
case "process_generating" | "process_completed":
|
||||
yield content["output"]["data"][0]
|
||||
# You can search for your desired end indicator and
|
||||
# stop generation by closing the websocket here
|
||||
if (content["msg"] == "process_completed"):
|
||||
break
|
||||
async with websockets.connect(URI) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
prompt = "What I would like to say is the following: "
|
||||
yield context # Remove this if you just want to see the reply
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
incoming_data = json.loads(incoming_data)
|
||||
|
||||
match incoming_data['event']:
|
||||
case 'text_stream':
|
||||
yield incoming_data['text']
|
||||
case 'stream_end':
|
||||
return
|
||||
|
||||
|
||||
async def get_result():
|
||||
async def print_response_stream(prompt):
|
||||
async for response in run(prompt):
|
||||
# Print intermediate steps
|
||||
print(response)
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
|
||||
# Print final result
|
||||
print(response)
|
||||
|
||||
asyncio.run(get_result())
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
asyncio.run(print_response_stream(prompt))
|
||||
|
@ -1,33 +1,22 @@
|
||||
'''
|
||||
|
||||
This is an example on how to use the API for oobabooga/text-generation-webui.
|
||||
|
||||
Make sure to start the web UI with the following flags:
|
||||
|
||||
python server.py --model MODEL --listen --no-stream
|
||||
|
||||
Optionally, you can also add the --share flag to generate a public gradio URL,
|
||||
allowing you to use the API remotely.
|
||||
|
||||
'''
|
||||
import json
|
||||
|
||||
import requests
|
||||
|
||||
# Server address
|
||||
server = "127.0.0.1"
|
||||
# For local streaming, the websockets are hosted without ssl - http://
|
||||
HOST = 'localhost:5000'
|
||||
URI = f'http://{HOST}/api/v1/generate'
|
||||
|
||||
# Generation parameters
|
||||
# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig
|
||||
params = {
|
||||
'max_new_tokens': 200,
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||
|
||||
def run(context):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
'max_new_tokens': 250,
|
||||
'do_sample': True,
|
||||
'temperature': 0.72,
|
||||
'top_p': 0.73,
|
||||
'temperature': 1.3,
|
||||
'top_p': 0.1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.1,
|
||||
'encoder_repetition_penalty': 1.0,
|
||||
'top_k': 0,
|
||||
'repetition_penalty': 1.18,
|
||||
'top_k': 40,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
@ -39,19 +28,15 @@ params = {
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': [],
|
||||
'stopping_strings': []
|
||||
}
|
||||
|
||||
# Input prompt
|
||||
prompt = "What I would like to say is the following: "
|
||||
response = requests.post(URI, json=request)
|
||||
|
||||
payload = json.dumps([prompt, params])
|
||||
if response.status_code == 200:
|
||||
result = response.json()['results'][0]['text']
|
||||
print(prompt + result)
|
||||
|
||||
response = requests.post(f"http://{server}:7860/run/textgen", json={
|
||||
"data": [
|
||||
payload
|
||||
]
|
||||
}).json()
|
||||
|
||||
reply = response["data"][0]
|
||||
print(reply)
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
run(prompt)
|
||||
|
3
characters/instruction-following/LLaVA.yaml
Normal file
3
characters/instruction-following/LLaVA.yaml
Normal file
@ -0,0 +1,3 @@
|
||||
name: "### Assistant"
|
||||
your_name: "### Human"
|
||||
context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n"
|
@ -16,19 +16,24 @@ command-line flag.
|
||||
|
||||
The link above contains a directory of user extensions for text-generation-webui.
|
||||
|
||||
If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above.
|
||||
|
||||
## Built-in extensions
|
||||
|
||||
Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a).
|
||||
|
||||
|Extension|Description|
|
||||
|---------|-----------|
|
||||
|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` por 5000. This is the main API for this web UI. |
|
||||
|[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.|
|
||||
|[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.|
|
||||
|[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. |
|
||||
|[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. |
|
||||
|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. Author: [@MetaIX](https://github.com/MetaIX). |
|
||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. Author: [@SillyLossy](https://github.com/sillylossy).|
|
||||
|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API similar to the one provided by KoboldAI. Works with TavernAI: start the web UI with `python server.py --no-stream --extensions api` and set the API URL to `http://127.0.0.1:5000/api`. Author: [@mayaeary](https://github.com/mayaeary).|
|
||||
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. Author: [@EliasVincent](https://github.com/EliasVincent).|
|
||||
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). Author: [@Brawlence](https://github.com/Brawlence).|
|
||||
|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. |
|
||||
|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. |
|
||||
|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. |
|
||||
|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). |
|
||||
|[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. |
|
||||
|
||||
## How to write an extension
|
||||
|
||||
@ -41,6 +46,7 @@ The link above contains a directory of user extensions for text-generation-webui
|
||||
| `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. |
|
||||
| `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). |
|
||||
| `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. |
|
||||
| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example |
|
||||
|
||||
Additionally, the script may define two special global variables:
|
||||
|
||||
@ -66,7 +72,9 @@ input_hijack = {
|
||||
'value': ["", ""]
|
||||
}
|
||||
```
|
||||
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the vales inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
|
||||
This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example.
|
||||
|
||||
Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example.
|
||||
|
||||
## The `bot_prefix_modifier`
|
||||
|
||||
|
167
docs/Training-LoRAs.md
Normal file
167
docs/Training-LoRAs.md
Normal file
@ -0,0 +1,167 @@
|
||||
## Training Your Own LoRAs
|
||||
|
||||
The WebUI seeks to make training your own LoRAs as easy as possible. It comes down to just a few simple steps:
|
||||
|
||||
### **Step 1**: Make a plan.
|
||||
- What base model do you want to use? The LoRA you make has to be matched up to a single architecture (eg LLaMA-13B) and cannot be transferred to others (eg LLaMA-7B, StableLM, etc. would all be different). Derivatives of the same model (eg Alpaca finetune of LLaMA-13B) might be transferrable, but even then it's best to train exactly on what you plan to use.
|
||||
- What model format do you want? At time of writing, 8-bit models are most stable, and 4-bit are supported but experimental. In the near future it is likely that 4-bit will be the best option for most users.
|
||||
- What are you training it on? Do you want it to learn real information, a simple format, ...?
|
||||
|
||||
### **Step 2**: Gather a dataset.
|
||||
- If you use a dataset similar to the [Alpaca](https://github.com/gururise/AlpacaDataCleaned/blob/main/alpaca_data_cleaned.json) format, that is natively supported by the `Formatted Dataset` input in the WebUI, with premade formatter options.
|
||||
- If you use a dataset that isn't matched to Alpaca's format, but uses the same basic JSON structure, you can make your own format file by copying `training/formats/alpaca-format.json` to a new file and [editing its content](#format-files).
|
||||
- If you can get the dataset into a simple text file, that works too! You can train using the `Raw text file` input option.
|
||||
- This means you can for example just copy/paste a chatlog/documentation page/whatever you want, shove it in a plain text file, and train on it.
|
||||
- If you use a structured dataset not in this format, you may have to find an external way to convert it - or open an issue to request native support.
|
||||
|
||||
### **Step 3**: Do the training.
|
||||
- **3.1**: Load the WebUI, and your model.
|
||||
- Make sure you don't have any LoRAs already loaded (unless you want to train for multi-LoRA usage).
|
||||
- **3.2**: Open the `Training` tab at the top, `Train LoRA` sub-tab.
|
||||
- **3.3**: Fill in the name lof the LoRA, select your dataset in the dataset options.
|
||||
- **3.4**: Select other parameters to your preference. See [parameters below](#parameters).
|
||||
- **3.5**: click `Start LoRA Training`, and wait.
|
||||
- It can take a few hours for a large dataset, or just a few minute if doing a small run.
|
||||
- You may want to monitor your [loss value](#loss) while it goes.
|
||||
|
||||
### **Step 4**: Evaluate your results.
|
||||
- Load the LoRA under the Models Tab.
|
||||
- You can go test-drive it on the `Text generation` tab, or you can use the `Perplexity evaluation` sub-tab of the `Training` tab.
|
||||
- If you used the `Save every n steps` option, you can grab prior copies of the model from sub-folders within the LoRA model's folder and try them instead.
|
||||
|
||||
### **Step 5**: Re-run if you're unhappy.
|
||||
- Make sure to unload the LoRA before training it.
|
||||
- You can simply resume a prior run - use `Copy parameters from` to select your LoRA, and edit parameters. Note that you cannot change the `Rank` of an already created LoRA.
|
||||
- If you want to resume from a checkpoint saved along the way, simply copy the contents of the checkpoint folder into the LoRA's folder.
|
||||
- (Note: `adapter_model.bin` is the important file that holds the actual LoRA content).
|
||||
- This will start Learning Rate and Steps back to the start. If you want to resume as if you were midway through, you can adjust your Learning Rate to the last reported LR in logs and reduce your epochs.
|
||||
- Or, you can start over entirely if you prefer.
|
||||
- If your model is producing corrupted outputs, you probably need to start over and use a lower Learning Rate.
|
||||
- If your model isn't learning detailed information but you want it to, you might need to just run more epochs, or you might need a higher Rank.
|
||||
- If your model is enforcing a format you didn't want, you may need to tweak your dataset, or start over and not train as far.
|
||||
|
||||
## Format Files
|
||||
|
||||
If using JSON formatted datasets, they are presumed to be in the following approximate format:
|
||||
|
||||
```json
|
||||
[
|
||||
{
|
||||
"somekey": "somevalue",
|
||||
"key2": "value2"
|
||||
},
|
||||
{
|
||||
// etc
|
||||
}
|
||||
]
|
||||
```
|
||||
|
||||
Where the keys (eg `somekey`, `key2` above) are standardized, and relatively consistent across the dataset, and the values (eg `somevalue`, `value2`) contain the content actually intended to be trained.
|
||||
|
||||
For Alpaca, the keys are `instruction`, `input`, and `output`, wherein `input` is sometimes blank.
|
||||
|
||||
A simple format file for Alpaca to be used as a chat bot is:
|
||||
|
||||
```json
|
||||
{
|
||||
"instruction,output": "User: %instruction%\nAssistant: %output%",
|
||||
"instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%"
|
||||
}
|
||||
```
|
||||
|
||||
Note that the keys (eg `instruction,output`) are a comma-separated list of dataset keys, and the values are a simple string that use those keys with `%%`.
|
||||
|
||||
So for example if a dataset has `"instruction": "answer my question"`, then the format file's `User: %instruction%\n` will be automatically filled in as `User: answer my question\n`.
|
||||
|
||||
If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs.
|
||||
|
||||
## Parameters
|
||||
|
||||
The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options.
|
||||
|
||||
That said, here's a guide to the most important parameter choices you should consider:
|
||||
|
||||
### VRAM
|
||||
|
||||
- First, you must consider your VRAM availability.
|
||||
- Generally, under default settings, VRAM usage for training with default parameters is very close to when generating text (with 1000+ tokens of context) (ie, if you can generate text, you can train LoRAs).
|
||||
- Note: worse by default in the 4-bit monkeypatch currently. Reduce `Micro Batch Size` to `1` to restore this to expectations.
|
||||
- If you have VRAM to spare, setting higher batch sizes will use more VRAM and get you better quality training in exchange.
|
||||
- If you have large data, setting a higher cutoff length may be beneficial, but will cost significant VRAM. If you can spare some, set your batch size to `1` and see how high you can push your cutoff length.
|
||||
- If you're low on VRAM, reducing batch size or cutoff length will of course improve that.
|
||||
- Don't be afraid to just try it and see what happens. If it's too much, it will just error out, and you can lower settings and try again.
|
||||
|
||||
### Rank
|
||||
|
||||
- Second, you want to consider the amount of learning you want.
|
||||
- For example, you may wish to just learn a dialogue format (as in the case of Alpaca) in which case setting a low `Rank` value (32 or lower) works great.
|
||||
- Or, you might be training on project documentation you want the bot to understand and be able to understand questions about, in which case the higher the rank, the better.
|
||||
- Generally, higher Rank = more precise learning = more total content learned = more VRAM usage while training.
|
||||
|
||||
### Learning Rate and Epochs
|
||||
|
||||
- Third, how carefully you want it to be learned.
|
||||
- In other words, how okay or not you are with the model losing unrelated understandings.
|
||||
- You can control this with 3 key settings: the Learning Rate, its scheduler, and your total epochs.
|
||||
- The learning rate controls how much change is made to the model by each token it sees.
|
||||
- It's in scientific notation normally, so for example `3e-4` means `3 * 10^-4` which is `0.0003`. The number after `e-` controls how many `0`s are in the number.
|
||||
- Higher values let training run faster, but also are more likely to corrupt prior data in the model.
|
||||
- You essentially have two variables to balance: the LR, and Epochs.
|
||||
- If you make LR higher, you can set Epochs equally lower to match. High LR + low epochs = very fast, low quality training.
|
||||
- If you make LR low, set epochs high. Low LR + high epochs = slow but high-quality training.
|
||||
- The scheduler controls change-over-time as you train - it starts high, and then goes low. This helps balance getting data in, and having decent quality, at the same time.
|
||||
- You can see graphs of the different scheduler options [in the HuggingFace docs here](https://moon-ci-docs.huggingface.co/docs/transformers/pr_1/en/main_classes/optimizer_schedules#transformers.SchedulerType)
|
||||
|
||||
## Loss
|
||||
|
||||
When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes.
|
||||
|
||||
"Loss" in the world of AI training theoretically means "how close is the model to perfect", with `0` meaning "absolutely perfect". This is calculated by measuring the difference between the model outputting exactly the text you're training it to output, and what it actually outputs.
|
||||
|
||||
In practice, a good LLM should have a very complex variable range of ideas running in its artificial head, so a loss of `0` would indicate that the model has broken and forgotten to how think about anything other than what you trained it.
|
||||
|
||||
So, in effect, Loss is a balancing game: you want to get it low enough that it understands your data, but high enough that it isn't forgetting everything else. Generally, if it goes below `1.0`, it's going to start forgetting its prior memories, and you should stop training. In some cases you may prefer to take it as low as `0.5` (if you want it to be very very predictable). Different goals have different needs, so don't be afraid to experiment and see what works best for you.
|
||||
|
||||
Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption).
|
||||
|
||||
## Note: 4-Bit Monkeypatch
|
||||
|
||||
The [4-bit LoRA monkeypatch](GPTQ-models-(4-bit-mode).md#using-loras-in-4-bit-mode) works for training, but has side effects:
|
||||
- VRAM usage is higher currently. You can reduce the `Micro Batch Size` to `1` to compensate.
|
||||
- Models do funky things. LoRAs apply themselves, or refuse to apply, or spontaneously error out, or etc. It can be helpful to reload base model or restart the WebUI between training/usage to minimize chances of anything going haywire.
|
||||
- Loading or working with multiple LoRAs at the same time doesn't currently work.
|
||||
- Generally, recognize and treat the monkeypatch as the dirty temporary hack it is - it works, but isn't very stable. It will get better in time when everything is merged upstream for full official support.
|
||||
|
||||
## Legacy notes
|
||||
|
||||
LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570).
|
||||
|
||||
### Using the original alpaca-lora code
|
||||
|
||||
Kept here for reference. The Training tab has much more features than this method.
|
||||
|
||||
```
|
||||
conda activate textgen
|
||||
git clone https://github.com/tloen/alpaca-lora
|
||||
```
|
||||
|
||||
Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda:
|
||||
|
||||
```
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"models/llama-7b",
|
||||
load_in_8bit=True,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"models/llama-7b", add_eos_token=True
|
||||
)
|
||||
```
|
||||
|
||||
Run the script with:
|
||||
|
||||
```
|
||||
python finetune.py
|
||||
```
|
||||
|
||||
It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode).
|
@ -52,38 +52,4 @@ print(f"Predicted {len(output)} tokens for '{sentence}':\n{output}")
|
||||
|
||||
## Training a LoRA
|
||||
|
||||
The Training tab in the interface can be used to train a LoRA. The parameters are self-documenting and good defaults are included.
|
||||
|
||||
You can interrupt and resume LoRA training in this tab. If the name and rank are the same, training will resume using the `adapter_model.bin` in your LoRA folder. You can resume from a past checkpoint by replacing this file using the contents of one of the checkpoint folders. Note that the learning rate and steps will be reset, and you may want to set the learning rate to the last reported rate in the console output.
|
||||
|
||||
LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570).
|
||||
|
||||
#### Using the original alpaca-lora code
|
||||
|
||||
Kept here for reference. The Training tab has much more features than this method.
|
||||
|
||||
```
|
||||
conda activate textgen
|
||||
git clone https://github.com/tloen/alpaca-lora
|
||||
```
|
||||
|
||||
Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda:
|
||||
|
||||
```
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"models/llama-7b",
|
||||
load_in_8bit=True,
|
||||
device_map="auto",
|
||||
)
|
||||
tokenizer = LlamaTokenizer.from_pretrained(
|
||||
"models/llama-7b", add_eos_token=True
|
||||
)
|
||||
```
|
||||
|
||||
Run the script with:
|
||||
|
||||
```
|
||||
python finetune.py
|
||||
```
|
||||
|
||||
It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode).
|
||||
You can train your own LoRAs from the `Training` tab. See [Training LoRAs](Training-LoRAs.md) for details.
|
||||
|
90
extensions/api/blocking_api.py
Normal file
90
extensions/api/blocking_api.py
Normal file
@ -0,0 +1,90 @@
|
||||
import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
response = json.dumps({
|
||||
'result': shared.model_name
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_POST(self):
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||
|
||||
if self.path == '/api/v1/generate':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
generate_params = build_parameters(body)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
|
||||
answer = ''
|
||||
for a in generator:
|
||||
if isinstance(a, str):
|
||||
answer = a
|
||||
else:
|
||||
answer = a[0]
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'text': answer if shared.is_chat() else answer[len(prompt):]
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
elif self.path == '/api/v1/token-count':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
tokens = encode(body['prompt'])[0]
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
|
||||
def _run_server(port: int, share: bool=False):
|
||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
|
||||
server = ThreadingHTTPServer((address, port), Handler)
|
||||
|
||||
def on_start(public_url: str):
|
||||
print(f'Starting non-streaming server at public url {public_url}/api')
|
||||
|
||||
if share:
|
||||
try:
|
||||
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
print(
|
||||
f'Starting API at http://{address}:{port}/api')
|
||||
|
||||
server.serve_forever()
|
||||
|
||||
|
||||
def start_server(port: int, share: bool = False):
|
||||
Thread(target=_run_server, args=[port, share], daemon=True).start()
|
@ -1 +1,2 @@
|
||||
flask_cloudflared==0.0.12
|
||||
websockets==11.0.2
|
@ -1,115 +1,10 @@
|
||||
import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import extensions.api.blocking_api as blocking_api
|
||||
import extensions.api.streaming_api as streaming_api
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
params = {
|
||||
'port': 5000,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path == '/api/v1/model':
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
response = json.dumps({
|
||||
'result': shared.model_name
|
||||
})
|
||||
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
def do_POST(self):
|
||||
content_length = int(self.headers['Content-Length'])
|
||||
body = json.loads(self.rfile.read(content_length).decode('utf-8'))
|
||||
|
||||
if self.path == '/api/v1/generate':
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
prompt = body['prompt']
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_length', 200)),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical', 1)),
|
||||
'repetition_penalty': float(body.get('rep_pen', 1.1)),
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': int(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings)
|
||||
answer = ''
|
||||
for a in generator:
|
||||
if isinstance(a, str):
|
||||
answer = a
|
||||
else:
|
||||
answer = a[0]
|
||||
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'text': answer if shared.is_chat() else answer[len(prompt):]
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
elif self.path == '/api/v1/token-count':
|
||||
# Not compatible with KoboldAI api
|
||||
self.send_response(200)
|
||||
self.send_header('Content-Type', 'application/json')
|
||||
self.end_headers()
|
||||
|
||||
tokens = encode(body['prompt'])[0]
|
||||
response = json.dumps({
|
||||
'results': [{
|
||||
'tokens': len(tokens)
|
||||
}]
|
||||
})
|
||||
self.wfile.write(response.encode('utf-8'))
|
||||
|
||||
else:
|
||||
self.send_error(404)
|
||||
|
||||
|
||||
def run_server():
|
||||
server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port'])
|
||||
server = ThreadingHTTPServer(server_addr, Handler)
|
||||
if shared.args.share:
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
public_url = _run_cloudflared(params['port'], params['port'] + 1)
|
||||
print(f'Starting KoboldAI compatible api at {public_url}/api')
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
else:
|
||||
print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api')
|
||||
server.serve_forever()
|
||||
|
||||
BLOCKING_PORT = 5000
|
||||
STREAMING_PORT = 5005
|
||||
|
||||
def setup():
|
||||
Thread(target=run_server, daemon=True).start()
|
||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||
|
82
extensions/api/streaming_api.py
Normal file
82
extensions/api/streaming_api.py
Normal file
@ -0,0 +1,82 @@
|
||||
import json
|
||||
import asyncio
|
||||
from websockets.server import serve
|
||||
from threading import Thread
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
||||
|
||||
async def _handle_connection(websocket, path):
|
||||
|
||||
if path != PATH:
|
||||
print(f'Streaming api: unknown path: {path}')
|
||||
return
|
||||
|
||||
async for message in websocket:
|
||||
message = json.loads(message)
|
||||
|
||||
prompt = message['prompt']
|
||||
generate_params = build_parameters(message)
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
|
||||
generator = generate_reply(
|
||||
prompt, generate_params, stopping_strings=stopping_strings)
|
||||
|
||||
# As we stream, only send the new bytes.
|
||||
skip_index = len(prompt) if not shared.is_chat() else 0
|
||||
message_num = 0
|
||||
|
||||
for a in generator:
|
||||
to_send = ''
|
||||
if isinstance(a, str):
|
||||
to_send = a[skip_index:]
|
||||
else:
|
||||
to_send = a[0][skip_index:]
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'text_stream',
|
||||
'message_num': message_num,
|
||||
'text': to_send
|
||||
}))
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
skip_index += len(to_send)
|
||||
message_num += 1
|
||||
|
||||
await websocket.send(json.dumps({
|
||||
'event': 'stream_end',
|
||||
'message_num': message_num
|
||||
}))
|
||||
|
||||
|
||||
async def _run(host: str, port: int):
|
||||
async with serve(_handle_connection, host, port):
|
||||
await asyncio.Future() # run forever
|
||||
|
||||
|
||||
def _run_server(port: int, share: bool = False):
|
||||
address = '0.0.0.0' if shared.args.listen else '127.0.0.1'
|
||||
|
||||
def on_start(public_url: str):
|
||||
public_url = public_url.replace('https://', 'wss://')
|
||||
print(f'Starting streaming server at public url {public_url}{PATH}')
|
||||
|
||||
if share:
|
||||
try:
|
||||
try_start_cloudflared(port, max_attempts=3, on_start=on_start)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
else:
|
||||
print(f'Starting streaming server at ws://{address}:{port}{PATH}')
|
||||
|
||||
asyncio.run(_run(host=address, port=port))
|
||||
|
||||
|
||||
def start_server(port: int, share: bool = False):
|
||||
Thread(target=_run_server, args=[port, share], daemon=True).start()
|
69
extensions/api/util.py
Normal file
69
extensions/api/util.py
Normal file
@ -0,0 +1,69 @@
|
||||
|
||||
from threading import Thread
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from modules.text_generation import encode
|
||||
|
||||
|
||||
def build_parameters(body):
|
||||
prompt = body['prompt']
|
||||
|
||||
prompt_lines = [k.strip() for k in prompt.split('\n')]
|
||||
max_context = body.get('max_context_length', 2048)
|
||||
while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context:
|
||||
prompt_lines.pop(0)
|
||||
|
||||
prompt = '\n'.join(prompt_lines)
|
||||
|
||||
generate_params = {
|
||||
'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))),
|
||||
'do_sample': bool(body.get('do_sample', True)),
|
||||
'temperature': float(body.get('temperature', 0.5)),
|
||||
'top_p': float(body.get('top_p', 1)),
|
||||
'typical_p': float(body.get('typical_p', body.get('typical', 1))),
|
||||
'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))),
|
||||
'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)),
|
||||
'top_k': int(body.get('top_k', 0)),
|
||||
'min_length': int(body.get('min_length', 0)),
|
||||
'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)),
|
||||
'num_beams': int(body.get('num_beams', 1)),
|
||||
'penalty_alpha': float(body.get('penalty_alpha', 0)),
|
||||
'length_penalty': float(body.get('length_penalty', 1)),
|
||||
'early_stopping': bool(body.get('early_stopping', False)),
|
||||
'seed': int(body.get('seed', -1)),
|
||||
'add_bos_token': int(body.get('add_bos_token', True)),
|
||||
'truncation_length': int(body.get('truncation_length', 2048)),
|
||||
'ban_eos_token': bool(body.get('ban_eos_token', False)),
|
||||
'skip_special_tokens': bool(body.get('skip_special_tokens', True)),
|
||||
'custom_stopping_strings': '', # leave this blank
|
||||
'stopping_strings': body.get('stopping_strings', []),
|
||||
}
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
Thread(target=_start_cloudflared, args=[
|
||||
port, max_attempts, on_start], daemon=True).start()
|
||||
|
||||
|
||||
def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None):
|
||||
try:
|
||||
from flask_cloudflared import _run_cloudflared
|
||||
except ImportError:
|
||||
print('You should install flask_cloudflared manually')
|
||||
raise Exception(
|
||||
'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.')
|
||||
|
||||
for _ in range(max_attempts):
|
||||
try:
|
||||
public_url = _run_cloudflared(port, port + 1)
|
||||
|
||||
if on_start:
|
||||
on_start(public_url)
|
||||
|
||||
return
|
||||
except Exception:
|
||||
time.sleep(3)
|
||||
|
||||
raise Exception('Could not start cloudflared.')
|
49
extensions/llava/README.md
Normal file
49
extensions/llava/README.md
Normal file
@ -0,0 +1,49 @@
|
||||
# LLaVA
|
||||
|
||||
## Description
|
||||
Adds [LLaVA](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui.
|
||||
|
||||
https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4
|
||||
|
||||
## Usage
|
||||
To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g), and then start server.py with `--extensions llava` argument.
|
||||
|
||||
When in ui, go to instruct mode, and select LLaVA template, you also should add `"\n###"` to "Custom stopping strings" in parameters tab.
|
||||
|
||||
Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated.
|
||||
|
||||
To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `<image>` in your prompt.
|
||||
|
||||
Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox.
|
||||
|
||||
## Extension config
|
||||
This extension uses following parameters (from settings.json):
|
||||
|Parameter|Description|
|
||||
|---------|-----------|
|
||||
|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)|
|
||||
|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)|
|
||||
|`llava-projector_bits`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available|
|
||||
|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox|
|
||||
|
||||
## Technical description
|
||||
|
||||
### Original LLaVA
|
||||
The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows:
|
||||
```
|
||||
images --> CLIP --> projector --> input embeddings for images --> |
|
||||
| --> LLaMA
|
||||
prompt -------------------------> input embeddings for text ----> |
|
||||
```
|
||||
The images are represented in the prompt by the following token IDs:
|
||||
- 32000 - `<im_patch>` - placeholder token for embeddings from projector
|
||||
- 32001 - `<im_start>` - token marking start of an image
|
||||
- 32002 - `<im_end>` - token marking end of an image
|
||||
|
||||
By default, image will be represented as `<im_start><im_patch>*256<im_end>`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `<im_patch>` tokens.
|
||||
The concatenated prompt then gets fed to fine-tuned LLaMA.
|
||||
|
||||
### In this extension
|
||||
|
||||
Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features.
|
||||
Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU).
|
267
extensions/llava/script.py
Normal file
267
extensions/llava/script.py
Normal file
@ -0,0 +1,267 @@
|
||||
import base64
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModel
|
||||
|
||||
from modules import shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
|
||||
params = {
|
||||
"add_all_images_to_prompt": False,
|
||||
# device to run CLIP on
|
||||
"clip_device": None,
|
||||
# bits to load clip in either 32 or 16 (it doesn't support 8-bit)
|
||||
"clip_bits": 32,
|
||||
# device to run projector on
|
||||
"projector_device": None,
|
||||
# projector bits, either 32 or 16
|
||||
"projector_bits": 32
|
||||
}
|
||||
|
||||
|
||||
# If 'state' is True, will hijack the next chat generation
|
||||
input_hijack = {
|
||||
'state': False,
|
||||
'value': ["", ""]
|
||||
}
|
||||
|
||||
|
||||
# initialized in ui, so that params are loaded from settings
|
||||
llava_embedder = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
token: str
|
||||
id: int
|
||||
|
||||
|
||||
class LLaVAEmbedder:
|
||||
IM_PATCH = Token("<im_patch>", 32000)
|
||||
IM_START = Token("<im_start>", 32001)
|
||||
IM_END = Token("<im_end>", 32002)
|
||||
CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14'
|
||||
PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0'
|
||||
PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin'
|
||||
|
||||
def __init__(self):
|
||||
self.clip_device = self._get_device("clip_device")
|
||||
self.clip_dtype = self._get_dtype("clip_bits")
|
||||
self.projector_device = self._get_device("projector_device")
|
||||
self.projector_dtype = self._get_dtype("projector_bits")
|
||||
self.image_processor, self.vision_tower, self.mm_projector = self._load_models()
|
||||
|
||||
def _get_device(self, setting_name):
|
||||
if params[setting_name] is None:
|
||||
return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
return torch.device(params[setting_name])
|
||||
|
||||
def _get_dtype(self, setting_name):
|
||||
return torch.float32 if int(params[setting_name]) == 32 else torch.float16
|
||||
|
||||
def _load_models(self):
|
||||
start_ts = time.time()
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...")
|
||||
image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype)
|
||||
vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device)
|
||||
|
||||
print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...")
|
||||
projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE)
|
||||
mm_projector = torch.nn.Linear(1024, 5120)
|
||||
projector_data = torch.load(projector_path)
|
||||
mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False)
|
||||
mm_projector = mm_projector.to(self.projector_device)
|
||||
|
||||
print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds")
|
||||
return image_processor, vision_tower, mm_projector
|
||||
|
||||
def _update_prompt(self, prompt, images):
|
||||
for _ in images:
|
||||
# replace the image token with the image patch token in the prompt (each occurrence)
|
||||
replace_token = LLaVAEmbedder.IM_PATCH.token * 256
|
||||
replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token
|
||||
prompt = re.sub(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', replace_token, prompt, 1)
|
||||
return prompt
|
||||
|
||||
def _extract_image_features(self, images):
|
||||
images = self.image_processor(images, return_tensors='pt')['pixel_values']
|
||||
images = images.to(self.clip_device, dtype=self.clip_dtype)
|
||||
|
||||
with torch.no_grad():
|
||||
image_forward_outs = self.vision_tower(images, output_hidden_states=True)
|
||||
select_hidden_state_layer = -2
|
||||
select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer]
|
||||
image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype)
|
||||
image_features = self.mm_projector(image_features)
|
||||
return image_features
|
||||
|
||||
def forward(self, prompt, images, state):
|
||||
prompt = self._update_prompt(prompt, images)
|
||||
input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0]
|
||||
input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device)
|
||||
|
||||
if input_ids[0] == LLaVAEmbedder.IM_PATCH.id:
|
||||
# prompt got truncated in the middle of an image, remove the image data
|
||||
im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0]
|
||||
input_ids = input_ids[im_end+1:]
|
||||
input_embeds = input_embeds[im_end+1:]
|
||||
leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0]
|
||||
print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens")
|
||||
images = images[-leftover_images:]
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds, 0
|
||||
|
||||
total_embedded = 0
|
||||
image_features = self._extract_image_features(images).to(self.projector_device)
|
||||
image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0]
|
||||
|
||||
if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0:
|
||||
# multimodal LLM, but the current prompt is not multimodal/truncated
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
cur_image_idx = 0
|
||||
if not params['add_all_images_to_prompt']:
|
||||
image_start_tokens = [image_start_tokens[-1]]
|
||||
cur_image_idx = -1
|
||||
|
||||
for image_start_token_pos in image_start_tokens:
|
||||
cur_image_features = image_features[cur_image_idx]
|
||||
num_patches = cur_image_features.shape[0]
|
||||
input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0)
|
||||
cur_image_idx += 1
|
||||
total_embedded += 1
|
||||
|
||||
return prompt, input_ids, input_embeds, total_embedded
|
||||
|
||||
@staticmethod
|
||||
def len_in_tokens(text):
|
||||
images = re.findall(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', text)
|
||||
image_tokens = 0
|
||||
for _ in images:
|
||||
image_tokens += 258
|
||||
return len(encode(re.sub(r'<img src="data:image/jpeg;base64,[A-Za-z0-9+/=]+">', '', text))[0]) + image_tokens
|
||||
|
||||
|
||||
def add_chat_picture(picture, text, visible_text):
|
||||
# resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable)
|
||||
max_hw, min_hw = max(picture.size), min(picture.size)
|
||||
aspect_ratio = max_hw / min_hw
|
||||
shortest_edge = int(max(300 / aspect_ratio, 224))
|
||||
longest_edge = int(shortest_edge * aspect_ratio)
|
||||
w = shortest_edge if picture.width < picture.height else longest_edge
|
||||
h = shortest_edge if picture.width >= picture.height else longest_edge
|
||||
picture = picture.resize((w,h))
|
||||
|
||||
buffer = BytesIO()
|
||||
picture.save(buffer, format="JPEG")
|
||||
img_str = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
||||
image = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||
|
||||
|
||||
if '<image>' in text:
|
||||
text = text.replace('<image>', image)
|
||||
else:
|
||||
text = text + '\n' + image
|
||||
|
||||
if visible_text == '' or visible_text is None:
|
||||
visible_text = text
|
||||
elif '<image>' in visible_text:
|
||||
visible_text = visible_text.replace('<image>', image)
|
||||
else:
|
||||
visible_text = visible_text + '\n' + image
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False
|
||||
_continue = kwargs['_continue'] if '_continue' in kwargs else False
|
||||
also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False
|
||||
rows = [f"{state['context'].strip()}\n"]
|
||||
min_rows = 3
|
||||
|
||||
# Finding the maximum prompt size
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
prefix1 = f"{state['name1']}: "
|
||||
prefix2 = f"{state['name2']}: "
|
||||
|
||||
i = len(shared.history['internal']) - 1
|
||||
while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length:
|
||||
if _continue and i == len(shared.history['internal']) - 1:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}")
|
||||
else:
|
||||
rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n")
|
||||
|
||||
string = shared.history['internal'][i][0]
|
||||
if string != '':
|
||||
rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n")
|
||||
|
||||
i -= 1
|
||||
|
||||
if impersonate:
|
||||
min_rows = 2
|
||||
rows.append(f"{prefix1}")
|
||||
elif not _continue:
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
return prompt
|
||||
|
||||
|
||||
def tokenizer_modifier(state, prompt, input_ids, input_embeds):
|
||||
global params
|
||||
start_ts = time.time()
|
||||
image_matches = re.finditer(r'<img src="data:image/jpeg;base64,([A-Za-z0-9+/=]+)">', prompt)
|
||||
images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches]
|
||||
|
||||
if len(images) == 0:
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state)
|
||||
print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s')
|
||||
return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device)
|
||||
|
||||
|
||||
def ui():
|
||||
global llava_embedder
|
||||
llava_embedder = LLaVAEmbedder()
|
||||
with gr.Column():
|
||||
picture_select = gr.Image(label='Send a picture', type='pil')
|
||||
# I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway
|
||||
single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one')
|
||||
# Prepare the input hijack
|
||||
picture_select.upload(
|
||||
lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}),
|
||||
[picture_select],
|
||||
None
|
||||
)
|
||||
picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None)
|
||||
single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None)
|
||||
shared.gradio['Generate'].click(lambda: None, None, picture_select)
|
||||
shared.gradio['textbox'].submit(lambda: None, None, picture_select)
|
@ -48,3 +48,7 @@ llama-[0-9]*b-4bit$:
|
||||
.*chatglm:
|
||||
mode: 'instruct'
|
||||
instruction_template: 'ChatGLM'
|
||||
.*llava:
|
||||
mode: 'instruct'
|
||||
model_type: 'llama'
|
||||
instruction_template: 'LLaVA'
|
||||
|
@ -135,7 +135,7 @@ def load_quantized(model_name):
|
||||
# Find the model type
|
||||
if not shared.args.model_type:
|
||||
name = model_name.lower()
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna'])):
|
||||
if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])):
|
||||
model_type = 'llama'
|
||||
elif any((k in name for k in ['opt-', 'galactica'])):
|
||||
model_type = 'opt'
|
||||
|
@ -1,52 +0,0 @@
|
||||
import json
|
||||
|
||||
import gradio as gr
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
# set this to True to rediscover the fn_index using the browser DevTools
|
||||
VISIBLE = False
|
||||
|
||||
|
||||
def generate_reply_wrapper(string):
|
||||
|
||||
# Provide defaults so as to not break the API on the client side when new parameters are added
|
||||
generate_params = {
|
||||
'max_new_tokens': 200,
|
||||
'do_sample': True,
|
||||
'temperature': 0.5,
|
||||
'top_p': 1,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.1,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 0,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'custom_stopping_strings': '',
|
||||
'truncation_length': 2048,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'stopping_strings': [],
|
||||
}
|
||||
params = json.loads(string)
|
||||
generate_params.update(params[1])
|
||||
stopping_strings = generate_params.pop('stopping_strings')
|
||||
for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings):
|
||||
yield i
|
||||
|
||||
|
||||
def create_apis():
|
||||
t1 = gr.Textbox(visible=VISIBLE)
|
||||
t2 = gr.Textbox(visible=VISIBLE)
|
||||
dummy = gr.Button(visible=VISIBLE)
|
||||
|
||||
input_params = [t1]
|
||||
output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']]
|
||||
dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen')
|
111
modules/chat.py
111
modules/chat.py
@ -10,7 +10,6 @@ from pathlib import Path
|
||||
import yaml
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
import modules.shared as shared
|
||||
from modules.extensions import apply_extensions
|
||||
from modules.html_generator import chat_html_wrapper, make_thumbnail
|
||||
@ -30,8 +29,8 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
chat_prompt_size = state['chat_prompt_size']
|
||||
if shared.soft_prompt:
|
||||
chat_prompt_size -= shared.soft_prompt_tensor.shape[1]
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
|
||||
max_length = min(get_max_prompt_length(state), chat_prompt_size)
|
||||
if is_instruct:
|
||||
prefix1 = f"{state['name1']}\n"
|
||||
prefix2 = f"{state['name2']}\n"
|
||||
@ -57,19 +56,18 @@ def generate_chat_prompt(user_input, state, **kwargs):
|
||||
min_rows = 2
|
||||
rows.append(f"{prefix1.strip() if not is_instruct else prefix1}")
|
||||
elif not _continue:
|
||||
|
||||
# Adding the user message
|
||||
if len(user_input) > 0:
|
||||
this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM
|
||||
rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n")
|
||||
|
||||
# Adding the Character prefix
|
||||
rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix"))
|
||||
rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}"))
|
||||
|
||||
while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length:
|
||||
rows.pop(1)
|
||||
prompt = ''.join(rows)
|
||||
|
||||
prompt = ''.join(rows)
|
||||
if also_return_rows:
|
||||
return prompt, rows
|
||||
else:
|
||||
@ -81,6 +79,7 @@ def get_stopping_strings(state):
|
||||
stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"]
|
||||
else:
|
||||
stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"]
|
||||
|
||||
stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]")
|
||||
return stopping_strings
|
||||
|
||||
@ -111,13 +110,13 @@ def extract_message_from_reply(reply, state):
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
break
|
||||
|
||||
return reply, next_character_found
|
||||
|
||||
|
||||
def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
print("No model is loaded! Select one in the Model tab.")
|
||||
yield shared.history['visible']
|
||||
@ -125,35 +124,36 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
|
||||
# Defining some variables
|
||||
cumulative_reply = ''
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None
|
||||
just_started = True
|
||||
visible_text = custom_generate_chat_prompt = None
|
||||
visible_text = None
|
||||
eos_token = '\n' if state['stop_at_newline'] else None
|
||||
stopping_strings = get_stopping_strings(state)
|
||||
|
||||
# Check if any extension wants to hijack this function call
|
||||
for extension, _ in extensions_module.iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
|
||||
# Preparing the input
|
||||
if not any((regenerate, _continue)):
|
||||
text, visible_text = apply_extensions('input_hijack', text, visible_text)
|
||||
if visible_text is None:
|
||||
visible_text = text
|
||||
if not _continue:
|
||||
text = apply_extensions(text, "input")
|
||||
|
||||
text = apply_extensions('input', text)
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
else:
|
||||
text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0]
|
||||
if regenerate:
|
||||
shared.history['visible'].pop()
|
||||
shared.history['internal'].pop()
|
||||
# *Is typing...*
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
elif _continue:
|
||||
last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]]
|
||||
yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']]
|
||||
|
||||
# Generating the prompt
|
||||
kwargs = {'_continue': _continue}
|
||||
if custom_generate_chat_prompt is None:
|
||||
prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs)
|
||||
if prompt is None:
|
||||
prompt = generate_chat_prompt(text, state, **kwargs)
|
||||
else:
|
||||
prompt = custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
# Yield *Is typing...*
|
||||
if not any((regenerate, _continue)):
|
||||
yield shared.history['visible'] + [[visible_text, shared.processing_message]]
|
||||
|
||||
# Generate
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
@ -164,26 +164,26 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
# Extracting the reply
|
||||
reply, next_character_found = extract_message_from_reply(reply, state)
|
||||
visible_reply = re.sub("(<USER>|<user>|{{user}})", state['name1'], reply)
|
||||
visible_reply = apply_extensions(visible_reply, "output")
|
||||
visible_reply = apply_extensions("output", visible_reply)
|
||||
if _continue:
|
||||
sep = ' ' if last_reply[0][-1] not in [' ', '\n'] else ''
|
||||
reply = last_reply[0] + sep + reply
|
||||
sep = ' ' if last_reply[1][-1] not in [' ', '\n'] else ''
|
||||
visible_reply = last_reply[1] + sep + visible_reply
|
||||
|
||||
# We need this global variable to handle the Stop event,
|
||||
# otherwise gradio gets confused
|
||||
if shared.stop_everything:
|
||||
return shared.history['visible']
|
||||
|
||||
if just_started:
|
||||
just_started = False
|
||||
if not _continue:
|
||||
shared.history['internal'].append(['', ''])
|
||||
shared.history['visible'].append(['', ''])
|
||||
|
||||
if _continue:
|
||||
sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply))
|
||||
shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}']
|
||||
shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}']
|
||||
else:
|
||||
shared.history['internal'][-1] = [text, reply]
|
||||
shared.history['visible'][-1] = [visible_text, visible_reply]
|
||||
if not shared.args.no_stream:
|
||||
yield shared.history['visible']
|
||||
if next_character_found:
|
||||
break
|
||||
@ -195,7 +195,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False):
|
||||
|
||||
|
||||
def impersonate_wrapper(text, state):
|
||||
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
print("No model is loaded! Select one in the Model tab.")
|
||||
yield ''
|
||||
@ -209,7 +208,6 @@ def impersonate_wrapper(text, state):
|
||||
|
||||
# Yield *Is typing...*
|
||||
yield shared.processing_message
|
||||
|
||||
for i in range(state['chat_generation_attempts']):
|
||||
reply = None
|
||||
for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings):
|
||||
@ -234,23 +232,16 @@ def regenerate_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
last_visible = shared.history['visible'].pop()
|
||||
last_internal = shared.history['internal'].pop()
|
||||
# Yield '*Is typing...*'
|
||||
yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(last_internal[0], state, regenerate=True):
|
||||
shared.history['visible'][-1] = [last_visible[0], history[-1][1]]
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper('', state, regenerate=True):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def continue_wrapper(text, state):
|
||||
if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0:
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
else:
|
||||
# Yield ' ...'
|
||||
yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True):
|
||||
yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode'])
|
||||
for history in chatbot_wrapper('', state, _continue=True):
|
||||
yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode'])
|
||||
|
||||
|
||||
def remove_last_message(name1, name2, mode):
|
||||
@ -273,14 +264,14 @@ def send_last_reply_to_input():
|
||||
def replace_last_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0:
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def send_dummy_message(text, name1, name2, mode):
|
||||
shared.history['visible'].append([text, ''])
|
||||
shared.history['internal'].append([apply_extensions(text, "input"), ''])
|
||||
shared.history['internal'].append([apply_extensions("input", text), ''])
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@ -288,8 +279,9 @@ def send_dummy_reply(text, name1, name2, mode):
|
||||
if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '':
|
||||
shared.history['visible'].append(['', ''])
|
||||
shared.history['internal'].append(['', ''])
|
||||
|
||||
shared.history['visible'][-1][1] = text
|
||||
shared.history['internal'][-1][1] = apply_extensions(text, "input")
|
||||
shared.history['internal'][-1][1] = apply_extensions("input", text)
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@ -303,11 +295,10 @@ def clear_chat_log(name1, name2, greeting, mode):
|
||||
|
||||
if greeting != '':
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Save cleared logs
|
||||
save_history(mode)
|
||||
|
||||
return chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
@ -328,8 +319,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
|
||||
for i in range(len(idx) - 1):
|
||||
messages.append(dialogue[idx[i]:idx[i + 1]].strip())
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
|
||||
messages.append(dialogue[idx[-1]:].strip())
|
||||
entry = ['', '']
|
||||
for i in messages:
|
||||
if i.startswith(f'{name1}:'):
|
||||
@ -338,6 +329,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
entry[1] = i[len(f'{name2}:'):].strip()
|
||||
if not (len(entry[0]) == 0 and len(entry[1]) == 0):
|
||||
history.append(entry)
|
||||
|
||||
entry = ['', '']
|
||||
|
||||
print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='')
|
||||
@ -346,6 +338,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode):
|
||||
print("\n")
|
||||
for line in column.strip().split('\n'):
|
||||
print("| " + line + "\n")
|
||||
|
||||
print("|\n")
|
||||
print("------------------------------")
|
||||
|
||||
@ -358,14 +351,17 @@ def save_history(mode, timestamp=False):
|
||||
if mode == 'instruct':
|
||||
if not timestamp:
|
||||
return
|
||||
|
||||
fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
else:
|
||||
if timestamp:
|
||||
fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json"
|
||||
else:
|
||||
fname = f"{shared.character}_persistent.json"
|
||||
|
||||
if not Path('logs').exists():
|
||||
Path('logs').mkdir()
|
||||
|
||||
with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2))
|
||||
|
||||
@ -396,8 +392,10 @@ def build_pygmalion_style_context(data):
|
||||
context = ""
|
||||
if 'char_persona' in data and data['char_persona'] != '':
|
||||
context += f"{data['char_name']}'s Persona: {data['char_persona']}\n"
|
||||
|
||||
if 'world_scenario' in data and data['world_scenario'] != '':
|
||||
context += f"Scenario: {data['world_scenario']}\n"
|
||||
|
||||
context = f"{context.strip()}\n<START>\n"
|
||||
return context
|
||||
|
||||
@ -412,6 +410,7 @@ def generate_pfp_cache(character):
|
||||
img = make_thumbnail(Image.open(path))
|
||||
img.save(Path('cache/pfp_character.png'), format='PNG')
|
||||
return img
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -475,7 +474,7 @@ def load_character(character, name1, name2, mode):
|
||||
# Insert greeting if it exists
|
||||
if greeting != "":
|
||||
shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]]
|
||||
shared.history['visible'] += [['', apply_extensions(greeting, "output")]]
|
||||
shared.history['visible'] += [['', apply_extensions("output", greeting)]]
|
||||
|
||||
# Create .json log files since they don't already exist
|
||||
save_history(mode)
|
||||
@ -483,10 +482,6 @@ def load_character(character, name1, name2, mode):
|
||||
return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode)
|
||||
|
||||
|
||||
def load_default_history(name1, name2):
|
||||
load_character("None", name1, name2, "chat")
|
||||
|
||||
|
||||
def upload_character(json_file, img, tavern=False):
|
||||
json_file = json_file if type(json_file) == str else json_file.decode('utf-8')
|
||||
data = json.loads(json_file)
|
||||
@ -495,13 +490,17 @@ def upload_character(json_file, img, tavern=False):
|
||||
while Path(f'characters/{outfile_name}.json').exists():
|
||||
outfile_name = f'{data["char_name"]}_{i:03d}'
|
||||
i += 1
|
||||
|
||||
if tavern:
|
||||
outfile_name = f'TavernAI-{outfile_name}'
|
||||
|
||||
with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f:
|
||||
f.write(json_file)
|
||||
|
||||
if img is not None:
|
||||
img = Image.open(io.BytesIO(img))
|
||||
img.save(Path(f'characters/{outfile_name}.png'))
|
||||
|
||||
print(f'New character saved to "characters/{outfile_name}.json".')
|
||||
return outfile_name
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
from functools import partial
|
||||
|
||||
import gradio as gr
|
||||
|
||||
@ -10,21 +11,39 @@ available_extensions = []
|
||||
setup_called = set()
|
||||
|
||||
|
||||
def apply_settings(extension, name):
|
||||
if not hasattr(extension, 'params'):
|
||||
return
|
||||
|
||||
for param in extension.params:
|
||||
_id = f"{name}-{param}"
|
||||
if _id not in shared.settings:
|
||||
continue
|
||||
|
||||
extension.params[param] = shared.settings[_id]
|
||||
|
||||
|
||||
def load_extensions():
|
||||
global state, setup_called
|
||||
for i, name in enumerate(shared.args.extensions):
|
||||
if name in available_extensions:
|
||||
if name != 'api':
|
||||
print(f'Loading the extension "{name}"... ', end='')
|
||||
try:
|
||||
exec(f"import extensions.{name}.script")
|
||||
extension = getattr(extensions, name).script
|
||||
apply_settings(extension, name)
|
||||
if extension not in setup_called and hasattr(extension, "setup"):
|
||||
setup_called.add(extension)
|
||||
extension.setup()
|
||||
|
||||
state[name] = [True, i]
|
||||
if name != 'api':
|
||||
print('Ok.')
|
||||
except:
|
||||
if name != 'api':
|
||||
print('Fail.')
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
@ -36,32 +55,74 @@ def iterator():
|
||||
|
||||
|
||||
# Extension functions that map string -> string
|
||||
def apply_extensions(text, typ):
|
||||
def _apply_string_extensions(function_name, text):
|
||||
for extension, _ in iterator():
|
||||
if typ == "input" and hasattr(extension, "input_modifier"):
|
||||
text = extension.input_modifier(text)
|
||||
elif typ == "output" and hasattr(extension, "output_modifier"):
|
||||
text = extension.output_modifier(text)
|
||||
elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"):
|
||||
text = extension.bot_prefix_modifier(text)
|
||||
if hasattr(extension, function_name):
|
||||
text = getattr(extension, function_name)(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# Input hijack of extensions
|
||||
def _apply_input_hijack(text, visible_text):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'input_hijack') and extension.input_hijack['state']:
|
||||
extension.input_hijack['state'] = False
|
||||
if callable(extension.input_hijack['value']):
|
||||
text, visible_text = extension.input_hijack['value'](text, visible_text)
|
||||
else:
|
||||
text, visible_text = extension.input_hijack['value']
|
||||
|
||||
return text, visible_text
|
||||
|
||||
|
||||
# custom_generate_chat_prompt handling
|
||||
def _apply_custom_generate_chat_prompt(text, state, **kwargs):
|
||||
custom_generate_chat_prompt = None
|
||||
for extension, _ in iterator():
|
||||
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
|
||||
custom_generate_chat_prompt = extension.custom_generate_chat_prompt
|
||||
|
||||
if custom_generate_chat_prompt is not None:
|
||||
return custom_generate_chat_prompt(text, state, **kwargs)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, function_name):
|
||||
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
|
||||
|
||||
return prompt, input_ids, input_embeds
|
||||
|
||||
|
||||
EXTENSION_MAP = {
|
||||
"input": partial(_apply_string_extensions, "input_modifier"),
|
||||
"output": partial(_apply_string_extensions, "output_modifier"),
|
||||
"bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"),
|
||||
"tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"),
|
||||
"input_hijack": _apply_input_hijack,
|
||||
"custom_generate_chat_prompt": _apply_custom_generate_chat_prompt
|
||||
}
|
||||
|
||||
|
||||
def apply_extensions(typ, *args, **kwargs):
|
||||
if typ not in EXTENSION_MAP:
|
||||
raise ValueError(f"Invalid extension type {typ}")
|
||||
|
||||
return EXTENSION_MAP[typ](*args, **kwargs)
|
||||
|
||||
|
||||
def create_extensions_block():
|
||||
global setup_called
|
||||
|
||||
# Updating the default values
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, 'params'):
|
||||
for param in extension.params:
|
||||
_id = f"{name}-{param}"
|
||||
if _id in shared.settings:
|
||||
extension.params[param] = shared.settings[_id]
|
||||
|
||||
should_display_ui = False
|
||||
for extension, name in iterator():
|
||||
if hasattr(extension, "ui"):
|
||||
should_display_ui = True
|
||||
break
|
||||
|
||||
# Creating the extension ui elements
|
||||
if should_display_ui:
|
||||
|
@ -24,7 +24,8 @@ class LlamaCppModel:
|
||||
'model_path': str(path),
|
||||
'n_ctx': 2048,
|
||||
'seed': 0,
|
||||
'n_threads': shared.args.threads or None
|
||||
'n_threads': shared.args.threads or None,
|
||||
'n_batch': shared.args.n_batch
|
||||
}
|
||||
self.model = Llama(**params)
|
||||
self.model.set_cache(LlamaCache)
|
||||
|
@ -50,6 +50,8 @@ def find_model_type(model_name):
|
||||
return 'chatglm'
|
||||
elif 'galactica' in model_name:
|
||||
return 'galactica'
|
||||
elif 'llava' in model_name:
|
||||
return 'llava'
|
||||
elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])):
|
||||
return 'gpt4chan'
|
||||
else:
|
||||
@ -217,6 +219,7 @@ def load_model(model_name):
|
||||
tokenizer = None
|
||||
|
||||
# Try to load an universal LLaMA tokenizer
|
||||
if shared.model_type != 'llava':
|
||||
for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]:
|
||||
if p.exists():
|
||||
print(f"Loading the universal LLaMA tokenizer from {p}...")
|
||||
|
@ -20,6 +20,9 @@ processing_message = '*Is typing...*'
|
||||
# UI elements (buttons, sliders, HTML, etc)
|
||||
gradio = {}
|
||||
|
||||
# For keeping the values of UI elements on page reload
|
||||
persistent_interface_state = {}
|
||||
|
||||
# Generation input parameters
|
||||
input_params = []
|
||||
|
||||
@ -31,6 +34,7 @@ settings = {
|
||||
'max_new_tokens_min': 1,
|
||||
'max_new_tokens_max': 2000,
|
||||
'seed': -1,
|
||||
'character': 'None',
|
||||
'name1': 'You',
|
||||
'name2': 'Assistant',
|
||||
'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.',
|
||||
@ -56,7 +60,7 @@ settings = {
|
||||
'chat_default_extensions': ["gallery"],
|
||||
'presets': {
|
||||
'default': 'Default',
|
||||
'.*(alpaca|llama)': "LLaMA-Precise",
|
||||
'.*(alpaca|llama|llava)': "LLaMA-Precise",
|
||||
'.*pygmalion': 'NovelAI-Storywriter',
|
||||
'.*RWKV': 'Naive',
|
||||
},
|
||||
@ -90,6 +94,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma
|
||||
parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.')
|
||||
parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.')
|
||||
parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.')
|
||||
parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.')
|
||||
parser.add_argument('--model', type=str, help='Name of the model to load by default.')
|
||||
parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.')
|
||||
parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models")
|
||||
@ -116,6 +121,7 @@ parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_
|
||||
|
||||
# llama.cpp
|
||||
parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.')
|
||||
parser.add_argument('--n_batch', type=int, default=8, help='Processing batch size for llama.cpp.')
|
||||
|
||||
# GPTQ
|
||||
parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.')
|
||||
@ -150,6 +156,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T
|
||||
parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.')
|
||||
parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None)
|
||||
|
||||
# API
|
||||
parser.add_argument('--api', action='store_true', help='Enable the API extension.')
|
||||
parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.')
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
args_defaults = parser.parse_args([])
|
||||
|
||||
@ -171,6 +182,13 @@ if args.trust_remote_code:
|
||||
if args.share:
|
||||
print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n")
|
||||
|
||||
# Activating the API extension
|
||||
if args.api or args.public_api:
|
||||
if args.extensions is None:
|
||||
args.extensions = ['api']
|
||||
elif 'api' not in args.extensions:
|
||||
args.extensions.append('api')
|
||||
|
||||
|
||||
def is_chat():
|
||||
return args.chat
|
||||
|
@ -113,9 +113,11 @@ def set_manual_seed(seed):
|
||||
seed = int(seed)
|
||||
if seed == -1:
|
||||
seed = random.randint(1, 2**31)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
return seed
|
||||
|
||||
|
||||
@ -123,8 +125,41 @@ def stop_everything_event():
|
||||
shared.stop_everything = True
|
||||
|
||||
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
def get_generate_params(state):
|
||||
generate_params = {}
|
||||
|
||||
# Models that are not on transformers
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
else:
|
||||
# FlexGen
|
||||
if shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
# transformers
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
|
||||
if state['ban_eos_token']:
|
||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({'use_cache': False})
|
||||
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
|
||||
return generate_params
|
||||
|
||||
|
||||
def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
if shared.model_name == 'None' or shared.model is None:
|
||||
print("No model is loaded! Select one in the Model tab.")
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
@ -133,40 +168,37 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
clear_torch_cache()
|
||||
seed = set_manual_seed(state['seed'])
|
||||
shared.stop_everything = False
|
||||
generate_params = {}
|
||||
generate_params = get_generate_params(state)
|
||||
t0 = time.time()
|
||||
|
||||
# Preparing the input
|
||||
original_question = question
|
||||
if not shared.is_chat():
|
||||
question = apply_extensions(question, 'input')
|
||||
question = apply_extensions('input', question)
|
||||
|
||||
# These models are not part of Hugging Face, so we handle them
|
||||
# separately and terminate the function call earlier
|
||||
# If the model is not on transformers, handle it separately and end this
|
||||
# function call earlier.
|
||||
if shared.model_type in ['rwkv', 'llamacpp']:
|
||||
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{question}\n--------------------\n')
|
||||
|
||||
for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']:
|
||||
generate_params[k] = state[k]
|
||||
generate_params['token_count'] = state['max_new_tokens']
|
||||
try:
|
||||
if shared.args.no_stream:
|
||||
reply = shared.model.generate(context=question, **generate_params)
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
else:
|
||||
if not shared.is_chat():
|
||||
yield formatted_outputs(question, shared.model_name)
|
||||
|
||||
# RWKV has proper streaming, which is very nice.
|
||||
# No need to generate 8 tokens at a time.
|
||||
for reply in shared.model.generate_with_streaming(context=question, **generate_params):
|
||||
output = original_question + reply
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
except Exception:
|
||||
@ -178,19 +210,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})')
|
||||
return
|
||||
|
||||
# Encode the input
|
||||
input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))
|
||||
original_input_ids = input_ids
|
||||
output = input_ids[0]
|
||||
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
if shared.args.verbose:
|
||||
print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n')
|
||||
|
||||
cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen))
|
||||
# Find the eos tokens
|
||||
eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else []
|
||||
if eos_token is not None:
|
||||
eos_token_ids.append(int(encode(eos_token)[0][-1]))
|
||||
|
||||
# Handling the stopping strings
|
||||
# Create the StoppingCriteriaList with the stopping strings
|
||||
stopping_criteria_list = transformers.StoppingCriteriaList()
|
||||
for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")):
|
||||
if type(st) is list and len(st) > 0:
|
||||
@ -198,30 +230,26 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0])))
|
||||
break
|
||||
|
||||
if not shared.args.flexgen:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']:
|
||||
generate_params[k] = state[k]
|
||||
# Update generate_params with the eos token and the stopping strings
|
||||
if shared.args.flexgen:
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
else:
|
||||
generate_params['eos_token_id'] = eos_token_ids
|
||||
generate_params['stopping_criteria'] = stopping_criteria_list
|
||||
if state['ban_eos_token']:
|
||||
generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id]
|
||||
else:
|
||||
for k in ['max_new_tokens', 'do_sample', 'temperature']:
|
||||
generate_params[k] = state[k]
|
||||
generate_params['stop'] = eos_token_ids[-1]
|
||||
if not shared.args.no_stream:
|
||||
generate_params['max_new_tokens'] = 8
|
||||
|
||||
if shared.args.no_cache:
|
||||
generate_params.update({'use_cache': False})
|
||||
if shared.args.deepspeed:
|
||||
generate_params.update({'synced_gpus': True})
|
||||
# Add the encoded tokens to generate_params
|
||||
if shared.soft_prompt:
|
||||
inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids)
|
||||
question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
generate_params.update({'inputs': filler_input_ids})
|
||||
else:
|
||||
question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None)
|
||||
original_input_ids = input_ids
|
||||
generate_params.update({'inputs': input_ids})
|
||||
if inputs_embeds is not None:
|
||||
generate_params.update({'inputs_embeds': inputs_embeds})
|
||||
|
||||
try:
|
||||
# Generate the entire reply at once.
|
||||
@ -237,7 +265,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
yield formatted_outputs(reply, shared.model_name)
|
||||
|
||||
@ -265,7 +293,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if output[-1] in eos_token_ids:
|
||||
break
|
||||
@ -285,7 +313,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]):
|
||||
new_tokens = len(output) - len(original_input_ids[0])
|
||||
reply = decode(output[-new_tokens:], state['skip_special_tokens'])
|
||||
if not shared.is_chat():
|
||||
reply = original_question + apply_extensions(reply, 'output')
|
||||
reply = original_question + apply_extensions('output', reply)
|
||||
|
||||
if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)):
|
||||
break
|
||||
|
@ -18,14 +18,14 @@ from modules.evaluate import calculate_perplexity, generate_markdown_table, save
|
||||
from server import get_available_loras, get_available_models
|
||||
|
||||
# This mapping is from a very recent commit, not yet released.
|
||||
# If not available, default to a backup map for the 3 safe model types.
|
||||
# If not available, default to a backup map for some common model types.
|
||||
try:
|
||||
from peft.utils.other import \
|
||||
TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \
|
||||
model_to_lora_modules
|
||||
except:
|
||||
standard_modules = ["q_proj", "v_proj"]
|
||||
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules}
|
||||
model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]}
|
||||
|
||||
WANT_INTERRUPT = False
|
||||
|
||||
@ -35,7 +35,8 @@ PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size",
|
||||
MODEL_CLASSES = {
|
||||
"LlamaForCausalLM": "llama",
|
||||
"OPTForCausalLM": "opt",
|
||||
"GPTJForCausalLM": "gptj"
|
||||
"GPTJForCausalLM": "gptj",
|
||||
"GPTNeoXForCausalLM": "gpt_neox"
|
||||
}
|
||||
|
||||
|
||||
@ -45,6 +46,8 @@ def get_datasets(path: str, ext: str):
|
||||
|
||||
def create_train_interface():
|
||||
with gr.Tab('Train LoRA', elem_id='lora-train-tab'):
|
||||
gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)")
|
||||
|
||||
with gr.Row():
|
||||
lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file')
|
||||
always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).')
|
||||
@ -215,11 +218,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch
|
||||
else:
|
||||
model_id = "llama"
|
||||
if model_type == "PeftModelForCausalLM":
|
||||
if len(shared.args.lora_names) > 0:
|
||||
yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})")
|
||||
yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.")
|
||||
else:
|
||||
yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*"
|
||||
print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})")
|
||||
time.sleep(5)
|
||||
|
||||
if shared.args.wbits > 0 and not shared.args.monkey_patch:
|
||||
|
@ -33,9 +33,10 @@ def list_model_elements():
|
||||
|
||||
|
||||
def list_interface_input_elements(chat=False):
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens']
|
||||
elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu']
|
||||
if chat:
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template']
|
||||
elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu']
|
||||
|
||||
elements += list_model_elements()
|
||||
return elements
|
||||
|
||||
@ -44,11 +45,26 @@ def gather_interface_values(*args):
|
||||
output = {}
|
||||
for i, element in enumerate(shared.input_elements):
|
||||
output[element] = args[i]
|
||||
|
||||
shared.persistent_interface_state = output
|
||||
return output
|
||||
|
||||
|
||||
def apply_interface_values(state):
|
||||
return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())]
|
||||
def apply_interface_values(state, use_persistent=False):
|
||||
if use_persistent:
|
||||
state = shared.persistent_interface_state
|
||||
|
||||
elements = list_interface_input_elements(chat=shared.is_chat())
|
||||
if len(state) == 0:
|
||||
return [gr.update() for k in elements] # Dummy, do nothing
|
||||
else:
|
||||
if use_persistent and 'mode' in state:
|
||||
if state['mode'] == 'instruct':
|
||||
return [state[k] if (k not in ['character_menu'] and k in state) else gr.update() for k in elements]
|
||||
else:
|
||||
return [state[k] if (k not in ['instruction_template'] and k in state) else gr.update() for k in elements]
|
||||
else:
|
||||
return [state[k] if k in state else gr.update() for k in elements]
|
||||
|
||||
|
||||
class ToolButton(gr.Button, gr.components.FormComponent):
|
||||
|
@ -16,5 +16,5 @@ tqdm
|
||||
git+https://github.com/huggingface/peft
|
||||
transformers==4.28.1
|
||||
bitsandbytes==0.38.1; platform_system != "Windows"
|
||||
llama-cpp-python==0.1.34; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
llama-cpp-python==0.1.36; platform_system != "Windows"
|
||||
https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.36/llama_cpp_python-0.1.36-cp310-cp310-win_amd64.whl; platform_system == "Windows"
|
||||
|
37
server.py
37
server.py
@ -32,6 +32,7 @@ import time
|
||||
import traceback
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import psutil
|
||||
@ -40,7 +41,7 @@ import yaml
|
||||
from PIL import Image
|
||||
|
||||
import modules.extensions as extensions_module
|
||||
from modules import api, chat, shared, training, ui
|
||||
from modules import chat, shared, training, ui
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
from modules.LoRA import add_lora_to_model
|
||||
from modules.models import load_model, load_soft_prompt, unload_model
|
||||
@ -543,7 +544,7 @@ def create_interface():
|
||||
shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False)
|
||||
|
||||
shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode')
|
||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value=shared.settings['instruction_template'], visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.')
|
||||
shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value='None', visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.')
|
||||
|
||||
with gr.Tab('Character', elem_id='chat-settings'):
|
||||
with gr.Row():
|
||||
@ -559,7 +560,7 @@ def create_interface():
|
||||
shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None)
|
||||
|
||||
with gr.Row():
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), value='None', label='Character', elem_id='character-menu')
|
||||
shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), label='Character', elem_id='character-menu')
|
||||
ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button')
|
||||
|
||||
with gr.Row():
|
||||
@ -710,14 +711,6 @@ def create_interface():
|
||||
set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then(
|
||||
lambda: None, None, None, _js='() => {document.body.innerHTML=\'<h1 style="font-family:monospace;margin-top:20%;color:lightgray;text-align:center;">Reloading...</h1>\'; setTimeout(function(){location.reload()},2500); return []}')
|
||||
|
||||
# Extensions block
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
# Create the invisible elements that define the API
|
||||
if not shared.is_chat():
|
||||
api.create_apis()
|
||||
|
||||
# chat mode event handlers
|
||||
if shared.is_chat():
|
||||
shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']]
|
||||
@ -801,16 +794,11 @@ def create_interface():
|
||||
shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']])
|
||||
shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display'])
|
||||
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}")
|
||||
shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']])
|
||||
shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None)
|
||||
shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True)
|
||||
|
||||
# notebook/default modes event handlers
|
||||
else:
|
||||
shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']]
|
||||
|
||||
if shared.args.notebook:
|
||||
output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']]
|
||||
else:
|
||||
@ -851,6 +839,11 @@ def create_interface():
|
||||
shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False)
|
||||
shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}")
|
||||
|
||||
shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False)
|
||||
# Extensions block
|
||||
if shared.args.extensions is not None:
|
||||
extensions_module.create_extensions_block()
|
||||
|
||||
# Launch the interface
|
||||
shared.gradio['interface'].queue()
|
||||
if shared.args.listen:
|
||||
@ -860,7 +853,6 @@ def create_interface():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# Loading custom settings
|
||||
settings_file = None
|
||||
if shared.args.settings is not None and Path(shared.args.settings).exists():
|
||||
@ -905,14 +897,15 @@ if __name__ == "__main__":
|
||||
print('The following models are available:\n')
|
||||
for i, model in enumerate(available_models):
|
||||
print(f'{i+1}. {model}')
|
||||
|
||||
print(f'\nWhich one do you want to load? 1-{len(available_models)}\n')
|
||||
i = int(input()) - 1
|
||||
print()
|
||||
|
||||
shared.model_name = available_models[i]
|
||||
|
||||
# If any model has been selected, load it
|
||||
if shared.model_name != 'None':
|
||||
|
||||
model_settings = get_model_specific_settings(shared.model_name)
|
||||
shared.settings.update(model_settings) # hijacking the interface defaults
|
||||
update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments
|
||||
@ -922,6 +915,14 @@ if __name__ == "__main__":
|
||||
if shared.args.lora:
|
||||
add_lora_to_model([shared.args.lora])
|
||||
|
||||
# Force a character to be loaded
|
||||
if shared.is_chat():
|
||||
shared.persistent_interface_state.update({
|
||||
'mode': shared.settings['mode'],
|
||||
'character_menu': shared.args.character or shared.settings['character'],
|
||||
'instruction_template': shared.settings['instruction_template']
|
||||
})
|
||||
|
||||
# Launch the web UI
|
||||
create_interface()
|
||||
while True:
|
||||
|
@ -3,6 +3,7 @@
|
||||
"max_new_tokens_min": 1,
|
||||
"max_new_tokens_max": 2000,
|
||||
"seed": -1,
|
||||
"character": "None",
|
||||
"name1": "You",
|
||||
"name2": "Assistant",
|
||||
"context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.",
|
||||
@ -30,7 +31,7 @@
|
||||
],
|
||||
"presets": {
|
||||
"default": "Default",
|
||||
".*(alpaca|llama)": "LLaMA-Precise",
|
||||
".*(alpaca|llama|llava)": "LLaMA-Precise",
|
||||
".*pygmalion": "NovelAI-Storywriter",
|
||||
".*RWKV": "Naive"
|
||||
},
|
||||
|
Loading…
Reference in New Issue
Block a user