Skip to content

Instantly share code, notes, and snippets.

@pdxjohnny
Created November 17, 2022 02:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pdxjohnny/1b44a4bcd176809a422d85299375132a to your computer and use it in GitHub Desktop.
Save pdxjohnny/1b44a4bcd176809a422d85299375132a to your computer and use it in GitHub Desktop.
diff --git a/cve_bin_tool/nvd_api.py b/cve_bin_tool/nvd_api.py
index 4a432a2..48b3ba8 100644
--- a/cve_bin_tool/nvd_api.py
+++ b/cve_bin_tool/nvd_api.py
@@ -8,6 +8,7 @@ Parameter values and more information: https://nvd.nist.gov/developers/products
"""
import asyncio
+import hashlib
import json
import math
import time
@@ -80,12 +81,15 @@ class NVD_API:
"Rejected": 0,
"Received": 0,
}
+ # TODO Remove extra await here
async with await session.get(
stats,
params={"reporttype": "countsbystatus"},
raise_for_status=True,
) as response:
data = await response.json()
+ with open(f"/tmp/v2/stats.json", "w") as fileobj:
+ json.dump(data, fileobj, indent=4, sort_keys=True)
for key in data["vulnsByStatusCounts"]:
cve_count[key["name"]] = int(key["count"])
return cve_count
@@ -136,8 +140,11 @@ class NVD_API:
self.logger.info("Fetching metadata from NVD...")
cve_count = await self.nvd_count_metadata(self.session, self.stats)
+ self.logger.info("Got metadata from NVD: %r", cve_count)
+ self.logger.info("Valiating NVD api...")
await self.validate_nvd_api()
+ self.logger.info("Valiated NVD api")
if self.invalid_api:
self.logger.warning(
@@ -158,6 +165,7 @@ class NVD_API:
f'Fetching updated CVE entries after {self.params["modStartDate"]}'
)
else:
+ self.logger.info("Fetch v2 all the updated CVE entries from the modified date. Subtracting 2-minute offset for updating cve entries")
self.params[
"lastModStartDate"
] = self.convert_date_to_nvd_date_api2(
@@ -180,8 +188,10 @@ class NVD_API:
progress.update(task)
progress.update(task, advance=1)
- else:
- self.total_results = cve_count["Total"] - cve_count["Rejected"]
+ self.total_results = cve_count["Total"] - cve_count["Rejected"]
+ self.logger.info(
+ f'self.total_results = Total: {cve_count["Total"]} - Rejected: {cve_count["Rejected"]}'
+ )
self.logger.info(f"Adding {self.total_results} CVE entries")
async def validate_nvd_api(self):
@@ -197,6 +207,16 @@ class NVD_API:
self.feed, params=param_dict, raise_for_status=True
) as response:
data = await response.json()
+ with open(
+ f"/tmp/v2/validate-{hashlib.sha384(str(param_dict).encode()).hexdigest()}.json",
+ "w",
+ ) as fileobj:
+ json.dump(data, fileobj, indent=4, sort_keys=True)
+ with open(
+ f"/tmp/v2/validate-{hashlib.sha384(str(param_dict).encode()).hexdigest()}-params.json",
+ "w",
+ ) as fileobj:
+ json.dump(param_dict, fileobj, indent=4, sort_keys=True)
if data.get("error", False):
self.logger.error(f"NVD API error: {data['error']}")
raise NVDKeyError(self.params["apiKey"])
@@ -227,6 +247,16 @@ class NVD_API:
self.logger.debug(f"Response received {response.status}")
if response.status == 200:
fetched_data = await response.json()
+ with open(
+ f"/tmp/v2/feed-{hashlib.sha384(str(param_dict).encode()).hexdigest()}.json",
+ "w",
+ ) as fileobj:
+ json.dump(fetched_data, fileobj, indent=4, sort_keys=True)
+ with open(
+ f"/tmp/v2/feed-{hashlib.sha384(str(param_dict).encode()).hexdigest()}-params.json",
+ "w",
+ ) as fileobj:
+ json.dump(param_dict, fileobj, indent=4, sort_keys=True)
if start_index == 0:
# Update total results in case there is discrepancy between NVD dashboard and API
@@ -238,6 +268,9 @@ class NVD_API:
self.total_results = (
fetched_data["totalResults"] - reject_count
)
+ self.logger.info(
+ f'self.total_results = Total: {fetched_data["totalResults"]} - Rejected: {reject_count}'
+ )
if self.api_version == "1.0":
self.all_cve_entries.extend(
fetched_data["result"]["CVE_Items"]
diff --git a/test/test_nvd_api.py b/test/test_nvd_api.py
index 91cf1fb..a6c26d5 100644
--- a/test/test_nvd_api.py
+++ b/test/test_nvd_api.py
@@ -2,16 +2,26 @@
# SPDX-License-Identifier: GPL-3.0-or-later
import os
+import types
import shutil
import tempfile
+import contextlib
from datetime import datetime, timedelta
from test.utils import LONG_TESTS
import pytest
+import aiohttp
+import httptest
+
+import alice.threats.vulns.serve.nvdstyle
from cve_bin_tool.cvedb import CVEDB
from cve_bin_tool.data_sources import nvd_source
-from cve_bin_tool.nvd_api import NVD_API
+from cve_bin_tool.nvd_api import (
+ NVD_API,
+ FEED as NVD_API_FEED,
+ NVD_CVE_STATUS,
+)
class TestNVD_API:
@@ -42,14 +52,47 @@ class TestNVD_API:
LONG_TESTS() != 1 or not os.getenv("nvd_api_key"),
reason="NVD tests run only in long tests",
)
- async def test_total_results_count(self):
+ @pytest.mark.parametrize(
+ "api_version, feed, stats",
+ [
+ (
+ "1.0",
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler),
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler),
+ ),
+ (
+ "2.0",
+ contextlib.nullcontext(types.SimpleNamespace(url=lambda: NVD_API_FEED)),
+ # httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler),
+ httptest.Server(alice.threats.vulns.serve.nvdstyle.NVDStyleHTTPHandler),
+ ),
+ ],
+ )
+ async def test_total_results_count(self, api_version, feed, stats):
"""Total results should be greater than or equal to the current fetched cves"""
- nvd_api = NVD_API(api_key=os.getenv("nvd_api_key") or "")
- await nvd_api.get_nvd_params(
- time_of_last_update=datetime.now() - timedelta(days=2)
- )
- await nvd_api.get()
- assert len(nvd_api.all_cve_entries) >= nvd_api.total_results
+ # TODO alice.nvd.TestHTTPServer will become either
+ # alice.nvd.TestNVDVersion_1_0 or alice.nvd.TestNVDVersion_2_0
+ # lambda *args: alice.nvd.TestHTTPServer(*args, directory=pathlib.Path(__file__).parent)
+ with feed as feed_http_server, stats as stats_http_server:
+ async with aiohttp.ClientSession() as session:
+ nvd_api = NVD_API(
+ feed=feed_http_server.url(),
+ stats=stats_http_server.url(),
+ api_key=os.getenv("nvd_api_key") or "",
+ session=session,
+ api_version=api_version,
+ )
+ nvd_api.logger.info(
+ "api_version: %s, feed: %s, stats: %s",
+ api_version,
+ feed_http_server.url(),
+ stats_http_server.url(),
+ )
+ await nvd_api.get_nvd_params(
+ time_of_last_update=datetime.now() - timedelta(days=2)
+ )
+ await nvd_api.get()
+ assert len(nvd_api.all_cve_entries) >= nvd_api.total_results
@pytest.mark.asyncio
@pytest.mark.skipif(
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment