Created
November 17, 2022 05:36
-
-
Save pdxjohnny/09d3051bea781a9b5e981374b8c1756a to your computer and use it in GitHub Desktop.
tests: add tests for NVD 2.0 API https://github.com/intel/cve-bin-tool/issues/2334 - 2022-11-16 @pdxjohnny Engineering Logs: https://github.com/intel/dffml/discussions/1406?sort=new#discussioncomment-4157129
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/cve_bin_tool/nvd_api.py b/cve_bin_tool/nvd_api.py | |
index 4a432a2..c7fc977 100644 | |
--- a/cve_bin_tool/nvd_api.py | |
+++ b/cve_bin_tool/nvd_api.py | |
@@ -8,6 +8,7 @@ Parameter values and more information: https://nvd.nist.gov/developers/products | |
""" | |
import asyncio | |
+import hashlib | |
import json | |
import math | |
import time | |
@@ -80,12 +81,15 @@ class NVD_API: | |
"Rejected": 0, | |
"Received": 0, | |
} | |
+ # TODO Remove extra await here | |
async with await session.get( | |
stats, | |
params={"reporttype": "countsbystatus"}, | |
raise_for_status=True, | |
) as response: | |
data = await response.json() | |
+ with open(f"/tmp/v2/stats.json", "w") as fileobj: | |
+ json.dump(data, fileobj, indent=4, sort_keys=True) | |
for key in data["vulnsByStatusCounts"]: | |
cve_count[key["name"]] = int(key["count"]) | |
return cve_count | |
@@ -130,14 +134,15 @@ class NVD_API: | |
if not self.session: | |
connector = aiohttp.TCPConnector(limit_per_host=19) | |
- self.session = RateLimiter( | |
- aiohttp.ClientSession(connector=connector, trust_env=True) | |
- ) | |
+ self.session = aiohttp.ClientSession(connector=connector, trust_env=True) | |
self.logger.info("Fetching metadata from NVD...") | |
cve_count = await self.nvd_count_metadata(self.session, self.stats) | |
+ self.logger.info("Got metadata from NVD: %r", cve_count) | |
+ self.logger.info("Valiating NVD api...") | |
await self.validate_nvd_api() | |
+ self.logger.info("Valiated NVD api") | |
if self.invalid_api: | |
self.logger.warning( | |
@@ -158,6 +163,7 @@ class NVD_API: | |
f'Fetching updated CVE entries after {self.params["modStartDate"]}' | |
) | |
else: | |
+ self.logger.info("Fetch v2 all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries") | |
self.params[ | |
"lastModStartDate" | |
] = self.convert_date_to_nvd_date_api2( | |
@@ -180,8 +186,10 @@ class NVD_API: | |
progress.update(task) | |
progress.update(task, advance=1) | |
- else: | |
- self.total_results = cve_count["Total"] - cve_count["Rejected"] | |
+ self.total_results = cve_count["Total"] - cve_count["Rejected"] | |
+ self.logger.info( | |
+ f'self.total_results = Total: {cve_count["Total"]} - Rejected: {cve_count["Rejected"]}' | |
+ ) | |
self.logger.info(f"Adding {self.total_results} CVE entries") | |
async def validate_nvd_api(self): | |
@@ -197,6 +205,16 @@ class NVD_API: | |
self.feed, params=param_dict, raise_for_status=True | |
) as response: | |
data = await response.json() | |
+ with open( | |
+ f"/tmp/v2/validate-{hashlib.sha384(str(param_dict).encode()).hexdigest()}.json", | |
+ "w", | |
+ ) as fileobj: | |
+ json.dump(data, fileobj, indent=4, sort_keys=True) | |
+ with open( | |
+ f"/tmp/v2/validate-{hashlib.sha384(str(param_dict).encode()).hexdigest()}-params.json", | |
+ "w", | |
+ ) as fileobj: | |
+ json.dump(param_dict, fileobj, indent=4, sort_keys=True) | |
if data.get("error", False): | |
self.logger.error(f"NVD API error: {data['error']}") | |
raise NVDKeyError(self.params["apiKey"]) | |
@@ -227,6 +245,16 @@ class NVD_API: | |
self.logger.debug(f"Response received {response.status}") | |
if response.status == 200: | |
fetched_data = await response.json() | |
+ with open( | |
+ f"/tmp/v2/feed-{hashlib.sha384(str(param_dict).encode()).hexdigest()}.json", | |
+ "w", | |
+ ) as fileobj: | |
+ json.dump(fetched_data, fileobj, indent=4, sort_keys=True) | |
+ with open( | |
+ f"/tmp/v2/feed-{hashlib.sha384(str(param_dict).encode()).hexdigest()}-params.json", | |
+ "w", | |
+ ) as fileobj: | |
+ json.dump(param_dict, fileobj, indent=4, sort_keys=True) | |
if start_index == 0: | |
# Update total results in case there is discrepancy between NVD dashboard and API | |
@@ -238,6 +266,9 @@ class NVD_API: | |
self.total_results = ( | |
fetched_data["totalResults"] - reject_count | |
) | |
+ self.logger.info( | |
+ f'self.total_results = Total: {fetched_data["totalResults"]} - Rejected: {reject_count}' | |
+ ) | |
if self.api_version == "1.0": | |
self.all_cve_entries.extend( | |
fetched_data["result"]["CVE_Items"] | |
@@ -260,15 +291,15 @@ class NVD_API: | |
self.logger.info( | |
f"Pausing requests for {self.interval} seconds" | |
) | |
- time.sleep(self.interval) | |
+ await asyncio.sleep(self.interval) | |
else: | |
- time.sleep(1) | |
+ await asyncio.sleep(1) | |
except Exception as error: | |
self.logger.debug(f"Failed to connect to NVD {error}") | |
self.logger.debug(f"Pausing requests for {self.interval} seconds") | |
self.failed_count += 1 | |
- time.sleep(self.interval) | |
+ await asyncio.sleep(self.interval) | |
async def get(self): | |
"""Calls load_nvd_request() multiple times to fetch all NVD feeds""" | |
diff --git a/test/test_nvd_api.py b/test/test_nvd_api.py | |
index 91cf1fb..c42485f 100644 | |
--- a/test/test_nvd_api.py | |
+++ b/test/test_nvd_api.py | |
@@ -2,16 +2,26 @@ | |
# SPDX-License-Identifier: GPL-3.0-or-later | |
import os | |
+import types | |
import shutil | |
import tempfile | |
+import contextlib | |
from datetime import datetime, timedelta | |
from test.utils import LONG_TESTS | |
import pytest | |
+import aiohttp | |
+import httptest | |
+ | |
+import alice.threats.vulns.serve.nvdstyle | |
from cve_bin_tool.cvedb import CVEDB | |
from cve_bin_tool.data_sources import nvd_source | |
-from cve_bin_tool.nvd_api import NVD_API | |
+from cve_bin_tool.nvd_api import ( | |
+ NVD_API, | |
+ FEED as NVD_API_FEED, | |
+ NVD_CVE_STATUS, | |
+) | |
class TestNVD_API: | |
@@ -42,14 +52,46 @@ class TestNVD_API: | |
LONG_TESTS() != 1 or not os.getenv("nvd_api_key"), | |
reason="NVD tests run only in long tests", | |
) | |
- async def test_total_results_count(self): | |
+ @pytest.mark.parametrize( | |
+ "api_version, feed, stats", | |
+ [ | |
+ ( | |
+ "1.0", | |
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler), | |
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler), | |
+ ), | |
+ ( | |
+ "2.0", | |
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler), | |
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler), | |
+ ), | |
+ ], | |
+ ) | |
+ async def test_total_results_count(self, api_version, feed, stats): | |
"""Total results should be greater than or equal to the current fetched cves""" | |
- nvd_api = NVD_API(api_key=os.getenv("nvd_api_key") or "") | |
- await nvd_api.get_nvd_params( | |
- time_of_last_update=datetime.now() - timedelta(days=2) | |
- ) | |
- await nvd_api.get() | |
- assert len(nvd_api.all_cve_entries) >= nvd_api.total_results | |
+ # TODO alice.nvd.TestHTTPServer will become either | |
+ # alice.nvd.TestNVDVersion_1_0 or alice.nvd.TestNVDVersion_2_0 | |
+ # lambda *args: alice.nvd.TestHTTPServer(*args, directory=pathlib.Path(__file__).parent) | |
+ with feed as feed_http_server, stats as stats_http_server: | |
+ async with aiohttp.ClientSession() as session: | |
+ nvd_api = NVD_API( | |
+ feed=feed_http_server.url(), | |
+ stats=stats_http_server.url(), | |
+ api_key=os.getenv("nvd_api_key") or "", | |
+ session=session, | |
+ api_version=api_version, | |
+ ) | |
+ nvd_api.logger.info( | |
+ "api_version: %s, feed: %s, stats: %s", | |
+ api_version, | |
+ feed_http_server.url(), | |
+ stats_http_server.url(), | |
+ ) | |
+ await nvd_api.get_nvd_params( | |
+ time_of_last_update=datetime.now() - timedelta(days=2) | |
+ ) | |
+ await nvd_api.get() | |
+ assert len(nvd_api.all_cve_entries) >= nvd_api.total_results | |
@pytest.mark.asyncio | |
@pytest.mark.skipif( |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment