Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Joaquim Garcia committed Dec 11, 2023
1 parent e68d603 commit b8d5df4
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,16 @@ function _quadratic_constraint_get_reverse!(
p_2 = p_idx(term.variable_2)
value_1 = get!(model.parameter_output_backward, p_1, 0.0)
value_2 = get!(model.parameter_output_backward, p_2, 0.0)
# TODO: why there is no factor of 2 here????
# ANS: probably because it was SET
model.parameter_output_backward[p_1] =
value_1 +
term.coefficient * grad_pf_cte * model.parameters[p_2] /
ifelse(term.variable_1 === term.variable_2, 2, 1)
ifelse(term.variable_1 === term.variable_2, 1, 1)
model.parameter_output_backward[p_2] =
value_2 +
term.coefficient * grad_pf_cte * model.parameters[p_1] /
ifelse(term.variable_1 === term.variable_2, 2, 1)
ifelse(term.variable_1 === term.variable_2, 1, 1)
end
for term in pf.pv
p = p_idx(term.variable_1)
Expand Down Expand Up @@ -479,3 +481,5 @@ function MOI.get(
JuMP.check_belongs_to_model(var_ref, model)
return _moi_get_result(JuMP.backend(model), attr, JuMP.index(var_ref))
end

# TODO: ignore ops that are 0
25 changes: 25 additions & 0 deletions test/jump_diff_param.jl
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,31 @@ function test_quadratic_rhs_changes()
atol = 1e-10,
)
end
for dir_x = 0:3
MOI.set(model, DiffOpt.ReverseVariablePrimal(), x, dir_x)
DiffOpt.reverse_differentiate!(model)
@test isapprox(MOI.get(model, POI.ReverseParameter(), p),
dir_x * 3 * q_val / (11 * t_val),
atol = 1e-10,
)
@test isapprox(MOI.get(model, POI.ReverseParameter(), q),
dir_x * 3 * p_val / (11 * t_val),
atol = 1e-10,
)
@test isapprox(MOI.get(model, POI.ReverseParameter(), r),
dir_x * 10 * r_val / (11 * t_val),
atol = 1e-10,
)
@test isapprox(MOI.get(model, POI.ReverseParameter(), s),
dir_x * 7 / (11 * t_val),
atol = 1e-10,
)
@test isapprox(MOI.get(model, POI.ReverseParameter(), t),
dir_x * (- (1 + 3 * p_val * q_val + 5 * r_val ^ 2 + 7 * s_val) /
(11 * t_val ^ 2)),
atol = 1e-10,
)
end
end
return
end
Expand Down

0 comments on commit b8d5df4

Please sign in to comment.