Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry pick PR #7003 #7441

Merged
merged 11 commits into from
Sep 18, 2023
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/conf/megatron_gpt_config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- optional [email protected]_tp_comm_overlap_cfg:

name: megatron_gpt
restore_from_path: null # used when starting from a .nemo file

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS1/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# UB communicator configurations
# Model configs: A100/175B/TP4/MBS2/SeqLen2K/BF16

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 8
num_splits: 4
set_sm_margin: 0

fc2_fprop:
method: pipeline
num_sm: 4
num_splits: 4
set_sm_margin: 0
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP4/MBS1/SeqLen2K/FP8

# Bulk overlap with AllGather / ReduceScatter
qkv_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 2
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 0

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 1

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 20
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# UB communicator configurations
# Model configs: H100/175B/TP8/MBS2/SeqLen2K/FP8

# Bulk overlap with AllGather
qkv_dgrad:
method: bulk
num_sm: 8
cga_size: 2
set_sm_margin: 0

qkv_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

fc1_dgrad:
method: bulk
num_sm: 4
cga_size: 2
set_sm_margin: 0

fc1_wgrad:
method: bulk
num_sm: 16
cga_size: 2
set_sm_margin: 0

## Ring-exchange overlap with AllGather
qkv_fprop:
method: ring_exchange
aggregate: 0

proj_dgrad:
method: ring_exchange
aggregate: 1

fc1_fprop:
method: ring_exchange
aggregate: 0

fc2_dgrad:
method: ring_exchange
aggregate: 0

# Chunked-collective overlap with ReduceScatter
proj_fprop:
method: pipeline
num_sm: 16
cga_size: 2
num_splits: 4
set_sm_margin: 1

fc2_fprop:
method: pipeline
num_sm: 24
cga_size: 2
num_splits: 4
set_sm_margin: 1
Original file line number Diff line number Diff line change
Expand Up @@ -511,20 +511,16 @@ def fwd_bwd_step(self, dataloader_iter, batch_idx, forward_only):
return loss_mean

def initialize_ub_func(self):
ub_cfgs = self.cfg.get('ub_tp_comm_overlap_cfg', None)
if ub_cfgs is None:
warnings.warn(
"Couldn't find TP config. Please check the path correctness. Initializing TP comm overlap with the default config."
)

input_shape = [
self.cfg.get('encoder_seq_length') * self.cfg.get('micro_batch_size'),
self.cfg.get('hidden_size'),
]
ub_cfg_file_name = self.cfg.get('ub_tp_comm_overlap_cfg', None)
ub_cfgs = None
if ub_cfg_file_name is not None:
try:
import yaml

with open(ub_cfg_file_name, 'r') as ub_cfg_file:
ub_cfgs = yaml.safe_load(ub_cfg_file)
except (ImportError, TypeError):
logging.error(f"Fail to read ub_tp_comm_overlap config file: {ub_cfg_file_name}.")
te_module.base.initialize_ub(
shape=input_shape,
tp_size=self.cfg.get('tensor_model_parallel_size'),
Expand Down
18 changes: 18 additions & 0 deletions nemo/core/config/hydra_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,24 @@
from hydra.types import TaskFunction
from omegaconf import DictConfig, OmegaConf


def _get_gpu_name():
import pynvml

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
cuda_capability, _ = pynvml.nvmlDeviceGetCudaComputeCapability(handle)
pynvml.nvmlShutdown()
if cuda_capability == 8:
return "a100"
elif cuda_capability == 9:
return "h100"
else:
return None


OmegaConf.register_new_resolver("gpu_name", _get_gpu_name)

# multiple interpolated values in the config
OmegaConf.register_new_resolver("multiply", lambda x, y: x * y)

Expand Down
Loading