Skip to content

Commit

Permalink
Merge pull request #39 from r0fls/rd/order-tracking
Browse files Browse the repository at this point in the history
update balance with order placement
  • Loading branch information
r0fls authored Jun 18, 2024
2 parents 0491ac6 + 5dc3684 commit 43eed6f
Showing 1 changed file with 64 additions and 21 deletions.
85 changes: 64 additions & 21 deletions brokers/base_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from sqlalchemy.orm import sessionmaker
from sqlalchemy.sql import and_
from database.db_manager import DBManager
from database.models import Trade, AccountInfo, Position
from database.models import Trade, AccountInfo, Position, Balance
from datetime import datetime
from utils.logger import logger # Import the logger


class BaseBroker(ABC):
def __init__(self, api_key, secret_key, broker_name, engine, prevent_day_trading=False):
self.api_key = api_key
Expand All @@ -16,7 +17,8 @@ def __init__(self, api_key, secret_key, broker_name, engine, prevent_day_trading
self.Session = sessionmaker(bind=engine)
self.account_id = None
self.prevent_day_trading = prevent_day_trading
logger.info('Initialized BaseBroker', extra={'broker_name': self.broker_name})
logger.info('Initialized BaseBroker', extra={
'broker_name': self.broker_name})

@abstractmethod
def connect(self):
Expand Down Expand Up @@ -49,11 +51,14 @@ def get_current_price(self, symbol):
def get_account_info(self):
try:
account_info = self._get_account_info()
self.db_manager.add_account_info(AccountInfo(broker=self.broker_name, value=account_info['value']))
logger.info('Account information retrieved', extra={'account_info': account_info})
self.db_manager.add_account_info(AccountInfo(
broker=self.broker_name, value=account_info['value']))
logger.info('Account information retrieved',
extra={'account_info': account_info})
return account_info
except Exception as e:
logger.error('Failed to get account information', extra={'error': str(e)})
logger.error('Failed to get account information',
extra={'error': str(e)})
return None

def has_bought_today(self, symbol):
Expand All @@ -68,15 +73,18 @@ def has_bought_today(self, symbol):
Trade.timestamp >= today
)
).all()
logger.info('Checked for trades today', extra={'symbol': symbol, 'trade_count': len(trades)})
logger.info('Checked for trades today', extra={
'symbol': symbol, 'trade_count': len(trades)})
return len(trades) > 0
except Exception as e:
logger.error('Failed to check if bought today', extra={'error': str(e)})
logger.error('Failed to check if bought today',
extra={'error': str(e)})
return False

def update_positions(self, session, trade):
try:
position = session.query(Position).filter_by(symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy).first()
position = session.query(Position).filter_by(
symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy).first()

if trade.order_type == 'buy':
if position:
Expand All @@ -97,7 +105,8 @@ def update_positions(self, session, trade):
position.quantity -= trade.quantity
position.latest_price = trade.executed_price
if position.quantity < 0:
logger.error('Sell quantity exceeds current position quantity', extra={'trade': trade})
logger.error('Sell quantity exceeds current position quantity', extra={
'trade': trade})
position.quantity = 0 # Set to 0 to handle the error gracefully

session.commit()
Expand All @@ -106,19 +115,23 @@ def update_positions(self, session, trade):
logger.error('Failed to update positions', extra={'error': str(e)})

async def place_order(self, symbol, quantity, order_type, strategy, price=None):
logger.info('Placing order', extra={'symbol': symbol, 'quantity': quantity, 'order_type': order_type, 'strategy': strategy})

logger.info('Placing order', extra={
'symbol': symbol, 'quantity': quantity, 'order_type': order_type, 'strategy': strategy})

if self.prevent_day_trading and order_type == 'sell':
if self.has_bought_today(symbol):
logger.error('Day trading is not allowed. Cannot sell positions opened today.', extra={'symbol': symbol})
logger.error('Day trading is not allowed. Cannot sell positions opened today.', extra={
'symbol': symbol})
return None

try:
if asyncio.iscoroutinefunction(self._place_order):
response = await self._place_order(symbol, quantity, order_type, price)
else:
response = self._place_order(symbol, quantity, order_type, price)
logger.info('Order placed successfully', extra={'response': response})
response = self._place_order(
symbol, quantity, order_type, price)
logger.info('Order placed successfully',
extra={'response': response})

trade = Trade(
symbol=symbol,
Expand All @@ -133,12 +146,36 @@ async def place_order(self, symbol, quantity, order_type, strategy, price=None):
profit_loss=0,
success='yes'
)

with self.Session() as session:
session.add(trade)
session.commit()
self.update_positions(session, trade)

# Fetch the latest cash balance for the strategy
latest_balance = session.query(Balance).filter_by(
strategy=strategy, type='cash').order_by(Balance.timestamp.desc()).first()
if latest_balance:
# Calculate the order cost
order_cost = trade.executed_price * quantity

# Subtract the order cost from the cash balance
if order_type == 'buy':
new_balance_amount = latest_balance.amount - order_cost
else: # order_type == 'sell'
new_balance_amount = latest_balance.amount + order_cost

# Create a new balance record with the updated cash balance
new_balance = Balance(
broker=self.broker_name,
strategy=strategy,
type='cash',
amount=new_balance_amount,
timestamp=datetime.now()
)
session.add(new_balance)
session.commit()

return response
except Exception as e:
logger.error('Failed to place order', extra={'error': str(e)})
Expand All @@ -152,7 +189,8 @@ def get_order_status(self, order_id):
trade = session.query(Trade).filter_by(id=order_id).first()
if trade:
self.update_trade(session, trade.id, order_status)
logger.info('Order status retrieved', extra={'order_status': order_status})
logger.info('Order status retrieved', extra={
'order_status': order_status})
return order_status
except Exception as e:
logger.error('Failed to get order status', extra={'error': str(e)})
Expand All @@ -166,27 +204,32 @@ def cancel_order(self, order_id):
trade = session.query(Trade).filter_by(id=order_id).first()
if trade:
self.update_trade(session, trade.id, cancel_status)
logger.info('Order cancelled successfully', extra={'cancel_status': cancel_status})
logger.info('Order cancelled successfully', extra={
'cancel_status': cancel_status})
return cancel_status
except Exception as e:
logger.error('Failed to cancel order', extra={'error': str(e)})
return None

def get_options_chain(self, symbol, expiration_date):
logger.info('Retrieving options chain', extra={'symbol': symbol, 'expiration_date': expiration_date})
logger.info('Retrieving options chain', extra={
'symbol': symbol, 'expiration_date': expiration_date})
try:
options_chain = self._get_options_chain(symbol, expiration_date)
logger.info('Options chain retrieved', extra={'options_chain': options_chain})
logger.info('Options chain retrieved', extra={
'options_chain': options_chain})
return options_chain
except Exception as e:
logger.error('Failed to retrieve options chain', extra={'error': str(e)})
logger.error('Failed to retrieve options chain',
extra={'error': str(e)})
return None

def update_trade(self, session, trade_id, order_info):
try:
trade = session.query(Trade).filter_by(id=trade_id).first()
if not trade:
logger.error('Trade not found for update', extra={'trade_id': trade_id})
logger.error('Trade not found for update',
extra={'trade_id': trade_id})
return

executed_price = order_info.get('filled_price', trade.price)
Expand Down

0 comments on commit 43eed6f

Please sign in to comment.