Merge branch 'main' into pt-path-changes

This commit is contained in:
oobabooga 2023-03-10 11:03:42 -03:00 committed by GitHub
commit e9dbdafb14
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 32 additions and 9 deletions

View File

@ -54,7 +54,7 @@ The third line assumes that you have an NVIDIA GPU.
pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2 pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/rocm5.2
``` ```
* If you are running in CPU mode, replace the third command with this one: * If you are running it in CPU mode, replace the third command with this one:
``` ```
conda install pytorch torchvision torchaudio git -c pytorch conda install pytorch torchvision torchaudio git -c pytorch

View File

@ -5,7 +5,9 @@ Example:
python download-model.py facebook/opt-1.3b python download-model.py facebook/opt-1.3b
''' '''
import argparse import argparse
import base64
import json import json
import multiprocessing import multiprocessing
import re import re
@ -93,23 +95,28 @@ facebook/opt-1.3b
def get_download_links_from_huggingface(model, branch): def get_download_links_from_huggingface(model, branch):
base = "https://huggingface.co" base = "https://huggingface.co"
page = f"/api/models/{model}/tree/{branch}?cursor=" page = f"/api/models/{model}/tree/{branch}?cursor="
cursor = b""
links = [] links = []
classifications = [] classifications = []
has_pytorch = False has_pytorch = False
has_safetensors = False has_safetensors = False
while page is not None: while True:
content = requests.get(f"{base}{page}").content content = requests.get(f"{base}{page}{cursor.decode()}").content
dict = json.loads(content) dict = json.loads(content)
if len(dict) == 0:
break
for i in range(len(dict)): for i in range(len(dict)):
fname = dict[i]['path'] fname = dict[i]['path']
is_pytorch = re.match("pytorch_model.*\.bin", fname) is_pytorch = re.match("pytorch_model.*\.bin", fname)
is_safetensors = re.match("model.*\.safetensors", fname) is_safetensors = re.match("model.*\.safetensors", fname)
is_text = re.match(".*\.(txt|json)", fname) is_tokenizer = re.match("tokenizer.*\.model", fname)
is_text = re.match(".*\.(txt|json)", fname) or is_tokenizer
if is_text or is_safetensors or is_pytorch: if any((is_pytorch, is_safetensors, is_text, is_tokenizer)):
if is_text: if is_text:
links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}") links.append(f"https://huggingface.co/{model}/resolve/{branch}/{fname}")
classifications.append('text') classifications.append('text')
@ -123,8 +130,9 @@ def get_download_links_from_huggingface(model, branch):
has_pytorch = True has_pytorch = True
classifications.append('pytorch') classifications.append('pytorch')
#page = dict['nextUrl'] cursor = base64.b64encode(f'{{"file_name":"{dict[-1]["path"]}"}}'.encode()) + b':50'
page = None cursor = base64.b64encode(cursor)
cursor = cursor.replace(b'=', b'%3D')
# If both pytorch and safetensors are available, download safetensors only # If both pytorch and safetensors are available, download safetensors only
if has_pytorch and has_safetensors: if has_pytorch and has_safetensors:

View File

@ -116,8 +116,23 @@ def load_model(model_name):
print(f"Could not find {pt_model}, exiting...") print(f"Could not find {pt_model}, exiting...")
exit() exit()
model = load_quant(path_to_model, pt_path, 4) model = load_quant(path_to_model, Path(f"models/{pt_model}"), 4)
model = model.to(torch.device('cuda:0'))
# Multi-GPU setup
if shared.args.gpu_memory:
import accelerate
max_memory = {}
for i in range(len(shared.args.gpu_memory)):
max_memory[i] = f"{shared.args.gpu_memory[i]}GiB"
max_memory['cpu'] = f"{shared.args.cpu_memory or '99'}GiB"
device_map = accelerate.infer_auto_device_map(model, max_memory=max_memory, no_split_module_classes=["LLaMADecoderLayer"])
model = accelerate.dispatch_model(model, device_map=device_map)
# Single GPU
else:
model = model.to(torch.device('cuda:0'))
# Custom # Custom
else: else: