This commit is contained in:
oobabooga 2023-07-12 11:33:25 -07:00
parent 9b55d3a9f9
commit e202190c4f
24 changed files with 146 additions and 125 deletions

View File

@ -22,6 +22,7 @@ options = {
'session_metadata': 'text-generation-webui',
}
def ui():
settings = shared.settings.get("ngrok")
if settings:
@ -33,4 +34,3 @@ def ui():
logging.info(f"Ingress established at: {tunnel.url()}")
except ModuleNotFoundError:
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")

View File

@ -380,7 +380,6 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
completion_token_count += len(encode(new_content)[0])
yield chunk
stop_reason = "stop"
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
stop_reason = "length"

View File

@ -41,10 +41,13 @@ default_req_params = {
# 'requested_model' - temporarily used
}
def get_default_req_params():
return copy.deepcopy(default_req_params)
# little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default):
val = dic.get(key, default)
if type(val) != type(default):
@ -59,6 +62,6 @@ def default(dic, key, default):
val = default
return val
def clamp(value, minvalue, maxvalue):
return max(minvalue, min(value, maxvalue))

View File

@ -6,6 +6,7 @@ from extensions.openai.errors import *
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None
def load_embedding_model(model):
try:
emb_model = SentenceTransformer(model)
@ -16,16 +17,19 @@ def load_embedding_model(model):
return emb_model
def get_embeddings_model():
global embeddings_model, st_model
if st_model and not embeddings_model:
embeddings_model = load_embedding_model(st_model) # lazy load the model
return embeddings_model
def get_embeddings_model_name():
global st_model
return st_model
def embeddings(input: list, encoding_format: str):
embeddings = get_embeddings_model().encode(input).tolist()

View File

@ -3,6 +3,7 @@ class OpenAIError(Exception):
self.message = message
self.code = code
self.internal_message = internal_message
def __repr__(self):
return "%s(message=%r, code=%d)" % (
self.__class__.__name__,
@ -10,10 +11,12 @@ class OpenAIError(Exception):
self.code,
)
class InvalidRequestError(OpenAIError):
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
self.param = param
def __repr__(self):
return "%s(message=%r, code=%d, param=%s)" % (
self.__class__.__name__,
@ -22,6 +25,7 @@ class InvalidRequestError(OpenAIError):
self.param,
)
class ServiceUnavailableError(OpenAIError):
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
super(OpenAIError, self).__init__(message, code, error_type, internal_message)

View File

@ -3,6 +3,7 @@ import time
import requests
from extensions.openai.errors import *
def generations(prompt: str, size: str, response_format: str, n: int):
# Stable Diffusion callout wrapper for txt2img
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E

View File

@ -7,15 +7,18 @@ from modules.models_settings import (get_model_settings_from_yamls,
from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import *
def get_current_model_list() -> list:
return [shared.model_name] # The real chat/completions model, maybe "None"
def get_pseudo_model_list() -> list:
return [ # these are expected by so much, so include some here as a dummy
'gpt-3.5-turbo',
'text-embedding-ada-002',
]
def load_model(model_name: str) -> dict:
resp = {
"id": model_name,
@ -74,4 +77,3 @@ def model_info(model_name: str) -> dict:
"owned_by": "user",
"permission": []
}

View File

@ -48,7 +48,6 @@ def moderations(input):
category_embeddings = get_category_embeddings()
# input, string or array
if isinstance(input, str):
input = [input]

View File

@ -22,6 +22,7 @@ params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
}
class Handler(BaseHTTPRequestHandler):
def send_access_control_headers(self):
self.send_header("Access-Control-Allow-Origin", "*")

View File

@ -1,6 +1,7 @@
from extensions.openai.utils import float_list_to_base64
from modules.text_generation import encode, decode
def token_count(prompt):
tokens = encode(prompt)[0]

View File

@ -2,6 +2,7 @@ import os
import base64
import numpy as np
def float_list_to_base64(float_list):
# Convert the list to a float32 array that the OpenAPI client expects
float_array = np.array(float_list, dtype="float32")
@ -16,11 +17,13 @@ def float_list_to_base64(float_list):
ascii_string = encoded_bytes.decode('ascii')
return ascii_string
def end_line(s):
if s and s[-1] != '\n':
s = s + '\n'
return s
def debug_msg(*args, **kwargs):
if 'OPENEDAI_DEBUG' in os.environ:
print(*args, **kwargs)

View File

@ -126,6 +126,8 @@ def input_modifier(string):
return string
# Get and save the Stable Diffusion-generated picture
def get_SD_pictures(description, character):
global params
@ -186,6 +188,8 @@ def get_SD_pictures(description, character):
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
# and replace it with 'text' for the purposes of logging?
def output_modifier(string, state):
"""
This function is applied to the model outputs.

View File

@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
'''
Copied from the transformers library
'''
def __init__(self, penalty: float, _range: int):
if not isinstance(penalty, float) or not (penalty > 0):
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")