Skip to content

Instantly share code, notes, and snippets.

@oliverchang
Created November 21, 2023 00:10
Show Gist options
  • Save oliverchang/1f2ef87e70b5e51b8a68b649117e05b3 to your computer and use it in GitHub Desktop.
Save oliverchang/1f2ef87e70b5e51b8a68b649117e05b3 to your computer and use it in GitHub Desktop.
auto-patch.py
#!/usr/bin/env python3
# Proof of concept demonstrating an automated guided remediation patching workflow.
# We progressively try more and more patches until tests fail.
import os.path
import re
import subprocess
import sys
from typing import List
PATCH_STRATEGIES = [
['--relock=false'], # Try every single transitive dependency upgrade without relocking or bumping direct dependencies.json.
['--relock=true'], # Relock the manifest and try direct depedency bumps.
# This could also include things like:
# '--min-severity=X' Minimum severity of vulnerabilities to consdier.
# '--max-depth=Y': Maximum (shortest) dependency depth
# '--major-upgrades={true/false}': Whether or not major upgrades are done.
# etc... which can help reduce/increase the scope of changes by prioritizing vulnerabilities according to these filters.
# e.g. ['--relock=true', '--major-upgrades=false', '--max-depth=5'], # Relock the manifest and try direct depedency bumps.
]
if len(sys.argv) < 2:
print(f'Usage: {sys.argv[0]} <project-directory>')
exit(1)
directory = sys.argv[1]
osv_fix_args = sys.argv[2:]
# check if the directory is within a git repo
if subprocess.call(['git', '-C', directory, 'rev-parse']):
print(f'{directory} is not part of a git repository')
exit(1)
manifest = os.path.join(directory, 'package.json')
lockfile = os.path.join(directory, 'package-lock.json')
def run_fix(n_patches: int, avoid_pkgs: List[str], strategy: List[str]) -> (List[str], int, int):
# restore package.json & package-lock.json
subprocess.check_call(['git', 'checkout', 'package.json', 'package-lock.json'], cwd=directory)
# run osv-fix and parse changes
cmd = ['osv-scanner', 'fix', '-a', '-m', manifest, '-l', lockfile] + osv_fix_args + strategy
# 0 is a magic value that means we try all patches.
if n_patches != 0:
cmd.extend(['--apply-top', str(n_patches)])
if len(avoid_pkgs) > 0:
cmd.extend(['--avoid-upgrades', ','.join(avoid_pkgs)])
try:
output = subprocess.check_output(cmd, text=True)
except subprocess.CalledProcessError as e:
output = (e.stdout or '') + (e.stderr or '')
upgraded = [m[1] for m in re.finditer(r'UPGRADED-PACKAGE: (.*),(.*),(.*)', output)]
remaining_vulns = None
unfixable_vulns = None
match = re.search(r'REMAINING-VULNS:\s*(\d+)', output)
if match:
remaining_vulns = int(match.group(1))
match = re.search(r'UNFIXABLE-VULNS:\s*(\d+)', output)
if match:
unfixable_vulns = int(match.group(1))
return upgraded, remaining_vulns, unfixable_vulns
def run_loop(strategy: List[str]) -> (List[str], int, int, List[str]):
valid = []
avoid = []
# 0 is a special value meaning that we try applying every patch. This is
# meant as a shortcut in case this would've succeeded anyway.
n_patches = 0
print('===== Attempting auto-patch with strategy', strategy, '=====')
remaining = None
total_unfixable = None
while True:
changes, remaining, unfixable = run_fix(n_patches, avoid, strategy)
if changes == valid:
# if the result of running osv-fix hasn't changed, then we've run out of patches to apply
break
print('===== Trying to upgrade:', changes, '=====')
print('===== Current blocklist:', avoid, '=====')
# check the tests
if subprocess.call(['npm', 'run', 'test'], cwd=directory): # tests failed
if n_patches == 0:
# First try with every single patch.
# Record the unfixable count using this, as it represents the real
# unfixable count if every possible package upgrade was allowed.
total_unfixable = unfixable
n_patches += 1
continue
print('===== Tests failed, blocklisting upgrades =====')
# add each new package to the avoid list
for c in changes:
if c not in valid:
avoid.append(c)
print('===== Current blocklist:', avoid, '=====')
else: # tests passed
if n_patches == 0:
valid = changes
break
# try now with the next patch
valid = changes
n_patches += 1
if valid:
print()
print('===== The following packages have been changed and verified against the tests: =====')
for v in valid:
print(v)
return valid, remaining, total_unfixable, avoid
best_strategy = None
best_changes = []
best_avoid = []
best_remaining = 10000000
best_unfixable = None
for strategy in PATCH_STRATEGIES:
changes, remaining, unfixable, avoid = run_loop(strategy)
if changes and remaining < best_remaining:
best_strategy = strategy
best_changes = changes
best_avoid = avoid
best_remaining = remaining
best_unfixable = unfixable
print()
print('===== Auto-patch completed with the following changed packages =====')
print('Best strategy:', best_strategy)
for v in best_changes:
print(v)
print('The follow packages cannot be upgraded due to failing tests:')
for v in best_avoid:
print(v)
print()
print(best_remaining, 'vulnerabilities remain')
if best_unfixable:
print(best_unfixable, 'vulnerabilities are impossible to fix by package upgrades')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment