-
Notifications
You must be signed in to change notification settings - Fork 83
/
rule.py
68 lines (56 loc) · 2.31 KB
/
rule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Standard Library
from abc import ABC, abstractmethod
# First Party
from smdebug.analysis.utils import no_refresh
from smdebug.core.logger import get_logger
from smdebug.exceptions import RuleEvaluationConditionMet
from smdebug.rules.action import Actions
# Local
from .req_tensors import RequiredTensors
# This is Rule interface
class Rule(ABC):
def __init__(self, base_trial, action_str, other_trials=None):
self.base_trial = base_trial
self.other_trials = other_trials
self.trials = [base_trial]
if self.other_trials is not None:
self.trials += [x for x in self.other_trials]
self.req_tensors = RequiredTensors(self.base_trial, self.other_trials)
self.logger = get_logger()
self.rule_name = self.__class__.__name__
self._actions = Actions(action_str, rule_name=self.rule_name)
self.report = {
"RuleTriggered": 0,
"Violations": 0,
"Details": {},
"Datapoints": 0,
"RuleParameters": "",
}
def set_required_tensors(self, step):
pass
# step here is global step
@abstractmethod
def invoke_at_step(self, step):
# implementation check for tensor
# do checkpoint if needed at periodic interval
# --> storage_handler.save("last_processed_tensor",(tensor_name,step))
# check-pointing is needed if execution is longer duration,
# so that we don't lose the work done in certain step
pass
# step specific for which global step this rule was invoked
# storage_handler is used to save & get states across different invocations
def invoke(self, step):
self.logger.debug("Invoking rule {} for step {}".format(self.rule_name, step))
self.base_trial.wait_for_steps([step])
# do not refresh during invoke at step
# since we have already waited till the current step
# this will ensure that the step numbers seen
# by required_tensors are the same as seen by invoke
with no_refresh(self.trials):
self.req_tensors.clear()
self.set_required_tensors(step)
self.req_tensors.fetch()
val = self.invoke_at_step(step)
if val:
self._actions.invoke()
raise RuleEvaluationConditionMet(self.rule_name, step)