Created
October 27, 2020 16:44
-
-
Save dolohow/aa0b45b0137aa73eef0b1d26d9a5e795 to your computer and use it in GitHub Desktop.
wp_merge
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
This module combines input from CSV file with API endpoint data and outputs | |
result to CSV file. | |
Example: | |
Usage: | |
$ wpe_merge input_csv_file output_csv_file | |
To show help use: | |
$ wpe_merge --help | |
""" | |
import argparse | |
import csv | |
import logging | |
import sys | |
from typing import Tuple, Dict, Generator, cast, TextIO | |
import requests | |
def take_from_csv(names: Tuple[str, ...], line: Dict[str, str]) -> Dict[str, str]: | |
"""Helper that returns specified fields from provided dictionary. | |
Args: | |
names: Dict keys to extract from `line` dictionary. | |
line: Dictionary that holds data to extract. | |
Returns: | |
Dictionary with extracted keys. | |
""" | |
return {name: line[name] for name in names} | |
class CSV(csv.DictReader): | |
"""CSV reader that perfroms sanity check for number of elements""" | |
def __init__(self, f: TextIO, required: Tuple[str, ...] = ()): | |
super().__init__(f) | |
self.f = f | |
self.required = required | |
def __next__(self) -> Dict[str, str]: | |
elem = super().__next__() | |
if None in elem.values(): | |
return self.__next__() | |
return elem | |
def validate(self) -> None: | |
"""Validates if CSV contains required fields | |
Raises: | |
KeyError: If key not found in header of CSV file | |
""" | |
header = next(self.reader) | |
self.fieldnames = header | |
for required in self.required: | |
if required not in header: | |
raise KeyError("Missing required keys in CSV") | |
class API: | |
"""Performs operations on API.""" | |
def __init__(self, base_url: str): | |
""" | |
Args: | |
base_url: Base url for API endpoint | |
""" | |
self.base_url = base_url | |
self.api_url = "/v1/accounts/{account_id}" | |
def get_account(self, account_id: int) -> Dict[str, str]: | |
"""Returns account details. | |
Args: | |
account_id: ID of account | |
""" | |
request = requests.get( | |
self.base_url + self.api_url.format(account_id=account_id) | |
) | |
request.raise_for_status() | |
return cast(Dict[str, str], request.json()) | |
class Merger: | |
"""Merges CSV file input with data from API.""" | |
def __init__(self, api: API, data: csv.DictReader): | |
self.api = api | |
self.data = data | |
def run(self) -> Generator[Dict[str, str], None, None]: | |
"""Runs merging operation. | |
Yields: | |
Dictionary with merged data | |
""" | |
for line in self.data: | |
try: | |
account_id = int(line["Account ID"]) | |
except ValueError: | |
logging.warning( | |
"Could not convert 'Account ID' to int for account_id='%s'", | |
line["Account ID"], | |
) | |
continue | |
try: | |
api_data = self.api.get_account(account_id) | |
except requests.HTTPError: | |
logging.warning( | |
"Could not fetch account status for account_id='%s'", account_id | |
) | |
continue | |
if ("status" or "created_on") not in api_data: | |
continue | |
yield { | |
**take_from_csv(("Account ID", "First Name", "Created On"), line), | |
"Status": api_data["status"], | |
"Status Set On": api_data["created_on"], | |
} | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"input_file", type=argparse.FileType("r"), help="Input CSV file" | |
) | |
parser.add_argument( | |
"output_file", type=argparse.FileType("w"), help="Output CSV file" | |
) | |
args = parser.parse_args() | |
reader = CSV(args.input_file, required=("Account ID", "First Name", "Created On")) | |
try: | |
reader.validate() | |
except KeyError as error: | |
logging.fatal(error) | |
sys.exit(1) | |
writer = csv.DictWriter( | |
args.output_file, | |
fieldnames=( | |
"Account ID", | |
"First Name", | |
"Created On", | |
"Status", | |
"Status Set On", | |
), | |
) | |
merger = Merger(API("http://interview.wpengine.io"), reader) | |
writer.writeheader() | |
for line_to_write in merger.run(): | |
writer.writerow(line_to_write) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import unittest | |
import logging | |
import io | |
from unittest.mock import MagicMock, patch | |
import requests | |
from main import API, CSV, Merger | |
logging.disable(logging.CRITICAL) | |
class TestCSV(unittest.TestCase): | |
def test_csv(self): | |
raw = io.StringIO("a,b\n1,2") | |
csv = CSV(raw) | |
elem = next(csv) | |
self.assertDictEqual({"a": "1", "b": "2"}, elem) | |
def test_csv_skip_none(self): | |
raw = io.StringIO("a,b\n1\n1,2") | |
csv = CSV(raw) | |
elem = next(csv) | |
self.assertDictEqual({"a": "1", "b": "2"}, elem) | |
def test_csv_validate_ok(self): | |
raw = io.StringIO("a,b\n1,2") | |
csv = CSV(raw, required=("a",)) | |
csv.validate() | |
elem = next(csv) | |
self.assertDictEqual({"a": "1", "b": "2"}, elem) | |
def test_csv_validate_missing(self): | |
raw = io.StringIO("b,c\n1,2") | |
csv = CSV(raw, required=("a",)) | |
with self.assertRaises(KeyError): | |
csv.validate() | |
class TestAPI(unittest.TestCase): | |
@patch("main.requests.get") | |
def test_pull_account(self, mock_request): | |
api = API("http://fake.url") | |
api.get_account(1) | |
mock_request.assert_called_with("http://fake.url/v1/accounts/1") | |
class TestMerger(unittest.TestCase): | |
def setUp(self): | |
self.data = ( | |
{ | |
"Account ID": "12345", | |
"Account Name": "lexcorp", | |
"First Name": "Lex", | |
"Created On": "2011-01-12", | |
}, | |
{ | |
"Account ID": "8171", | |
"Account Name": "latveriaembassy", | |
"First Name": "Victor", | |
"Created On": "2014-11-19", | |
}, | |
) | |
self.api_data = ( | |
{ | |
"status": "good", | |
"created_on": "2011-01-12", | |
}, | |
{ | |
"status": "closed", | |
"created_on": "2015-09-01", | |
}, | |
) | |
def test_run(self): | |
api_mock = MagicMock() | |
merger = Merger(api_mock, self.data) | |
api_mock.get_account.side_effect = self.api_data | |
for n, line in enumerate(merger.run()): | |
self.assertDictEqual( | |
{ | |
"Account ID": self.data[n]["Account ID"], | |
"First Name": self.data[n]["First Name"], | |
"Created On": self.data[n]["Created On"], | |
"Status": self.api_data[n]["status"], | |
"Status Set On": self.api_data[n]["created_on"], | |
}, | |
line, | |
) | |
def test_run_invalid_account_id(self): | |
self.data[0]["Account ID"] = "a" | |
api_mock = MagicMock() | |
merger = Merger(api_mock, self.data) | |
api_mock.get_account.side_effect = self.api_data | |
result = tuple(merger.run()) | |
self.assertEqual(len(result), 1) | |
def test_run_http_error(self): | |
api_mock = MagicMock() | |
merger = Merger(api_mock, self.data) | |
api_mock.get_account.side_effect = (self.api_data[0], requests.HTTPError) | |
result = tuple(merger.run()) | |
self.assertEqual(len(result), 1) | |
def test_run_malformed_api_data(self): | |
api_mock = MagicMock() | |
merger = Merger(api_mock, self.data) | |
self.api_data[0].pop("status") | |
api_mock.get_account.side_effect = self.api_data | |
result = tuple(merger.run()) | |
self.assertEqual(len(result), 1) | |
if __name__ == "__main__": | |
unittest.main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment