Skip to content

Commit

Permalink
Fix for issue #620 regarding AGE-MOEA-II (#635)
Browse files Browse the repository at this point in the history
* Fix overflow and division by zero errors: Add checks to prevent division by zero and apply regularization to avoid extremely large results in exponentiation when the base is too small.

* Add example for AGE-MOEA-II with constrained problems

* Fix NumbaDeprecationWarning in AGE-MOEA and AGE-MOEA-II

* Add additional checks for Inf and NaN values

* Fix more overflow errors

* Adding more overflow checks

* Fixing division-by-zero errors in survival_score(..)
  • Loading branch information
apanichella authored Aug 25, 2024
1 parent e00e3c1 commit ad98c14
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 19 deletions.
64 changes: 64 additions & 0 deletions examples/algorithms/moo/age2_constrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from pymoo.indicators.igd import IGD
from pymoo.util.ref_dirs import get_reference_directions
from pymoo.algorithms.moo.age2 import AGEMOEA2
from pymoo.optimize import minimize

from pymoo.problems.many import C1DTLZ1, DC1DTLZ1, DC1DTLZ3, DC2DTLZ1, DC2DTLZ3, DC3DTLZ1, DC3DTLZ3, C1DTLZ3, \
C2DTLZ2, C3DTLZ1, C3DTLZ4
import ray
import numpy as np

benchmark_algorithms = [
AGEMOEA2(),
]

benchmark_problems = [
C1DTLZ1, DC1DTLZ1, DC1DTLZ3, DC2DTLZ1, DC2DTLZ3, DC3DTLZ1, DC3DTLZ3, C1DTLZ3, C2DTLZ2, C3DTLZ1, C3DTLZ4
]


def run_benchmark(problem_class, algorithm):
# Instantiate the problem
problem = problem_class()

res = minimize(
problem,
algorithm,
pop_size=100,
verbose=True,
seed=1,
termination=('n_gen', 2000)
)

# Step 4: Generate the reference points
ref_dirs = get_reference_directions("uniform", problem.n_obj, n_points=528)

# Obtain the true Pareto front (for synthetic problems)
pareto_front = problem.pareto_front(ref_dirs)

# Calculate IGD
if res.F is None:
igd = np.Infinity
else:
igd = IGD(pareto_front)(res.F)

result = {
"problem": problem,
"algorithm": algorithm,
"result": res,
"igd": igd
}

return result


tasks = []
for problem in benchmark_problems:
for algorithm in benchmark_algorithms:
tasks.append(ray.remote(run_benchmark).remote(problem, algorithm))
result = ray.get(tasks)

for res in result:
print(f"Algorithm = {res['algorithm'].__class__.__name__}, "
f"Problem = {res['problem'].__class__.__name__}, "
f"IGD = {res['igd']}")
14 changes: 10 additions & 4 deletions pymoo/algorithms/moo/age.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,13 @@ def survival_score(self, front, ideal_point):
p = self.compute_geometry(front, extreme, n)

nn = np.linalg.norm(front, p, axis=1)
distances = self.pairwise_distances(front, p) / (nn[:, None])
# Replace very small norms with 1 to prevent division by zero
nn[nn < 1e-8] = 1

distances = self.pairwise_distances(front, p)
distances[distances < 1e-8] = 1e-8 # Replace very small entries to prevent division by zero

distances = distances / (nn[:, None])

neighbors = 2
remaining = np.arange(m)
Expand Down Expand Up @@ -209,7 +215,7 @@ def compute_geometry(front, extreme, n):
return p

@staticmethod
@jit(fastmath=True)
#@jit(nopython=True, fastmath=True)
def pairwise_distances(front, p):
m = np.shape(front)[0]
distances = np.zeros((m, m))
Expand All @@ -219,7 +225,7 @@ def pairwise_distances(front, p):
return distances

@staticmethod
@jit(fastmath=True)
@jit(nopython=True, fastmath=True)
def minkowski_distances(A, B, p):
m1 = np.shape(A)[0]
m2 = np.shape(B)[0]
Expand Down Expand Up @@ -254,7 +260,7 @@ def find_corner_solutions(front):
return indexes


@jit(fastmath=True)
@jit(nopython=True, fastmath=True)
def point_2_line_distance(P, A, B):
d = np.zeros(P.shape[0])

Expand Down
60 changes: 45 additions & 15 deletions pymoo/algorithms/moo/age2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,48 +64,78 @@ def __init__(self,
self.tournament_type = 'comp_by_rank_and_crowding'


@jit(fastmath=True)
@jit(nopython=True, fastmath=True)
def project_on_manifold(point, p):
dist = sum(point[point > 0] ** p) ** (1/p)
return np.multiply(point, 1 / dist)


import numpy as np


def find_zero(point, n, precision):
x = 1

epsilon = 1e-10 # Small constant for regularization
past_value = x
max_float = np.finfo(np.float64).max # Maximum representable float value
log_max_float = np.log(max_float) # Logarithm of the maximum float

for i in range(0, 100):

# Original function
# Original function with regularization
f = 0.0
for obj_index in range(0, n):
if point[obj_index] > 0:
f += np.power(point[obj_index], x)
log_value = x * np.log(point[obj_index] + epsilon)
if log_value < log_max_float:
f += np.exp(log_value)
else:
return 1 # Handle overflow by returning a fallback value

f = np.log(f)
f = np.log(f) if f > 0 else 0 # Avoid log of non-positive numbers

# Derivative
# Derivative with regularization
numerator = 0
denominator = 0
for obj_index in range(0, n):
if point[obj_index] > 0:
numerator = numerator + np.power(point[obj_index], x) * np.log(point[obj_index])
denominator = denominator + np.power(point[obj_index], x)

if denominator == 0:
return 1
log_value = x * np.log(point[obj_index] + epsilon)
if log_value < log_max_float:
power_value = np.exp(log_value)
log_term = np.log(point[obj_index] + epsilon)

# Use logarithmic comparison to avoid overflow
if log_value + np.log(abs(log_term) + epsilon) < log_max_float:
result = power_value * log_term
numerator += result
denominator += power_value
else:
# Handle extreme cases by capping the result
numerator += max_float
denominator += power_value
else:
return 1 # Handle overflow by returning a fallback value

if denominator == 0 or np.isnan(denominator) or np.isinf(denominator):
return 1 # Handle division by zero or NaN

if np.isnan(numerator) or np.isinf(numerator):
return 1 # Handle invalid numerator

ff = numerator / denominator

# zero of function
if ff == 0: # Check for zero before division
return 1 # Handle by returning a fallback value

# Update x using Newton's method
x = x - f / ff

if abs(x - past_value) <= precision:
break
else:
paste_value = x # update current point
past_value = x # Update current point

if isinstance(x, complex):
if isinstance(x, complex) or np.isinf(x) or np.isnan(x):
return 1
else:
return x
Expand Down Expand Up @@ -135,7 +165,7 @@ def compute_geometry(front, extreme, n):
return p

@staticmethod
@jit(fastmath=True)
@jit(nopython=True, fastmath=True)
def pairwise_distances(front, p):
m, n = front.shape
projected_front = front.copy()
Expand Down

0 comments on commit ad98c14

Please sign in to comment.