mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-26 01:30:20 +01:00
Lint the openai extension
This commit is contained in:
parent
760510db52
commit
8f97e87cac
@ -3,6 +3,9 @@
|
|||||||
# Dockerfile:
|
# Dockerfile:
|
||||||
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
|
# ENV OPENEDAI_EMBEDDING_MODEL=all-mpnet-base-v2 # Optional
|
||||||
# RUN python3 cache_embedded_model.py
|
# RUN python3 cache_embedded_model.py
|
||||||
import os, sentence_transformers
|
import os
|
||||||
|
|
||||||
|
import sentence_transformers
|
||||||
|
|
||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
model = sentence_transformers.SentenceTransformer(st_model)
|
model = sentence_transformers.SentenceTransformer(st_model)
|
||||||
|
@ -1,18 +1,15 @@
|
|||||||
import time
|
import time
|
||||||
import yaml
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from math import log, exp
|
import yaml
|
||||||
|
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||||
from transformers import LogitsProcessor, LogitsProcessorList
|
from extensions.openai.errors import InvalidRequestError
|
||||||
|
from extensions.openai.utils import debug_msg, end_line
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.text_generation import encode, decode, generate_reply
|
from modules.text_generation import decode, encode, generate_reply
|
||||||
|
from transformers import LogitsProcessor, LogitsProcessorList
|
||||||
from extensions.openai.defaults import get_default_req_params, default, clamp
|
|
||||||
from extensions.openai.utils import end_line, debug_msg
|
|
||||||
from extensions.openai.errors import *
|
|
||||||
|
|
||||||
|
|
||||||
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
# Thanks to @Cypherfox [Cypherfoxy] for the logits code, blame to @matatonic
|
||||||
@ -21,7 +18,7 @@ class LogitsBiasProcessor(LogitsProcessor):
|
|||||||
self.logit_bias = logit_bias
|
self.logit_bias = logit_bias
|
||||||
if self.logit_bias:
|
if self.logit_bias:
|
||||||
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
self.keys = list([int(key) for key in self.logit_bias.keys()])
|
||||||
values = [ self.logit_bias[str(key)] for key in self.keys ]
|
values = [self.logit_bias[str(key)] for key in self.keys]
|
||||||
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
|
self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device)
|
||||||
debug_msg(f"{self})")
|
debug_msg(f"{self})")
|
||||||
|
|
||||||
@ -36,6 +33,7 @@ class LogitsBiasProcessor(LogitsProcessor):
|
|||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
|
return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>"
|
||||||
|
|
||||||
|
|
||||||
class LogprobProcessor(LogitsProcessor):
|
class LogprobProcessor(LogitsProcessor):
|
||||||
def __init__(self, logprobs=None):
|
def __init__(self, logprobs=None):
|
||||||
self.logprobs = logprobs
|
self.logprobs = logprobs
|
||||||
@ -44,9 +42,9 @@ class LogprobProcessor(LogitsProcessor):
|
|||||||
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor:
|
||||||
if self.logprobs is not None: # 0-5
|
if self.logprobs is not None: # 0-5
|
||||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs+1)
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||||
top_tokens = [ decode(tok) for tok in top_indices[0] ]
|
top_tokens = [decode(tok) for tok in top_indices[0]]
|
||||||
top_probs = [ float(x) for x in top_values[0] ]
|
top_probs = [float(x) for x in top_values[0]]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
debug_msg(repr(self))
|
debug_msg(repr(self))
|
||||||
return logits
|
return logits
|
||||||
@ -56,14 +54,15 @@ class LogprobProcessor(LogitsProcessor):
|
|||||||
|
|
||||||
|
|
||||||
def convert_logprobs_to_tiktoken(model, logprobs):
|
def convert_logprobs_to_tiktoken(model, logprobs):
|
||||||
# more problems than it's worth.
|
# more problems than it's worth.
|
||||||
# try:
|
# try:
|
||||||
# encoder = tiktoken.encoding_for_model(model)
|
# encoder = tiktoken.encoding_for_model(model)
|
||||||
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
# # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall.
|
||||||
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
# return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()])
|
||||||
# except KeyError:
|
# except KeyError:
|
||||||
# # assume native tokens if we can't find the tokenizer
|
# # assume native tokens if we can't find the tokenizer
|
||||||
# return logprobs
|
# return logprobs
|
||||||
|
|
||||||
return logprobs
|
return logprobs
|
||||||
|
|
||||||
|
|
||||||
@ -146,7 +145,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||||||
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
if body.get('function_call', ''): # chat only, 'none', 'auto', {'name': 'func'}
|
||||||
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
raise InvalidRequestError(message="function_call is not supported.", param='function_call')
|
||||||
|
|
||||||
if not 'messages' in body:
|
if 'messages' not in body:
|
||||||
raise InvalidRequestError(message="messages is required", param='messages')
|
raise InvalidRequestError(message="messages is required", param='messages')
|
||||||
|
|
||||||
messages = body['messages']
|
messages = body['messages']
|
||||||
@ -159,7 +158,7 @@ def messages_to_prompt(body: dict, req_params: dict, max_tokens):
|
|||||||
'prompt': 'Assistant:',
|
'prompt': 'Assistant:',
|
||||||
}
|
}
|
||||||
|
|
||||||
if not 'stopping_strings' in req_params:
|
if 'stopping_strings' not in req_params:
|
||||||
req_params['stopping_strings'] = []
|
req_params['stopping_strings'] = []
|
||||||
|
|
||||||
# Instruct models can be much better
|
# Instruct models can be much better
|
||||||
@ -439,7 +438,7 @@ def completions(body: dict, is_legacy: bool = False):
|
|||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
prompt_str = 'context' if is_legacy else 'prompt'
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
if not prompt_str in body:
|
if prompt_str not in body:
|
||||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
prompt_arg = body[prompt_str]
|
prompt_arg = body[prompt_str]
|
||||||
@ -455,7 +454,7 @@ def completions(body: dict, is_legacy: bool = False):
|
|||||||
requested_model = req_params.pop('requested_model')
|
requested_model = req_params.pop('requested_model')
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
@ -538,10 +537,12 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
|
|
||||||
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
# ... encoded as a string, array of strings, array of tokens, or array of token arrays.
|
||||||
prompt_str = 'context' if is_legacy else 'prompt'
|
prompt_str = 'context' if is_legacy else 'prompt'
|
||||||
if not prompt_str in body:
|
if prompt_str not in body:
|
||||||
raise InvalidRequestError("Missing required input", param=prompt_str)
|
raise InvalidRequestError("Missing required input", param=prompt_str)
|
||||||
|
|
||||||
prompt = body[prompt_str]
|
prompt = body[prompt_str]
|
||||||
|
req_params = marshal_common_params(body)
|
||||||
|
requested_model = req_params.pop('requested_model')
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt, list):
|
||||||
if prompt and isinstance(prompt[0], int):
|
if prompt and isinstance(prompt[0], int):
|
||||||
try:
|
try:
|
||||||
@ -553,15 +554,13 @@ def stream_completions(body: dict, is_legacy: bool = False):
|
|||||||
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str)
|
||||||
|
|
||||||
# common params
|
# common params
|
||||||
req_params = marshal_common_params(body)
|
|
||||||
req_params['stream'] = True
|
req_params['stream'] = True
|
||||||
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
max_tokens_str = 'length' if is_legacy else 'max_tokens'
|
||||||
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
max_tokens = default(body, max_tokens_str, req_params['max_new_tokens'])
|
||||||
req_params['max_new_tokens'] = max_tokens
|
req_params['max_new_tokens'] = max_tokens
|
||||||
requested_model = req_params.pop('requested_model')
|
|
||||||
logprob_proc = req_params.pop('logprob_proc', None)
|
logprob_proc = req_params.pop('logprob_proc', None)
|
||||||
stopping_strings = req_params.pop('stopping_strings', [])
|
stopping_strings = req_params.pop('stopping_strings', [])
|
||||||
#req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
# req_params['suffix'] = default(body, 'suffix', req_params['suffix'])
|
||||||
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
req_params['echo'] = default(body, 'echo', req_params['echo'])
|
||||||
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
req_params['top_k'] = default(body, 'best_of', req_params['top_k'])
|
||||||
|
|
||||||
|
@ -51,9 +51,11 @@ def get_default_req_params():
|
|||||||
return copy.deepcopy(default_req_params)
|
return copy.deepcopy(default_req_params)
|
||||||
|
|
||||||
# little helper to get defaults if arg is present but None and should be the same type as default.
|
# little helper to get defaults if arg is present but None and should be the same type as default.
|
||||||
|
|
||||||
|
|
||||||
def default(dic, key, default):
|
def default(dic, key, default):
|
||||||
val = dic.get(key, default)
|
val = dic.get(key, default)
|
||||||
if type(val) != type(default):
|
if not isinstance(val, type(default)):
|
||||||
# maybe it's just something like 1 instead of 1.0
|
# maybe it's just something like 1 instead of 1.0
|
||||||
try:
|
try:
|
||||||
v = type(default)(val)
|
v = type(default)(val)
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
import time
|
import time
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
import os
|
|
||||||
from modules import shared
|
|
||||||
from extensions.openai.defaults import get_default_req_params
|
from extensions.openai.defaults import get_default_req_params
|
||||||
|
from extensions.openai.errors import InvalidRequestError
|
||||||
from extensions.openai.utils import debug_msg
|
from extensions.openai.utils import debug_msg
|
||||||
from extensions.openai.errors import *
|
from modules import shared
|
||||||
from modules.text_generation import encode, generate_reply
|
from modules.text_generation import encode, generate_reply
|
||||||
|
|
||||||
|
|
||||||
@ -74,7 +74,6 @@ def edits(instruction: str, input: str, temperature=1.0, top_p=1.0) -> dict:
|
|||||||
|
|
||||||
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
generator = generate_reply(edit_task, req_params, stopping_strings=stopping_strings, is_chat=False)
|
||||||
|
|
||||||
longest_stop_len = max([len(x) for x in stopping_strings] + [0])
|
|
||||||
answer = ''
|
answer = ''
|
||||||
for a in generator:
|
for a in generator:
|
||||||
answer = a
|
answer = a
|
||||||
|
@ -1,8 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from extensions.openai.utils import float_list_to_base64, debug_msg
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.utils import debug_msg, float_list_to_base64
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
st_model = os.environ["OPENEDAI_EMBEDDING_MODEL"] if "OPENEDAI_EMBEDDING_MODEL" in os.environ else "all-mpnet-base-v2"
|
||||||
embeddings_model = None
|
embeddings_model = None
|
||||||
@ -11,6 +12,7 @@ embeddings_device = os.environ.get("OPENEDAI_EMBEDDING_DEVICE", "cpu")
|
|||||||
if embeddings_device.lower() == 'auto':
|
if embeddings_device.lower() == 'auto':
|
||||||
embeddings_device = None
|
embeddings_device = None
|
||||||
|
|
||||||
|
|
||||||
def load_embedding_model(model: str) -> SentenceTransformer:
|
def load_embedding_model(model: str) -> SentenceTransformer:
|
||||||
global embeddings_device, embeddings_model
|
global embeddings_device, embeddings_model
|
||||||
try:
|
try:
|
||||||
@ -41,6 +43,7 @@ def get_embeddings_model_name() -> str:
|
|||||||
def get_embeddings(input: list) -> np.ndarray:
|
def get_embeddings(input: list) -> np.ndarray:
|
||||||
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
|
return get_embeddings_model().encode(input, convert_to_numpy=True, normalize_embeddings=True, convert_to_tensor=False, device=embeddings_device)
|
||||||
|
|
||||||
|
|
||||||
def embeddings(input: list, encoding_format: str) -> dict:
|
def embeddings(input: list, encoding_format: str) -> dict:
|
||||||
|
|
||||||
embeddings = get_embeddings(input)
|
embeddings = get_embeddings(input)
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.errors import ServiceUnavailableError
|
||||||
|
|
||||||
|
|
||||||
def generations(prompt: str, size: str, response_format: str, n: int):
|
def generations(prompt: str, size: str, response_format: str, n: int):
|
||||||
@ -14,7 +15,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
|||||||
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
# At this point I will not add the edits and variations endpoints (ie. img2img) because they
|
||||||
# require changing the form data handling to accept multipart form data, also to properly support
|
# require changing the form data handling to accept multipart form data, also to properly support
|
||||||
# url return types will require file management and a web serving files... Perhaps later!
|
# url return types will require file management and a web serving files... Perhaps later!
|
||||||
base_model_size = 512 if not 'SD_BASE_MODEL_SIZE' in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
base_model_size = 512 if 'SD_BASE_MODEL_SIZE' not in os.environ else int(os.environ.get('SD_BASE_MODEL_SIZE', 512))
|
||||||
sd_defaults = {
|
sd_defaults = {
|
||||||
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
'sampler_name': 'DPM++ 2M Karras', # vast improvement
|
||||||
'steps': 30,
|
'steps': 30,
|
||||||
@ -56,7 +57,7 @@ def generations(prompt: str, size: str, response_format: str, n: int):
|
|||||||
r = response.json()
|
r = response.json()
|
||||||
if response.status_code != 200 or 'images' not in r:
|
if response.status_code != 200 or 'images' not in r:
|
||||||
print(r)
|
print(r)
|
||||||
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors',None))
|
raise ServiceUnavailableError(r.get('error', 'Unknown error calling Stable Diffusion'), code=response.status_code, internal_message=r.get('errors', None))
|
||||||
# r['parameters']...
|
# r['parameters']...
|
||||||
for b64_json in r['images']:
|
for b64_json in r['images']:
|
||||||
if response_format == 'b64_json':
|
if response_format == 'b64_json':
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
from extensions.openai.embeddings import get_embeddings_model_name
|
from extensions.openai.embeddings import get_embeddings_model_name
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.errors import OpenAIError
|
||||||
from modules import shared
|
from modules import shared
|
||||||
from modules.models import load_model, unload_model
|
from modules.models import load_model as _load_model
|
||||||
|
from modules.models import unload_model
|
||||||
from modules.models_settings import get_model_metadata, update_model_parameters
|
from modules.models_settings import get_model_metadata, update_model_parameters
|
||||||
from modules.utils import get_available_models
|
from modules.utils import get_available_models
|
||||||
|
|
||||||
@ -38,7 +39,7 @@ def load_model(model_name: str) -> dict:
|
|||||||
if shared.settings['mode'] != 'instruct':
|
if shared.settings['mode'] != 'instruct':
|
||||||
shared.settings['instruction_template'] = None
|
shared.settings['instruction_template'] = None
|
||||||
|
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
shared.model, shared.tokenizer = _load_model(shared.model_name)
|
||||||
|
|
||||||
if not shared.model: # load failed.
|
if not shared.model: # load failed.
|
||||||
shared.model_name = "None"
|
shared.model_name = "None"
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
|
||||||
from numpy.linalg import norm
|
|
||||||
from extensions.openai.embeddings import get_embeddings
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from extensions.openai.embeddings import get_embeddings
|
||||||
|
from numpy.linalg import norm
|
||||||
|
|
||||||
moderations_disabled = False # return 0/false
|
moderations_disabled = False # return 0/false
|
||||||
category_embeddings = None
|
category_embeddings = None
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
flask_cloudflared==0.0.12
|
flask_cloudflared == 0.0.12
|
||||||
sentence-transformers
|
sentence - transformers
|
||||||
tiktoken
|
tiktoken
|
@ -4,19 +4,21 @@ import traceback
|
|||||||
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
|
|
||||||
from modules import shared
|
import extensions.openai.completions as OAIcompletions
|
||||||
|
|
||||||
from extensions.openai.tokens import token_count, token_encode, token_decode
|
|
||||||
import extensions.openai.models as OAImodels
|
|
||||||
import extensions.openai.edits as OAIedits
|
import extensions.openai.edits as OAIedits
|
||||||
import extensions.openai.embeddings as OAIembeddings
|
import extensions.openai.embeddings as OAIembeddings
|
||||||
import extensions.openai.images as OAIimages
|
import extensions.openai.images as OAIimages
|
||||||
|
import extensions.openai.models as OAImodels
|
||||||
import extensions.openai.moderations as OAImoderations
|
import extensions.openai.moderations as OAImoderations
|
||||||
import extensions.openai.completions as OAIcompletions
|
from extensions.openai.defaults import clamp, default, get_default_req_params
|
||||||
from extensions.openai.errors import *
|
from extensions.openai.errors import (
|
||||||
|
InvalidRequestError,
|
||||||
|
OpenAIError,
|
||||||
|
ServiceUnavailableError
|
||||||
|
)
|
||||||
|
from extensions.openai.tokens import token_count, token_decode, token_encode
|
||||||
from extensions.openai.utils import debug_msg
|
from extensions.openai.utils import debug_msg
|
||||||
from extensions.openai.defaults import (get_default_req_params, default, clamp)
|
from modules import shared
|
||||||
|
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
'port': int(os.environ.get('OPENEDAI_PORT')) if 'OPENEDAI_PORT' in os.environ else 5001,
|
||||||
@ -209,7 +211,7 @@ class Handler(BaseHTTPRequestHandler):
|
|||||||
self.return_json(response)
|
self.return_json(response)
|
||||||
|
|
||||||
elif '/images/generations' in self.path:
|
elif '/images/generations' in self.path:
|
||||||
if not 'SD_WEBUI_URL' in os.environ:
|
if 'SD_WEBUI_URL' not in os.environ:
|
||||||
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
raise ServiceUnavailableError("Stable Diffusion not available. SD_WEBUI_URL not set.")
|
||||||
|
|
||||||
prompt = body['prompt']
|
prompt = body['prompt']
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from extensions.openai.utils import float_list_to_base64
|
from modules.text_generation import decode, encode
|
||||||
from modules.text_generation import encode, decode
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def token_count(prompt):
|
def token_count(prompt):
|
||||||
tokens = encode(prompt)[0]
|
tokens = encode(prompt)[0]
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
import os
|
|
||||||
import base64
|
import base64
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||||
# Convert the list to a float32 array that the OpenAPI client expects
|
# Convert the list to a float32 array that the OpenAPI client expects
|
||||||
#float_array = np.array(float_list, dtype="float32")
|
# float_array = np.array(float_list, dtype="float32")
|
||||||
|
|
||||||
# Get raw bytes
|
# Get raw bytes
|
||||||
bytes_array = float_array.tobytes()
|
bytes_array = float_array.tobytes()
|
||||||
|
Loading…
Reference in New Issue
Block a user