Style improvements (#1957)

This commit is contained in:
oobabooga 2023-05-09 22:49:39 -03:00 committed by GitHub
parent 334486f527
commit 3913155c1f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 64 additions and 50 deletions

View File

@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
async def run(context):
# Note: the selected defaults change from time to time.
request = {
@ -42,7 +43,7 @@ async def run(context):
async with websockets.connect(URI, ping_interval=None) as websocket:
await websocket.send(json.dumps(request))
yield context # Remove this if you just want to see the reply
yield context # Remove this if you just want to see the reply
while True:
incoming_data = await websocket.recv()
@ -58,7 +59,7 @@ async def run(context):
async def print_response_stream(prompt):
async for response in run(prompt):
print(response, end='')
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
sys.stdout.flush() # If we don't flush, we won't see tokens in realtime.
if __name__ == '__main__':

View File

@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
# For reverse-proxied streaming, the remote will likely host with ssl - https://
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
def run(prompt):
request = {
'prompt': prompt,
@ -37,6 +38,7 @@ def run(prompt):
result = response.json()['results'][0]['text']
print(prompt + result)
if __name__ == '__main__':
prompt = "In order to make homemade bread, follow these steps:\n1)"
run(prompt)

View File

@ -2,11 +2,10 @@ import json
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import encode, generate_reply
from extensions.api.util import build_parameters, try_start_cloudflared
class Handler(BaseHTTPRequestHandler):
def do_GET(self):

View File

@ -5,6 +5,7 @@ from modules import shared
BLOCKING_PORT = 5000
STREAMING_PORT = 5005
def setup():
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)

View File

@ -1,12 +1,12 @@
import json
import asyncio
from websockets.server import serve
import json
from threading import Thread
from modules import shared
from modules.text_generation import generate_reply
from websockets.server import serve
from extensions.api.util import build_parameters, try_start_cloudflared
from modules import shared
from modules.text_generation import generate_reply
PATH = '/api/v1/stream'

View File

@ -1,6 +1,7 @@
import gradio as gr
import os
import gradio as gr
# get the current directory of the script
current_dir = os.path.dirname(os.path.abspath(__file__))

View File

@ -1,6 +1,8 @@
import gradio as gr
import logging
import gradio as gr
def ui():
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")

View File

@ -6,10 +6,11 @@ from io import BytesIO
from typing import Any, List, Optional
import torch
from PIL import Image
from extensions.multimodal.pipeline_loader import load_pipeline
from modules import shared
from modules.text_generation import encode, get_max_prompt_length
from PIL import Image
@dataclass

View File

@ -7,6 +7,7 @@ from io import BytesIO
import gradio as gr
import torch
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
from modules import shared

View File

@ -1,11 +1,12 @@
import base64
import json
import numpy as np
import os
import time
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread
import numpy as np
from modules import shared
from modules.text_generation import encode, generate_reply
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
class Handler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path.startswith('/v1/models'):
@ -387,8 +389,8 @@ class Handler(BaseHTTPRequestHandler):
"created": created_time,
"model": model, # TODO: add Lora info?
resp_list: [{
"index": 0,
"finish_reason": "stop",
"index": 0,
"finish_reason": "stop",
}],
"usage": {
"prompt_tokens": token_count,

View File

@ -6,12 +6,13 @@ from datetime import date
from pathlib import Path
import gradio as gr
import modules.shared as shared
import requests
import torch
from modules.models import reload_model, unload_model
from PIL import Image
import modules.shared as shared
from modules.models import reload_model, unload_model
torch._C._jit_set_profiling_mode(False)
# parameters which can be customized in settings.json of webui
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
picture_response = False # specifies if the next model response should appear as a picture
def remove_surrounded_chars(string):
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
@ -122,7 +124,6 @@ def input_modifier(string):
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description):
global params
if params['manage_VRAM']:
@ -259,6 +260,7 @@ def SD_api_address_update(address):
return gr.Textbox.update(label=msg)
def ui():
# Gradio elements
@ -292,10 +294,9 @@ def ui():
restore_faces = gr.Checkbox(value=params['restore_faces'], label='Restore faces')
enable_hr = gr.Checkbox(value=params['enable_hr'], label='Hires. fix')
with gr.Row(visible=params['enable_hr'], elem_classes="hires_opts") as hr_options:
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
hr_scale = gr.Slider(1, 4, value=params['hr_scale'], step=0.1, label='Upscale by')
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
# Event functions to update the parameters in the backend
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)

View File

@ -4,6 +4,7 @@ from pathlib import Path
import gradio as gr
import torch
from extensions.silero_tts import tts_preprocessor
from modules import chat, shared
from modules.html_generator import chat_html_wrapper

View File

@ -2,7 +2,6 @@ import time
from pathlib import Path
import torch
import tts_preprocessor
torch._C._jit_set_profiling_mode(False)

View File

@ -59,7 +59,7 @@ class ChromaCollector(Collecter):
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
n_results = min(len(self.ids), n_results)
result = self.collection.query(query_texts=search_strings, n_results=n_results, include=['documents'])['ids'][0]
return list(map(lambda x : int(x[2:]), result))
return list(map(lambda x: int(x[2:]), result))
def clear(self):
self.collection.delete(ids=self.ids)
@ -162,13 +162,13 @@ def input_modifier(string):
def custom_generate_chat_prompt(user_input, state, **kwargs):
if len(shared.history['internal']) > 2 and user_input != '':
chunks = []
for i in range(len(shared.history['internal'])-1):
for i in range(len(shared.history['internal']) - 1):
chunks.append('\n'.join(shared.history['internal'][i]))
add_chunks_to_collector(chunks)
query = '\n'.join(shared.history['internal'][-1] + [user_input])
try:
best_ids = collector.get_ids(query, n_results=len(shared.history['internal'])-1)
best_ids = collector.get_ids(query, n_results=len(shared.history['internal']) - 1)
# Sort the history by relevance instead of by chronological order,
# except for the latest message

View File

@ -1,5 +1,6 @@
import gradio as gr
import speech_recognition as sr
from modules import shared
input_hijack = {

View File

@ -24,13 +24,12 @@ class RWKVModel:
@classmethod
def from_pretrained(self, path, dtype="fp16", device="cuda"):
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
if shared.args.rwkv_strategy is None:
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
else:
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
pipeline = PIPELINE(model, str(tokenizer_path))
pipeline = PIPELINE(model, str(tokenizer_path))
result = self()
result.pipeline = pipeline
result.model = model
@ -83,7 +82,6 @@ class RWKVModel:
out = self.cached_output_logits
for i in range(token_count):
# forward
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
while len(tokens) > 0:
@ -102,6 +100,7 @@ class RWKVModel:
# adjust probabilities
for n in args.token_ban:
out[n] = -float('inf')
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
@ -109,6 +108,7 @@ class RWKVModel:
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
if token in args.token_stop:
break
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
@ -117,9 +117,10 @@ class RWKVModel:
# output
tmp = self.pipeline.decode([token])
if '\ufffd' not in tmp: # is valid utf-8 string?
if '\ufffd' not in tmp: # is valid utf-8 string?
if callback:
callback(tmp)
out_str += tmp
return out_str
@ -133,7 +134,6 @@ class RWKVTokenizer:
def from_pretrained(self, path):
tokenizer_path = path / "20B_tokenizer.json"
tokenizer = Tokenizer.from_file(str(tokenizer_path))
result = self()
result.tokenizer = tokenizer
return result

View File

@ -1,5 +1,4 @@
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
'''
DeepSpeed configration
https://huggingface.co/docs/transformers/main_classes/deepspeed

View File

@ -20,6 +20,8 @@ def load_past_evaluations():
return df
else:
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
past_evaluations = load_past_evaluations()

View File

@ -7,7 +7,6 @@ import gradio as gr
import extensions
import modules.shared as shared
state = {}
available_extensions = []
setup_called = set()

View File

@ -1,6 +1,8 @@
# Copied from https://stackoverflow.com/a/1336640
import logging
import platform
def add_coloring_to_emit_windows(fn):
# add methods we need to the class
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
def _set_color(self, code):
import ctypes
# Constants from the Windows API
self.STD_OUTPUT_HANDLE = -11
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
return new
import platform
if platform.system() == 'Windows':
# Windows does not support ANSI escapes and we are using API calls to set the console color
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)

View File

@ -161,8 +161,8 @@ def load_model(model_name):
# Custom
else:
params = {
"low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code
"low_cpu_mem_usage": True,
"trust_remote_code": trust_remote_code
}
if not any((shared.args.cpu, torch.cuda.is_available(), torch.has_mps)):

View File

@ -377,7 +377,7 @@ def create_model_menus():
shared.gradio['lora_menu_apply'].click(load_lora_wrapper, shared.gradio['lora_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['download_model_button'].click(download_model_wrapper, shared.gradio['custom_model_menu'], shared.gradio['model_status'], show_progress=False)
shared.gradio['autoload_model'].change(lambda x : gr.update(visible=not x), shared.gradio['autoload_model'], load)
shared.gradio['autoload_model'].change(lambda x: gr.update(visible=not x), shared.gradio['autoload_model'], load)
def create_settings_menus(default_preset):