Skip to content

Instantly share code, notes, and snippets.

@bbayles
Created February 21, 2017 20:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bbayles/307f4a433fee706aa32eabf05abce5dd to your computer and use it in GitHub Desktop.
Save bbayles/307f4a433fee706aa32eabf05abce5dd to your computer and use it in GitHub Desktop.
# Copyright 2017 Observable Networks
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function, division, unicode_literals
# python builtins
import logging
from csv import DictWriter
from datetime import datetime, timedelta
from glob import iglob
from gzip import open as gz_open
from os import environ
from os.path import basename, join, splitext
from subprocess import check_call
# third party
from azure.storage.blob import BlockBlobService
import vendor.requests as requests
# local
from pusher import Pusher
from utils import create_dirs, utcnow
# logging setup
FORMAT = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
logging.basicConfig(level=logging.INFO, format=FORMAT)
logging.captureWarnings(True)
py_warnings_logger = logging.getLogger('py.warnings')
py_warnings_logger.setLevel(logging.ERROR)
urllib3_logger = logging.getLogger('requests.packages.urllib3')
urllib3_logger.setLevel(logging.ERROR)
# environment variables
ENV_AZURE_PCAP_DIR = 'OBSRVBL_AZURE_PCAP_DIR'
DEFAULT_AZURE_PCAP_DIR = './logs'
ENV_AZURE_TENANT_ID = 'OBSRVBL_AZURE_TENANT_ID'
ENV_AZURE_APPLICATION_ID = 'OBSRVBL_AZURE_APPLICATION_ID'
ENV_AZURE_APPLICATION_SECRET = 'OBSRVBL_AZURE_APPLICATION_SECRET'
ENV_AZURE_SUBSCRIPTION_ID = 'OBSRVBL_AZURE_SUBSCRIPTION_ID'
ENV_AZURE_RESOURCE_GROUP = 'OBSRVBL_AZURE_RESOURCE_GROUP'
ENV_AZURE_NETWORK_WATCHER = 'OBSRVBL_AZURE_NETWORK_WATCHER'
ENV_AZURE_STORAGE_ACCOUNT = 'OBSRVBL_AZURE_STORAGE_ACCOUNT'
ENV_AZURE_STORAGE_CONTAINER = 'OBSRVBL_AZURE_STORAGE_CONTAINER'
ENV_AZURE_STORAGE_TOKEN = 'OBSRVBL_AZURE_STORAGE_TOKEN'
ENV_AZURE_VIRTUAL_MACHINES = 'OBSRVBL_AZURE_VIRTUAL_MACHINES'
# module variables
AZURE_AUTHENTICATION_ENDPOINT = 'https://login.microsoftonline.com/'
AZURE_MANAGEMENT_RESOURCE = 'https://management.core.windows.net/'
CAPTURE_SNAPLEN = 64
CAPTURE_TIMEOUT_SECONDS = 600
CAPTURE_BYTE_LIMIT = 4294967295
CSV_HEADER = [
'srcaddr',
'dstaddr',
'srcport',
'dstport',
'protocol',
'bytes_in',
'bytes_out',
'start',
'end',
]
TSHARK_ARGS = [
'tshark',
'-r', None,
'-T', 'fields',
'-e', 'ip.src',
'-e', 'ip.dst',
'-e', 'tcp.srcport',
'-e', 'tcp.dstport',
'-e', 'udp.srcport',
'-e', 'udp.dstport',
'-e', 'ip.proto',
'-e', 'frame.len',
'-e', 'frame.time_epoch',
'ip',
]
POLL_SECONDS = 60
class AzurePusher(Pusher):
def __init__(self, *args, **kwargs):
input_dir = environ.get(ENV_AZURE_PCAP_DIR, DEFAULT_AZURE_PCAP_DIR)
for key, default in (
('file_fmt', '%Y%m%d%H%M%S'),
('prefix_len', 14),
('data_type', 'csv'),
('input_dir', input_dir),
('poll_seconds', POLL_SECONDS),
):
kwargs.setdefault(key, default)
self.tar_mode = 'w'
self.tenant_id = environ[ENV_AZURE_TENANT_ID]
self.application_id = environ[ENV_AZURE_APPLICATION_ID]
self.application_secret = environ[ENV_AZURE_APPLICATION_SECRET]
self.subscription_id = environ[ENV_AZURE_SUBSCRIPTION_ID]
self.resource_group = environ[ENV_AZURE_RESOURCE_GROUP]
self.network_watcher = environ[ENV_AZURE_NETWORK_WATCHER]
self.storage_account = environ[ENV_AZURE_STORAGE_ACCOUNT]
self.storage_container = environ[ENV_AZURE_STORAGE_CONTAINER]
self.storage_token = environ[ENV_AZURE_STORAGE_TOKEN]
vms = environ[ENV_AZURE_VIRTUAL_MACHINES]
if vms:
self.virtual_machines = [x.strip() for x in vms.split(',')]
else:
self.virtual_machines = []
self.access_token = None
self.access_expires = datetime.min
self.to_delete = []
super(AzurePusher, self).__init__(*args, **kwargs)
def _set_access_token(self):
token_url = (
AZURE_AUTHENTICATION_ENDPOINT +
'{}/oauth2/token?api-version=1.0'.format(self.tenant_id)
)
token_data = {
'grant_type': 'client_credentials',
'resource': AZURE_MANAGEMENT_RESOURCE,
'client_id': self.application_id,
'client_secret': self.application_secret,
}
token_response = requests.post(url=token_url, data=token_data)
token_data = token_response.json()
self.access_token = token_data['access_token']
self.access_expires = datetime.utcfromtimestamp(
float(token_data['expires_on'])
)
def _get_headers(self, **kwargs):
headers = {'Authorization': 'Bearer ' + self.access_token}
headers.update(kwargs)
return headers
def _list_captures(self):
url = (
'https://management.azure.com/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Network/networkWatchers/{networkWatcher}/'
'packetCaptures?api-version=2016-03-30'
).format(
subscriptionId=self.subscription_id,
resourceGroup=self.resource_group,
networkWatcher=self.network_watcher,
)
response = requests.get(url=url, headers=self._get_headers())
return response
def _get_capture_status(self, capture_name):
url = (
'https://management.azure.com/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Network/networkWatchers/{networkWatcher}/'
'packetCaptures/{packetCapture}/querystatus?api-version=2016-03-30'
).format(
subscriptionId=self.subscription_id,
resourceGroup=self.resource_group,
networkWatcher=self.network_watcher,
packetCapture=capture_name,
)
response = requests.post(url, headers=self._get_headers())
return response
def _delete_capture(self, capture_name):
url = (
'https://management.azure.com/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Network/networkWatchers/{networkWatcher}/'
'packetCaptures/{packetCapture}?api-version=2016-03-30'
).format(
subscriptionId=self.subscription_id,
resourceGroup=self.resource_group,
networkWatcher=self.network_watcher,
packetCapture=capture_name,
)
response = requests.delete(url, headers=self._get_headers())
return response
def _start_capture(self, vm_name, now):
dt_str = now.strftime('%Y%m%d%H%M%S')
params = {
'subscriptionId': self.subscription_id,
'resourceGroup': self.resource_group,
'networkWatcher': self.network_watcher,
'packetCapture': 'obsrvbl_{}_{}'.format(vm_name, dt_str),
'vmName': vm_name,
'storageAccount': self.storage_account,
'storageContainer': self.storage_container,
}
url = (
'https://management.azure.com/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Network/networkWatchers/{networkWatcher}/'
'packetCaptures/{packetCapture}?api-version=2016-03-30'
).format(**params)
data = {
'properties': {
'target': (
'/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Compute/virtualMachines/{vmName}'
).format(**params),
'bytesToCapturePerPacket': CAPTURE_SNAPLEN,
'totalBytesPerSession': CAPTURE_BYTE_LIMIT,
'timeLimitInSeconds': CAPTURE_TIMEOUT_SECONDS,
'storageLocation': {
'storageId': (
'/subscriptions/{subscriptionId}/'
'resourceGroups/{resourceGroup}/'
'providers/Microsoft.Storage/'
'storageAccounts/{storageAccount}'
).format(**params),
'storagePath': (
'https://{storageAccount}.blob.core.windows.net/'
'{storageContainer}/{packetCapture}.cap'
).format(**params),
},
}
}
response = requests.put(url, json=data, headers=self._get_headers())
return response
def _move_captures(self):
create_dirs(self.input_dir)
service = BlockBlobService(
self.storage_account, sas_token=self.storage_token
)
retrieved_blobs = []
for blob in service.list_blobs(self.storage_container):
blob_name = blob.name
file_path = join(self.input_dir, blob_name)
logging.info('Downloading %s to %s', blob_name, file_path)
service.get_blob_to_path(
self.storage_container, blob_name, file_path
)
retrieved_blobs.append(blob_name)
for blob_name in retrieved_blobs:
logging.info('Deleting %s', blob_name)
service.delete_blob(self.storage_container, blob_name)
def _convert_pcap(self, infile_path, outfile_path):
tshark_args = TSHARK_ARGS[:]
tshark_args[2] = infile_path
with open(outfile_path, 'wb') as outfile:
check_call(tshark_args, stdout=outfile)
def _convert_txt(self, txt_path):
with open(txt_path, 'r') as infile:
for line in infile:
fields = line.strip().split('\t')
timestamp = str(int(float(fields[8])))
yield {
'srcaddr': fields[0],
'dstaddr': fields[1],
'srcport': fields[2] or fields[4],
'dstport': fields[3] or fields[5],
'protocol': fields[6],
'bytes_in': '0',
'bytes_out': fields[7],
'start': timestamp,
'end': timestamp,
}
def _process_captures(self, now):
# Turn each .cap file into a .txt file
cap_files = sorted(iglob(join(self.input_dir, '*.cap')))
txt_files = []
for cap_path in cap_files:
logging.info('Processing %s', cap_path)
cap_name = splitext(basename(cap_path))[0]
txt_path = join(self.input_dir, cap_name + '.txt')
try:
self._convert_pcap(cap_path, txt_path)
finally:
self._remove_file(cap_path)
txt_files.append(txt_path)
if not txt_files:
return
# Combine the .txt files into a .csv.gz file
csv_name = '{}.csv.gz'.format(now.strftime(self.file_fmt))
csv_path = join(self.input_dir, csv_name)
with gz_open(csv_path, 'wt') as outfile:
writer = DictWriter(outfile, CSV_HEADER)
writer.writeheader()
for txt_path in txt_files:
writer.writerows(self._convert_txt(txt_path))
self._remove_file(txt_path)
def execute(self, now=None):
now = now or utcnow()
# Retrieve an access token if ours is near or past expiration
if now + timedelta(seconds=POLL_SECONDS) >= self.access_expires:
logging.info('Setting access token')
self._set_access_token()
# List captures
capture_map = {}
for item in self._list_captures().json()['value']:
logging.info('Checking capture %s', item.get('name'))
capture_name = item['name']
if not capture_name.startswith('obsrvbl_'):
continue
prefix, vm_name, dt_str = capture_name.split('_')
capture_map[vm_name] = capture_name
logging.info('Tracking capture %s', capture_name)
# Delete previously-completed captures
while self.to_delete:
capture_name = self.to_delete.pop()
logging.info('Deleting capture %s', capture_name)
self._delete_capture(capture_name)
# Check for completed captures
for vm_name, capture_name in capture_map.iteritems():
status_response = self._get_capture_status(capture_name)
if status_response.status_code not in [200, 202]:
logging.info(
'Could not retrieve status for %s', capture_name
)
continue
status_data = status_response.json()
capture_status = status_data.get('packetCaptureStatus')
logging.info('%s is in state: %s', capture_name, capture_status)
if capture_status in ['Error', 'Stopped']:
self.to_delete.append(capture_name)
# Start captures for everything that was clear
start_captures = sorted(set(self.virtual_machines) - set(capture_map))
for vm_name in start_captures:
logging.info('Starting capture for %s', vm_name)
self._start_capture(vm_name, now)
# Move captures to input dir
self._move_captures()
# Process captures with TShark
self._process_captures(now)
# Push the files to the web proxy
super(AzurePusher, self).execute(now=now)
if __name__ == '__main__':
pusher = AzurePusher()
pusher.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment