diff --git a/src/fairchem/core/common/logger.py b/src/fairchem/core/common/logger.py index fd52756e2..c1d501920 100644 --- a/src/fairchem/core/common/logger.py +++ b/src/fairchem/core/common/logger.py @@ -56,6 +56,9 @@ def mark_preempting(self) -> None: def log_summary(self, summary_dict: dict[str, Any]) -> None: pass + @abstractmethod + def log_artifact(self, name: str, type: str, file_location: str) -> None: + pass @registry.register_logger("wandb") class WandBLogger(Logger): @@ -101,6 +104,10 @@ def log_summary(self, summary_dict: dict[str, Any]): def mark_preempting(self) -> None: wandb.mark_preempting() + def log_artifact(self, name: str, type: str, file_location: str) -> None: + art = wandb.Artifact(name=name, type=type) + art.add_file(file_location) + art.save() @registry.register_logger("tensorboard") class TensorboardLogger(Logger): @@ -130,3 +137,6 @@ def log_plots(self, plots) -> None: def log_summary(self, summary_dict: dict[str, Any]) -> None: logging.warning("log_summary for Tensorboard not supported") + + def log_artifact(self, name: str, type: str, file_location: str) -> None: + logging.warning("log_artifact for Tensorboard not supported") diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 5b610aab6..0be243830 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -15,6 +15,7 @@ import numpy as np import torch import torch_geometric +from torch.profiler import ProfilerActivity, profile from tqdm import tqdm from fairchem.core.common import distutils @@ -118,6 +119,24 @@ def __init__( gp_gpus=gp_gpus, ) + def get_profiler_config(self): + def trace_handler(p): + if distutils.is_master(): + trace_name = f"{self.config['cmd']['timestamp_id']}_rank_{distutils.get_rank()}.pt.trace.json" + output_path = os.path.join(self.config["cmd"]["results_dir"], trace_name) + print(f"Saving trace in {output_path}") + p.export_chrome_trace(output_path) + if self.logger: + self.logger.log_artifact(name=trace_name, type="profile", file_location=output_path) + + wait = 5 + warmup = 5 + active = 2 + total_profile_steps = wait + warmup + active + profile_schedule = torch.profiler.schedule(wait=wait, warmup=warmup, active=active) + + return trace_handler, profile_schedule, total_profile_steps + def train(self, disable_eval_tqdm: bool = False) -> None: ensure_fitted(self._unwrapped_model, warn=True) @@ -136,96 +155,106 @@ def train(self, disable_eval_tqdm: bool = False) -> None: # to prevent inconsistencies due to different batch size in checkpoint. start_epoch = self.step // len(self.train_loader) - for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): - skip_steps = self.step % len(self.train_loader) - self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps) - train_loader_iter = iter(self.train_loader) - - for i in range(skip_steps, len(self.train_loader)): - self.epoch = epoch_int + (i + 1) / len(self.train_loader) - self.step = epoch_int * len(self.train_loader) + i + 1 - self.model.train() - - # Get a batch. - batch = next(train_loader_iter) - - # Forward, loss, backward. - with torch.cuda.amp.autocast(enabled=self.scaler is not None): - out = self._forward(batch) - loss = self._compute_loss(out, batch) - - # Compute metrics. - self.metrics = self._compute_metrics( - out, - batch, - self.evaluator, - self.metrics, - ) - self.metrics = self.evaluator.update("loss", loss.item(), self.metrics) - - loss = self.scaler.scale(loss) if self.scaler else loss - self._backward(loss) - - # Log metrics. - log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} - log_dict.update( - { - "lr": self.scheduler.get_lr(), - "epoch": self.epoch, - "step": self.step, - } - ) - if ( - self.step % self.config["cmd"]["print_every"] == 0 - and distutils.is_master() - ): - log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] - logging.info(", ".join(log_str)) - self.metrics = {} - - if self.logger is not None: - self.logger.log( - log_dict, - step=self.step, - split="train", + trace_handler, profile_schedule, total_profile_steps = self.get_profiler_config() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=profile_schedule, + on_trace_ready=trace_handler + ) as p: + for epoch_int in range(start_epoch, self.config["optim"]["max_epochs"]): + skip_steps = self.step % len(self.train_loader) + self.train_sampler.set_epoch_and_start_iteration(epoch_int, skip_steps) + train_loader_iter = iter(self.train_loader) + + for i in range(skip_steps, len(self.train_loader)): + self.epoch = epoch_int + (i + 1) / len(self.train_loader) + self.step = epoch_int * len(self.train_loader) + i + 1 + self.model.train() + + # Get a batch. + batch = next(train_loader_iter) + + # Forward, loss, backward. + with torch.cuda.amp.autocast(enabled=self.scaler is not None): + out = self._forward(batch) + loss = self._compute_loss(out, batch) + + # Compute metrics. + self.metrics = self._compute_metrics( + out, + batch, + self.evaluator, + self.metrics, + ) + self.metrics = self.evaluator.update("loss", loss.item(), self.metrics) + + loss = self.scaler.scale(loss) if self.scaler else loss + self._backward(loss) + + # Log metrics. + log_dict = {k: self.metrics[k]["metric"] for k in self.metrics} + log_dict.update( + { + "lr": self.scheduler.get_lr(), + "epoch": self.epoch, + "step": self.step, + } ) + if ( + self.step % self.config["cmd"]["print_every"] == 0 + and distutils.is_master() + ): + log_str = [f"{k}: {v:.2e}" for k, v in log_dict.items()] + logging.info(", ".join(log_str)) + self.metrics = {} + + if self.logger is not None: + self.logger.log( + log_dict, + step=self.step, + split="train", + ) - if checkpoint_every != -1 and self.step % checkpoint_every == 0: - self.save(checkpoint_file="checkpoint.pt", training_state=True) + if checkpoint_every != -1 and self.step % checkpoint_every == 0: + self.save(checkpoint_file="checkpoint.pt", training_state=True) - # Evaluate on val set every `eval_every` iterations. - if self.step % eval_every == 0: - if self.val_loader is not None: - val_metrics = self.validate( - split="val", - disable_tqdm=disable_eval_tqdm, - ) - self.update_best( - primary_metric, - val_metrics, - disable_eval_tqdm=disable_eval_tqdm, - ) + # Evaluate on val set every `eval_every` iterations. + if self.step % eval_every == 0: + if self.val_loader is not None: + val_metrics = self.validate( + split="val", + disable_tqdm=disable_eval_tqdm, + ) + self.update_best( + primary_metric, + val_metrics, + disable_eval_tqdm=disable_eval_tqdm, + ) - if self.config["task"].get("eval_relaxations", False): - if "relax_dataset" not in self.config["task"]: - logging.warning( - "Cannot evaluate relaxations, relax_dataset not specified" + if self.config["task"].get("eval_relaxations", False): + if "relax_dataset" not in self.config["task"]: + logging.warning( + "Cannot evaluate relaxations, relax_dataset not specified" + ) + else: + self.run_relaxations() + + if self.scheduler.scheduler_type == "ReduceLROnPlateau": + if self.step % eval_every == 0: + self.scheduler.step( + metrics=val_metrics[primary_metric]["metric"], ) - else: - self.run_relaxations() + else: + self.scheduler.step() + if i < total_profile_steps: + p.step() - if self.scheduler.scheduler_type == "ReduceLROnPlateau": - if self.step % eval_every == 0: - self.scheduler.step( - metrics=val_metrics[primary_metric]["metric"], - ) - else: - self.scheduler.step() + torch.cuda.empty_cache() - torch.cuda.empty_cache() + if checkpoint_every == -1: + self.save(checkpoint_file="checkpoint.pt", training_state=True) - if checkpoint_every == -1: - self.save(checkpoint_file="checkpoint.pt", training_state=True) self.train_dataset.close_db() if self.config.get("val_dataset", False):