You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def main(args):
if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
if args.wandb:
import wandb
wandb.login()
set_seed(args.seed)
timeout = InitProcessGroupKwargs(timeout=timedelta(seconds=1_000_000))
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulate_every,
mixed_precision="bf16",
log_with="wandb" if args.wandb else None,
kwargs_handlers=[timeout],
# fsdp_plugin=fsdp_plugin,
)
accelerator.init_trackers(project_name=args.wandb, init_kwargs={"wandb":{"name":args.output_dir.split("/")[-1]}})
accelerator.print(f"Total GPUS: {accelerator.num_processes}")
model = AutoModelForCausalLM.from_pretrained(
args.model,
device_map=accelerator.device,
torch_dtype=torch.bfloat16,
rope_theta=args.rope_theta,
_attn_implementation="flash_attention_2",
)
# tokenizer = AutoTokenizer.from_pretrained(
# args.model,
# trust_remote_code=True,
# # llama不支持fast
# )
try:
train_dataset = load_dataset(args.dataset)
except:
train_dataset = load_from_disk(args.dataset)
if isinstance(train_dataset, DatasetDict):
train_dataset = train_dataset["train"]
# train_dataset = QwenSFTDataset(args.dataset, tokenizer, args)
assert isinstance(
model, (transformers.LlamaForCausalLM, transformers.MistralForCausalLM)
), "Only support llama and mistral model"
model_type = (
"llama" if isinstance(model, transformers.LlamaForCausalLM) else "mistral"
)
apply_seq_parallel_monkey_patch(args.parallel_mode, model_type)
if "input_ids" not in train_dataset.column_names:
raise RuntimeError("Dataset must include an `input_ids` feature")
# remove everything that is not input_ids
to_remove = [col for col in train_dataset.column_names if col != "input_ids"]
train_dataset = train_dataset.remove_columns(to_remove)
train_dataset = train_dataset.shuffle(seed=args.seed)
print("Dataset Size:", len(train_dataset))
train_loader = DataLoader(
train_dataset,
collate_fn=default_data_collator,
shuffle=True,
batch_size=args.batch_size,
)
if args.learning_rate != 2e-5:
accelerator.print(f"Warning: You also need to modify accelerate_configs/zero3_offload.json to change the learning rate")
optim = DummyOptim(model.parameters(), lr=args.learning_rate)
scheduler = DummyScheduler(
optim,
num_training_steps=args.max_train_steps,
total_num_steps=args.max_train_steps,
)
model, optim, scheduler = accelerator.prepare(model, optim, scheduler)
train_loader = prepare_dataloader(args.parallel_mode, train_loader, accelerator)
model.gradient_checkpointing_enable()
accelerator.register_for_checkpointing(scheduler)
accelerator.print(f"Max train steps: {args.max_train_steps}")
progress_bar = tqdm(
range(args.max_train_steps), disable=not accelerator.is_local_main_process
)
completed_steps = 0
model.train()
loss_func = CrossEntropyLoss(inplace_backward=True)
for step, batch in enumerate(train_loader):
input_ids = batch["input_ids"][..., : args.seq_length + 1][..., :-1]
target_ids = batch["input_ids"][..., : args.seq_length + 1][..., 1:]
position_ids = (
torch.arange(args.seq_length).unsqueeze(0).expand(input_ids.shape[0], -1)
)
# shard the input_ids according to the world size and rank according to zig zag attention
# print(input_ids.shape, position_ids.shape) # these values must be equal
prepared = prepare_seq_parallel_inputs(
args.parallel_mode,
input_ids,
position_ids,
target_ids,
accelerator.process_index,
accelerator.num_processes,
accelerator.device,
)
local_input_ids = prepared["local_input_ids"]
local_position_ids = prepared["local_position_ids"]
local_target_ids = prepared["local_target_ids"]
loss_log = None
with accelerator.accumulate(model):
logits = model(
local_input_ids,
position_ids=local_position_ids,
).logits
loss = loss_func(
logits.reshape(-1, logits.shape[-1]), local_target_ids.reshape(-1)
)
accelerator.backward(loss)
if accelerator.sync_gradients:
# pay attention here. When any seq parallel algo is turned on. This technically only log the very first chunk's loss
# and what is the first chunk really depends on how do you shard the sequence
# for zig zag attention, the first chunk contains the left most and rightmost tokens
# so you cannot compare the (logged) loss of dist attention and zigzag ring attention.
# loss_log = {"loss": loss.item(), "ppl": math.exp(loss.item())}
# we now try gathered loss to verify if ring attention and dist flash attention produce the same loss
# this may slow down the training
gathered_loss = accelerator.reduce(loss.clone().detach(), "mean")
loss_log = {
"loss": gathered_loss.item(),
"ppl": math.exp(gathered_loss.item()),
}
accelerator.log(loss_log, step=completed_steps)
optim.step()
scheduler.step()
optim.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
if loss_log is not None:
progress_bar.set_postfix(loss_log)
completed_steps += 1
if completed_steps >= args.max_train_steps:
break
accelerator.print(f"Training Finished")
accelerator.end_training()
if args.output_dir is not None:
accelerator.print(f"Saving model to {args.output_dir}")
accelerator.wait_for_everyone()
state_dict = accelerator.get_state_dict(model)
accelerator.unwrap_model(model).save_pretrained(
f"{args.output_dir}",
is_main_process=accelerator.is_main_process,
save_function=accelerator.save,
state_dict=state_dict,
)
accelerator.print(f"Saving Finished")
The text was updated successfully, but these errors were encountered:
您好,我在使用EasyContext的zigzag_ring_flash_attn模式的时候报错如上
我的所有数据都被group by length到32768+1的长度上(根据https://github.com/jzhang38/EasyContext/issues/31#issue-2308064466)
在数据并行模式下可以正常运行,但序列并行报错。
code:
The text was updated successfully, but these errors were encountered: