Skip to content

Commit

Permalink
Merge pull request #20 from PatrickOHara/gt_graph
Browse files Browse the repository at this point in the history
➕ Install graph tool
  • Loading branch information
PatrickOHara authored Jan 11, 2021
2 parents fdefce2 + 4716bb6 commit 93b42bb
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
[mypy]

[mypy-graph_tool.*]
ignore_missing_imports=True

[mypy-networkx.*]
ignore_missing_imports = True

Expand Down
14 changes: 14 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,20 @@ before_install:
- git clone https://github.com/bcamath-ds/OPLib.git ../OPLib
- git clone https://github.com/rhgrant10/tsplib95.git ../tsplib95
install:
# install conda
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- source "$HOME/miniconda/etc/profile.d/conda.sh"
- hash -r
- conda config --set always_yes yes --set changeps1 no
- conda update -q conda
- conda info -a
# create conda env and activate
- conda create -q -n test-environment python=$TRAVIS_PYTHON_VERSION
- conda activate test-environment
# install graph tool
- conda install -c conda-forge graph-tool
# install our package
- pip install .
- pip install -r requirements.txt
script:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
description="Library of instances for TSP with Profits",
install_requires=[
"pandas>=1.0.0",
"tsplib95",
"tsplib95>=0.7.1",
],
name="tspwplib",
packages=["tspwplib"],
Expand Down
33 changes: 33 additions & 0 deletions tests/test_profits_problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,36 @@ def test_get_graph(oplib_root, generation, graph_name, alpha):
assert type(value) in valid_types

assert graph.graph["root"] == 1


def test_get_graph_tool(oplib_root, generation, graph_name, alpha):
"""Test returning graph tool undirected weighted graph"""
filepath = build_path_to_oplib_instance(
oplib_root, generation, graph_name, alpha=alpha
)
problem = ProfitsProblem.load(filepath)
gt_graph = problem.get_graph_tool()
nx_graph = problem.get_graph(normalize=True)
assert nx_graph.has_node(0)
assert 0 in gt_graph.get_vertices()
assert gt_graph.num_vertices() == nx_graph.number_of_nodes()
assert gt_graph.num_edges() == nx_graph.number_of_edges()

# check weight
for u, v, data in nx_graph.edges(data=True):
gt_edge = gt_graph.edge(u, v, add_missing=False)
assert gt_edge
assert gt_graph.ep.weight[gt_edge] == data["weight"]
# check prize on vertices
for u, data in nx_graph.nodes(data=True):
assert data["prize"] == gt_graph.vertex_properties.prize[u]


def test_get_root_vertex(oplib_root, generation, graph_name, alpha):
"""Test the root vertex is 1 when un-normalized (0 when normalized)"""
filepath = build_path_to_oplib_instance(
oplib_root, generation, graph_name, alpha=alpha
)
problem = ProfitsProblem.load(filepath)
assert problem.get_root_vertex(normalize=False) == 1
assert problem.get_root_vertex(normalize=True) == 0
72 changes: 66 additions & 6 deletions tspwplib/problem.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Functions and classes for datasets"""

from typing import List
import graph_tool as gt
import networkx as nx
import tsplib95
from .types import Vertex, VertexFunctionName, VertexLookup
from .types import Vertex, VertexLookup


class ProfitsProblem(tsplib95.models.StandardProblem):
Expand Down Expand Up @@ -32,11 +33,11 @@ def __set_graph_attributes(self, graph: nx.Graph) -> None:
graph.graph["type"] = self.type
graph.graph["dimension"] = self.dimension
graph.graph["capacity"] = self.capacity
# pylint: disable=unsubscriptable-object
graph.graph["root"] = self.depots[0]
graph.graph["root"] = self.get_root_vertex()

def __set_node_attributes(self, graph: nx.Graph, names: VertexLookup) -> None:
"""Add node attributes"""
node_score = self.get_node_score()
for vertex in list(self.get_nodes()):
# NOTE pyintergraph cannot handle bool, so we remove some attributes:
# is_depot, demand, display
Expand All @@ -47,6 +48,7 @@ def __set_node_attributes(self, graph: nx.Graph, names: VertexLookup) -> None:
names[vertex],
x=coord[0],
y=coord[1],
prize=node_score[vertex],
# is_depot=is_depot,
)
# demand: int = self.demands.get(vertex)
Expand All @@ -55,9 +57,9 @@ def __set_node_attributes(self, graph: nx.Graph, names: VertexLookup) -> None:
# graph[vertex]["demand"] = demand
# if not display is None:
# graph[vertex]["display"] = display
nx.set_node_attributes(
graph, self.get_node_score(), name=VertexFunctionName.prize
)
# nx.set_node_attributes(
# graph, self.get_node_score(), name=VertexFunctionName.prize
# )

def get_graph(self, normalize: bool = False) -> nx.Graph:
"""Return a networkx graph instance representing the problem.
Expand All @@ -82,6 +84,41 @@ def get_graph(self, normalize: bool = False) -> nx.Graph:
self.__set_edge_attributes(graph, names)
return graph

def get_graph_tool(self, normalize: bool = True) -> gt.Graph:
"""Return a graph tools undirected graph
Args:
normalize: rename nodes to be zero-indexed
"""
graph = gt.Graph(directed=not self.is_symmetric())

# by default normalize because graph tools index starts at zero
nodes: List[Vertex] = list(self.get_nodes())
if normalize:
names = {n: i for i, n in enumerate(nodes)}
else:
names = {n: n for n in nodes}

# create list of edges
edges = []
for u, v in self.get_edges():
if u <= v or not self.is_symmetric():
edges.append((names[u], names[v], self.get_weight(u, v)))

# assign weight to edges
weight_property = graph.new_edge_property("int")
graph.add_edge_list(edges, eprops=[weight_property])
graph.ep.weight = weight_property

# assign prize to vertices
prize_property = graph.new_vertex_property("int")
node_score = self.get_node_score()
prize_list = [node_score[v + 1] for v in graph.get_vertices()]
prize_property.a = prize_list
graph.vertex_properties.prize = prize_property

return graph

def get_cost_limit(self) -> int:
"""Get the cost limit for a TSP with Profits problem
Expand All @@ -105,3 +142,26 @@ def get_tsp_optimal_value(self) -> int:
TSP optimal value
"""
return self.tspsol

def get_root_vertex(self, normalize: bool = False) -> Vertex:
"""Get the root vertex
Args:
normalize: If true, vertices start at index 0
Returns:
The first depot in the list
Raises:
ValueError: If the list of depots is empty
"""
nodes: List[Vertex] = list(self.get_nodes())
if normalize:
names = {n: i for i, n in enumerate(nodes)}
else:
names = {n: n for n in nodes}
try:
# pylint: disable=unsubscriptable-object
return names[self.depots[0]]
except KeyError as key_error:
raise ValueError("The list of depots is empty") from key_error

0 comments on commit 93b42bb

Please sign in to comment.