Skip to content

Instantly share code, notes, and snippets.

@joehand
Last active November 22, 2022 02:39
Show Gist options
  • Save joehand/498a1656e028c6163aa9 to your computer and use it in GitHub Desktop.
Save joehand/498a1656e028c6163aa9 to your computer and use it in GitHub Desktop.
Filter Shapefiles with Python
"""
Shapefile Filters
~~~~~~~~~
:copyright: (c) 2015 by Joe Hand, Santa Fe Institute.
:license: MIT
"""
import ntpath
import os
import re
import unicodedata
from osgeo import ogr
class ShapeFilter(object):
""" ShapeFilter
Filter single shapefile by a field,
creating new shapefiles for each value.
Returns list of new shapefile paths.
Usage:
shape_filter = ShapeFilter('my_shapefile.shp', 'some_field')
new_shapefiles = shape_filter.create_all_shapefiles()
"""
def __init__(self, shapefile, filter_field, out_dir='tmp'):
super(ShapeFilter, self).__init__()
self.shapefile = shapefile
self.field = filter_field
self.input_ds = ogr.Open('{}'.format(shapefile))
self.filename = self._get_filename()
self.out_dir = self._get_create_out_dir(out_dir)
def _get_create_out_dir(self, out_dir):
""" Return path for out_dir
Creates directory if it doesn't exist
"""
path = os.path.dirname(self.shapefile)
out_dir = os.path.join(path, out_dir)
if not os.path.exists(out_dir):
os.makedirs(out_dir)
return out_dir
def _get_filename(self):
""" Return filename for source shapefile
"""
return os.path.splitext(ntpath.basename(self.shapefile))[0]
def _slugify(self, value):
"""
From Django source.
Converts to lowercase, removes non-word characters (alphanumerics and
underscores) and converts spaces to hyphens. Also strips leading and
trailing whitespace.
"""
value = unicodedata.normalize('NFKD', value).encode(
'ascii', 'ignore').decode('ascii')
value = re.sub('[^\w\s-]', '', value).strip().lower()
return re.sub('[-\s]+', '-', value)
def _create_filtered_shapefile(self, value):
""" Return new shapefile path/name.shp
Creates a shapefile from source, based on filtered value
"""
input_layer = self.input_ds.GetLayer()
query_str = '"{}" = "{}"'.format(self.field, value)
# Filter by our query
input_layer.SetAttributeFilter(query_str)
driver = ogr.GetDriverByName('ESRI Shapefile')
out_shapefile = self._value_to_fname_path(value)
# Remove output shapefile if it already exists
if os.path.exists(out_shapefile):
driver.DeleteDataSource(out_shapefile)
out_ds = driver.CreateDataSource(out_shapefile)
out_layer = out_ds.CopyLayer(input_layer, str(value))
del input_layer, out_layer, out_ds
return out_shapefile
def _get_unique_values(self):
""" Return unique values of filter from source shapefile.
"""
sql = 'SELECT DISTINCT "{}" FROM {}'.format(
self.field, self.filename)
layer = self.input_ds.ExecuteSQL(sql)
return [feature.GetField(0) for features in layer]
def _value_to_fname_path(self, value):
""" Return full filename path for shapefile from query value
"""
value = value.split('-')[0] # Hack to make US City names prettier
value = self._slugify(value)
fname = "{}.shp".format(value)
return os.path.join(self.out_dir, fname)
def _shapefile_exists(self, value):
""" Return boolean
Does shapefile exist (uses query value, not fname).
"""
return os.path.isfile(self._value_to_fname_path(value))
def create_all_shapefiles(self, overwrite=False):
""" Returns list of new shapefiles
Creates shapefiles for filtered data from source shapefile.
"""
shapefiles = []
values = self._get_unique_values()
logging.info('Creating {} Shapefiles'.format(len(values)))
for val in values:
# TODO: make this multiprocess also, too slow for big filters
if overwrite or not self._shapefile_exists(val):
out_file = self._create_filtered_shapefile(val)
logging.debug('Shapefile created: {}'.format(val))
shapefiles.append(out_file)
else:
logging.debug('Shapefile exists, skipped: {}'.format(val))
shapefiles.append(self._value_to_fname_path(val))
return shapefiles
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment