Skip to content

Commit

Permalink
Adds adjust parameter to walkforward and backtest.
Browse files Browse the repository at this point in the history
  • Loading branch information
edtechre committed Nov 15, 2024
1 parent a67422a commit 27422de
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
6 changes: 3 additions & 3 deletions src/pybroker/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,18 +422,18 @@ def query(
start_date: Union[str, datetime],
end_date: Union[str, datetime],
timeframe: Optional[str] = "1d",
_: Optional[str] = None,
_adjust: Optional[str] = None,
) -> pd.DataFrame:
_parse_alpaca_timeframe(timeframe)
return super().query(symbols, start_date, end_date, timeframe, _)
return super().query(symbols, start_date, end_date, timeframe, _adjust)

def _fetch_data(
self,
symbols: frozenset[str],
start_date: datetime,
end_date: datetime,
timeframe: Optional[str],
_: Optional[str],
_adjust: Optional[str],
) -> pd.DataFrame:
""":meta private:"""
amount, unit = _parse_alpaca_timeframe(timeframe)
Expand Down
20 changes: 17 additions & 3 deletions src/pybroker/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from decimal import Decimal
from numpy.typing import NDArray
from typing import (
Any,
Callable,
Iterable,
Iterator,
Expand Down Expand Up @@ -1035,6 +1036,7 @@ def backtest(
disable_parallel: bool = False,
warmup: Optional[int] = None,
portfolio: Optional[Portfolio] = None,
adjust: Optional[Any] = None,
) -> TestResult:
"""Backtests the trading strategy by running executions that were added
with :meth:`.add_execution`.
Expand Down Expand Up @@ -1085,6 +1087,8 @@ def backtest(
executions.
portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for
backtests.
adjust: The type of adjustment to make to the
:class:`pybroker.data.DataSource`.
Returns:
:class:`.TestResult` containing portfolio balances, order
Expand All @@ -1104,6 +1108,7 @@ def backtest(
disable_parallel=disable_parallel,
warmup=warmup,
portfolio=portfolio,
adjust=adjust,
)

def walkforward(
Expand All @@ -1121,6 +1126,7 @@ def walkforward(
disable_parallel: bool = False,
warmup: Optional[int] = None,
portfolio: Optional[Portfolio] = None,
adjust: Optional[Any] = None,
) -> TestResult:
"""Backtests the trading strategy using `Walkforward Analysis
<https://www.pybroker.com/en/latest/notebooks/6.%20Training%20a%20Model.html#Walkforward-Analysis>`_.
Expand Down Expand Up @@ -1177,6 +1183,8 @@ def walkforward(
executions.
portfolio: Custom :class:`pybroker.portfolio.Portfolio` to use for
backtests.
adjust: The type of adjustment to make to the
:class:`pybroker.data.DataSource`.
Returns:
:class:`.TestResult` containing portfolio balances, order
Expand Down Expand Up @@ -1210,7 +1218,7 @@ def walkforward(
if start_dt is not None and end_dt is not None:
verify_date_range(start_dt, end_dt)
self._logger.walkforward_start(start_dt, end_dt)
df = self._fetch_data(timeframe)
df = self._fetch_data(timeframe, adjust)
day_ids = self._to_day_ids(days)
df = self._filter_dates(
df=df,
Expand Down Expand Up @@ -1436,13 +1444,19 @@ def _fetch_indicators(
disable_parallel=disable_parallel,
)

def _fetch_data(self, timeframe: str) -> pd.DataFrame:
def _fetch_data(
self, timeframe: str, adjust: Optional[Any]
) -> pd.DataFrame:
unique_syms = {
sym for execution in self._executions for sym in execution.symbols
}
if isinstance(self._data_source, DataSource):
df = self._data_source.query(
unique_syms, self._start_date, self._end_date, timeframe
unique_syms,
self._start_date,
self._end_date,
timeframe,
adjust,
)
else:
df = _between(self._data_source, self._start_date, self._end_date)
Expand Down
1 change: 1 addition & 0 deletions tests/test_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ def test_walkforward(
between_time=between_time,
calc_bootstrap=calc_bootstrap,
disable_parallel=disable_parallel,
adjust="adjustment",
)
if date_range[0] is None:
expected_start_date = datetime.strptime(START_DATE, "%Y-%m-%d")
Expand Down

0 comments on commit 27422de

Please sign in to comment.