Skip to content

Commit

Permalink
introduce --dry-run mode (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
junyuanz1 authored Dec 6, 2024
1 parent e3be592 commit 3921e3e
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 70 deletions.
9 changes: 7 additions & 2 deletions thetagang/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@
help="Run without IBC. Enable this if you want to run the TWS "
"gateway yourself, without having ThetaGang manage it for you.",
)
def cli(config: str, without_ibc: bool) -> None:
@click.option(
"--dry-run",
is_flag=True,
help="Perform a dry run. This will display the the orders without sending any live trades.",
)
def cli(config: str, without_ibc: bool, dry_run: bool) -> None:
"""ThetaGang is an IBKR bot for collecting money.
You can configure this tool by supplying a toml configuration file.
Expand All @@ -38,4 +43,4 @@ def cli(config: str, without_ibc: bool) -> None:

from .thetagang import start

start(config, without_ibc)
start(config, without_ibc, dry_run)
46 changes: 46 additions & 0 deletions thetagang/orders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from typing import List, Tuple

from ib_async import Contract, LimitOrder
from rich import box
from rich.pretty import Pretty
from rich.table import Table

from thetagang import log
from thetagang.fmt import dfmt, ifmt


class Orders:
def __init__(self) -> None:
self.__records: List[Tuple[Contract, LimitOrder]] = []

def add_order(self, contract: Contract, order: LimitOrder) -> None:
self.__records.append((contract, order))

def records(self) -> List[Tuple[Contract, LimitOrder]]:
return self.__records

def print_summary(self) -> None:
if not self.__records:
return

table = Table(
title="Order Summary", show_lines=True, box=box.MINIMAL_HEAVY_HEAD
)
table.add_column("Symbol")
table.add_column("Exchange")
table.add_column("Contract")
table.add_column("Action")
table.add_column("Price")
table.add_column("Qty")

for contract, order in self.__records:
table.add_row(
contract.symbol,
contract.exchange,
Pretty(contract, indent_size=2),
order.action,
dfmt(order.lmtPrice),
ifmt(int(order.totalQuantity)),
)

log.print(table)
102 changes: 36 additions & 66 deletions thetagang/portfolio_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,23 @@
import numpy as np
from ib_async import (
AccountValue,
Order,
PortfolioItem,
TagValue,
Ticker,
Trade,
util,
)
from ib_async.contract import ComboLeg, Contract, Index, Option, Stock
from ib_async.ib import IB
from ib_async.order import LimitOrder
from rich import box
from rich.console import Group
from rich.panel import Panel
from rich.pretty import Pretty
from rich.table import Table

from thetagang import log
from thetagang.fmt import dfmt, ffmt, ifmt, pfmt
from thetagang.ibkr import IBKR, RequiredFieldValidationError, TickerField
from thetagang.orders import Orders
from thetagang.trades import Trades
from thetagang.util import (
account_summary_to_dict,
algo_params_from,
Expand Down Expand Up @@ -72,7 +70,11 @@ def __init__(self, message: str) -> None:

class PortfolioManager:
def __init__(
self, config: Dict[str, Dict[str, Any]], ib: IB, completion_future: Future[bool]
self,
config: Dict[str, Dict[str, Any]],
ib: IB,
completion_future: Future[bool],
dry_run: bool,
) -> None:
self.account_number = config["account"]["number"]
self.config = config
Expand All @@ -84,10 +86,11 @@ def __init__(
self.completion_future = completion_future
self.has_excess_calls: set[str] = set()
self.has_excess_puts: set[str] = set()
self.orders: List[tuple[Contract, LimitOrder]] = []
self.trades: List[Trade] = []
self.orders: Orders = Orders()
self.trades: Trades = Trades(self.ibkr)
self.target_quantities: Dict[str, int] = {}
self.qualified_contracts: Dict[int, Contract] = {}
self.dry_run = dry_run

def get_short_calls(
self, portfolio_positions: Dict[str, List[PortfolioItem]]
Expand Down Expand Up @@ -587,17 +590,22 @@ async def manage(self) -> None:
# manage dat cash
await self.do_cashman(account_summary, portfolio_positions)

self.submit_orders()
if self.dry_run:
log.warning("Dry run enabled, no trades will be executed.")

try:
await self.ibkr.wait_for_submitting_orders(self.trades)
except RuntimeError:
log.error("Submitting orders failed. Continuing anyway..")
pass
self.orders.print_summary()
else:
self.submit_orders()

await self.adjust_prices()
try:
await self.ibkr.wait_for_submitting_orders(self.trades.records())
except RuntimeError:
log.error("Submitting orders failed. Continuing anyway..")
pass

await self.ibkr.wait_for_submitting_orders(self.trades)
await self.adjust_prices()

await self.ibkr.wait_for_submitting_orders(self.trades.records())

log.info("ThetaGang is done, shutting down! Cya next time. :sparkles:")
except:
Expand Down Expand Up @@ -1951,13 +1959,13 @@ def get_multiplier(contract: Contract) -> float:
return sum(
[
order.lmtPrice * order.totalQuantity * get_multiplier(contract)
for (contract, order) in self.orders
for (contract, order) in self.orders.records()
if order.action == "SELL"
]
) - sum(
[
order.lmtPrice * order.totalQuantity * get_multiplier(contract)
for (contract, order) in self.orders
for (contract, order) in self.orders.records()
if order.action == "BUY"
]
)
Expand Down Expand Up @@ -2090,49 +2098,12 @@ async def make_order() -> tuple[Optional[Ticker], Optional[LimitOrder]]:
def enqueue_order(self, contract: Optional[Contract], order: LimitOrder) -> None:
if not contract:
return
self.orders.append((contract, order))
self.orders.add_order(contract, order)

def submit_orders(self) -> None:
def submit(contract: Contract, order: Order) -> Optional[Trade]:
try:
trade = self.ibkr.place_order(contract, order)
return trade
except RuntimeError:
log.error(f"Failed to submit contract: {contract}, order: {order}")
return None

self.trades = [
trade
for trade in [submit(order[0], order[1]) for order in self.orders]
if trade
]

if len(self.trades) > 0:
table = Table(
title="Orders submitted", show_lines=True, box=box.MINIMAL_HEAVY_HEAD
)
table.add_column("Symbol")
table.add_column("Exchange")
table.add_column("Contract")
table.add_column("Action")
table.add_column("Price")
table.add_column("Qty")
table.add_column("Status")
table.add_column("Filled")

for trade in self.trades:
if trade:
table.add_row(
trade.contract.symbol,
trade.contract.exchange,
Pretty(trade.contract, indent_size=2),
trade.order.action,
dfmt(trade.order.lmtPrice),
ifmt(int(trade.order.totalQuantity)),
trade.orderStatus.status,
ffmt(trade.orderStatus.filled, 0),
)
log.print(table)
for contract, order in self.orders.records():
self.trades.submit_order(contract, order)
self.trades.print_summary()

async def adjust_prices(self) -> None:
if (
Expand All @@ -2144,7 +2115,7 @@ async def adjust_prices(self) -> None:
for symbol in self.config["symbols"]
]
)
or len(self.trades) == 0
or self.trades.is_empty()
):
log.warning("Skipping order price adjustments...")
return
Expand All @@ -2154,11 +2125,11 @@ async def adjust_prices(self) -> None:
self.config["orders"]["price_update_delay"][1],
)

await self.ibkr.wait_for_orders_complete(self.trades, delay)
await self.ibkr.wait_for_orders_complete(self.trades.records(), delay)

unfilled = [
(idx, trade)
for idx, trade in enumerate(self.trades)
for idx, trade in enumerate(self.trades.records())
if trade
and trade.contract.symbol in self.config["symbols"]
and self.config["symbols"][trade.contract.symbol].get(
Expand Down Expand Up @@ -2220,12 +2191,11 @@ async def adjust_prices(self) -> None:
algoParams=order.algoParams,
)

# put the trade back from whence it came
self.trades[idx] = self.ibkr.place_order(contract, order)
# resubmit the order and it will be placed back to the
# original position in the queue
self.trades.submit_order(contract, order, idx)

log.info(
f"{contract.symbol}: Order updated, order={self.trades[idx].order}"
)
log.info(f"{contract.symbol}: Order updated, order={order}")
except (RuntimeError, RequiredFieldValidationError):
log.error(
f"Couldn't generate midpoint price for {trade.contract}, skipping"
Expand Down
4 changes: 2 additions & 2 deletions thetagang/thetagang.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
console = Console()


def start(config_path: str, without_ibc: bool = False) -> None:
def start(config_path: str, without_ibc: bool = False, dry_run: bool = False) -> None:
import toml

with open(config_path, "r", encoding="utf8") as file:
Expand Down Expand Up @@ -418,7 +418,7 @@ async def onConnected() -> None:
ib.connectedEvent += onConnected

completion_future: Future[bool] = Future()
portfolio_manager = PortfolioManager(config, ib, completion_future)
portfolio_manager = PortfolioManager(config, ib, completion_future, dry_run)

probeContractConfig = config["watchdog"]["probeContract"]
watchdogConfig = config.get("watchdog", {})
Expand Down
70 changes: 70 additions & 0 deletions thetagang/trades.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import List, Optional

from ib_async import Contract, LimitOrder, Trade
from rich import box
from rich.pretty import Pretty
from rich.table import Table

from thetagang import log
from thetagang.fmt import dfmt, ffmt, ifmt
from thetagang.ibkr import IBKR


class Trades:
def __init__(self, ibkr: IBKR) -> None:
self.ibkr = ibkr
self.__records: List[Trade] = []

def submit_order(
self, contract: Contract, order: LimitOrder, idx: Optional[int] = None
) -> None:
try:
trade = self.ibkr.place_order(contract, order)
if idx is not None:
self.__replace_trade(trade, idx)
else:
self.__add_trade(trade)
except RuntimeError:
log.error(f"{contract.symbol}: Failed to submit contract, order={order}")

def records(self) -> List[Trade]:
return self.__records

def is_empty(self) -> bool:
return len(self.__records) == 0

def print_summary(self) -> None:
if not self.__records:
return

table = Table(
title="Trade Summary", show_lines=True, box=box.MINIMAL_HEAVY_HEAD
)
table.add_column("Symbol")
table.add_column("Exchange")
table.add_column("Contract")
table.add_column("Action")
table.add_column("Price")
table.add_column("Qty")
table.add_column("Status")
table.add_column("Filled")

for trade in self.__records:
table.add_row(
trade.contract.symbol,
trade.contract.exchange,
Pretty(trade.contract, indent_size=2),
trade.order.action,
dfmt(trade.order.lmtPrice),
ifmt(int(trade.order.totalQuantity)),
trade.orderStatus.status,
ffmt(trade.orderStatus.filled, 0),
)

log.print(table)

def __add_trade(self, trade: Trade) -> None:
self.__records.append(trade)

def __replace_trade(self, trade: Trade, idx: int) -> None:
self.__records[idx] = trade
Loading

0 comments on commit 3921e3e

Please sign in to comment.