Skip to content

Commit

Permalink
layer block compiles and exports
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 2, 2024
1 parent f56b440 commit 5f223a3
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 7 deletions.
3 changes: 0 additions & 3 deletions src/fairchem/core/models/escn/escn_exportable.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,6 @@ def forward(
edge_index,
wigner,
)
print(f"x_message: {x_message.mean()}")

# Compute point-wise spherical non-linearity on aggregated messages

Expand All @@ -517,14 +516,12 @@ def forward(
x_grid = self.act(self.fc1_sphere(x_grid))
x_grid = self.act(self.fc2_sphere(x_grid))
x_grid = self.fc3_sphere(x_grid)
print(f"x_grid: {x_grid.mean()}")

# Project back to spherical harmonic coefficients
# x_message._from_grid(x_grid, self.SO3_grid["lmax_lmax"])
from_grid_mat = self.SO3_grid["lmax_lmax"].from_grid_mat[:, :, self.SO3_grid["lmax_lmax"].mapping.coefficient_idx(self.lmax_list[0], self.lmax_list[0])]
x_message_final = torch.einsum("bai,zbac->zic", from_grid_mat, x_grid)

print(f"x_message_final: {x_message_final.mean()}")
# Return aggregated messages
return x_message_final

Expand Down
70 changes: 66 additions & 4 deletions tests/core/models/test_escn_compiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,6 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.verbose = True
# torch._logging.set_logs(dynamo = logging.INFO)
# torch._dynamo.reset()
# explain_output = torch._dynamo.explain(message_block)(*args[0])
# print(explain_output)
compiled_model = torch.compile(message_block, dynamic=True)
compiled_output = compiled_model(*args[0])

Expand All @@ -240,6 +236,72 @@ def test_escn_message_block_exports_and_compiles(self, tol=1e-5) -> None:
assert torch.allclose(compiled_output, regular_out, atol=tol)
assert torch.allclose(exported_output, regular_out, atol=tol)

def test_escn_layer_block_exports_and_compiles(self, tol=1e-5) -> None:
random.seed(1)

sphere_channels = 128
hidden_channels = 128
edge_channels = 128
lmax, mmax = 4, 2
distance_expansion = GaussianSmearing(0.0, 8.0, int(8.0 / 0.02), 1.0)
SO3_grid = torch.nn.ModuleDict()
SO3_grid["lmax_lmax"] = SO3_Grid(lmax, lmax)
SO3_grid["lmax_mmax"] = SO3_Grid(lmax, mmax)
mappingReduced = CoefficientMapping([lmax], [mmax])
layer_block = escn_exportable.LayerBlock(
layer_idx = 0,
sphere_channels = sphere_channels,
hidden_channels = hidden_channels,
edge_channels = edge_channels,
lmax_list = [lmax],
mmax_list = [mmax],
distance_expansion = distance_expansion,
max_num_elements = 90,
SO3_grid = SO3_grid,
act = torch.nn.SiLU(),
mappingReduced = mappingReduced
)

# generate inputs
batch_sizes = [34, 35, 35]
num_edges = [680, 700, 680]
num_coefs = 25
run_args = []
for b,edges in zip(batch_sizes, num_edges):
x = torch.rand([b, num_coefs, sphere_channels])
atom_n = torch.randint(1, 90, (b,))
edge_d = torch.rand([edges])
edge_indx = torch.randint(0, b, (2, edges))
wigner = torch.rand([edges, num_coefs, num_coefs])
run_args.append((x, atom_n, edge_d, edge_indx, wigner))

torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True
torch._dynamo.config.verbose = True
# torch._logging.set_logs(dynamo = logging.INFO)
# torch._dynamo.reset()
# explain_output = torch._dynamo.explain(message_block)(*args[0])
# print(explain_output)

batch_dim = Dim("batch_dim")
edges_dim = Dim("edges_dim")
dynamic_shapes1 = {
"x": {0: batch_dim, 1: None, 2: None},
"atomic_numbers": {0: batch_dim},
"edge_distance": {0: edges_dim},
"edge_index": {0: None, 1: edges_dim},
"wigner": {0: edges_dim, 1: None, 2: None}
}
exported_prog = export(layer_block, args=run_args[0], dynamic_shapes=dynamic_shapes1)
for run_arg in run_args:
exported_output = exported_prog(*run_arg)
compiled_model = torch.compile(layer_block, dynamic=True)
compiled_output = compiled_model(*run_arg)
regular_out = layer_block(*run_arg)
assert torch.allclose(compiled_output, regular_out, atol=tol)
assert torch.allclose(exported_output, regular_out, atol=tol)

def test_escn_compiles(self):
init("gloo")
data = load_data()
Expand Down

0 comments on commit 5f223a3

Please sign in to comment.