Skip to content
This repository has been archived by the owner on Oct 16, 2023. It is now read-only.

Commit

Permalink
Deprecate optimize model function
Browse files Browse the repository at this point in the history
  • Loading branch information
mejai1206 committed Oct 4, 2023
1 parent f9a3112 commit 7ae5580
Showing 1 changed file with 0 additions and 30 deletions.
30 changes: 0 additions & 30 deletions trident/util/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,36 +86,6 @@ def size_and_stride(input: torch.Tensor, dim: int):
raise ValueError(f"{dim} is not supported.")


def optimize_module(mod):
opt_mod = None

if isinstance(mod, torch.nn.Dropout):
opt_mod = module.Dropout(mod.p)
elif isinstance(mod, torch.nn.GroupNorm):
opt_mod = module.GroupNorm(mod.num_groups, mod.num_channels, mod.eps, mod.affine)
elif isinstance(mod, torch.nn.InstanceNorm1d):
opt_mod = module.InstanceNorm1d(mod.num_features, mod.eps, mod.momentum, mod.affine, mod.track_running_stats)
elif isinstance(mod, torch.nn.InstanceNorm2d):
opt_mod = module.InstanceNorm1d(mod.num_features, mod.eps, mod.momentum, mod.affine, mod.track_running_stats)
elif isinstance(mod, torch.nn.LayerNorm):
opt_mod = module.LayerNorm(mod.normalized_shape, mod.eps, mod.elementwise_affine)
elif isinstance(mod, torch.nn.Softmax):
opt_mod = module.Softmax(mod.dim)

if opt_mod is not None:
opt_mod.load_state_dict(mod.state_dict())

return opt_mod


def optimize_model(model):
for name, child in model.named_children():
if other := optimize_module(child):
setattr(model, name, other)

optimize_model(child)


def push_trace(message: str):
if config.use_trace:
nvtx.push_range(message, color="green", domain="Trident")
Expand Down

0 comments on commit 7ae5580

Please sign in to comment.