Lint the openai extension

This commit is contained in:
oobabooga 2023-09-15 20:11:16 -07:00
parent 760510db52
commit 8f97e87cac
12 changed files with 79 additions and 69 deletions

View File

@ -3,6 +3,9 @@
# Dockerfile: # Dockerfile:
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional # ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
# RUN python3 cache_embedded_model.py # RUN python3 cache_embedded_model.py
import os, sentence_transformers import os
import sentence_transformers
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
model = sentence_transformers.SentenceTransformer(st_model) model = sentence_transformers.SentenceTransformer(st_model)

View File

@ -1,18 +1,15 @@
import time import time
import yaml
import tiktoken import tiktoken
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from math import log, exp import yaml
from extensions.openai.defaults import clamp, default, get_default_req_params
from transformers import LogitsProcessor, LogitsProcessorList from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg, end_line
from modules import shared from modules import shared
from modules.text_generation import encode, decode, generate_reply from modules.text_generation import decode, encode, generate_reply
from transformers import LogitsProcessor, LogitsProcessorList
from extensions.openai.defaults import get_default_req_params, default, clamp
from extensions.openai.utils import end_line, debug_msg
from extensions.openai.errors import *
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic # Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
@ -21,7 +18,7 @@ class LogitsBiasProcessor(LogitsProcessor):
self.logit_bias = logit_bias self.logit_bias = logit_bias
if self.logit_bias: if self.logit_bias:
self.keys = list([int(key) for key in self.logit_bias.keys()]) self.keys = list([int(key) for key in self.logit_bias.keys()])
values = [ self.logit_bias[str(key)] for key in self.keys ] values = [self.logit_bias[str(key)] for key in self.keys]
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
debug_msg(f"{self})") debug_msg(f"{self})")
@ -36,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
def __repr__(self): def __repr__(self):
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
class LogprobProcessor(LogitsProcessor): class LogprobProcessor(LogitsProcessor):
def __init__(self, logprobs=None): def __init__(self, logprobs=None):
self.logprobs = logprobs self.logprobs = logprobs
@ -44,9 +42,9 @@ class LogprobProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
if self.logprobs is not None: # 0-5 if self.logprobs is not None: # 0-5
log_e_probabilities = F.log_softmax(logits, dim=1) log_e_probabilities = F.log_softmax(logits, dim=1)
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
top_tokens = [ decode(tok) for tok in top_indices[0] ] top_tokens = [decode(tok) for tok in top_indices[0]]
top_probs = [ float(x) for x in top_values[0] ] top_probs = [float(x) for x in top_values[0]]
self.token_alternatives = dict(zip(top_tokens, top_probs)) self.token_alternatives = dict(zip(top_tokens, top_probs))
debug_msg(repr(self)) debug_msg(repr(self))
return logits return logits
@ -56,14 +54,15 @@ class LogprobProcessor(LogitsProcessor):
def convert_logprobs_to_tiktoken(model, logprobs): def convert_logprobs_to_tiktoken(model, logprobs):
# more problems than it's worth. # more problems than it's worth.
# try: # try:
# encoder = tiktoken.encoding_for_model(model) # encoder = tiktoken.encoding_for_model(model)
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
# except KeyError: # except KeyError:
# # assume native tokens if we can't find the tokenizer # # assume native tokens if we can't find the tokenizer
# return logprobs # return logprobs
return logprobs return logprobs
@ -115,7 +114,7 @@ def marshal_common_params(body):
new_logit_bias = {} new_logit_bias = {}
for logit, bias in logit_bias.items(): for logit, bias in logit_bias.items():
for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]:
if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens
continue continue
new_logit_bias[str(int(x))] = bias new_logit_bias[str(int(x))] = bias
debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias)
@ -146,7 +145,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'} if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
raise InvalidRequestError(message="function_call is not supported.", param='function_call') raise InvalidRequestError(message="function_call is not supported.", param='function_call')
if not 'messages' in body: if 'messages' not in body:
raise InvalidRequestError(message="messages is required", param='messages') raise InvalidRequestError(message="messages is required", param='messages')
messages = body['messages'] messages = body['messages']
@ -159,7 +158,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
'prompt': 'Assistant:', 'prompt': 'Assistant:',
} }
if not 'stopping_strings' in req_params: if 'stopping_strings' not in req_params:
req_params['stopping_strings'] = [] req_params['stopping_strings'] = []
# Instruct models can be much better # Instruct models can be much better
@ -169,7 +168,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
template = instruct['turn_template'] template = instruct['turn_template']
system_message_template = "{message}" system_message_template = "{message}"
system_message_default = instruct.get('context', '') # can be missing system_message_default = instruct.get('context', '') # can be missing
bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token
user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', '')) user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', ''))
bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', '')) bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', ''))
@ -216,7 +215,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
raise InvalidRequestError(message="messages: missing role", param='messages') raise InvalidRequestError(message="messages: missing role", param='messages')
if 'content' not in m: if 'content' not in m:
raise InvalidRequestError(message="messages: missing content", param='messages') raise InvalidRequestError(message="messages: missing content", param='messages')
role = m['role'] role = m['role']
content = m['content'] content = m['content']
# name = m.get('name', None) # name = m.get('name', None)
@ -439,7 +438,7 @@ def completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt' prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body: if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str) raise InvalidRequestError("Missing required input", param=prompt_str)
prompt_arg = body[prompt_str] prompt_arg = body[prompt_str]
@ -455,7 +454,7 @@ def completions(body: dict, is_legacy: bool = False):
requested_model = req_params.pop('requested_model') requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None) logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', []) stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix']) # req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k']) req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
@ -538,10 +537,12 @@ def stream_completions(body: dict, is_legacy: bool = False):
# ... encoded as a string, array of strings, array of tokens, or array of token arrays. # ... encoded as a string, array of strings, array of tokens, or array of token arrays.
prompt_str = 'context' if is_legacy else 'prompt' prompt_str = 'context' if is_legacy else 'prompt'
if not prompt_str in body: if prompt_str not in body:
raise InvalidRequestError("Missing required input", param=prompt_str) raise InvalidRequestError("Missing required input", param=prompt_str)
prompt = body[prompt_str] prompt = body[prompt_str]
req_params = marshal_common_params(body)
requested_model = req_params.pop('requested_model')
if isinstance(prompt, list): if isinstance(prompt, list):
if prompt and isinstance(prompt[0], int): if prompt and isinstance(prompt[0], int):
try: try:
@ -553,15 +554,13 @@ def stream_completions(body: dict, is_legacy: bool = False):
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
# common params # common params
req_params = marshal_common_params(body)
req_params['stream'] = True req_params['stream'] = True
max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens_str = 'length' if is_legacy else 'max_tokens'
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
req_params['max_new_tokens'] = max_tokens req_params['max_new_tokens'] = max_tokens
requested_model = req_params.pop('requested_model')
logprob_proc = req_params.pop('logprob_proc', None) logprob_proc = req_params.pop('logprob_proc', None)
stopping_strings = req_params.pop('stopping_strings', []) stopping_strings = req_params.pop('stopping_strings', [])
#req_params['suffix'] = default(body, 'suffix', req_params['suffix']) # req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['echo'] = default(body, 'echo', req_params['echo'])
req_params['top_k'] = default(body, 'best_of', req_params['top_k']) req_params['top_k'] = default(body, 'best_of', req_params['top_k'])

View File

@ -51,9 +51,11 @@ def get_default_req_params():
return copy.deepcopy(default_req_params) return copy.deepcopy(default_req_params)
# little helper to get defaults if arg is present but None and should be the same type as default. # little helper to get defaults if arg is present but None and should be the same type as default.
def default(dic, key, default): def default(dic, key, default):
val = dic.get(key, default) val = dic.get(key, default)
if type(val) != type(default): if not isinstance(val, type(default)):
# maybe it's just something like 1 instead of 1.0 # maybe it's just something like 1 instead of 1.0
try: try:
v = type(default)(val) v = type(default)(val)

View File

@ -1,10 +1,10 @@
import time import time
import yaml import yaml
import os
from modules import shared
from extensions.openai.defaults import get_default_req_params from extensions.openai.defaults import get_default_req_params
from extensions.openai.errors import InvalidRequestError
from extensions.openai.utils import debug_msg from extensions.openai.utils import debug_msg
from extensions.openai.errors import * from modules import shared
from modules.text_generation import encode, generate_reply from modules.text_generation import encode, generate_reply
@ -74,7 +74,6 @@ def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
answer = '' answer = ''
for a in generator: for a in generator:
answer = a answer = a

View File

@ -1,8 +1,9 @@
import os import os
from sentence_transformers import SentenceTransformer
import numpy as np import numpy as np
from extensions.openai.utils import float_list_to_base64, debug_msg from extensions.openai.errors import ServiceUnavailableError
from extensions.openai.errors import * from extensions.openai.utils import debug_msg, float_list_to_base64
from sentence_transformers import SentenceTransformer
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
embeddings_model = None embeddings_model = None
@ -11,10 +12,11 @@ embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == 'auto': if embeddings_device.lower() == 'auto':
embeddings_device = None embeddings_device = None
def load_embedding_model(model: str) -> SentenceTransformer: def load_embedding_model(model: str) -> SentenceTransformer:
global embeddings_device, embeddings_model global embeddings_device, embeddings_model
try: try:
embeddings_model = 'loading...' # flag embeddings_model = 'loading...' # flag
# see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer
emb_model = SentenceTransformer(model, device=embeddings_device) emb_model = SentenceTransformer(model, device=embeddings_device)
# ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM # ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM
@ -41,6 +43,7 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray: def get_embeddings(input: list) -> np.ndarray:
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device) return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
def embeddings(input: list, encoding_format: str) -> dict: def embeddings(input: list, encoding_format: str) -> dict:
embeddings = get_embeddings(input) embeddings = get_embeddings(input)

View File

@ -1,7 +1,8 @@
import os import os
import time import time
import requests import requests
from extensions.openai.errors import * from extensions.openai.errors import ServiceUnavailableError
def generations(prompt: str, size: str, response_format: str, n: int): def generations(prompt: str, size: str, response_format: str, n: int):
@ -14,7 +15,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
# At this point I will not add the edits and variations endpoints (ie. img2img) because they # At this point I will not add the edits and variations endpoints (ie. img2img) because they
# require changing the form data handling to accept multipart form data, also to properly support # require changing the form data handling to accept multipart form data, also to properly support
# url return types will require file management and a web serving files... Perhaps later! # url return types will require file management and a web serving files... Perhaps later!
base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
sd_defaults = { sd_defaults = {
'sampler_name': 'DPM++ 2M Karras', # vast improvement 'sampler_name': 'DPM++ 2M Karras', # vast improvement
'steps': 30, 'steps': 30,
@ -56,7 +57,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
r = response.json() r = response.json()
if response.status_code != 200 or 'images' not in r: if response.status_code != 200 or 'images' not in r:
print(r) print(r)
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors',None)) raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
# r['parameters']... # r['parameters']...
for b64_json in r['images']: for b64_json in r['images']:
if response_format == 'b64_json': if response_format == 'b64_json':

View File

@ -1,7 +1,8 @@
from extensions.openai.embeddings import get_embeddings_model_name from extensions.openai.embeddings import get_embeddings_model_name
from extensions.openai.errors import * from extensions.openai.errors import OpenAIError
from modules import shared from modules import shared
from modules.models import load_model, unload_model from modules.models import load_model as _load_model
from modules.models import unload_model
from modules.models_settings import get_model_metadata, update_model_parameters from modules.models_settings import get_model_metadata, update_model_parameters
from modules.utils import get_available_models from modules.utils import get_available_models
@ -38,7 +39,7 @@ def load_model(model_name: str) -> dict:
if shared.settings['mode'] != 'instruct': if shared.settings['mode'] != 'instruct':
shared.settings['instruction_template'] = None shared.settings['instruction_template'] = None
shared.model, shared.tokenizer = load_model(shared.model_name) shared.model, shared.tokenizer = _load_model(shared.model_name)
if not shared.model: # load failed. if not shared.model: # load failed.
shared.model_name = "None" shared.model_name = "None"

View File

@ -1,8 +1,8 @@
import time import time
import numpy as np
from numpy.linalg import norm
from extensions.openai.embeddings import get_embeddings
import numpy as np
from extensions.openai.embeddings import get_embeddings
from numpy.linalg import norm
moderations_disabled = False # return 0/false moderations_disabled = False # return 0/false
category_embeddings = None category_embeddings = None

View File

@ -1,3 +1,3 @@
flask_cloudflared==0.0.12 flask_cloudflared == 0.0.12
sentence-transformers sentence - transformers
tiktoken tiktoken

View File

@ -4,19 +4,21 @@ import traceback
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from threading import Thread from threading import Thread
from modules import shared import extensions.openai.completions as OAIcompletions
from extensions.openai.tokens import token_count, token_encode, token_decode
import extensions.openai.models as OAImodels
import extensions.openai.edits as OAIedits import extensions.openai.edits as OAIedits
import extensions.openai.embeddings as OAIembeddings import extensions.openai.embeddings as OAIembeddings
import extensions.openai.images as OAIimages import extensions.openai.images as OAIimages
import extensions.openai.models as OAImodels
import extensions.openai.moderations as OAImoderations import extensions.openai.moderations as OAImoderations
import extensions.openai.completions as OAIcompletions from extensions.openai.defaults import clamp, default, get_default_req_params
from extensions.openai.errors import * from extensions.openai.errors import (
InvalidRequestError,
OpenAIError,
ServiceUnavailableError
)
from extensions.openai.tokens import token_count, token_decode, token_encode
from extensions.openai.utils import debug_msg from extensions.openai.utils import debug_msg
from extensions.openai.defaults import (get_default_req_params, default, clamp) from modules import shared
params = { params = {
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
@ -209,7 +211,7 @@ class Handler(BaseHTTPRequestHandler):
self.return_json(response) self.return_json(response)
elif '/images/generations' in self.path: elif '/images/generations' in self.path:
if not 'SD_WEBUI_URL' in os.environ: if 'SD_WEBUI_URL' not in os.environ:
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
prompt = body['prompt'] prompt = body['prompt']

View File

@ -1,6 +1,5 @@
from extensions.openai.utils import float_list_to_base64 from modules.text_generation import decode, encode
from modules.text_generation import encode, decode
import numpy as np
def token_count(prompt): def token_count(prompt):
tokens = encode(prompt)[0] tokens = encode(prompt)[0]

View File

@ -1,11 +1,12 @@
import os
import base64 import base64
import os
import numpy as np import numpy as np
def float_list_to_base64(float_array: np.ndarray) -> str: def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects # Convert the list to a float32 array that the OpenAPI client expects
#float_array = np.array(float_list, dtype="float32") # float_array = np.array(float_list, dtype="float32")
# Get raw bytes # Get raw bytes
bytes_array = float_array.tobytes() bytes_array = float_array.tobytes()