Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hashable edges #87

Merged
merged 2 commits into from
Oct 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from abc import ABC, abstractmethod

from chgnet.utils import write_json

Expand Down Expand Up @@ -32,13 +33,13 @@ def add_neighbor(self, index, edge):
self.neighbors[index].append(edge)


class UndirectedEdge:
"""An undirected/bi-directed edge in a graph."""
class Edge(ABC):
"""Abstract base class for edges in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize an UndirectedEdge."""
"""Initialize an Edge."""
self.nodes = nodes
self.index = index
self.info = info
Expand All @@ -48,29 +49,40 @@ def __repr__(self):
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"

def __hash__(self) -> int:
"""Hash this edge."""
img = (self.info or {}).get("image")
img_str = "" if img is None else img.tostring()
return hash((self.nodes[0], self.nodes[1], img_str))

@abstractmethod
def __eq__(self, other: object) -> bool:
"""Check if two edges are equal."""
raise NotImplementedError


class UndirectedEdge(Edge):
"""An undirected/bi-directed edge in a graph."""

__hash__ = Edge.__hash__

def __eq__(self, other):
"""Check if two undirected edges are equal."""
return set(self.nodes) == set(other.nodes) and self.info == other.info


class DirectedEdge:
class DirectedEdge(Edge):
"""A directed edge in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize a DirectedEdge."""
self.nodes = nodes
self.index = index
self.info = info
__hash__ = Edge.__hash__

def make_undirected(self, index, info=None):
def make_undirected(self, index: int, info: dict | None = None) -> UndirectedEdge:
"""Make a directed edge undirected."""
info = info or {}
info["distance"] = self.info["distance"]
return UndirectedEdge(self.nodes, index, info)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Check if the two directed edges are equal.

Args:
Expand All @@ -80,7 +92,10 @@ def __eq__(self, other) -> bool:
bool: True if other is the same directed edge, or if other is the directed edge
with reverse direction of self, else False.
"""
if self.nodes == other.nodes and all(self.info["image"] == other.info["image"]):
self_img = (self.info or {}).get("image")
other_img = (other.info or {}).get("image")
none_img = self_img is other_img is None
if self.nodes == other.nodes and (none_img or all(self_img == other_img)):
# the image key here is provided by Pymatgen, which refers to the periodic
# cell the neighbor node comes from

Expand All @@ -94,18 +109,12 @@ def __eq__(self, other) -> bool:
)
return True

return (
# In this case the first edge is from node i to j and the second edge is
# from node j to i
self.nodes == other.nodes[::-1]
and (self.info["image"] == -1 * other.info["image"]).all()
# In this case the first edge is from node i to j and the second edge is
# from node j to i
return self.nodes == other.nodes[::-1] and (
none_img or all(self_img == -1 * other_img)
)

def __repr__(self):
"""String representation of this edge."""
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"


class Graph:
"""A graph for storing the neighbor information of atoms."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,14 @@ def test_directed_edge() -> None:
assert edge.index == 0
assert repr(edge) == f"DirectedEdge(nodes=[0, 1], index=0, {info=})"

# test hashable
_ = {edge}


def test_undirected_edge() -> None:
info = {"image": np.array([0, 0, 0]), "distance": 1.0}
edge = UndirectedEdge([0, 1], index=0, info=info)
assert repr(edge) == f"UndirectedEdge(nodes=[0, 1], index=0, {info=})"

# test hashable
_ = {edge}