diff --git a/chgnet/graph/graph.py b/chgnet/graph/graph.py index 5b725d79..90c93508 100644 --- a/chgnet/graph/graph.py +++ b/chgnet/graph/graph.py @@ -1,6 +1,7 @@ from __future__ import annotations import sys +from abc import ABC, abstractmethod from chgnet.utils import write_json @@ -32,13 +33,13 @@ def add_neighbor(self, index, edge): self.neighbors[index].append(edge) -class UndirectedEdge: - """An undirected/bi-directed edge in a graph.""" +class Edge(ABC): + """Abstract base class for edges in a graph.""" def __init__( self, nodes: list, index: int | None = None, info: dict | None = None ) -> None: - """Initialize an UndirectedEdge.""" + """Initialize an Edge.""" self.nodes = nodes self.index = index self.info = info @@ -48,29 +49,40 @@ def __repr__(self): nodes, index, info = self.nodes, self.index, self.info return f"{type(self).__name__}({nodes=}, {index=}, {info=})" + def __hash__(self) -> int: + """Hash this edge.""" + img = (self.info or {}).get("image") + img_str = "" if img is None else img.tostring() + return hash((self.nodes[0], self.nodes[1], img_str)) + + @abstractmethod + def __eq__(self, other: object) -> bool: + """Check if two edges are equal.""" + raise NotImplementedError + + +class UndirectedEdge(Edge): + """An undirected/bi-directed edge in a graph.""" + + __hash__ = Edge.__hash__ + def __eq__(self, other): """Check if two undirected edges are equal.""" return set(self.nodes) == set(other.nodes) and self.info == other.info -class DirectedEdge: +class DirectedEdge(Edge): """A directed edge in a graph.""" - def __init__( - self, nodes: list, index: int | None = None, info: dict | None = None - ) -> None: - """Initialize a DirectedEdge.""" - self.nodes = nodes - self.index = index - self.info = info + __hash__ = Edge.__hash__ - def make_undirected(self, index, info=None): + def make_undirected(self, index: int, info: dict | None = None) -> UndirectedEdge: """Make a directed edge undirected.""" info = info or {} info["distance"] = self.info["distance"] return UndirectedEdge(self.nodes, index, info) - def __eq__(self, other) -> bool: + def __eq__(self, other: object) -> bool: """Check if the two directed edges are equal. Args: @@ -80,7 +92,10 @@ def __eq__(self, other) -> bool: bool: True if other is the same directed edge, or if other is the directed edge with reverse direction of self, else False. """ - if self.nodes == other.nodes and all(self.info["image"] == other.info["image"]): + self_img = (self.info or {}).get("image") + other_img = (other.info or {}).get("image") + none_img = self_img is other_img is None + if self.nodes == other.nodes and (none_img or all(self_img == other_img)): # the image key here is provided by Pymatgen, which refers to the periodic # cell the neighbor node comes from @@ -94,18 +109,12 @@ def __eq__(self, other) -> bool: ) return True - return ( - # In this case the first edge is from node i to j and the second edge is - # from node j to i - self.nodes == other.nodes[::-1] - and (self.info["image"] == -1 * other.info["image"]).all() + # In this case the first edge is from node i to j and the second edge is + # from node j to i + return self.nodes == other.nodes[::-1] and ( + none_img or all(self_img == -1 * other_img) ) - def __repr__(self): - """String representation of this edge.""" - nodes, index, info = self.nodes, self.index, self.info - return f"{type(self).__name__}({nodes=}, {index=}, {info=})" - class Graph: """A graph for storing the neighbor information of atoms.""" diff --git a/tests/test_graph.py b/tests/test_graph.py index c8b95be6..115f085d 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -106,8 +106,14 @@ def test_directed_edge() -> None: assert edge.index == 0 assert repr(edge) == f"DirectedEdge(nodes=[0, 1], index=0, {info=})" + # test hashable + _ = {edge} + def test_undirected_edge() -> None: info = {"image": np.array([0, 0, 0]), "distance": 1.0} edge = UndirectedEdge([0, 1], index=0, info=info) assert repr(edge) == f"UndirectedEdge(nodes=[0, 1], index=0, {info=})" + + # test hashable + _ = {edge}