From 3913155c1fdd209bea1f3404e1557936dd475979 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Tue, 9 May 2023 22:49:39 -0300 Subject: [PATCH] Style improvements (#1957) --- api-example-stream.py | 7 ++++--- api-example.py | 2 ++ extensions/api/blocking_api.py | 3 +-- extensions/api/script.py | 1 + extensions/api/streaming_api.py | 8 ++++---- extensions/character_bias/script.py | 3 ++- extensions/llava/script.py | 4 +++- extensions/multimodal/multimodal_embedder.py | 3 ++- extensions/multimodal/script.py | 1 + extensions/openai/script.py | 8 +++++--- extensions/sd_api_pictures/script.py | 17 +++++++++-------- extensions/silero_tts/script.py | 3 ++- extensions/silero_tts/test_tts.py | 1 - extensions/silero_tts/tts_preprocessor.py | 2 +- extensions/superbooga/script.py | 8 ++++---- extensions/whisper_stt/script.py | 1 + modules/RWKV.py | 20 ++++++++++---------- modules/deepspeed_parameters.py | 1 - modules/evaluate.py | 2 ++ modules/extensions.py | 5 ++--- modules/logging_colors.py | 4 +++- modules/models.py | 8 ++++---- server.py | 2 +- 23 files changed, 64 insertions(+), 50 deletions(-) diff --git a/api-example-stream.py b/api-example-stream.py index 49058776..ad8f7bf8 100644 --- a/api-example-stream.py +++ b/api-example-stream.py @@ -5,7 +5,7 @@ import sys try: import websockets except ImportError: - print("Websockets package not found. Make sure it's installed.") + print("Websockets package not found. Make sure it's installed.") # For local streaming, the websockets are hosted without ssl - ws:// HOST = 'localhost:5005' @@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream' # For reverse-proxied streaming, the remote will likely host with ssl - wss:// # URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream' + async def run(context): # Note: the selected defaults change from time to time. request = { @@ -42,7 +43,7 @@ async def run(context): async with websockets.connect(URI, ping_interval=None) as websocket: await websocket.send(json.dumps(request)) - yield context # Remove this if you just want to see the reply + yield context # Remove this if you just want to see the reply while True: incoming_data = await websocket.recv() @@ -58,7 +59,7 @@ async def run(context): async def print_response_stream(prompt): async for response in run(prompt): print(response, end='') - sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. + sys.stdout.flush() # If we don't flush, we won't see tokens in realtime. if __name__ == '__main__': diff --git a/api-example.py b/api-example.py index d6053fda..f35ea1db 100644 --- a/api-example.py +++ b/api-example.py @@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate' # For reverse-proxied streaming, the remote will likely host with ssl - https:// # URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate' + def run(prompt): request = { 'prompt': prompt, @@ -37,6 +38,7 @@ def run(prompt): result = response.json()['results'][0]['text'] print(prompt + result) + if __name__ == '__main__': prompt = "In order to make homemade bread, follow these steps:\n1)" run(prompt) diff --git a/extensions/api/blocking_api.py b/extensions/api/blocking_api.py index 2c72d789..57cc0b9e 100644 --- a/extensions/api/blocking_api.py +++ b/extensions/api/blocking_api.py @@ -2,11 +2,10 @@ import json from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread +from extensions.api.util import build_parameters, try_start_cloudflared from modules import shared from modules.text_generation import encode, generate_reply -from extensions.api.util import build_parameters, try_start_cloudflared - class Handler(BaseHTTPRequestHandler): def do_GET(self): diff --git a/extensions/api/script.py b/extensions/api/script.py index efeed71f..3911b106 100644 --- a/extensions/api/script.py +++ b/extensions/api/script.py @@ -5,6 +5,7 @@ from modules import shared BLOCKING_PORT = 5000 STREAMING_PORT = 5005 + def setup(): blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api) streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api) diff --git a/extensions/api/streaming_api.py b/extensions/api/streaming_api.py index 42570c94..e847178a 100644 --- a/extensions/api/streaming_api.py +++ b/extensions/api/streaming_api.py @@ -1,12 +1,12 @@ -import json import asyncio -from websockets.server import serve +import json from threading import Thread -from modules import shared -from modules.text_generation import generate_reply +from websockets.server import serve from extensions.api.util import build_parameters, try_start_cloudflared +from modules import shared +from modules.text_generation import generate_reply PATH = '/api/v1/stream' diff --git a/extensions/character_bias/script.py b/extensions/character_bias/script.py index 614d9ce3..ff12f3af 100644 --- a/extensions/character_bias/script.py +++ b/extensions/character_bias/script.py @@ -1,6 +1,7 @@ -import gradio as gr import os +import gradio as gr + # get the current directory of the script current_dir = os.path.dirname(os.path.abspath(__file__)) diff --git a/extensions/llava/script.py b/extensions/llava/script.py index eaf6b313..3f6c73a2 100644 --- a/extensions/llava/script.py +++ b/extensions/llava/script.py @@ -1,6 +1,8 @@ -import gradio as gr import logging +import gradio as gr + + def ui(): gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead") logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead") diff --git a/extensions/multimodal/multimodal_embedder.py b/extensions/multimodal/multimodal_embedder.py index 816e3866..62e99ca7 100644 --- a/extensions/multimodal/multimodal_embedder.py +++ b/extensions/multimodal/multimodal_embedder.py @@ -6,10 +6,11 @@ from io import BytesIO from typing import Any, List, Optional import torch +from PIL import Image + from extensions.multimodal.pipeline_loader import load_pipeline from modules import shared from modules.text_generation import encode, get_max_prompt_length -from PIL import Image @dataclass diff --git a/extensions/multimodal/script.py b/extensions/multimodal/script.py index aeaadffd..2ca11bf5 100644 --- a/extensions/multimodal/script.py +++ b/extensions/multimodal/script.py @@ -7,6 +7,7 @@ from io import BytesIO import gradio as gr import torch + from extensions.multimodal.multimodal_embedder import MultimodalEmbedder from modules import shared diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 9eb35a46..c46dbe04 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -1,11 +1,12 @@ import base64 import json -import numpy as np import os import time from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread +import numpy as np + from modules import shared from modules.text_generation import encode, generate_reply @@ -61,6 +62,7 @@ def float_list_to_base64(float_list): ascii_string = encoded_bytes.decode('ascii') return ascii_string + class Handler(BaseHTTPRequestHandler): def do_GET(self): if self.path.startswith('/v1/models'): @@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler): "created": created_time, "model": model, # TODO: add Lora info? resp_list: [{ - "index": 0, - "finish_reason": "stop", + "index": 0, + "finish_reason": "stop", }], "usage": { "prompt_tokens": token_count, diff --git a/extensions/sd_api_pictures/script.py b/extensions/sd_api_pictures/script.py index 2d4e39dc..2c054242 100644 --- a/extensions/sd_api_pictures/script.py +++ b/extensions/sd_api_pictures/script.py @@ -6,12 +6,13 @@ from datetime import date from pathlib import Path import gradio as gr -import modules.shared as shared import requests import torch -from modules.models import reload_model, unload_model from PIL import Image +import modules.shared as shared +from modules.models import reload_model, unload_model + torch._C._jit_set_profiling_mode(False) # parameters which can be customized in settings.json of webui @@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd picture_response = False # specifies if the next model response should appear as a picture + def remove_surrounded_chars(string): # this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR # 'as few symbols as possible (0 upwards) between an asterisk and the end of the string' @@ -122,7 +124,6 @@ def input_modifier(string): # Get and save the Stable Diffusion-generated picture def get_SD_pictures(description): - global params if params['manage_VRAM']: @@ -259,6 +260,7 @@ def SD_api_address_update(address): return gr.Textbox.update(label=msg) + def ui(): # Gradio elements @@ -290,12 +292,11 @@ def ui(): cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box") with gr.Column() as hr_options: restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces') - enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') + enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix') with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options: - hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') - denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') - hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') - + hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by') + denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength') + hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler') # Event functions to update the parameters in the backend address.change(lambda x: params.update({"address": filter_address(x)}), address, None) diff --git a/extensions/silero_tts/script.py b/extensions/silero_tts/script.py index 345e3821..3166bb63 100644 --- a/extensions/silero_tts/script.py +++ b/extensions/silero_tts/script.py @@ -4,6 +4,7 @@ from pathlib import Path import gradio as gr import torch + from extensions.silero_tts import tts_preprocessor from modules import chat, shared from modules.html_generator import chat_html_wrapper @@ -216,4 +217,4 @@ def ui(): # Play preview preview_text.submit(voice_preview, preview_text, preview_audio) - preview_play.click(voice_preview, preview_text, preview_audio) \ No newline at end of file + preview_play.click(voice_preview, preview_text, preview_audio) diff --git a/extensions/silero_tts/test_tts.py b/extensions/silero_tts/test_tts.py index ad8ee764..ebc2c102 100644 --- a/extensions/silero_tts/test_tts.py +++ b/extensions/silero_tts/test_tts.py @@ -2,7 +2,6 @@ import time from pathlib import Path import torch - import tts_preprocessor torch._C._jit_set_profiling_mode(False) diff --git a/extensions/silero_tts/tts_preprocessor.py b/extensions/silero_tts/tts_preprocessor.py index eb2ca41b..daefdcbd 100644 --- a/extensions/silero_tts/tts_preprocessor.py +++ b/extensions/silero_tts/tts_preprocessor.py @@ -69,7 +69,7 @@ def remove_surrounded_chars(string): # first this expression will check if there is a string nested exclusively between a alt= # and a style= string. This would correspond to only a the alt text of an embedded image # If it matches it will only keep that part as the string, and rend it for further processing - # Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any + # Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any # asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string' if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL): m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL) diff --git a/extensions/superbooga/script.py b/extensions/superbooga/script.py index 5b98128e..e239c58a 100644 --- a/extensions/superbooga/script.py +++ b/extensions/superbooga/script.py @@ -59,7 +59,7 @@ class ChromaCollector(Collecter): def get_ids(self, search_strings: list[str], n_results: int) -> list[str]: n_results = min(len(self.ids), n_results) result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0] - return list(map(lambda x : int(x[2:]), result)) + return list(map(lambda x: int(x[2:]), result)) def clear(self): self.collection.delete(ids=self.ids) @@ -162,13 +162,13 @@ def input_modifier(string): def custom_generate_chat_prompt(user_input, state, **kwargs): if len(shared.history['internal']) > 2 and user_input != '': chunks = [] - for i in range(len(shared.history['internal'])-1): + for i in range(len(shared.history['internal']) - 1): chunks.append('\n'.join(shared.history['internal'][i])) add_chunks_to_collector(chunks) query = '\n'.join(shared.history['internal'][-1] + [user_input]) try: - best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1) + best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1) # Sort the history by relevance instead of by chronological order, # except for the latest message @@ -226,7 +226,7 @@ def ui(): ## Chat mode - In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair. + In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair. That is, the prompt will include (starting from the end): diff --git a/extensions/whisper_stt/script.py b/extensions/whisper_stt/script.py index 9daee7be..32226404 100644 --- a/extensions/whisper_stt/script.py +++ b/extensions/whisper_stt/script.py @@ -1,5 +1,6 @@ import gradio as gr import speech_recognition as sr + from modules import shared input_hijack = { diff --git a/modules/RWKV.py b/modules/RWKV.py index 35d650e1..bb6bab50 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -24,13 +24,12 @@ class RWKVModel: @classmethod def from_pretrained(self, path, dtype="fp16", device="cuda"): tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json") - if shared.args.rwkv_strategy is None: model = RWKV(model=str(path), strategy=f'{device} {dtype}') else: model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy) - pipeline = PIPELINE(model, str(tokenizer_path)) + pipeline = PIPELINE(model, str(tokenizer_path)) result = self() result.pipeline = pipeline result.model = model @@ -83,7 +82,6 @@ class RWKVModel: out = self.cached_output_logits for i in range(token_count): - # forward tokens = self.pipeline.encode(ctx) if i == 0 else [token] while len(tokens) > 0: @@ -91,35 +89,38 @@ class RWKVModel: tokens = tokens[args.chunk_len:] # cache the model state after scanning the context - # we don't cache the state after processing our own generated tokens because - # the output string might be post-processed arbitrarily. Therefore, what's fed into the model + # we don't cache the state after processing our own generated tokens because + # the output string might be post-processed arbitrarily. Therefore, what's fed into the model # on the next round of chat might be slightly different what what it output on the previous round if i == 0: self.cached_context += ctx self.cached_model_state = copy.deepcopy(state) self.cached_output_logits = copy.deepcopy(out) - + # adjust probabilities for n in args.token_ban: out[n] = -float('inf') + for n in occurrence: out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) - + # sampler token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k) if token in args.token_stop: break + all_tokens += [token] if token not in occurrence: occurrence[token] = 1 else: occurrence[token] += 1 - + # output tmp = self.pipeline.decode([token]) - if '\ufffd' not in tmp: # is valid utf-8 string? + if '\ufffd' not in tmp: # is valid utf-8 string? if callback: callback(tmp) + out_str += tmp return out_str @@ -133,7 +134,6 @@ class RWKVTokenizer: def from_pretrained(self, path): tokenizer_path = path / "20B_tokenizer.json" tokenizer = Tokenizer.from_file(str(tokenizer_path)) - result = self() result.tokenizer = tokenizer return result diff --git a/modules/deepspeed_parameters.py b/modules/deepspeed_parameters.py index 3dbed437..9116f579 100644 --- a/modules/deepspeed_parameters.py +++ b/modules/deepspeed_parameters.py @@ -1,5 +1,4 @@ def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir): - ''' DeepSpeed configration https://huggingface.co/docs/transformers/main_classes/deepspeed diff --git a/modules/evaluate.py b/modules/evaluate.py index 3134280c..adafa713 100644 --- a/modules/evaluate.py +++ b/modules/evaluate.py @@ -20,6 +20,8 @@ def load_past_evaluations(): return df else: return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment']) + + past_evaluations = load_past_evaluations() diff --git a/modules/extensions.py b/modules/extensions.py index 8e88e0cc..47629012 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -7,7 +7,6 @@ import gradio as gr import extensions import modules.shared as shared - state = {} available_extensions = [] setup_called = set() @@ -91,7 +90,7 @@ def _apply_state_modifier_extensions(state): state = getattr(extension, "state_modifier")(state) return state - + # Extension functions that override the default tokenizer output - currently only the first one will work def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds): @@ -108,7 +107,7 @@ def _apply_custom_tokenized_length(prompt): for extension, _ in iterator(): if hasattr(extension, 'custom_tokenized_length'): return getattr(extension, 'custom_tokenized_length')(prompt) - + return None diff --git a/modules/logging_colors.py b/modules/logging_colors.py index 5485b090..5c9714f7 100644 --- a/modules/logging_colors.py +++ b/modules/logging_colors.py @@ -1,6 +1,8 @@ # Copied from https://stackoverflow.com/a/1336640 import logging +import platform + def add_coloring_to_emit_windows(fn): # add methods we need to the class @@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn): def _set_color(self, code): import ctypes + # Constants from the Windows API self.STD_OUTPUT_HANDLE = -11 hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE) @@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn): return new -import platform if platform.system() == 'Windows': # Windows does not support ANSI escapes and we are using API calls to set the console color logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit) diff --git a/modules/models.py b/modules/models.py index 1f2219ae..d5f6594c 100644 --- a/modules/models.py +++ b/modules/models.py @@ -161,10 +161,10 @@ def load_model(model_name): # Custom else: params = { - "low_cpu_mem_usage": True, - "trust_remote_code": trust_remote_code + "low_cpu_mem_usage": True, + "trust_remote_code": trust_remote_code } - + if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)): logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.") shared.args.cpu = True @@ -288,7 +288,7 @@ def load_soft_prompt(name): logging.info(f"{field}: {', '.join(j[field])}") else: logging.info(f"{field}: {j[field]}") - + logging.info() tensor = np.load('tensor.npy') Path('tensor.npy').unlink() diff --git a/server.py b/server.py index e2bbeaef..10df484b 100644 --- a/server.py +++ b/server.py @@ -377,7 +377,7 @@ def create_model_menus(): shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False) shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False) - shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load) + shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load) def create_settings_menus(default_preset):