Skip to content

Commit

Permalink
Fixing prize function
Browse files Browse the repository at this point in the history
  • Loading branch information
PatrickOHara committed Mar 29, 2024
1 parent a458531 commit cb95c1e
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 24 deletions.
20 changes: 15 additions & 5 deletions tests/test_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@

import numpy as np
import pytest
from tspwplib import build_path_to_oplib_instance, build_path_to_tsplib_instance, ProfitsProblem
from tspwplib import (
build_path_to_oplib_instance,
build_path_to_tsplib_instance,
ProfitsProblem,
)
from tspwplib.problem import PrizeCollectingTSP, BaseTSP
from tspwplib.weights import generation_three_prize, generation_three_prize_list
from tspwplib.weights import euc_2d, generation_three_prize, generation_three_prize_list
from tspwplib.types import NodeCoordType, Generation


Expand All @@ -17,7 +21,7 @@ def test_generation_three_prize_map():
[1, 1],
]
)
prizes = generation_three_prize_list(vertex_coords)
prizes = generation_three_prize_list(euc_2d, vertex_coords)
assert prizes[0] == 1
assert prizes[1] == 100
assert prizes[2] == 29
Expand All @@ -34,9 +38,11 @@ def test_generation_three_prize_map():
def test_generation_three_prize(vertex_coord, root_coord, max_distance, expected_prize):
"""Test generation three prizes are generated correctly"""
assert (
generation_three_prize(vertex_coord, root_coord, max_distance) == expected_prize
generation_three_prize(euc_2d, vertex_coord, root_coord, max_distance)
== expected_prize
)


def test_gen3_prize_is_same(oplib_root, tsplib_root, graph_name):
"""Test an OP instance can be parsed"""
op_filepath = build_path_to_oplib_instance(oplib_root, Generation.gen3, graph_name)
Expand All @@ -48,4 +54,8 @@ def test_gen3_prize_is_same(oplib_root, tsplib_root, graph_name):
assert op_problem.node_coord_type == tsp_problem.node_coord_type
if op_problem.node_coord_type == NodeCoordType.TWOD_COORDS:
coords = np.array(list(tsp_problem.node_coords.values()))
assert list(profits_problem.get_node_score().values()) == generation_three_prize_list(coords)
# TODO note that root node prize is 0 neq 1
assert (
list(profits_problem.get_node_score().values())[1:]
== generation_three_prize_list(euc_2d, coords)[1:]
)
19 changes: 15 additions & 4 deletions tspwplib/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,23 @@ def from_tsplib95(cls, problem: tsplib95.models.StandardProblem):
node_coord_type == NodeCoordType.NO_COORDS
node_coords = None
raise RuntimeWarning(f"Problem {problem.name} has no node co-ordinates.")
elif len(node_coords.get(1)) == 3 or node_coord_type == NodeCoordType.THREED_COORDS:
elif (
len(node_coords.get(1)) == 3
or node_coord_type == NodeCoordType.THREED_COORDS
):
raise NotImplementedError("3D coords not yet supported")
elif (len(node_coords.get(1)) == 2 and node_coord_type == NodeCoordType.THREED_COORDS) or (len(node_coords.get(1)) == 3 and node_coord_type == NodeCoordType.TWOD_COORDS):
raise ValueError(f"Problem {problem.name} has NODE_COORD_TYPE {node_coord_type.value}, but the length of the co-ordinates for node 1 is {len(node_coords.get(1))}.")
elif (
len(node_coords.get(1)) == 2
and node_coord_type == NodeCoordType.THREED_COORDS
) or (
len(node_coords.get(1)) == 3
and node_coord_type == NodeCoordType.TWOD_COORDS
):
raise ValueError(
f"Problem {problem.name} has NODE_COORD_TYPE {node_coord_type.value}, but the length of the co-ordinates for node 1 is {len(node_coords.get(1))}."
)
elif len(node_coords.get(1)) == 2:
node_coord_type = NodeCoordType.TWOD_COORDS
node_coord_type = NodeCoordType.TWOD_COORDS

return cls(
capacity=problem.capacity,
Expand Down
30 changes: 15 additions & 15 deletions tspwplib/weights.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Functions for calculating edge and vertex weights"""

from typing import List
from typing import List, Callable
import numpy as np
import numpy.typing as npt
from .types import Vertex
Expand All @@ -17,34 +17,34 @@ def generation_two_prize(vertex: Vertex) -> int:


def generation_three_prize_list(
dist_fn: Callable[[npt.NDArray, npt.NDArray], int],
vertex_coords: npt.ArrayLike,
root_vertex_index: int = 0,
) -> List[int]:
"""Generate a prize map from the vertex coordinates using the generation three prize function"""
# get the distance from the root vertex to every other vertex
root_coord = vertex_coords[root_vertex_index]

def __euclidean_distance_oplib(coord2: npt.NDArray) -> int:
xd = root_coord[0] - coord2[0]
yd = root_coord[1] - coord2[1]
return np.sqrt( xd*xd + yd*yd) + 0.5

# distance_v = np.vectorize(__euclidean_distance_oplib)
# max_distance = np.max(distance_v(vertex_coords))
# FIXME there is a bug here somewhere, see Section 2.1 of https://github.com/bcamath-ds/OPLib/tree/master/instances
max_distance = int(np.max(np.linalg.norm(vertex_coords - root_coord, ord=2, axis=1) + 0.5))
distances = np.apply_along_axis(dist_fn, 1, vertex_coords, root_coord)
max_distance = np.max(distances)
if np.isclose(max_distance, 0):
raise ValueError("Maximum distance from any vertex to the root vertex is zero.")
return [
generation_three_prize(coord, root_coord, max_distance)
generation_three_prize(dist_fn, coord, root_coord, max_distance)
for coord in vertex_coords
]


def generation_three_prize(
vertex_coord: npt.NDArray, root_coord: npt.NDArray, max_distance: float
dist_fn: Callable[[npt.NDArray, npt.NDArray], int],
vertex_coord: npt.NDArray,
root_coord: npt.NDArray,
max_distance: float,
) -> int:
"""Vertices have larger prizes when they are further away from the root vertex"""
return int(
1 + np.floor((99 / max_distance) * np.linalg.norm(root_coord - vertex_coord, ord=2))
)
return int(1 + np.floor((99 / max_distance) * dist_fn(vertex_coord, root_coord)))


def euc_2d(i: npt.NDArray, j: npt.NDArray) -> int:
"""The rounded Euclidean distance / L2 norm between two co-ordinates"""
return int(np.round(np.linalg.norm(i - j, ord=2)))

0 comments on commit cb95c1e

Please sign in to comment.