-
Notifications
You must be signed in to change notification settings - Fork 32
/
fetch_model.py
executable file
·51 lines (46 loc) · 1.79 KB
/
fetch_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
#!/usr/bin/env python3
import sys
import os
import argparse
import subprocess
from urllib.parse import urlparse
from huggingface_hub import model_info
parser = argparse.ArgumentParser()
parser.add_argument('model', type=str)
parser.add_argument('output_folder', type=str)
args = parser.parse_args()
SCRIPT_DIR = "/"
model = args.model.strip()
output_folder = args.output_folder
success=False
retry_count=0
while not success and retry_count < 10:
os.makedirs(output_folder, exist_ok=True)
os.chdir(output_folder)
retry_count += 1
print(f'Downloading {model} to {output_folder}, attempt {retry_count}')
if 'http' in model.lower():
# We've been passed a URL to download
parsed = urlparse(model)
# split the path by '/' and get the filename
filename = parsed.path.split("/")[-1]
print(f"Passed URL: {model}", flush=True)
run = subprocess.run(f'/usr/bin/wget --continue --progress=dot:giga "{model}"', shell=True, check=False)
write = filename
elif model_info(model).id == model:
# We've got an HF model, eg 'TheBloke/WizardLM-7B-Uncensored'
print(f"Passed HF model: {model}", flush=True)
model_folder = model.replace('/','_')
run = subprocess.run(f'{SCRIPT_DIR}/download_model.py --threads 2 --output "{output_folder}/{model_folder}" "{args.model}"', shell=True, check=False)
write = model_folder
else:
print(f"Error, {model} does not seem to be in a supported format.")
success = False
break
if run.returncode == 0:
# Succesful download. Write the model file or folder name to /tmp for use in --model arg
with open('/tmp/text-gen-model', 'w') as f:
f.write(write + '\n')
success = True
# Exit 0 for success, 1 for failure
sys.exit(not success)