import base64 import copy import re import time from collections import deque from io import BytesIO import requests import tiktoken import torch import torch.nn.functional as F from PIL import Image from transformers import LogitsProcessor, LogitsProcessorList from extensions.openai.errors import InvalidRequestError from extensions.openai.utils import debug_msg from modules import shared from modules.chat import ( generate_chat_prompt, generate_chat_reply, load_character_memoized, load_instruction_template_memoized ) from modules.presets import load_preset_memoized from modules.text_generation import ( decode, encode, generate_reply, get_reply_from_output_ids ) class LogitsBiasProcessor(LogitsProcessor): def __init__(self, logit_bias={}): self.logit_bias = logit_bias if self.logit_bias: self.keys = list([int(key) for key in self.logit_bias.keys()]) values = [self.logit_bias[str(key)] for key in self.keys] self.values = torch.tensor(values, dtype=torch.float, device=shared.model.device) debug_msg(f"{self})") def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logit_bias: debug_msg(logits[0, self.keys], " + ", self.values) logits[0, self.keys] += self.values debug_msg(" --> ", logits[0, self.keys]) debug_msg(" max/min ", float(torch.max(logits[0])), float(torch.min(logits[0]))) return logits def __repr__(self): return f"<{self.__class__.__name__}(logit_bias={self.logit_bias})>" class LogprobProcessor(LogitsProcessor): def __init__(self, logprobs=None): self.logprobs = logprobs self.token_alternatives = {} def __call__(self, input_ids: torch.LongTensor, logits: torch.FloatTensor) -> torch.FloatTensor: if self.logprobs is not None: # 0-5 log_e_probabilities = F.log_softmax(logits, dim=1) top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1) top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]] top_probs = [float(x) for x in top_values[0]] self.token_alternatives = dict(zip(top_tokens, top_probs)) debug_msg(repr(self)) return logits def __repr__(self): return f"<{self.__class__.__name__}(logprobs={self.logprobs}, token_alternatives={self.token_alternatives})>" def convert_logprobs_to_tiktoken(model, logprobs): # more problems than it's worth. # try: # encoder = tiktoken.encoding_for_model(model) # # just pick the first one if it encodes to multiple tokens... 99.9% not required and maybe worse overall. # return dict([(encoder.decode([encoder.encode(token)[0]]), prob) for token, prob in logprobs.items()]) # except KeyError: # # assume native tokens if we can't find the tokenizer # return logprobs return logprobs def process_parameters(body, is_legacy=False): generate_params = body max_tokens_str = 'length' if is_legacy else 'max_tokens' generate_params['max_new_tokens'] = body.pop(max_tokens_str) if generate_params['truncation_length'] == 0: generate_params['truncation_length'] = shared.settings['truncation_length'] if generate_params['temperature'] == 0: generate_params['do_sample'] = False generate_params['top_k'] = 1 if body['preset'] is not None: preset = load_preset_memoized(body['preset']) generate_params.update(preset) generate_params['custom_stopping_strings'] = [] if 'stop' in body: # str or array, max len 4 (ignored) if isinstance(body['stop'], str): generate_params['custom_stopping_strings'] = [body['stop']] elif isinstance(body['stop'], list): generate_params['custom_stopping_strings'] = body['stop'] logits_processor = [] logit_bias = body.get('logit_bias', None) if logit_bias: # {str: float, ...} logits_processor = [LogitsBiasProcessor(logit_bias)] logprobs = None # coming to chat eventually if 'logprobs' in body: logprobs = body.get('logprobs', 0) # maybe cap at topk? don't clamp 0-5. generate_params['logprob_proc'] = LogprobProcessor(logprobs) logits_processor.extend([generate_params['logprob_proc']]) else: logprobs = None if logits_processor: # requires logits_processor support generate_params['logits_processor'] = LogitsProcessorList(logits_processor) return generate_params def convert_history(history): ''' Chat histories in this program are in the format [message, reply]. This function converts OpenAI histories to that format. ''' chat_dialogue = [] current_message = "" current_reply = "" user_input = "" user_input_last = True system_message = "" # Multimodal: convert OpenAI format to multimodal extension format if any('content' in entry and isinstance(entry['content'], list) for entry in history): new_history = [] for entry in history: if isinstance(entry['content'], list): image_url = None content = None for item in entry['content']: if not isinstance(item, dict): continue if item['type'] == 'image_url' and isinstance(item['image_url'], dict): image_url = item['image_url']['url'] elif item['type'] == 'text' and isinstance(item['text'], str): content = item['text'] if image_url: new_history.append({"image_url": image_url, "role": "user"}) if content: new_history.append({"content": content, "role": "user"}) else: new_history.append(entry) history = new_history for entry in history: if "image_url" in entry: image_url = entry['image_url'] if "base64" in image_url: image_url = re.sub('^data:image/.+;base64,', '', image_url) img = Image.open(BytesIO(base64.b64decode(image_url))) else: try: my_res = requests.get(image_url) img = Image.open(BytesIO(my_res.content)) except Exception: raise 'Image cannot be loaded from the URL!' buffered = BytesIO() if img.mode in ("RGBA", "P"): img = img.convert("RGB") img.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode('utf-8') content = f'' else: content = entry["content"] role = entry["role"] if role == "user": user_input = content user_input_last = True if current_message: chat_dialogue.append([current_message, '']) current_message = "" current_message = content elif role == "assistant": current_reply = content user_input_last = False if current_message: chat_dialogue.append([current_message, current_reply]) current_message = "" current_reply = "" else: chat_dialogue.append(['', current_reply]) elif role == "system": system_message = content if not user_input_last: user_input = "" return user_input, system_message, {'internal': chat_dialogue, 'visible': copy.deepcopy(chat_dialogue)} def chat_completions_common(body: dict, is_legacy: bool = False, stream=False, prompt_only=False) -> dict: if body.get('functions', []): raise InvalidRequestError(message="functions is not supported.", param='functions') if body.get('function_call', ''): raise InvalidRequestError(message="function_call is not supported.", param='function_call') if 'messages' not in body: raise InvalidRequestError(message="messages is required", param='messages') messages = body['messages'] for m in messages: if 'role' not in m: raise InvalidRequestError(message="messages: missing role", param='messages') elif m['role'] == 'function': raise InvalidRequestError(message="role: function is not supported.", param='messages') if 'content' not in m and "image_url" not in m: raise InvalidRequestError(message="messages: missing content", param='messages') # Chat Completions object_type = 'chat.completion' if not stream else 'chat.completion.chunk' created_time = int(time.time()) cmpl_id = "chatcmpl-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' # generation parameters generate_params = process_parameters(body, is_legacy=is_legacy) continue_ = body['continue_'] # Instruction template if body['instruction_template_str']: instruction_template_str = body['instruction_template_str'] elif body['instruction_template']: instruction_template = body['instruction_template'] instruction_template = "Alpaca" if instruction_template == "None" else instruction_template instruction_template_str = load_instruction_template_memoized(instruction_template) else: instruction_template_str = shared.settings['instruction_template_str'] chat_template_str = body['chat_template_str'] or shared.default_settings['chat_template_str'] chat_instruct_command = body['chat_instruct_command'] or shared.default_settings['chat-instruct_command'] # Chat character character = body['character'] or shared.default_settings['character'] character = "Assistant" if character == "None" else character name1 = body['user_name'] or shared.default_settings['name1'] name1, name2, _, greeting, context = load_character_memoized(character, name1, '') name2 = body['bot_name'] or name2 context = body['context'] or context greeting = body['greeting'] or greeting user_bio = body['user_bio'] or '' # History user_input, custom_system_message, history = convert_history(messages) generate_params.update({ 'mode': body['mode'], 'name1': name1, 'name2': name2, 'context': context, 'greeting': greeting, 'user_bio': user_bio, 'instruction_template_str': instruction_template_str, 'custom_system_message': custom_system_message, 'chat_template_str': chat_template_str, 'chat-instruct_command': chat_instruct_command, 'history': history, 'stream': stream }) max_tokens = generate_params['max_new_tokens'] if max_tokens in [None, 0]: generate_params['max_new_tokens'] = 512 generate_params['auto_max_new_tokens'] = True requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) def chat_streaming_chunk(content): # begin streaming chunk = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, "delta": {'role': 'assistant', 'content': content}, }], } if logprob_proc: # not official for chat yet top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) chunk[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} # else: # chunk[resp_list][0]["logprobs"] = None return chunk # generate reply ####################################### prompt = generate_chat_prompt(user_input, generate_params, _continue=continue_) if prompt_only: yield {'prompt': prompt} return debug_msg({'prompt': prompt, 'generate_params': generate_params}) if stream: yield chat_streaming_chunk('') generator = generate_chat_reply( user_input, generate_params, regenerate=False, _continue=continue_, loading_message=False) answer = '' seen_content = '' for a in generator: answer = a['internal'][-1][1] if stream: len_seen = len(seen_content) new_content = answer[len_seen:] if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue seen_content = answer chunk = chat_streaming_chunk(new_content) yield chunk token_count = len(encode(prompt)[0]) completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= generate_params['max_new_tokens']: stop_reason = "length" if stream: chunk = chat_streaming_chunk('') chunk[resp_list][0]['finish_reason'] = stop_reason chunk['usage'] = { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } yield chunk else: resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": stop_reason, "message": {"role": "assistant", "content": answer} }], "usage": { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } } if logprob_proc: # not official for chat yet top_logprobs = convert_logprobs_to_tiktoken(model=requested_model, logprobs=logprob_proc.token_alternatives) resp[resp_list][0]["logprobs"] = {'top_logprobs': [top_logprobs]} # else: # resp[resp_list][0]["logprobs"] = None yield resp def completions_common(body: dict, is_legacy: bool = False, stream=False): object_type = 'text_completion.chunk' if stream else 'text_completion' created_time = int(time.time()) cmpl_id = "conv-%d" % (int(time.time() * 1000000000)) resp_list = 'data' if is_legacy else 'choices' prompt_str = 'context' if is_legacy else 'prompt' # ... encoded as a string, array of strings, array of tokens, or array of token arrays. if prompt_str not in body: raise InvalidRequestError("Missing required input", param=prompt_str) # common params generate_params = process_parameters(body, is_legacy=is_legacy) max_tokens = generate_params['max_new_tokens'] generate_params['stream'] = stream requested_model = generate_params.pop('model') logprob_proc = generate_params.pop('logprob_proc', None) suffix = body['suffix'] if body['suffix'] else '' echo = body['echo'] if not stream: prompt_arg = body[prompt_str] if isinstance(prompt_arg, str) or (isinstance(prompt_arg, list) and isinstance(prompt_arg[0], int)): prompt_arg = [prompt_arg] resp_list_data = [] total_completion_token_count = 0 total_prompt_token_count = 0 for idx, prompt in enumerate(prompt_arg, start=0): if isinstance(prompt[0], int): # token lists if requested_model == shared.model_name: prompt = decode(prompt)[0] else: try: encoder = tiktoken.encoding_for_model(requested_model) prompt = encoder.decode(prompt) except KeyError: prompt = decode(prompt)[0] prefix = prompt if echo else '' # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) generator = generate_reply(prompt, generate_params, is_chat=False) answer = '' for a in generator: answer = a token_count = len(encode(prompt)[0]) total_prompt_token_count += token_count completion_token_count = len(encode(answer)[0]) total_completion_token_count += completion_token_count stop_reason = "stop" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: stop_reason = "length" respi = { "index": idx, "finish_reason": stop_reason, "text": prefix + answer + suffix, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, } resp_list_data.extend([respi]) resp = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, resp_list: resp_list_data, "usage": { "prompt_tokens": total_prompt_token_count, "completion_tokens": total_completion_token_count, "total_tokens": total_prompt_token_count + total_completion_token_count } } yield resp else: prompt = body[prompt_str] if isinstance(prompt, list): if prompt and isinstance(prompt[0], int): try: encoder = tiktoken.encoding_for_model(requested_model) prompt = encoder.decode(prompt) except KeyError: prompt = decode(prompt)[0] else: raise InvalidRequestError(message="API Batched generation not yet supported.", param=prompt_str) prefix = prompt if echo else '' token_count = len(encode(prompt)[0]) def text_streaming_chunk(content): # begin streaming chunk = { "id": cmpl_id, "object": object_type, "created": created_time, "model": shared.model_name, resp_list: [{ "index": 0, "finish_reason": None, "text": content, "logprobs": {'top_logprobs': [logprob_proc.token_alternatives]} if logprob_proc else None, }], } return chunk yield text_streaming_chunk(prefix) # generate reply ####################################### debug_msg({'prompt': prompt, 'generate_params': generate_params}) generator = generate_reply(prompt, generate_params, is_chat=False) answer = '' seen_content = '' completion_token_count = 0 for a in generator: answer = a len_seen = len(seen_content) new_content = answer[len_seen:] if not new_content or chr(0xfffd) in new_content: # partial unicode character, don't send it yet. continue seen_content = answer chunk = text_streaming_chunk(new_content) yield chunk completion_token_count = len(encode(answer)[0]) stop_reason = "stop" if token_count + completion_token_count >= generate_params['truncation_length'] or completion_token_count >= max_tokens: stop_reason = "length" chunk = text_streaming_chunk(suffix) chunk[resp_list][0]["finish_reason"] = stop_reason chunk["usage"] = { "prompt_tokens": token_count, "completion_tokens": completion_token_count, "total_tokens": token_count + completion_token_count } yield chunk def chat_completions(body: dict, is_legacy: bool = False) -> dict: generator = chat_completions_common(body, is_legacy, stream=False) return deque(generator, maxlen=1).pop() def stream_chat_completions(body: dict, is_legacy: bool = False): for resp in chat_completions_common(body, is_legacy, stream=True): yield resp def completions(body: dict, is_legacy: bool = False) -> dict: generator = completions_common(body, is_legacy, stream=False) return deque(generator, maxlen=1).pop() def stream_completions(body: dict, is_legacy: bool = False): for resp in completions_common(body, is_legacy, stream=True): yield resp