From 8f97e87cac4fda9e25b7222d18fdc17669f203a5 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Fri, 15 Sep 2023 20:11:16 -0700 Subject: [PATCH] Lint the openai extension --- extensions/openai/cache_embedding_model.py | 5 +- extensions/openai/completions.py | 65 +++++++++++----------- extensions/openai/defaults.py | 4 +- extensions/openai/edits.py | 7 +-- extensions/openai/embeddings.py | 11 ++-- extensions/openai/images.py | 7 ++- extensions/openai/models.py | 7 ++- extensions/openai/moderations.py | 6 +- extensions/openai/requirements.txt | 6 +- extensions/openai/script.py | 20 ++++--- extensions/openai/tokens.py | 5 +- extensions/openai/utils.py | 5 +- 12 files changed, 79 insertions(+), 69 deletions(-) diff --git a/extensions/openai/cache_embedding_model.py b/extensions/openai/cache_embedding_model.py index 44ac1dcd..2bd69844 100644 --- a/extensions/openai/cache_embedding_model.py +++ b/extensions/openai/cache_embedding_model.py @@ -3,6 +3,9 @@ # Dockerfile: # ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional # RUN python3 cache_embedded_model.py -import os, sentence_transformers +import os + +import sentence_transformers + st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" model = sentence_transformers.SentenceTransformer(st_model) diff --git a/extensions/openai/completions.py b/extensions/openai/completions.py index 3e277710..40d96c1f 100644 --- a/extensions/openai/completions.py +++ b/extensions/openai/completions.py @@ -1,18 +1,15 @@ import time -import yaml + import tiktoken import torch import torch.nn.functional as F -from math import log, exp - -from transformers import LogitsProcessor, LogitsProcessorList - +import yaml +from extensions.openai.defaults import clamp, default, get_default_req_params +from extensions.openai.errors import InvalidRequestError +from extensions.openai.utils import debug_msg, end_line from modules import shared -from modules.text_generation import encode, decode, generate_reply - -from extensions.openai.defaults import get_default_req_params, default, clamp -from extensions.openai.utils import end_line, debug_msg -from extensions.openai.errors import * +from modules.text_generation import decode, encode, generate_reply +from transformers import LogitsProcessor, LogitsProcessorList # Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic @@ -21,7 +18,7 @@ class LogitsBiasProcessor(LogitsProcessor): self.logit_bias = logit_bias if self.logit_bias: self.keys = list([int(key) for key in self.logit_bias.keys()]) - values = [ self.logit_bias[str(key)] for key in self.keys ] + values = [self.logit_bias[str(key)] for key in self.keys] self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) debug_msg(f"{self})") @@ -36,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor): def __repr__(self): return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" + class LogprobProcessor(LogitsProcessor): def __init__(self, logprobs=None): self.logprobs = logprobs @@ -44,9 +42,9 @@ class LogprobProcessor(LogitsProcessor): def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) - top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1) - top_tokens = [ decode(tok) for tok in top_indices[0] ] - top_probs = [ float(x) for x in top_values[0] ] + top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) + top_tokens = [decode(tok) for tok in top_indices[0]] + top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) debug_msg(repr(self)) return logits @@ -56,14 +54,15 @@ class LogprobProcessor(LogitsProcessor): def convert_logprobs_to_tiktoken(model, logprobs): -# more problems than it's worth. -# try: -# encoder = tiktoken.encoding_for_model(model) -# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. -# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) -# except KeyError: -# # assume native tokens if we can't find the tokenizer -# return logprobs + # more problems than it's worth. + # try: + # encoder = tiktoken.encoding_for_model(model) + # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. + # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) + # except KeyError: + # # assume native tokens if we can't find the tokenizer + # return logprobs + return logprobs @@ -115,7 +114,7 @@ def marshal_common_params(body): new_logit_bias = {} for logit, bias in logit_bias.items(): for x in encode(encoder.decode([int(logit)]), add_special_tokens=False)[0]: - if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens + if int(x) in [0, 1, 2, 29871]: # XXX LLAMA tokens continue new_logit_bias[str(int(x))] = bias debug_msg('logit_bias_map', logit_bias, '->', new_logit_bias) @@ -146,7 +145,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens): if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'} raise InvalidRequestError(message="function_call is not supported.", param='function_call') - if not 'messages' in body: + if 'messages' not in body: raise InvalidRequestError(message="messages is required", param='messages') messages = body['messages'] @@ -159,7 +158,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens): 'prompt': 'Assistant:', } - if not 'stopping_strings' in req_params: + if 'stopping_strings' not in req_params: req_params['stopping_strings'] = [] # Instruct models can be much better @@ -169,7 +168,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens): template = instruct['turn_template'] system_message_template = "{message}" - system_message_default = instruct.get('context', '') # can be missing + system_message_default = instruct.get('context', '') # can be missing bot_start = template.find('<|bot|>') # So far, 100% of instruction templates have this token user_message_template = template[:bot_start].replace('<|user-message|>', '{message}').replace('<|user|>', instruct.get('user', '')) bot_message_template = template[bot_start:].replace('<|bot-message|>', '{message}').replace('<|bot|>', instruct.get('bot', '')) @@ -216,7 +215,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens): raise InvalidRequestError(message="messages: missing role", param='messages') if 'content' not in m: raise InvalidRequestError(message="messages: missing content", param='messages') - + role = m['role'] content = m['content'] # name = m.get('name', None) @@ -439,7 +438,7 @@ def completions(body: dict, is_legacy: bool = False): # ... encoded as a string, array of strings, array of tokens, or array of token arrays. prompt_str = 'context' if is_legacy else 'prompt' - if not prompt_str in body: + if prompt_str not in body: raise InvalidRequestError("Missing required input", param=prompt_str) prompt_arg = body[prompt_str] @@ -455,7 +454,7 @@ def completions(body: dict, is_legacy: bool = False): requested_model = req_params.pop('requested_model') logprob_proc = req_params.pop('logprob_proc', None) stopping_strings = req_params.pop('stopping_strings', []) - #req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['top_k'] = default(body, 'best_of', req_params['top_k']) @@ -538,10 +537,12 @@ def stream_completions(body: dict, is_legacy: bool = False): # ... encoded as a string, array of strings, array of tokens, or array of token arrays. prompt_str = 'context' if is_legacy else 'prompt' - if not prompt_str in body: + if prompt_str not in body: raise InvalidRequestError("Missing required input", param=prompt_str) prompt = body[prompt_str] + req_params = marshal_common_params(body) + requested_model = req_params.pop('requested_model') if isinstance(prompt, list): if prompt and isinstance(prompt[0], int): try: @@ -553,15 +554,13 @@ def stream_completions(body: dict, is_legacy: bool = False): raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) # common params - req_params = marshal_common_params(body) req_params['stream'] = True max_tokens_str = 'length' if is_legacy else 'max_tokens' max_tokens = default(body, max_tokens_str, req_params['max_new_tokens']) req_params['max_new_tokens'] = max_tokens - requested_model = req_params.pop('requested_model') logprob_proc = req_params.pop('logprob_proc', None) stopping_strings = req_params.pop('stopping_strings', []) - #req_params['suffix'] = default(body, 'suffix', req_params['suffix']) + # req_params['suffix'] = default(body, 'suffix', req_params['suffix']) req_params['echo'] = default(body, 'echo', req_params['echo']) req_params['top_k'] = default(body, 'best_of', req_params['top_k']) diff --git a/extensions/openai/defaults.py b/extensions/openai/defaults.py index 052862f7..88a1aaac 100644 --- a/extensions/openai/defaults.py +++ b/extensions/openai/defaults.py @@ -51,9 +51,11 @@ def get_default_req_params(): return copy.deepcopy(default_req_params) # little helper to get defaults if arg is present but None and should be the same type as default. + + def default(dic, key, default): val = dic.get(key, default) - if type(val) != type(default): + if not isinstance(val, type(default)): # maybe it's just something like 1 instead of 1.0 try: v = type(default)(val) diff --git a/extensions/openai/edits.py b/extensions/openai/edits.py index 2b527dc0..edf4e6c0 100644 --- a/extensions/openai/edits.py +++ b/extensions/openai/edits.py @@ -1,10 +1,10 @@ import time + import yaml -import os -from modules import shared from extensions.openai.defaults import get_default_req_params +from extensions.openai.errors import InvalidRequestError from extensions.openai.utils import debug_msg -from extensions.openai.errors import * +from modules import shared from modules.text_generation import encode, generate_reply @@ -74,7 +74,6 @@ def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict: generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False) - longest_stop_len = max([len(x) for x in stopping_strings] + [0]) answer = '' for a in generator: answer = a diff --git a/extensions/openai/embeddings.py b/extensions/openai/embeddings.py index be4cd80b..24ab55e5 100644 --- a/extensions/openai/embeddings.py +++ b/extensions/openai/embeddings.py @@ -1,8 +1,9 @@ import os -from sentence_transformers import SentenceTransformer + import numpy as np -from extensions.openai.utils import float_list_to_base64, debug_msg -from extensions.openai.errors import * +from extensions.openai.errors import ServiceUnavailableError +from extensions.openai.utils import debug_msg, float_list_to_base64 +from sentence_transformers import SentenceTransformer st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2" embeddings_model = None @@ -11,10 +12,11 @@ embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu") if embeddings_device.lower() == 'auto': embeddings_device = None + def load_embedding_model(model: str) -> SentenceTransformer: global embeddings_device, embeddings_model try: - embeddings_model = 'loading...' # flag + embeddings_model = 'loading...' # flag # see: https://www.sbert.net/docs/package_reference/SentenceTransformer.html#sentence_transformers.SentenceTransformer emb_model = SentenceTransformer(model, device=embeddings_device) # ... emb_model.device doesn't seem to work, always cpu anyways? but specify cpu anyways to free more VRAM @@ -41,6 +43,7 @@ def get_embeddings_model_name() -> str: def get_embeddings(input: list) -> np.ndarray: return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device) + def embeddings(input: list, encoding_format: str) -> dict: embeddings = get_embeddings(input) diff --git a/extensions/openai/images.py b/extensions/openai/images.py index 9fdb625e..cdb29d6c 100644 --- a/extensions/openai/images.py +++ b/extensions/openai/images.py @@ -1,7 +1,8 @@ import os import time + import requests -from extensions.openai.errors import * +from extensions.openai.errors import ServiceUnavailableError def generations(prompt: str, size: str, response_format: str, n: int): @@ -14,7 +15,7 @@ def generations(prompt: str, size: str, response_format: str, n: int): # At this point I will not add the edits and variations endpoints (ie. img2img) because they # require changing the form data handling to accept multipart form data, also to properly support # url return types will require file management and a web serving files... Perhaps later! - base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) + base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512)) sd_defaults = { 'sampler_name': 'DPM++ 2M Karras', # vast improvement 'steps': 30, @@ -56,7 +57,7 @@ def generations(prompt: str, size: str, response_format: str, n: int): r = response.json() if response.status_code != 200 or 'images' not in r: print(r) - raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors',None)) + raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None)) # r['parameters']... for b64_json in r['images']: if response_format == 'b64_json': diff --git a/extensions/openai/models.py b/extensions/openai/models.py index e6715a81..83e550f8 100644 --- a/extensions/openai/models.py +++ b/extensions/openai/models.py @@ -1,7 +1,8 @@ from extensions.openai.embeddings import get_embeddings_model_name -from extensions.openai.errors import * +from extensions.openai.errors import OpenAIError from modules import shared -from modules.models import load_model, unload_model +from modules.models import load_model as _load_model +from modules.models import unload_model from modules.models_settings import get_model_metadata, update_model_parameters from modules.utils import get_available_models @@ -38,7 +39,7 @@ def load_model(model_name: str) -> dict: if shared.settings['mode'] != 'instruct': shared.settings['instruction_template'] = None - shared.model, shared.tokenizer = load_model(shared.model_name) + shared.model, shared.tokenizer = _load_model(shared.model_name) if not shared.model: # load failed. shared.model_name = "None" diff --git a/extensions/openai/moderations.py b/extensions/openai/moderations.py index 5b06a672..1d2d4c1d 100644 --- a/extensions/openai/moderations.py +++ b/extensions/openai/moderations.py @@ -1,8 +1,8 @@ import time -import numpy as np -from numpy.linalg import norm -from extensions.openai.embeddings import get_embeddings +import numpy as np +from extensions.openai.embeddings import get_embeddings +from numpy.linalg import norm moderations_disabled = False # return 0/false category_embeddings = None diff --git a/extensions/openai/requirements.txt b/extensions/openai/requirements.txt index 56d567b8..e0f7c7ce 100644 --- a/extensions/openai/requirements.txt +++ b/extensions/openai/requirements.txt @@ -1,3 +1,3 @@ -flask_cloudflared==0.0.12 -sentence-transformers -tiktoken \ No newline at end of file +flask_cloudflared == 0.0.12 +sentence - transformers +tiktoken diff --git a/extensions/openai/script.py b/extensions/openai/script.py index 5b319674..28263fa5 100644 --- a/extensions/openai/script.py +++ b/extensions/openai/script.py @@ -4,19 +4,21 @@ import traceback from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from threading import Thread -from modules import shared - -from extensions.openai.tokens import token_count, token_encode, token_decode -import extensions.openai.models as OAImodels +import extensions.openai.completions as OAIcompletions import extensions.openai.edits as OAIedits import extensions.openai.embeddings as OAIembeddings import extensions.openai.images as OAIimages +import extensions.openai.models as OAImodels import extensions.openai.moderations as OAImoderations -import extensions.openai.completions as OAIcompletions -from extensions.openai.errors import * +from extensions.openai.defaults import clamp, default, get_default_req_params +from extensions.openai.errors import ( + InvalidRequestError, + OpenAIError, + ServiceUnavailableError +) +from extensions.openai.tokens import token_count, token_decode, token_encode from extensions.openai.utils import debug_msg -from extensions.openai.defaults import (get_default_req_params, default, clamp) - +from modules import shared params = { 'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001, @@ -209,7 +211,7 @@ class Handler(BaseHTTPRequestHandler): self.return_json(response) elif '/images/generations' in self.path: - if not 'SD_WEBUI_URL' in os.environ: + if 'SD_WEBUI_URL' not in os.environ: raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.") prompt = body['prompt'] diff --git a/extensions/openai/tokens.py b/extensions/openai/tokens.py index f8d6737a..0338e7f2 100644 --- a/extensions/openai/tokens.py +++ b/extensions/openai/tokens.py @@ -1,6 +1,5 @@ -from extensions.openai.utils import float_list_to_base64 -from modules.text_generation import encode, decode -import numpy as np +from modules.text_generation import decode, encode + def token_count(prompt): tokens = encode(prompt)[0] diff --git a/extensions/openai/utils.py b/extensions/openai/utils.py index abc1acbc..1e83bcbe 100644 --- a/extensions/openai/utils.py +++ b/extensions/openai/utils.py @@ -1,11 +1,12 @@ -import os import base64 +import os + import numpy as np def float_list_to_base64(float_array: np.ndarray) -> str: # Convert the list to a float32 array that the OpenAPI client expects - #float_array = np.array(float_list, dtype="float32") + # float_array = np.array(float_list, dtype="float32") # Get raw bytes bytes_array = float_array.tobytes()