diff --git a/download-model.py b/download-model.py index d6b2ebff..34986c75 100644 --- a/download-model.py +++ b/download-model.py @@ -30,6 +30,8 @@ class ModelDownloader: self.s.mount('https://huggingface.co', HTTPAdapter(max_retries=max_retries)) if os.getenv('HF_USER') is not None and os.getenv('HF_PASS') is not None: self.s.auth = (os.getenv('HF_USER'), os.getenv('HF_PASS')) + if os.getenv('HF_TOKEN') is not None: + self.s.headers = {'authorization': f'Bearer {os.getenv("HF_TOKEN")}'} def sanitize_model_and_branch_names(self, model, branch): if model[-1] == '/':