-
Notifications
You must be signed in to change notification settings - Fork 271
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FM-v4 branch into main #752
Changes from 100 commits
ae4add3
01fe2b4
2bf8213
7ba5b8a
a367d1e
46e3c57
87714f5
3105359
e170f53
3ea4dc4
bbda257
102667f
2c571fd
319a597
d1f2ccf
1e7548d
845bce3
86da069
ff628dd
122197f
fb4ce16
c9e1759
85b8ab9
e227de3
95d3e6f
b6c640e
7fa1904
04c96bf
ea35b57
dc98285
2339916
9b58cc7
63c03fc
76322aa
0e7e4a8
bb41b13
b4e22bc
a6cc2c2
7033d10
899a227
29b6e68
f9b15cd
dc59f96
505cc24
44234b7
63348fd
45a2b4a
64b8df2
fec7fc7
80fea27
80c8e6b
f4910bc
e93e73f
883a15f
e58a53a
9b87082
d8cf857
061abf9
0b3f9fe
5277b4f
18c15f8
fb24889
57a2eaf
870fd22
e50120d
2e557ad
7ef4aec
20e62b5
87869b6
0850f34
75b7e9e
49dfca7
94f6ce1
e10575c
5743a59
4880d0c
692147d
881890e
f284190
ed0e936
25327b0
089de08
8be3c78
14a073b
7a71c46
2aca348
371eb31
a23434c
f11ac5e
8951360
67229dc
b031719
fc269b8
039f9e6
3ec098c
21eecd4
f3e1c38
f2302bf
69648fb
0b4c5ee
07efac0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,12 @@ | |
|
||
if TYPE_CHECKING: | ||
from torch_geometric.data import Data | ||
from contextlib import suppress | ||
|
||
with suppress(ImportError): | ||
# TODO remove this in favor of a better solution | ||
# We should never be importing * from a module | ||
from fairchem.experimental.foundation_models.multi_task_dataloader.transforms.data_object import * # noqa | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment |
||
|
||
|
||
class DataTransforms: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import logging | ||
import os | ||
import random | ||
import sys | ||
from abc import ABC, abstractmethod | ||
from itertools import chain | ||
from typing import TYPE_CHECKING | ||
|
@@ -126,6 +127,7 @@ def __init__( | |
"gpus": distutils.get_world_size() if not self.cpu else 0, | ||
"cmd": { | ||
"identifier": identifier, | ||
"parent": identifier, | ||
"print_every": print_every, | ||
"seed": seed, | ||
"timestamp_id": self.timestamp_id, | ||
|
@@ -232,6 +234,8 @@ def load(self) -> None: | |
self.load_loss() | ||
self.load_optimizer() | ||
self.load_extras() | ||
if self.config["optim"].get("load_datasets_and_model_then_exit", False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this needed if this is the last line of the init anyways? |
||
sys.exit(0) | ||
|
||
def set_seed(self, seed) -> None: | ||
# https://pytorch.org/docs/stable/notes/randomness.html | ||
|
@@ -571,6 +575,9 @@ def load_checkpoint( | |
self.step = checkpoint.get("step", 0) | ||
self.best_val_metric = checkpoint.get("best_val_metric", None) | ||
self.primary_metric = checkpoint.get("primary_metric", None) | ||
self.config["cmd"]["parent"] = checkpoint["config"]["cmd"].get( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is this parent actually used? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think its read anywhere, i believe it was intended to link fine tuning runs to their parent runs. @mshuaibii is this right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. lets remove if its not used There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah it was introduced to organize things in wandb by this entry. But you guys probably have a different solution at this point so ommitting is fine |
||
"parent", "identifier" | ||
) | ||
|
||
new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"]) | ||
strict = self.config.get("task", {}).get("strict_load", True) | ||
|
@@ -792,6 +799,22 @@ def update_best( | |
disable_tqdm=disable_eval_tqdm, | ||
) | ||
|
||
def _aggregate_metrics(self, metrics): | ||
aggregated_metrics = {} | ||
for k in metrics: | ||
aggregated_metrics[k] = { | ||
"total": distutils.all_reduce( | ||
metrics[k]["total"], average=False, device=self.device | ||
), | ||
"numel": distutils.all_reduce( | ||
metrics[k]["numel"], average=False, device=self.device | ||
), | ||
} | ||
aggregated_metrics[k]["metric"] = ( | ||
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] | ||
) | ||
return aggregated_metrics | ||
|
||
@torch.no_grad() | ||
def validate(self, split: str = "val", disable_tqdm: bool = False): | ||
ensure_fitted(self._unwrapped_model, warn=True) | ||
|
@@ -833,20 +856,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False): | |
metrics = self._compute_metrics(out, batch, evaluator, metrics) | ||
metrics = evaluator.update("loss", loss.item(), metrics) | ||
|
||
aggregated_metrics = {} | ||
for k in metrics: | ||
aggregated_metrics[k] = { | ||
"total": distutils.all_reduce( | ||
metrics[k]["total"], average=False, device=self.device | ||
), | ||
"numel": distutils.all_reduce( | ||
metrics[k]["numel"], average=False, device=self.device | ||
), | ||
} | ||
aggregated_metrics[k]["metric"] = ( | ||
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"] | ||
) | ||
metrics = aggregated_metrics | ||
metrics = self._aggregate_metrics(metrics) | ||
|
||
log_dict = {k: metrics[k]["metric"] for k in metrics} | ||
log_dict.update({"epoch": self.epoch}) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
comment on what this does