Skip to content

Commit

Permalink
fixed more backend issues
Browse files Browse the repository at this point in the history
  • Loading branch information
maniospas committed Jun 10, 2024
1 parent cb0aca6 commit e46fa74
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 17 deletions.
34 changes: 21 additions & 13 deletions examples/playground/run_backend.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,30 @@
import pygrank as pg
import torch
from timeit import default_timer as time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_, graph, community = next(pg.load_datasets_one_community(["youtube"], graph_api=pg, min_group_size=50))
print(f"Nodes {len(graph)}, edges {graph.number_of_edges()}")

ppr = pg.HeatKernel(
normalization="symmetric",
assume_immutability=True
)
signal = pg.to_signal(graph, {node: 1.0 for node in community})
preprocessor = ppr.preprocessor
#ppr = pg.ParameterTuner(preprocessor=preprocessor)
"""
with pg.Backend("numpy"):
preprocessor(graph)
torch.cuda.synchronize() # correct timing
tic = time()
scores = ppr(signal)
print("numpy", ppr.convergence, "actual time", time()-tic)"""

with pg.Backend("torch_sparse", device=device):
_, graph, community = next(pg.load_datasets_one_community(["amazon"]))
ppr = pg.PageRank(
alpha=0.9,
normalization="symmetric",
assume_immutability=True,
convergence=pg.ConvergenceManager(max_iters=38, error_type="iters"),
)
ppr.preprocessor(graph)
signal = pg.to_signal(graph, {node: 1.0 for node in community})
preprocessor(graph)
torch.cuda.synchronize() # correct timing
tic = time()
scores = ppr(signal)
print(ppr.convergence)
print(scores["B00005MHUG"]) # 0.00508212111890316
print(scores["B00006RGI2"]) # 0.70645672082901
print(scores["0006497993"]) # 0.19633759558200836
print("torch_sparse", ppr.convergence, "actual time", time()-tic)
4 changes: 2 additions & 2 deletions pygrank/algorithms/postprocess/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ class Sweep(Postprocessor):
Applies a sweep procedure that divides personalized node ranks by corresponding non-personalized ones.
"""

def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = None):
def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = None, assume_immutability: bool = True):
"""
Initializes the sweep procedure.
Expand Down Expand Up @@ -404,7 +404,7 @@ def __init__(self, ranker: NodeRanking = None, uniform_ranker: NodeRanking = Non
super().__init__(ranker)
self.uniform_ranker = ranker if uniform_ranker is None else uniform_ranker
self.centrality = MethodHasher(
lambda graph: self.uniform_ranker.rank(graph), assume_immutability=True
lambda graph: self.uniform_ranker.rank(graph), assume_immutability=assume_immutability
)

def _transform(self, ranks: GraphSignal, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion pygrank/core/backend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def backend_init(mode="dense", device=None):
return
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.warn(
f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}"
f"[pygrank.backend.pytorch] Automatically detected device to run on {device}: {torch.get_device(device)}"
)
if device is not None and isinstance(device, str):
device = torch.device(device)
Expand Down
2 changes: 1 addition & 1 deletion pygrank/core/backend/torch_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def backend_init(device="auto"):
return
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
warnings.warn(
f"[pygrank.backend.torch_sparse] Automatically detected device to run on {device}: {torch.cuda.get_device_name(device)}"
f"[pygrank.backend.torch_sparse] Automatically detected device to run on {device}: {torch.device(device)}"
)
if device is not None and isinstance(device, str):
device = torch.device(device)
Expand Down

0 comments on commit e46fa74

Please sign in to comment.