From 3ae305c5767329fa44786ae7c6934b74979c640c Mon Sep 17 00:00:00 2001 From: Alex Golec Date: Sat, 31 Aug 2024 09:49:44 -0400 Subject: [PATCH] Break up the login flow and expose the authorization URL to the library user --- schwab/auth.py | 150 ++++++++++++++++++++++++++++--------------------- 1 file changed, 86 insertions(+), 64 deletions(-) diff --git a/schwab/auth.py b/schwab/auth.py index e57a063..fb4c4df 100644 --- a/schwab/auth.py +++ b/schwab/auth.py @@ -1,6 +1,7 @@ from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuth2Client from prompt_toolkit import prompt +import collections import contextlib import httpx import json @@ -27,7 +28,7 @@ def get_logger(): return logging.getLogger(__name__) -def __update_token(token_path): +def __make_update_token_func(token_path): def update_token(t, *args, **kwargs): get_logger().info('Updating token to file %s', token_path) @@ -119,51 +120,6 @@ def wrap_token_in_metadata(self, token): } -def __fetch_and_register_token_from_redirect( - oauth, redirected_url, api_key, app_secret, token_path, - token_write_func, asyncio, enforce_enums=True): - token = oauth.fetch_token( - TOKEN_ENDPOINT, - authorization_response=redirected_url, - client_id=api_key, auth=(api_key, app_secret)) - - # Don't emit token details in debug logs - register_redactions(token) - - # Set up token writing and perform the initial token write - update_token = ( - __update_token(token_path) if token_write_func is None - else token_write_func) - metadata_manager = TokenMetadata(token, int(time.time()), update_token) - update_token = metadata_manager.wrapped_token_write_func() - update_token(token) - - # The synchronous and asynchronous versions of the OAuth2Client are similar - # enough that can mostly be used interchangeably. The one currently known - # exception is the token update function: the synchronous version expects a - # synchronous one, the asynchronous requires an async one. The - # oauth_client_update_token variable will contain the appropriate one. - if asyncio: - async def oauth_client_update_token(t, *args, **kwargs): - update_token(t, *args, **kwargs) # pragma: no cover - session_class = AsyncOAuth2Client - client_class = AsyncClient - else: - oauth_client_update_token = update_token - session_class = OAuth2Client - client_class = Client - - # Return a new session configured to refresh credentials - return client_class( - api_key, - session_class(api_key, - client_secret=app_secret, - token=token, - update_token=oauth_client_update_token, - leeway=300), - token_metadata=metadata_manager, enforce_enums=enforce_enums) - - ################################################################################ # client_from_login_flow @@ -351,9 +307,7 @@ def callback_server(): time.sleep(0.1) # Open the browser - oauth = OAuth2Client(api_key, redirect_uri=callback_url) - authorization_url, state = oauth.create_authorization_url( - 'https://api.schwabapi.com/v1/oauth/authorize') + auth_context = get_auth_context(api_key, callback_url) print() print('***********************************************************************') @@ -363,7 +317,7 @@ def callback_server(): print('browser, captures the resulting OAuth callback, and creates a token') print('using the result. The authorization URL is:') print() - print('>>', authorization_url) + print('>>', auth_context.authorization_url) print() print('IMPORTANT: Your browser will give you a security warning about an') print('invalid certificate prior to issuing the redirect. This is because') @@ -388,7 +342,7 @@ def callback_server(): 'this method with interactive=False to skip this input.') controller = webbrowser.get(requested_browser) - controller.open(authorization_url) + controller.open(auth_context.authorization_url) # Wait for a response now = __TIME_TIME() @@ -420,9 +374,13 @@ def callback_server(): 'can set a longer timeout by passing a value of ' + 'callback_timeout to client_from_login_flow.') - return __fetch_and_register_token_from_redirect( - oauth, received_url, api_key, app_secret, token_path, - token_write_func, asyncio, enforce_enums=enforce_enums) + token_write_func = ( + __make_update_token_func(token_path) if token_write_func is None + else token_write_func) + + return client_from_received_url( + api_key, app_secret, auth_context, received_url, + token_write_func, asyncio, enforce_enums) ################################################################################ @@ -455,8 +413,8 @@ def client_from_token_file(token_path, api_key, app_secret, asyncio=False, load = __token_loader(token_path) return client_from_access_functions( - api_key, app_secret, load, __update_token(token_path), asyncio=asyncio, - enforce_enums=enforce_enums) + api_key, app_secret, load, __make_update_token_func(token_path), + asyncio=asyncio, enforce_enums=enforce_enums) ################################################################################ @@ -494,9 +452,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path, get_logger().info('Creating new token with callback URL \'%s\' ' + 'and token path \'%s\'', callback_url, token_path) - oauth = OAuth2Client(api_key, redirect_uri=callback_url) - authorization_url, state = oauth.create_authorization_url( - 'https://api.schwabapi.com/v1/oauth/authorize') + auth_context = get_auth_context(api_key, callback_url) print('\n**************************************************************\n') print('This is the manual login and token creation flow for schwab-py.') @@ -505,7 +461,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path, print(' 1. Open the following link by copy-pasting it into the browser') print(' of your choice:') print() - print(' ' + authorization_url) + print(' ' + auth_context.authorization_url) print() print(' 2. Log in with your account credentials. You may be asked to') print(' perform two-factor authentication using text messaging or') @@ -529,11 +485,15 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path, 'and update your callback URL to begin with \'https\' ' + 'to stop seeing this message.').format(callback_url)) - redirected_url = prompt('Redirect URL> ').strip() + received_url = prompt('Redirect URL> ').strip() + + token_write_func = ( + __make_update_token_func(token_path) if token_write_func is None + else token_write_func) - return __fetch_and_register_token_from_redirect( - oauth, redirected_url, api_key, app_secret, token_path, token_write_func, - asyncio, enforce_enums=enforce_enums) + return client_from_received_url( + api_key, app_secret, auth_context, received_url, token_write_func, + asyncio, enforce_enums) ################################################################################ @@ -611,6 +571,68 @@ async def oauth_client_update_token(t, *args, **kwargs): enforce_enums=enforce_enums) +################################################################################ +# Tools for incorporating token generation into webapp workflows + + +AuthContext = collections.namedtuple( + 'AuthContext', ['callback_url', 'authorization_url', 'state']) + +def get_auth_context(api_key, callback_url): + oauth = OAuth2Client(api_key, redirect_uri=callback_url) + authorization_url, state = oauth.create_authorization_url( + 'https://api.schwabapi.com/v1/oauth/authorize') + + return AuthContext(callback_url, authorization_url, state) + + +def client_from_received_url( + api_key, app_secret, auth_context, received_url, token_write_func, + asyncio=False, enforce_enums=True): + # XXX: The AuthContext must be serializable, which means the original + # OAuth2Client created in get_auth_context cannot be passed around. + # Instead, we reconstruct it here. + oauth = OAuth2Client(api_key, redirect_uri=auth_context.callback_url) + + token = oauth.fetch_token( + TOKEN_ENDPOINT, + authorization_response=received_url, + client_id=api_key, auth=(api_key, app_secret)) + + # Don't emit token details in debug logs + register_redactions(token) + + # Set up token writing and perform the initial token write + metadata_manager = TokenMetadata(token, int(time.time()), token_write_func) + token_write_func = metadata_manager.wrapped_token_write_func() + token_write_func(token) + + # The synchronous and asynchronous versions of the OAuth2Client are similar + # enough that can mostly be used interchangeably. The one currently known + # exception is the token update function: the synchronous version expects a + # synchronous one, the asynchronous requires an async one. The + # oauth_client_update_token variable will contain the appropriate one. + if asyncio: + async def oauth_client_update_token(t, *args, **kwargs): + token_write_func(t, *args, **kwargs) # pragma: no cover + session_class = AsyncOAuth2Client + client_class = AsyncClient + else: + oauth_client_update_token = token_write_func + session_class = OAuth2Client + client_class = Client + + # Return a new session configured to refresh credentials + return client_class( + api_key, + session_class(api_key, + client_secret=app_secret, + token=token, + update_token=oauth_client_update_token, + leeway=300), + token_metadata=metadata_manager, enforce_enums=enforce_enums) + + ################################################################################ # easy_client