mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-22 08:07:56 +01:00
lint
This commit is contained in:
parent
9b55d3a9f9
commit
e202190c4f
@ -22,6 +22,7 @@ options = {
|
|||||||
'session_metadata': 'text-generation-webui',
|
'session_metadata': 'text-generation-webui',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def ui():
|
def ui():
|
||||||
settings = shared.settings.get("ngrok")
|
settings = shared.settings.get("ngrok")
|
||||||
if settings:
|
if settings:
|
||||||
@ -33,4 +34,3 @@ def ui():
|
|||||||
logging.info(f"Ingress established at: {tunnel.url()}")
|
logging.info(f"Ingress established at: {tunnel.url()}")
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
logging.error("===> ngrok library not found, please run `pip install -r extensions/ngrok/requirements.txt`")
|
||||||
|
|
||||||
|
@ -380,7 +380,6 @@ def stream_chat_completions(body: dict, is_legacy: bool=False):
|
|||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
@ -583,7 +582,6 @@ def stream_completions(body: dict, is_legacy: bool=False):
|
|||||||
completion_token_count += len(encode(new_content)[0])
|
completion_token_count += len(encode(new_content)[0])
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
|
|
||||||
stop_reason = "stop"
|
stop_reason = "stop"
|
||||||
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
if token_count + completion_token_count >= req_params['truncation_length'] or completion_token_count >= max_tokens:
|
||||||
stop_reason = "length"
|
stop_reason = "length"
|
||||||
|
@ -41,10 +41,13 @@ default_req_params = {
|
|||||||
# 'requested_model' - temporarily used
|
# 'requested_model' - temporarily used
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_default_req_params():
|
def get_default_req_params():
|
||||||
return copy.deepcopy(default_req_params)
|
return copy.deepcopy(default_req_params)
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
|
|
||||||
|
|
||||||
def default(dic, key, default):
|
def default(dic, key, default):
|
||||||
val = dic.get(key, default)
|
val = dic.get(key, default)
|
||||||
if type(val) != type(default):
|
if type(val) != type(default):
|
||||||
@ -59,6 +62,6 @@ def default(dic, key, default):
|
|||||||
val = default
|
val = default
|
||||||
return val
|
return val
|
||||||
|
|
||||||
|
|
||||||
def clamp(value, minvalue, maxvalue):
|
def clamp(value, minvalue, maxvalue):
|
||||||
return max(minvalue, min(value, maxvalue))
|
return max(minvalue, min(value, maxvalue))
|
||||||
|
|
||||||
|
@ -6,6 +6,7 @@ from extensions.openai.errors import *
|
|||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
|
|
||||||
|
|
||||||
def load_embedding_model(model):
|
def load_embedding_model(model):
|
||||||
try:
|
try:
|
||||||
emb_model = SentenceTransformer(model)
|
emb_model = SentenceTransformer(model)
|
||||||
@ -16,16 +17,19 @@ def load_embedding_model(model):
|
|||||||
|
|
||||||
return emb_model
|
return emb_model
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_model():
|
def get_embeddings_model():
|
||||||
global embeddings_model, st_model
|
global embeddings_model, st_model
|
||||||
if st_model and not embeddings_model:
|
if st_model and not embeddings_model:
|
||||||
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
embeddings_model = load_embedding_model(st_model) # lazy load the model
|
||||||
return embeddings_model
|
return embeddings_model
|
||||||
|
|
||||||
|
|
||||||
def get_embeddings_model_name():
|
def get_embeddings_model_name():
|
||||||
global st_model
|
global st_model
|
||||||
return st_model
|
return st_model
|
||||||
|
|
||||||
|
|
||||||
def embeddings(input: list, encoding_format: str):
|
def embeddings(input: list, encoding_format: str):
|
||||||
|
|
||||||
embeddings = get_embeddings_model().encode(input).tolist()
|
embeddings = get_embeddings_model().encode(input).tolist()
|
||||||
|
@ -3,6 +3,7 @@ class OpenAIError(Exception):
|
|||||||
self.message = message
|
self.message = message
|
||||||
self.code = code
|
self.code = code
|
||||||
self.internal_message = internal_message
|
self.internal_message = internal_message
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "%s(message=%r, code=%d)" % (
|
return "%s(message=%r, code=%d)" % (
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
@ -10,10 +11,12 @@ class OpenAIError(Exception):
|
|||||||
self.code,
|
self.code,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class InvalidRequestError(OpenAIError):
|
class InvalidRequestError(OpenAIError):
|
||||||
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
def __init__(self, message, param, code=400, error_type='InvalidRequestError', internal_message=''):
|
||||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||||
self.param = param
|
self.param = param
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "%s(message=%r, code=%d, param=%s)" % (
|
return "%s(message=%r, code=%d, param=%s)" % (
|
||||||
self.__class__.__name__,
|
self.__class__.__name__,
|
||||||
@ -22,6 +25,7 @@ class InvalidRequestError(OpenAIError):
|
|||||||
self.param,
|
self.param,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ServiceUnavailableError(OpenAIError):
|
class ServiceUnavailableError(OpenAIError):
|
||||||
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
def __init__(self, message=None, code=500, error_type='ServiceUnavailableError', internal_message=''):
|
||||||
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
super(OpenAIError, self).__init__(message, code, error_type, internal_message)
|
||||||
|
@ -3,6 +3,7 @@ import time
|
|||||||
import requests
|
import requests
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||||
# Stable Diffusion callout wrapper for txt2img
|
# Stable Diffusion callout wrapper for txt2img
|
||||||
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
# Low effort implementation for compatibility. With only "prompt" being passed and assuming DALL-E
|
||||||
|
@ -7,15 +7,18 @@ from modules.models_settings import (get_model_settings_from_yamls,
|
|||||||
from extensions.openai.embeddings import get_embeddings_model_name
|
from extensions.openai.embeddings import get_embeddings_model_name
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.errors import *
|
||||||
|
|
||||||
|
|
||||||
def get_current_model_list() -> list:
|
def get_current_model_list() -> list:
|
||||||
return [shared.model_name] # The real chat/completions model, maybe "None"
|
return [shared.model_name] # The real chat/completions model, maybe "None"
|
||||||
|
|
||||||
|
|
||||||
def get_pseudo_model_list() -> list:
|
def get_pseudo_model_list() -> list:
|
||||||
return [ # these are expected by so much, so include some here as a dummy
|
return [ # these are expected by so much, so include some here as a dummy
|
||||||
'gpt-3.5-turbo',
|
'gpt-3.5-turbo',
|
||||||
'text-embedding-ada-002',
|
'text-embedding-ada-002',
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_name: str) -> dict:
|
def load_model(model_name: str) -> dict:
|
||||||
resp = {
|
resp = {
|
||||||
"id": model_name,
|
"id": model_name,
|
||||||
@ -74,4 +77,3 @@ def model_info(model_name: str) -> dict:
|
|||||||
"owned_by": "user",
|
"owned_by": "user",
|
||||||
"permission": []
|
"permission": []
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +48,6 @@ def moderations(input):
|
|||||||
|
|
||||||
category_embeddings = get_category_embeddings()
|
category_embeddings = get_category_embeddings()
|
||||||
|
|
||||||
|
|
||||||
# input, string or array
|
# input, string or array
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
|
@ -22,6 +22,7 @@ params = {
|
|||||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Handler(BaseHTTPRequestHandler):
|
class Handler(BaseHTTPRequestHandler):
|
||||||
def send_access_control_headers(self):
|
def send_access_control_headers(self):
|
||||||
self.send_header("Access-Control-Allow-Origin", "*")
|
self.send_header("Access-Control-Allow-Origin", "*")
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from extensions.openai.utils import float_list_to_base64
|
from extensions.openai.utils import float_list_to_base64
|
||||||
from modules.text_generation import encode, decode
|
from modules.text_generation import encode, decode
|
||||||
|
|
||||||
|
|
||||||
def token_count(prompt):
|
def token_count(prompt):
|
||||||
tokens = encode(prompt)[0]
|
tokens = encode(prompt)[0]
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ import os
|
|||||||
import base64
|
import base64
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_list):
|
def float_list_to_base64(float_list):
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
float_array = np.array(float_list, dtype="float32")
|
float_array = np.array(float_list, dtype="float32")
|
||||||
@ -16,11 +17,13 @@ def float_list_to_base64(float_list):
|
|||||||
ascii_string = encoded_bytes.decode('ascii')
|
ascii_string = encoded_bytes.decode('ascii')
|
||||||
return ascii_string
|
return ascii_string
|
||||||
|
|
||||||
|
|
||||||
def end_line(s):
|
def end_line(s):
|
||||||
if s and s[-1] != '\n':
|
if s and s[-1] != '\n':
|
||||||
s = s + '\n'
|
s = s + '\n'
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
def debug_msg(*args, **kwargs):
|
def debug_msg(*args, **kwargs):
|
||||||
if 'OPENEDAI_DEBUG' in os.environ:
|
if 'OPENEDAI_DEBUG' in os.environ:
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
@ -126,6 +126,8 @@ def input_modifier(string):
|
|||||||
return string
|
return string
|
||||||
|
|
||||||
# Get and save the Stable Diffusion-generated picture
|
# Get and save the Stable Diffusion-generated picture
|
||||||
|
|
||||||
|
|
||||||
def get_SD_pictures(description, character):
|
def get_SD_pictures(description, character):
|
||||||
|
|
||||||
global params
|
global params
|
||||||
@ -186,6 +188,8 @@ def get_SD_pictures(description, character):
|
|||||||
|
|
||||||
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
# TODO: how do I make the UI history ignore the resulting pictures (I don't want HTML to appear in history)
|
||||||
# and replace it with 'text' for the purposes of logging?
|
# and replace it with 'text' for the purposes of logging?
|
||||||
|
|
||||||
|
|
||||||
def output_modifier(string, state):
|
def output_modifier(string, state):
|
||||||
"""
|
"""
|
||||||
This function is applied to the model outputs.
|
This function is applied to the model outputs.
|
||||||
|
@ -126,6 +126,7 @@ class RepetitionPenaltyLogitsProcessorWithRange(LogitsProcessor):
|
|||||||
'''
|
'''
|
||||||
Copied from the transformers library
|
Copied from the transformers library
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def __init__(self, penalty: float, _range: int):
|
def __init__(self, penalty: float, _range: int):
|
||||||
if not isinstance(penalty, float) or not (penalty > 0):
|
if not isinstance(penalty, float) or not (penalty > 0):
|
||||||
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
|
||||||
|
Loading…
Reference in New Issue
Block a user