Skip to content

Commit

Permalink
TYP: dual typing fixes (#543)
Browse files Browse the repository at this point in the history
Co-authored-by: JHM Darbyshire (win11) <[email protected]>
  • Loading branch information
attack68 and attack68 authored Dec 10, 2024
1 parent b22b7cd commit 27f93e5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
7 changes: 5 additions & 2 deletions python/rateslib/dual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,16 @@ def gradient(
return dual.dual
elif vars is not None and not keep_manifold:
return dual.grad1(vars)

_ = dual.grad1_manifold(vars)
elif isinstance(dual, Dual): # and keep_manifold:
raise TypeError("Dual type cannot perform `keep_manifold`.")
_ = dual.grad1_manifold(dual.vars if vars is None else vars)
return np.asarray(_)

elif order == 2:
if isinstance(dual, Variable):
dual = Dual2(dual.real, vars=dual.vars, dual=dual.dual, dual2=[])
elif isinstance(dual, Dual):
raise TypeError("Dual type cannot derive second order automatic derivatives.")

if vars is None:
return 2.0 * dual.dual2
Expand Down
2 changes: 1 addition & 1 deletion python/rateslib/rs.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class Dual2:
dual2: list[float] | Arr1dF64,
) -> Dual2: ...
def grad1(self, vars: Sequence[str]) -> Arr1dF64: ...
def gra1_manifold(self, vars: Sequence[str]) -> list[Dual2]: ...
def grad1_manifold(self, vars: Sequence[str]) -> list[Dual2]: ...
def grad2(self, vars: list[str]) -> Arr2dF64: ...
def ptr_eq(self, other: Dual2) -> bool: ...
def __repr__(self) -> str: ...
Expand Down

0 comments on commit 27f93e5

Please sign in to comment.