mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-11-25 17:29:22 +01:00
If only 1 model is available, load that model
This commit is contained in:
parent
8b482b4127
commit
fbb448ce4f
53
server.py
53
server.py
@ -209,8 +209,8 @@ def list_model_parameters():
|
|||||||
|
|
||||||
# Model parameters: update the command-line arguments based on the interface values
|
# Model parameters: update the command-line arguments based on the interface values
|
||||||
def update_model_parameters(*args):
|
def update_model_parameters(*args):
|
||||||
args = list(args) # the values of the parameters
|
args = list(args) # the values of the parameters
|
||||||
elements = list_model_parameters() # the names of the parameters
|
elements = list_model_parameters() # the names of the parameters
|
||||||
|
|
||||||
gpu_memories = []
|
gpu_memories = []
|
||||||
for i, element in enumerate(elements):
|
for i, element in enumerate(elements):
|
||||||
@ -232,8 +232,8 @@ def update_model_parameters(*args):
|
|||||||
elif element == 'cpu_memory' and args[i] is not None:
|
elif element == 'cpu_memory' and args[i] is not None:
|
||||||
args[i] = f"{args[i]}MiB"
|
args[i] = f"{args[i]}MiB"
|
||||||
|
|
||||||
#print(element, repr(eval(f"shared.args.{element}")), repr(args[i]))
|
# print(element, repr(eval(f"shared.args.{element}")), repr(args[i]))
|
||||||
#print(f"shared.args.{element} = args[i]")
|
# print(f"shared.args.{element} = args[i]")
|
||||||
exec(f"shared.args.{element} = args[i]")
|
exec(f"shared.args.{element} = args[i]")
|
||||||
|
|
||||||
found_positive = False
|
found_positive = False
|
||||||
@ -251,7 +251,7 @@ def create_model_menus():
|
|||||||
# Finding the default values for the GPU and CPU memories
|
# Finding the default values for the GPU and CPU memories
|
||||||
total_mem = []
|
total_mem = []
|
||||||
for i in range(torch.cuda.device_count()):
|
for i in range(torch.cuda.device_count()):
|
||||||
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024*1024)))
|
total_mem.append(math.floor(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024)))
|
||||||
|
|
||||||
default_gpu_mem = []
|
default_gpu_mem = []
|
||||||
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
|
if shared.args.gpu_memory is not None and len(shared.args.gpu_memory) > 0:
|
||||||
@ -259,11 +259,11 @@ def create_model_menus():
|
|||||||
if 'mib' in i.lower():
|
if 'mib' in i.lower():
|
||||||
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
|
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)))
|
||||||
else:
|
else:
|
||||||
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i))*1000)
|
default_gpu_mem.append(int(re.sub('[a-zA-Z ]', '', i)) * 1000)
|
||||||
while len(default_gpu_mem) < len(total_mem):
|
while len(default_gpu_mem) < len(total_mem):
|
||||||
default_gpu_mem.append(0)
|
default_gpu_mem.append(0)
|
||||||
|
|
||||||
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024*1024))
|
total_cpu_mem = math.floor(psutil.virtual_memory().total / (1024 * 1024))
|
||||||
if shared.args.cpu_memory is not None:
|
if shared.args.cpu_memory is not None:
|
||||||
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
|
default_cpu_mem = re.sub('[a-zA-Z ]', '', shared.args.cpu_memory)
|
||||||
else:
|
else:
|
||||||
@ -441,16 +441,19 @@ else:
|
|||||||
if extension not in shared.args.extensions:
|
if extension not in shared.args.extensions:
|
||||||
shared.args.extensions.append(extension)
|
shared.args.extensions.append(extension)
|
||||||
|
|
||||||
# Default model
|
# Model defined through --model
|
||||||
if shared.args.model is not None:
|
if shared.args.model is not None:
|
||||||
shared.model_name = shared.args.model
|
shared.model_name = shared.args.model
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
# Only one model is available
|
||||||
|
elif len(available_models) == 1:
|
||||||
|
shared.model_name = available_models[0]
|
||||||
|
|
||||||
|
# Select the model from a command-line menu
|
||||||
elif shared.args.model_menu:
|
elif shared.args.model_menu:
|
||||||
if len(available_models) == 0:
|
if len(available_models) == 0:
|
||||||
print('No models are available! Please download at least one.')
|
print('No models are available! Please download at least one.')
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
elif len(available_models) == 1:
|
|
||||||
i = 0
|
|
||||||
else:
|
else:
|
||||||
print('The following models are available:\n')
|
print('The following models are available:\n')
|
||||||
for i, model in enumerate(available_models):
|
for i, model in enumerate(available_models):
|
||||||
@ -459,10 +462,12 @@ elif shared.args.model_menu:
|
|||||||
i = int(input()) - 1
|
i = int(input()) - 1
|
||||||
print()
|
print()
|
||||||
shared.model_name = available_models[i]
|
shared.model_name = available_models[i]
|
||||||
shared.model, shared.tokenizer = load_model(shared.model_name)
|
|
||||||
|
|
||||||
if shared.args.model is not None and shared.args.lora:
|
# If any model has been selected, load it
|
||||||
add_lora_to_model(shared.args.lora)
|
if shared.model_name != 'None':
|
||||||
|
shared.model, shared.tokenizer = load_model(shared.model_name)
|
||||||
|
if shared.args.lora:
|
||||||
|
add_lora_to_model(shared.args.lora)
|
||||||
|
|
||||||
# Default UI settings
|
# Default UI settings
|
||||||
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
default_preset = shared.settings['presets'][next((k for k in shared.settings['presets'] if re.match(k.lower(), shared.model_name.lower())), 'default')]
|
||||||
@ -685,14 +690,14 @@ def create_interface():
|
|||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[0]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
@ -744,20 +749,20 @@ def create_interface():
|
|||||||
|
|
||||||
gen_events.append(shared.gradio['Generate'].click(
|
gen_events.append(shared.gradio['Generate'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['textbox'].submit(
|
gen_events.append(shared.gradio['textbox'].submit(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, shared.input_params, output_params, show_progress=shared.args.no_stream) # .then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_events.append(shared.gradio['Continue'].click(
|
gen_events.append(shared.gradio['Continue'].click(
|
||||||
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
ui.gather_interface_values, [shared.gradio[k] for k in shared.input_elements], shared.gradio['interface_state']).then(
|
||||||
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream)#.then(
|
generate_reply, [shared.gradio['output_textbox']] + shared.input_params[1:], output_params, show_progress=shared.args.no_stream) # .then(
|
||||||
#None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
# None, None, None, _js="() => {element = document.getElementsByTagName('textarea')[1]; element.scrollTop = element.scrollHeight}")
|
||||||
)
|
)
|
||||||
|
|
||||||
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
shared.gradio['Stop'].click(stop_everything_event, None, None, queue=False, cancels=gen_events if shared.args.no_stream else None)
|
||||||
|
Loading…
Reference in New Issue
Block a user