Skip to content

Commit

Permalink
Added EnvVarDagshubToken for when the user has overriden the token wi…
Browse files Browse the repository at this point in the history
…th the env var.

Made sure the get_token_object() function works e2e
  • Loading branch information
kbolashev committed Aug 21, 2023
1 parent 486aafa commit 7f01035
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 7 deletions.
29 changes: 29 additions & 0 deletions dagshub/auth/token_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def __init__(self, token_value: str, expiry_date: datetime.datetime):
self.token_value = token_value
self.expiry_date = expiry_date

# TODO: override the call function to warn about how much lifetime the token has

def serialize(self) -> Dict[str, Any]:
return {
"access_token": self.token_value,
Expand Down Expand Up @@ -123,6 +125,33 @@ def __repr__(self):
return "Dagshub App token"


class EnvVarDagshubToken(DagshubTokenABC):
token_type = "env-var"
priority = -1

def __init__(self, token_value: str, host: str):
self.token_value = token_value
self.host = host

def serialize(self) -> Dict[str, Any]:
raise RuntimeError("Can't serialize env var token")

@staticmethod
def deserialize(values: Dict[str, Any]):
raise RuntimeError("Can't deserialize env var token")

@property
def token_text(self) -> str:
return self.token_value

@property
def is_expired(self) -> bool:
return False

def __repr__(self):
return f"Dagshub Env Var token for host {self.host}"


class HTTPBearerAuth(Auth):
"""Attaches HTTP Bearer Authorization to the given Request object."""

Expand Down
43 changes: 36 additions & 7 deletions dagshub/auth/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from typing import Optional, Dict, List, Set, Union

import yaml
from httpx import Auth

from dagshub.auth import oauth
from dagshub.auth.token_auth import HTTPBearerAuth, DagshubTokenABC, TokenDeserializationError, AppDagshubToken
from dagshub.auth.token_auth import HTTPBearerAuth, DagshubTokenABC, TokenDeserializationError, AppDagshubToken, \
EnvVarDagshubToken
from dagshub.common import config
from dagshub.common.helpers import http_request
from dagshub.common.util import multi_urljoin
Expand Down Expand Up @@ -85,11 +87,11 @@ def get_token_object(self, host: str = None, fail_if_no_token: bool = False, **k
We're using a set of known good tokens to skip rechecking for token validity every time
"""

# TODO: warn on timed tokens
# TODO: different token types
host = host or config.host
if host == config.host and config.token is not None:
return EnvVarDagshubToken(config.token, host)

with self._token_access_lock:
host = host or config.host
tokens = self._token_cache.get(host, [])

had_changes = False # For saving if we invalidate some tokens
Expand Down Expand Up @@ -171,7 +173,7 @@ def _is_expired(token: Dict[str, str]) -> bool:
return is_expired

@staticmethod
def is_valid_token(token: str, host: str) -> bool:
def is_valid_token(token: Union[str, Auth, DagshubTokenABC], host: str) -> bool:
"""
Check for token validity
Expand All @@ -181,7 +183,10 @@ def is_valid_token(token: str, host: str) -> bool:
"""
host = host or config.host
check_url = multi_urljoin(host, "api/v1/user")
auth = HTTPBearerAuth(token)
if type(token) is str:
auth = HTTPBearerAuth(token)
else:
auth = token
resp = http_request("GET", check_url, auth=auth)

try:
Expand Down Expand Up @@ -262,10 +267,34 @@ def _get_token_storage(**kwargs):
return _token_storage


def get_token(**kwargs):
def get_authenticator(**kwargs):
"""
Get an authenticator object.
This object can be used as auth argument for the httpx requests
The authenticator has renegotiation logic in case where a token gets invalidated
"""
return _get_token_storage(**kwargs).get_authenticator(**kwargs)


def get_token_object(**kwargs):
"""
Gets a DagsHub token, by default if no token is found authenticates with OAuth
Kwargs:
host (str): URL of a dagshub instance (defaults to dagshub.com)
cache_location (str): Location of the cache file with the token (defaults to <cache_dir>/dagshub/tokens)
fail_if_no_token (bool): What to do if token is not found.
If set to False (default), goes through OAuth flow
If set to True, throws a RuntimeError
"""
return _get_token_storage(**kwargs).get_token_object(**kwargs)


def get_token(**kwargs):
"""
Gets a DagsHub token text, by default if no token is found authenticates with OAuth
Kwargs:
host (str): URL of a dagshub instance (defaults to dagshub.com)
cache_location (str): Location of the cache file with the token (defaults to <cache_dir>/dagshub/tokens)
Expand Down

0 comments on commit 7f01035

Please sign in to comment.