Skip to content

Commit

Permalink
Merge pull request #110 from dasoran/add-gcs-discover-cache
Browse files Browse the repository at this point in the history
Add gcs discover cache
  • Loading branch information
nishiba authored Dec 17, 2019
2 parents a4c5e7b + da8c0fb commit c8e47ee
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 8 deletions.
46 changes: 45 additions & 1 deletion gokart/gcs_config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,61 @@
import json
import os
import fcntl
import uritemplate

import luigi
import luigi.contrib.gcs
from http import client as http_client
from googleapiclient.errors import HttpError
from google.oauth2.service_account import Credentials
from googleapiclient.http import build_http
from googleapiclient.discovery import _retrieve_discovery_doc


class GCSConfig(luigi.Config):
gcs_credential_name = luigi.Parameter(
default='GCS_CREDENTIAL', description='GCS credential environment variable.')
discover_cache_local_path = luigi.Parameter(
default='DISCOVER_CACHE_LOCAL_PATH', description='The file path of discover api cache.')

_DISCOVERY_URI = (
"https://www.googleapis.com/discovery/v1/apis/" "{api}/{apiVersion}/rest"
)
_V2_DISCOVERY_URI = (
"https://{api}.googleapis.com/$discovery/rest?" "version={apiVersion}"
)

def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:
return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials())
if (not os.path.isfile(self.discover_cache_local_path)):
with open(self.discover_cache_local_path, "w") as f:
try:
fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB)

params = {"api": "storage", "apiVersion": "v1"}
discovery_http = build_http()
for discovery_url in (self._DISCOVERY_URI, self._V2_DISCOVERY_URI):
requested_url = uritemplate.expand(discovery_url, params)
try:
content = _retrieve_discovery_doc(
requested_url, discovery_http, False
)
except HttpError as e:
if e.resp.status == http_client.NOT_FOUND:
continue
else:
raise e
break
f.write(content)
fcntl.flock(f, fcntl.LOCK_UN)
except IOError:
# try to read
pass

with open(self.discover_cache_local_path, "r") as f:
fcntl.flock(f, fcntl.LOCK_SH)
descriptor = f.read()
fcntl.flock(f, fcntl.LOCK_UN)
return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials(), descriptor=descriptor)

def _load_oauth_credentials(self):
json_str = os.environ.get(self.gcs_credential_name)
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
'numpy',
'tqdm',
'google-auth',
'pyarrow'
'pyarrow',
'uritemplate',
'google-api-python-client'
]

setup(
Expand Down
26 changes: 20 additions & 6 deletions test/test_gcs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,39 @@
class TestGcfConfig(unittest.TestCase):
def test_get_gcs_client_without_gcs_credential_name(self):
mock = MagicMock()
discover_path = 'discover_cache.json'
os.environ['env_name'] = ''
os.environ['discover_path'] = discover_path
with open(f'{discover_path}', 'w') as f:
f.write('{}')
with patch('luigi.contrib.gcs.GCSClient', mock):
GCSConfig().get_gcs_client()
self.assertEqual(dict(oauth_credentials=None), mock.call_args[1])
with patch('fcntl.flock'):
GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
self.assertEqual(dict(oauth_credentials=None, descriptor='{}'), mock.call_args[1])

def test_get_gcs_client_with_file_path(self):
mock = MagicMock()
file_path = 'test.json'
discover_path = 'discover_cache.json'
os.environ['env_name'] = file_path
os.environ['discover_path'] = discover_path
with open(f'{discover_path}', 'w') as f:
f.write('{}')
with patch('luigi.contrib.gcs.GCSClient'):
with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock):
with patch('os.path.isfile', return_value=True):
GCSConfig(gcs_credential_name='env_name').get_gcs_client()
self.assertEqual(file_path, mock.call_args[0][0])
GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
self.assertEqual(file_path, mock.call_args[0][0])

def test_get_gcs_client_with_json(self):
mock = MagicMock()
json_str = '{"test": 1}'
discover_path = 'discover_cache.json'
os.environ['env_name'] = json_str
os.environ['discover_path'] = discover_path
with open(f'{discover_path}', 'w') as f:
f.write('{}')
with patch('luigi.contrib.gcs.GCSClient'):
with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock):
GCSConfig(gcs_credential_name='env_name').get_gcs_client()
self.assertEqual(dict(test=1), mock.call_args[0][0])
GCSConfig(gcs_credential_name='env_name', discover_cache_local_path=discover_path).get_gcs_client()
self.assertEqual(dict(test=1), mock.call_args[0][0])

0 comments on commit c8e47ee

Please sign in to comment.