Skip to content

Commit

Permalink
Break up the login flow and expose the authorization URL to the libra…
Browse files Browse the repository at this point in the history
…ry user
  • Loading branch information
alexgolec committed Aug 31, 2024
1 parent eb985d3 commit 3ae305c
Showing 1 changed file with 86 additions and 64 deletions.
150 changes: 86 additions & 64 deletions schwab/auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuth2Client
from prompt_toolkit import prompt

import collections
import contextlib
import httpx
import json
Expand All @@ -27,7 +28,7 @@ def get_logger():
return logging.getLogger(__name__)


def __update_token(token_path):
def __make_update_token_func(token_path):
def update_token(t, *args, **kwargs):
get_logger().info('Updating token to file %s', token_path)

Expand Down Expand Up @@ -119,51 +120,6 @@ def wrap_token_in_metadata(self, token):
}


def __fetch_and_register_token_from_redirect(
oauth, redirected_url, api_key, app_secret, token_path,
token_write_func, asyncio, enforce_enums=True):
token = oauth.fetch_token(
TOKEN_ENDPOINT,
authorization_response=redirected_url,
client_id=api_key, auth=(api_key, app_secret))

# Don't emit token details in debug logs
register_redactions(token)

# Set up token writing and perform the initial token write
update_token = (
__update_token(token_path) if token_write_func is None
else token_write_func)
metadata_manager = TokenMetadata(token, int(time.time()), update_token)
update_token = metadata_manager.wrapped_token_write_func()
update_token(token)

# The synchronous and asynchronous versions of the OAuth2Client are similar
# enough that can mostly be used interchangeably. The one currently known
# exception is the token update function: the synchronous version expects a
# synchronous one, the asynchronous requires an async one. The
# oauth_client_update_token variable will contain the appropriate one.
if asyncio:
async def oauth_client_update_token(t, *args, **kwargs):
update_token(t, *args, **kwargs) # pragma: no cover
session_class = AsyncOAuth2Client
client_class = AsyncClient
else:
oauth_client_update_token = update_token
session_class = OAuth2Client
client_class = Client

# Return a new session configured to refresh credentials
return client_class(
api_key,
session_class(api_key,
client_secret=app_secret,
token=token,
update_token=oauth_client_update_token,
leeway=300),
token_metadata=metadata_manager, enforce_enums=enforce_enums)


################################################################################
# client_from_login_flow

Expand Down Expand Up @@ -351,9 +307,7 @@ def callback_server():
time.sleep(0.1)

# Open the browser
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
authorization_url, state = oauth.create_authorization_url(
'https://api.schwabapi.com/v1/oauth/authorize')
auth_context = get_auth_context(api_key, callback_url)

print()
print('***********************************************************************')
Expand All @@ -363,7 +317,7 @@ def callback_server():
print('browser, captures the resulting OAuth callback, and creates a token')
print('using the result. The authorization URL is:')
print()
print('>>', authorization_url)
print('>>', auth_context.authorization_url)
print()
print('IMPORTANT: Your browser will give you a security warning about an')
print('invalid certificate prior to issuing the redirect. This is because')
Expand All @@ -388,7 +342,7 @@ def callback_server():
'this method with interactive=False to skip this input.')

controller = webbrowser.get(requested_browser)
controller.open(authorization_url)
controller.open(auth_context.authorization_url)

# Wait for a response
now = __TIME_TIME()
Expand Down Expand Up @@ -420,9 +374,13 @@ def callback_server():
'can set a longer timeout by passing a value of ' +
'callback_timeout to client_from_login_flow.')

return __fetch_and_register_token_from_redirect(
oauth, received_url, api_key, app_secret, token_path,
token_write_func, asyncio, enforce_enums=enforce_enums)
token_write_func = (
__make_update_token_func(token_path) if token_write_func is None
else token_write_func)

return client_from_received_url(
api_key, app_secret, auth_context, received_url,
token_write_func, asyncio, enforce_enums)


################################################################################
Expand Down Expand Up @@ -455,8 +413,8 @@ def client_from_token_file(token_path, api_key, app_secret, asyncio=False,
load = __token_loader(token_path)

return client_from_access_functions(
api_key, app_secret, load, __update_token(token_path), asyncio=asyncio,
enforce_enums=enforce_enums)
api_key, app_secret, load, __make_update_token_func(token_path),
asyncio=asyncio, enforce_enums=enforce_enums)


################################################################################
Expand Down Expand Up @@ -494,9 +452,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
get_logger().info('Creating new token with callback URL \'%s\' ' +
'and token path \'%s\'', callback_url, token_path)

oauth = OAuth2Client(api_key, redirect_uri=callback_url)
authorization_url, state = oauth.create_authorization_url(
'https://api.schwabapi.com/v1/oauth/authorize')
auth_context = get_auth_context(api_key, callback_url)

print('\n**************************************************************\n')
print('This is the manual login and token creation flow for schwab-py.')
Expand All @@ -505,7 +461,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
print(' 1. Open the following link by copy-pasting it into the browser')
print(' of your choice:')
print()
print(' ' + authorization_url)
print(' ' + auth_context.authorization_url)
print()
print(' 2. Log in with your account credentials. You may be asked to')
print(' perform two-factor authentication using text messaging or')
Expand All @@ -529,11 +485,15 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
'and update your callback URL to begin with \'https\' ' +
'to stop seeing this message.').format(callback_url))

redirected_url = prompt('Redirect URL> ').strip()
received_url = prompt('Redirect URL> ').strip()

token_write_func = (
__make_update_token_func(token_path) if token_write_func is None
else token_write_func)

return __fetch_and_register_token_from_redirect(
oauth, redirected_url, api_key, app_secret, token_path, token_write_func,
asyncio, enforce_enums=enforce_enums)
return client_from_received_url(
api_key, app_secret, auth_context, received_url, token_write_func,
asyncio, enforce_enums)


################################################################################
Expand Down Expand Up @@ -611,6 +571,68 @@ async def oauth_client_update_token(t, *args, **kwargs):
enforce_enums=enforce_enums)


################################################################################
# Tools for incorporating token generation into webapp workflows


AuthContext = collections.namedtuple(
'AuthContext', ['callback_url', 'authorization_url', 'state'])

def get_auth_context(api_key, callback_url):
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
authorization_url, state = oauth.create_authorization_url(
'https://api.schwabapi.com/v1/oauth/authorize')

return AuthContext(callback_url, authorization_url, state)


def client_from_received_url(
api_key, app_secret, auth_context, received_url, token_write_func,
asyncio=False, enforce_enums=True):
# XXX: The AuthContext must be serializable, which means the original
# OAuth2Client created in get_auth_context cannot be passed around.
# Instead, we reconstruct it here.
oauth = OAuth2Client(api_key, redirect_uri=auth_context.callback_url)

token = oauth.fetch_token(
TOKEN_ENDPOINT,
authorization_response=received_url,
client_id=api_key, auth=(api_key, app_secret))

# Don't emit token details in debug logs
register_redactions(token)

# Set up token writing and perform the initial token write
metadata_manager = TokenMetadata(token, int(time.time()), token_write_func)
token_write_func = metadata_manager.wrapped_token_write_func()
token_write_func(token)

# The synchronous and asynchronous versions of the OAuth2Client are similar
# enough that can mostly be used interchangeably. The one currently known
# exception is the token update function: the synchronous version expects a
# synchronous one, the asynchronous requires an async one. The
# oauth_client_update_token variable will contain the appropriate one.
if asyncio:
async def oauth_client_update_token(t, *args, **kwargs):

Check warning on line 616 in schwab/auth.py

View check run for this annotation

Codecov / codecov/patch

schwab/auth.py#L616

Added line #L616 was not covered by tests
token_write_func(t, *args, **kwargs) # pragma: no cover
session_class = AsyncOAuth2Client
client_class = AsyncClient

Check warning on line 619 in schwab/auth.py

View check run for this annotation

Codecov / codecov/patch

schwab/auth.py#L618-L619

Added lines #L618 - L619 were not covered by tests
else:
oauth_client_update_token = token_write_func
session_class = OAuth2Client
client_class = Client

# Return a new session configured to refresh credentials
return client_class(
api_key,
session_class(api_key,
client_secret=app_secret,
token=token,
update_token=oauth_client_update_token,
leeway=300),
token_metadata=metadata_manager, enforce_enums=enforce_enums)


################################################################################
# easy_client

Expand Down

0 comments on commit 3ae305c

Please sign in to comment.