From b37c54edcfee36ef5fdbaae9f6337d236be52b99 Mon Sep 17 00:00:00 2001 From: catalpaaa Date: Fri, 24 Mar 2023 17:30:18 -0700 Subject: [PATCH 1/9] lora-dir, model-dir and login auth Added lora-dir, model-dir, and a login auth arguments that points to a file contains usernames and passwords in the format of "u:pw,u:pw,..." --- modules/LoRA.py | 2 +- modules/models.py | 20 ++++++++++---------- modules/shared.py | 3 +++ server.py | 14 +++++++++----- 4 files changed, 23 insertions(+), 16 deletions(-) diff --git a/modules/LoRA.py b/modules/LoRA.py index aa68ad32..394f7367 100644 --- a/modules/LoRA.py +++ b/modules/LoRA.py @@ -30,7 +30,7 @@ def add_lora_to_model(lora_name): elif shared.args.load_in_8bit: params['device_map'] = {'': 0} - shared.model = PeftModel.from_pretrained(shared.model, Path(f"loras/{lora_name}"), **params) + shared.model = PeftModel.from_pretrained(shared.model, Path(f"{shared.args.lora_dir}/{lora_name}"), **params) if not shared.args.load_in_8bit and not shared.args.cpu: shared.model.half() if not hasattr(shared.model, "hf_device_map"): diff --git a/modules/models.py b/modules/models.py index ccb97da3..757eb8b9 100644 --- a/modules/models.py +++ b/modules/models.py @@ -46,9 +46,9 @@ def load_model(model_name): # Default settings if not any([shared.args.cpu, shared.args.load_in_8bit, shared.args.gptq_bits, shared.args.auto_devices, shared.args.disk, shared.args.gpu_memory is not None, shared.args.cpu_memory is not None, shared.args.deepspeed, shared.args.flexgen, shared.is_RWKV]): if any(size in shared.model_name.lower() for size in ('13b', '20b', '30b')): - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), device_map='auto', load_in_8bit=True) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), device_map='auto', load_in_8bit=True) else: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), low_cpu_mem_usage=True, torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) if torch.has_mps: device = torch.device('mps') model = model.to(device) @@ -76,11 +76,11 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, "models", policy) + model = OptLM(f"facebook/{shared.model_name}", env, shared.model_name, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: - model = AutoModelForCausalLM.from_pretrained(Path(f"models/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) + model = AutoModelForCausalLM.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}"), torch_dtype=torch.bfloat16 if shared.args.bf16 else torch.float16) model = deepspeed.initialize(model=model, config_params=ds_config, model_parameters=None, optimizer=None, lr_scheduler=None)[0] model.module.eval() # Inference print(f"DeepSpeed ZeRO-3 is enabled: {is_deepspeed_zero3_enabled()}") @@ -89,8 +89,8 @@ def load_model(model_name): elif shared.is_RWKV: from modules.RWKV import RWKVModel, RWKVTokenizer - model = RWKVModel.from_pretrained(Path(f'models/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path('models')) + model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.model_name)) return model, tokenizer @@ -142,7 +142,7 @@ def load_model(model_name): if shared.args.disk: params["offload_folder"] = shared.args.disk_cache_dir - checkpoint = Path(f'models/{shared.model_name}') + checkpoint = Path(f'{shared.args.model_dir}/{shared.model_name}') if shared.args.load_in_8bit and params.get('max_memory', None) is not None and params['device_map'] == 'auto': config = AutoConfig.from_pretrained(checkpoint) @@ -159,10 +159,10 @@ def load_model(model_name): model = AutoModelForCausalLM.from_pretrained(checkpoint, **params) # Loading the tokenizer - if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path("models/gpt-j-6B/").exists(): - tokenizer = AutoTokenizer.from_pretrained(Path("models/gpt-j-6B/")) + if shared.model_name.lower().startswith(('gpt4chan', 'gpt-4chan', '4chan')) and Path(f"{shared.args.model_dir}/gpt-j-6B/").exists(): + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/gpt-j-6B/")) else: - tokenizer = AutoTokenizer.from_pretrained(Path(f"models/{shared.model_name}/")) + tokenizer = AutoTokenizer.from_pretrained(Path(f"{shared.args.model_dir}/{shared.model_name}/")) tokenizer.truncation_side = 'left' print(f"Loaded the model in {(time.time()-t0):.2f} seconds.") diff --git a/modules/shared.py b/modules/shared.py index 720c697e..72cea1d4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -106,6 +106,9 @@ parser.add_argument('--listen-port', type=int, help='The listening port that the parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') +parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") +parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") args = parser.parse_args() # Provisional, this will be deleted later diff --git a/server.py b/server.py index f423e368..f8fd663c 100644 --- a/server.py +++ b/server.py @@ -31,9 +31,9 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path('models/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path('models/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) @@ -48,7 +48,7 @@ def get_available_softprompts(): return ['None'] + sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('softprompts').glob('*.zip'))), key=str.lower) def get_available_loras(): - return ['None'] + sorted([item.name for item in list(Path('loras/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return ['None'] + sorted([item.name for item in list(Path('shared.args.lora_dir').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def load_model_wrapper(selected_model): if selected_model != shared.model_name: @@ -448,11 +448,15 @@ def create_interface(): extensions_module.create_extensions_block() # Launch the interface + gradio_auth_creds = [] + with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) create_interface() From ec2a1faceecddf1400245a6c8983e40ef430cccf Mon Sep 17 00:00:00 2001 From: catalpaaa Date: Fri, 24 Mar 2023 17:34:33 -0700 Subject: [PATCH 2/9] Update server.py --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index f8fd663c..c69abb4b 100644 --- a/server.py +++ b/server.py @@ -31,9 +31,9 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.model_name}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) From 9e2963e0c86180fec5a88db4ec77530ad2de7d69 Mon Sep 17 00:00:00 2001 From: catalpaaa Date: Fri, 24 Mar 2023 17:35:45 -0700 Subject: [PATCH 3/9] Update server.py --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index c69abb4b..67c1e915 100644 --- a/server.py +++ b/server.py @@ -31,9 +31,9 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.arg.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.arg.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) From d51cb8292b42eb29e4e45ed850d23b446208a0d3 Mon Sep 17 00:00:00 2001 From: catalpaaa Date: Fri, 24 Mar 2023 17:36:31 -0700 Subject: [PATCH 4/9] Update server.py yea i should go to bed --- server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server.py b/server.py index 67c1e915..8ac6031a 100644 --- a/server.py +++ b/server.py @@ -31,9 +31,9 @@ if settings_file is not None: def get_available_models(): if shared.args.flexgen: - return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.arg.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) + return sorted([re.sub('-np$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if item.name.endswith('-np')], key=str.lower) else: - return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.arg.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) + return sorted([re.sub('.pth$', '', item.name) for item in list(Path(f'{shared.args.model_dir}/').glob('*')) if not item.name.endswith(('.txt', '-np', '.pt', '.json'))], key=str.lower) def get_available_presets(): return sorted(set(map(lambda x : '.'.join(str(x.name).split('.')[:-1]), Path('presets').glob('*.txt'))), key=str.lower) From 005f552ea311e9bf932b91337da101a490bdd5ff Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:29:52 -0300 Subject: [PATCH 5/9] Some simplifications --- modules/shared.py | 6 +++--- server.py | 17 +++++++++++------ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/modules/shared.py b/modules/shared.py index d9bcf241..71829a01 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -107,14 +107,14 @@ parser.add_argument('--rwkv-cuda-on', action='store_true', help='RWKV: Compile t parser.add_argument('--no-stream', action='store_true', help='Don\'t stream the text output in real time.') parser.add_argument('--settings', type=str, help='Load the default interface settings from this json file. See settings-template.json for an example. If you create a file called settings.json, this file will be loaded by default without the need to use the --settings flag.') parser.add_argument('--extensions', type=str, nargs="+", help='The list of extensions to load. If you want to load more than one extension, write the names separated by spaces.') +parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") +parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") +parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument('--listen', action='store_true', help='Make the web UI reachable from your local network.') parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') -parser.add_argument('--verbose', action='store_true', help='Print the prompts to the terminal.') parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3"', default=None) -parser.add_argument("--model-dir", type=str, default='models/', help="Path to directory with all the models") -parser.add_argument("--lora-dir", type=str, default='loras/', help="Path to directory with all the loras") args = parser.parse_args() # Provisional, this will be deleted later diff --git a/server.py b/server.py index 15aa84bb..66f60074 100644 --- a/server.py +++ b/server.py @@ -498,16 +498,21 @@ def create_interface(): if shared.args.extensions is not None: extensions_module.create_extensions_block() + # Authentication + auth = None + if shared.args.gradio_auth_path is not None: + gradio_auth_creds = [] + with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] + auth = [tuple(cred.split(':')) for cred in gradio_auth_creds] + # Launch the interface - gradio_auth_creds = [] - with open(shared.args.gradio_auth_path, 'r', encoding="utf8") as file: - for line in file.readlines(): - gradio_auth_creds += [x.strip() for x in line.split(',') if x.strip()] shared.gradio['interface'].queue() if shared.args.listen: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_name='0.0.0.0', server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) else: - shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None) + shared.gradio['interface'].launch(prevent_thread_lock=True, share=shared.args.share, server_port=shared.args.listen_port, inbrowser=shared.args.auto_launch, auth=auth) create_interface() From 30585b3e716e646ffabb8d590e5fe3b53863656d Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:35:01 -0300 Subject: [PATCH 6/9] Update README --- README.md | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 3bfbc72f..cd75284c 100644 --- a/README.md +++ b/README.md @@ -198,12 +198,15 @@ Optionally, you can use the following command-line flags: | `--rwkv-cuda-on` | RWKV: Compile the CUDA kernel for better performance. | | `--no-stream` | Don't stream the text output in real time. | | `--settings SETTINGS_FILE` | Load the default interface settings from this json file. See `settings-template.json` for an example. If you create a file called `settings.json`, this file will be loaded by default without the need to use the `--settings` flag.| -| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | -| `--listen` | Make the web UI reachable from your local network.| -| `--listen-port LISTEN_PORT` | The listening port that the server will use. | -| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | -| `--auto-launch` | Open the web UI in the default browser upon launch. | -| `--verbose` | Print the prompts to the terminal. | +| `--extensions EXTENSIONS [EXTENSIONS ...]` | The list of extensions to load. If you want to load more than one extension, write the names separated by spaces. | +| `--model-dir MODEL_DIR` | Path to directory with all the models | +| `--lora-dir LORA_DIR` | Path to directory with all the loras | +| `--verbose` | Print the prompts to the terminal. | +| `--listen` | Make the web UI reachable from your local network. | +| `--listen-port LISTEN_PORT` | The listening port that the server will use. | +| `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | +| `--auto-launch` | Open the web UI in the default browser upon launch. | +| `--gradio-auth-path GRADIO_AUTH_PATH` | set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3" | Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). From 036163a75134ba88d83754548b992331d2b450f5 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:39:26 -0300 Subject: [PATCH 7/9] Change description --- README.md | 2 +- modules/shared.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index cd75284c..f6b1d4f5 100644 --- a/README.md +++ b/README.md @@ -206,7 +206,7 @@ Optionally, you can use the following command-line flags: | `--listen-port LISTEN_PORT` | The listening port that the server will use. | | `--share` | Create a public URL. This is useful for running the web UI on Google Colab or similar. | | `--auto-launch` | Open the web UI in the default browser upon launch. | -| `--gradio-auth-path GRADIO_AUTH_PATH` | set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3" | +| `--gradio-auth-path GRADIO_AUTH_PATH` | Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3" | Out of memory errors? [Check the low VRAM guide](https://github.com/oobabooga/text-generation-webui/wiki/Low-VRAM-guide). diff --git a/modules/shared.py b/modules/shared.py index 71829a01..ac9d750c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -114,7 +114,7 @@ parser.add_argument('--listen', action='store_true', help='Make the web UI reach parser.add_argument('--listen-port', type=int, help='The listening port that the server will use.') parser.add_argument('--share', action='store_true', help='Create a public URL. This is useful for running the web UI on Google Colab or similar.') parser.add_argument('--auto-launch', action='store_true', default=False, help='Open the web UI in the default browser upon launch.') -parser.add_argument("--gradio-auth-path", type=str, help='set gradio authentication file path ex. "/path/to/auth/file" with format like "u1:p1,u2:p2,u3:p3"', default=None) +parser.add_argument("--gradio-auth-path", type=str, help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', default=None) args = parser.parse_args() # Provisional, this will be deleted later From ee95e55df67468902fc411bbfc51bb961d1953d2 Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:42:29 -0300 Subject: [PATCH 8/9] Fix RWKV tokenizer --- modules/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models.py b/modules/models.py index 5aaef800..26a10f7a 100644 --- a/modules/models.py +++ b/modules/models.py @@ -90,7 +90,7 @@ def load_model(model_name): from modules.RWKV import RWKVModel, RWKVTokenizer model = RWKVModel.from_pretrained(Path(f'{shared.args.model_dir}/{model_name}'), dtype="fp32" if shared.args.cpu else "bf16" if shared.args.bf16 else "fp16", device="cpu" if shared.args.cpu else "cuda") - tokenizer = RWKVTokenizer.from_pretrained(Path(shared.model_name)) + tokenizer = RWKVTokenizer.from_pretrained(Path(shared.args.model_dir)) return model, tokenizer From 53da672315d3914b1af728274f0223e7bac60b7a Mon Sep 17 00:00:00 2001 From: oobabooga <112222186+oobabooga@users.noreply.github.com> Date: Mon, 27 Mar 2023 23:44:21 -0300 Subject: [PATCH 9/9] Fix FlexGen --- modules/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/models.py b/modules/models.py index 26a10f7a..a6839318 100644 --- a/modules/models.py +++ b/modules/models.py @@ -76,7 +76,7 @@ def load_model(model_name): num_bits=4, group_size=64, group_dim=2, symmetric=False)) - model = OptLM(f"facebook/{shared.model_name}", env, shared.model_name, policy) + model = OptLM(f"facebook/{shared.model_name}", env, shared.args.model_dir, policy) # DeepSpeed ZeRO-3 elif shared.args.deepspeed: