Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Commit

Permalink
Update OFA download paths (#81)
Browse files Browse the repository at this point in the history
MIT Han Lab updated URLs from which OFA super-networks can be download.
This change updates paths within the DyNAS-T to accommodate for that.

Signed-off-by: Maciej Szankin <[email protected]>
  • Loading branch information
macsz authored Sep 27, 2023
1 parent bd97a6f commit fa03430
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
16 changes: 10 additions & 6 deletions dynast/supernetwork/image_classification/ofa/ofa/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ def ofa_specialized(net_id, pretrained=True):


def ofa_net(net_id, pretrained=True):
url_base = "https://raw.githubusercontent.com/han-cai/files/master/ofa/ofa_nets/"
googledrive = False
if net_id == "ofa_proxyless_d234_e346_k357_w1.3":
net = OFAProxylessNASNets(
dropout_rate=0,
Expand Down Expand Up @@ -107,16 +109,18 @@ def ofa_net(net_id, pretrained=True):
expand_ratio_list=[0.2, 0.25, 0.35],
width_mult_list=[0.65, 0.8, 1.0],
)
net_id = "ofa_resnet50_d=0+1+2_e=0.2+0.25+0.35_w=0.65+0.8+1.0"
net_id = "ofa_supernet_resnet50"
url_base = "https://huggingface.co/han-cai/ofa/resolve/main/"
else:
raise ValueError("Not supported: %s" % net_id)

if pretrained:
url_base = "https://hanlab.mit.edu/files/OnceForAll/ofa_nets/"
init = torch.load(
download_url(url_base + net_id, model_dir=".torch/ofa_nets"),
map_location="cpu",
)["state_dict"]
if googledrive:
pt_path = f".torch/ofa_nets/{net_id}"
gdown.download(url_base, pt_path, quiet=False)
else:
pt_path = download_url(url_base + net_id, model_dir=".torch/ofa_nets")
init = torch.load(pt_path, map_location="cpu")["state_dict"]
net.load_state_dict(init)
return net

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
autograd>=1.5
fairseq>=0.12.2
gdown
numba>=0.56.4
numpy>=1.21.6
pandas>=1.3.5
Expand Down

0 comments on commit fa03430

Please sign in to comment.