Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle resume on RTMOModeSwitchHook and YOLOXPoseModeSwitchHook #3045

Open
wants to merge 3 commits into
base: dev-1.x
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion mmpose/engine/hooks/mode_switch_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
from typing import Dict, Sequence

import torch.nn as nn
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
Expand Down Expand Up @@ -38,6 +39,7 @@ def __init__(self,
self.num_last_epochs = num_last_epochs
self.new_train_dataset = new_train_dataset
self.new_train_pipeline = new_train_pipeline
self.switched = False

def _modify_dataloader(self, runner: Runner):
"""Modify dataloader with new dataset and pipeline configurations."""
Expand All @@ -60,10 +62,12 @@ def before_train_epoch(self, runner: Runner):
if is_model_wrapper(model):
model = model.module

if epoch + 1 == runner.max_epochs - self.num_last_epochs:
if self.switched is False and (
epoch + 1 >= runner.max_epochs - self.num_last_epochs):
self._modify_dataloader(runner)
runner.logger.info('Added additional reg loss now!')
model.head.use_aux_loss = True
self.switched = True


@HOOKS.register_module()
Expand All @@ -89,6 +93,21 @@ class RTMOModeSwitchHook(Hook):

def __init__(self, epoch_attributes: Dict[int, Dict]):
self.epoch_attributes = epoch_attributes
self.handled_resume = False

def handle_resume(self, runner: Runner, model: nn.Module,
resumed_epoch: int):
"""Iter over all the previous batch size when training is resumed to
apply each epoch attributes modification in order."""
for epoch in self.epoch_attributes.keys():
if epoch >= resumed_epoch:
break

for key, value in self.epoch_attributes[epoch].items():
rsetattr(model.head, key, value)
runner.logger.info(
f'Change model.head.{key} to {rgetattr(model.head, key)}')
self.handled_resume = True

def before_train_epoch(self, runner: Runner):
"""Method called before each training epoch.
Expand Down
Loading