Skip to content

Commit

Permalink
make DirectedEdge and UndirectedEdge inherit from new Edge abstract b…
Browse files Browse the repository at this point in the history
…ase class, make them hashable, handle missing info['image'] gracefully in DirectedEdge.__eq__
  • Loading branch information
janosh committed Oct 23, 2023
1 parent 19de8ff commit d0aa03b
Showing 1 changed file with 33 additions and 24 deletions.
57 changes: 33 additions & 24 deletions chgnet/graph/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import sys
from abc import ABC, abstractmethod

from chgnet.utils import write_json

Expand Down Expand Up @@ -32,13 +33,13 @@ def add_neighbor(self, index, edge):
self.neighbors[index].append(edge)


class UndirectedEdge:
"""An undirected/bi-directed edge in a graph."""
class Edge(ABC):
"""Abstract base class for edges in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize an UndirectedEdge."""
"""Initialize an Edge."""
self.nodes = nodes
self.index = index
self.info = info
Expand All @@ -48,29 +49,40 @@ def __repr__(self):
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"

def __hash__(self) -> int:
"""Hash this edge."""
img = (self.info or {}).get("image")
img_str = "" if img is None else img.tostring()
return hash((self.nodes[0], self.nodes[1], img_str))

@abstractmethod
def __eq__(self, other: object) -> bool:
"""Check if two edges are equal."""
raise NotImplementedError


class UndirectedEdge(Edge):
"""An undirected/bi-directed edge in a graph."""

__hash__ = Edge.__hash__

def __eq__(self, other):
"""Check if two undirected edges are equal."""
return set(self.nodes) == set(other.nodes) and self.info == other.info


class DirectedEdge:
class DirectedEdge(Edge):
"""A directed edge in a graph."""

def __init__(
self, nodes: list, index: int | None = None, info: dict | None = None
) -> None:
"""Initialize a DirectedEdge."""
self.nodes = nodes
self.index = index
self.info = info
__hash__ = Edge.__hash__

def make_undirected(self, index, info=None):
def make_undirected(self, index: int, info: dict | None = None) -> UndirectedEdge:
"""Make a directed edge undirected."""
info = info or {}
info["distance"] = self.info["distance"]
return UndirectedEdge(self.nodes, index, info)

def __eq__(self, other) -> bool:
def __eq__(self, other: object) -> bool:
"""Check if the two directed edges are equal.
Args:
Expand All @@ -80,7 +92,10 @@ def __eq__(self, other) -> bool:
bool: True if other is the same directed edge, or if other is the directed edge
with reverse direction of self, else False.
"""
if self.nodes == other.nodes and all(self.info["image"] == other.info["image"]):
self_img = (self.info or {}).get("image")
other_img = (other.info or {}).get("image")
none_img = self_img is other_img is None
if self.nodes == other.nodes and (none_img or all(self_img == other_img)):
# the image key here is provided by Pymatgen, which refers to the periodic
# cell the neighbor node comes from

Expand All @@ -94,18 +109,12 @@ def __eq__(self, other) -> bool:
)
return True

return (
# In this case the first edge is from node i to j and the second edge is
# from node j to i
self.nodes == other.nodes[::-1]
and (self.info["image"] == -1 * other.info["image"]).all()
# In this case the first edge is from node i to j and the second edge is
# from node j to i
return self.nodes == other.nodes[::-1] and (
none_img or all(self_img == -1 * other_img)
)

def __repr__(self):
"""String representation of this edge."""
nodes, index, info = self.nodes, self.index, self.info
return f"{type(self).__name__}({nodes=}, {index=}, {info=})"


class Graph:
"""A graph for storing the neighbor information of atoms."""
Expand Down

0 comments on commit d0aa03b

Please sign in to comment.