mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-21 23:57:58 +01:00
lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
@ -22,6 +22,7 @@ options = {
|
||||
'session_metadata': 'text-generation-webui',
|
||||
}
|
||||
|
||||
|
||||
def ui():
|
||||
settings = shared.settings.get("ngrok")
|
||||
if settings:
|
||||
@ -33,4 +34,3 @@ def ui():
|
||||
logging.info(f"Ingress established at: {tunnel.url()}")
|
||||
except ModuleNotFoundError:
|
||||
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
||||
|
||||
|
@ -380,7 +380,6 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
||||
|
||||
yield chunk
|
||||
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
|
||||
completion_token_count += len(encode(new_content)[0])
|
||||
yield chunk
|
||||
|
||||
|
||||
stop_reason = "stop"
|
||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||
stop_reason = "length"
|
||||
|
@ -41,10 +41,13 @@ default_req_params = {
|
||||
# 'requested_model' - temporarily used
|
||||
}
|
||||
|
||||
|
||||
def get_default_req_params():
|
||||
return copy.deepcopy(default_req_params)
|
||||
|
||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||
|
||||
|
||||
def default(dic, key, default):
|
||||
val = dic.get(key, default)
|
||||
if type(val) != type(default):
|
||||
@ -59,6 +62,6 @@ def default(dic, key, default):
|
||||
val = default
|
||||
return val
|
||||
|
||||
|
||||
def clamp(value, minvalue, maxvalue):
|
||||
return max(minvalue, min(value, maxvalue))
|
||||
|
||||
|
@ -6,6 +6,7 @@ from extensions.openai.errors import *
|
||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model):
|
||||
try:
|
||||
emb_model = SentenceTransformer(model)
|
||||
@ -16,16 +17,19 @@ def load_embedding_model(model):
|
||||
|
||||
return emb_model
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||
return embeddings_model
|
||||
|
||||
|
||||
def get_embeddings_model_name():
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def embeddings(input: list, encoding_format: str):
|
||||
|
||||
embeddings = get_embeddings_model().encode(input).tolist()
|
||||
|
@ -3,6 +3,7 @@ class OpenAIError(Exception):
|
||||
self.message = message
|
||||
self.code = code
|
||||
self.internal_message = internal_message
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d)" % (
|
||||
self.__class__.__name__,
|
||||
@ -10,10 +11,12 @@ class OpenAIError(Exception):
|
||||
self.code,
|
||||
)
|
||||
|
||||
|
||||
class InvalidRequestError(OpenAIError):
|
||||
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
self.param = param
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(message=%r, code=%d, param=%s)" % (
|
||||
self.__class__.__name__,
|
||||
@ -22,6 +25,7 @@ class InvalidRequestError(OpenAIError):
|
||||
self.param,
|
||||
)
|
||||
|
||||
|
||||
class ServiceUnavailableError(OpenAIError):
|
||||
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||
|
@ -3,6 +3,7 @@ import time
|
||||
import requests
|
||||
from extensions.openai.errors import *
|
||||
|
||||
|
||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||
# Stable Diffusion callout wrapper for txt2img
|
||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||
|
@ -7,15 +7,18 @@ from modules.models_settings import (get_model_settings_from_yamls,
|
||||
from extensions.openai.embeddings import get_embeddings_model_name
|
||||
from extensions.openai.errors import *
|
||||
|
||||
|
||||
def get_current_model_list() -> list:
|
||||
return [shared.model_name] # The real chat/completions model, maybe "None"
|
||||
|
||||
|
||||
def get_pseudo_model_list() -> list:
|
||||
return [ # these are expected by so much, so include some here as a dummy
|
||||
'gpt-3.5-turbo',
|
||||
'text-embedding-ada-002',
|
||||
]
|
||||
|
||||
|
||||
def load_model(model_name: str) -> dict:
|
||||
resp = {
|
||||
"id": model_name,
|
||||
@ -74,4 +77,3 @@ def model_info(model_name: str) -> dict:
|
||||
"owned_by": "user",
|
||||
"permission": []
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,6 @@ def moderations(input):
|
||||
|
||||
category_embeddings = get_category_embeddings()
|
||||
|
||||
|
||||
# input, string or array
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
@ -22,6 +22,7 @@ params = {
|
||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||
}
|
||||
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def send_access_control_headers(self):
|
||||
self.send_header("Access-Control-Allow-Origin", "*")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from extensions.openai.utils import float_list_to_base64
|
||||
from modules.text_generation import encode, decode
|
||||
|
||||
|
||||
def token_count(prompt):
|
||||
tokens = encode(prompt)[0]
|
||||
|
||||
|
@ -2,6 +2,7 @@ import os
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
|
||||
def float_list_to_base64(float_list):
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
float_array = np.array(float_list, dtype="float32")
|
||||
@ -16,11 +17,13 @@ def float_list_to_base64(float_list):
|
||||
ascii_string = encoded_bytes.decode('ascii')
|
||||
return ascii_string
|
||||
|
||||
|
||||
def end_line(s):
|
||||
if s and s[-1] != '\n':
|
||||
s = s + '\n'
|
||||
return s
|
||||
|
||||
|
||||
def debug_msg(*args, **kwargs):
|
||||
if 'OPENEDAI_DEBUG' in os.environ:
|
||||
print(*args, **kwargs)
|
@ -126,6 +126,8 @@ def input_modifier(string):
|
||||
return string
|
||||
|
||||
# Get and save the Stable Diffusion-generated picture
|
||||
|
||||
|
||||
def get_SD_pictures(description, character):
|
||||
|
||||
global params
|
||||
@ -186,6 +188,8 @@ def get_SD_pictures(description, character):
|
||||
|
||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||
# and replace it with 'text' for the purposes of logging?
|
||||
|
||||
|
||||
def output_modifier(string, state):
|
||||
"""
|
||||
This function is applied to the model outputs.
|
||||
|
@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
||||
'''
|
||||
Copied from the transformers library
|
||||
'''
|
||||
|
||||
def __init__(self, penalty: float, _range: int):
|
||||
if not isinstance(penalty, float) or not (penalty > 0):
|
||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||
|
Loading…
Reference in New Issue
Block a user