Skip to content

Commit

Permalink
A careful optimization approach
Browse files Browse the repository at this point in the history
  • Loading branch information
dustalov committed Jul 8, 2024
1 parent 937cacd commit 3d2fe6e
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
4 changes: 2 additions & 2 deletions python/evalica/naive.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def newman(
broadcast_scores_t = scores[:, None].T
sqrt_scores_outer = np.sqrt(np.outer(scores, scores))
sum_scores = np.add.outer(scores, scores)
sqrt_div_scores_t = np.sqrt(np.divide.outer(scores, scores)).T
sqrt_div_scores_outer_t = np.sqrt(np.divide.outer(scores, scores)).T

scores_numerator = np.sum(
win_tie_half * (broadcast_scores_t + v * sqrt_scores_outer) / (sum_scores + 2 * v * sqrt_scores_outer),
axis=1,
)
scores_denominator = np.sum(
win_tie_half.T * (1 + v * sqrt_div_scores_t) / (sum_scores + 2 * v * sqrt_scores_outer),
win_tie_half.T * (1 + v * sqrt_div_scores_outer_t) / (sum_scores + 2 * v * sqrt_scores_outer),
axis=1,
)
scores_new[:] = scores_numerator / scores_denominator
Expand Down
39 changes: 19 additions & 20 deletions src/bradley_terry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub fn bradley_terry(
"The matrix must be square"
);

let totals = matrix.t().to_owned() + matrix;
let totals = &matrix.t().clone() + matrix;

let active = totals.mapv(|x| x > 0.0);

Expand Down Expand Up @@ -92,7 +92,6 @@ pub fn newman(
let win_tie_half = win_matrix + &(tie_matrix / 2.0);

let mut scores = Array1::<f64>::ones(win_matrix.shape()[0]);
let mut scores_new = scores.clone();
let mut v = v_init;
let mut v_new = v;

Expand All @@ -104,22 +103,28 @@ pub fn newman(

v = if v_new.is_nan() { tolerance } else { v_new };

let broadcast_scores = scores.broadcast((scores.len(), scores.len())).unwrap();
let sqrt_scores_outer = (&broadcast_scores * &broadcast_scores.t()).mapv_into(f64::sqrt);
let sum_scores = &broadcast_scores + &broadcast_scores.t();
let sqrt_div_scores_outer =
(&broadcast_scores / &broadcast_scores.t()).mapv_into(f64::sqrt);

let mut scores_new = scores.clone();

for i in 0..win_matrix.shape()[0] {
let mut i_numerator = 0.0;
let mut i_denominator = 0.0;

for j in 0..win_matrix.shape()[1] {
let sqrt_scores_ij = (scores[i] * scores[j]).sqrt();
let ij_numerator = scores[j] + v * sqrt_scores_ij;
let ij_denominator = scores[i] + scores[j] + 2.0 * v * sqrt_scores_ij;
let ij_numerator = scores[j] + v * sqrt_scores_outer[[i, j]];
let ij_denominator = sum_scores[[i, j]] + 2.0 * v * sqrt_scores_outer[[i, j]];

i_numerator += win_tie_half[[i, j]] * ij_numerator / ij_denominator;
}

for j in 0..win_matrix.shape()[1] {
let sqrt_scores_ij = (scores[i] * scores[j]).sqrt();
let ij_num = 1.0 + v * (scores[j] / scores[i]).sqrt();
let ij_den = scores[i] + scores[j] + 2.0 * v * sqrt_scores_ij;
let ij_num = 1.0 + v * sqrt_div_scores_outer[[i, j]];
let ij_den = sum_scores[[i, j]] + 2.0 * v * sqrt_scores_outer[[i, j]];

i_denominator += win_tie_half[[j, i]] * ij_num / ij_den;
}
Expand All @@ -133,25 +138,19 @@ pub fn newman(
}
});

let mut v_numerator = 0.0;
let mut v_denominator = 0.0;
let v_numerator =
(tie_matrix * &sum_scores / (&sum_scores + 2.0 * v * &sqrt_scores_outer)).sum() / 2.0;

for i in 0..win_matrix.shape()[0] {
for j in 0..win_matrix.shape()[1] {
let sqrt_scores_ij = (scores[i] * scores[j]).sqrt();
v_numerator += tie_matrix[[i, j]] / 2.0 * (scores[i] + scores[j])
/ (scores[i] + scores[j] + 2.0 * v * sqrt_scores_ij);
v_denominator += win_matrix[[i, j]] * (2.0 * sqrt_scores_ij)
/ (scores[i] + scores[j] + 2.0 * v * sqrt_scores_ij);
}
}
let v_denominator =
(win_matrix * &sqrt_scores_outer / (&sum_scores + 2.0 * v * &sqrt_scores_outer)).sum()
* 2.0;

v_new = v_numerator / v_denominator;

let difference = &scores_new - &scores;
converged = difference.dot(&difference).sqrt() < tolerance;

scores = scores_new.clone();
scores.assign(&scores_new);
}

(scores, v, iterations)
Expand Down

0 comments on commit 3d2fe6e

Please sign in to comment.