Skip to content

Commit

Permalink
Hashable edges (#87)
Browse files Browse the repository at this point in the history
* make DirectedEdge and UndirectedEdge inherit from new Edge abstract base class, make them hashable, handle missing info['image'] gracefully in DirectedEdge.__eq__

* check (Un)directedEdge is hashable in test_(un)directed_edge
  • Loading branch information
janosh authored Oct 23, 2023
1 parent 19de8ff commit 8308faa
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 24 deletions.
57 changes: 33 additions & 24 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from abc import ABC, abstractmethod

from chgnet.utils import write_json

Expand Down Expand Up @@ -32,13 +33,13 @@ def add_neighbor(self, index, edge):
self.neighbors[index].append(edge)


class UndirectedEdge:
"""An undirected/bi-directed edge in a graph."""
class Edge(ABC):
"""Abstract base class for edges in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize an UndirectedEdge."""
"""Initialize an Edge."""
self.nodes = nodes
self.index = index
self.info = info
Expand All @@ -48,29 +49,40 @@ def __repr__(self):
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"

def __hash__(self) -> int:
"""Hash this edge."""
img = (self.info or {}).get("image")
img_str = "" if img is None else img.tostring()
return hash((self.nodes[0], self.nodes[1], img_str))

@abstractmethod
def __eq__(self, other: object) -> bool:
"""Check if two edges are equal."""
raise NotImplementedError


class UndirectedEdge(Edge):
"""An undirected/bi-directed edge in a graph."""

__hash__ = Edge.__hash__

def __eq__(self, other):
"""Check if two undirected edges are equal."""
return set(self.nodes) == set(other.nodes) and self.info == other.info


class DirectedEdge:
class DirectedEdge(Edge):
"""A directed edge in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize a DirectedEdge."""
self.nodes = nodes
self.index = index
self.info = info
__hash__ = Edge.__hash__

def make_undirected(self, index, info=None):
def make_undirected(self, index: int, info: dict | None = None) -> UndirectedEdge:
"""Make a directed edge undirected."""
info = info or {}
info["distance"] = self.info["distance"]
return UndirectedEdge(self.nodes, index, info)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Check if the two directed edges are equal.
Args:
Expand All @@ -80,7 +92,10 @@ def __eq__(self, other) -> bool:
bool: True if other is the same directed edge, or if other is the directed edge
with reverse direction of self, else False.
"""
if self.nodes == other.nodes and all(self.info["image"] == other.info["image"]):
self_img = (self.info or {}).get("image")
other_img = (other.info or {}).get("image")
none_img = self_img is other_img is None
if self.nodes == other.nodes and (none_img or all(self_img == other_img)):
# the image key here is provided by Pymatgen, which refers to the periodic
# cell the neighbor node comes from

Expand All @@ -94,18 +109,12 @@ def __eq__(self, other) -> bool:
)
return True

return (
# In this case the first edge is from node i to j and the second edge is
# from node j to i
self.nodes == other.nodes[::-1]
and (self.info["image"] == -1 * other.info["image"]).all()
# In this case the first edge is from node i to j and the second edge is
# from node j to i
return self.nodes == other.nodes[::-1] and (
none_img or all(self_img == -1 * other_img)
)

def __repr__(self):
"""String representation of this edge."""
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"


class Graph:
"""A graph for storing the neighbor information of atoms."""
Expand Down
6 changes: 6 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,14 @@ def test_directed_edge() -> None:
assert edge.index == 0
assert repr(edge) == f"DirectedEdge(nodes=[0, 1], index=0, {info=})"

# test hashable
_ = {edge}


def test_undirected_edge() -> None:
info = {"image": np.array([0, 0, 0]), "distance": 1.0}
edge = UndirectedEdge([0, 1], index=0, info=info)
assert repr(edge) == f"UndirectedEdge(nodes=[0, 1], index=0, {info=})"

# test hashable
_ = {edge}

0 comments on commit 8308faa

Please sign in to comment.