diff --git a/download-model.py b/download-model.py new file mode 100644 index 00000000..2b2ca909 --- /dev/null +++ b/download-model.py @@ -0,0 +1,54 @@ +''' +Downloads models from Hugging Face to models/model-name. + +Example: +python download-model.py facebook/opt-1.3b + +''' + +import requests +from bs4 import BeautifulSoup +import multiprocessing +import os +import tqdm +from sys import argv + +def get_file(args): + url = args[0] + output_folder = args[1] + + r = requests.get(url, stream=True) + with open(f"{output_folder}/{url.split('/')[-1]}", 'wb') as f: + total_size = int(r.headers.get('content-length', 0)) + block_size = 1024 + t = tqdm.tqdm(total=total_size, unit='iB', unit_scale=True) + for data in r.iter_content(block_size): + t.update(len(data)) + f.write(data) + t.close() + +model = argv[1] +if model.endswith('/'): + model = model[:-1] +url = f'https://huggingface.co/{model}/tree/main' +output_folder = f"models/{model.split('/')[-1]}" +if not os.path.exists(output_folder): + os.mkdir(output_folder) + +# Finding the relevant files to download +page = requests.get(url) +soup = BeautifulSoup(page.content, 'html.parser') +links = soup.find_all('a') +downloads = [] +for link in links: + href = link.get('href')[1:] + if href.startswith(f'{model}/resolve/main'): + if href.endswith(('.json', '.txt')) or (href.endswith('.bin') and 'pytorch_model' in href): + downloads.append(f'https://huggingface.co/{href}') + +# Downloading the files +print(f"Downloading the model to {output_folder}...") +pool = multiprocessing.Pool(processes=4) +results = pool.map(get_file, [[downloads[i], output_folder] for i in range(len(downloads))]) +pool.close() +pool.join() diff --git a/requirements.txt b/requirements.txt index fc6edc77..0e1eb64d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,8 +5,10 @@ altair==4.2.0 anyio==3.6.2 async-timeout==4.0.2 attrs==22.1.0 +beautifulsoup4==4.11.1 bitsandbytes==0.35.4 brotlipy==0.7.0 +bs4==0.0.1 charset-normalizer==2.1.1 click==8.1.3 contourpy==1.0.6 @@ -14,6 +16,7 @@ cycler==0.11.0 entrypoints==0.4 fastapi==0.88.0 ffmpy==0.3.0 +filelock==3.9.0 fonttools==4.38.0 frozenlist==1.3.3 fsspec==2022.11.0 @@ -47,9 +50,11 @@ pyrsistent==0.19.2 python-dateutil==2.8.2 python-multipart==0.0.5 pytz==2022.7 +PyYAML==6.0 regex==2022.10.31 rfc3986==1.5.0 sniffio==1.3.0 +soupsieve==2.3.2.post1 starlette==0.22.0 tokenizers==0.13.2 toolz==0.12.0