mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 01:09:22 +01:00
Merge remote-tracking branch 'refs/remotes/origin/dev' into dev
This commit is contained in:
commit
8b5495ebf8
@ -22,7 +22,7 @@ from modules.chat import (
|
|||||||
load_instruction_template_memoized
|
load_instruction_template_memoized
|
||||||
)
|
)
|
||||||
from modules.presets import load_preset_memoized
|
from modules.presets import load_preset_memoized
|
||||||
from modules.text_generation import decode, encode, generate_reply
|
from modules.text_generation import decode, encode, generate_reply, get_reply_from_output_ids
|
||||||
|
|
||||||
|
|
||||||
class LogitsBiasProcessor(LogitsProcessor):
|
class LogitsBiasProcessor(LogitsProcessor):
|
||||||
@ -56,7 +56,7 @@ class LogprobProcessor(LogitsProcessor):
|
|||||||
if self.logprobs is not None: # 0-5
|
if self.logprobs is not None: # 0-5
|
||||||
log_e_probabilities = F.log_softmax(logits, dim=1)
|
log_e_probabilities = F.log_softmax(logits, dim=1)
|
||||||
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
top_values, top_indices = torch.topk(log_e_probabilities, k=self.logprobs + 1)
|
||||||
top_tokens = [decode(tok) for tok in top_indices[0]]
|
top_tokens = [get_reply_from_output_ids([tok]) for tok in top_indices[0]]
|
||||||
top_probs = [float(x) for x in top_values[0]]
|
top_probs = [float(x) for x in top_values[0]]
|
||||||
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
self.token_alternatives = dict(zip(top_tokens, top_probs))
|
||||||
debug_msg(repr(self))
|
debug_msg(repr(self))
|
||||||
@ -144,6 +144,26 @@ def convert_history(history):
|
|||||||
user_input = ""
|
user_input = ""
|
||||||
system_message = ""
|
system_message = ""
|
||||||
|
|
||||||
|
if any(isinstance(entry['content'], list) for entry in history):
|
||||||
|
new_history = []
|
||||||
|
for entry in history:
|
||||||
|
if isinstance(entry['content'], list):
|
||||||
|
image_url = None
|
||||||
|
content = None
|
||||||
|
for item in entry['content']:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
if item['type'] == 'image_url' and isinstance(item['image_url'], dict):
|
||||||
|
image_url = item['image_url']['url']
|
||||||
|
elif item['type'] == 'text' and isinstance(item['text'], str):
|
||||||
|
content = item['text']
|
||||||
|
if image_url and content:
|
||||||
|
new_history.append({"image_url": image_url, "role": "user"})
|
||||||
|
new_history.append({"content": content, "role": "user"})
|
||||||
|
else:
|
||||||
|
new_history.append(entry)
|
||||||
|
history = new_history
|
||||||
|
|
||||||
for entry in history:
|
for entry in history:
|
||||||
if "image_url" in entry:
|
if "image_url" in entry:
|
||||||
image_url = entry['image_url']
|
image_url = entry['image_url']
|
||||||
@ -156,8 +176,9 @@ def convert_history(history):
|
|||||||
img = Image.open(BytesIO(my_res.content))
|
img = Image.open(BytesIO(my_res.content))
|
||||||
except Exception:
|
except Exception:
|
||||||
raise 'Image cannot be loaded from the URL!'
|
raise 'Image cannot be loaded from the URL!'
|
||||||
|
|
||||||
buffered = BytesIO()
|
buffered = BytesIO()
|
||||||
|
if img.mode in ("RGBA", "P"):
|
||||||
|
img = img.convert("RGB")
|
||||||
img.save(buffered, format="JPEG")
|
img.save(buffered, format="JPEG")
|
||||||
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
img_str = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||||
content = f'<img src="data:image/jpeg;base64,{img_str}">'
|
content = f'<img src="data:image/jpeg;base64,{img_str}">'
|
||||||
|
@ -268,8 +268,8 @@ def apply_stopping_strings(reply, all_stop_strings):
|
|||||||
return reply, stop_found
|
return reply, stop_found
|
||||||
|
|
||||||
|
|
||||||
def get_reply_from_output_ids(output_ids, state, starting_from=0):
|
def get_reply_from_output_ids(output_ids, state=None, starting_from=0):
|
||||||
reply = decode(output_ids[starting_from:], state['skip_special_tokens'])
|
reply = decode(output_ids[starting_from:], state['skip_special_tokens'] if state else True)
|
||||||
|
|
||||||
# Handle tokenizers that do not add the leading space for the first token
|
# Handle tokenizers that do not add the leading space for the first token
|
||||||
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):
|
if (hasattr(shared.tokenizer, 'convert_ids_to_tokens') and len(output_ids) > starting_from) and not reply.startswith(' '):
|
||||||
|
Loading…
Reference in New Issue
Block a user