diff --git a/modules/LoRA.py b/modules/LoRA.py index aa68ad32..394f7367 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -30,7 +30,7 @@ def add_lora_to_model(lora_name): elif shared.args.load_in_8bit: params['device_map'] = {'': 0} - shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params) if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): diff --git a/modules/models.py b/modules/models.py index ccb97da3..757eb8b9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -46,9 +46,9 @@ def load_model(model_name): # Default settings if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) if torch.has_mps: device = torch.device('mps') model = model.to(device) @@ -76,11 +76,11 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) + model = OptLM(f"facebook/{shared.model_name}", env, shared.model_name, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -89,8 +89,8 @@ def load_model(model_name): elif shared.is_RWKV: from modules.RWKV import RWKVModel, RWKVTokenizer - model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path('models')) + model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.model_name)) return model, tokenizer @@ -142,7 +142,7 @@ def load_model(model_name): if shared.args.disk: params["offload_folder"] = shared.args.disk_cache_dir - checkpoint = Path(f'models/{shared.model_name}') + checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) @@ -159,10 +159,10 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer - if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): - tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) + if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") diff --git a/modules/shared.py b/modules/shared.py index 720c697e..72cea1d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -106,6 +106,9 @@ parser.add_argument('--listen-port', type=int, help='The listening port that the parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") +parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") args = parser.parse_args() # Provisional, this will be deleted later diff --git a/server.py b/server.py index f423e368..f8fd663c 100644 --- a/server.py +++ b/server.py @@ -31,9 +31,9 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) @@ -48,7 +48,7 @@ def get_available_softprompts(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) def get_available_loras(): - return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return ['None'] + sorted([item.name for item in list(Path('shared.args.lora_dir').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def load_model_wrapper(selected_model): if selected_model != shared.model_name: @@ -448,11 +448,15 @@ def create_interface(): extensions_module.create_extensions_block() # Launch the interface + gradio_auth_creds = [] + with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) create_interface()