Skip to content

Commit

Permalink
fix: smp will not be imported if not specified by user (#651)
Browse files Browse the repository at this point in the history
* fix: smp will not be imported if not specified by user

* fix: add pipeline_parallel_degree for smp after v1.60

* remove tf related model parallel vars

* version bump
  • Loading branch information
yl-to authored Mar 13, 2023
1 parent 6cb0d55 commit 1af0e2f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 24 deletions.
2 changes: 1 addition & 1 deletion smdebug/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.0.29"
__version__ = "1.0.30"
18 changes: 12 additions & 6 deletions smdebug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
SMDebugRuntimeError,
SMDebugTypeError,
SMDebugValueError,
SMDebugError
)


Expand All @@ -49,18 +50,18 @@ class FRAMEWORK(Enum):
_smddp_tf_imported = None
_smddp_pt_imported = None
_is_using_smmodelparallel = None
_smp_imported = None

try:
import smdistributed.modelparallel.tensorflow as smp

_smp_imported = smp
except (ImportError, ModuleNotFoundError):
if check_smmodelparallel_training():
try:
import smdistributed.modelparallel.torch as smp

_smp_imported = smp
except (ImportError, ModuleNotFoundError):
_smp_imported = None
except Exception as e:
raise SMDebugError(e)


try:
Expand Down Expand Up @@ -644,8 +645,13 @@ def check_smmodelparallel_training():
else:
try:
smp_flag = json.loads(os.getenv("SM_HPS"))
if "mp_parameters" in smp_flag and "partitions" in smp_flag["mp_parameters"]:
_is_using_smmodelparallel = True
if "mp_parameters" in smp_flag:
if "pipeline_parallel_degree" in smp_flag["mp_parameters"]:
_is_using_smmodelparallel = True
elif "partitions" in smp_flag["mp_parameters"]:
_is_using_smmodelparallel = True
else:
_is_using_smmodelparallel = False
else:
_is_using_smmodelparallel = False
except:
Expand Down
17 changes: 0 additions & 17 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,6 @@
load_tf_config_json,
)

try:
import smdistributed.modelparallel.tensorflow as smp # noqa isort:skip

_smp_imported = smp
except ImportError:
_smp_imported = None


DEFAULT_INCLUDE_COLLECTIONS = [
CollectionKeys.METRICS,
Expand Down Expand Up @@ -195,11 +188,6 @@ def _get_worker_name(self) -> str:
"""
self._assert_distribution_strategy()
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
if _smp_imported and _smp_imported.core.initialized:
# when model parallel is being used, there will be multiple processes
# with same hvd rank, hence use smp.rank
return f"worker_{smp.rank()}"

import horovod.tensorflow as hvd

return f"worker_{hvd.rank()}"
Expand Down Expand Up @@ -277,11 +265,6 @@ def _get_custom_and_default_collections(self) -> Tuple[Set["Collection"], Set["C
def _get_num_workers(self):
self._assert_distribution_strategy()
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
if _smp_imported and smp.core.initialized:
# when model parallel is being used, there will be multiple hvd process groups,
# hence use smp.size
return smp.size()

import horovod.tensorflow as hvd

return hvd.size()
Expand Down

0 comments on commit 1af0e2f

Please sign in to comment.