Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update test loss #329

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions tests/test_training/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from internlm.utils.gputest import empty_cache_and_diag
from internlm.utils.megatron_timers import megatron_timer as timer

CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_sft.py")
INTERNLM1_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/model_ckpt")
CONFIG_FILE_PATH = os.getenv("CONFIG_FILE_PATH", "./configs/7B_internlm2.py")
INTERNLM2_CKPT_PATH = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss_pri/model_ckpt")
TOTAL_STEPS = 10
LOSS_SPIKE_LIMIT = 1.5
LOSS_DEVIATION_LIMIT = 0.02
Expand Down Expand Up @@ -59,32 +59,34 @@ def train(
enable_sp: bool = False,
save_ckpt: bool = False,
load_ckpt: bool = False,
model_type: str = "INTERNLM",
optimizer_ver: str = "v1",
zero_bubble: bool = False,
model_type: str = None,
):
# initialize distributed environment
config = Config.from_file(CONFIG_FILE_PATH)

# init setting
config.data.total_steps = TOTAL_STEPS
config.data.total_steps = 50000
config.data.fixed_random_dataset_seqlen = False
config.lr_scheduler.total_steps = TOTAL_STEPS
config.model_type = model_type
config.lr_scheduler.total_steps = config.data.total_steps
config.ckpt.load_ckpt_folder = None
config.ckpt.load_ckpt_info = None
config.ckpt.auto_resume = False
total_steps = config.data.total_steps
total_steps = TOTAL_STEPS
skip_batches = config.data.skip_batches
label_smoothing = config.loss.label_smoothing

if not model_type:
model_type = config.model_type

if optimizer_ver == "v2":
config.hybrid_zero_optimizer.use_split_tensor_optim = True
config.all_gather_size = 512 * 1024 * 1024

# update ckpt config
if model_type == "INTERNLM" and tp_mode != "isp" and interleaved is False:
config.ckpt.load_ckpt_info = dict(path=INTERNLM1_CKPT_PATH, content=("model",), ckpt_type="internlm_test")
if model_type == "INTERNLM2_PUBLIC" and tp_mode != "isp" and interleaved is False:
config.ckpt.load_ckpt_info = dict(path=INTERNLM2_CKPT_PATH, content=("model",), ckpt_type="internlm2")

if save_ckpt:
config.ckpt.enable_save_ckpt = True
Expand All @@ -97,6 +99,7 @@ def train(

# update parallel config
config.parallel.tensor = dict(size=tp_size, mode=tp_mode)
config.parallel.zero1 = dict(size=-1)
if zero_bubble:
config.hybrid_zero_optimizer.overlap_sync_grad = False
config.parallel.pipeline = dict(size=pp_size, zero_bubble=True)
Expand All @@ -109,16 +112,17 @@ def train(

if "use_packed_dataset" not in config.data:
config.data.use_packed_dataset = True
if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [
AcceleratorType.NPU,
AcceleratorType.DIPU,
]:
config.data.use_packed_dataset = False
# if tp_mode == "isp" and internlm_accelerator.get_accelerator_backend() in [
# AcceleratorType.NPU,
# AcceleratorType.DIPU,
# ]:
# config.data.use_packed_dataset = False

if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
launcher = "slurm"
else:
launcher = "torch"
config.model.checkpoint = True

initialize_distributed_env(config=config, launcher=launcher)
assert hasattr(gpc, "config") and gpc.config is not None
Expand Down Expand Up @@ -211,7 +215,7 @@ def train(

train_iter = iter(train_dl)

if model_type == "INTERNLM":
if model_type == "INTERNLM2_PUBLIC":
data_path = os.path.join(os.environ["share_path"], "quailty_assurance/test_loss/data_batch_4DP")
data_batch = torch.load(f"{data_path}/{gpc.get_local_rank(ParallelMode.DATA)}_data_batch.pt")

Expand All @@ -220,7 +224,7 @@ def train(
empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval)
timer("one-batch").start()

if model_type == "INTERNLM":
if model_type == "INTERNLM2_PUBLIC":
if batch_count >= 10:
batch = data_batch[batch_count - 10]
else:
Expand Down
Loading