mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 05:48:55 +01:00
Style improvements (#1957)
This commit is contained in:
parent
334486f527
commit
3913155c1f
@ -5,7 +5,7 @@ import sys
|
||||
try:
|
||||
import websockets
|
||||
except ImportError:
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
print("Websockets package not found. Make sure it's installed.")
|
||||
|
||||
# For local streaming, the websockets are hosted without ssl - ws://
|
||||
HOST = 'localhost:5005'
|
||||
@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||
|
||||
|
||||
async def run(context):
|
||||
# Note: the selected defaults change from time to time.
|
||||
request = {
|
||||
@ -42,7 +43,7 @@ async def run(context):
|
||||
async with websockets.connect(URI, ping_interval=None) as websocket:
|
||||
await websocket.send(json.dumps(request))
|
||||
|
||||
yield context # Remove this if you just want to see the reply
|
||||
yield context # Remove this if you just want to see the reply
|
||||
|
||||
while True:
|
||||
incoming_data = await websocket.recv()
|
||||
@ -58,7 +59,7 @@ async def run(context):
|
||||
async def print_response_stream(prompt):
|
||||
async for response in run(prompt):
|
||||
print(response, end='')
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
|
||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||
|
||||
|
||||
def run(prompt):
|
||||
request = {
|
||||
'prompt': prompt,
|
||||
@ -37,6 +38,7 @@ def run(prompt):
|
||||
result = response.json()['results'][0]['text']
|
||||
print(prompt + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||
run(prompt)
|
||||
|
@ -2,11 +2,10 @@ import json
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
|
@ -5,6 +5,7 @@ from modules import shared
|
||||
BLOCKING_PORT = 5000
|
||||
STREAMING_PORT = 5005
|
||||
|
||||
|
||||
def setup():
|
||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||
|
@ -1,12 +1,12 @@
|
||||
import json
|
||||
import asyncio
|
||||
from websockets.server import serve
|
||||
import json
|
||||
from threading import Thread
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
from websockets.server import serve
|
||||
|
||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||
from modules import shared
|
||||
from modules.text_generation import generate_reply
|
||||
|
||||
PATH = '/api/v1/stream'
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
import gradio as gr
|
||||
import os
|
||||
|
||||
import gradio as gr
|
||||
|
||||
# get the current directory of the script
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
import gradio as gr
|
||||
import logging
|
||||
|
||||
import gradio as gr
|
||||
|
||||
|
||||
def ui():
|
||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||
|
@ -6,10 +6,11 @@ from io import BytesIO
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, get_max_prompt_length
|
||||
from PIL import Image
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -7,6 +7,7 @@ from io import BytesIO
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||
from modules import shared
|
||||
|
||||
|
@ -1,11 +1,12 @@
|
||||
import base64
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||
from threading import Thread
|
||||
|
||||
import numpy as np
|
||||
|
||||
from modules import shared
|
||||
from modules.text_generation import encode, generate_reply
|
||||
|
||||
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if self.path.startswith('/v1/models'):
|
||||
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
|
||||
"created": created_time,
|
||||
"model": model, # TODO: add Lora info?
|
||||
resp_list: [{
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
"index": 0,
|
||||
"finish_reason": "stop",
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": token_count,
|
||||
|
@ -6,12 +6,13 @@ from datetime import date
|
||||
from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import modules.shared as shared
|
||||
import requests
|
||||
import torch
|
||||
from modules.models import reload_model, unload_model
|
||||
from PIL import Image
|
||||
|
||||
import modules.shared as shared
|
||||
from modules.models import reload_model, unload_model
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
||||
# parameters which can be customized in settings.json of webui
|
||||
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
|
||||
|
||||
picture_response = False # specifies if the next model response should appear as a picture
|
||||
|
||||
|
||||
def remove_surrounded_chars(string):
|
||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
@ -122,7 +124,6 @@ def input_modifier(string):
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
def get_SD_pictures(description):
|
||||
|
||||
global params
|
||||
|
||||
if params['manage_VRAM']:
|
||||
@ -259,6 +260,7 @@ def SD_api_address_update(address):
|
||||
|
||||
return gr.Textbox.update(label=msg)
|
||||
|
||||
|
||||
def ui():
|
||||
|
||||
# Gradio elements
|
||||
@ -290,12 +292,11 @@ def ui():
|
||||
cfg_scale = gr.Number(label="CFG Scale", value=params['cfg_scale'], elem_id="cfg_box")
|
||||
with gr.Column() as hr_options:
|
||||
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
|
||||
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
||||
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
|
||||
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
|
||||
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||
|
||||
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
|
||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||
|
||||
# Event functions to update the parameters in the backend
|
||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
|
||||
from extensions.silero_tts import tts_preprocessor
|
||||
from modules import chat, shared
|
||||
from modules.html_generator import chat_html_wrapper
|
||||
@ -216,4 +217,4 @@ def ui():
|
||||
|
||||
# Play preview
|
||||
preview_text.submit(voice_preview, preview_text, preview_audio)
|
||||
preview_play.click(voice_preview, preview_text, preview_audio)
|
||||
preview_play.click(voice_preview, preview_text, preview_audio)
|
||||
|
@ -2,7 +2,6 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
import tts_preprocessor
|
||||
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
|
@ -69,7 +69,7 @@ def remove_surrounded_chars(string):
|
||||
# first this expression will check if there is a string nested exclusively between a alt=
|
||||
# and a style= string. This would correspond to only a the alt text of an embedded image
|
||||
# If it matches it will only keep that part as the string, and rend it for further processing
|
||||
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
||||
# Afterwards this expression matches to 'as few symbols as possible (0 upwards) between any
|
||||
# asterisks' OR' as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||
if re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL):
|
||||
m = re.search(r'(?<=alt=)(.*)(?=style=)', string, re.DOTALL)
|
||||
|
@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
|
||||
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
|
||||
n_results = min(len(self.ids), n_results)
|
||||
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
|
||||
return list(map(lambda x : int(x[2:]), result))
|
||||
return list(map(lambda x: int(x[2:]), result))
|
||||
|
||||
def clear(self):
|
||||
self.collection.delete(ids=self.ids)
|
||||
@ -162,13 +162,13 @@ def input_modifier(string):
|
||||
def custom_generate_chat_prompt(user_input, state, **kwargs):
|
||||
if len(shared.history['internal']) > 2 and user_input != '':
|
||||
chunks = []
|
||||
for i in range(len(shared.history['internal'])-1):
|
||||
for i in range(len(shared.history['internal']) - 1):
|
||||
chunks.append('\n'.join(shared.history['internal'][i]))
|
||||
|
||||
add_chunks_to_collector(chunks)
|
||||
query = '\n'.join(shared.history['internal'][-1] + [user_input])
|
||||
try:
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
|
||||
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
|
||||
|
||||
# Sort the history by relevance instead of by chronological order,
|
||||
# except for the latest message
|
||||
@ -226,7 +226,7 @@ def ui():
|
||||
|
||||
## Chat mode
|
||||
|
||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||
In chat mode, the extension automatically sorts the history by relevance instead of chronologically, except for the very latest input/reply pair.
|
||||
|
||||
That is, the prompt will include (starting from the end):
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
import gradio as gr
|
||||
import speech_recognition as sr
|
||||
|
||||
from modules import shared
|
||||
|
||||
input_hijack = {
|
||||
|
@ -24,13 +24,12 @@ class RWKVModel:
|
||||
@classmethod
|
||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||
|
||||
if shared.args.rwkv_strategy is None:
|
||||
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
||||
else:
|
||||
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||
|
||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||
result = self()
|
||||
result.pipeline = pipeline
|
||||
result.model = model
|
||||
@ -83,7 +82,6 @@ class RWKVModel:
|
||||
out = self.cached_output_logits
|
||||
|
||||
for i in range(token_count):
|
||||
|
||||
# forward
|
||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||
while len(tokens) > 0:
|
||||
@ -91,35 +89,38 @@ class RWKVModel:
|
||||
tokens = tokens[args.chunk_len:]
|
||||
|
||||
# cache the model state after scanning the context
|
||||
# we don't cache the state after processing our own generated tokens because
|
||||
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
||||
# we don't cache the state after processing our own generated tokens because
|
||||
# the output string might be post-processed arbitrarily. Therefore, what's fed into the model
|
||||
# on the next round of chat might be slightly different what what it output on the previous round
|
||||
if i == 0:
|
||||
self.cached_context += ctx
|
||||
self.cached_model_state = copy.deepcopy(state)
|
||||
self.cached_output_logits = copy.deepcopy(out)
|
||||
|
||||
|
||||
# adjust probabilities
|
||||
for n in args.token_ban:
|
||||
out[n] = -float('inf')
|
||||
|
||||
for n in occurrence:
|
||||
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
||||
|
||||
|
||||
# sampler
|
||||
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
||||
if token in args.token_stop:
|
||||
break
|
||||
|
||||
all_tokens += [token]
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
|
||||
|
||||
# output
|
||||
tmp = self.pipeline.decode([token])
|
||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||
if callback:
|
||||
callback(tmp)
|
||||
|
||||
out_str += tmp
|
||||
|
||||
return out_str
|
||||
@ -133,7 +134,6 @@ class RWKVTokenizer:
|
||||
def from_pretrained(self, path):
|
||||
tokenizer_path = path / "20B_tokenizer.json"
|
||||
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
||||
|
||||
result = self()
|
||||
result.tokenizer = tokenizer
|
||||
return result
|
||||
|
@ -1,5 +1,4 @@
|
||||
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
||||
|
||||
'''
|
||||
DeepSpeed configration
|
||||
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
||||
|
@ -20,6 +20,8 @@ def load_past_evaluations():
|
||||
return df
|
||||
else:
|
||||
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
||||
|
||||
|
||||
past_evaluations = load_past_evaluations()
|
||||
|
||||
|
||||
|
@ -7,7 +7,6 @@ import gradio as gr
|
||||
import extensions
|
||||
import modules.shared as shared
|
||||
|
||||
|
||||
state = {}
|
||||
available_extensions = []
|
||||
setup_called = set()
|
||||
@ -91,7 +90,7 @@ def _apply_state_modifier_extensions(state):
|
||||
state = getattr(extension, "state_modifier")(state)
|
||||
|
||||
return state
|
||||
|
||||
|
||||
|
||||
# Extension functions that override the default tokenizer output - currently only the first one will work
|
||||
def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_embeds):
|
||||
@ -108,7 +107,7 @@ def _apply_custom_tokenized_length(prompt):
|
||||
for extension, _ in iterator():
|
||||
if hasattr(extension, 'custom_tokenized_length'):
|
||||
return getattr(extension, 'custom_tokenized_length')(prompt)
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
# Copied from https://stackoverflow.com/a/1336640
|
||||
|
||||
import logging
|
||||
import platform
|
||||
|
||||
|
||||
def add_coloring_to_emit_windows(fn):
|
||||
# add methods we need to the class
|
||||
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
|
||||
|
||||
def _set_color(self, code):
|
||||
import ctypes
|
||||
|
||||
# Constants from the Windows API
|
||||
self.STD_OUTPUT_HANDLE = -11
|
||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
|
||||
return new
|
||||
|
||||
|
||||
import platform
|
||||
if platform.system() == 'Windows':
|
||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||
|
@ -161,10 +161,10 @@ def load_model(model_name):
|
||||
# Custom
|
||||
else:
|
||||
params = {
|
||||
"low_cpu_mem_usage": True,
|
||||
"trust_remote_code": trust_remote_code
|
||||
"low_cpu_mem_usage": True,
|
||||
"trust_remote_code": trust_remote_code
|
||||
}
|
||||
|
||||
|
||||
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):
|
||||
logging.warning("torch.cuda.is_available() returned False. This means that no GPU has been detected. Falling back to CPU mode.")
|
||||
shared.args.cpu = True
|
||||
@ -288,7 +288,7 @@ def load_soft_prompt(name):
|
||||
logging.info(f"{field}: {', '.join(j[field])}")
|
||||
else:
|
||||
logging.info(f"{field}: {j[field]}")
|
||||
|
||||
|
||||
logging.info()
|
||||
tensor = np.load('tensor.npy')
|
||||
Path('tensor.npy').unlink()
|
||||
|
@ -377,7 +377,7 @@ def create_model_menus():
|
||||
|
||||
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
|
||||
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
|
||||
|
||||
|
||||
def create_settings_menus(default_preset):
|
||||
|
Loading…
Reference in New Issue
Block a user