From 9852250eab1823420e93767bae3fa222acaa37d9 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 13 Sep 2024 15:01:38 +0800 Subject: [PATCH] test loss --- tests/test_training/test_loss.py | 36 ++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/tests/test_training/test_loss.py b/tests/test_training/test_loss.py index f928e2b6..432684d9 100644 --- a/tests/test_training/test_loss.py +++ b/tests/test_training/test_loss.py @@ -25,8 +25,8 @@ from internlm.utils.gputest import empty_cache_and_diag from internlm.utils.megatron_timers import megatron_timer as timer -CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py") -INTERNLM1_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/model_ckpt") +CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py") +INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt") TOTAL_STEPS = 10 LOSS_SPIKE_LIMIT = 1.5 LOSS_DEVIATION_LIMIT = 0.02 @@ -59,32 +59,34 @@ def train( enable_sp: bool = False, save_ckpt: bool = False, load_ckpt: bool = False, - model_type: str = "INTERNLM", optimizer_ver: str = "v1", zero_bubble: bool = False, + model_type: str = None, ): # initialize distributed environment config = Config.from_file(CONFIG_FILE_PATH) # init setting - config.data.total_steps = TOTAL_STEPS + config.data.total_steps = 50000 config.data.fixed_random_dataset_seqlen = False - config.lr_scheduler.total_steps = TOTAL_STEPS - config.model_type = model_type + config.lr_scheduler.total_steps = config.data.total_steps config.ckpt.load_ckpt_folder = None config.ckpt.load_ckpt_info = None config.ckpt.auto_resume = False - total_steps = config.data.total_steps + total_steps = TOTAL_STEPS skip_batches = config.data.skip_batches label_smoothing = config.loss.label_smoothing + if not model_type: + model_type = config.model_type + if optimizer_ver == "v2": config.hybrid_zero_optimizer.use_split_tensor_optim = True config.all_gather_size = 512 * 1024 * 1024 # update ckpt config - if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False: - config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test") + if model_type == "INTERNLM2_PUBLIC" and tp_mode != "isp" and interleaved is False: + config.ckpt.load_ckpt_info = dict(path=INTERNLM2_CKPT_PATH, content=("model",), ckpt_type="internlm2") if save_ckpt: config.ckpt.enable_save_ckpt = True @@ -97,6 +99,7 @@ def train( # update parallel config config.parallel.tensor = dict(size=tp_size, mode=tp_mode) + config.parallel.zero1 = dict(size=-1) if zero_bubble: config.hybrid_zero_optimizer.overlap_sync_grad = False config.parallel.pipeline = dict(size=pp_size, zero_bubble=True) @@ -109,16 +112,17 @@ def train( if "use_packed_dataset" not in config.data: config.data.use_packed_dataset = True - if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [ - AcceleratorType.NPU, - AcceleratorType.DIPU, - ]: - config.data.use_packed_dataset = False + # if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [ + # AcceleratorType.NPU, + # AcceleratorType.DIPU, + # ]: + # config.data.use_packed_dataset = False if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU: launcher = "slurm" else: launcher = "torch" + config.model.checkpoint = True initialize_distributed_env(config=config, launcher=launcher) assert hasattr(gpc, "config") and gpc.config is not None @@ -211,7 +215,7 @@ def train( train_iter = iter(train_dl) - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": data_path = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/data_batch_4DP") data_batch = torch.load(f"{data_path}/{gpc.get_local_rank(ParallelMode.DATA)}_data_batch.pt") @@ -220,7 +224,7 @@ def train( empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) timer("one-batch").start() - if model_type == "INTERNLM": + if model_type == "INTERNLM2_PUBLIC": if batch_count >= 10: batch = data_batch[batch_count - 10] else: