Skip to content

Instantly share code, notes, and snippets.

@deeso
Created January 26, 2018 17:51
Show Gist options
  • Save deeso/4208f1d86adf3b65d150303e4ae61cac to your computer and use it in GitHub Desktop.
Save deeso/4208f1d86adf3b65d150303e4ae61cac to your computer and use it in GitHub Desktop.
Rate limit the IOCs submitted to VT
import string
from pymongo import MongoClient
from datetime import datetime
import traceback
import argparse
import sys
from threading import Thread, Lock
import logging
import time
from virus_total_apis import PublicApi as VTPublic
from virus_total_apis import PrivateApi as VTPrivate
logging.getLogger().setLevel(logging.DEBUG)
ch = logging.StreamHandler(sys.stdout)
ch.setLevel(logging.DEBUG)
formatter = logging.Formatter('[%(asctime)s - %(name)s] %(message)s')
ch.setFormatter(formatter)
logging.getLogger().addHandler(ch)
def mongo_callback(ioc, result):
if MONGO_HOST is None:
return False
con = MongoClient(MONGO_HOST, MONGO_PORT)
db = con[MONGO_DB]
db[ioc] = result
return True
class VTQueryObject(object):
def __init__(self, api_key, key_type='public',
queries_pm=4, callback=None):
self.vtapi = None
self.key = api_key
if key_type == 'public':
self.vtapi = VTPublic(self.key)
else:
self.vtapi = VTPrivate(self.key)
self.queries = []
self.last_query_time = None
self.waiting = False
self.queries_pm = queries_pm
self.wait_thread = None
self.wait_lock = Lock()
self.qlock = Lock()
self.completed = {}
def prune_queries(self):
q = self.queries
self.queries = []
for i in q:
if (datetime.now() - i).seconds < 60:
self.queries.append(i)
return len(q) != len(self.queries)
def need_to_wait(self):
# iterate of queries and remove last elements
self.prune_queries()
if len(self.queries) >= self.queries_pm:
self.wait_lock.acquire()
if self.wait_thread is not None and \
self.wait_thread.is_alive():
self.wait_lock.release()
return False
self.waiting = True
# set the timer thread
td = datetime.now() - self.queries[0]
self.wait_thread = Thread(target=self.wait_for_secs,
args=(td.seconds))
self.wait_lock.release()
return True
return False
def performed_query(self):
self.prune_queries()
self.queries.append(datetime.now())
def wait_for_secs(self, seconds=60):
time.sleep(seconds)
self.prune_queries()
self.waiting = False
def spinlock(self):
while self.waiting:
time.sleep(1.0)
def domain_report(self, ioc, wait=False, callback=None):
try:
while True:
if ioc in self.completed and \
self.completed[ioc] is not None:
return self.completed[ioc]
self.qlock.acquire()
if not self.need_to_wait():
self.performed_query()
r = self.vt.get_domain_report(ioc)
self.completed[ioc] = r
if callback is not None:
callback(ioc, r)
if self.callback is not None:
self.callback(ioc, r)
return r
self.qlock.release()
if not wait:
break
self.spinlock()
except:
traceback.print_exc()
finally:
self.qlock.release()
return None
def file_report(self, ioc, wait=False, callback=None):
try:
while True:
if ioc in self.completed and \
self.completed[ioc] is not None:
return self.completed[ioc]
self.qlock.acquire()
if not self.need_to_wait():
self.performed_query()
r = self.vt.get_file_report(ioc)
self.completed[ioc] = r
if callback is not None:
callback(ioc, r)
if self.callback is not None:
self.callback(ioc, r)
return r
self.qlock.release()
if not wait:
break
self.spinlock()
except:
traceback.print_exc()
finally:
self.qlock.release()
return None
def ip_report(self, ioc, wait=False, callback=None):
try:
while True:
if ioc in self.completed and \
self.completed[ioc] is not None:
return self.completed[ioc]
self.qlock.acquire()
if not self.need_to_wait():
self.performed_query()
r = self.vt.get_ip_report(ioc)
self.completed[ioc] = r
if callback is not None:
callback(ioc, r)
if self.callback is not None:
self.callback(ioc, r)
return r
self.qlock.release()
if not wait:
break
self.spinlock()
except:
traceback.print_exc()
finally:
self.qlock.release()
return None
@classmethod
def is_ip(cls, ioc):
il = ioc.strip().split('.')
p = all([i.isdigit() and int(i) < 256 and int(i) > 0 for i in il])
return len(il) == 4 and p
@classmethod
def is_hash(cls, ioc):
p = all([i for i in ioc if i in string.hexdigits])
g = len(ioc) in [32, 48, 64, 96, 128]
return g and p
def submit_iocs(self, iocs, wait=False, callback=None):
completed = 0
success = []
logging.info("Submitting %d iocs to VT." % len(iocs))
while completed < len(iocs):
completed += 1
ioc = iocs.pop()
r = self.submit_ioc(ioc, wait=wait, callback=callback)
if r is None:
iocs.append(ioc)
else:
success.append(r)
if completed % 100 == 0:
logging.info("completed %d submissions to VT." % (completed))
m = "Submitted %d submissions to VT with %d reports." % (completed, len(success))
logging.info(m)
return success, iocs
def submit_ioc(self, ioc, wait=False, callback=None):
if self.is_hash(ioc):
return self.get_file_report(ioc, wait=wait,
callback=callback)
elif self.is_ip(ioc):
return self.get_ip_report(ioc, wait=wait,
callback=callback)
return self.get_domain_report(ioc, wait=wait,
callback=callback)
CMD_DESC = 'Query VT up to the limit.'
MONGO_HOST = None
MONGO_PORT = 27017
MONGO_DB = 'ioc_reports'
parser = argparse.ArgumentParser(description=CMD_DESC)
parser.add_argument('-key', type=str, default=None,
help='api key')
parser.add_argument('-mongo_host', default=MONGO_HOST, action='store_true',
help='mongo host to put results in')
parser.add_argument('-mongo_port', default=MONGO_PORT,
help='mongo port to put results in')
parser.add_argument('-mongo_db', default=MONGO_DB,
help='mongo db to put results in')
parser.add_argument('-ioc_file', default=None, type=str,
help='ioc_file to read from')
parser.add_argument('-ioc', default=None, type=str,
help='ioc to search for')
parser.add_argument('-queries_pm', default=4, type=int,
help='limit queries per minute')
if __name__ == '__main__':
args = parser.parse_args()
iocs = []
if args.ioc is not None:
iocs.append(args.ioc)
if args.ioc_file is not None:
lines = open(args.ioc_file).read().splitlines()
for l in lines:
iocs.append(l.strip())
MONGO_HOST = args.mongo_host
MONGO_PORT = args.mongo_port
MONGO_DB = args.mongo_db
if MONGO_HOST is None:
raise Exception("Mongo host must be specified")
elif len(iocs) == 0:
raise Exception("No IOCs provided")
vtqo = VTQueryObject(args.key, callback=mongo_callback)
vtqo.submit_iocs(iocs)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment