diff --git a/.idea/workspace.xml b/.idea/workspace.xml new file mode 100644 index 00000000..404920a8 --- /dev/null +++ b/.idea/workspace.xml @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 1678590722207 + + + + \ No newline at end of file diff --git a/extensions/llama_prompts/script.py b/extensions/llama_prompts/script.py new file mode 100644 index 00000000..e45cd445 --- /dev/null +++ b/extensions/llama_prompts/script.py @@ -0,0 +1,18 @@ +import gradio as gr +import modules.shared as shared +import pandas as pd + +df = pd.read_csv("https://raw.githubusercontent.com/devbrones/llama-prompts/main/prompts/prompts.csv") + +def get_prompt_by_name(name): + if name == 'None': + return '' + else: + return df[df['Prompt name'] == name].iloc[0]['Prompt'].replace('\\n', '\n') + +def ui(): + if not shared.args.chat or share.args.cai_chat: + choices = ['None'] + list(df['Prompt name']) + + prompts_menu = gr.Dropdown(value=choices[0], choices=choices, label='Prompt') + prompts_menu.change(get_prompt_by_name, prompts_menu, shared.gradio['textbox']) diff --git a/modules/RWKV.py b/modules/RWKV.py index b226a195..836d31dc 100644 --- a/modules/RWKV.py +++ b/modules/RWKV.py @@ -7,6 +7,7 @@ import numpy as np from tokenizers import Tokenizer import modules.shared as shared +from modules.callbacks import Iteratorize np.set_printoptions(precision=4, suppress=True, linewidth=200) @@ -49,11 +50,11 @@ class RWKVModel: return context+self.pipeline.generate(context, token_count=token_count, args=args, callback=callback) def generate_with_streaming(self, **kwargs): - iterable = Iteratorize(self.generate, kwargs, callback=None) - reply = kwargs['context'] - for token in iterable: - reply += token - yield reply + with Iteratorize(self.generate, kwargs, callback=None) as generator: + reply = kwargs['context'] + for token in generator: + reply += token + yield reply class RWKVTokenizer: def __init__(self): @@ -73,38 +74,3 @@ class RWKVTokenizer: def decode(self, ids): return self.tokenizer.decode(ids) - -class Iteratorize: - - """ - Transforms a function that takes a callback - into a lazy iterator (generator). - """ - - def __init__(self, func, kwargs={}, callback=None): - self.mfunc=func - self.c_callback=callback - self.q = Queue(maxsize=1) - self.sentinel = object() - self.kwargs = kwargs - - def _callback(val): - self.q.put(val) - - def gentask(): - ret = self.mfunc(callback=_callback, **self.kwargs) - self.q.put(self.sentinel) - if self.c_callback: - self.c_callback(ret) - - Thread(target=gentask).start() - - def __iter__(self): - return self - - def __next__(self): - obj = self.q.get(True,None) - if obj is self.sentinel: - raise StopIteration - else: - return obj diff --git a/modules/callbacks.py b/modules/callbacks.py new file mode 100644 index 00000000..faa4a5e9 --- /dev/null +++ b/modules/callbacks.py @@ -0,0 +1,98 @@ +import gc +from queue import Queue +from threading import Thread + +import torch +import transformers + +import modules.shared as shared + +# Copied from https://github.com/PygmalionAI/gradio-ui/ +class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): + + def __init__(self, sentinel_token_ids: torch.LongTensor, + starting_idx: int): + transformers.StoppingCriteria.__init__(self) + self.sentinel_token_ids = sentinel_token_ids + self.starting_idx = starting_idx + + def __call__(self, input_ids: torch.LongTensor, + _scores: torch.FloatTensor) -> bool: + for sample in input_ids: + trimmed_sample = sample[self.starting_idx:] + # Can't unfold, output is still too tiny. Skip. + if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: + continue + + for window in trimmed_sample.unfold( + 0, self.sentinel_token_ids.shape[-1], 1): + if torch.all(torch.eq(self.sentinel_token_ids, window)): + return True + return False + +class Stream(transformers.StoppingCriteria): + def __init__(self, callback_func=None): + self.callback_func = callback_func + + def __call__(self, input_ids, scores) -> bool: + if self.callback_func is not None: + self.callback_func(input_ids[0]) + return False + +class Iteratorize: + + """ + Transforms a function that takes a callback + into a lazy iterator (generator). + """ + + def __init__(self, func, kwargs={}, callback=None): + self.mfunc=func + self.c_callback=callback + self.q = Queue() + self.sentinel = object() + self.kwargs = kwargs + self.stop_now = False + + def _callback(val): + if self.stop_now: + raise ValueError + self.q.put(val) + + def gentask(): + try: + ret = self.mfunc(callback=_callback, **self.kwargs) + except ValueError: + pass + clear_torch_cache() + self.q.put(self.sentinel) + if self.c_callback: + self.c_callback(ret) + + self.thread = Thread(target=gentask) + self.thread.start() + + def __iter__(self): + return self + + def __next__(self): + obj = self.q.get(True,None) + if obj is self.sentinel: + raise StopIteration + else: + return obj + + def __del__(self): + clear_torch_cache() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop_now = True + clear_torch_cache() + +def clear_torch_cache(): + gc.collect() + if not shared.args.cpu: + torch.cuda.empty_cache() diff --git a/modules/chat.py b/modules/chat.py index f40f8299..2048e2c5 100644 --- a/modules/chat.py +++ b/modules/chat.py @@ -84,6 +84,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate tmp = f"\n{asker}:" for j in range(1, len(tmp)): if reply[-j:] == tmp[:j]: + reply = reply[:-j] substring_found = True return reply, next_character_found, substring_found @@ -91,7 +92,7 @@ def extract_message_from_reply(question, reply, name1, name2, check, impersonate def stop_everything_event(): shared.stop_everything = True -def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1): +def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts=1, regenerate=False): shared.stop_everything = False just_started = True eos_token = '\n' if check else None @@ -120,6 +121,10 @@ def chatbot_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typical else: prompt = custom_generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size) + if not regenerate: + # Display user input and "*is typing...*" imediately + yield shared.history['visible']+[[visible_text, '*Is typing...*']] + # Generate reply = '' for i in range(chat_generation_attempts): @@ -158,6 +163,9 @@ def impersonate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typ prompt = generate_chat_prompt(text, max_new_tokens, name1, name2, context, chat_prompt_size, impersonate=True) + # Display "*is typing...*" imediately + yield '*Is typing...*' + reply = '' for i in range(chat_generation_attempts): for reply in generate_reply(prompt+reply, max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, eos_token=eos_token, stopping_string=f"\n{name2}:"): @@ -182,7 +190,7 @@ def regenerate_wrapper(text, max_new_tokens, do_sample, temperature, top_p, typi last_visible = shared.history['visible'].pop() last_internal = shared.history['internal'].pop() - for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts): + for _history in chatbot_wrapper(last_internal[0], max_new_tokens, do_sample, temperature, top_p, typical_p, repetition_penalty, top_k, min_length, no_repeat_ngram_size, num_beams, penalty_alpha, length_penalty, early_stopping, name1, name2, context, check, chat_prompt_size, chat_generation_attempts, regenerate=True): if shared.args.cai_chat: shared.history['visible'][-1] = [last_visible[0], _history[-1][1]] yield generate_chat_html(shared.history['visible'], name1, name2, shared.character) @@ -291,7 +299,7 @@ def save_history(timestamp=True): fname = f"{prefix}persistent.json" if not Path('logs').exists(): Path('logs').mkdir() - with open(Path(f'logs/{fname}'), 'w') as f: + with open(Path(f'logs/{fname}'), 'w', encoding='utf-8') as f: f.write(json.dumps({'data': shared.history['internal'], 'data_visible': shared.history['visible']}, indent=2)) return Path(f'logs/{fname}') @@ -332,7 +340,7 @@ def load_character(_character, name1, name2): shared.history['visible'] = [] if _character != 'None': shared.character = _character - data = json.loads(open(Path(f'characters/{_character}.json'), 'r').read()) + data = json.loads(open(Path(f'characters/{_character}.json'), 'r', encoding='utf-8').read()) name2 = data['char_name'] if 'char_persona' in data and data['char_persona'] != '': context += f"{data['char_name']}'s Persona: {data['char_persona']}\n" @@ -372,7 +380,7 @@ def upload_character(json_file, img, tavern=False): i += 1 if tavern: outfile_name = f'TavernAI-{outfile_name}' - with open(Path(f'characters/{outfile_name}.json'), 'w') as f: + with open(Path(f'characters/{outfile_name}.json'), 'w', encoding='utf-8') as f: f.write(json_file) if img is not None: img = Image.open(io.BytesIO(img)) diff --git a/modules/shared.py b/modules/shared.py index aa66761e..52f9300a 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -91,4 +91,5 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch') args = parser.parse_args() diff --git a/modules/stopping_criteria.py b/modules/stopping_criteria.py deleted file mode 100644 index 44a631b3..00000000 --- a/modules/stopping_criteria.py +++ /dev/null @@ -1,32 +0,0 @@ -''' -This code was copied from - -https://github.com/PygmalionAI/gradio-ui/ - -''' - -import torch -import transformers - - -class _SentinelTokenStoppingCriteria(transformers.StoppingCriteria): - - def __init__(self, sentinel_token_ids: torch.LongTensor, - starting_idx: int): - transformers.StoppingCriteria.__init__(self) - self.sentinel_token_ids = sentinel_token_ids - self.starting_idx = starting_idx - - def __call__(self, input_ids: torch.LongTensor, - _scores: torch.FloatTensor) -> bool: - for sample in input_ids: - trimmed_sample = sample[self.starting_idx:] - # Can't unfold, output is still too tiny. Skip. - if trimmed_sample.shape[-1] < self.sentinel_token_ids.shape[-1]: - continue - - for window in trimmed_sample.unfold( - 0, self.sentinel_token_ids.shape[-1], 1): - if torch.all(torch.eq(self.sentinel_token_ids, window)): - return True - return False diff --git a/modules/text_generation.py b/modules/text_generation.py index 5a715e8e..6f53e416 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -5,13 +5,13 @@ import time import numpy as np import torch import transformers -from tqdm import tqdm import modules.shared as shared +from modules.callbacks import (Iteratorize, Stream, + _SentinelTokenStoppingCriteria) from modules.extensions import apply_extensions from modules.html_generator import generate_4chan_html, generate_basic_html from modules.models import local_rank -from modules.stopping_criteria import _SentinelTokenStoppingCriteria def get_max_prompt_length(tokens): @@ -92,19 +92,22 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi # These models are not part of Hugging Face, so we handle them # separately and terminate the function call earlier if shared.is_RWKV: - if shared.args.no_stream: - reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) - yield formatted_outputs(reply, shared.model_name) - else: - yield formatted_outputs(question, shared.model_name) - # RWKV has proper streaming, which is very nice. - # No need to generate 8 tokens at a time. - for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + try: + if shared.args.no_stream: + reply = shared.model.generate(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k) yield formatted_outputs(reply, shared.model_name) - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds.") - return + else: + yield formatted_outputs(question, shared.model_name) + # RWKV has proper streaming, which is very nice. + # No need to generate 8 tokens at a time. + for reply in shared.model.generate_with_streaming(context=question, token_count=max_new_tokens, temperature=temperature, top_p=top_p, top_k=top_k): + yield formatted_outputs(reply, shared.model_name) + finally: + t1 = time.time() + output = encode(reply)[0] + input_ids = encode(question) + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(input_ids[0])} tokens)") + return original_question = question if not (shared.args.chat or shared.args.cai_chat): @@ -113,23 +116,19 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi print(f"\n\n{question}\n--------------------\n") input_ids = encode(question, max_new_tokens) + original_input_ids = input_ids + output = input_ids[0] cuda = "" if (shared.args.cpu or shared.args.deepspeed or shared.args.flexgen) else ".cuda()" n = shared.tokenizer.eos_token_id if eos_token is None else int(encode(eos_token)[0][-1]) + stopping_criteria_list = transformers.StoppingCriteriaList() if stopping_string is not None: - # The stopping_criteria code below was copied from - # https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py + # Copied from https://github.com/PygmalionAI/gradio-ui/blob/master/src/model.py t = encode(stopping_string, 0, add_special_tokens=False) - stopping_criteria_list = transformers.StoppingCriteriaList([ - _SentinelTokenStoppingCriteria( - sentinel_token_ids=t, - starting_idx=len(input_ids[0]) - ) - ]) - else: - stopping_criteria_list = None + stopping_criteria_list.append(_SentinelTokenStoppingCriteria(sentinel_token_ids=t, starting_idx=len(input_ids[0]))) if not shared.args.flexgen: generate_params = [ + f"max_new_tokens=max_new_tokens", f"eos_token_id={n}", f"stopping_criteria=stopping_criteria_list", f"do_sample={do_sample}", @@ -147,45 +146,23 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi ] else: generate_params = [ + f"max_new_tokens={max_new_tokens if shared.args.no_stream else 8}", f"do_sample={do_sample}", f"temperature={temperature}", f"stop={n}", ] if shared.args.deepspeed: generate_params.append("synced_gpus=True") - if shared.args.no_stream: - generate_params.append("max_new_tokens=max_new_tokens") - else: - generate_params.append("max_new_tokens=8") if shared.soft_prompt: inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) generate_params.insert(0, "inputs_embeds=inputs_embeds") - generate_params.insert(0, "filler_input_ids") + generate_params.insert(0, "inputs=filler_input_ids") else: - generate_params.insert(0, "input_ids") - - # Generate the entire reply at once - if shared.args.no_stream: - with torch.no_grad(): - output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] - if shared.soft_prompt: - output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) - - reply = decode(output) - if not (shared.args.chat or shared.args.cai_chat): - reply = original_question + apply_extensions(reply[len(question):], "output") - - t1 = time.time() - print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(input_ids[0]))/(t1-t0)/8:.2f} it/s, {len(output)-len(input_ids[0])} tokens)") - yield formatted_outputs(reply, shared.model_name) - - # Generate the reply 8 tokens at a time - else: - yield formatted_outputs(original_question, shared.model_name) - shared.still_streaming = True - for i in tqdm(range(max_new_tokens//8+1)): - clear_torch_cache() + generate_params.insert(0, "inputs=input_ids") + try: + # Generate the entire reply at once. + if shared.args.no_stream: with torch.no_grad(): output = eval(f"shared.model.generate({', '.join(generate_params)}){cuda}")[0] if shared.soft_prompt: @@ -194,22 +171,66 @@ def generate_reply(question, max_new_tokens, do_sample, temperature, top_p, typi reply = decode(output) if not (shared.args.chat or shared.args.cai_chat): reply = original_question + apply_extensions(reply[len(question):], "output") - - if not shared.args.flexgen: - if output[-1] == n: - break - input_ids = torch.reshape(output, (1, output.shape[0])) - else: - if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): - break - input_ids = np.reshape(output, (1, output.shape[0])) - - #Mid-stream yield, ran if no breaks + yield formatted_outputs(reply, shared.model_name) - if shared.soft_prompt: - inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) - - #Stream finished from max tokens or break. Do final yield. - shared.still_streaming = False - yield formatted_outputs(reply, shared.model_name) \ No newline at end of file + # Stream the reply 1 token at a time. + # This is based on the trick of using 'stopping_criteria' to create an iterator. + elif not shared.args.flexgen: + + def generate_with_callback(callback=None, **kwargs): + kwargs['stopping_criteria'].append(Stream(callback_func=callback)) + clear_torch_cache() + with torch.no_grad(): + shared.model.generate(**kwargs) + + def generate_with_streaming(**kwargs): + return Iteratorize(generate_with_callback, kwargs, callback=None) + + shared.still_streaming = True + yield formatted_outputs(original_question, shared.model_name) + with eval(f"generate_with_streaming({', '.join(generate_params)})") as generator: + for output in generator: + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) + + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + + if output[-1] == n: + break + yield formatted_outputs(reply, shared.model_name) + + shared.still_streaming = False + yield formatted_outputs(reply, shared.model_name) + + # Stream the output naively for FlexGen since it doesn't support 'stopping_criteria' + else: + shared.still_streaming = True + for i in range(max_new_tokens//8+1): + clear_torch_cache() + with torch.no_grad(): + output = eval(f"shared.model.generate({', '.join(generate_params)})")[0] + if shared.soft_prompt: + output = torch.cat((input_ids[0], output[filler_input_ids.shape[1]:])) + reply = decode(output) + + if not (shared.args.chat or shared.args.cai_chat): + reply = original_question + apply_extensions(reply[len(question):], "output") + + if np.count_nonzero(input_ids[0] == n) < np.count_nonzero(output == n): + break + yield formatted_outputs(reply, shared.model_name) + + input_ids = np.reshape(output, (1, output.shape[0])) + if shared.soft_prompt: + inputs_embeds, filler_input_ids = generate_softprompt_input_tensors(input_ids) + + shared.still_streaming = False + yield formatted_outputs(reply, shared.model_name) + + finally: + t1 = time.time() + print(f"Output generated in {(t1-t0):.2f} seconds ({(len(output)-len(original_input_ids[0]))/(t1-t0):.2f} tokens/s, {len(output)-len(original_input_ids[0])} tokens)") + return diff --git a/requirements.txt b/requirements.txt index 6133f394..a7df93bb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,7 @@ bitsandbytes==0.37.0 flexgen==0.1.7 gradio==3.18.0 numpy +requests rwkv==0.1.0 safetensors==0.2.8 sentencepiece diff --git a/server.py b/server.py index c2977f41..47e9c8ed 100644 --- a/server.py +++ b/server.py @@ -18,9 +18,6 @@ from modules.html_generator import generate_chat_html from modules.models import load_model, load_soft_prompt from modules.text_generation import generate_reply -if (shared.args.chat or shared.args.cai_chat) and not shared.args.no_stream: - print('Warning: chat mode currently becomes somewhat slower with text streaming on.\nConsider starting the web UI with the --no-stream option.\n') - # Loading custom settings settings_file = None if shared.args.settings is not None and Path(shared.args.settings).exists(): @@ -272,10 +269,10 @@ if shared.args.chat or shared.args.cai_chat: function_call = 'chat.cai_chatbot_wrapper' if shared.args.cai_chat else 'chat.chatbot_wrapper' - gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream, api_name='textgen')) - gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=shared.args.no_stream)) - gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=shared.args.no_stream)) + gen_events.append(shared.gradio['Generate'].click(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False, api_name='textgen')) + gen_events.append(shared.gradio['textbox'].submit(eval(function_call), shared.input_params, shared.gradio['display'], show_progress=False)) + gen_events.append(shared.gradio['Regenerate'].click(chat.regenerate_wrapper, shared.input_params, shared.gradio['display'], show_progress=False)) + gen_events.append(shared.gradio['Impersonate'].click(chat.impersonate_wrapper, shared.input_params, shared.gradio['textbox'], show_progress=False)) shared.gradio['Stop'].click(chat.stop_everything_event, [], [], cancels=gen_events) shared.gradio['Copy last reply'].click(chat.send_last_reply_to_input, [], shared.gradio['textbox'], show_progress=shared.args.no_stream) @@ -309,6 +306,7 @@ if shared.args.chat or shared.args.cai_chat: reload_inputs = [shared.gradio['name1'], shared.gradio['name2']] if shared.args.cai_chat else [] shared.gradio['upload_chat_history'].upload(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['upload_img_me'].upload(reload_func, reload_inputs, [shared.gradio['display']]) + shared.gradio['Stop'].click(reload_func, reload_inputs, [shared.gradio['display']]) shared.gradio['interface'].load(lambda : chat.load_default_history(shared.settings[f'name1{suffix}'], shared.settings[f'name2{suffix}']), None, None) shared.gradio['interface'].load(reload_func, reload_inputs, [shared.gradio['display']], show_progress=True) @@ -372,9 +370,9 @@ else: shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) # I think that I will need this later while True: