Skip to content

Commit

Permalink
TYP: add typing to dual directory (#539)
Browse files Browse the repository at this point in the history
  • Loading branch information
attack68 authored Dec 8, 2024
1 parent f05f82b commit 33f5827
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 59 deletions.
49 changes: 29 additions & 20 deletions python/rateslib/dual/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from rateslib.dual.variable import FLOATS, INTS, Variable
from rateslib.dual.variable import FLOATS, INTS, Arr1dF64, Arr1dObj, Arr2dF64, Arr2dObj, Variable
from rateslib.rs import ADOrder, Dual, Dual2, _dsolve1, _dsolve2, _fdsolve1, _fdsolve2

Dual.__doc__ = "Dual number data type to perform first derivative automatic differentiation."
Expand Down Expand Up @@ -77,13 +77,17 @@ def set_order_convert(
elif order == 1:
if vars_from is None:
return Dual(val, _, [])
else:
elif isinstance(vars_from, Dual):
return Dual.vars_from(vars_from, val, _, [])
else:
raise TypeError("`vars_from` must be a Dual when converting to ADOrder:1.")
elif order == 2:
if vars_from is None:
return Dual2(val, _, [], [])
else:
elif isinstance(vars_from, Dual2):
return Dual2.vars_from(vars_from, val, _, [], [])
else:
raise TypeError("`vars_from` must be a Dual2 when converting to ADOrder:2.")
# else val is Dual or Dual2 so convert directly
return set_order(val, order)

Expand All @@ -93,9 +97,7 @@ def gradient(
vars: list[str] | None = None,
order: int = 1,
keep_manifold: bool = False,
) -> (
np.ndarray[tuple[int], np.dtype[np.float64]] | np.ndarray[tuple[int, int], np.dtype[np.float64]]
):
) -> Arr1dF64 | Arr2dF64:
"""
Return derivatives of a dual number.
Expand Down Expand Up @@ -135,6 +137,7 @@ def gradient(
elif order == 2:
if isinstance(dual, Variable):
dual = Dual2(dual.real, vars=dual.vars, dual=dual.dual, dual2=[])

if vars is None:
return 2.0 * dual.dual2
else:
Expand Down Expand Up @@ -240,11 +243,14 @@ def dual_inv_norm_cdf(x: DualTypes) -> Number:


def dual_solve(
A: np.ndarray[tuple[int, int], np.dtype[np.object_]],
b: np.ndarray[tuple[int], np.dtype[np.object_]],
A: Arr2dObj | Arr2dF64,
b: Arr1dObj | Arr1dF64,
allow_lsq: bool = False,
types: tuple[Number, Number] = (Dual, Dual),
) -> np.ndarray[tuple[int], np.dtype[np.object_]]:
types: tuple[type[float] | type[Dual] | type[Dual2], type[float] | type[Dual] | type[Dual2]] = (
Dual,
Dual,
),
) -> Arr1dObj | Arr1dF64:
"""
Solve a linear system of equations involving dual number data types.
Expand Down Expand Up @@ -272,9 +278,9 @@ def dual_solve(
if types == (float, float):
# Use basic Numpy LinAlg
if allow_lsq:
return np.linalg.lstsq(A, b, rcond=None)[0]
return np.linalg.lstsq(A, b, rcond=None)[0] # type: ignore[arg-type]
else:
return np.linalg.solve(A, b)
return np.linalg.solve(A, b) # type: ignore[arg-type]

# Move to Rust implementation
if types in [(Dual, float), (Dual2, float)]:
Expand All @@ -287,19 +293,22 @@ def dual_solve(
A_ = np.vectorize(partial(set_order_convert, tag=[], order=map[types[0]], vars_from=None))(A)
b_ = np.vectorize(partial(set_order_convert, tag=[], order=map[types[1]], vars_from=None))(b)

a = [item for sublist in A_.tolist() for item in sublist] # 1D array of A_
b = b_[:, 0].tolist()
a_ = [item for sublist in A_.tolist() for item in sublist] # 1D array of A_
b_ = b_[:, 0].tolist()

if types == (Dual, Dual):
out = _dsolve1(a, b, allow_lsq)
return np.array(_dsolve1(a_, b_, allow_lsq))[:, None]
elif types == (Dual2, Dual2):
out = _dsolve2(a, b, allow_lsq)
return np.array(_dsolve2(a_, b_, allow_lsq))[:, None]
elif types == (float, Dual):
out = _fdsolve1(A_, b, allow_lsq)
return np.array(_fdsolve1(A_, b_, allow_lsq))[:, None]
elif types == (float, Dual2):
out = _fdsolve2(A_, b, allow_lsq)

return np.array(out)[:, None]
return np.array(_fdsolve2(A_, b_, allow_lsq))[:, None]
else:
raise TypeError(
"Provided `types` argument are not permitted. Must be a 2-tuple with "
"elements from {float, Dual, Dual2}"
)


def _get_adorder(order: int) -> ADOrder:
Expand Down
53 changes: 27 additions & 26 deletions python/rateslib/dual/variable.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import math
from typing import Any
from collections.abc import Sequence
from typing import Any, TypeAlias

import numpy as np

Expand All @@ -13,6 +14,12 @@
FLOATS = float | np.float16 | np.float32 | np.float64 | np.longdouble
INTS = int | np.int8 | np.int16 | np.int32 | np.int32 | np.int64

# https://stackoverflow.com/questions/68916893/
Arr1dF64: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.float64]]"
Arr2dF64: TypeAlias = "np.ndarray[tuple[int, int], np.dtype[np.float64]]"
Arr1dObj: TypeAlias = "np.ndarray[tuple[int], np.dtype[np.object_]]"
Arr2dObj: TypeAlias = "np.ndarray[tuple[int, int], np.dtype[np.object_]]"


class Variable:
"""
Expand Down Expand Up @@ -43,14 +50,14 @@ class Variable:
def __init__(
self,
real: float,
vars: tuple[str, ...] = (),
dual: np.ndarray[tuple[int, int], np.dtype[np.float64]] | NoInput = NoInput(0),
vars: Sequence[str] = (),
dual: list[float] | Arr1dF64 | NoInput = NoInput(0),
):
self.real: float = float(real)
self.vars: tuple[str, ...] = tuple(vars)
n = len(self.vars)
if isinstance(dual, NoInput) or len(dual) == 0:
self.dual: np.ndarray = np.ones(n, dtype=np.float64)
self.dual: Arr1dF64 = np.ones(n, dtype=np.float64)
else:
self.dual = np.asarray(dual.copy())

Expand Down Expand Up @@ -103,67 +110,61 @@ def __eq_coeffs__(self, argument: Dual | Dual2 | Variable, precision: float) ->
def __neg__(self) -> Variable:
return Variable(-self.real, vars=self.vars, dual=-self.dual)

def __add__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __add__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
return _1.__add__(_2)
elif isinstance(other, FLOATS | INTS):
return Variable(self.real + float(other), vars=self.vars, dual=self.dual)
elif isinstance(other, Dual):
_ = Dual(self.real, vars=self.vars, dual=self.dual)
return _.__add__(other)
return Dual(self.real, vars=self.vars, dual=self.dual).__add__(other)
elif isinstance(other, Dual2):
_ = Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[])
return _.__add__(other)
return Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[]).__add__(other)
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __radd__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __radd__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
return self.__add__(other)

def __rsub__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __rsub__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
return (self.__neg__()).__add__(other)

def __sub__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __sub__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
return self.__add__(other.__neg__())

def __mul__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __mul__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
return _1.__mul__(_2)
elif isinstance(other, FLOATS | INTS):
return Variable(self.real * float(other), vars=self.vars, dual=self.dual * float(other))
elif isinstance(other, Dual):
_ = Dual(self.real, vars=self.vars, dual=self.dual)
return _.__mul__(other)
return Dual(self.real, vars=self.vars, dual=self.dual).__mul__(other)
elif isinstance(other, Dual2):
_ = Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[])
return _.__mul__(other)
return Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[]).__mul__(other)
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __rmul__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __rmul__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
return self.__mul__(other)

def __truediv__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __truediv__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
if isinstance(other, Variable):
_1 = self._to_dual_type(defaults._global_ad_order)
_2 = other._to_dual_type(defaults._global_ad_order)
return _1.__truediv__(_2)
elif isinstance(other, FLOATS | INTS):
return Variable(self.real / float(other), vars=self.vars, dual=self.dual / float(other))
elif isinstance(other, Dual):
_ = Dual(self.real, vars=self.vars, dual=self.dual)
return _.__truediv__(other)
return Dual(self.real, vars=self.vars, dual=self.dual).__truediv__(other)
elif isinstance(other, Dual2):
_ = Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[])
return _.__truediv__(other)
return Dual2(self.real, vars=self.vars, dual=self.dual, dual2=[]).__truediv__(other)
else:
raise TypeError(f"No operation defined between `Variable` and type: `{type(other)}`")

def __rtruediv__(self, other: Dual | Dual2 | float | int | Variable) -> Dual | Dual2:
def __rtruediv__(self, other: Dual | Dual2 | float | Variable) -> Dual | Dual2 | Variable:
if isinstance(other, Variable):
# cannot reach this line
raise TypeError("Impossible line execution - please report issue.") # pragma: no cover
Expand Down Expand Up @@ -195,9 +196,9 @@ def __norm_inv_cdf__(self) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__norm_inv_cdf__()

def __pow__(self, exponent: float) -> Dual | Dual2:
def __pow__(self, exponent: float | Dual | Dual2, modulo: int | None = None) -> Dual | Dual2:
_1 = self._to_dual_type(defaults._global_ad_order)
return _1.__pow__(exponent)
return _1.__pow__(exponent, modulo)

def __repr__(self) -> str:
a = ", ".join(self.vars[:3])
Expand Down
6 changes: 3 additions & 3 deletions python/rateslib/fx/fx_forwards.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def convert(
value_date: datetime | NoInput = NoInput(0),
collateral: str | NoInput = NoInput(0),
on_error: str = "ignore",
) -> Number:
) -> Number | None:
"""
Convert an amount of a domestic currency, as of a settlement date
into a foreign currency, valued on another date.
Expand Down Expand Up @@ -1036,7 +1036,7 @@ def plot(
points: int = (right_ - left_).days
x = [left_ + timedelta(days=i) for i in range(points)]
_, path = self._rate_with_path(pair, x[0])
rates: list[DualTypes] = [self._rate_with_path(pair, _, path=path)[0] for _ in x]
rates: list[Number] = [self._rate_with_path(pair, _, path=path)[0] for _ in x]
if not fx_swap:
y: list[Number] = [rates]
else:
Expand Down Expand Up @@ -1155,7 +1155,7 @@ def forward_fx(
curve_foreign: Curve,
fx_rate: DualTypes,
fx_settlement: datetime | NoInput = NoInput(0),
) -> Dual:
) -> DualTypes:
"""
Return a forward FX rate based on interest rate parity.
Expand Down
23 changes: 14 additions & 9 deletions python/rateslib/fx/fx_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from rateslib import defaults
from rateslib.default import NoInput, _drb, _make_py_json
from rateslib.dual import Dual, DualTypes, Number, _get_adorder, gradient
from rateslib.dual.variable import Arr1dF64, Arr1dObj, Arr2dObj
from rateslib.rs import Ccy, FXRate
from rateslib.rs import FXRates as FXRatesObj

Expand Down Expand Up @@ -155,10 +156,14 @@ def __repr__(self) -> str:
return f"<rl.FXRates:[{','.join(self.currencies_list)}] at {hex(id(self))}>"

@cached_property
def fx_array(self) -> np.ndarray[tuple[int, int], np.dtype[np.object_]]:
def fx_array(self) -> Arr2dObj:
# caching this prevents repetitive data transformations between Rust/Python
return np.array(self.obj.fx_array)

def fx_array_el(self, i: int, j: int) -> Number:
# this is for typing since this numpy object array can only hold float | Dual | Dual2
return self.fx_array[i, j] # type: ignore

@property
def base(self) -> str:
return self.obj.base.name
Expand All @@ -184,7 +189,7 @@ def q(self) -> int:
return len(self.obj.currencies)

@property
def fx_vector(self) -> np.ndarray[tuple[int], np.dtype[np.object_]]:
def fx_vector(self) -> Arr1dObj:
return self.fx_array[0, :]

@property
Expand All @@ -199,7 +204,7 @@ def variables(self) -> tuple[str, ...]:
def _ad(self) -> int:
return self.obj.ad

def rate(self, pair: str) -> DualTypes:
def rate(self, pair: str) -> Number:
"""
Return a specified FX rate for a given currency pair.
Expand All @@ -221,7 +226,7 @@ def rate(self, pair: str) -> DualTypes:
fxr.rate("eurgbp")
"""
domi, fori = self.currencies[pair[:3].lower()], self.currencies[pair[3:].lower()]
return self.fx_array[domi][fori]
return self.fx_array_el(domi, fori)

def restate(self, pairs: list[str], keep_ad: bool = False) -> FXRates:
"""
Expand Down Expand Up @@ -355,7 +360,7 @@ def convert(
domestic: str,
foreign: str | NoInput = NoInput(0),
on_error: str = "ignore",
) -> Number | None:
) -> DualTypes | None:
"""
Convert an amount of a domestic currency into a foreign currency.
Expand Down Expand Up @@ -402,11 +407,11 @@ def convert(
raise ValueError(f"'{ccy}' not in FXRates.currencies.")

i, j = self.currencies[domestic.lower()], self.currencies[foreign.lower()]
return value * self.fx_array[i, j]
return value * self.fx_array_el(i, j)

def convert_positions(
self,
array: np.ndarray[tuple[int], np.dtype[np.float64]] | list[float],
array: Arr1dF64 | list[float],
base: str | NoInput = NoInput(0),
) -> Number:
"""
Expand Down Expand Up @@ -491,9 +496,9 @@ def _get_positions_from_delta(
# _[f_idx] = f_val
# _[d_idx] = -f_val / float(self.fx_array[d_idx, f_idx])
# return _
f_val = delta * float(self.fx_array[b_idx, f_idx])
f_val = delta * float(self.fx_array_el(b_idx, f_idx))
_[d_idx] = f_val
_[f_idx] = -f_val / float(self.fx_array[f_idx, d_idx])
_[f_idx] = -f_val / float(self.fx_array_el(f_idx, d_idx))
return _ # calculation is more efficient from a domestic pov than foreign

def rates_table(self) -> DataFrame:
Expand Down
Loading

0 comments on commit 33f5827

Please sign in to comment.