Created
November 10, 2018 19:34
-
-
Save nrweir/cee3f63bb48498484e7134c731752c42 to your computer and use it in GitHub Desktop.
Script to produce predictions csv from geojsons
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import spacenetutilities.labeltools.coreLabelTools as cLT | |
import os | |
import argparse | |
import re | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument('--geojson_src_dir', '-j', type=str, required=True, | |
help='Path to the directory containing geojsons. If ' + | |
'the referenced directory contains subdirectories ' + | |
'that must be searched, the `--recursive` argument ' + | |
'must be passed.') | |
argparser.add_argument('--geotiff_src_dir', '-t', type=str, required=True, | |
help='Path to a directory containing geotiffs for ' + | |
'affine transformations. Does not need to have' + | |
'geotiffs for all angles, but must have all chips.') | |
argparser.add_argument('--output_path', '-o', default='predictions.csv', | |
help='Path to save predictions to. Defaults to ' + | |
'predictions.csv in the working directory.') | |
argparser.add_argument('--recursive', '-r', action='store_const', const=True, | |
default=False, | |
help='Should subdirectories within `geojson_src_dir`' + | |
' be searched? Defaults to False. If this flag is ' + | |
'passed, ALL .json or .geojson files in ALL ' + | |
'subdirectories contained within `geojson_src_dir` ' + | |
'will be included.') | |
args = argparser.parse_args() | |
def main(geojson_src_dir, geotiff_src_dir, | |
output_path='predictions.csv', recursive=False): | |
"""Generate a predictions.csv file.""" | |
print('Arguments: ') | |
print(args) | |
print() | |
json_paths = get_files_recursively( | |
geojson_src_dir, traverse_subdirs=recursive, | |
file_ext='json' # will match .geojson too | |
) | |
geotiffs = [f for f in os.listdir(geotiff_src_dir) if f.endswith('.tif')] | |
chip_summary_list = [] | |
for json_path in json_paths: | |
fname = os.path.split(json_path)[1] | |
collect_id, chip_id = get_chip_and_collect_ids(fname) | |
chip_summary = { | |
'chipName': [g for g in geotiffs if chip_id in g][0], | |
'geoVectorName': json_path, | |
'imageId': '_'.join([collect_id, chip_id]) | |
} | |
chip_summary_list.append(chip_summary) | |
cLT.createCSVSummaryFile(chip_summary_list, output_path, | |
rasterChipDirectory=geotiff_src_dir, | |
createProposalsFile=True, | |
competitionType='buildings', | |
pixPrecision=2) | |
def get_chip_and_collect_ids(fname): | |
"""Extract chip ID from filename using a regex search.""" | |
chip_re = re.compile(r'[0-9]{6}_[0-9]{7}') | |
collect_re = re.compile(r'Atlanta_nadir[0-9]{1,2}_catid_[0-9A-Z]{16}') | |
try: | |
chip_id = chip_re.findall(fname)[0] | |
except IndexError: | |
raise IndexError( | |
'There is no chip ID within filename {}'.format(fname)) | |
try: | |
collect_id = collect_re.findall(fname)[0] | |
except IndexError: | |
raise IndexError( | |
'There is no collect ID within filename {}'.format(fname)) | |
return (collect_id, chip_id) | |
def get_files_recursively(image_path, traverse_subdirs=False, file_ext='.tif'): | |
"""Get files from subdirs of `path`, joining them to the dir.""" | |
if traverse_subdirs: | |
walker = os.walk(image_path) | |
im_path_list = [] | |
for step in walker: | |
if not step[2]: # if there are no files in the current dir | |
continue | |
im_path_list += [os.path.join(step[0], fname) | |
for fname in step[2] if | |
fname.endswith(file_ext)] | |
return im_path_list | |
else: | |
return [os.path.join(image_path, f) for f in os.listdir(image_path) | |
if f.endswith(file_ext)] | |
if __name__ == '__main__': | |
main(**vars(args)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment