Skip to content

Commit

Permalink
Add enum for trainer actions
Browse files Browse the repository at this point in the history
The enum is used for validation check before triggering one of them. Previously I was checking if the queue was alive, but that won't be enough, for example if you want to perform resume, while you are resumed, the queue is operational, but the action shouldn't be valid.
  • Loading branch information
thodkatz committed Dec 10, 2024
1 parent 5cb34c0 commit e4b53d5
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 34 deletions.
34 changes: 17 additions & 17 deletions tests/test_server/test_grpc/test_training_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
import numpy as np
import pytest

from tiktorch.converters import trainer_state_to_pb
from tiktorch.converters import pb_state_to_trainer, trainer_state_to_pb
from tiktorch.proto import training_pb2, training_pb2_grpc
from tiktorch.server.device_pool import TorchDevicePool
from tiktorch.server.grpc import training_servicer
from tiktorch.server.session.backend.base import TrainerSessionBackend
from tiktorch.server.session.process import TrainerSessionProcess
from tiktorch.server.session_manager import SessionManager
from tiktorch.trainer import Callbacks, ShouldStopCallbacks, Trainer, TrainerState
from tiktorch.trainer import ShouldStopCallbacks, Trainer, TrainerState


@pytest.fixture(scope="module")
Expand Down Expand Up @@ -57,8 +57,8 @@ def unet2d_config_path(checkpoint_dir, train_data_dir, val_data_path, device: st
resume: null
validate_after_iters: 2
log_after_iters: 2
max_num_epochs: 1
max_num_iterations: 1
max_num_epochs: 1000
max_num_iterations: 10000
eval_score_higher_is_better: True
optimizer:
learning_rate: 0.0002
Expand Down Expand Up @@ -164,7 +164,7 @@ def create_random_dataset(shape, channel_per_class):
l_shape = (2,) + l_shape

f.create_dataset("raw", data=np.random.rand(*shape))
f.create_dataset("label", data=np.random.randint(0, 1, l_shape))
f.create_dataset("label", data=np.random.randint(0, 2, l_shape, dtype=np.int64))
f.create_dataset("weight_map", data=np.random.rand(*w_shape))

return tmp.name
Expand All @@ -188,7 +188,7 @@ def assert_state(self, grpc_stub, training_session_id: str, state_to_check: Trai

def poll_for_state_grpc(self, grpc_stub, session_id, expected_state: TrainerState, timeout=3, poll_interval=0.1):
def get_status(*args):
return trainer_state_to_pb[grpc_stub.GetStatus(session_id).state]
return pb_state_to_trainer[grpc_stub.GetStatus(session_id).state]

self.poll_for_state(get_status, expected_state, timeout, poll_interval)

Expand Down Expand Up @@ -285,12 +285,12 @@ def test_concurrent_state_transitions(self, grpc_stub):
thread.join()

def test_queueing_multiple_commands(self, grpc_stub):
init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

def assert_state(state_to_check):
self.assert_state(grpc_stub, training_session_id, state_to_check)

init_response = grpc_stub.Init(training_pb2.TrainingConfig(yaml_content=prepare_unet2d_test_environment()))
training_session_id = training_pb2.TrainingSessionId(id=init_response.id)

grpc_stub.Start(training_session_id)
assert_state(TrainerState.RUNNING)

Expand Down Expand Up @@ -330,12 +330,12 @@ def test_error_handling_on_invalid_state_transitions_before_training_started(sel
# Attempt to resume before start
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Resume(training_session_id)
assert "Training hasn't started" in excinfo.value.details()
assert "Invalid state transition: TrainerState.IDLE -> TrainerState.RUNNING" in excinfo.value.details()

# Attempt to pause before start
with pytest.raises(grpc.RpcError) as excinfo:
grpc_stub.Pause(training_session_id)
assert "Training hasn't started" in excinfo.value.details()
assert "Invalid state transition: TrainerState.IDLE -> TrainerState.PAUSED" in excinfo.value.details()

def test_start_training_without_init(self, grpc_stub):
"""
Expand All @@ -347,20 +347,20 @@ def test_start_training_without_init(self, grpc_stub):
assert "trainer-session with id doesn't exist" in excinfo.value.details()

def test_recover_training_failed(self):
class MockedExceptionTrainer(Trainer):
class MockedExceptionTrainer:
def __init__(self):
self.should_stop_callbacks = Callbacks()
self.should_stop_callbacks = ShouldStopCallbacks()

def fit(self):
raise Exception("mocked exception")

class MockedNominalTrainer(Trainer):
class MockedNominalTrainer:
def __init__(self):
self.num_epochs = 0
self.max_num_epochs = 10
self.num_iterations = 0
self.max_num_iterations = 100
self.should_stop_callbacks = Callbacks()
self.should_stop_callbacks = ShouldStopCallbacks()

def fit(self):
for epoch in range(self.max_num_epochs):
Expand Down Expand Up @@ -397,9 +397,9 @@ def assert_error(func, expected_message: str):
func()
assert expected_message in str(excinfo.value)

class MockedExceptionTrainer(Trainer):
class MockedExceptionTrainer:
def __init__(self):
self.should_stop_callbacks = Callbacks()
self.should_stop_callbacks = ShouldStopCallbacks()

def fit(self):
raise Exception("mocked exception")
Expand Down
8 changes: 4 additions & 4 deletions tiktorch/server/session/backend/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dataclasses import dataclass, field
from typing import Generic, Type, TypeVar

from tiktorch.trainer import TrainerState
from tiktorch.trainer import TrainerAction, TrainerState

if typing.TYPE_CHECKING:
from tiktorch.server.session.backend.supervisor import BioModelSupervisor, Supervisors, TrainerSupervisor

Check warning on line 14 in tiktorch/server/session/backend/commands.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/commands.py#L14

Added line #L14 was not covered by tests
Expand Down Expand Up @@ -110,17 +110,17 @@ def execute(self, ctx: Context[TrainerSupervisor]) -> None:

class SetStartStateTrainingCmd(ICommand):
def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.transition_to_state(new_state=TrainerState.RUNNING, valid_states={TrainerState.IDLE})
ctx.session.transition_to_state(new_state=TrainerState.RUNNING, trainer_action=TrainerAction.START)


class SetPauseStateTrainingCmd(ICommand):
def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.transition_to_state(new_state=TrainerState.PAUSED, valid_states={TrainerState.RUNNING})
ctx.session.transition_to_state(new_state=TrainerState.PAUSED, trainer_action=TrainerAction.PAUSE)

Check warning on line 118 in tiktorch/server/session/backend/commands.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/commands.py#L118

Added line #L118 was not covered by tests


class SetResumeStateTrainingCmd(ICommand):
def execute(self, ctx: Context[TrainerSupervisor]) -> None:
ctx.session.transition_to_state(new_state=TrainerState.RUNNING, valid_states={TrainerState.PAUSED})
ctx.session.transition_to_state(new_state=TrainerState.RUNNING, trainer_action=TrainerAction.RESUME)

Check warning on line 123 in tiktorch/server/session/backend/commands.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/commands.py#L123

Added line #L123 was not covered by tests


class ShutdownCmd(ICommand):
Expand Down
32 changes: 21 additions & 11 deletions tiktorch/server/session/backend/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tiktorch.server.session.backend import commands
from tiktorch.server.session.backend.commands import CommandPriorityQueueUtils, ShutdownWithTeardownCmd
from tiktorch.trainer import BaseCallbacks, ErrorCallbacks, Trainer, TrainerState
from tiktorch.trainer import BaseCallbacks, ErrorCallbacks, Trainer, TrainerAction, TrainerState

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -57,10 +57,7 @@ def get_state(self) -> TrainerState:
return self._state

def start(self):
if self._state != TrainerState.IDLE:
raise StateTransitionError(
current_state=self._state, transitioning_state=TrainerState.RUNNING, valid_states={TrainerState.IDLE}
)
self._check_transition_to_start()
self._session_thread.start()
self._pause_triggered = False
start_cmd = commands.SetStartStateTrainingCmd()
Expand Down Expand Up @@ -98,23 +95,23 @@ def is_training_finished(self):
return (
self._trainer.num_epochs == self._trainer.max_num_epochs
or self._trainer.num_iterations == self._trainer.max_num_iterations
)
) or self._trainer.should_stop_model_criteria()

def _get_num_iterations_epochs(self) -> str:
iterations = f"Iterations[{self._trainer.num_iterations}/{self._trainer.max_num_iterations}]"
epochs = f"Epochs[{self._trainer.num_epochs}/{self._trainer.max_num_epochs}]"
return f"{iterations}, {epochs}"

@requires_queue_alive
def resume(self):
self._check_transition_to_resume()
self._pause_triggered = False
resume_cmd = commands.SetResumeStateTrainingCmd()
self._command_queue_utils.send_command(resume_cmd.awaitable)
resume_cmd.awaitable.wait() # make sure that the state has actually changed (acknowledge)
logger.info(f"Resume training: {self._get_num_iterations_epochs()}")

Check warning on line 111 in tiktorch/server/session/backend/supervisor.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/supervisor.py#L106-L111

Added lines #L106 - L111 were not covered by tests

@requires_queue_alive
def pause(self):
self._check_transition_to_pause()
self._pause_triggered = True
pause_cmd = commands.SetPauseStateTrainingCmd()
self._command_queue_utils.send_command(pause_cmd.awaitable)
Expand All @@ -128,7 +125,6 @@ def shutdown(self):
self._command_queue_utils.send_command(commands.ShutdownCmd())
self._session_thread.join()

@requires_queue_alive
def forward(self, input_tensors):
self.pause()
self._trainer.forward(input_tensors)
Expand All @@ -143,14 +139,28 @@ def export(self):
def _should_stop(self):
return self._pause_triggered

Check warning on line 140 in tiktorch/server/session/backend/supervisor.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/supervisor.py#L140

Added line #L140 was not covered by tests

def transition_to_state(self, new_state: TrainerState, valid_states: Set[TrainerState]):
def transition_to_state(self, new_state: TrainerState, trainer_action: TrainerAction):
"""
Should be used via the ICommands to monitor the state of the training
"""
self._check_transition_to_state(new_state, valid_states)
if trainer_action == TrainerAction.START:
self._check_transition_to_start()
elif trainer_action == TrainerAction.PAUSE:
self._check_transition_to_pause()
elif trainer_action == TrainerAction.RESUME:
self._check_transition_to_resume()

Check warning on line 151 in tiktorch/server/session/backend/supervisor.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/supervisor.py#L148-L151

Added lines #L148 - L151 were not covered by tests
logger.info(f"State transition: {self._state} -> {new_state}")
self._state = new_state

def _check_transition_to_start(self):
return self._check_transition_to_state(TrainerState.RUNNING, {TrainerState.IDLE})

def _check_transition_to_pause(self):
return self._check_transition_to_state(TrainerState.PAUSED, {TrainerState.RUNNING})

def _check_transition_to_resume(self):
return self._check_transition_to_state(TrainerState.RUNNING, {TrainerState.PAUSED})

Check warning on line 162 in tiktorch/server/session/backend/supervisor.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/server/session/backend/supervisor.py#L162

Added line #L162 was not covered by tests

def _check_transition_to_state(self, new_state: TrainerState, valid_states: Set[TrainerState]):
if self._state not in valid_states:
raise StateTransitionError(
Expand Down
21 changes: 19 additions & 2 deletions tiktorch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def __str__(self):
LogsCallbacks = Callbacks[Callable[[Logs], None]]


class TrainerAction(Enum):
START = "start"
PAUSE = "pause"
RESUME = "resume"
SHUTDOWN = "shutdown"


class TrainerState(Enum):
IDLE = 0
RUNNING = 1
Expand Down Expand Up @@ -148,8 +155,18 @@ def forward(self, input_tensors):
with torch.no_grad():
self.model(input_tensors)

Check warning on line 156 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L154-L156

Added lines #L154 - L156 were not covered by tests

def should_stop(self):
return self.should_stop_callbacks() or super().should_stop()
def should_stop(self) -> bool:
"""
Intervene on how to stop the training.
"""
return self.should_stop_callbacks() or self.should_stop_model_criteria()

Check warning on line 162 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L162

Added line #L162 was not covered by tests

def should_stop_model_criteria(self) -> bool:
"""
Retain the logic designed by a custom model on how to stop the training
e.g. learning rate lower than a threshold.
"""
return super().should_stop()

def _log_stats(self, phase, loss_avg, eval_score_avg):
logs = Logs(

Check warning on line 172 in tiktorch/trainer.py

View check run for this annotation

Codecov / codecov/patch

tiktorch/trainer.py#L172

Added line #L172 was not covered by tests
Expand Down

0 comments on commit e4b53d5

Please sign in to comment.