Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wire up retry count config to NVD provider #738

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/vunnel/providers/nvd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class Config:
),
)
request_timeout: int = 125
request_retry_count: int = 10
api_key: Optional[str] = "env:NVD_API_KEY" # noqa: UP007
overrides_url: str = "https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz"
overrides_enabled: bool = False
Expand Down Expand Up @@ -71,6 +72,7 @@ def __init__(self, root: str, config: Config | None = None):
workspace=self.workspace,
schema=self.__schema__,
download_timeout=self.config.request_timeout,
download_retry_count=self.config.request_retry_count,
api_key=self.config.api_key,
logger=self.logger,
overrides_enabled=self.config.overrides_enabled,
Expand Down
18 changes: 16 additions & 2 deletions src/vunnel/providers/nvd/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,16 @@ class NvdAPI:
_max_results_per_page_: int = 2000
max_date_range_days: int = 120

def __init__(self, api_key: str | None = None, logger: logging.Logger | None = None, timeout: int = 30):
def __init__(
self,
api_key: str | None = None,
logger: logging.Logger | None = None,
timeout: int = 30,
retries: int = 10,
) -> None:
self.api_key = api_key
self.timeout = timeout
self.retries = retries

if not logger:
logger = logging.getLogger(self.__class__.__name__)
Expand Down Expand Up @@ -154,7 +161,14 @@ def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str]

# NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second
# rolling window, so setting retry to start trying again after 30 seconds.
response = http.get(url, self.logger, backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout)
response = http.get(
url,
self.logger,
params=payload_str,
headers=headers,
timeout=self.timeout,
retries=self.retries,
)
response.encoding = "utf-8"

return response
Expand Down
4 changes: 3 additions & 1 deletion src/vunnel/providers/nvd/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__( # noqa: PLR0913
overrides_url: str,
logger: logging.Logger | None = None,
download_timeout: int = 125,
download_retry_count: int = 10,
api_key: str | None = None,
overrides_enabled: bool = False,
) -> None:
Expand All @@ -35,14 +36,15 @@ def __init__( # noqa: PLR0913
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger

self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout)
self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout, retries=download_retry_count)

self.overrides = NVDOverrides(
enabled=overrides_enabled,
url=overrides_url,
workspace=workspace,
logger=logger,
download_timeout=download_timeout,
retries=download_retry_count,
)

self.urls = [self.api._cve_api_url_]
Expand Down
6 changes: 4 additions & 2 deletions src/vunnel/providers/nvd/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,20 @@ class NVDOverrides:
__file_name__ = "nvd-overrides.tar.gz"
__extract_name__ = "nvd-overrides"

def __init__(
def __init__( # noqa: PLR0913
self,
enabled: bool,
url: str,
workspace: Workspace,
logger: logging.Logger | None = None,
download_timeout: int = 125,
retries: int = 5,
):
self.enabled = enabled
self.__url__ = url
self.workspace = workspace
self.download_timeout = download_timeout
self.retries = retries
if not logger:
logger = logging.getLogger(self.__class__.__name__)
self.logger = logger
Expand All @@ -43,7 +45,7 @@ def download(self) -> None:
self.logger.debug("overrides are not enabled, skipping download...")
return

req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout)
req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout, retries=self.retries)

file_path = os.path.join(self.workspace.input_path, self.__file_name__)
with open(file_path, "wb") as fp:
Expand Down
30 changes: 18 additions & 12 deletions src/vunnel/utils/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def get( # noqa: PLR0913
backoff_in_seconds: int = 3,
timeout: int = DEFAULT_TIMEOUT,
status_handler: Optional[Callable[[requests.Response], None]] = None, # noqa: UP007 - python 3.9
max_interval: int = 600,
**kwargs: Any,
) -> requests.Response:
"""
Expand All @@ -45,15 +46,15 @@ def get( # noqa: PLR0913
status_handler= lambda response: None if response.status_code in [200, 201, 405] else response.raise_for_status())

"""
logger.debug(f"http GET {url}")
last_exception: Exception | None = None
sleep_interval = backoff_in_seconds
for attempt in range(retries + 1):
if last_exception:
sleep_interval = backoff_sleep_interval(backoff_in_seconds, attempt - 1, max_value=max_interval)
logger.warning(f"will retry in {int(sleep_interval)} seconds...")
time.sleep(sleep_interval)
sleep_interval = backoff_in_seconds * 2**attempt + random.uniform(0, 1) # noqa: S311
# explanation of S311 disable: rng is not used cryptographically

try:
logger.debug(f"http GET {url} timeout={timeout} retries={retries} backoff={backoff_in_seconds}")
response = requests.get(url, timeout=timeout, **kwargs)
if status_handler:
status_handler(response)
Expand All @@ -62,20 +63,25 @@ def get( # noqa: PLR0913
return response
except requests.exceptions.HTTPError as e:
last_exception = e
will_retry = ""
if attempt < retries:
will_retry = f" (will retry in {int(backoff_in_seconds)} seconds) "
# HTTPError includes the attempted request, so don't include it redundantly here
logger.warning(f"attempt {attempt + 1} of {retries + 1} failed:{will_retry}{e}")
logger.warning(f"attempt {attempt + 1} of {retries + 1} failed: {e}")
except Exception as e:
last_exception = e
will_retry = ""
if attempt < retries:
will_retry = f" (will retry in {int(sleep_interval)} seconds) "
# this is an unexpected exception type, so include the attempted request in case the
# message from the unexpected exception doesn't.
logger.warning(f"attempt {attempt + 1} of {retries + 1}{will_retry}: unexpected exception during GET {url}: {e}")
logger.warning(f"attempt {attempt + 1} of {retries + 1}: unexpected exception during GET {url}: {e}")
if last_exception:
logger.error(f"last retry of GET {url} failed with {last_exception}")
raise last_exception
raise Exception("unreachable")


def backoff_sleep_interval(interval: int, attempt: int, max_value: None | int = None, jitter: bool = True) -> float:
# this is an exponential backoff
val = interval * 2**attempt
if max_value and val > max_value:
val = max_value
if jitter:
val += random.uniform(0, 1) # noqa: S311
# explanation of S311 disable: rng is not used cryptographically
return val
1 change: 1 addition & 0 deletions tests/unit/cli/test-fixtures/full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ providers:
nvd:
runtime: *runtime
request_timeout: 20
request_retry_count: 50
overrides_enabled: true
overrides_url: https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz
oracle:
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ def test_config(monkeypatch) -> None:
api_key: secret
overrides_enabled: false
overrides_url: https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz
request_retry_count: 10
request_timeout: 125
runtime:
existing_input: keep
Expand Down
1 change: 1 addition & 0 deletions tests/unit/cli/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_full_config(helpers):
nvd=providers.nvd.Config(
runtime=runtime_cfg,
request_timeout=20,
request_retry_count=50,
overrides_enabled=True,
overrides_url="https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz",
),
Expand Down
18 changes: 9 additions & 9 deletions tests/unit/providers/nvd/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def test_cve_no_api_key(self, simple_mock, mocker):
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
backoff_in_seconds=30,
retries=10,
params="cveId=CVE-2020-0000",
headers={"content-type": "application/json"},
timeout=1,
Expand All @@ -59,7 +59,7 @@ def test_cve_single_cve(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down Expand Up @@ -118,23 +118,23 @@ def test_cve_multi_page(self, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=3",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
mocker.call(
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=3&startIndex=6",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -156,7 +156,7 @@ def test_cve_pub_date_range(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="pubStartDate=2019-12-04T00:00:00&pubEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -178,7 +178,7 @@ def test_cve_last_modified_date_range(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="lastModStartDate=2019-12-04T00:00:00&lastModEndDate=2019-12-05T00:00:00",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -197,7 +197,7 @@ def test_results_per_page(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cves/2.0",
subject.logger,
params="resultsPerPage=5",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand All @@ -214,7 +214,7 @@ def test_cve_history(self, simple_mock, mocker):
"https://services.nvd.nist.gov/rest/json/cvehistory/2.0",
subject.logger,
params="cveId=CVE-2020-0000",
backoff_in_seconds=30,
retries=10,
headers={"content-type": "application/json", "apiKey": "secret"},
timeout=1,
),
Expand Down
54 changes: 47 additions & 7 deletions tests/unit/utils/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,14 @@ def test_correct_number_of_retria(self, mock_requests, mock_sleep, mock_logger,

@patch("time.sleep")
@patch("requests.get")
def test_succeeds_if_retries_succeed(self, mock_requests, mock_sleep, mock_logger, error_response, success_response):
@patch("random.uniform")
def test_succeeds_if_retries_succeed(
self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response
):
mock_uniform_random.side_effect = [0.1]
mock_requests.side_effect = [error_response, success_response]
http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=22)
mock_sleep.assert_called_with(22)
mock_sleep.assert_called_with(22.1)
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
mock_requests.assert_called_with("http://example.com/some-path", timeout=http.DEFAULT_TIMEOUT)
Expand All @@ -74,7 +78,7 @@ def test_exponential_backoff_and_jitter(
mock_requests.side_effect = [error_response, error_response, error_response, success_response]
mock_uniform_random.side_effect = [0.5, 0.4, 0.1]
http.get("http://example.com/some-path", mock_logger, backoff_in_seconds=10, retries=3)
assert mock_sleep.call_args_list == [call(10), call(10 * 2 + 0.5), call(10 * 4 + 0.4)]
assert mock_sleep.call_args_list == [call(10 + 0.5), call(10 * 2 + 0.4), call(10 * 4 + 0.1)]

@patch("time.sleep")
@patch("requests.get")
Expand All @@ -91,8 +95,13 @@ def test_it_logs_the_url_on_failure(self, mock_requests, mock_sleep, mock_logger
def test_it_log_warns_errors(self, mock_requests, mock_sleep, mock_logger, error_response, success_response):
mock_requests.side_effect = [error_response, success_response]
http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=33)
assert "HTTP ERROR" in mock_logger.warning.call_args.args[0]
assert "will retry in 33 seconds" in mock_logger.warning.call_args.args[0]

logged_warnings = [call.args[0] for call in mock_logger.warning.call_args_list]

assert any("HTTP ERROR" in message for message in logged_warnings), "Expected 'HTTP ERROR' in logged warnings."
assert any(
"will retry in 33 seconds" in message for message in logged_warnings
), "Expected retry message in logged warnings."

@patch("time.sleep")
@patch("requests.get")
Expand All @@ -109,16 +118,47 @@ def test_it_calls_status_handler(self, mock_requests, mock_sleep, mock_logger, e

@patch("time.sleep")
@patch("requests.get")
@patch("random.uniform")
def test_it_retries_when_status_handler_raises(
self, mock_requests, mock_sleep, mock_logger, error_response, success_response
self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response
):
mock_uniform_random.side_effect = [0.25]
mock_requests.side_effect = [success_response, error_response]
status_handler = MagicMock()
status_handler.side_effect = [Exception("custom exception"), None]
result = http.get(
"http://example.com/some-path", mock_logger, status_handler=status_handler, retries=1, backoff_in_seconds=33
)
mock_sleep.assert_called_with(33)
mock_sleep.assert_called_with(33.25)
# custom status handler raised the first time it was called,
# so we expect the second mock response to be returned overall
assert result == error_response


@pytest.mark.parametrize(
"interval, jitter, max_value, expected",
[
(
30, # interval
False, # jitter
None, # max_value
[30, 60, 120, 240, 480, 960, 1920, 3840, 7680, 15360, 30720, 61440, 122880, 245760, 491520], # expected
),
(
3, # interval
False, # jitter
1000, # max_value
[3, 6, 12, 24, 48, 96, 192, 384, 768, 1000, 1000, 1000, 1000, 1000, 1000], # expected
),
],
)
def test_backoff_sleep_interval(interval, jitter, max_value, expected):
actual = [
http.backoff_sleep_interval(interval, attempt, jitter=jitter, max_value=max_value) for attempt in range(len(expected))
]

if not jitter:
assert actual == expected
else:
for i, (a, e) in enumerate(zip(actual, expected)):
assert a >= e and a <= e + 1, f"Jittered value out of bounds at attempt {i}: {a} (expected ~{e})"
Loading