diff --git a/src/vunnel/providers/nvd/__init__.py b/src/vunnel/providers/nvd/__init__.py index d91fc310..b25c5ef7 100644 --- a/src/vunnel/providers/nvd/__init__.py +++ b/src/vunnel/providers/nvd/__init__.py @@ -20,6 +20,7 @@ class Config: ), ) request_timeout: int = 125 + request_retry_count: int = 10 api_key: Optional[str] = "env:NVD_API_KEY" # noqa: UP007 overrides_url: str = "https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz" overrides_enabled: bool = False @@ -71,6 +72,7 @@ def __init__(self, root: str, config: Config | None = None): workspace=self.workspace, schema=self.__schema__, download_timeout=self.config.request_timeout, + download_retry_count=self.config.request_retry_count, api_key=self.config.api_key, logger=self.logger, overrides_enabled=self.config.overrides_enabled, diff --git a/src/vunnel/providers/nvd/api.py b/src/vunnel/providers/nvd/api.py index 61acdc1f..1d7681b9 100644 --- a/src/vunnel/providers/nvd/api.py +++ b/src/vunnel/providers/nvd/api.py @@ -21,9 +21,16 @@ class NvdAPI: _max_results_per_page_: int = 2000 max_date_range_days: int = 120 - def __init__(self, api_key: str | None = None, logger: logging.Logger | None = None, timeout: int = 30): + def __init__( + self, + api_key: str | None = None, + logger: logging.Logger | None = None, + timeout: int = 30, + retries: int = 10, + ) -> None: self.api_key = api_key self.timeout = timeout + self.retries = retries if not logger: logger = logging.getLogger(self.__class__.__name__) @@ -154,7 +161,14 @@ def _request(self, url: str, parameters: dict[str, str], headers: dict[str, str] # NVD rate-limiting is detailed at https://nvd.nist.gov/developers/start-here and currently resets on a 30 second # rolling window, so setting retry to start trying again after 30 seconds. - response = http.get(url, self.logger, backoff_in_seconds=30, params=payload_str, headers=headers, timeout=self.timeout) + response = http.get( + url, + self.logger, + params=payload_str, + headers=headers, + timeout=self.timeout, + retries=self.retries, + ) response.encoding = "utf-8" return response diff --git a/src/vunnel/providers/nvd/manager.py b/src/vunnel/providers/nvd/manager.py index 23431c7c..637ace86 100644 --- a/src/vunnel/providers/nvd/manager.py +++ b/src/vunnel/providers/nvd/manager.py @@ -26,6 +26,7 @@ def __init__( # noqa: PLR0913 overrides_url: str, logger: logging.Logger | None = None, download_timeout: int = 125, + download_retry_count: int = 10, api_key: str | None = None, overrides_enabled: bool = False, ) -> None: @@ -35,7 +36,7 @@ def __init__( # noqa: PLR0913 logger = logging.getLogger(self.__class__.__name__) self.logger = logger - self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout) + self.api = NvdAPI(api_key=api_key, logger=logger, timeout=download_timeout, retries=download_retry_count) self.overrides = NVDOverrides( enabled=overrides_enabled, @@ -43,6 +44,7 @@ def __init__( # noqa: PLR0913 workspace=workspace, logger=logger, download_timeout=download_timeout, + retries=download_retry_count, ) self.urls = [self.api._cve_api_url_] diff --git a/src/vunnel/providers/nvd/overrides.py b/src/vunnel/providers/nvd/overrides.py index 4a61cf01..c25581a4 100644 --- a/src/vunnel/providers/nvd/overrides.py +++ b/src/vunnel/providers/nvd/overrides.py @@ -17,18 +17,20 @@ class NVDOverrides: __file_name__ = "nvd-overrides.tar.gz" __extract_name__ = "nvd-overrides" - def __init__( + def __init__( # noqa: PLR0913 self, enabled: bool, url: str, workspace: Workspace, logger: logging.Logger | None = None, download_timeout: int = 125, + retries: int = 5, ): self.enabled = enabled self.__url__ = url self.workspace = workspace self.download_timeout = download_timeout + self.retries = retries if not logger: logger = logging.getLogger(self.__class__.__name__) self.logger = logger @@ -43,7 +45,7 @@ def download(self) -> None: self.logger.debug("overrides are not enabled, skipping download...") return - req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout) + req = http.get(self.__url__, self.logger, stream=True, timeout=self.download_timeout, retries=self.retries) file_path = os.path.join(self.workspace.input_path, self.__file_name__) with open(file_path, "wb") as fp: diff --git a/src/vunnel/utils/http.py b/src/vunnel/utils/http.py index 7b73fa1c..b2ad4f56 100644 --- a/src/vunnel/utils/http.py +++ b/src/vunnel/utils/http.py @@ -20,6 +20,7 @@ def get( # noqa: PLR0913 backoff_in_seconds: int = 3, timeout: int = DEFAULT_TIMEOUT, status_handler: Optional[Callable[[requests.Response], None]] = None, # noqa: UP007 - python 3.9 + max_interval: int = 600, **kwargs: Any, ) -> requests.Response: """ @@ -45,15 +46,15 @@ def get( # noqa: PLR0913 status_handler= lambda response: None if response.status_code in [200, 201, 405] else response.raise_for_status()) """ - logger.debug(f"http GET {url}") last_exception: Exception | None = None - sleep_interval = backoff_in_seconds for attempt in range(retries + 1): if last_exception: + sleep_interval = backoff_sleep_interval(backoff_in_seconds, attempt - 1, max_value=max_interval) + logger.warning(f"will retry in {int(sleep_interval)} seconds...") time.sleep(sleep_interval) - sleep_interval = backoff_in_seconds * 2**attempt + random.uniform(0, 1) # noqa: S311 - # explanation of S311 disable: rng is not used cryptographically + try: + logger.debug(f"http GET {url} timeout={timeout} retries={retries} backoff={backoff_in_seconds}") response = requests.get(url, timeout=timeout, **kwargs) if status_handler: status_handler(response) @@ -62,20 +63,25 @@ def get( # noqa: PLR0913 return response except requests.exceptions.HTTPError as e: last_exception = e - will_retry = "" - if attempt < retries: - will_retry = f" (will retry in {int(backoff_in_seconds)} seconds) " # HTTPError includes the attempted request, so don't include it redundantly here - logger.warning(f"attempt {attempt + 1} of {retries + 1} failed:{will_retry}{e}") + logger.warning(f"attempt {attempt + 1} of {retries + 1} failed: {e}") except Exception as e: last_exception = e - will_retry = "" - if attempt < retries: - will_retry = f" (will retry in {int(sleep_interval)} seconds) " # this is an unexpected exception type, so include the attempted request in case the # message from the unexpected exception doesn't. - logger.warning(f"attempt {attempt + 1} of {retries + 1}{will_retry}: unexpected exception during GET {url}: {e}") + logger.warning(f"attempt {attempt + 1} of {retries + 1}: unexpected exception during GET {url}: {e}") if last_exception: logger.error(f"last retry of GET {url} failed with {last_exception}") raise last_exception raise Exception("unreachable") + + +def backoff_sleep_interval(interval: int, attempt: int, max_value: None | int = None, jitter: bool = True) -> float: + # this is an exponential backoff + val = interval * 2**attempt + if max_value and val > max_value: + val = max_value + if jitter: + val += random.uniform(0, 1) # noqa: S311 + # explanation of S311 disable: rng is not used cryptographically + return val diff --git a/tests/unit/cli/test-fixtures/full.yaml b/tests/unit/cli/test-fixtures/full.yaml index 90f80b57..c4737cb5 100644 --- a/tests/unit/cli/test-fixtures/full.yaml +++ b/tests/unit/cli/test-fixtures/full.yaml @@ -38,6 +38,7 @@ providers: nvd: runtime: *runtime request_timeout: 20 + request_retry_count: 50 overrides_enabled: true overrides_url: https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz oracle: diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index b154d7b3..8a19e061 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -249,6 +249,7 @@ def test_config(monkeypatch) -> None: api_key: secret overrides_enabled: false overrides_url: https://github.com/anchore/nvd-data-overrides/archive/refs/heads/main.tar.gz + request_retry_count: 10 request_timeout: 125 runtime: existing_input: keep diff --git a/tests/unit/cli/test_config.py b/tests/unit/cli/test_config.py index e578e20e..2e4ae0e7 100644 --- a/tests/unit/cli/test_config.py +++ b/tests/unit/cli/test_config.py @@ -84,6 +84,7 @@ def test_full_config(helpers): nvd=providers.nvd.Config( runtime=runtime_cfg, request_timeout=20, + request_retry_count=50, overrides_enabled=True, overrides_url="https://github.com/anchore/nvd-data-overrides/SOMEWHEREELSE/main.tar.gz", ), diff --git a/tests/unit/providers/nvd/test_api.py b/tests/unit/providers/nvd/test_api.py index d6019c90..c1fe0f6d 100644 --- a/tests/unit/providers/nvd/test_api.py +++ b/tests/unit/providers/nvd/test_api.py @@ -41,7 +41,7 @@ def test_cve_no_api_key(self, simple_mock, mocker): mocker.call( "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, - backoff_in_seconds=30, + retries=10, params="cveId=CVE-2020-0000", headers={"content-type": "application/json"}, timeout=1, @@ -59,7 +59,7 @@ def test_cve_single_cve(self, simple_mock, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="cveId=CVE-2020-0000", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -118,7 +118,7 @@ def test_cve_multi_page(self, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -126,7 +126,7 @@ def test_cve_multi_page(self, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="resultsPerPage=3&startIndex=3", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -134,7 +134,7 @@ def test_cve_multi_page(self, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="resultsPerPage=3&startIndex=6", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -156,7 +156,7 @@ def test_cve_pub_date_range(self, simple_mock, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="pubStartDate=2019-12-04T00:00:00&pubEndDate=2019-12-05T00:00:00", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -178,7 +178,7 @@ def test_cve_last_modified_date_range(self, simple_mock, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="lastModStartDate=2019-12-04T00:00:00&lastModEndDate=2019-12-05T00:00:00", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -197,7 +197,7 @@ def test_results_per_page(self, simple_mock, mocker): "https://services.nvd.nist.gov/rest/json/cves/2.0", subject.logger, params="resultsPerPage=5", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), @@ -214,7 +214,7 @@ def test_cve_history(self, simple_mock, mocker): "https://services.nvd.nist.gov/rest/json/cvehistory/2.0", subject.logger, params="cveId=CVE-2020-0000", - backoff_in_seconds=30, + retries=10, headers={"content-type": "application/json", "apiKey": "secret"}, timeout=1, ), diff --git a/tests/unit/utils/test_http.py b/tests/unit/utils/test_http.py index ba8d0a7e..5b89ed5c 100644 --- a/tests/unit/utils/test_http.py +++ b/tests/unit/utils/test_http.py @@ -52,10 +52,14 @@ def test_correct_number_of_retria(self, mock_requests, mock_sleep, mock_logger, @patch("time.sleep") @patch("requests.get") - def test_succeeds_if_retries_succeed(self, mock_requests, mock_sleep, mock_logger, error_response, success_response): + @patch("random.uniform") + def test_succeeds_if_retries_succeed( + self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response + ): + mock_uniform_random.side_effect = [0.1] mock_requests.side_effect = [error_response, success_response] http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=22) - mock_sleep.assert_called_with(22) + mock_sleep.assert_called_with(22.1) mock_logger.warning.assert_called() mock_logger.error.assert_not_called() mock_requests.assert_called_with("http://example.com/some-path", timeout=http.DEFAULT_TIMEOUT) @@ -74,7 +78,7 @@ def test_exponential_backoff_and_jitter( mock_requests.side_effect = [error_response, error_response, error_response, success_response] mock_uniform_random.side_effect = [0.5, 0.4, 0.1] http.get("http://example.com/some-path", mock_logger, backoff_in_seconds=10, retries=3) - assert mock_sleep.call_args_list == [call(10), call(10 * 2 + 0.5), call(10 * 4 + 0.4)] + assert mock_sleep.call_args_list == [call(10 + 0.5), call(10 * 2 + 0.4), call(10 * 4 + 0.1)] @patch("time.sleep") @patch("requests.get") @@ -91,8 +95,13 @@ def test_it_logs_the_url_on_failure(self, mock_requests, mock_sleep, mock_logger def test_it_log_warns_errors(self, mock_requests, mock_sleep, mock_logger, error_response, success_response): mock_requests.side_effect = [error_response, success_response] http.get("http://example.com/some-path", mock_logger, retries=1, backoff_in_seconds=33) - assert "HTTP ERROR" in mock_logger.warning.call_args.args[0] - assert "will retry in 33 seconds" in mock_logger.warning.call_args.args[0] + + logged_warnings = [call.args[0] for call in mock_logger.warning.call_args_list] + + assert any("HTTP ERROR" in message for message in logged_warnings), "Expected 'HTTP ERROR' in logged warnings." + assert any( + "will retry in 33 seconds" in message for message in logged_warnings + ), "Expected retry message in logged warnings." @patch("time.sleep") @patch("requests.get") @@ -109,16 +118,47 @@ def test_it_calls_status_handler(self, mock_requests, mock_sleep, mock_logger, e @patch("time.sleep") @patch("requests.get") + @patch("random.uniform") def test_it_retries_when_status_handler_raises( - self, mock_requests, mock_sleep, mock_logger, error_response, success_response + self, mock_uniform_random, mock_requests, mock_sleep, mock_logger, error_response, success_response ): + mock_uniform_random.side_effect = [0.25] mock_requests.side_effect = [success_response, error_response] status_handler = MagicMock() status_handler.side_effect = [Exception("custom exception"), None] result = http.get( "http://example.com/some-path", mock_logger, status_handler=status_handler, retries=1, backoff_in_seconds=33 ) - mock_sleep.assert_called_with(33) + mock_sleep.assert_called_with(33.25) # custom status handler raised the first time it was called, # so we expect the second mock response to be returned overall assert result == error_response + + +@pytest.mark.parametrize( + "interval, jitter, max_value, expected", + [ + ( + 30, # interval + False, # jitter + None, # max_value + [30, 60, 120, 240, 480, 960, 1920, 3840, 7680, 15360, 30720, 61440, 122880, 245760, 491520], # expected + ), + ( + 3, # interval + False, # jitter + 1000, # max_value + [3, 6, 12, 24, 48, 96, 192, 384, 768, 1000, 1000, 1000, 1000, 1000, 1000], # expected + ), + ], +) +def test_backoff_sleep_interval(interval, jitter, max_value, expected): + actual = [ + http.backoff_sleep_interval(interval, attempt, jitter=jitter, max_value=max_value) for attempt in range(len(expected)) + ] + + if not jitter: + assert actual == expected + else: + for i, (a, e) in enumerate(zip(actual, expected)): + assert a >= e and a <= e + 1, f"Jittered value out of bounds at attempt {i}: {a} (expected ~{e})"