Skip to content

Commit

Permalink
feat: break df into chunks, handle df larger than memory
Browse files Browse the repository at this point in the history
  • Loading branch information
softwareentrepreneer committed Apr 29, 2024
1 parent cf28507 commit 3d62d31
Show file tree
Hide file tree
Showing 25 changed files with 435 additions and 314 deletions.
8 changes: 4 additions & 4 deletions docs/getting-started/basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Therefore, a virtual broker named `CRYPTO` has been created as an intermediary t
```{code-cell}
:tags: [hide-output]
from pfund.const.commons import SUPPORTED_BROKERS
from pfund.const.common import SUPPORTED_BROKERS
from pprint import pprint
pprint(SUPPORTED_BROKERS)
Expand Down Expand Up @@ -70,7 +70,7 @@ Unlike the virtual broker `CRYPTO`, which is an actual broker object in `pfund`
```{code-cell}
:tags: [hide-output]
from pfund.const.commons import SUPPORTED_CRYPTO_EXCHANGES
from pfund.const.common import SUPPORTED_CRYPTO_EXCHANGES
from pprint import pprint
pprint(SUPPORTED_CRYPTO_EXCHANGES)
Expand Down Expand Up @@ -114,7 +114,7 @@ Financial products/instruments are in the format of `XXX_YYY_PTYPE` where
```{code-cell}
:tags: [hide-output]
from pfund.const.commons import SUPPORTED_PRODUCT_TYPES
from pfund.const.common import SUPPORTED_PRODUCT_TYPES
from pprint import pprint
pprint(SUPPORTED_PRODUCT_TYPES)
Expand All @@ -132,7 +132,7 @@ Crypto product types supported by `pfund` include:
```{code-cell}
:tags: [hide-output]
from pfund.const.commons import SUPPORTED_CRYPTO_PRODUCT_TYPES
from pfund.const.common import SUPPORTED_CRYPTO_PRODUCT_TYPES
from pprint import pprint
pprint(SUPPORTED_CRYPTO_PRODUCT_TYPES)
Expand Down
2 changes: 1 addition & 1 deletion pfund/accounts/account_crypto.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pfund.accounts.account_base import BaseAccount
from pfund.const.commons import SUPPORTED_BYBIT_ACCOUNT_TYPES
from pfund.const.common import SUPPORTED_BYBIT_ACCOUNT_TYPES


class CryptoAccount(BaseAccount):
Expand Down
2 changes: 1 addition & 1 deletion pfund/brokers/broker_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from collections import defaultdict

from pfund.const.commons import SUPPORTED_ENVIRONMENTS
from pfund.const.common import SUPPORTED_ENVIRONMENTS
from pfund.utils.utils import get_engine_class


Expand Down
2 changes: 1 addition & 1 deletion pfund/brokers/broker_crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pfund.utils.utils import convert_to_uppercases
from pfund.brokers.broker_live import LiveBroker
from pfund.exchanges.exchange_base import BaseExchange
from pfund.const.commons import SUPPORTED_CRYPTO_EXCHANGES, SUPPORTED_CRYPTO_PRODUCT_TYPES
from pfund.const.common import SUPPORTED_CRYPTO_EXCHANGES, SUPPORTED_CRYPTO_PRODUCT_TYPES


class CryptoBroker(LiveBroker):
Expand Down
2 changes: 1 addition & 1 deletion pfund/brokers/ib/broker_ib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pfund.adapter import Adapter
from pfund.config.configuration import Configuration
from pfund.const.paths import PROJ_CONFIG_PATH
from pfund.const.commons import SUPPORTED_PRODUCT_TYPES
from pfund.const.common import SUPPORTED_PRODUCT_TYPES
from pfund.products import IBProduct
from pfund.accounts import IBAccount
from pfund.orders import IBOrder
Expand Down
2 changes: 1 addition & 1 deletion pfund/brokers/ib/ib_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from pfund.brokers.ib.ib_client import IBClient
from pfund.brokers.ib.ib_wrapper import *
from pfund.const.commons import SUPPORTED_DATA_CHANNELS
from pfund.const.common import SUPPORTED_DATA_CHANNELS
from pfund.zeromq import ZeroMQ


Expand Down
2 changes: 1 addition & 1 deletion pfund/const/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from pfund.const.commons import *
from pfund.const.common import *
from pfund.const.paths import *
File renamed without changes.
77 changes: 49 additions & 28 deletions pfund/data_tools/data_tool_pandas.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from __future__ import annotations
from collections import defaultdict
from decimal import Decimal

from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING, Generator
if TYPE_CHECKING:
from pfund.datas.data_base import BaseData

Expand All @@ -15,7 +14,12 @@
class PandasDataTool(BaseDataTool):
_INDEX = ['ts', 'product', 'resolution']
_GROUP = ['product', 'resolution']
_DECIMAL_COLS = ['price', 'open', 'high', 'low', 'close', 'volume']

def get_df(self, copy=True):
return self.df.copy(deep=True) if copy else self.df

def concat(self, dfs: list[pd.DataFrame]) -> pd.DataFrame:
return pd.concat(dfs)

def prepare_df(self):
assert self._raw_dfs, "No data is found, make sure add_data(...) is called correctly"
Expand All @@ -24,19 +28,37 @@ def prepare_df(self):
# arrange columns
self.df = self.df[self._INDEX + [col for col in self.df.columns if col not in self._INDEX]]
self._raw_dfs.clear()


def get_total_rows(self, df: pd.DataFrame):
return df.shape[0]

@backtest
def iterate_df_by_chunks(self, df: pd.DataFrame, num_chunks=1) -> Generator[pd.DataFrame, None, None]:
total_rows = self.get_total_rows(df)
chunk_size = total_rows // num_chunks
for i in range(0, total_rows, chunk_size):
df_chunk = df.iloc[i:i + chunk_size].copy(deep=True)
yield df_chunk

@backtest
def preprocess_event_driven_df(self, df: pd.DataFrame) -> Iterator:
def preprocess_event_driven_df(self, df: pd.DataFrame) -> pd.DataFrame:
def _check_resolution(res):
from pfund.datas.resolution import Resolution
resolution = Resolution(res)
return resolution.is_quote(), resolution.is_tick()

# converts 'ts' from datetime to unix timestamp
df['ts'] = df['ts'].astype(int) // 10**6 # in milliseconds
df['ts'] = df['ts'] / 10**3 # in seconds with milliseconds precision
# convert float columns to decimal for consistency with live trading
for col in df.columns:
if col in self._DECIMAL_COLS:
df[col] = df[col].apply(lambda x: Decimal(str(x)))
# TODO: split 'broker' str column from 'product' str column
# df['broker'] = ...
return df.itertuples(index=False)
# in milliseconds int -> in seconds with milliseconds precision
df['ts'] = df['ts'].astype(int) // 10**6 / 10**3

# add 'broker', 'is_quote', 'is_tick' columns
df['broker'] = df['product'].str.split('-').str[0]
df['is_quote'], df['is_tick'] = zip(*df['resolution'].apply(_check_resolution))

# arrange columns
left_cols = self._INDEX + ['broker', 'is_quote', 'is_tick']
df = df[left_cols + [col for col in df.columns if col not in left_cols]]
return df

@backtest
def postprocess_vectorized_df(self, df: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -129,34 +151,33 @@ def append_to_df(self, data: BaseData, predictions: dict, **kwargs):
index=self.create_multi_index(index_data, self.df.index.names)
)
self.df = pd.concat([self.df, new_row], ignore_index=False)

def convert_ts_index_to_dt(self, df: pd.DataFrame) -> pd.DataFrame:
ts_index = df.index.get_level_values('ts')
dt_index = pd.to_datetime(ts_index, unit='s')
df.index = df.index.set_levels(dt_index, level='ts')
return df

def create_multi_index(self, index_data: dict, index_names: list[str]) -> pd.MultiIndex:
return pd.MultiIndex.from_tuples([tuple(index_data[name] for name in index_names)], names=index_names)

def output_df_to_parquet(self, df: pd.DataFrame, file_path: str):
df.to_parquet(file_path, compression='zstd')



'''
************************************************
Helper Functions
************************************************
'''
def get_index_values(self, df: pd.DataFrame, index: str) -> list:
assert index in df.index.names, f"index must be one of {df.index.names}"
return df.index.get_level_values(index).unique().to_list()

def set_index_values(self, df: pd.DataFrame, index: str, values: list) -> pd.DataFrame:
assert index in df.index.names, f"index must be one of {df.index.names}"
df.index = df.index.set_levels(values, level=index)
return df

def output_df_to_parquet(self, df: pd.DataFrame, file_path: str, compression: str='zstd'):
df.to_parquet(file_path, compression=compression)

def filter_df(self, df: pd.DataFrame, start_date: str | None=None, end_date: str | None=None, product: str='', resolution: str=''):
product = product or slice(None)
resolution = resolution or slice(None)
return df.loc[(slice(start_date, end_date), product, resolution), :]

def get_index_values(self, df: pd.DataFrame, index: str) -> list:
assert index in self._INDEX, f"index must be one of {self._INDEX}"
return df.index.get_level_values(index).unique().to_list()

def unstack_df(self, df: pd.DataFrame):
return df.unstack(level=self._GROUP)

Expand Down
81 changes: 71 additions & 10 deletions pfund/data_tools/data_tool_polars.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations
from collections import defaultdict
from decimal import Decimal

from typing import TYPE_CHECKING, Iterator
from typing import TYPE_CHECKING, Generator
if TYPE_CHECKING:
from pfund.datas.data_base import BaseData

import pandas as pd
import polars as pl

from pfund.data_tools.data_tool_base import BaseDataTool
Expand All @@ -15,6 +15,12 @@
class PolarsDataTool(BaseDataTool):
_INDEX = ['ts', 'product', 'resolution']

def get_df(self, copy=True):
return self.df.clone() if copy else self.df

def concat(self, dfs: list[pl.DataFrame | pl.LazyFrame]) -> pl.DataFrame | pl.LazyFrame:
return pl.concat(dfs)

def prepare_df(self):
assert self._raw_dfs, "No data is found, make sure add_data(...) is called correctly"
self.df = pl.concat(self._raw_dfs.values())
Expand All @@ -23,13 +29,55 @@ def prepare_df(self):
self.df = self.df.select(self._INDEX + [col for col in self.df.columns if col not in self._INDEX])
self._raw_dfs.clear()

def get_total_rows(self, df: pl.DataFrame | pl.LazyFrame):
if isinstance(df, pl.DataFrame):
return df.shape[0]
elif isinstance(df, pl.LazyFrame):
return df.count().collect()['ts'][0]
else:
raise ValueError("df should be either pl.DataFrame or pl.LazyFrame")

@backtest
def preprocess_event_driven_df(self, df: pl.DataFrame | pl.LazyFrame) -> Iterator:
pass
def iterate_df_by_chunks(self, lf: pl.LazyFrame, num_chunks=1) -> Generator[pd.DataFrame, None, None]:
total_rows = self.get_total_rows(lf)
chunk_size = total_rows // num_chunks
for i in range(0, total_rows, chunk_size):
df_chunk = lf.slice(i, chunk_size).collect()
yield df_chunk

@backtest
def postprocess_vectorized_df(self, df: pl.DataFrame | pl.LazyFrame) -> pl.DataFrame | pl.LazyFrame:
pass
def preprocess_event_driven_df(self, df: pl.DataFrame) -> pl.DataFrame:
def _check_resolution(res):
from pfund.datas.resolution import Resolution
resolution = Resolution(res)
return {
'is_quote': resolution.is_quote(),
'is_tick': resolution.is_tick()
}

df = df.with_columns(
# converts 'ts' from datetime to unix timestamp
pl.col("ts").cast(pl.Int64) // 10**6 / 10**3,

# add 'broker', 'is_quote', 'is_tick' columns
pl.col('product').str.split("-").list.get(0).alias("broker"),
pl.col('resolution').map_elements(
_check_resolution,
return_dtype=pl.Struct([
pl.Field('is_quote', pl.Boolean),
pl.Field('is_tick', pl.Boolean)
])
).alias('Resolution')
).unnest('Resolution')

# arrange columns
left_cols = self._INDEX + ['broker', 'is_quote', 'is_tick']
df = df.select(left_cols + [col for col in df.columns if col not in left_cols])
return df

@backtest
def postprocess_vectorized_df(self, df: pl.DataFrame) -> pl.LazyFrame:
return df.lazy()

# TODO:
def prepare_df_with_signals(self, models):
Expand All @@ -39,13 +87,26 @@ def prepare_df_with_signals(self, models):
def prepare_datasets(self, datas):
pass

# TODO:
def clear_df(self):
pass
self.df.clear()

# TODO:
def append_to_df(self, data: BaseData, predictions: dict, **kwargs):
pass

def output_df_to_parquet(self, df: pl.DataFrame | pl.LazyFrame, file_path: str):
df.write_parquet(file_path, compression='zstd')

'''
************************************************
Helper Functions
************************************************
'''
def output_df_to_parquet(self, df: pl.DataFrame | pl.LazyFrame, file_path: str, compression: str='zstd'):
df.write_parquet(file_path, compression=compression)

# TODO
def filter_df(self, df: pl.DataFrame | pl.LazyFrame, **kwargs) -> pl.DataFrame | pl.LazyFrame:
pass

# TODO
def unstack_df(self, df: pl.DataFrame | pl.LazyFrame, **kwargs) -> pl.DataFrame | pl.LazyFrame:
pass
25 changes: 9 additions & 16 deletions pfund/datas/data_bar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sys
import logging
from decimal import Decimal

from pfund.datas.resolution import Resolution
from pfund.datas.data_time_based import TimeBasedData
Expand All @@ -21,14 +20,16 @@ def __init__(self, product, resolution, shift: int=0):
self.timeframe = resolution.timeframe
self.unit = self.get_unit()
self.shift_unit = self.get_shift_unit(shift)
# variables that will be cleared using clear() for each new bar
self.o = self.open = Decimal(0.0)
self.h = self.high = Decimal(0.0)
self.l = self.low = Decimal(sys.float_info.max)
self.c = self.close = Decimal(0.0)
self.v = self.volume = Decimal(0.0)
self._start_ts = self._end_ts = self.ts = 0.0
self.clear()

def clear(self):
self.o = self.open = 0.0
self.h = self.high = 0.0
self.l = self.low = sys.float_info.max
self.c = self.close = 0.0
self.v = self.volume = 0.0
self._start_ts = self._end_ts = self.ts = 0.0

def __str__(self):
bar_type = 'Bar'
if not self._start_ts:
Expand Down Expand Up @@ -116,14 +117,6 @@ def get_unit(self):
unit = 60 * 60 * 24 * 7 * 4 * self.period
return unit

def clear(self):
self.o = self.open = Decimal(0.0)
self.h = self.high = Decimal(0.0)
self.l = self.low = Decimal(sys.float_info.max)
self.c = self.close = Decimal(0.0)
self.v = self.volume = Decimal(0.0)
self._start_ts = self._end_ts = self.ts = 0.0

def update(self, o, h, l, c, v, ts, is_volume_aggregated):
if not self.o:
self.o = self.open = o
Expand Down
2 changes: 1 addition & 1 deletion pfund/datas/resolution.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re

from pfund.datas.timeframe import Timeframe, TimeframeUnits
from pfund.const.commons import SUPPORTED_TIMEFRAMES
from pfund.const.common import SUPPORTED_TIMEFRAMES


class Resolution:
Expand Down
Loading

0 comments on commit 3d62d31

Please sign in to comment.