Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rd/update tastytrade broker #36

Merged
merged 9 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/docker-image.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
name: Build and Push Docker image

on:
pull_request:
push:
branches:
- main
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ Chart.lock
charts/
test-config.yaml
app.log
venv
3 changes: 1 addition & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Use an official Python runtime as a parent image
FROM python:3.9-buster
FROM python:3.12-buster

# Set the working directory in the container
WORKDIR /app
Expand Down
9 changes: 6 additions & 3 deletions brokers/base_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class BaseBroker(ABC):
def __init__(self, api_key, secret_key, broker_name, engine, prevent_day_trading=False):
self.api_key = api_key
self.secret_key = secret_key
self.broker_name = broker_name
self.broker_name = broker_name.lower()
self.db_manager = DBManager(engine)
self.Session = sessionmaker(bind=engine)
self.account_id = None
Expand Down Expand Up @@ -104,7 +104,7 @@ def update_positions(self, session, trade):
except Exception as e:
logger.error('Failed to update positions', extra={'error': str(e)})

def place_order(self, symbol, quantity, order_type, strategy, price=None):
async def place_order(self, symbol, quantity, order_type, strategy, price=None):
logger.info('Placing order', extra={'symbol': symbol, 'quantity': quantity, 'order_type': order_type, 'strategy': strategy})

if self.prevent_day_trading and order_type == 'sell':
Expand All @@ -113,7 +113,10 @@ def place_order(self, symbol, quantity, order_type, strategy, price=None):
return None

try:
response = self._place_order(symbol, quantity, order_type, price)
if self.broker_name == 'tastytrade':
response = await self._place_order(symbol, quantity, order_type, price)
else:
response = self._place_order(symbol, quantity, order_type, price)
logger.info('Order placed successfully', extra={'response': response})

trade = Trade(
Expand Down
193 changes: 157 additions & 36 deletions brokers/tastytrade_broker.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,171 @@
import requests
import time
import json
from brokers.base_broker import BaseBroker
from utils.logger import logger # Import the logger
from tastytrade import ProductionSession
from tastytrade import DXLinkStreamer
from tastytrade.dxfeed import EventType

class TastytradeBroker(BaseBroker):
def __init__(self, api_key, secret_key, engine):
super().__init__(api_key, secret_key, 'Tastytrade', engine)
def __init__(self, username, password, engine, **kwargs):
super().__init__(username, password, 'Tastytrade', engine=engine, **kwargs)
self.base_url = 'https://api.tastytrade.com'
self.username = username
self.password = password
self.headers = {
"Accept": "application/json",
"Content-Type": "application/json"
}
self.order_timeout = 1
self.auto_cancel_orders = True
logger.info('Initialized TastytradeBroker', extra={'base_url': self.base_url})
self.session = None
self.connect()

def connect(self):
# Implement the connection logic
response = requests.post("https://api.tastytrade.com/oauth/token", data={"key": self.api_key, "secret": self.secret_key})
self.auth = response.json().get('access_token')
logger.info('Connecting to Tastytrade API')
auth_data = {
"login": self.username,
"password": self.password,
"remember-me": True
}
response = requests.post(f"{self.base_url}/sessions", json=auth_data, headers={"Content-Type": "application/json"})
response.raise_for_status()
auth_response = response.json().get('data')
self.auth = auth_response['session-token']
self.headers["Authorization"] = self.auth
if self.session is None:
self.session = ProductionSession(self.username, self.password)
logger.info('Connected to Tastytrade API')

def _get_account_info(self):
response = requests.get("https://api.tastytrade.com/accounts", headers={"Authorization": f"Bearer {self.auth}"})
account_info = response.json()
account_id = account_info['accounts'][0]['accountId']
self.account_id = account_id
account_data = account_info.get('accounts')[0]
return {'value': account_data.get('value')}

def _place_order(self, symbol, quantity, order_type, price=None):
# Implement order placement
order_data = {
"symbol": symbol,
"quantity": quantity,
"order_type": order_type,
"price": price
}
response = requests.post("https://api.tastytrade.com/orders", json=order_data, headers={"Authorization": f"Bearer {self.auth}"})
return response.json()
logger.info('Retrieving account information')
try:
response = requests.get(f"{self.base_url}/customers/me/accounts", headers=self.headers)
response.raise_for_status()
account_info = response.json()
account_id = account_info['data']['items'][0]['account']['account-number']
self.account_id = account_id
logger.info('Account info retrieved', extra={'account_id': self.account_id})

response = requests.get(f"{self.base_url}/accounts/{self.account_id}/balances", headers=self.headers)
response.raise_for_status()
account_data = response.json().get('data')

if not account_data:
logger.error("Invalid account info response")

buying_power = account_data['equity-buying-power']
account_value = account_data['net-liquidating-value']
account_type = None

logger.info('Account balances retrieved', extra={'account_type': account_type, 'buying_power': buying_power, 'value': account_value})
return {
'account_number': self.account_id,
'account_type': account_type,
'buying_power': float(buying_power),
'value': float(account_value)
}
except requests.RequestException as e:
logger.error('Failed to retrieve account information', extra={'error': str(e)})

def get_positions(self):
logger.info('Retrieving positions')
url = f"{self.base_url}/accounts/{self.account_id}/positions"
try:
response = requests.get(url, headers=self.headers)
response.raise_for_status()
positions_data = response.json()['data']['items']

positions = {p['symbol']: p for p in positions_data}
logger.info('Positions retrieved', extra={'positions': positions})
return positions
except requests.RequestException as e:
logger.error('Failed to retrieve positions', extra={'error': str(e)})

async def _place_order(self, symbol, quantity, order_type, price=None):
logger.info('Placing order', extra={'symbol': symbol, 'quantity': quantity, 'order_type': order_type, 'price': price})
try:
last_price = await self.get_current_price(symbol)

if price is None:
price = round(last_price, 2)

order_data = {
"class": "equity",
"symbol": symbol,
"quantity": quantity,
"side": order_type,
"type": "limit",
"duration": "day",
"price": price
}

response = requests.post(f"{self.base_url}/accounts/{self.account_id}/orders", json=order_data, headers=self.headers)
response.raise_for_status()

order_id = response.json()['data']['order']['order_id']
logger.info('Order placed', extra={'order_id': order_id})

if self.auto_cancel_orders:
time.sleep(self.order_timeout)
order_status_url = f"{self.base_url}/accounts/{self.account_id}/orders/{order_id}"
status_response = requests.get(order_status_url, headers=self.headers)
status_response.raise_for_status()
order_status = status_response.json()['data']['order']['status']

if order_status != 'filled':
cancel_url = f"{self.base_url}/accounts/{self.account_id}/orders/{order_id}/cancel"
cancel_response = requests.put(cancel_url, headers=self.headers)
cancel_response.raise_for_status()
logger.info('Order cancelled', extra={'order_id': order_id})

data = response.json()
if data.get('filled_price') is None:
data['filled_price'] = price
logger.info('Order execution complete', extra={'order_data': data})
return data
except requests.RequestException as e:
logger.error('Failed to place order', extra={'error': str(e)})

def _get_order_status(self, order_id):
# Implement order status retrieval
response = requests.get(f"https://api.tastytrade.com/orders/{order_id}", headers={"Authorization": f"Bearer {self.auth}"})
return response.json()
logger.info('Retrieving order status', extra={'order_id': order_id})
try:
response = requests.get(f"{self.base_url}/accounts/{self.account_id}/orders/{order_id}", headers=self.headers)
response.raise_for_status()
order_status = response.json()
logger.info('Order status retrieved', extra={'order_status': order_status})
return order_status
except requests.RequestException as e:
logger.error('Failed to retrieve order status', extra={'error': str(e)})

def _cancel_order(self, order_id):
# Implement order cancellation
response = requests.put(f"https://api.tastytrade.com/orders/{order_id}/cancel", headers={"Authorization": f"Bearer {self.auth}"})
return response.json()
logger.info('Cancelling order', extra={'order_id': order_id})
try:
response = requests.put(f"{self.base_url}/accounts/{self.account_id}/orders/{order_id}/cancel", headers=self.headers)
response.raise_for_status()
cancellation_response = response.json()
logger.info('Order cancelled successfully', extra={'cancellation_response': cancellation_response})
return cancellation_response
except requests.RequestException as e:
logger.error('Failed to cancel order', extra={'error': str(e)})

def _get_options_chain(self, symbol, expiration_date):
# Implement options chain retrieval
response = requests.get(f"https://api.tastytrade.com/markets/options/chains?symbol={symbol}&expiration={expiration_date}", headers={"Authorization": f"Bearer {self.auth}"})
return response.json()

def get_current_price(self, symbol):
# Implement current price retrieval
response = requests.get(f"https://api.tastytrade.com/markets/quotes/{symbol}", headers={"Authorization": f"Bearer {self.auth}"})
return response.json().get('lastPrice')
logger.info('Retrieving options chain', extra={'symbol': symbol, 'expiration_date': expiration_date})
try:
response = requests.get(f"{self.base_url}/markets/options/chains", params={"symbol": symbol, "expiration": expiration_date}, headers=self.headers)
response.raise_for_status()
options_chain = response.json()
logger.info('Options chain retrieved', extra={'options_chain': options_chain})
return options_chain
except requests.RequestException as e:
logger.error('Failed to retrieve options chain', extra={'error': str(e)})

async def get_current_price(self, symbol):
async with DXLinkStreamer(self.session) as streamer:
subs_list = [symbol]
await streamer.subscribe(EventType.QUOTE, subs_list)
quote = await streamer.get_event(EventType.QUOTE)
# Just return the mid price for now
return round(float((quote.bidPrice + quote.askPrice) / 2), 2)
3 changes: 1 addition & 2 deletions database/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from sqlalchemy import Column, Integer, String, Float, DateTime, create_engine, ForeignKey, PrimaryKeyConstraint, ForeignKeyConstraint
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.orm import sessionmaker, relationship, declarative_base
from datetime import datetime

Base = declarative_base()
Expand Down
11 changes: 6 additions & 5 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import time
import os
from datetime import datetime, timedelta
Expand All @@ -8,7 +9,7 @@
from sqlalchemy import create_engine
from utils.logger import logger # Import the logger

def start_trading_system(config_path):
async def start_trading_system(config_path):
logger.info('Starting the trading system', extra={'config_path': config_path})

# Parse the configuration file
Expand Down Expand Up @@ -62,7 +63,7 @@ def start_trading_system(config_path):
for i, strategy in enumerate(strategies):
if now - last_rebalances[i] >= rebalance_intervals[i]:
try:
strategy.rebalance()
await strategy.rebalance()
last_rebalances[i] = now
logger.info(f'Strategy {i} rebalanced successfully', extra={'time': now})
except Exception as e:
Expand Down Expand Up @@ -109,7 +110,7 @@ def start_api_server(config_path=None, local_testing=False):
except Exception as e:
logger.error('Failed to start API server', extra={'error': str(e)})

def main():
async def main():
parser = argparse.ArgumentParser(description="Run trading strategies or start API server based on YAML configuration.")
parser.add_argument('--mode', choices=['trade', 'api'], required=True, help='Mode to run the system in: "trade" or "api"')
parser.add_argument('--config', type=str, help='Path to the YAML configuration file.')
Expand All @@ -119,9 +120,9 @@ def main():
if args.mode == 'trade':
if not args.config:
parser.error('--config is required when mode is "trade"')
start_trading_system(args.config)
await start_trading_system(args.config)
elif args.mode == 'api':
start_api_server(config_path=args.config, local_testing=args.local_testing)

if __name__ == "__main__":
main()
asyncio.run(main())
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,5 @@ pytz
numpy
scipy
psycopg2
websocket-client
tastytrade
8 changes: 6 additions & 2 deletions strategies/constant_percentage_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def __init__(self, broker, stock_allocations, cash_percentage, rebalance_interva
self.sync_positions_with_broker() # Ensure positions are synced on initialization
logger.info(f"Initialized {self.strategy_name} strategy with starting capital {self.starting_capital}")

def rebalance(self):
async def rebalance(self):
logger.debug("Starting rebalance process")
self.sync_positions_with_broker() # Ensure positions are synced before rebalancing

Expand Down Expand Up @@ -47,7 +47,11 @@ def rebalance(self):
for stock, allocation in self.stock_allocations.items():
target_balance = target_investment_balance * allocation
current_position = current_positions.get(stock, 0)
current_price = self.broker.get_current_price(stock)
# async price fetcher
if self.broker.broker_name == 'tastytrade':
current_price = await self.broker.get_current_price(stock)
else:
current_price = self.broker.get_current_price(stock)
target_quantity = target_balance // current_price
logger.debug(f"Stock: {stock}, Allocation: {allocation}, Target balance: {target_balance}, Current position: {current_position}, Current price: {current_price}, Target quantity: {target_quantity}")

Expand Down
6 changes: 3 additions & 3 deletions tests/test_brokers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from unittest.mock import MagicMock, patch
from datetime import datetime
from datetime import datetime, timezone
from database.models import Trade, Balance, Position
from .base_test import BaseTest
from brokers.base_broker import BaseBroker
Expand Down Expand Up @@ -42,7 +42,7 @@ def setUp(self):

# Additional setup
additional_fake_trades = [
Trade(symbol='MSFT', quantity=8, price=200.0, executed_price=202.0, order_type='buy', status='executed', timestamp=datetime.utcnow(), broker='Tastytrade', strategy='RSI', profit_loss=16.0, success='yes'),
Trade(symbol='MSFT', quantity=8, price=200.0, executed_price=202.0, order_type='buy', status='executed', timestamp=datetime.now(timezone.utc), broker='Tastytrade', strategy='RSI', profit_loss=16.0, success='yes'),
]
self.session.add_all(additional_fake_trades)
self.session.commit()
Expand All @@ -56,7 +56,7 @@ def skip_test_execute_trade(self):
'executed_price': 151.0,
'order_type': 'buy',
'status': 'executed',
'timestamp': datetime.utcnow(),
'timestamp': datetime.now(timezone.utc),
'broker': 'E*TRADE',
'strategy': 'SMA',
'profit_loss': 10.0,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ def setUp(self):
api_key: "your_tradier_api_key"
tastytrade:
type: "tastytrade"
api_key: "your_tastytrade_api_key"
secret_key: "example_key"
username: "your_tastytrade_username"
password: "password"
strategies:
- type: "constant_percentage"
broker: "tradier"
Expand Down
Loading
Loading