Skip to content

Commit

Permalink
gated repo exception handling
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchitvj committed May 2, 2024
1 parent fa6757d commit bc9a48d
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/grag/quantize/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import requests
from git import Repo
from grag.components.utils import get_config
from huggingface_hub import snapshot_download
from huggingface_hub import login, snapshot_download
from huggingface_hub.utils import GatedRepoError

config = get_config()

Expand Down Expand Up @@ -107,12 +108,29 @@ def fetch_model_repo(repo_id: str, model_path: Union[str, Path] = './grag-quanti
model_path = Path(model_path)
local_dir = model_path / f"{repo_id.split('/')[1]}"
local_dir.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks="auto",
resume_download=True,
)

try:
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks="auto",
resume_download=True,
)
except GatedRepoError:
print(
"This model comes under gated repository. You must be authenticated to download the model. For more: https://huggingface.co/docs/hub/en/models-gated")
resp = input("If you have auth token, please provide it here ['n' or enter to exit]: ")
if resp == 'n' or resp == '':
print("No token provided, exiting.")
exit(0)
else:
login(resp)
snapshot_download(
repo_id=repo_id,
local_dir=local_dir,
local_dir_use_symlinks="auto",
resume_download=True,
)
print(f"Model downloaded in {local_dir}")
return local_dir

Expand Down

0 comments on commit bc9a48d

Please sign in to comment.