Skip to content

Commit

Permalink
Add tests for the tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
kbolashev committed Aug 21, 2023
1 parent 7f01035 commit 2c70634
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 2 deletions.
2 changes: 1 addition & 1 deletion dagshub/auth/oauth.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hashlib
from typing import Optional, Dict
from typing import Optional
import logging
import urllib
import uuid
Expand Down
3 changes: 2 additions & 1 deletion dagshub/auth/token_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ class DagshubAuthenticator(Auth):
This class contains a token + flow on how to re-init the token in case of failure
"""

def __init__(self, token: "DagshubTokenABC", token_storage: "TokenStorage"):
def __init__(self, token: "DagshubTokenABC", token_storage: "TokenStorage", host: str):
self.token = token
self.token_storage = token_storage
self.host = host

def auth_flow(self, request: Request) -> Generator[Request, Response, None]:
# TODO: failure mode recovery
Expand Down
145 changes: 145 additions & 0 deletions tests/dda/test_tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import datetime

import httpx

import dagshub.common.config
import dateutil.parser
import pytest
from dagshub.auth.token_auth import AppDagshubToken, DagshubTokenABC, OAuthDagshubToken, EnvVarDagshubToken
from dagshub.auth.tokens import (
TokenStorage,
InvalidTokenError,
)


def valid_token_side_effect(request: httpx.Request) -> httpx.Response:
if request.headers["Authorization"] == "Bearer good-token":
return httpx.Response(200, json={
"id": 1,
"login": "user",
"full_name": "user",
"avatar_url": "random_url",
"username": "user",
})
else:
return httpx.Response(401)


@pytest.fixture
def token_api(mock_api):
dagshub.common.config.token = None # Disable the env var token for these tests explicitly
mock_api.get("https://dagshub.com/api/v1/user").mock(side_effect=valid_token_side_effect)

mock_api.post("https://dagshub.com/api/v1/middleman").mock(httpx.Response(200, json="code"))

a_day_away = datetime.datetime.utcnow() + datetime.timedelta(days=1)
good_auth_token = {
"access_token": "good-token",
"expiry": a_day_away.isoformat(),
"token_type": "bearer",
}
mock_api.post("https://dagshub.com/api/v1/access_token").mock(httpx.Response(200, json=good_auth_token))
yield mock_api


@pytest.fixture
def token_cache(token_api, tmp_path) -> TokenStorage:
cache_file = tmp_path / "tokens"
storage = TokenStorage(cache_location=str(cache_file.absolute()))
yield storage


@pytest.fixture
def valid_token() -> DagshubTokenABC:
return AppDagshubToken("good-token")


@pytest.fixture
def invalid_token() -> DagshubTokenABC:
return AppDagshubToken("fake-token")


@pytest.fixture
def expired_token() -> DagshubTokenABC:
old_date = dateutil.parser.parse("1990-01-01T16:24:53.451259Z")
return OAuthDagshubToken("fake-token", old_date)


def test_no_token_fails_if_set_to_fail(token_cache):
with pytest.raises(RuntimeError):
token_cache.get_token(fail_if_no_token=True)


def test_cant_add_invalid_token(token_cache, invalid_token):
with pytest.raises(InvalidTokenError):
token_cache.add_token(invalid_token)


def test_can_add_valid_token(token_cache, valid_token):
# Assume there is a known good token in the regular get_token flow
token_cache.add_token(valid_token)


def test_expired_token_gets_cleaned_up(token_cache, expired_token):
token_cleanup_test(token_cache, expired_token)


def test_invalid_token_gets_cleaned_up(token_cache, invalid_token):
token_cleanup_test(token_cache, invalid_token)


def token_cleanup_test(token_cache, token):
token_cache.add_token(token, skip_validation=True)

with pytest.raises(RuntimeError):
token_cache.get_token(fail_if_no_token=True)
# Also call this to remove expired tokens
token_cache.remove_expired_tokens()

# Check that the token got deleted from the file
failed = False
with open(token_cache.cache_location, "r") as f:
for line in f.readlines():
if token.token_text in line:
failed = True
break
print(token_cache.cache_location)
assert not failed


def test_token_addition(token_cache, valid_token):
token_cache.add_token(valid_token, skip_validation=True)
passed = False
with open(token_cache.cache_location, "r") as f:
for line in f.readlines():
if valid_token.token_text in line:
passed = True
break
assert passed


def test_token_validity(token_cache, valid_token):
assert token_cache.is_valid_token(valid_token, host=dagshub.common.config.host)


def test_valid_token_gets_returned(token_cache, valid_token, invalid_token):
token_cache.add_token(valid_token, skip_validation=True)
token_cache.add_token(invalid_token, skip_validation=True)

actual = token_cache.get_token()
assert actual == valid_token.token_text


def test_env_var_tokens_gets_returned_no_matter_what(token_cache, valid_token):
token_cache.add_token(valid_token, skip_validation=True)

old_val = dagshub.common.config.token
val = "token-set-in-env-var"
try:
dagshub.common.config.token = val
actual = token_cache.get_token_object()

assert type(actual) is EnvVarDagshubToken
assert actual.token_text == val
finally:
dagshub.common.config.token = old_val

0 comments on commit 2c70634

Please sign in to comment.