From 459e725af9c73ba2043ab009f904fc9a09d833e6 Mon Sep 17 00:00:00 2001 From: "Alex \"mcmonkey\" Goodwin" <4000772+mcmonkey4eva@users.noreply.github.com> Date: Sun, 23 Apr 2023 08:54:41 -0700 Subject: [PATCH 01/19] Lora trainer docs (#1493) --- docs/Training-LoRAs.md | 167 +++++++++++++++++++++++++++++++++++++++++ docs/Using-LoRAs.md | 36 +-------- modules/training.py | 21 ++++-- 3 files changed, 182 insertions(+), 42 deletions(-) create mode 100644 docs/Training-LoRAs.md diff --git a/docs/Training-LoRAs.md b/docs/Training-LoRAs.md new file mode 100644 index 00000000..3d75ec5a --- /dev/null +++ b/docs/Training-LoRAs.md @@ -0,0 +1,167 @@ +## Training Your Own LoRAs + +The WebUI seeks to make training your own LoRAs as easy as possible. It comes down to just a few simple steps: + +### **Step 1**: Make a plan. +- What base model do you want to use? The LoRA you make has to be matched up to a single architecture (eg LLaMA-13B) and cannot be transferred to others (eg LLaMA-7B, StableLM, etc. would all be different). Derivatives of the same model (eg Alpaca finetune of LLaMA-13B) might be transferrable, but even then it's best to train exactly on what you plan to use. +- What model format do you want? At time of writing, 8-bit models are most stable, and 4-bit are supported but experimental. In the near future it is likely that 4-bit will be the best option for most users. +- What are you training it on? Do you want it to learn real information, a simple format, ...? + +### **Step 2**: Gather a dataset. +- If you use a dataset similar to the [Alpaca](https://github.com/gururise/AlpacaDataCleaned/blob/main/alpaca_data_cleaned.json) format, that is natively supported by the `Formatted Dataset` input in the WebUI, with premade formatter options. +- If you use a dataset that isn't matched to Alpaca's format, but uses the same basic JSON structure, you can make your own format file by copying `training/formats/alpaca-format.json` to a new file and [editing its content](#format-files). +- If you can get the dataset into a simple text file, that works too! You can train using the `Raw text file` input option. + - This means you can for example just copy/paste a chatlog/documentation page/whatever you want, shove it in a plain text file, and train on it. +- If you use a structured dataset not in this format, you may have to find an external way to convert it - or open an issue to request native support. + +### **Step 3**: Do the training. +- **3.1**: Load the WebUI, and your model. + - Make sure you don't have any LoRAs already loaded (unless you want to train for multi-LoRA usage). +- **3.2**: Open the `Training` tab at the top, `Train LoRA` sub-tab. +- **3.3**: Fill in the name lof the LoRA, select your dataset in the dataset options. +- **3.4**: Select other parameters to your preference. See [parameters below](#parameters). +- **3.5**: click `Start LoRA Training`, and wait. + - It can take a few hours for a large dataset, or just a few minute if doing a small run. + - You may want to monitor your [loss value](#loss) while it goes. + +### **Step 4**: Evaluate your results. +- Load the LoRA under the Models Tab. +- You can go test-drive it on the `Text generation` tab, or you can use the `Perplexity evaluation` sub-tab of the `Training` tab. +- If you used the `Save every n steps` option, you can grab prior copies of the model from sub-folders within the LoRA model's folder and try them instead. + +### **Step 5**: Re-run if you're unhappy. +- Make sure to unload the LoRA before training it. +- You can simply resume a prior run - use `Copy parameters from` to select your LoRA, and edit parameters. Note that you cannot change the `Rank` of an already created LoRA. + - If you want to resume from a checkpoint saved along the way, simply copy the contents of the checkpoint folder into the LoRA's folder. + - (Note: `adapter_model.bin` is the important file that holds the actual LoRA content). + - This will start Learning Rate and Steps back to the start. If you want to resume as if you were midway through, you can adjust your Learning Rate to the last reported LR in logs and reduce your epochs. +- Or, you can start over entirely if you prefer. +- If your model is producing corrupted outputs, you probably need to start over and use a lower Learning Rate. +- If your model isn't learning detailed information but you want it to, you might need to just run more epochs, or you might need a higher Rank. +- If your model is enforcing a format you didn't want, you may need to tweak your dataset, or start over and not train as far. + +## Format Files + +If using JSON formatted datasets, they are presumed to be in the following approximate format: + +```json +[ + { + "somekey": "somevalue", + "key2": "value2" + }, + { + // etc + } +] +``` + +Where the keys (eg `somekey`, `key2` above) are standardized, and relatively consistent across the dataset, and the values (eg `somevalue`, `value2`) contain the content actually intended to be trained. + +For Alpaca, the keys are `instruction`, `input`, and `output`, wherein `input` is sometimes blank. + +A simple format file for Alpaca to be used as a chat bot is: + +```json +{ + "instruction,output": "User: %instruction%\nAssistant: %output%", + "instruction,input,output": "User: %instruction%: %input%\nAssistant: %output%" +} +``` + +Note that the keys (eg `instruction,output`) are a comma-separated list of dataset keys, and the values are a simple string that use those keys with `%%`. + +So for example if a dataset has `"instruction": "answer my question"`, then the format file's `User: %instruction%\n` will be automatically filled in as `User: answer my question\n`. + +If you have different sets of key inputs, you can make your own format file to match it. This format-file is designed to be as simple as possible to enable easy editing to match your needs. + +## Parameters + +The basic purpose and function of each parameter is documented on-page in the WebUI, so read through them in the UI to understand your options. + +That said, here's a guide to the most important parameter choices you should consider: + +### VRAM + +- First, you must consider your VRAM availability. + - Generally, under default settings, VRAM usage for training with default parameters is very close to when generating text (with 1000+ tokens of context) (ie, if you can generate text, you can train LoRAs). + - Note: worse by default in the 4-bit monkeypatch currently. Reduce `Micro Batch Size` to `1` to restore this to expectations. + - If you have VRAM to spare, setting higher batch sizes will use more VRAM and get you better quality training in exchange. + - If you have large data, setting a higher cutoff length may be beneficial, but will cost significant VRAM. If you can spare some, set your batch size to `1` and see how high you can push your cutoff length. + - If you're low on VRAM, reducing batch size or cutoff length will of course improve that. + - Don't be afraid to just try it and see what happens. If it's too much, it will just error out, and you can lower settings and try again. + +### Rank + +- Second, you want to consider the amount of learning you want. + - For example, you may wish to just learn a dialogue format (as in the case of Alpaca) in which case setting a low `Rank` value (32 or lower) works great. + - Or, you might be training on project documentation you want the bot to understand and be able to understand questions about, in which case the higher the rank, the better. + - Generally, higher Rank = more precise learning = more total content learned = more VRAM usage while training. + +### Learning Rate and Epochs + +- Third, how carefully you want it to be learned. + - In other words, how okay or not you are with the model losing unrelated understandings. + - You can control this with 3 key settings: the Learning Rate, its scheduler, and your total epochs. + - The learning rate controls how much change is made to the model by each token it sees. + - It's in scientific notation normally, so for example `3e-4` means `3 * 10^-4` which is `0.0003`. The number after `e-` controls how many `0`s are in the number. + - Higher values let training run faster, but also are more likely to corrupt prior data in the model. + - You essentially have two variables to balance: the LR, and Epochs. + - If you make LR higher, you can set Epochs equally lower to match. High LR + low epochs = very fast, low quality training. + - If you make LR low, set epochs high. Low LR + high epochs = slow but high-quality training. + - The scheduler controls change-over-time as you train - it starts high, and then goes low. This helps balance getting data in, and having decent quality, at the same time. + - You can see graphs of the different scheduler options [in the HuggingFace docs here](https://moon-ci-docs.huggingface.co/docs/transformers/pr_1/en/main_classes/optimizer_schedules#transformers.SchedulerType) + +## Loss + +When you're running training, the WebUI's console window will log reports that include, among other things, a numeric value named `Loss`. It will start as a high number, and gradually get lower and lower as it goes. + +"Loss" in the world of AI training theoretically means "how close is the model to perfect", with `0` meaning "absolutely perfect". This is calculated by measuring the difference between the model outputting exactly the text you're training it to output, and what it actually outputs. + +In practice, a good LLM should have a very complex variable range of ideas running in its artificial head, so a loss of `0` would indicate that the model has broken and forgotten to how think about anything other than what you trained it. + +So, in effect, Loss is a balancing game: you want to get it low enough that it understands your data, but high enough that it isn't forgetting everything else. Generally, if it goes below `1.0`, it's going to start forgetting its prior memories, and you should stop training. In some cases you may prefer to take it as low as `0.5` (if you want it to be very very predictable). Different goals have different needs, so don't be afraid to experiment and see what works best for you. + +Note: if you see Loss start at or suddenly jump to exactly `0`, it is likely something has gone wrong in your training process (eg model corruption). + +## Note: 4-Bit Monkeypatch + +The [4-bit LoRA monkeypatch](GPTQ-models-(4-bit-mode).md#using-loras-in-4-bit-mode) works for training, but has side effects: +- VRAM usage is higher currently. You can reduce the `Micro Batch Size` to `1` to compensate. +- Models do funky things. LoRAs apply themselves, or refuse to apply, or spontaneously error out, or etc. It can be helpful to reload base model or restart the WebUI between training/usage to minimize chances of anything going haywire. +- Loading or working with multiple LoRAs at the same time doesn't currently work. +- Generally, recognize and treat the monkeypatch as the dirty temporary hack it is - it works, but isn't very stable. It will get better in time when everything is merged upstream for full official support. + +## Legacy notes + +LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570). + +### Using the original alpaca-lora code + +Kept here for reference. The Training tab has much more features than this method. + +``` +conda activate textgen +git clone https://github.com/tloen/alpaca-lora +``` + +Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda: + +``` +model = LlamaForCausalLM.from_pretrained( + "models/llama-7b", + load_in_8bit=True, + device_map="auto", +) +tokenizer = LlamaTokenizer.from_pretrained( + "models/llama-7b", add_eos_token=True +) +``` + +Run the script with: + +``` +python finetune.py +``` + +It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode). diff --git a/docs/Using-LoRAs.md b/docs/Using-LoRAs.md index de271e3d..0a679c0f 100644 --- a/docs/Using-LoRAs.md +++ b/docs/Using-LoRAs.md @@ -52,38 +52,4 @@ print(f"Predicted {len(output)} tokens for '{sentence}':\n{output}") ## Training a LoRA -The Training tab in the interface can be used to train a LoRA. The parameters are self-documenting and good defaults are included. - -You can interrupt and resume LoRA training in this tab. If the name and rank are the same, training will resume using the `adapter_model.bin` in your LoRA folder. You can resume from a past checkpoint by replacing this file using the contents of one of the checkpoint folders. Note that the learning rate and steps will be reset, and you may want to set the learning rate to the last reported rate in the console output. - -LoRA training was contributed by [mcmonkey4eva](https://github.com/mcmonkey4eva) in PR [#570](https://github.com/oobabooga/text-generation-webui/pull/570). - -#### Using the original alpaca-lora code - -Kept here for reference. The Training tab has much more features than this method. - -``` -conda activate textgen -git clone https://github.com/tloen/alpaca-lora -``` - -Edit those two lines in `alpaca-lora/finetune.py` to use your existing model folder instead of downloading everything from decapoda: - -``` -model = LlamaForCausalLM.from_pretrained( - "models/llama-7b", - load_in_8bit=True, - device_map="auto", -) -tokenizer = LlamaTokenizer.from_pretrained( - "models/llama-7b", add_eos_token=True -) -``` - -Run the script with: - -``` -python finetune.py -``` - -It just works. It runs at 22.32s/it, with 1170 iterations in total, so about 7 hours and a half for training a LoRA. RTX 3090, 18153MiB VRAM used, drawing maximum power (350W, room heater mode). +You can train your own LoRAs from the `Training` tab. See [Training LoRAs](Training-LoRAs.md) for details. diff --git a/modules/training.py b/modules/training.py index 70629ef3..cde4a555 100644 --- a/modules/training.py +++ b/modules/training.py @@ -18,14 +18,14 @@ from modules.evaluate import calculate_perplexity, generate_markdown_table, save from server import get_available_loras, get_available_models # This mapping is from a very recent commit, not yet released. -# If not available, default to a backup map for the 3 safe model types. +# If not available, default to a backup map for some common model types. try: from peft.utils.other import \ TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING as \ model_to_lora_modules except: standard_modules = ["q_proj", "v_proj"] - model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules} + model_to_lora_modules = {"llama": standard_modules, "opt": standard_modules, "gptj": standard_modules, "gpt_neox": ["query_key_value"]} WANT_INTERRUPT = False @@ -35,7 +35,8 @@ PARAMETERS = ["lora_name", "always_override", "save_steps", "micro_batch_size", MODEL_CLASSES = { "LlamaForCausalLM": "llama", "OPTForCausalLM": "opt", - "GPTJForCausalLM": "gptj" + "GPTJForCausalLM": "gptj", + "GPTNeoXForCausalLM": "gpt_neox" } @@ -45,6 +46,8 @@ def get_datasets(path: str, ext: str): def create_train_interface(): with gr.Tab('Train LoRA', elem_id='lora-train-tab'): + gr.Markdown("Confused? [[Click here for a guide]](https://github.com/oobabooga/text-generation-webui/blob/main/docs/Training-LoRAs.md)") + with gr.Row(): lora_name = gr.Textbox(label='Name', info='The name of your new LoRA file') always_override = gr.Checkbox(label='Override Existing Files', value=False, info='If the name given is the same as an existing file, checking this will replace that file. Leaving unchecked will load that file and continue from it (must use the same rank value as the original had).') @@ -215,11 +218,15 @@ def do_train(lora_name: str, always_override: bool, save_steps: int, micro_batch else: model_id = "llama" if model_type == "PeftModelForCausalLM": - yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" - print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") + if len(shared.args.lora_names) > 0: + yield "You are trying to train a LoRA while you already have another LoRA loaded. This will work, but may have unexpected effects. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + print("Warning: Training LoRA over top of another LoRA. May have unexpected effects.") + else: + yield "Model ID not matched due to LoRA loading. Consider reloading base model. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + print("Warning: Model ID not matched due to LoRA loading. Consider reloading base model.") else: - yield "LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" - print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, and GPT-J models. (Found model type: {model_type})") + yield "LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. Unexpected errors may follow. *(Will continue anyway in 5 seconds, press `Interrupt` to stop.)*" + print(f"Warning: LoRA training has only currently been validated for LLaMA, OPT, GPT-J, and GPT-NeoX models. (Found model type: {model_type})") time.sleep(5) if shared.args.wbits > 0 and not shared.args.monkey_patch: From 654933c634d5f685618db3bf0f2ff76de5524425 Mon Sep 17 00:00:00 2001 From: Andy Salerno Date: Sun, 23 Apr 2023 11:52:43 -0700 Subject: [PATCH 02/19] New universal API with streaming/blocking endpoints (#990) Previous title: Add api_streaming extension and update api-example-stream to use it * Merge with latest main * Add parameter capturing encoder_repetition_penalty * Change some defaults, minor fixes * Add --api, --public-api flags * remove unneeded/broken comment from blocking API startup. The comment is already correctly emitted in try_start_cloudflared by calling the lambda we pass in. * Update on_start message for blocking_api, it should say 'non-streaming' and not 'streaming' * Update the API examples * Change a comment * Update README * Remove the gradio API * Remove unused import * Minor change * Remove unused import --------- Co-authored-by: oobabooga <112222186+oobabooga@users.noreply.github.com> --- README.md | 7 ++ api-example-stream.py | 102 +++++++++++----------------- api-example.py | 85 ++++++++++------------- extensions/api/blocking_api.py | 90 ++++++++++++++++++++++++ extensions/api/requirements.txt | 3 +- extensions/api/script.py | 117 ++------------------------------ extensions/api/streaming_api.py | 80 ++++++++++++++++++++++ extensions/api/util.py | 69 +++++++++++++++++++ modules/api.py | 52 -------------- modules/extensions.py | 9 ++- modules/shared.py | 12 ++++ server.py | 6 +- 12 files changed, 346 insertions(+), 286 deletions(-) create mode 100644 extensions/api/blocking_api.py create mode 100644 extensions/api/streaming_api.py create mode 100644 extensions/api/util.py delete mode 100644 modules/api.py diff --git a/README.md b/README.md index 681180ba..bea64666 100644 --- a/README.md +++ b/README.md @@ -269,6 +269,13 @@ Optionally, you can use the following command-line flags: | `--auto-launch` | Open the web UI in the default browser upon launch. | | `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | +#### API + +| Flag | Description | +|---------------------------------------|-------------| +| `--api` | Enable the API extension. | +| `--public-api` | Create a public URL for the API using Cloudfare. | + Out of memory errors? [Check the low VRAM guide](docs/Low-VRAM-guide.md). ## Presets diff --git a/api-example-stream.py b/api-example-stream.py index b8e7cfb5..b299616f 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -1,39 +1,30 @@ -''' - -Contributed by SagsMug. Thank you SagsMug. -https://github.com/oobabooga/text-generation-webui/pull/175 - -''' - import asyncio import json -import random -import string +import sys -import websockets +try: + import websockets +except ImportError: + print("Websockets package not found. Make sure it's installed.") -# Gradio changes this index from time to time. To rediscover it, set VISIBLE = False in -# modules/api.py and use the dev tools to inspect the request made after clicking on the -# button called "Run" at the bottom of the UI -GRADIO_FN = 34 - - -def random_hash(): - letters = string.ascii_lowercase + string.digits - return ''.join(random.choice(letters) for i in range(9)) +# For local streaming, the websockets are hosted without ssl - ws:// +HOST = 'localhost:5005' +URI = f'ws://{HOST}/api/v1/stream' +# For reverse-proxied streaming, the remote will likely host with ssl - wss:// +# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' async def run(context): - server = "127.0.0.1" - params = { - 'max_new_tokens': 200, + # Note: the selected defaults change from time to time. + request = { + 'prompt': context, + 'max_new_tokens': 250, 'do_sample': True, - 'temperature': 0.72, - 'top_p': 0.73, + 'temperature': 1.3, + 'top_p': 0.1, 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1.0, - 'top_k': 0, + 'repetition_penalty': 1.18, + 'top_k': 40, 'min_length': 0, 'no_repeat_ngram_size': 0, 'num_beams': 1, @@ -45,48 +36,31 @@ async def run(context): 'truncation_length': 2048, 'ban_eos_token': False, 'skip_special_tokens': True, - 'stopping_strings': [], + 'stopping_strings': [] } - payload = json.dumps([context, params]) - session = random_hash() - async with websockets.connect(f"ws://{server}:7860/queue/join") as websocket: - while content := json.loads(await websocket.recv()): - # Python3.10 syntax, replace with if elif on older - match content["msg"]: - case "send_hash": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": GRADIO_FN - })) - case "estimation": - pass - case "send_data": - await websocket.send(json.dumps({ - "session_hash": session, - "fn_index": GRADIO_FN, - "data": [ - payload - ] - })) - case "process_starts": - pass - case "process_generating" | "process_completed": - yield content["output"]["data"][0] - # You can search for your desired end indicator and - # stop generation by closing the websocket here - if (content["msg"] == "process_completed"): - break + async with websockets.connect(URI) as websocket: + await websocket.send(json.dumps(request)) -prompt = "What I would like to say is the following: " + yield context # Remove this if you just want to see the reply + + while True: + incoming_data = await websocket.recv() + incoming_data = json.loads(incoming_data) + + match incoming_data['event']: + case 'text_stream': + yield incoming_data['text'] + case 'stream_end': + return -async def get_result(): +async def print_response_stream(prompt): async for response in run(prompt): - # Print intermediate steps - print(response) + print(response, end='') + sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. - # Print final result - print(response) -asyncio.run(get_result()) +if __name__ == '__main__': + prompt = "In order to make homemade bread, follow these steps:\n1)" + asyncio.run(print_response_stream(prompt)) diff --git a/api-example.py b/api-example.py index eff610c1..4bf4f0d6 100644 --- a/api-example.py +++ b/api-example.py @@ -1,57 +1,42 @@ -''' - -This is an example on how to use the API for oobabooga/text-generation-webui. - -Make sure to start the web UI with the following flags: - -python server.py --model MODEL --listen --no-stream - -Optionally, you can also add the --share flag to generate a public gradio URL, -allowing you to use the API remotely. - -''' -import json - import requests -# Server address -server = "127.0.0.1" +# For local streaming, the websockets are hosted without ssl - http:// +HOST = 'localhost:5000' +URI = f'http://{HOST}/api/v1/generate' -# Generation parameters -# Reference: https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig -params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.72, - 'top_p': 0.73, - 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1.0, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'seed': -1, - 'add_bos_token': True, - 'truncation_length': 2048, - 'ban_eos_token': False, - 'skip_special_tokens': True, - 'stopping_strings': [], -} +# For reverse-proxied streaming, the remote will likely host with ssl - https:// +# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' -# Input prompt -prompt = "What I would like to say is the following: " +def run(context): + request = { + 'prompt': prompt, + 'max_new_tokens': 250, + 'do_sample': True, + 'temperature': 1.3, + 'top_p': 0.1, + 'typical_p': 1, + 'repetition_penalty': 1.18, + 'top_k': 40, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'seed': -1, + 'add_bos_token': True, + 'truncation_length': 2048, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'stopping_strings': [] + } -payload = json.dumps([prompt, params]) + response = requests.post(URI, json=request) -response = requests.post(f"http://{server}:7860/run/textgen", json={ - "data": [ - payload - ] -}).json() + if response.status_code == 200: + result = response.json()['results'][0]['text'] + print(prompt + result) -reply = response["data"][0] -print(reply) +if __name__ == '__main__': + prompt = "In order to make homemade bread, follow these steps:\n1)" + run(prompt) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py new file mode 100644 index 00000000..e66a6a50 --- /dev/null +++ b/extensions/api/blocking_api.py @@ -0,0 +1,90 @@ +import json +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer +from threading import Thread + +from modules import shared +from modules.text_generation import encode, generate_reply + +from extensions.api.util import build_parameters, try_start_cloudflared + + +class Handler(BaseHTTPRequestHandler): + def do_GET(self): + if self.path == '/api/v1/model': + self.send_response(200) + self.end_headers() + response = json.dumps({ + 'result': shared.model_name + }) + + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + def do_POST(self): + content_length = int(self.headers['Content-Length']) + body = json.loads(self.rfile.read(content_length).decode('utf-8')) + + if self.path == '/api/v1/generate': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + prompt = body['prompt'] + generate_params = build_parameters(body) + stopping_strings = generate_params.pop('stopping_strings') + + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings) + + answer = '' + for a in generator: + if isinstance(a, str): + answer = a + else: + answer = a[0] + + response = json.dumps({ + 'results': [{ + 'text': answer if shared.is_chat() else answer[len(prompt):] + }] + }) + self.wfile.write(response.encode('utf-8')) + elif self.path == '/api/v1/token-count': + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.end_headers() + + tokens = encode(body['prompt'])[0] + response = json.dumps({ + 'results': [{ + 'tokens': len(tokens) + }] + }) + self.wfile.write(response.encode('utf-8')) + else: + self.send_error(404) + + +def _run_server(port: int, share: bool=False): + address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + + server = ThreadingHTTPServer((address, port), Handler) + + def on_start(public_url: str): + print(f'Starting non-streaming server at public url {public_url}/api') + + if share: + try: + try_start_cloudflared(port, max_attempts=3, on_start=on_start) + except Exception: + pass + else: + print( + f'Starting API at http://{address}:{port}/api') + + server.serve_forever() + + +def start_server(port: int, share: bool = False): + Thread(target=_run_server, args=[port, share], daemon=True).start() diff --git a/extensions/api/requirements.txt b/extensions/api/requirements.txt index ad788ab8..14e29d35 100644 --- a/extensions/api/requirements.txt +++ b/extensions/api/requirements.txt @@ -1 +1,2 @@ -flask_cloudflared==0.0.12 \ No newline at end of file +flask_cloudflared==0.0.12 +websockets==11.0.2 \ No newline at end of file diff --git a/extensions/api/script.py b/extensions/api/script.py index e4c3a556..efeed71f 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -1,115 +1,10 @@ -import json -from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer -from threading import Thread - +import extensions.api.blocking_api as blocking_api +import extensions.api.streaming_api as streaming_api from modules import shared -from modules.text_generation import encode, generate_reply - -params = { - 'port': 5000, -} - - -class Handler(BaseHTTPRequestHandler): - def do_GET(self): - if self.path == '/api/v1/model': - self.send_response(200) - self.end_headers() - response = json.dumps({ - 'result': shared.model_name - }) - - self.wfile.write(response.encode('utf-8')) - else: - self.send_error(404) - - def do_POST(self): - content_length = int(self.headers['Content-Length']) - body = json.loads(self.rfile.read(content_length).decode('utf-8')) - - if self.path == '/api/v1/generate': - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - prompt = body['prompt'] - prompt_lines = [k.strip() for k in prompt.split('\n')] - max_context = body.get('max_context_length', 2048) - while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: - prompt_lines.pop(0) - - prompt = '\n'.join(prompt_lines) - generate_params = { - 'max_new_tokens': int(body.get('max_length', 200)), - 'do_sample': bool(body.get('do_sample', True)), - 'temperature': float(body.get('temperature', 0.5)), - 'top_p': float(body.get('top_p', 1)), - 'typical_p': float(body.get('typical', 1)), - 'repetition_penalty': float(body.get('rep_pen', 1.1)), - 'encoder_repetition_penalty': 1, - 'top_k': int(body.get('top_k', 0)), - 'min_length': int(body.get('min_length', 0)), - 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), - 'num_beams': int(body.get('num_beams', 1)), - 'penalty_alpha': float(body.get('penalty_alpha', 0)), - 'length_penalty': float(body.get('length_penalty', 1)), - 'early_stopping': bool(body.get('early_stopping', False)), - 'seed': int(body.get('seed', -1)), - 'add_bos_token': int(body.get('add_bos_token', True)), - 'truncation_length': int(body.get('truncation_length', 2048)), - 'ban_eos_token': bool(body.get('ban_eos_token', False)), - 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), - 'custom_stopping_strings': '', # leave this blank - 'stopping_strings': body.get('stopping_strings', []), - } - stopping_strings = generate_params.pop('stopping_strings') - generator = generate_reply(prompt, generate_params, stopping_strings=stopping_strings) - answer = '' - for a in generator: - if isinstance(a, str): - answer = a - else: - answer = a[0] - - response = json.dumps({ - 'results': [{ - 'text': answer if shared.is_chat() else answer[len(prompt):] - }] - }) - self.wfile.write(response.encode('utf-8')) - - elif self.path == '/api/v1/token-count': - # Not compatible with KoboldAI api - self.send_response(200) - self.send_header('Content-Type', 'application/json') - self.end_headers() - - tokens = encode(body['prompt'])[0] - response = json.dumps({ - 'results': [{ - 'tokens': len(tokens) - }] - }) - self.wfile.write(response.encode('utf-8')) - - else: - self.send_error(404) - - -def run_server(): - server_addr = ('0.0.0.0' if shared.args.listen else '127.0.0.1', params['port']) - server = ThreadingHTTPServer(server_addr, Handler) - if shared.args.share: - try: - from flask_cloudflared import _run_cloudflared - public_url = _run_cloudflared(params['port'], params['port'] + 1) - print(f'Starting KoboldAI compatible api at {public_url}/api') - except ImportError: - print('You should install flask_cloudflared manually') - else: - print(f'Starting KoboldAI compatible api at http://{server_addr[0]}:{server_addr[1]}/api') - server.serve_forever() +BLOCKING_PORT = 5000 +STREAMING_PORT = 5005 def setup(): - Thread(target=run_server, daemon=True).start() + blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api) + streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py new file mode 100644 index 00000000..5ffd925b --- /dev/null +++ b/extensions/api/streaming_api.py @@ -0,0 +1,80 @@ +import json +import asyncio +from websockets.server import serve +from threading import Thread + +from modules import shared +from modules.text_generation import generate_reply + +from extensions.api.util import build_parameters, try_start_cloudflared + +PATH = '/api/v1/stream' + + +async def _handle_connection(websocket, path): + + if path != PATH: + print(f'Streaming api: unknown path: {path}') + return + + async for message in websocket: + message = json.loads(message) + + prompt = message['prompt'] + generate_params = build_parameters(message) + stopping_strings = generate_params.pop('stopping_strings') + + generator = generate_reply( + prompt, generate_params, stopping_strings=stopping_strings) + + # As we stream, only send the new bytes. + skip_index = len(prompt) if not shared.is_chat() else 0 + message_num = 0 + + for a in generator: + to_send = '' + if isinstance(a, str): + to_send = a[skip_index:] + else: + to_send = a[0][skip_index:] + + await websocket.send(json.dumps({ + 'event': 'text_stream', + 'message_num': message_num, + 'text': to_send + })) + + skip_index += len(to_send) + message_num += 1 + + await websocket.send(json.dumps({ + 'event': 'stream_end', + 'message_num': message_num + })) + + +async def _run(host: str, port: int): + async with serve(_handle_connection, host, port): + await asyncio.Future() # run forever + + +def _run_server(port: int, share: bool = False): + address = '0.0.0.0' if shared.args.listen else '127.0.0.1' + + def on_start(public_url: str): + public_url = public_url.replace('https://', 'wss://') + print(f'Starting streaming server at public url {public_url}{PATH}') + + if share: + try: + try_start_cloudflared(port, max_attempts=3, on_start=on_start) + except Exception as e: + print(e) + else: + print(f'Starting streaming server at ws://{address}:{port}{PATH}') + + asyncio.run(_run(host=address, port=port)) + + +def start_server(port: int, share: bool = False): + Thread(target=_run_server, args=[port, share], daemon=True).start() diff --git a/extensions/api/util.py b/extensions/api/util.py new file mode 100644 index 00000000..cb9d9d06 --- /dev/null +++ b/extensions/api/util.py @@ -0,0 +1,69 @@ + +from threading import Thread +import time +from typing import Callable, Optional +from modules.text_generation import encode + + +def build_parameters(body): + prompt = body['prompt'] + + prompt_lines = [k.strip() for k in prompt.split('\n')] + max_context = body.get('max_context_length', 2048) + while len(prompt_lines) >= 0 and len(encode('\n'.join(prompt_lines))) > max_context: + prompt_lines.pop(0) + + prompt = '\n'.join(prompt_lines) + + generate_params = { + 'max_new_tokens': int(body.get('max_new_tokens', body.get('max_length', 200))), + 'do_sample': bool(body.get('do_sample', True)), + 'temperature': float(body.get('temperature', 0.5)), + 'top_p': float(body.get('top_p', 1)), + 'typical_p': float(body.get('typical_p', body.get('typical', 1))), + 'repetition_penalty': float(body.get('repetition_penalty', body.get('rep_pen', 1.1))), + 'encoder_repetition_penalty': float(body.get('encoder_repetition_penalty', 1.0)), + 'top_k': int(body.get('top_k', 0)), + 'min_length': int(body.get('min_length', 0)), + 'no_repeat_ngram_size': int(body.get('no_repeat_ngram_size', 0)), + 'num_beams': int(body.get('num_beams', 1)), + 'penalty_alpha': float(body.get('penalty_alpha', 0)), + 'length_penalty': float(body.get('length_penalty', 1)), + 'early_stopping': bool(body.get('early_stopping', False)), + 'seed': int(body.get('seed', -1)), + 'add_bos_token': int(body.get('add_bos_token', True)), + 'truncation_length': int(body.get('truncation_length', 2048)), + 'ban_eos_token': bool(body.get('ban_eos_token', False)), + 'skip_special_tokens': bool(body.get('skip_special_tokens', True)), + 'custom_stopping_strings': '', # leave this blank + 'stopping_strings': body.get('stopping_strings', []), + } + + return generate_params + + +def try_start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): + Thread(target=_start_cloudflared, args=[ + port, max_attempts, on_start], daemon=True).start() + + +def _start_cloudflared(port: int, max_attempts: int = 3, on_start: Optional[Callable[[str], None]] = None): + try: + from flask_cloudflared import _run_cloudflared + except ImportError: + print('You should install flask_cloudflared manually') + raise Exception( + 'flask_cloudflared not installed. Make sure you installed the requirements.txt for this extension.') + + for _ in range(max_attempts): + try: + public_url = _run_cloudflared(port, port + 1) + + if on_start: + on_start(public_url) + + return + except Exception: + time.sleep(3) + + raise Exception('Could not start cloudflared.') diff --git a/modules/api.py b/modules/api.py deleted file mode 100644 index 9de8e25d..00000000 --- a/modules/api.py +++ /dev/null @@ -1,52 +0,0 @@ -import json - -import gradio as gr - -from modules import shared -from modules.text_generation import generate_reply - -# set this to True to rediscover the fn_index using the browser DevTools -VISIBLE = False - - -def generate_reply_wrapper(string): - - # Provide defaults so as to not break the API on the client side when new parameters are added - generate_params = { - 'max_new_tokens': 200, - 'do_sample': True, - 'temperature': 0.5, - 'top_p': 1, - 'typical_p': 1, - 'repetition_penalty': 1.1, - 'encoder_repetition_penalty': 1, - 'top_k': 0, - 'min_length': 0, - 'no_repeat_ngram_size': 0, - 'num_beams': 1, - 'penalty_alpha': 0, - 'length_penalty': 1, - 'early_stopping': False, - 'seed': -1, - 'add_bos_token': True, - 'custom_stopping_strings': '', - 'truncation_length': 2048, - 'ban_eos_token': False, - 'skip_special_tokens': True, - 'stopping_strings': [], - } - params = json.loads(string) - generate_params.update(params[1]) - stopping_strings = generate_params.pop('stopping_strings') - for i in generate_reply(params[0], generate_params, stopping_strings=stopping_strings): - yield i - - -def create_apis(): - t1 = gr.Textbox(visible=VISIBLE) - t2 = gr.Textbox(visible=VISIBLE) - dummy = gr.Button(visible=VISIBLE) - - input_params = [t1] - output_params = [t2] + [shared.gradio[k] for k in ['markdown', 'html']] - dummy.click(generate_reply_wrapper, input_params, output_params, api_name='textgen') diff --git a/modules/extensions.py b/modules/extensions.py index a6903a9b..24d57f89 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -14,7 +14,8 @@ def load_extensions(): global state, setup_called for i, name in enumerate(shared.args.extensions): if name in available_extensions: - print(f'Loading the extension "{name}"... ', end='') + if name != 'api': + print(f'Loading the extension "{name}"... ', end='') try: exec(f"import extensions.{name}.script") extension = getattr(extensions, name).script @@ -22,9 +23,11 @@ def load_extensions(): setup_called.add(extension) extension.setup() state[name] = [True, i] - print('Ok.') + if name != 'api': + print('Ok.') except: - print('Fail.') + if name != 'api': + print('Fail.') traceback.print_exc() diff --git a/modules/shared.py b/modules/shared.py index 1517526a..7540e3fb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -150,6 +150,11 @@ parser.add_argument('--share', action='store_true', help='Create a public URL. T parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) +# API +parser.add_argument('--api', action='store_true', help='Enable the API extension.') +parser.add_argument('--public-api', action='store_true', help='Create a public URL for the API using Cloudfare.') + + args = parser.parse_args() args_defaults = parser.parse_args([]) @@ -171,6 +176,13 @@ if args.trust_remote_code: if args.share: print("Warning: the gradio \"share link\" feature downloads a proprietary and\nunaudited blob to create a reverse tunnel. This is potentially dangerous.\n") +# Activating the API extension +if args.api or args.public_api: + if args.extensions is None: + args.extensions = ['api'] + elif 'api' not in args.extensions: + args.extensions.append('api') + def is_chat(): return args.chat diff --git a/server.py b/server.py index 2de817cb..ca44cdb5 100644 --- a/server.py +++ b/server.py @@ -40,7 +40,7 @@ import yaml from PIL import Image import modules.extensions as extensions_module -from modules import api, chat, shared, training, ui +from modules import chat, shared, training, ui from modules.html_generator import chat_html_wrapper from modules.LoRA import add_lora_to_model from modules.models import load_model, load_soft_prompt, unload_model @@ -714,10 +714,6 @@ def create_interface(): if shared.args.extensions is not None: extensions_module.create_extensions_block() - # Create the invisible elements that define the API - if not shared.is_chat(): - api.create_apis() - # chat mode event handlers if shared.is_chat(): shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']] From 9197d3fec8c74643b1c19aeeefefa1dd889249e0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Sun, 23 Apr 2023 16:11:17 -0300 Subject: [PATCH 03/19] Update Extensions.md --- docs/Extensions.md | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/docs/Extensions.md b/docs/Extensions.md index 184ace55..b3365d84 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -16,19 +16,23 @@ command-line flag. The link above contains a directory of user extensions for text-generation-webui. +If you create an extension, you are welcome to host it in a GitHub repository and submit it to the list above. + ## Built-in extensions +Most of these have been created by the extremely talented contributors that you can find here: [contributors](https://github.com/oobabooga/text-generation-webui/graphs/contributors?from=2022-12-18&to=&type=a). + |Extension|Description| |---------|-----------| +|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API with two endpoints, one for streaming at `/api/v1/stream` port 5005 and another for blocking at `/api/v1/generate` por 5000. This is the main API for this web UI. | |[google_translate](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/google_translate)| Automatically translates inputs and outputs using Google Translate.| |[character_bias](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/character_bias)| Just a very simple example that biases the bot's responses in chat mode.| |[gallery](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/gallery/)| Creates a gallery with the chat characters and their pictures. | |[silero_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/silero_tts)| Text-to-speech extension using [Silero](https://github.com/snakers4/silero-models). When used in chat mode, it replaces the responses with an audio widget. | -|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. Author: [@MetaIX](https://github.com/MetaIX). | -|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. Author: [@SillyLossy](https://github.com/sillylossy).| -|[api](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/api)| Creates an API similar to the one provided by KoboldAI. Works with TavernAI: start the web UI with `python server.py --no-stream --extensions api` and set the API URL to `http://127.0.0.1:5000/api`. Author: [@mayaeary](https://github.com/mayaeary).| -|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. Author: [@EliasVincent](https://github.com/EliasVincent).| -|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). Author: [@Brawlence](https://github.com/Brawlence).| +|[elevenlabs_tts](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/elevenlabs_tts)| Text-to-speech extension using the [ElevenLabs](https://beta.elevenlabs.io/) API. You need an API key to use it. | +|[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | +|[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | +|[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | ## How to write an extension From 12212cf6be8e7ab683593c83f8de78deb77fda79 Mon Sep 17 00:00:00 2001 From: Wojtab Date: Mon, 24 Apr 2023 01:32:22 +0200 Subject: [PATCH 04/19] LLaVA support (#1487) --- characters/instruction-following/LLaVA.yaml | 3 + docs/Extensions.md | 6 +- extensions/llava/README.md | 49 ++++ extensions/llava/script.py | 279 ++++++++++++++++++++ models/config.yaml | 4 + modules/GPTQ_loader.py | 2 +- modules/chat.py | 31 +-- modules/extensions.py | 58 +++- modules/models.py | 13 +- modules/shared.py | 2 +- modules/text_generation.py | 19 +- settings-template.json | 2 +- 12 files changed, 426 insertions(+), 42 deletions(-) create mode 100644 characters/instruction-following/LLaVA.yaml create mode 100644 extensions/llava/README.md create mode 100644 extensions/llava/script.py diff --git a/characters/instruction-following/LLaVA.yaml b/characters/instruction-following/LLaVA.yaml new file mode 100644 index 00000000..b3999e46 --- /dev/null +++ b/characters/instruction-following/LLaVA.yaml @@ -0,0 +1,3 @@ +name: "### Assistant" +your_name: "### Human" +context: "You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. Follow the instructions carefully and explain your answers in detail.\n### Human: \nHi!\n### Assistant: \nHi there! How can I help you today?\n" \ No newline at end of file diff --git a/docs/Extensions.md b/docs/Extensions.md index b3365d84..72eec19e 100644 --- a/docs/Extensions.md +++ b/docs/Extensions.md @@ -33,6 +33,7 @@ Most of these have been created by the extremely talented contributors that you |[send_pictures](https://github.com/oobabooga/text-generation-webui/blob/main/extensions/send_pictures/)| Creates an image upload field that can be used to send images to the bot in chat mode. Captions are automatically generated using BLIP. | |[whisper_stt](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/whisper_stt)| Allows you to enter your inputs in chat mode using your microphone. | |[sd_api_pictures](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/sd_api_pictures)| Allows you to request pictures from the bot in chat mode, which will be generated using the AUTOMATIC1111 Stable Diffusion API. See examples [here](https://github.com/oobabooga/text-generation-webui/pull/309). | +|[llava](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava) | Adds LLaVA multimodal model support. For detailed description see [README.md](https://github.com/oobabooga/text-generation-webui/tree/main/extensions/llava/README.md) in the extension directory. | ## How to write an extension @@ -45,6 +46,7 @@ Most of these have been created by the extremely talented contributors that you | `def output_modifier(string)` | Modifies the output string before it is presented in the UI. In chat mode, it is applied to the bot's reply. Otherwise, it is applied to the entire output. | | `def bot_prefix_modifier(string)` | Applied in chat mode to the prefix for the bot's reply (more on that below). | | `def custom_generate_chat_prompt(...)` | Overrides the prompt generator in chat mode. | +| `def tokenizer_modifier(state, prompt, input_ids, input_embeds)` | Modifies the `input_ids`/`input_embeds` fed to the model. Should return `prompt`, `input_ids`, `input_embeds`. See `llava` extension for an example | Additionally, the script may define two special global variables: @@ -70,7 +72,9 @@ input_hijack = { 'value': ["", ""] } ``` -This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the vales inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example. +This is only relevant in chat mode. If your extension sets `input_hijack['state']` to `True` at any moment, the next call to `modules.chat.chatbot_wrapper` will use the values inside `input_hijack['value']` as the user input for text generation. See the `send_pictures` extension above for an example. + +Additionally, your extension can set the value to be a callback, in the form of `def cb(text: str, visible_text: str) -> [str, str]`. See the `llava` extension above for an example. ## The `bot_prefix_modifier` diff --git a/extensions/llava/README.md b/extensions/llava/README.md new file mode 100644 index 00000000..848c7cb0 --- /dev/null +++ b/extensions/llava/README.md @@ -0,0 +1,49 @@ +# LLaVA + +## Description +Adds [LLaVA](https://github.com/haotian-liu/LLaVA) multimodality support to text-generation-webui. + +https://user-images.githubusercontent.com/3718215/233817203-69b57e77-0c55-4fd6-b742-3204bb13b8fc.mp4 + +## Usage +To run this extension, download LLaVA weights, for example from [here](https://huggingface.co/wojtab/llava-13b-v0-4bit-128g), and then start server.py with `--extensions llava` argument. + +When in ui, go to instruct mode, and select LLaVA template, you also should add `"\n###"` to "Custom stopping strings" in parameters tab. + +Do note, that each image takes up 258 tokens, so adjust max_new_tokens to be at most 1700 (recommended value is between 200 to 500), so the images don't get truncated. + +To send an image, just upload it to the extension field below chat, and send a prompt as always. The image will be added to the end of your message. If you wish to modify the placement, include a string `` in your prompt. + +Additionally, there is *Embed all images, not only the last one* checkbox. It modifies the image embeddings, by default (if it's unchecked), all but the most recent images have their embeddings empty, so they are not fed to the network. From initial testing, it seems as LLaVA considers the features in all images at the same time, so by default the extension skips previous images. If you want to include them anyway, just tick this checkbox. + +## Extension config +This extension uses following parameters (from settings.json): +|Parameter|Description| +|---------|-----------| +|`llava-clip_bits`|Number of bits to load CLIP feature extractor in (either 32 or 16, default=32)| +|`llava-clip_device`|Torch device to run the extractor on, for example `cpu` or `cuda:0`, by default `cuda:0` if available| +|`llava-projector_bits`|Number of bits to load CLIP->LLaMA feature projector in (either 32 or 16, default=32)| +|`llava-projector_bits`|Torch device to run the CLIP->LLaMA feature projector on, for example `cpu` or `cuda:0`, by default `cuda:0` if available| +|`llava-add_all_images_to_prompt`|Default value of "Embed all images, not only the last one" checkbox| + +## Technical description + +### Original LLaVA +The default LLaVA implementation uses modified `transformers` library, however this extension forgoes this requirement. The transformers are modified in LLaVA in such a way, that the entire LLaVA model gets loaded, and the inference now looks as follows: +``` +images --> CLIP --> projector --> input embeddings for images --> | + | --> LLaMA +prompt -------------------------> input embeddings for text ----> | +``` +The images are represented in the prompt by the following token IDs: +- 32000 - `` - placeholder token for embeddings from projector +- 32001 - `` - token marking start of an image +- 32002 - `` - token marking end of an image + +By default, image will be represented as `*256`. The input embeddings for an image are converted with a single linear layer of the projector, then they are placed instead of `` tokens. +The concatenated prompt then gets fed to fine-tuned LLaMA. + +### In this extension + +Using default transformers, they only load the LLaMA part of LLaVA, ignoring the added projector weights, and not loading CLIP. We then reconstruct the `images -> CLIP -> projector` pipeline ourselves, then concatenate the input embeddings, and feed it to LLaMA loaded by transformers. This allows us to use normal flow from webui to load this model, and just hijack the model input with additional features. +Splitting it to 3 separate models, allows us to configure each of them, and to move them to different devices(for example we can run CLIP+projector on CPU and LLaMA on GPU). Also, it enables us to use 4-bit GPTQ quantization for LLaVA, massively cutting down the VRAM requirement (it should be possible to fit on 12GB of VRAM with full context size by moving CLIP and projector to CPU). \ No newline at end of file diff --git a/extensions/llava/script.py b/extensions/llava/script.py new file mode 100644 index 00000000..a2ad34d5 --- /dev/null +++ b/extensions/llava/script.py @@ -0,0 +1,279 @@ +import base64 +import re +import time +from dataclasses import dataclass +from functools import partial +from io import BytesIO + +import gradio as gr +import torch +from huggingface_hub import hf_hub_download +from PIL import Image +from transformers import CLIPImageProcessor, CLIPVisionModel + +from modules import shared +from modules.extensions import apply_extensions +from modules.text_generation import encode, get_max_prompt_length + +params = { + "add_all_images_to_prompt": False, + # device to run CLIP on + "clip_device": None, + # bits to load clip in either 32 or 16 (it doesn't support 8-bit) + "clip_bits": 32, + # device to run projector on + "projector_device": None, + # projector bits, either 32 or 16 + "projector_bits": 32 +} + + +# If 'state' is True, will hijack the next chat generation +input_hijack = { + 'state': False, + 'value': ["", ""] +} + + +# initialized in ui, so that params are loaded from settings +llava_embedder = None + + +@dataclass +class Token: + token: str + id: int + + +class LLaVAEmbedder: + IM_PATCH = Token("", 32000) + IM_START = Token("", 32001) + IM_END = Token("", 32002) + CLIP_VIT_HUB_NAME = 'openai/clip-vit-large-patch14' + PROJECTOR_HUB_NAME = 'liuhaotian/LLaVA-13b-pretrain-projector-v0' + PROJECTOR_FILE = 'LLaVA-13b-pretrain-projector-v0-CC3M-595K-original_caption.bin' + + def __init__(self): + self.clip_device = self._get_device("clip_device") + self.clip_dtype = self._get_dtype("clip_bits") + self.projector_device = self._get_device("projector_device") + self.projector_dtype = self._get_dtype("projector_bits") + self.image_processor, self.vision_tower, self.mm_projector = self._load_models() + + def _get_device(self, setting_name): + if params[setting_name] is None: + return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + return torch.device(params[setting_name]) + + def _get_dtype(self, setting_name): + return torch.float32 if int(params[setting_name]) == 32 else torch.float16 + + def _load_models(self): + start_ts = time.time() + + print(f"LLaVA - Loading {LLaVAEmbedder.CLIP_VIT_HUB_NAME} as {self.clip_dtype} on {self.clip_device}...") + image_processor = CLIPImageProcessor.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype) + vision_tower = CLIPVisionModel.from_pretrained(LLaVAEmbedder.CLIP_VIT_HUB_NAME, torch_dtype=self.clip_dtype).to(self.clip_device) + + print(f"LLaVA - Loading {LLaVAEmbedder.PROJECTOR_HUB_NAME} as {self.projector_dtype} on {self.projector_device}...") + projector_path = hf_hub_download(LLaVAEmbedder.PROJECTOR_HUB_NAME, LLaVAEmbedder.PROJECTOR_FILE) + mm_projector = torch.nn.Linear(1024, 5120) + projector_data = torch.load(projector_path) + mm_projector.weight = torch.nn.Parameter(projector_data['model.mm_projector.weight'].to(dtype=self.projector_dtype), False) + mm_projector.bias = torch.nn.Parameter(projector_data['model.mm_projector.bias'].to(dtype=self.projector_dtype), False) + mm_projector = mm_projector.to(self.projector_device) + + print(f"LLaVA supporting models loaded, took {time.time() - start_ts:.2f} seconds") + return image_processor, vision_tower, mm_projector + + def _update_prompt(self, prompt, images): + for _ in images: + # replace the image token with the image patch token in the prompt (each occurrence) + replace_token = LLaVAEmbedder.IM_PATCH.token * 256 + replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token + prompt = re.sub(r"", replace_token, prompt, 1) + return prompt + + def _extract_image_features(self, images): + images = self.image_processor(images, return_tensors='pt')['pixel_values'] + images = images.to(self.clip_device, dtype=self.clip_dtype) + + with torch.no_grad(): + image_forward_outs = self.vision_tower(images, output_hidden_states=True) + select_hidden_state_layer = -2 + select_hidden_state = image_forward_outs.hidden_states[select_hidden_state_layer] + image_features = select_hidden_state[:, 1:].to(self.projector_device, dtype=self.projector_dtype) + image_features = self.mm_projector(image_features) + return image_features + + def forward(self, prompt, images, state): + prompt = self._update_prompt(prompt, images) + input_ids = encode(prompt, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state))[0] + input_embeds = shared.model.model.embed_tokens(input_ids).to(self.projector_device) + + if input_ids[0] == LLaVAEmbedder.IM_PATCH.id: + # prompt got truncated in the middle of an image, remove the image data + im_end = torch.where(input_ids == LLaVAEmbedder.IM_END.id)[0][0] + input_ids = input_ids[im_end+1:] + input_embeds = input_embeds[im_end+1:] + leftover_images = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0].shape[0] + print(f"LLaVA - WARNING: removed {len(images) - leftover_images} image(s) from prompt. The generation might be broken, try decreasing max_new_tokens") + images = images[-leftover_images:] + if len(images) == 0: + return prompt, input_ids, input_embeds, 0 + + total_embedded = 0 + image_features = self._extract_image_features(images).to(self.projector_device) + image_start_tokens = torch.where(input_ids == LLaVAEmbedder.IM_START.id)[0] + + if not torch.any(input_ids == LLaVAEmbedder.IM_PATCH.id) or len(image_start_tokens) == 0: + # multimodal LLM, but the current prompt is not multimodal/truncated + return prompt, input_ids, input_embeds, total_embedded + + cur_image_idx = 0 + if not params['add_all_images_to_prompt']: + image_start_tokens = [image_start_tokens[-1]] + cur_image_idx = -1 + + for image_start_token_pos in image_start_tokens: + cur_image_features = image_features[cur_image_idx] + num_patches = cur_image_features.shape[0] + input_embeds = torch.cat((input_embeds[:image_start_token_pos+1], cur_image_features, input_embeds[image_start_token_pos + num_patches + 1:]), dim=0) + cur_image_idx += 1 + total_embedded += 1 + + return prompt, input_ids, input_embeds, total_embedded + + @staticmethod + def len_in_tokens(text): + images = re.findall(r"", text) + image_tokens = 0 + for _ in images: + image_tokens += 258 + return len(encode(re.sub(r"", '', text))[0]) + image_tokens + + +def add_chat_picture(picture, text, visible_text): + # resize the image, so that shortest edge is at least 224 (size for CLIP), and at most 300 (to keep history manageable) + max_hw, min_hw = max(picture.size), min(picture.size) + aspect_ratio = max_hw / min_hw + shortest_edge = int(max(300 / aspect_ratio, 224)) + longest_edge = int(shortest_edge * aspect_ratio) + w = shortest_edge if picture.width < picture.height else longest_edge + h = shortest_edge if picture.width >= picture.height else longest_edge + picture = picture.resize((w,h)) + + buffer = BytesIO() + picture.save(buffer, format="JPEG") + img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') + visible = f'' + internal = f'' + + if visible_text == '' or visible_text is None: + visible_text = text + + if '' in text: + text = text.replace('', internal) + else: + text = text + '\n' + internal + + if '' in visible_text: + visible_text = visible_text.replace('', visible) + else: + visible_text = visible_text + '\n' + visible + + return text, visible_text + + +def fix_picture_after_remove_last(text, visible_text): + image = re.search(r'', text) + if image is None: + return text, visible_text + if visible_text is None: + visible_text = text + text = re.sub(r'', "", text) + return text, visible_text + + +def custom_generate_chat_prompt(user_input, state, **kwargs): + impersonate = kwargs['impersonate'] if 'impersonate' in kwargs else False + _continue = kwargs['_continue'] if '_continue' in kwargs else False + also_return_rows = kwargs['also_return_rows'] if 'also_return_rows' in kwargs else False + rows = [f"{state['context'].strip()}\n"] + min_rows = 3 + + # Finding the maximum prompt size + chat_prompt_size = state['chat_prompt_size'] + if shared.soft_prompt: + chat_prompt_size -= shared.soft_prompt_tensor.shape[1] + max_length = min(get_max_prompt_length(state), chat_prompt_size) + + prefix1 = f"{state['name1']}: " + prefix2 = f"{state['name2']}: " + + i = len(shared.history['internal']) - 1 + while i >= 0 and LLaVAEmbedder.len_in_tokens(''.join(rows)) < max_length: + if _continue and i == len(shared.history['internal']) - 1: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1]}") + else: + rows.insert(1, f"{prefix2}{shared.history['internal'][i][1].strip()}{state['end_of_turn']}\n") + + string = shared.history['internal'][i][0] + if string != '': + rows.insert(1, f"{prefix1}{string.strip()}{state['end_of_turn']}\n") + + i -= 1 + + if impersonate: + min_rows = 2 + rows.append(f"{prefix1}") + elif not _continue: + # Adding the user message + if len(user_input) > 0: + rows.append(f"{prefix1}{user_input}{state['end_of_turn']}\n") + + # Adding the Character prefix + rows.append(apply_extensions("bot_prefix", f"{prefix2}")) + + while len(rows) > min_rows and LLaVAEmbedder.len_in_tokens(''.join(rows)) >= max_length: + rows.pop(1) + prompt = ''.join(rows) + + if also_return_rows: + return prompt, rows + else: + return prompt + + +def tokenizer_modifier(state, prompt, input_ids, input_embeds): + global params + start_ts = time.time() + image_matches = re.finditer(r"", prompt) + images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches] + + if len(images) == 0: + return prompt, input_ids, input_embeds + + prompt, input_ids, input_embeds, total_embedded = llava_embedder.forward(prompt, images, state) + print(f'LLaVA - Embedded {total_embedded} image(s) in {time.time()-start_ts:.2f}s') + return prompt, input_ids.unsqueeze(0).to(shared.model.device), input_embeds.unsqueeze(0).to(shared.model.device) + + +def ui(): + global llava_embedder + llava_embedder = LLaVAEmbedder() + with gr.Column(): + picture_select = gr.Image(label='Send a picture', type='pil') + # I found that it doesn't deal super well with multiple images, and demo ui had a bug where it included only the last image anyway + single_image_checkbox = gr.Checkbox(False, label='Embed all images, not only the last one') + # Prepare the input hijack + picture_select.upload( + lambda picture: input_hijack.update({"state": True, "value": partial(add_chat_picture, picture)}), + [picture_select], + None + ) + picture_select.clear(lambda: input_hijack.update({"state": False, "value": ["",""]}), None, None) + single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) + shared.gradio['Generate'].click(lambda: None, None, picture_select) + shared.gradio['textbox'].submit(lambda: None, None, picture_select) + shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None) diff --git a/models/config.yaml b/models/config.yaml index 3ebf21f8..e9aa3a55 100644 --- a/models/config.yaml +++ b/models/config.yaml @@ -48,3 +48,7 @@ llama-[0-9]*b-4bit$: .*chatglm: mode: 'instruct' instruction_template: 'ChatGLM' +.*llava: + mode: 'instruct' + model_type: 'llama' + instruction_template: 'LLaVA' diff --git a/modules/GPTQ_loader.py b/modules/GPTQ_loader.py index a42dbcf3..58c4a0bb 100644 --- a/modules/GPTQ_loader.py +++ b/modules/GPTQ_loader.py @@ -135,7 +135,7 @@ def load_quantized(model_name): # Find the model type if not shared.args.model_type: name = model_name.lower() - if any((k in name for k in ['llama', 'alpaca', 'vicuna'])): + if any((k in name for k in ['llama', 'alpaca', 'vicuna', 'llava'])): model_type = 'llama' elif any((k in name for k in ['opt-', 'galactica'])): model_type = 'opt' diff --git a/modules/chat.py b/modules/chat.py index 3eebb1a3..c4703236 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -64,7 +64,7 @@ def generate_chat_prompt(user_input, state, **kwargs): rows.append(f"{this_prefix1}{user_input}{state['end_of_turn']}\n") # Adding the Character prefix - rows.append(apply_extensions(f"{prefix2.strip() if not is_instruct else prefix2}", "bot_prefix")) + rows.append(apply_extensions("bot_prefix", f"{prefix2.strip() if not is_instruct else prefix2}")) while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: rows.pop(1) @@ -127,29 +127,22 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): cumulative_reply = '' last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None just_started = True - visible_text = custom_generate_chat_prompt = None + visible_text = None eos_token = '\n' if state['stop_at_newline'] else None stopping_strings = get_stopping_strings(state) - # Check if any extension wants to hijack this function call - for extension, _ in extensions_module.iterator(): - if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: - extension.input_hijack['state'] = False - text, visible_text = extension.input_hijack['value'] - if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): - custom_generate_chat_prompt = extension.custom_generate_chat_prompt + text, visible_text = apply_extensions('input_hijack', text, visible_text) if visible_text is None: visible_text = text if not _continue: - text = apply_extensions(text, "input") + text = apply_extensions("input", text) # Generating the prompt kwargs = {'_continue': _continue} - if custom_generate_chat_prompt is None: + prompt = apply_extensions('custom_generate_chat_prompt', text, state, **kwargs) + if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) - else: - prompt = custom_generate_chat_prompt(text, state, **kwargs) # Yield *Is typing...* if not any((regenerate, _continue)): @@ -164,7 +157,7 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): # Extracting the reply reply, next_character_found = extract_message_from_reply(reply, state) visible_reply = re.sub("(||{{user}})", state['name1'], reply) - visible_reply = apply_extensions(visible_reply, "output") + visible_reply = apply_extensions("output", visible_reply) # We need this global variable to handle the Stop event, # otherwise gradio gets confused @@ -273,14 +266,14 @@ def send_last_reply_to_input(): def replace_last_reply(text, name1, name2, mode): if len(shared.history['visible']) > 0: shared.history['visible'][-1][1] = text - shared.history['internal'][-1][1] = apply_extensions(text, "input") + shared.history['internal'][-1][1] = apply_extensions("input", text) return chat_html_wrapper(shared.history['visible'], name1, name2, mode) def send_dummy_message(text, name1, name2, mode): shared.history['visible'].append([text, '']) - shared.history['internal'].append([apply_extensions(text, "input"), '']) + shared.history['internal'].append([apply_extensions("input", text), '']) return chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -289,7 +282,7 @@ def send_dummy_reply(text, name1, name2, mode): shared.history['visible'].append(['', '']) shared.history['internal'].append(['', '']) shared.history['visible'][-1][1] = text - shared.history['internal'][-1][1] = apply_extensions(text, "input") + shared.history['internal'][-1][1] = apply_extensions("input", text) return chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -303,7 +296,7 @@ def clear_chat_log(name1, name2, greeting, mode): if greeting != '': shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] - shared.history['visible'] += [['', apply_extensions(greeting, "output")]] + shared.history['visible'] += [['', apply_extensions("output", greeting)]] # Save cleared logs save_history(mode) @@ -475,7 +468,7 @@ def load_character(character, name1, name2, mode): # Insert greeting if it exists if greeting != "": shared.history['internal'] += [['<|BEGIN-VISIBLE-CHAT|>', greeting]] - shared.history['visible'] += [['', apply_extensions(greeting, "output")]] + shared.history['visible'] += [['', apply_extensions("output", greeting)]] # Create .json log files since they don't already exist save_history(mode) diff --git a/modules/extensions.py b/modules/extensions.py index 24d57f89..92d86772 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,4 +1,5 @@ import traceback +from functools import partial import gradio as gr @@ -39,17 +40,60 @@ def iterator(): # Extension functions that map string -> string -def apply_extensions(text, typ): +def _apply_string_extensions(function_name, text): for extension, _ in iterator(): - if typ == "input" and hasattr(extension, "input_modifier"): - text = extension.input_modifier(text) - elif typ == "output" and hasattr(extension, "output_modifier"): - text = extension.output_modifier(text) - elif typ == "bot_prefix" and hasattr(extension, "bot_prefix_modifier"): - text = extension.bot_prefix_modifier(text) + if hasattr(extension, function_name): + text = getattr(extension, function_name)(text) return text +# Input hijack of extensions +def _apply_input_hijack(text, visible_text): + for extension, _ in iterator(): + if hasattr(extension, 'input_hijack') and extension.input_hijack['state']: + extension.input_hijack['state'] = False + if callable(extension.input_hijack['value']): + text, visible_text = extension.input_hijack['value'](text, visible_text) + else: + text, visible_text = extension.input_hijack['value'] + return text, visible_text + + +# custom_generate_chat_prompt handling +def _apply_custom_generate_chat_prompt(text, state, **kwargs): + custom_generate_chat_prompt = None + for extension, _ in iterator(): + if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): + custom_generate_chat_prompt = extension.custom_generate_chat_prompt + if custom_generate_chat_prompt is not None: + return custom_generate_chat_prompt(text, state, **kwargs) + return None + + +# Extension functions that override the default tokenizer output +def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): + for extension, _ in iterator(): + if hasattr(extension, function_name): + prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + return prompt, input_ids, input_embeds + + +EXTENSION_MAP = { + "input": partial(_apply_string_extensions, "input_modifier"), + "output": partial(_apply_string_extensions, "output_modifier"), + "bot_prefix": partial(_apply_string_extensions, "bot_prefix_modifier"), + "tokenizer": partial(_apply_tokenizer_extensions, "tokenizer_modifier"), + "input_hijack": _apply_input_hijack, + "custom_generate_chat_prompt": _apply_custom_generate_chat_prompt +} + + +def apply_extensions(typ, *args, **kwargs): + if typ not in EXTENSION_MAP: + raise ValueError(f"Invalid extension type {typ}") + return EXTENSION_MAP[typ](*args, **kwargs) + + def create_extensions_block(): global setup_called diff --git a/modules/models.py b/modules/models.py index 469cbaf7..a17fba4b 100644 --- a/modules/models.py +++ b/modules/models.py @@ -50,6 +50,8 @@ def find_model_type(model_name): return 'chatglm' elif 'galactica' in model_name: return 'galactica' + elif 'llava' in model_name: + return 'llava' elif any((k in model_name for k in ['gpt4chan', 'gpt-4chan'])): return 'gpt4chan' else: @@ -217,11 +219,12 @@ def load_model(model_name): tokenizer = None # Try to load an universal LLaMA tokenizer - for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: - if p.exists(): - print(f"Loading the universal LLaMA tokenizer from {p}...") - tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) - break + if shared.model_type != 'llava': + for p in [Path(f"{shared.args.model_dir}/llama-tokenizer/"), Path(f"{shared.args.model_dir}/oobabooga_llama-tokenizer/")]: + if p.exists(): + print(f"Loading the universal LLaMA tokenizer from {p}...") + tokenizer = LlamaTokenizer.from_pretrained(p, clean_up_tokenization_spaces=True) + break # Otherwise, load it from the model folder and hope that these # are not outdated tokenizer files. diff --git a/modules/shared.py b/modules/shared.py index 7540e3fb..82acf3c0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -56,7 +56,7 @@ settings = { 'chat_default_extensions': ["gallery"], 'presets': { 'default': 'Default', - '.*(alpaca|llama)': "LLaMA-Precise", + '.*(alpaca|llama|llava)': "LLaMA-Precise", '.*pygmalion': 'NovelAI-Storywriter', '.*RWKV': 'Naive', }, diff --git a/modules/text_generation.py b/modules/text_generation.py index e1e169a0..032fc84c 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -138,7 +138,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): original_question = question if not shared.is_chat(): - question = apply_extensions(question, 'input') + question = apply_extensions('input', question) # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier @@ -155,7 +155,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): reply = shared.model.generate(context=question, **generate_params) output = original_question + reply if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + reply = original_question + apply_extensions('output', reply) yield formatted_outputs(reply, shared.model_name) else: if not shared.is_chat(): @@ -166,7 +166,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): for reply in shared.model.generate_with_streaming(context=question, **generate_params): output = original_question + reply if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + reply = original_question + apply_extensions('output', reply) yield formatted_outputs(reply, shared.model_name) except Exception: @@ -179,7 +179,6 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): return input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) - original_input_ids = input_ids output = input_ids[0] if shared.args.verbose: @@ -218,10 +217,16 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): generate_params.update({'synced_gpus': True}) if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds) + original_input_ids = input_ids generate_params.update({'inputs_embeds': inputs_embeds}) generate_params.update({'inputs': filler_input_ids}) else: + question, input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, input_ids, None) + original_input_ids = input_ids generate_params.update({'inputs': input_ids}) + if inputs_embeds is not None: + generate_params.update({'inputs_embeds': inputs_embeds}) try: # Generate the entire reply at once. @@ -237,7 +242,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): new_tokens = len(output) - len(input_ids[0]) reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + reply = original_question + apply_extensions('output', reply) yield formatted_outputs(reply, shared.model_name) @@ -265,7 +270,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): new_tokens = len(output) - len(input_ids[0]) reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + reply = original_question + apply_extensions('output', reply) if output[-1] in eos_token_ids: break @@ -285,7 +290,7 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): new_tokens = len(output) - len(original_input_ids[0]) reply = decode(output[-new_tokens:], state['skip_special_tokens']) if not shared.is_chat(): - reply = original_question + apply_extensions(reply, 'output') + reply = original_question + apply_extensions('output', reply) if np.count_nonzero(np.isin(input_ids[0], eos_token_ids)) < np.count_nonzero(np.isin(output, eos_token_ids)): break diff --git a/settings-template.json b/settings-template.json index 298e4421..286add0b 100644 --- a/settings-template.json +++ b/settings-template.json @@ -30,7 +30,7 @@ ], "presets": { "default": "Default", - ".*(alpaca|llama)": "LLaMA-Precise", + ".*(alpaca|llama|llava)": "LLaMA-Precise", ".*pygmalion": "NovelAI-Storywriter", ".*RWKV": "Naive" }, From 04b98a8485c93f3a6356947d7f500ece892e5931 Mon Sep 17 00:00:00 2001 From: Wojtab Date: Mon, 24 Apr 2023 03:58:15 +0200 Subject: [PATCH 05/19] Fix Continue for LLaVA (#1507) --- extensions/llava/script.py | 40 +++++++++++++------------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/extensions/llava/script.py b/extensions/llava/script.py index a2ad34d5..d48e35fa 100644 --- a/extensions/llava/script.py +++ b/extensions/llava/script.py @@ -91,7 +91,7 @@ class LLaVAEmbedder: # replace the image token with the image patch token in the prompt (each occurrence) replace_token = LLaVAEmbedder.IM_PATCH.token * 256 replace_token = LLaVAEmbedder.IM_START.token + replace_token + LLaVAEmbedder.IM_END.token - prompt = re.sub(r"", replace_token, prompt, 1) + prompt = re.sub(r'', replace_token, prompt, 1) return prompt def _extract_image_features(self, images): @@ -146,11 +146,11 @@ class LLaVAEmbedder: @staticmethod def len_in_tokens(text): - images = re.findall(r"", text) + images = re.findall(r'', text) image_tokens = 0 for _ in images: image_tokens += 258 - return len(encode(re.sub(r"", '', text))[0]) + image_tokens + return len(encode(re.sub(r'', '', text))[0]) + image_tokens def add_chat_picture(picture, text, visible_text): @@ -166,32 +166,21 @@ def add_chat_picture(picture, text, visible_text): buffer = BytesIO() picture.save(buffer, format="JPEG") img_str = base64.b64encode(buffer.getvalue()).decode('utf-8') - visible = f'' - internal = f'' + image = f'' + + + if '' in text: + text = text.replace('', image) + else: + text = text + '\n' + image if visible_text == '' or visible_text is None: visible_text = text - - if '' in text: - text = text.replace('', internal) + elif '' in visible_text: + visible_text = visible_text.replace('', image) else: - text = text + '\n' + internal + visible_text = visible_text + '\n' + image - if '' in visible_text: - visible_text = visible_text.replace('', visible) - else: - visible_text = visible_text + '\n' + visible - - return text, visible_text - - -def fix_picture_after_remove_last(text, visible_text): - image = re.search(r'', text) - if image is None: - return text, visible_text - if visible_text is None: - visible_text = text - text = re.sub(r'', "", text) return text, visible_text @@ -248,7 +237,7 @@ def custom_generate_chat_prompt(user_input, state, **kwargs): def tokenizer_modifier(state, prompt, input_ids, input_embeds): global params start_ts = time.time() - image_matches = re.finditer(r"", prompt) + image_matches = re.finditer(r'', prompt) images = [Image.open(BytesIO(base64.b64decode(match.group(1)))) for match in image_matches] if len(images) == 0: @@ -276,4 +265,3 @@ def ui(): single_image_checkbox.change(lambda x: params.update({"add_all_images_to_prompt": x}), single_image_checkbox, None) shared.gradio['Generate'].click(lambda: None, None, picture_select) shared.gradio['textbox'].submit(lambda: None, None, picture_select) - shared.gradio['Remove last'].click(lambda: input_hijack.update({"state": True, "value": fix_picture_after_remove_last}), None, None) From 435f8cc0e75c7566036644fbb6b9885b10533f61 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 00:47:40 -0300 Subject: [PATCH 06/19] Simplify some chat functions --- modules/chat.py | 88 +++++++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 39 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index c4703236..6801741a 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -10,7 +10,6 @@ from pathlib import Path import yaml from PIL import Image -import modules.extensions as extensions_module import modules.shared as shared from modules.extensions import apply_extensions from modules.html_generator import chat_html_wrapper, make_thumbnail @@ -30,8 +29,8 @@ def generate_chat_prompt(user_input, state, **kwargs): chat_prompt_size = state['chat_prompt_size'] if shared.soft_prompt: chat_prompt_size -= shared.soft_prompt_tensor.shape[1] - max_length = min(get_max_prompt_length(state), chat_prompt_size) + max_length = min(get_max_prompt_length(state), chat_prompt_size) if is_instruct: prefix1 = f"{state['name1']}\n" prefix2 = f"{state['name2']}\n" @@ -57,7 +56,6 @@ def generate_chat_prompt(user_input, state, **kwargs): min_rows = 2 rows.append(f"{prefix1.strip() if not is_instruct else prefix1}") elif not _continue: - # Adding the user message if len(user_input) > 0: this_prefix1 = prefix1.replace('<|round|>', f'{len(shared.history["internal"])}') # for ChatGLM @@ -68,8 +66,8 @@ def generate_chat_prompt(user_input, state, **kwargs): while len(rows) > min_rows and len(encode(''.join(rows))[0]) >= max_length: rows.pop(1) - prompt = ''.join(rows) + prompt = ''.join(rows) if also_return_rows: return prompt, rows else: @@ -81,6 +79,7 @@ def get_stopping_strings(state): stopping_strings = [f"\n{state['name1']}", f"\n{state['name2']}"] else: stopping_strings = [f"\n{state['name1']}:", f"\n{state['name2']}:"] + stopping_strings += ast.literal_eval(f"[{state['custom_stopping_strings']}]") return stopping_strings @@ -111,13 +110,13 @@ def extract_message_from_reply(reply, state): break else: continue + break return reply, next_character_found def chatbot_wrapper(text, state, regenerate=False, _continue=False): - if shared.model_name == 'None' or shared.model is None: print("No model is loaded! Select one in the Model tab.") yield shared.history['visible'] @@ -125,18 +124,30 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): # Defining some variables cumulative_reply = '' - last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] if _continue else None just_started = True visible_text = None eos_token = '\n' if state['stop_at_newline'] else None stopping_strings = get_stopping_strings(state) - text, visible_text = apply_extensions('input_hijack', text, visible_text) + # Preparing the input + if not any((regenerate, _continue)): + text, visible_text = apply_extensions('input_hijack', text, visible_text) + if visible_text is None: + visible_text = text - if visible_text is None: - visible_text = text - if not _continue: - text = apply_extensions("input", text) + text = apply_extensions('input', text) + # *Is typing...* + yield shared.history['visible'] + [[visible_text, shared.processing_message]] + else: + text, visible_text = shared.history['internal'][-1][0], shared.history['visible'][-1][0] + if regenerate: + shared.history['visible'].pop() + shared.history['internal'].pop() + # *Is typing...* + yield shared.history['visible'] + [[visible_text, shared.processing_message]] + elif _continue: + last_reply = [shared.history['internal'][-1][1], shared.history['visible'][-1][1]] + yield shared.history['visible'][:-1] + [[visible_text, last_reply[1] + '...']] # Generating the prompt kwargs = {'_continue': _continue} @@ -144,10 +155,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): if prompt is None: prompt = generate_chat_prompt(text, state, **kwargs) - # Yield *Is typing...* - if not any((regenerate, _continue)): - yield shared.history['visible'] + [[visible_text, shared.processing_message]] - # Generate for i in range(state['chat_generation_attempts']): reply = None @@ -158,26 +165,26 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): reply, next_character_found = extract_message_from_reply(reply, state) visible_reply = re.sub("(||{{user}})", state['name1'], reply) visible_reply = apply_extensions("output", visible_reply) + if _continue: + sep = ' ' if last_reply[0][-1] not in [' ', '\n'] else '' + reply = last_reply[0] + sep + reply + sep = ' ' if last_reply[1][-1] not in [' ', '\n'] else '' + visible_reply = last_reply[1] + sep + visible_reply # We need this global variable to handle the Stop event, # otherwise gradio gets confused if shared.stop_everything: return shared.history['visible'] + if just_started: just_started = False if not _continue: shared.history['internal'].append(['', '']) shared.history['visible'].append(['', '']) - if _continue: - sep = list(map(lambda x: ' ' if len(x) > 0 and x[-1] != ' ' else '', last_reply)) - shared.history['internal'][-1] = [text, f'{last_reply[0]}{sep[0]}{reply}'] - shared.history['visible'][-1] = [visible_text, f'{last_reply[1]}{sep[1]}{visible_reply}'] - else: - shared.history['internal'][-1] = [text, reply] - shared.history['visible'][-1] = [visible_text, visible_reply] - if not shared.args.no_stream: - yield shared.history['visible'] + shared.history['internal'][-1] = [text, reply] + shared.history['visible'][-1] = [visible_text, visible_reply] + yield shared.history['visible'] if next_character_found: break @@ -188,7 +195,6 @@ def chatbot_wrapper(text, state, regenerate=False, _continue=False): def impersonate_wrapper(text, state): - if shared.model_name == 'None' or shared.model is None: print("No model is loaded! Select one in the Model tab.") yield '' @@ -202,7 +208,6 @@ def impersonate_wrapper(text, state): # Yield *Is typing...* yield shared.processing_message - for i in range(state['chat_generation_attempts']): reply = None for reply in generate_reply(f"{prompt}{' ' if len(cumulative_reply) > 0 else ''}{cumulative_reply}", state, eos_token=eos_token, stopping_strings=stopping_strings): @@ -227,23 +232,16 @@ def regenerate_wrapper(text, state): if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) else: - last_visible = shared.history['visible'].pop() - last_internal = shared.history['internal'].pop() - # Yield '*Is typing...*' - yield chat_html_wrapper(shared.history['visible'] + [[last_visible[0], shared.processing_message]], state['name1'], state['name2'], state['mode']) - for history in chatbot_wrapper(last_internal[0], state, regenerate=True): - shared.history['visible'][-1] = [last_visible[0], history[-1][1]] - yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) + for history in chatbot_wrapper('', state, regenerate=True): + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode']) def continue_wrapper(text, state): if (len(shared.history['visible']) == 1 and not shared.history['visible'][0][0]) or len(shared.history['internal']) == 0: yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) else: - # Yield ' ...' - yield chat_html_wrapper(shared.history['visible'][:-1] + [[shared.history['visible'][-1][0], shared.history['visible'][-1][1] + ' ...']], state['name1'], state['name2'], state['mode']) - for history in chatbot_wrapper(shared.history['internal'][-1][0], state, _continue=True): - yield chat_html_wrapper(shared.history['visible'], state['name1'], state['name2'], state['mode']) + for history in chatbot_wrapper('', state, _continue=True): + yield chat_html_wrapper(history, state['name1'], state['name2'], state['mode']) def remove_last_message(name1, name2, mode): @@ -281,6 +279,7 @@ def send_dummy_reply(text, name1, name2, mode): if len(shared.history['visible']) > 0 and not shared.history['visible'][-1][1] == '': shared.history['visible'].append(['', '']) shared.history['internal'].append(['', '']) + shared.history['visible'][-1][1] = text shared.history['internal'][-1][1] = apply_extensions("input", text) return chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -300,7 +299,6 @@ def clear_chat_log(name1, name2, greeting, mode): # Save cleared logs save_history(mode) - return chat_html_wrapper(shared.history['visible'], name1, name2, mode) @@ -321,8 +319,8 @@ def tokenize_dialogue(dialogue, name1, name2, mode): for i in range(len(idx) - 1): messages.append(dialogue[idx[i]:idx[i + 1]].strip()) - messages.append(dialogue[idx[-1]:].strip()) + messages.append(dialogue[idx[-1]:].strip()) entry = ['', ''] for i in messages: if i.startswith(f'{name1}:'): @@ -331,6 +329,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode): entry[1] = i[len(f'{name2}:'):].strip() if not (len(entry[0]) == 0 and len(entry[1]) == 0): history.append(entry) + entry = ['', ''] print("\033[1;32;1m\nDialogue tokenized to:\033[0;37;0m\n", end='') @@ -339,6 +338,7 @@ def tokenize_dialogue(dialogue, name1, name2, mode): print("\n") for line in column.strip().split('\n'): print("| " + line + "\n") + print("|\n") print("------------------------------") @@ -351,14 +351,17 @@ def save_history(mode, timestamp=False): if mode == 'instruct': if not timestamp: return + fname = f"Instruct_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" else: if timestamp: fname = f"{shared.character}_{datetime.now().strftime('%Y%m%d-%H%M%S')}.json" else: fname = f"{shared.character}_persistent.json" + if not Path('logs').exists(): Path('logs').mkdir() + with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) @@ -389,8 +392,10 @@ def build_pygmalion_style_context(data): context = "" if 'char_persona' in data and data['char_persona'] != '': context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" + if 'world_scenario' in data and data['world_scenario'] != '': context += f"Scenario: {data['world_scenario']}\n" + context = f"{context.strip()}\n\n" return context @@ -405,6 +410,7 @@ def generate_pfp_cache(character): img = make_thumbnail(Image.open(path)) img.save(Path('cache/pfp_character.png'), format='PNG') return img + return None @@ -488,13 +494,17 @@ def upload_character(json_file, img, tavern=False): while Path(f'characters/{outfile_name}.json').exists(): outfile_name = f'{data["char_name"]}_{i:03d}' i += 1 + if tavern: outfile_name = f'TavernAI-{outfile_name}' + with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f: f.write(json_file) + if img is not None: img = Image.open(io.BytesIO(img)) img.save(Path(f'characters/{outfile_name}.png')) + print(f'New character saved to "characters/{outfile_name}.json".') return outfile_name From 47809e28aa015b6550833250ee9c97f531cd4378 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 01:04:48 -0300 Subject: [PATCH 07/19] Minor changes --- server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/server.py b/server.py index ca44cdb5..573552ff 100644 --- a/server.py +++ b/server.py @@ -806,7 +806,6 @@ def create_interface(): # notebook/default modes event handlers else: shared.input_params = [shared.gradio[k] for k in ['textbox', 'interface_state']] - if shared.args.notebook: output_params = [shared.gradio[k] for k in ['textbox', 'markdown', 'html']] else: @@ -908,7 +907,6 @@ if __name__ == "__main__": # If any model has been selected, load it if shared.model_name != 'None': - model_settings = get_model_specific_settings(shared.model_name) shared.settings.update(model_settings) # hijacking the interface defaults update_model_parameters(model_settings, initial=True) # hijacking the command-line arguments From b1ee674d75fca92c638e967118addb796a34a9b0 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 03:05:47 -0300 Subject: [PATCH 08/19] Make interface state (mostly) persistent on page reload --- modules/shared.py | 3 +++ modules/ui.py | 24 ++++++++++++++++++++---- server.py | 6 +++++- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index 82acf3c0..6b0c6f06 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -20,6 +20,9 @@ processing_message = '*Is typing...*' # UI elements (buttons, sliders, HTML, etc) gradio = {} +# For keeping the values of UI elements on page reload +persistent_interface_state = {} + # Generation input parameters input_params = [] diff --git a/modules/ui.py b/modules/ui.py index 5db36b3e..0ddcc833 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -33,9 +33,10 @@ def list_model_elements(): def list_interface_input_elements(chat=False): - elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens'] + elements = ['max_new_tokens', 'seed', 'temperature', 'top_p', 'top_k', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'no_repeat_ngram_size', 'min_length', 'do_sample', 'penalty_alpha', 'num_beams', 'length_penalty', 'early_stopping', 'add_bos_token', 'ban_eos_token', 'truncation_length', 'custom_stopping_strings', 'skip_special_tokens', 'preset_menu'] if chat: - elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template'] + elements += ['name1', 'name2', 'greeting', 'context', 'end_of_turn', 'chat_prompt_size', 'chat_generation_attempts', 'stop_at_newline', 'mode', 'instruction_template', 'character_menu'] + elements += list_model_elements() return elements @@ -44,11 +45,26 @@ def gather_interface_values(*args): output = {} for i, element in enumerate(shared.input_elements): output[element] = args[i] + + shared.persistent_interface_state = output return output -def apply_interface_values(state): - return [state[i] for i in list_interface_input_elements(chat=shared.is_chat())] +def apply_interface_values(state, use_persistent=False): + if use_persistent: + state = shared.persistent_interface_state + + elements = list_interface_input_elements(chat=shared.is_chat()) + if len(state) == 0: + return [gr.update() for k in elements] # Dummy, do nothing + else: + if use_persistent and 'mode' in state: + if state['mode'] == 'instruct': + return [state[k] if k not in ['character_menu'] else gr.update() for k in elements] + else: + return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements] + else: + return [state[k] for k in elements] class ToolButton(gr.Button, gr.components.FormComponent): diff --git a/server.py b/server.py index 573552ff..8308c9b8 100644 --- a/server.py +++ b/server.py @@ -32,6 +32,7 @@ import time import traceback import zipfile from datetime import datetime +from functools import partial from pathlib import Path import psutil @@ -846,6 +847,8 @@ def create_interface(): shared.gradio['count_tokens'].click(count_tokens, shared.gradio['textbox'], shared.gradio['status'], show_progress=False) shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") + shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) + # Launch the interface shared.gradio['interface'].queue() if shared.args.listen: @@ -855,7 +858,6 @@ def create_interface(): if __name__ == "__main__": - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -900,9 +902,11 @@ if __name__ == "__main__": print('The following models are available:\n') for i, model in enumerate(available_models): print(f'{i+1}. {model}') + print(f'\nWhich one do you want to load? 1-{len(available_models)}\n') i = int(input()) - 1 print() + shared.model_name = available_models[i] # If any model has been selected, load it From caaa5561593d645bf1676eca9db57df073aeada1 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 03:30:35 -0300 Subject: [PATCH 09/19] Move extensions block definition to the bottom --- server.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/server.py b/server.py index 8308c9b8..da786349 100644 --- a/server.py +++ b/server.py @@ -711,10 +711,6 @@ def create_interface(): set_interface_arguments, [shared.gradio[k] for k in ['interface_modes_menu', 'extensions_menu', 'bool_menu']], None).then( lambda: None, None, None, _js='() => {document.body.innerHTML=\'

Reloading...

\'; setTimeout(function(){location.reload()},2500); return []}') - # Extensions block - if shared.args.extensions is not None: - extensions_module.create_extensions_block() - # chat mode event handlers if shared.is_chat(): shared.input_params = [shared.gradio[k] for k in ['Chat input', 'interface_state']] @@ -848,6 +844,9 @@ def create_interface(): shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js}}}") shared.gradio['interface'].load(partial(ui.apply_interface_values, {}, use_persistent=True), None, [shared.gradio[k] for k in ui.list_interface_input_elements(chat=shared.is_chat())], show_progress=False) + # Extensions block + if shared.args.extensions is not None: + extensions_module.create_extensions_block() # Launch the interface shared.gradio['interface'].queue() From 2f6e2ddeac28090c06c9c2cc14bce971dba70ffd Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 03:42:03 -0300 Subject: [PATCH 10/19] Bump llama-cpp-python version --- requirements.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/requirements.txt b/requirements.txt index e5f0a8f7..2ee5274e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,5 +16,5 @@ tqdm git+https://github.com/huggingface/peft transformers==4.28.1 bitsandbytes==0.38.1; platform_system != "Windows" -llama-cpp-python==0.1.34; platform_system != "Windows" -https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.34/llama_cpp_python-0.1.34-cp310-cp310-win_amd64.whl; platform_system == "Windows" +llama-cpp-python==0.1.36; platform_system != "Windows" +https://github.com/abetlen/llama-cpp-python/releases/download/v0.1.36/llama_cpp_python-0.1.36-cp310-cp310-win_amd64.whl; platform_system == "Windows" From 78d1977ebf9ffaa8c973ba6747985c5bf8342fbf Mon Sep 17 00:00:00 2001 From: eiery <19350831+eiery@users.noreply.github.com> Date: Mon, 24 Apr 2023 02:46:18 -0400 Subject: [PATCH 11/19] add n_batch support for llama.cpp (#1115) --- README.md | 1 + modules/llamacpp_model_alternative.py | 3 ++- modules/shared.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index bea64666..d29c85a9 100644 --- a/README.md +++ b/README.md @@ -220,6 +220,7 @@ Optionally, you can use the following command-line flags: | Flag | Description | |-------------|-------------| | `--threads` | Number of threads to use in llama.cpp. | +| `--n_batch` | Processing batch size for llama.cpp. | #### GPTQ diff --git a/modules/llamacpp_model_alternative.py b/modules/llamacpp_model_alternative.py index 6bdf9bc3..2671f227 100644 --- a/modules/llamacpp_model_alternative.py +++ b/modules/llamacpp_model_alternative.py @@ -24,7 +24,8 @@ class LlamaCppModel: 'model_path': str(path), 'n_ctx': 2048, 'seed': 0, - 'n_threads': shared.args.threads or None + 'n_threads': shared.args.threads or None, + 'n_batch': shared.args.n_batch } self.model = Llama(**params) self.model.set_cache(LlamaCache) diff --git a/modules/shared.py b/modules/shared.py index 6b0c6f06..9a24f220 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -119,6 +119,7 @@ parser.add_argument('--trust-remote-code', action='store_true', help="Set trust_ # llama.cpp parser.add_argument('--threads', type=int, default=0, help='Number of threads to use in llama.cpp.') +parser.add_argument('--n_batch', type=int, default=8, help='Processing batch size for llama.cpp.') # GPTQ parser.add_argument('--wbits', type=int, default=0, help='Load a pre-quantized model with specified precision in bits. 2, 3, 4 and 8 are supported.') From c86e9a3372c078bfae9bee8dfc147583b3f30b8c Mon Sep 17 00:00:00 2001 From: MajdajkD Date: Mon, 24 Apr 2023 08:51:32 +0200 Subject: [PATCH 12/19] fix websocket batching (#1511) --- extensions/api/streaming_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 5ffd925b..c06facc5 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -44,6 +44,8 @@ async def _handle_connection(websocket, path): 'text': to_send })) + await asyncio.sleep(0) + skip_index += len(to_send) message_num += 1 From 0c32ae27cc7712c03953e5583a2abe35ea938b0a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 11:50:51 -0300 Subject: [PATCH 13/19] Only load the default history if it's empty --- modules/chat.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/chat.py b/modules/chat.py index 6801741a..863353d8 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -483,7 +483,8 @@ def load_character(character, name1, name2, mode): def load_default_history(name1, name2): - load_character("None", name1, name2, "chat") + if len(shared.history['visible']) == 0 and len(shared.history['internal']) == 0: + load_character("None", name1, name2, "chat") def upload_character(json_file, img, tavern=False): From b6af2e56a230458638726d1eeb2f1c427ee5773f Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 13:19:42 -0300 Subject: [PATCH 14/19] Add --character flag, add character to settings.json --- README.md | 1 + modules/shared.py | 2 ++ modules/ui.py | 6 +++--- server.py | 16 ++++++++++------ settings-template.json | 1 + 5 files changed, 17 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index d29c85a9..3b0c6936 100644 --- a/README.md +++ b/README.md @@ -188,6 +188,7 @@ Optionally, you can use the following command-line flags: | `-h`, `--help` | Show this help message and exit. | | `--notebook` | Launch the web UI in notebook mode, where the output is written to the same text box as the input. | | `--chat` | Launch the web UI in chat mode. | +| `--character CHARACTER` | The name of the character to load in chat mode by default. | | `--model MODEL` | Name of the model to load by default. | | `--lora LORA` | Name of the LoRA to apply to the model by default. | | `--model-dir MODEL_DIR` | Path to directory with all the models. | diff --git a/modules/shared.py b/modules/shared.py index 9a24f220..a3b867de 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -34,6 +34,7 @@ settings = { 'max_new_tokens_min': 1, 'max_new_tokens_max': 2000, 'seed': -1, + 'character': 'None', 'name1': 'You', 'name2': 'Assistant', 'context': 'This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.', @@ -93,6 +94,7 @@ parser = argparse.ArgumentParser(formatter_class=lambda prog: argparse.HelpForma parser.add_argument('--notebook', action='store_true', help='Launch the web UI in notebook mode, where the output is written to the same text box as the input.') parser.add_argument('--chat', action='store_true', help='Launch the web UI in chat mode with a style similar to the Character.AI website.') parser.add_argument('--cai-chat', action='store_true', help='DEPRECATED: use --chat instead.') +parser.add_argument('--character', type=str, help='The name of the character to load in chat mode by default.') parser.add_argument('--model', type=str, help='Name of the model to load by default.') parser.add_argument('--lora', type=str, help='Name of the LoRA to apply to the model by default.') parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") diff --git a/modules/ui.py b/modules/ui.py index 0ddcc833..0d62ab3c 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -60,11 +60,11 @@ def apply_interface_values(state, use_persistent=False): else: if use_persistent and 'mode' in state: if state['mode'] == 'instruct': - return [state[k] if k not in ['character_menu'] else gr.update() for k in elements] + return [state[k] if (k not in ['character_menu'] and k in state) else gr.update() for k in elements] else: - return [state[k] if k not in ['instruction_template'] else gr.update() for k in elements] + return [state[k] if (k not in ['instruction_template'] and k in state) else gr.update() for k in elements] else: - return [state[k] for k in elements] + return [state[k] if k in state else gr.update() for k in elements] class ToolButton(gr.Button, gr.components.FormComponent): diff --git a/server.py b/server.py index da786349..dc804a98 100644 --- a/server.py +++ b/server.py @@ -544,7 +544,7 @@ def create_interface(): shared.gradio['Clear history-cancel'] = gr.Button('Cancel', visible=False) shared.gradio['mode'] = gr.Radio(choices=['cai-chat', 'chat', 'instruct'], value=shared.settings['mode'], label='Mode') - shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value=shared.settings['instruction_template'], visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.') + shared.gradio['instruction_template'] = gr.Dropdown(choices=get_available_instruction_templates(), label='Instruction template', value='None', visible=shared.settings['mode'] == 'instruct', info='Change this according to the model/LoRA that you are using.') with gr.Tab('Character', elem_id='chat-settings'): with gr.Row(): @@ -560,7 +560,7 @@ def create_interface(): shared.gradio['your_picture'] = gr.Image(label='Your picture', type='pil', value=Image.open(Path('cache/pfp_me.png')) if Path('cache/pfp_me.png').exists() else None) with gr.Row(): - shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), value='None', label='Character', elem_id='character-menu') + shared.gradio['character_menu'] = gr.Dropdown(choices=get_available_characters(), label='Character', elem_id='character-menu') ui.create_refresh_button(shared.gradio['character_menu'], lambda: None, lambda: {'choices': get_available_characters()}, 'refresh-button') with gr.Row(): @@ -794,11 +794,7 @@ def create_interface(): shared.gradio['character_menu'].change(chat.load_character, [shared.gradio[k] for k in ['character_menu', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) shared.gradio['upload_img_tavern'].upload(chat.upload_tavern_character, [shared.gradio['upload_img_tavern'], shared.gradio['name1'], shared.gradio['name2']], [shared.gradio['character_menu']]) shared.gradio['your_picture'].change(chat.upload_your_profile_picture, [shared.gradio[k] for k in ['your_picture', 'name1', 'name2', 'mode']], shared.gradio['display']) - shared.gradio['interface'].load(None, None, None, _js=f"() => {{{ui.main_js+ui.chat_js}}}") - shared.gradio['interface'].load(chat.load_character, [shared.gradio[k] for k in ['instruction_template', 'name1', 'name2', 'mode']], [shared.gradio[k] for k in ['name1', 'name2', 'character_picture', 'greeting', 'context', 'end_of_turn', 'display']]) - shared.gradio['interface'].load(chat.load_default_history, [shared.gradio[k] for k in ['name1', 'name2']], None) - shared.gradio['interface'].load(chat.redraw_html, reload_inputs, shared.gradio['display'], show_progress=True) # notebook/default modes event handlers else: @@ -919,6 +915,14 @@ if __name__ == "__main__": if shared.args.lora: add_lora_to_model([shared.args.lora]) + # Force a character to be loaded + if shared.is_chat(): + shared.persistent_interface_state.update({ + 'mode': shared.settings['mode'], + 'character_menu': shared.args.character or shared.settings['character'], + 'instruction_template': shared.settings['instruction_template'] + }) + # Launch the web UI create_interface() while True: diff --git a/settings-template.json b/settings-template.json index 286add0b..55032aa9 100644 --- a/settings-template.json +++ b/settings-template.json @@ -3,6 +3,7 @@ "max_new_tokens_min": 1, "max_new_tokens_max": 2000, "seed": -1, + "character": "None", "name1": "You", "name2": "Assistant", "context": "This is a conversation with your Assistant. The Assistant is very helpful and is eager to chat with you and answer your questions.", From 2f4f1241324df9666b11d1699194c163fe162c72 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 13:27:24 -0300 Subject: [PATCH 15/19] Remove obsolete function --- modules/chat.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/modules/chat.py b/modules/chat.py index 863353d8..efdc0de8 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -482,11 +482,6 @@ def load_character(character, name1, name2, mode): return name1, name2, picture, greeting, context, end_of_turn, chat_html_wrapper(shared.history['visible'], name1, name2, mode) -def load_default_history(name1, name2): - if len(shared.history['visible']) == 0 and len(shared.history['internal']) == 0: - load_character("None", name1, name2, "chat") - - def upload_character(json_file, img, tavern=False): json_file = json_file if type(json_file) == str else json_file.decode('utf-8') data = json.loads(json_file) From 1a0c12c6f203d28b865346a767cd30720737a5ca Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 24 Apr 2023 19:24:12 -0300 Subject: [PATCH 16/19] Refactor text-generation.py a bit --- modules/text_generation.py | 79 ++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 28 deletions(-) diff --git a/modules/text_generation.py b/modules/text_generation.py index 032fc84c..936ec647 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -113,9 +113,11 @@ def set_manual_seed(seed): seed = int(seed) if seed == -1: seed = random.randint(1, 2**31) + torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + return seed @@ -123,8 +125,41 @@ def stop_everything_event(): shared.stop_everything = True -def generate_reply(question, state, eos_token=None, stopping_strings=[]): +def get_generate_params(state): + generate_params = {} + # Models that are not on transformers + if shared.model_type in ['rwkv', 'llamacpp']: + generate_params['token_count'] = state['max_new_tokens'] + for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: + generate_params[k] = state[k] + else: + # FlexGen + if shared.args.flexgen: + for k in ['max_new_tokens', 'do_sample', 'temperature']: + generate_params[k] = state[k] + + if not shared.args.no_stream: + generate_params['max_new_tokens'] = 8 + + # transformers + else: + for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: + generate_params[k] = state[k] + + if state['ban_eos_token']: + generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] + + if shared.args.no_cache: + generate_params.update({'use_cache': False}) + + if shared.args.deepspeed: + generate_params.update({'synced_gpus': True}) + + return generate_params + + +def generate_reply(question, state, eos_token=None, stopping_strings=[]): if shared.model_name == 'None' or shared.model is None: print("No model is loaded! Select one in the Model tab.") yield formatted_outputs(question, shared.model_name) @@ -133,40 +168,37 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): clear_torch_cache() seed = set_manual_seed(state['seed']) shared.stop_everything = False - generate_params = {} + generate_params = get_generate_params(state) t0 = time.time() + # Preparing the input original_question = question if not shared.is_chat(): question = apply_extensions('input', question) - # These models are not part of Hugging Face, so we handle them - # separately and terminate the function call earlier + # If the model is not on transformers, handle it separately and end this + # function call earlier. if shared.model_type in ['rwkv', 'llamacpp']: - if shared.args.verbose: print(f'\n\n{question}\n--------------------\n') - for k in ['temperature', 'top_p', 'top_k', 'repetition_penalty']: - generate_params[k] = state[k] - generate_params['token_count'] = state['max_new_tokens'] try: if shared.args.no_stream: reply = shared.model.generate(context=question, **generate_params) output = original_question + reply if not shared.is_chat(): reply = original_question + apply_extensions('output', reply) + yield formatted_outputs(reply, shared.model_name) else: if not shared.is_chat(): yield formatted_outputs(question, shared.model_name) - # RWKV has proper streaming, which is very nice. - # No need to generate 8 tokens at a time. for reply in shared.model.generate_with_streaming(context=question, **generate_params): output = original_question + reply if not shared.is_chat(): reply = original_question + apply_extensions('output', reply) + yield formatted_outputs(reply, shared.model_name) except Exception: @@ -178,18 +210,19 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): print(f'Output generated in {(t1-t0):.2f} seconds ({new_tokens/(t1-t0):.2f} tokens/s, {new_tokens} tokens, context {original_tokens}, seed {seed})') return + # Encode the input input_ids = encode(question, add_bos_token=state['add_bos_token'], truncation_length=get_max_prompt_length(state)) output = input_ids[0] - + cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) if shared.args.verbose: print(f'\n\n{decode(input_ids[0], state["skip_special_tokens"])}\n--------------------\n') - cuda = not any((shared.args.cpu, shared.args.deepspeed, shared.args.flexgen)) + # Find the eos tokens eos_token_ids = [shared.tokenizer.eos_token_id] if shared.tokenizer.eos_token_id is not None else [] if eos_token is not None: eos_token_ids.append(int(encode(eos_token)[0][-1])) - # Handling the stopping strings + # Create the StoppingCriteriaList with the stopping strings stopping_criteria_list = transformers.StoppingCriteriaList() for st in (stopping_strings, ast.literal_eval(f"[{state['custom_stopping_strings']}]")): if type(st) is list and len(st) > 0: @@ -197,24 +230,14 @@ def generate_reply(question, state, eos_token=None, stopping_strings=[]): stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=sentinel_token_ids, starting_idx=len(input_ids[0]))) break - if not shared.args.flexgen: - for k in ['max_new_tokens', 'do_sample', 'temperature', 'top_p', 'typical_p', 'repetition_penalty', 'encoder_repetition_penalty', 'top_k', 'min_length', 'no_repeat_ngram_size', 'num_beams', 'penalty_alpha', 'length_penalty', 'early_stopping']: - generate_params[k] = state[k] + # Update generate_params with the eos token and the stopping strings + if shared.args.flexgen: + generate_params['stop'] = eos_token_ids[-1] + else: generate_params['eos_token_id'] = eos_token_ids generate_params['stopping_criteria'] = stopping_criteria_list - if state['ban_eos_token']: - generate_params['suppress_tokens'] = [shared.tokenizer.eos_token_id] - else: - for k in ['max_new_tokens', 'do_sample', 'temperature']: - generate_params[k] = state[k] - generate_params['stop'] = eos_token_ids[-1] - if not shared.args.no_stream: - generate_params['max_new_tokens'] = 8 - if shared.args.no_cache: - generate_params.update({'use_cache': False}) - if shared.args.deepspeed: - generate_params.update({'synced_gpus': True}) + # Add the encoded tokens to generate_params if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) question, filler_input_ids, inputs_embeds = apply_extensions('tokenizer', state, question, filler_input_ids, inputs_embeds) From b0ce750d4ecb8a9cf46f6319a143bc21c8b6d087 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 25 Apr 2023 00:10:21 -0300 Subject: [PATCH 17/19] Add spaces --- modules/extensions.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/modules/extensions.py b/modules/extensions.py index 92d86772..547fe3f4 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -23,12 +23,14 @@ def load_extensions(): if extension not in setup_called and hasattr(extension, "setup"): setup_called.add(extension) extension.setup() + state[name] = [True, i] if name != 'api': print('Ok.') except: if name != 'api': print('Fail.') + traceback.print_exc() @@ -44,6 +46,7 @@ def _apply_string_extensions(function_name, text): for extension, _ in iterator(): if hasattr(extension, function_name): text = getattr(extension, function_name)(text) + return text @@ -56,6 +59,7 @@ def _apply_input_hijack(text, visible_text): text, visible_text = extension.input_hijack['value'](text, visible_text) else: text, visible_text = extension.input_hijack['value'] + return text, visible_text @@ -65,8 +69,10 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs): for extension, _ in iterator(): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): custom_generate_chat_prompt = extension.custom_generate_chat_prompt + if custom_generate_chat_prompt is not None: return custom_generate_chat_prompt(text, state, **kwargs) + return None @@ -75,6 +81,7 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e for extension, _ in iterator(): if hasattr(extension, function_name): prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) + return prompt, input_ids, input_embeds @@ -91,6 +98,7 @@ EXTENSION_MAP = { def apply_extensions(typ, *args, **kwargs): if typ not in EXTENSION_MAP: raise ValueError(f"Invalid extension type {typ}") + return EXTENSION_MAP[typ](*args, **kwargs) From ebca3f86d55400f359550490aedad0b4082024bc Mon Sep 17 00:00:00 2001 From: da3dsoul Date: Mon, 24 Apr 2023 23:23:11 -0400 Subject: [PATCH 18/19] Apply the settings for extensions after import, but before setup() (#1484) --- modules/extensions.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/modules/extensions.py b/modules/extensions.py index 547fe3f4..731fe089 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -11,6 +11,18 @@ available_extensions = [] setup_called = set() +def apply_settings(extension, name): + if not hasattr(extension, 'params'): + return + + for param in extension.params: + _id = f"{name}-{param}" + if _id not in shared.settings: + continue + + extension.params[param] = shared.settings[_id] + + def load_extensions(): global state, setup_called for i, name in enumerate(shared.args.extensions): @@ -22,6 +34,7 @@ def load_extensions(): extension = getattr(extensions, name).script if extension not in setup_called and hasattr(extension, "setup"): setup_called.add(extension) + apply_settings(extension, name) extension.setup() state[name] = [True, i] @@ -105,18 +118,11 @@ def apply_extensions(typ, *args, **kwargs): def create_extensions_block(): global setup_called - # Updating the default values - for extension, name in iterator(): - if hasattr(extension, 'params'): - for param in extension.params: - _id = f"{name}-{param}" - if _id in shared.settings: - extension.params[param] = shared.settings[_id] - should_display_ui = False for extension, name in iterator(): if hasattr(extension, "ui"): should_display_ui = True + break # Creating the extension ui elements if should_display_ui: From da812600f49309cd7090464d7ffcfa1246d3a309 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 25 Apr 2023 01:16:23 -0300 Subject: [PATCH 19/19] Apply settings regardless of setup() function --- modules/extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extensions.py b/modules/extensions.py index 731fe089..e07e90ac 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -32,9 +32,9 @@ def load_extensions(): try: exec(f"import extensions.{name}.script") extension = getattr(extensions, name).script + apply_settings(extension, name) if extension not in setup_called and hasattr(extension, "setup"): setup_called.add(extension) - apply_settings(extension, name) extension.setup() state[name] = [True, i]