diff --git a/README.md b/README.md index 4b0d3c77..7f332c8c 100644 --- a/README.md +++ b/README.md @@ -105,14 +105,12 @@ After downloading the model, follow these steps: 1. Place the files under `models/gpt4chan_model_float16` or `models/gpt4chan_model`. 2. Place GPT-J 6B's config.json file in that same folder: [config.json](https://huggingface.co/EleutherAI/gpt-j-6B/raw/main/config.json). -3. Download GPT-J 6B under `models/gpt-j-6B`: +3. Download GPT-J 6B's tokenizer files (they will be automatically detected when you attempt to load GPT-4chan): ``` -python download-model.py EleutherAI/gpt-j-6B +python download-model.py EleutherAI/gpt-j-6B --text-only ``` -You don't really need all of GPT-J 6B's files, just the tokenizer files, but you might as well download the whole thing. Those files will be automatically detected when you attempt to load GPT-4chan. - #### Converting to pytorch (optional) The script `convert-to-torch.py` allows you to convert models to .pt format, which is sometimes 10x faster to load to the GPU: diff --git a/download-model.py b/download-model.py index 46aa9d77..0e114c65 100644 --- a/download-model.py +++ b/download-model.py @@ -19,6 +19,7 @@ parser = argparse.ArgumentParser() parser.add_argument('MODEL', type=str) parser.add_argument('--branch', type=str, default='main', help='Name of the Git branch to download from.') parser.add_argument('--threads', type=int, default=1, help='Number of files to download simultaneously.') +parser.add_argument('--text-only', action='store_true', help='Only download text files (txt/json).') args = parser.parse_args() def get_file(args): @@ -84,15 +85,18 @@ if __name__ == '__main__': is_text = re.match(".*\.(txt|json)", fname) if is_text or is_safetensors or is_pytorch: - downloads.append(f'https://huggingface.co/{href}') if is_text: + downloads.append(f'https://huggingface.co/{href}') classifications.append('text') - elif is_safetensors: - has_safetensors = True - classifications.append('safetensors') - elif is_pytorch: - has_pytorch = True - classifications.append('pytorch') + continue + if not args.text_only: + downloads.append(f'https://huggingface.co/{href}') + if is_safetensors: + has_safetensors = True + classifications.append('safetensors') + elif is_pytorch: + has_pytorch = True + classifications.append('pytorch') # If both pytorch and safetensors are available, download safetensors only if has_pytorch and has_safetensors: