Skip to content

Commit

Permalink
Bring back client from login flow (#122)
Browse files Browse the repository at this point in the history
* Reinstate client_from_login_flow

* Added initial tests

* Adds missing dependencies

* fix tests

* Ignore coverage reports on unreachable code

* tolerate the insecure request to the server at the start

* move pragma no cover for server

* swap

* one more time

* disallow netlocs other than 127.0.0.1

* Fixed tests

* test for disallowed hostname with port number

* added todo
  • Loading branch information
alexgolec authored Jun 12, 2024
1 parent c576ac7 commit 632058e
Show file tree
Hide file tree
Showing 5 changed files with 272 additions and 6 deletions.
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Note this requirements file is only used by sphinx. For regular installations,
# see the setup.py file.
authlib==1.3.0
Flask==3.0.3
httpx==0.27.0
prompt_toolkit==3.0.43
psutil==5.9.8
python-dateutil==2.9.0.post0
sphinx-rtd-theme==2.0.0
websockets==12.0
129 changes: 127 additions & 2 deletions schwab/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
##########################################################################
# Authentication Wrappers


from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuth2Client
from prompt_toolkit import prompt

import urllib
import json
import logging
import multiprocessing
import os
import psutil
import queue
import requests
import sys
import time
import urllib3
import warnings
import webbrowser

from schwab.client import AsyncClient, Client
from schwab.debug import register_redactions
Expand Down Expand Up @@ -159,6 +165,125 @@ async def oauth_client_update_token(t, *args, **kwargs):
token_metadata=metadata_manager, enforce_enums=enforce_enums)


# This runs in a separate process and is invisible to coverage
def __run_client_from_login_flow_server(
q, callback_port, callback_path): # pragma: no cover
'''Helper server for intercepting redirects to the callback URL. See
client_from_login_flow for details.'''

import flask

app = flask.Flask(__name__)

@app.route(callback_path)
def handle_token():
q.put(flask.request.url)
return 'schwab-py callback received! You may now close this window/tab.'

@app.route('/schwab-py-internal/status')
def status():
return 'running'

app.run(port=callback_port, ssl_context='adhoc')


class RedirectTimeoutError(Exception):
pass

class RedirectServerExitedError(Exception):
pass


def client_from_login_flow(api_key, app_secret, callback_url, token_path,
asyncio=False, enforce_enums=False,
token_write_func=None, callback_timeout=300.0):
# TODO: documentation

# Start the server
parsed = urllib.parse.urlparse(callback_url)

if parsed.hostname != '127.0.0.1':
raise ValueError(
('disallowed hostname {}. client_from_login_flow only allows '+
'callback URLs with hostname 127.0.0.1').format(
parsed.hostname))

callback_port = parsed.port if parsed.port else 80
callback_path = parsed.path if parsed.path else '/'

output_queue = multiprocessing.Queue()

server = multiprocessing.Process(
target=__run_client_from_login_flow_server,
args=(output_queue, callback_port, callback_path))

print('Running a server to intercept the callback. Please ignore the ' +
'following debug messages:')
print()
server.start()

# Wait until the server successfully starts
while True:
# Check if the server is still alive
if server.exitcode is not None:
raise RedirectServerExitedError(
'Redirect server exited. Are you attempting to use a ' +
'callback URL without a port number specified?')

import traceback

# Attempt to send a request to the server
try:
with warnings.catch_warnings():
warnings.filterwarnings(
'ignore', category=urllib3.exceptions.InsecureRequestWarning)

resp = requests.get(
'https://127.0.0.1:{}/schwab-py-internal/status'.format(
callback_port), verify=False)
break
except requests.exceptions.ConnectionError as e:
pass

time.sleep(0.1)

# Open the browser
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
authorization_url, state = oauth.create_authorization_url(
'https://api.schwabapi.com/v1/oauth/authorize')

webbrowser.open(authorization_url)

# Wait for a response
now = time.time()
timeout_time = now + callback_timeout
callback_url = None
while now < timeout_time:
# Attempt to fetch from the queue
try:
callback_url = output_queue.get(
timeout=min(timeout_time - now, 0.1))
break
except queue.Empty:
pass

now = time.time()

# Clean up and create the client
psutil.Process(server.pid).kill()

if callback_url:
return __fetch_and_register_token_from_redirect(
oauth, callback_url, api_key, app_secret, token_path, token_write_func,
asyncio, enforce_enums=enforce_enums)
else:
raise RedirectTimeoutError(
'Timed out waiting for a post-authorization callback. You '+
'can set a longer timeout by passing a value of ' +
'callback_timeout to client_from_login_flow.')



def client_from_token_file(token_path, api_key, app_secret, asyncio=False,
enforce_enums=True):
'''
Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,11 @@
install_requires=[
'autopep8',
'authlib',
'flask',
'httpx',
'prompt_toolkit',
'psutil',
'requests',
'python-dateutil',
'selenium',
'websockets'
Expand Down
142 changes: 139 additions & 3 deletions tests/auth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@

import json
import os
import requests
import tempfile
import unittest

import schwab


API_KEY = 'APIKEY'
APP_SECRET = '0x5EC07'
Expand All @@ -21,6 +24,142 @@
REDIRECT_URL = 'https://redirect.url.com'


class ClientFromLoginFlowTest(unittest.TestCase):

def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory()
self.token_path = os.path.join(self.tmp_dir.name, 'token.json')
self.raw_token = {'token': 'yes'}
self.token = {
'token': self.raw_token,
'creation_timestamp': TOKEN_CREATION_TIMESTAMP
}

@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
@patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW))
def test_create_token_file(
self, mock_webbrowser_open, async_session, sync_session, client):
AUTH_URL = 'https://auth.url.com'

sync_session.return_value = sync_session
sync_session.create_authorization_url.return_value = AUTH_URL, None
sync_session.fetch_token.return_value = self.raw_token

callback_url = 'https://127.0.0.1:6969/callback'

mock_webbrowser_open.side_effect = \
lambda auth_url: requests.get(
'https://127.0.0.1:6969/callback', verify=False)

client.return_value = 'returned client'

auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path)

with open(self.token_path, 'r') as f:
self.assertEqual({
'creation_timestamp': MOCK_NOW,
'token': self.raw_token
}, json.load(f))


@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
@patch('time.time', unittest.mock.MagicMock(return_value=MOCK_NOW))
def test_create_token_file_root_callback_url(
self, mock_webbrowser_open, async_session, sync_session, client):
AUTH_URL = 'https://auth.url.com'

sync_session.return_value = sync_session
sync_session.create_authorization_url.return_value = AUTH_URL, None
sync_session.fetch_token.return_value = self.raw_token

callback_url = 'https://127.0.0.1:6969/'

mock_webbrowser_open.side_effect = \
lambda auth_url: requests.get(
'https://127.0.0.1:6969/', verify=False)

client.return_value = 'returned client'

auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path)

with open(self.token_path, 'r') as f:
self.assertEqual({
'creation_timestamp': MOCK_NOW,
'token': self.raw_token
}, json.load(f))


@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
def test_disallowed_hostname(
self, mock_webbrowser_open, async_session, sync_session, client):
callback_url = 'https://example.com/callback'

with self.assertRaisesRegex(
ValueError,'disallowed hostname example.com'):
auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path)


@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
def test_disallowed_hostname_with_port(
self, mock_webbrowser_open, async_session, sync_session, client):
callback_url = 'https://example.com:8080/callback'

with self.assertRaisesRegex(
ValueError,'disallowed hostname example.com'):
auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path)


@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
def test_unprivileged_start_on_port_80(
self, mock_webbrowser_open, async_session, sync_session, client):
callback_url = 'https://127.0.0.1/callback'

with self.assertRaisesRegex(schwab.auth.RedirectServerExitedError,
'callback URL without a port number'):
auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path)


@patch('schwab.auth.Client')
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
@patch('schwab.auth.webbrowser.open', new_callable=MagicMock)
def test_time_out_waiting_for_request(
self, mock_webbrowser_open, async_session, sync_session, client):
AUTH_URL = 'https://auth.url.com'

sync_session.return_value = sync_session
sync_session.create_authorization_url.return_value = AUTH_URL, None
sync_session.fetch_token.return_value = self.raw_token

callback_url = 'https://127.0.0.1:6969/callback'

with self.assertRaisesRegex(schwab.auth.RedirectTimeoutError,
'Timed out waiting'):
auth.client_from_login_flow(
API_KEY, APP_SECRET, callback_url, self.token_path,
callback_timeout=0.01)


class ClientFromTokenFileTest(unittest.TestCase):

def setUp(self):
Expand Down Expand Up @@ -309,9 +448,6 @@ def test_custom_token_write_func(
sync_session.create_authorization_url.return_value = AUTH_URL, None
sync_session.fetch_token.return_value = self.raw_token

webdriver = MagicMock()
webdriver.current_url = REDIRECT_URL + '/token_params'

client.return_value = 'returned client'
prompt_func.return_value = 'http://redirect.url.com/?data'

Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ setenv =
RCFILE=setup.cfg
allowlist_externals=coverage
commands =
coverage run --rcfile={env:RCFILE} --source=schwab -p -m pytest {env:TESTPATH}
coverage run --rcfile={env:RCFILE} --source=schwab -p -m pytest -W ignore::urllib3.exceptions.SystemTimeWarning -W ignore::urllib3.exceptions.InsecureRequestWarning {env:TESTPATH}

[testenv:coverage]
skip_install = true
Expand Down

0 comments on commit 632058e

Please sign in to comment.