mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-24 17:06:53 +01:00
Style improvements (#1957)
This commit is contained in:
parent
334486f527
commit
3913155c1f
@ -14,6 +14,7 @@ URI = f'ws://{HOST}/api/v1/stream'
|
|||||||
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
# For reverse-proxied streaming, the remote will likely host with ssl - wss://
|
||||||
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
# URI = 'wss://your-uri-here.trycloudflare.com/api/v1/stream'
|
||||||
|
|
||||||
|
|
||||||
async def run(context):
|
async def run(context):
|
||||||
# Note: the selected defaults change from time to time.
|
# Note: the selected defaults change from time to time.
|
||||||
request = {
|
request = {
|
||||||
|
@ -7,6 +7,7 @@ URI = f'http://{HOST}/api/v1/generate'
|
|||||||
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
# For reverse-proxied streaming, the remote will likely host with ssl - https://
|
||||||
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
# URI = 'https://your-uri-here.trycloudflare.com/api/v1/generate'
|
||||||
|
|
||||||
|
|
||||||
def run(prompt):
|
def run(prompt):
|
||||||
request = {
|
request = {
|
||||||
'prompt': prompt,
|
'prompt': prompt,
|
||||||
@ -37,6 +38,7 @@ def run(prompt):
|
|||||||
result = response.json()['results'][0]['text']
|
result = response.json()['results'][0]['text']
|
||||||
print(prompt + result)
|
print(prompt + result)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
prompt = "In order to make homemade bread, follow these steps:\n1)"
|
||||||
run(prompt)
|
run(prompt)
|
||||||
|
@ -2,11 +2,10 @@ import json
|
|||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
|
@ -5,6 +5,7 @@ from modules import shared
|
|||||||
BLOCKING_PORT = 5000
|
BLOCKING_PORT = 5000
|
||||||
STREAMING_PORT = 5005
|
STREAMING_PORT = 5005
|
||||||
|
|
||||||
|
|
||||||
def setup():
|
def setup():
|
||||||
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
blocking_api.start_server(BLOCKING_PORT, share=shared.args.public_api)
|
||||||
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
streaming_api.start_server(STREAMING_PORT, share=shared.args.public_api)
|
||||||
|
@ -1,12 +1,12 @@
|
|||||||
import json
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from websockets.server import serve
|
import json
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from modules import shared
|
from websockets.server import serve
|
||||||
from modules.text_generation import generate_reply
|
|
||||||
|
|
||||||
from extensions.api.util import build_parameters, try_start_cloudflared
|
from extensions.api.util import build_parameters, try_start_cloudflared
|
||||||
|
from modules import shared
|
||||||
|
from modules.text_generation import generate_reply
|
||||||
|
|
||||||
PATH = '/api/v1/stream'
|
PATH = '/api/v1/stream'
|
||||||
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import gradio as gr
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
# get the current directory of the script
|
# get the current directory of the script
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
import gradio as gr
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
gr.Markdown("### This extension is deprecated, use \"multimodal\" extension instead")
|
||||||
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
logging.error("LLaVA extension is deprecated, use \"multimodal\" extension instead")
|
||||||
|
@ -6,10 +6,11 @@ from io import BytesIO
|
|||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
from extensions.multimodal.pipeline_loader import load_pipeline
|
from extensions.multimodal.pipeline_loader import load_pipeline
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, get_max_prompt_length
|
from modules.text_generation import encode, get_max_prompt_length
|
||||||
from PIL import Image
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -7,6 +7,7 @@ from io import BytesIO
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
from extensions.multimodal.multimodal_embedder import MultimodalEmbedder
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
import json
|
||||||
import numpy as np
|
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
@ -61,6 +62,7 @@ def float_list_to_base64(float_list):
|
|||||||
ascii_string = encoded_bytes.decode('ascii')
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
return ascii_string
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def do_GET(self):
|
def do_GET(self):
|
||||||
if self.path.startswith('/v1/models'):
|
if self.path.startswith('/v1/models'):
|
||||||
|
@ -6,12 +6,13 @@ from datetime import date
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import modules.shared as shared
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
from modules.models import reload_model, unload_model
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
import modules.shared as shared
|
||||||
|
from modules.models import reload_model, unload_model
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
|
||||||
# parameters which can be customized in settings.json of webui
|
# parameters which can be customized in settings.json of webui
|
||||||
@ -77,6 +78,7 @@ SD_models = ['NeverEndingDream'] # TODO: get with http://{address}}/sdapi/v1/sd
|
|||||||
|
|
||||||
picture_response = False # specifies if the next model response should appear as a picture
|
picture_response = False # specifies if the next model response should appear as a picture
|
||||||
|
|
||||||
|
|
||||||
def remove_surrounded_chars(string):
|
def remove_surrounded_chars(string):
|
||||||
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
# this expression matches to 'as few symbols as possible (0 upwards) between any asterisks' OR
|
||||||
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
# 'as few symbols as possible (0 upwards) between an asterisk and the end of the string'
|
||||||
@ -122,7 +124,6 @@ def input_modifier(string):
|
|||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
def get_SD_pictures(description):
|
def get_SD_pictures(description):
|
||||||
|
|
||||||
global params
|
global params
|
||||||
|
|
||||||
if params['manage_VRAM']:
|
if params['manage_VRAM']:
|
||||||
@ -259,6 +260,7 @@ def SD_api_address_update(address):
|
|||||||
|
|
||||||
return gr.Textbox.update(label=msg)
|
return gr.Textbox.update(label=msg)
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
|
|
||||||
# Gradio elements
|
# Gradio elements
|
||||||
@ -296,7 +298,6 @@ def ui():
|
|||||||
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
denoising_strength = gr.Slider(0, 1, value=params['denoising_strength'], step=0.01, label='Denoising strength')
|
||||||
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
hr_upscaler = gr.Textbox(placeholder=params['hr_upscaler'], value=params['hr_upscaler'], label='Upscaler')
|
||||||
|
|
||||||
|
|
||||||
# Event functions to update the parameters in the backend
|
# Event functions to update the parameters in the backend
|
||||||
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
address.change(lambda x: params.update({"address": filter_address(x)}), address, None)
|
||||||
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
mode.select(lambda x: params.update({"mode": x}), mode, None)
|
||||||
|
@ -4,6 +4,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from extensions.silero_tts import tts_preprocessor
|
from extensions.silero_tts import tts_preprocessor
|
||||||
from modules import chat, shared
|
from modules import chat, shared
|
||||||
from modules.html_generator import chat_html_wrapper
|
from modules.html_generator import chat_html_wrapper
|
||||||
|
@ -2,7 +2,6 @@ import time
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import tts_preprocessor
|
import tts_preprocessor
|
||||||
|
|
||||||
torch._C._jit_set_profiling_mode(False)
|
torch._C._jit_set_profiling_mode(False)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
import speech_recognition as sr
|
import speech_recognition as sr
|
||||||
|
|
||||||
from modules import shared
|
from modules import shared
|
||||||
|
|
||||||
input_hijack = {
|
input_hijack = {
|
||||||
|
@ -24,13 +24,12 @@ class RWKVModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
def from_pretrained(self, path, dtype="fp16", device="cuda"):
|
||||||
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
tokenizer_path = Path(f"{path.parent}/20B_tokenizer.json")
|
||||||
|
|
||||||
if shared.args.rwkv_strategy is None:
|
if shared.args.rwkv_strategy is None:
|
||||||
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
model = RWKV(model=str(path), strategy=f'{device} {dtype}')
|
||||||
else:
|
else:
|
||||||
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
model = RWKV(model=str(path), strategy=shared.args.rwkv_strategy)
|
||||||
pipeline = PIPELINE(model, str(tokenizer_path))
|
|
||||||
|
|
||||||
|
pipeline = PIPELINE(model, str(tokenizer_path))
|
||||||
result = self()
|
result = self()
|
||||||
result.pipeline = pipeline
|
result.pipeline = pipeline
|
||||||
result.model = model
|
result.model = model
|
||||||
@ -83,7 +82,6 @@ class RWKVModel:
|
|||||||
out = self.cached_output_logits
|
out = self.cached_output_logits
|
||||||
|
|
||||||
for i in range(token_count):
|
for i in range(token_count):
|
||||||
|
|
||||||
# forward
|
# forward
|
||||||
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
tokens = self.pipeline.encode(ctx) if i == 0 else [token]
|
||||||
while len(tokens) > 0:
|
while len(tokens) > 0:
|
||||||
@ -102,6 +100,7 @@ class RWKVModel:
|
|||||||
# adjust probabilities
|
# adjust probabilities
|
||||||
for n in args.token_ban:
|
for n in args.token_ban:
|
||||||
out[n] = -float('inf')
|
out[n] = -float('inf')
|
||||||
|
|
||||||
for n in occurrence:
|
for n in occurrence:
|
||||||
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
|
||||||
|
|
||||||
@ -109,6 +108,7 @@ class RWKVModel:
|
|||||||
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p, top_k=args.top_k)
|
||||||
if token in args.token_stop:
|
if token in args.token_stop:
|
||||||
break
|
break
|
||||||
|
|
||||||
all_tokens += [token]
|
all_tokens += [token]
|
||||||
if token not in occurrence:
|
if token not in occurrence:
|
||||||
occurrence[token] = 1
|
occurrence[token] = 1
|
||||||
@ -120,6 +120,7 @@ class RWKVModel:
|
|||||||
if '\ufffd' not in tmp: # is valid utf-8 string?
|
if '\ufffd' not in tmp: # is valid utf-8 string?
|
||||||
if callback:
|
if callback:
|
||||||
callback(tmp)
|
callback(tmp)
|
||||||
|
|
||||||
out_str += tmp
|
out_str += tmp
|
||||||
|
|
||||||
return out_str
|
return out_str
|
||||||
@ -133,7 +134,6 @@ class RWKVTokenizer:
|
|||||||
def from_pretrained(self, path):
|
def from_pretrained(self, path):
|
||||||
tokenizer_path = path / "20B_tokenizer.json"
|
tokenizer_path = path / "20B_tokenizer.json"
|
||||||
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
||||||
|
|
||||||
result = self()
|
result = self()
|
||||||
result.tokenizer = tokenizer
|
result.tokenizer = tokenizer
|
||||||
return result
|
return result
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
def generate_ds_config(ds_bf16, train_batch_size, nvme_offload_dir):
|
||||||
|
|
||||||
'''
|
'''
|
||||||
DeepSpeed configration
|
DeepSpeed configration
|
||||||
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
https://huggingface.co/docs/transformers/main_classes/deepspeed
|
||||||
|
@ -20,6 +20,8 @@ def load_past_evaluations():
|
|||||||
return df
|
return df
|
||||||
else:
|
else:
|
||||||
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
return pd.DataFrame(columns=['Model', 'LoRAs', 'Dataset', 'Perplexity', 'stride', 'max_length', 'Date', 'Comment'])
|
||||||
|
|
||||||
|
|
||||||
past_evaluations = load_past_evaluations()
|
past_evaluations = load_past_evaluations()
|
||||||
|
|
||||||
|
|
||||||
|
@ -7,7 +7,6 @@ import gradio as gr
|
|||||||
import extensions
|
import extensions
|
||||||
import modules.shared as shared
|
import modules.shared as shared
|
||||||
|
|
||||||
|
|
||||||
state = {}
|
state = {}
|
||||||
available_extensions = []
|
available_extensions = []
|
||||||
setup_called = set()
|
setup_called = set()
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
# Copied from https://stackoverflow.com/a/1336640
|
# Copied from https://stackoverflow.com/a/1336640
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import platform
|
||||||
|
|
||||||
|
|
||||||
def add_coloring_to_emit_windows(fn):
|
def add_coloring_to_emit_windows(fn):
|
||||||
# add methods we need to the class
|
# add methods we need to the class
|
||||||
@ -11,6 +13,7 @@ def add_coloring_to_emit_windows(fn):
|
|||||||
|
|
||||||
def _set_color(self, code):
|
def _set_color(self, code):
|
||||||
import ctypes
|
import ctypes
|
||||||
|
|
||||||
# Constants from the Windows API
|
# Constants from the Windows API
|
||||||
self.STD_OUTPUT_HANDLE = -11
|
self.STD_OUTPUT_HANDLE = -11
|
||||||
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
hdl = ctypes.windll.kernel32.GetStdHandle(self.STD_OUTPUT_HANDLE)
|
||||||
@ -94,7 +97,6 @@ def add_coloring_to_emit_ansi(fn):
|
|||||||
return new
|
return new
|
||||||
|
|
||||||
|
|
||||||
import platform
|
|
||||||
if platform.system() == 'Windows':
|
if platform.system() == 'Windows':
|
||||||
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
# Windows does not support ANSI escapes and we are using API calls to set the console color
|
||||||
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
logging.StreamHandler.emit = add_coloring_to_emit_windows(logging.StreamHandler.emit)
|
||||||
|
Loading…
Reference in New Issue
Block a user