Add spaces

This commit is contained in:
oobabooga 2023-04-25 00:10:21 -03:00
parent 1a0c12c6f2
commit b0ce750d4e

View File

@ -23,12 +23,14 @@ def load_extensions():
if extension not in setup_called and hasattr(extension, "setup"): if extension not in setup_called and hasattr(extension, "setup"):
setup_called.add(extension) setup_called.add(extension)
extension.setup() extension.setup()
state[name] = [True, i] state[name] = [True, i]
if name != 'api': if name != 'api':
print('Ok.') print('Ok.')
except: except:
if name != 'api': if name != 'api':
print('Fail.') print('Fail.')
traceback.print_exc() traceback.print_exc()
@ -44,6 +46,7 @@ def _apply_string_extensions(function_name, text):
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
text = getattr(extension, function_name)(text) text = getattr(extension, function_name)(text)
return text return text
@ -56,6 +59,7 @@ def _apply_input_hijack(text, visible_text):
text, visible_text = extension.input_hijack['value'](text, visible_text) text, visible_text = extension.input_hijack['value'](text, visible_text)
else: else:
text, visible_text = extension.input_hijack['value'] text, visible_text = extension.input_hijack['value']
return text, visible_text return text, visible_text
@ -65,8 +69,10 @@ def _apply_custom_generate_chat_prompt(text, state, **kwargs):
for extension, _ in iterator(): for extension, _ in iterator():
if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'): if custom_generate_chat_prompt is None and hasattr(extension, 'custom_generate_chat_prompt'):
custom_generate_chat_prompt = extension.custom_generate_chat_prompt custom_generate_chat_prompt = extension.custom_generate_chat_prompt
if custom_generate_chat_prompt is not None: if custom_generate_chat_prompt is not None:
return custom_generate_chat_prompt(text, state, **kwargs) return custom_generate_chat_prompt(text, state, **kwargs)
return None return None
@ -75,6 +81,7 @@ def _apply_tokenizer_extensions(function_name, state, prompt, input_ids, input_e
for extension, _ in iterator(): for extension, _ in iterator():
if hasattr(extension, function_name): if hasattr(extension, function_name):
prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds) prompt, input_ids, input_embeds = getattr(extension, function_name)(state, prompt, input_ids, input_embeds)
return prompt, input_ids, input_embeds return prompt, input_ids, input_embeds
@ -91,6 +98,7 @@ EXTENSION_MAP = {
def apply_extensions(typ, *args, **kwargs): def apply_extensions(typ, *args, **kwargs):
if typ not in EXTENSION_MAP: if typ not in EXTENSION_MAP:
raise ValueError(f"Invalid extension type {typ}") raise ValueError(f"Invalid extension type {typ}")
return EXTENSION_MAP[typ](*args, **kwargs) return EXTENSION_MAP[typ](*args, **kwargs)