Skip to content

Commit

Permalink
Adds restarts to exchange heuristic (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
matt035343 authored Jun 13, 2024
1 parent 74fdf6a commit 8c770a9
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 25 deletions.
61 changes: 38 additions & 23 deletions anti_clustering/exchange_heuristic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ class ExchangeHeuristicAntiClustering(ClusterSwapHeuristic):
The exchange heuristic to solving the anti-clustering problem.
"""

def __init__(self, verbose: bool = False, random_seed: int = None):
def __init__(self, verbose: bool = False, random_seed: int = None, restarts: int = 9):
super().__init__(verbose=verbose, random_seed=random_seed)
self.restarts = restarts

def _solve(self, distance_matrix: npt.NDArray[float], num_groups: int) -> npt.NDArray[bool]:
# Starts with random cluster assignment
Expand All @@ -38,31 +39,45 @@ def _solve(self, distance_matrix: npt.NDArray[float], num_groups: int) -> npt.ND
if self.verbose:
print("Solving")

# Initial objective value
current_objective = self._calculate_objective(cluster_assignment, distance_matrix)
for i in range(len(distance_matrix)):
if self.verbose and i % 5 == 0:
print(f"Iteration {i + 1} of {len(distance_matrix)}")
candidate_solutions = []

# Get list of possible swaps
exchange_indices = self._get_exchanges(cluster_assignment, i)
for restart in range(self.restarts):
# Initial objective value
current_objective = self._calculate_objective(cluster_assignment, distance_matrix)
for i in range(len(distance_matrix)):
if self.verbose and i % 5 == 0:
print(f"Iteration {i + 1} of {len(distance_matrix)}")

if len(exchange_indices) == 0:
continue
# Get list of possible swaps
exchange_indices = self._get_exchanges(cluster_assignment, i)

# Calculate objective value for all possible swaps.
# List contains tuples of obj. val. and swapped element index.
exchanges = [
(self._calculate_objective(self._swap(cluster_assignment, i, j), distance_matrix), j)
for j in exchange_indices
]
if len(exchange_indices) == 0:
continue

# Find best swap
best_exchange = max(exchanges)
# Calculate objective value for all possible swaps.
# List contains tuples of obj. val. and swapped element index.
exchanges = [
(self._calculate_objective(self._swap(cluster_assignment, i, j), distance_matrix), j)
for j in exchange_indices
]

# If best swap is better than current objective value then complete swap
if best_exchange[0] > current_objective:
cluster_assignment = self._swap(cluster_assignment, i, best_exchange[1])
current_objective = best_exchange[0]
# Find best swap
best_exchange = max(exchanges)

return cluster_assignment
# If best swap is better than current objective value then complete swap
if best_exchange[0] > current_objective:
cluster_assignment = self._swap(cluster_assignment, i, best_exchange[1])
current_objective = best_exchange[0]

candidate_solutions.append((current_objective, cluster_assignment))

if self.verbose:
print(f"Restart {restart + 1} of {self.restarts}")

# Cold restart, select random cluster assignment
cluster_assignment = self._get_random_clusters(num_groups=num_groups, num_elements=len(distance_matrix))

# Select best solution, maximizing objective
_, best_cluster_assignment = max(candidate_solutions, key=lambda x: x[0])

return best_cluster_assignment
4 changes: 2 additions & 2 deletions examples/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
iris_df = pd.DataFrame(data=iris_data.data, columns=iris_data.feature_names)

methods: List[AntiClustering] = [
ExchangeHeuristicAntiClustering(),
SimulatedAnnealingHeuristicAntiClustering(alpha=0.95, iterations=5000, starting_temperature=1000, restarts=15),
ExchangeHeuristicAntiClustering(restarts=20),
SimulatedAnnealingHeuristicAntiClustering(alpha=0.95, iterations=5000, starting_temperature=1000, restarts=20),
NaiveRandomHeuristicAntiClustering(),
ExactClusterEditingAntiClustering(),
]
Expand Down

0 comments on commit 8c770a9

Please sign in to comment.