Skip to content

Commit

Permalink
add profiler to ocp trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Jul 9, 2024
1 parent 657598b commit bbaaeab
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 82 deletions.
10 changes: 10 additions & 0 deletions src/fairchem/core/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def mark_preempting(self) -> None:
def log_summary(self, summary_dict: dict[str, Any]) -> None:
pass

@abstractmethod
def log_artifact(self, name: str, type: str, file_location: str) -> None:
pass

@registry.register_logger("wandb")
class WandBLogger(Logger):
Expand Down Expand Up @@ -101,6 +104,10 @@ def log_summary(self, summary_dict: dict[str, Any]):
def mark_preempting(self) -> None:
wandb.mark_preempting()

def log_artifact(self, name: str, type: str, file_location: str) -> None:
art = wandb.Artifact(name=name, type=type)
art.add_file(file_location)
art.save()

@registry.register_logger("tensorboard")
class TensorboardLogger(Logger):
Expand Down Expand Up @@ -130,3 +137,6 @@ def log_plots(self, plots) -> None:

def log_summary(self, summary_dict: dict[str, Any]) -> None:
logging.warning("log_summary for Tensorboard not supported")

def log_artifact(self, name: str, type: str, file_location: str) -> None:
logging.warning("log_artifact for Tensorboard not supported")
193 changes: 111 additions & 82 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import numpy as np
import torch
import torch_geometric
from torch.profiler import ProfilerActivity, profile
from tqdm import tqdm

from fairchem.core.common import distutils
Expand Down Expand Up @@ -118,6 +119,24 @@ def __init__(
gp_gpus=gp_gpus,
)

def get_profiler_config(self):
def trace_handler(p):
if distutils.is_master():
trace_name = f"{self.config['cmd']['timestamp_id']}_rank_{distutils.get_rank()}.pt.trace.json"
output_path = os.path.join(self.config["cmd"]["results_dir"], trace_name)
print(f"Saving trace in {output_path}")
p.export_chrome_trace(output_path)
if self.logger:
self.logger.log_artifact(name=trace_name, type="profile", file_location=output_path)

wait = 5
warmup = 5
active = 2
total_profile_steps = wait + warmup + active
profile_schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active)

return trace_handler, profile_schedule, total_profile_steps

def train(self, disable_eval_tqdm: bool = False) -> None:
ensure_fitted(self._unwrapped_model, warn=True)

Expand All @@ -136,96 +155,106 @@ def train(self, disable_eval_tqdm: bool = False) -> None:
# to prevent inconsistencies due to different batch size in checkpoint.
start_epoch = self.step // len(self.train_loader)

for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]):
skip_steps = self.step % len(self.train_loader)
self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps)
train_loader_iter = iter(self.train_loader)

for i in range(skip_steps, len(self.train_loader)):
self.epoch = epoch_int + (i + 1) / len(self.train_loader)
self.step = epoch_int * len(self.train_loader) + i + 1
self.model.train()

# Get a batch.
batch = next(train_loader_iter)

# Forward, loss, backward.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
out = self._forward(batch)
loss = self._compute_loss(out, batch)

# Compute metrics.
self.metrics = self._compute_metrics(
out,
batch,
self.evaluator,
self.metrics,
)
self.metrics = self.evaluator.update("loss", loss.item(), self.metrics)

loss = self.scaler.scale(loss) if self.scaler else loss
self._backward(loss)

# Log metrics.
log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
log_dict.update(
{
"lr": self.scheduler.get_lr(),
"epoch": self.epoch,
"step": self.step,
}
)
if (
self.step % self.config["cmd"]["print_every"] == 0
and distutils.is_master()
):
log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()]
logging.info(", ".join(log_str))
self.metrics = {}

if self.logger is not None:
self.logger.log(
log_dict,
step=self.step,
split="train",
trace_handler, profile_schedule, total_profile_steps = self.get_profiler_config()

with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=profile_schedule,
on_trace_ready=trace_handler
) as p:
for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]):
skip_steps = self.step % len(self.train_loader)
self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps)
train_loader_iter = iter(self.train_loader)

for i in range(skip_steps, len(self.train_loader)):
self.epoch = epoch_int + (i + 1) / len(self.train_loader)
self.step = epoch_int * len(self.train_loader) + i + 1
self.model.train()

# Get a batch.
batch = next(train_loader_iter)

# Forward, loss, backward.
with torch.cuda.amp.autocast(enabled=self.scaler is not None):
out = self._forward(batch)
loss = self._compute_loss(out, batch)

# Compute metrics.
self.metrics = self._compute_metrics(
out,
batch,
self.evaluator,
self.metrics,
)
self.metrics = self.evaluator.update("loss", loss.item(), self.metrics)

loss = self.scaler.scale(loss) if self.scaler else loss
self._backward(loss)

# Log metrics.
log_dict = {k: self.metrics[k]["metric"] for k in self.metrics}
log_dict.update(
{
"lr": self.scheduler.get_lr(),
"epoch": self.epoch,
"step": self.step,
}
)
if (
self.step % self.config["cmd"]["print_every"] == 0
and distutils.is_master()
):
log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()]
logging.info(", ".join(log_str))
self.metrics = {}

if self.logger is not None:
self.logger.log(
log_dict,
step=self.step,
split="train",
)

if checkpoint_every != -1 and self.step % checkpoint_every == 0:
self.save(checkpoint_file="checkpoint.pt", training_state=True)
if checkpoint_every != -1 and self.step % checkpoint_every == 0:
self.save(checkpoint_file="checkpoint.pt", training_state=True)

# Evaluate on val set every `eval_every` iterations.
if self.step % eval_every == 0:
if self.val_loader is not None:
val_metrics = self.validate(
split="val",
disable_tqdm=disable_eval_tqdm,
)
self.update_best(
primary_metric,
val_metrics,
disable_eval_tqdm=disable_eval_tqdm,
)
# Evaluate on val set every `eval_every` iterations.
if self.step % eval_every == 0:
if self.val_loader is not None:
val_metrics = self.validate(
split="val",
disable_tqdm=disable_eval_tqdm,
)
self.update_best(
primary_metric,
val_metrics,
disable_eval_tqdm=disable_eval_tqdm,
)

if self.config["task"].get("eval_relaxations", False):
if "relax_dataset" not in self.config["task"]:
logging.warning(
"Cannot evaluate relaxations, relax_dataset not specified"
if self.config["task"].get("eval_relaxations", False):
if "relax_dataset" not in self.config["task"]:
logging.warning(
"Cannot evaluate relaxations, relax_dataset not specified"
)
else:
self.run_relaxations()

if self.scheduler.scheduler_type == "ReduceLROnPlateau":
if self.step % eval_every == 0:
self.scheduler.step(
metrics=val_metrics[primary_metric]["metric"],
)
else:
self.run_relaxations()
else:
self.scheduler.step()
if i < total_profile_steps:
p.step()

if self.scheduler.scheduler_type == "ReduceLROnPlateau":
if self.step % eval_every == 0:
self.scheduler.step(
metrics=val_metrics[primary_metric]["metric"],
)
else:
self.scheduler.step()
torch.cuda.empty_cache()

torch.cuda.empty_cache()
if checkpoint_every == -1:
self.save(checkpoint_file="checkpoint.pt", training_state=True)

if checkpoint_every == -1:
self.save(checkpoint_file="checkpoint.pt", training_state=True)

self.train_dataset.close_db()
if self.config.get("val_dataset", False):
Expand Down

0 comments on commit bbaaeab

Please sign in to comment.