mirror of
https://github.com/oobabooga/text-generation-webui.git
synced 2024-12-25 13:58:56 +01:00
Downloader: Add --model-dir argument, respect --model-dir in the UI
This commit is contained in:
parent
ad54d524f7
commit
4f1e96b9e3
@ -167,8 +167,11 @@ class ModelDownloader:
|
||||
is_llamacpp = has_gguf and specific_file is not None
|
||||
return links, sha256, is_lora, is_llamacpp
|
||||
|
||||
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False):
|
||||
base_folder = 'models' if not is_lora else 'loras'
|
||||
def get_output_folder(self, model, branch, is_lora, is_llamacpp=False, model_dir=None):
|
||||
if model_dir:
|
||||
base_folder = model_dir
|
||||
else:
|
||||
base_folder = 'models' if not is_lora else 'loras'
|
||||
|
||||
# If the model is of type GGUF, save directly in the base_folder
|
||||
if is_llamacpp:
|
||||
@ -304,7 +307,8 @@ if __name__ == '__main__':
|
||||
parser.add_argument('--threads', type=int, default=4, help='Number of files to download simultaneously.')
|
||||
parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).')
|
||||
parser.add_argument('--specific-file', type=str, default=None, help='Name of the specific file to download (if not provided, downloads all).')
|
||||
parser.add_argument('--output', type=str, default=None, help='The folder where the model should be saved.')
|
||||
parser.add_argument('--output', type=str, default=None, help='Save the model files to this folder.')
|
||||
parser.add_argument('--model-dir', type=str, default=None, help='Save the model files to a subfolder of this folder instead of the default one (text-generation-webui/models).')
|
||||
parser.add_argument('--clean', action='store_true', help='Does not resume the previous download.')
|
||||
parser.add_argument('--check', action='store_true', help='Validates the checksums of model files.')
|
||||
parser.add_argument('--max-retries', type=int, default=5, help='Max retries count when get error in download time.')
|
||||
@ -333,7 +337,7 @@ if __name__ == '__main__':
|
||||
if args.output:
|
||||
output_folder = Path(args.output)
|
||||
else:
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp)
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp, model_dir=args.model_dir)
|
||||
|
||||
if args.check:
|
||||
# Check previously downloaded files
|
||||
|
@ -290,7 +290,13 @@ def download_model_wrapper(repo_id, specific_file, progress=gr.Progress(), retur
|
||||
return
|
||||
|
||||
yield ("Getting the output folder")
|
||||
output_folder = downloader.get_output_folder(model, branch, is_lora, is_llamacpp=is_llamacpp)
|
||||
output_folder = downloader.get_output_folder(
|
||||
model,
|
||||
branch,
|
||||
is_lora,
|
||||
is_llamacpp=is_llamacpp,
|
||||
model_dir=shared.args.model_dir if shared.args.model_dir != shared.args_defaults.model_dir else None
|
||||
)
|
||||
|
||||
if output_folder == Path("models"):
|
||||
output_folder = Path(shared.args.model_dir)
|
||||
|
Loading…
Reference in New Issue
Block a user