diff --git a/src/pybroker/data.py b/src/pybroker/data.py index a8285ed..1fb6054 100644 --- a/src/pybroker/data.py +++ b/src/pybroker/data.py @@ -422,10 +422,10 @@ def query( start_date: Union[str, datetime], end_date: Union[str, datetime], timeframe: Optional[str] = "1d", - _: Optional[str] = None, + _adjust: Optional[str] = None, ) -> pd.DataFrame: _parse_alpaca_timeframe(timeframe) - return super().query(symbols, start_date, end_date, timeframe, _) + return super().query(symbols, start_date, end_date, timeframe, _adjust) def _fetch_data( self, @@ -433,7 +433,7 @@ def _fetch_data( start_date: datetime, end_date: datetime, timeframe: Optional[str], - _: Optional[str], + _adjust: Optional[str], ) -> pd.DataFrame: """:meta private:""" amount, unit = _parse_alpaca_timeframe(timeframe) diff --git a/src/pybroker/strategy.py b/src/pybroker/strategy.py index 68d6a23..0ec99e7 100644 --- a/src/pybroker/strategy.py +++ b/src/pybroker/strategy.py @@ -62,6 +62,7 @@ from decimal import Decimal from numpy.typing import NDArray from typing import ( + Any, Callable, Iterable, Iterator, @@ -1035,6 +1036,7 @@ def backtest( disable_parallel: bool = False, warmup: Optional[int] = None, portfolio: Optional[Portfolio] = None, + adjust: Optional[Any] = None, ) -> TestResult: """Backtests the trading strategy by running executions that were added with :meth:`.add_execution`. @@ -1085,6 +1087,8 @@ def backtest( executions. portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for backtests. + adjust: The type of adjustment to make to the + :class:`pybroker.data.DataSource`. Returns: :class:`.TestResult` containing portfolio balances, order @@ -1104,6 +1108,7 @@ def backtest( disable_parallel=disable_parallel, warmup=warmup, portfolio=portfolio, + adjust=adjust, ) def walkforward( @@ -1121,6 +1126,7 @@ def walkforward( disable_parallel: bool = False, warmup: Optional[int] = None, portfolio: Optional[Portfolio] = None, + adjust: Optional[Any] = None, ) -> TestResult: """Backtests the trading strategy using `Walkforward Analysis `_. @@ -1177,6 +1183,8 @@ def walkforward( executions. portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for backtests. + adjust: The type of adjustment to make to the + :class:`pybroker.data.DataSource`. Returns: :class:`.TestResult` containing portfolio balances, order @@ -1210,7 +1218,7 @@ def walkforward( if start_dt is not None and end_dt is not None: verify_date_range(start_dt, end_dt) self._logger.walkforward_start(start_dt, end_dt) - df = self._fetch_data(timeframe) + df = self._fetch_data(timeframe, adjust) day_ids = self._to_day_ids(days) df = self._filter_dates( df=df, @@ -1436,13 +1444,19 @@ def _fetch_indicators( disable_parallel=disable_parallel, ) - def _fetch_data(self, timeframe: str) -> pd.DataFrame: + def _fetch_data( + self, timeframe: str, adjust: Optional[Any] + ) -> pd.DataFrame: unique_syms = { sym for execution in self._executions for sym in execution.symbols } if isinstance(self._data_source, DataSource): df = self._data_source.query( - unique_syms, self._start_date, self._end_date, timeframe + unique_syms, + self._start_date, + self._end_date, + timeframe, + adjust, ) else: df = _between(self._data_source, self._start_date, self._end_date) diff --git a/tests/test_strategy.py b/tests/test_strategy.py index 370e010..8e8a400 100644 --- a/tests/test_strategy.py +++ b/tests/test_strategy.py @@ -1216,6 +1216,7 @@ def test_walkforward( between_time=between_time, calc_bootstrap=calc_bootstrap, disable_parallel=disable_parallel, + adjust="adjustment", ) if date_range[0] is None: expected_start_date = datetime.strptime(START_DATE, "%Y-%m-%d")