diff --git a/datatracker/oauth.py b/datatracker/oauth.py index 5395885c..b93197ca 100644 --- a/datatracker/oauth.py +++ b/datatracker/oauth.py @@ -140,22 +140,37 @@ def get_client(request): As a side effect, if session is not authenticated, clears OAuth data from session. + + Will attempt to auto-refresh the token if expired. """ + def token_updater(token: str): + request.session[OAUTH_TOKEN_KEY] = token + if is_datatracker_oauth_enabled() and OAUTH_TOKEN_KEY in request.session: + provider = get_provider() session = OAuth2Session( CLIENT_ID, scope=OAUTH_SCOPES, redirect_uri=_get_redirect_uri(), - token=request.session[OAUTH_TOKEN_KEY]) - provider = get_provider() + token=request.session[OAUTH_TOKEN_KEY], + + auto_refresh_url=provider.token_endpoint, + auto_refresh_kwargs=dict( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + scope=OAUTH_SCOPES, + ), + token_updater=token_updater, + ) try: - session.get(provider.userinfo_endpoint).json() + user_info = session.get(provider.userinfo_endpoint).json() except Exception: - # Most likely, token expired. + log.exception("Unable to retrieve user info") clear_session(request) return None else: + request.session[OAUTH_USER_INFO_KEY] = user_info return session else: return None