Skip to content
This repository has been archived by the owner on Dec 3, 2024. It is now read-only.

Simplify how predictors are handled in LINAS search #27

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,7 @@ runs/

# Neural Compressor
nc_workspace/

mod/
models/
datasets/
92 changes: 40 additions & 52 deletions dynast/search/search_tactic.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def train_predictors(self, results_path: str = None):
"""

# Store predictor objects by objective name in a dictionary
self.predictor_dict = dict()
self.predictors = dict()

# Create/train a predictor for each objective
for objective in SUPERNET_METRICS[self.supernet]:
Expand All @@ -339,10 +339,10 @@ def train_predictors(self, results_path: str = None):
)
log.info(f'Training {objective} predictor.')
predictor = objective_predictor.train_predictor()
log.info(f'Updated self.predictor_dict[{objective}].')
self.predictor_dict[objective] = predictor
log.info(f'Updated self.predictors[{objective}].')
self.predictors[objective] = predictor
else:
self.predictor_dict[objective] = None
self.predictors[objective] = None

def search(self):
"""Runs the LINAS search"""
Expand Down Expand Up @@ -373,10 +373,7 @@ def search(self):
]:
runner_predict = OFARunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
predictors=self.predictors,
dataset_path=self.dataset_path,
device=self.device,
dataloader_workers=self.dataloader_workers,
Expand All @@ -385,32 +382,29 @@ def search(self):
elif self.supernet == 'transformer_lt_wmt_en_de':
runner_predict = TransformerLTRunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['bleu'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['bleu'],
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
)

elif self.supernet == 'bert_base_sst2':
runner_predict = BertSST2Runner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_sst2'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_sst2'],
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
device=self.device,
)
elif self.supernet == 'vit_base_imagenet':
runner_predict = ViTRunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
predictors=self.predictors,
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
batch_size=self.batch_size,
Expand All @@ -420,10 +414,10 @@ def search(self):
elif self.supernet == 'inc_quantization_ofa_resnet50':
runner_predict = QuantizedOFARunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
model_size_predictor=self.predictor_dict['model_size'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
latency_predictor=self.predictors['latency'],
model_size_predictor=self.predictors['model_size'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_top1'],
dataset_path=self.dataset_path,
device=self.device,
dataloader_workers=self.dataloader_workers,
Expand All @@ -434,10 +428,10 @@ def search(self):
runner_predict = BootstrapNASRunner(
bootstrapnas_supernetwork=self.bootstrapnas_supernetwork,
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_top1'],
dataset_path=self.dataset_path,
batch_size=self.batch_size,
device=self.device,
Expand Down Expand Up @@ -885,10 +879,7 @@ def search(self):
]:
runner_predict = OFARunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
predictors=self.predictors,
dataset_path=self.dataset_path,
device=self.device,
dataloader_workers=self.dataloader_workers,
Expand All @@ -897,32 +888,29 @@ def search(self):
elif self.supernet == 'transformer_lt_wmt_en_de':
runner_predict = TransformerLTRunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['bleu'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['bleu'],
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
)

elif self.supernet == 'bert_base_sst2':
runner_predict = BertSST2Runner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_sst2'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_sst2'],
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
device=self.device,
)
elif self.supernet == 'vit_base_imagenet':
runner_predict = ViTRunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
predictors=self.predictors,
dataset_path=self.dataset_path,
checkpoint_path=self.supernet_ckpt_path,
batch_size=self.batch_size,
Expand All @@ -932,10 +920,10 @@ def search(self):
elif self.supernet == 'inc_quantization_ofa_resnet50':
runner_predict = QuantizedOFARunner(
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
model_size_predictor=self.predictor_dict['model_size'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
latency_predictor=self.predictors['latency'],
model_size_predictor=self.predictors['model_size'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_top1'],
dataset_path=self.dataset_path,
device=self.device,
dataloader_workers=self.dataloader_workers,
Expand All @@ -946,10 +934,10 @@ def search(self):
runner_predict = BootstrapNASRunner(
bootstrapnas_supernetwork=self.bootstrapnas_supernetwork,
supernet=self.supernet,
latency_predictor=self.predictor_dict['latency'],
macs_predictor=self.predictor_dict['macs'],
params_predictor=self.predictor_dict['params'],
acc_predictor=self.predictor_dict['accuracy_top1'],
latency_predictor=self.predictors['latency'],
macs_predictor=self.predictors['macs'],
params_predictor=self.predictors['params'],
acc_predictor=self.predictors['accuracy_top1'],
dataset_path=self.dataset_path,
batch_size=self.batch_size,
device=self.device,
Expand Down
91 changes: 25 additions & 66 deletions dynast/supernetwork/image_classification/ofa/ofa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import csv
import uuid
from datetime import datetime
from typing import Tuple
from typing import Dict, Tuple, Union

import torch

Expand All @@ -31,12 +31,13 @@
ImagenetRunConfig,
RunManager,
)
from dynast.supernetwork.runner import Runner
from dynast.utils import log
from dynast.utils.datasets import ImageNet
from dynast.utils.nn import get_macs, get_parameters, measure_latency, validate_classification


class OFARunner:
class OFARunner(Runner):
"""The OFARunner class manages the sub-network selection from the OFA super-network and
the validation measurements of the sub-networks. ResNet50, MobileNetV3 w1.0, and MobileNetV3 w1.2
are currently supported. Imagenet is required for these super-networks `imagenet-ilsvrc2012`.
Expand All @@ -45,30 +46,27 @@ class OFARunner:
def __init__(
self,
supernet: str,
dataset_path: str,
acc_predictor: Predictor = None,
macs_predictor: Predictor = None,
latency_predictor: Predictor = None,
params_predictor: Predictor = None,
dataset_path: Union[None, str] = None,
predictors: Dict[str, Predictor] = {},
batch_size: int = 128,
eval_batch_size: int = 128,
dataloader_workers: int = 4,
device: str = 'cpu',
test_fraction: float = 1.0,
verbose: bool = False,
):
self.supernet = supernet
self.acc_predictor = acc_predictor
self.macs_predictor = macs_predictor
self.latency_predictor = latency_predictor
self.params_predictor = params_predictor
self.batch_size = batch_size
self.eval_batch_size = eval_batch_size
self.device = device
self.test_fraction = test_fraction
self.dataset_path = dataset_path
self.dataloader_workers = dataloader_workers
self.verbose = verbose
) -> None:
super().__init__(
supernet=supernet,
dataset_path=dataset_path,
predictors=predictors,
batch_size=batch_size,
eval_batch_size=eval_batch_size,
dataloader_workers=dataloader_workers,
device=device,
test_fraction=test_fraction,
verbose=verbose,
)

ImagenetDataProvider.DEFAULT_PATH = dataset_path

self.ofa_network = ofa_model_zoo.ofa_net(supernet, pretrained=True)
Expand All @@ -90,22 +88,6 @@ def _init_data(self):
self.dataloader = None
log.warning('No dataset path provided. Cannot validate sub-networks.')

def estimate_accuracy_top1(self, subnet_cfg) -> float:
top1 = self.acc_predictor.predict(subnet_cfg)
return top1

def estimate_macs(self, subnet_cfg) -> int:
macs = self.macs_predictor.predict(subnet_cfg)
return macs

def estimate_latency(self, subnet_cfg) -> float:
latency = self.latency_predictor.predict(subnet_cfg)
return latency

def estimate_parameters(self, subnet_cfg) -> int:
parameters = self.params_predictor.predict(subnet_cfg)
return parameters

def validate_top1(self, subnet_cfg, device=None) -> float:
device = self.device if not device else device

Expand Down Expand Up @@ -232,21 +214,10 @@ def eval_subnet(self, x):

# Predictor Mode
if self.predictor_mode == True:
if 'params' in self.optimization_metrics:
individual_results['params'] = self.evaluator.estimate_parameters(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'latency' in self.optimization_metrics:
individual_results['latency'] = self.evaluator.estimate_latency(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'macs' in self.optimization_metrics:
individual_results['macs'] = self.evaluator.estimate_macs(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'accuracy_top1' in self.optimization_metrics:
individual_results['accuracy_top1'] = self.evaluator.estimate_accuracy_top1(
self.manager.onehot_generic(x).reshape(1, -1)
for metric in self.optimization_metrics:
# TODO(macsz) Maybe move [0] to the `estimate_metric`.
individual_results[metric] = self.evaluator.estimate_metric(
metric, self.manager.onehot_generic(x).reshape(1, -1)
)[0]

# Validation Mode
Expand Down Expand Up @@ -329,21 +300,9 @@ def eval_subnet(self, x):

# Predictor Mode
if self.predictor_mode == True:
if 'params' in self.optimization_metrics:
individual_results['params'] = self.evaluator.estimate_parameters(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'latency' in self.optimization_metrics:
individual_results['latency'] = self.evaluator.estimate_latency(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'macs' in self.optimization_metrics:
individual_results['macs'] = self.evaluator.estimate_macs(
self.manager.onehot_generic(x).reshape(1, -1)
)[0]
if 'accuracy_top1' in self.optimization_metrics:
individual_results['accuracy_top1'] = self.evaluator.estimate_accuracy_top1(
self.manager.onehot_generic(x).reshape(1, -1)
for metric in self.optimization_metrics:
individual_results[metric] = self.evaluator.estimate_metric(
metric, self.manager.onehot_generic(x).reshape(1, -1)
)[0]

# Validation Mode
Expand Down
Loading
Loading