Skip to content

Instantly share code, notes, and snippets.

@lossyrob
Last active August 25, 2019 15:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lossyrob/46062f74d38afc5eb6bcd7dd4bb62c3e to your computer and use it in GitHub Desktop.
Save lossyrob/46062f74d38afc5eb6bcd7dd4bb62c3e to your computer and use it in GitHub Desktop.
STAC Python model strawdog
"""STAC Model classes.
"""
import os
import json
from copy import copy, deepcopy
STAC_VERSION = '0.7.0'
class STAC_IO:
"""Methods used to read and save STAC json.
Allows users of the library to set their own methods
(e.g. for reading and writing from cloud storage)
"""
def default_read_text_method(uri):
with open(uri) as f:
return f.read()
def default_write_text_method(uri, txt):
with open(uri, 'w') as f:
f.write(txt)
read_text_method = default_read_text_method
write_text_method = default_write_text_method
@classmethod
def read_text(cls, uri):
return cls.read_text_method(uri)
@classmethod
def write_text(cls, uri, txt):
cls.write_text_method(uri, txt)
def read_stac_json(uri, root=None, parent=None):
d = json.loads(STAC_IO.read_text(uri))
# Check the type
if 'type' in d:
return Item.from_dict(d)
elif 'bbox' in d:
return Collection.from_dict(d)
else:
return Catalog.from_dict(d)
def save_json(uri, json_dict):
dirname = os.path.dirname(uri)
if not os.path.isdir(dirname):
os.makedirs(dirname)
STAC_IO.write_text(uri, json.dumps(json_dict, indent=4))
class LinkMixin:
"""Methods for working with links. Requires mixing class has 'links' property"""
def add_link(self, link):
self.links.append(link)
def get_single_link(self, rel):
return next((l for l in self.links if l.rel == rel), None)
def get_root(self):
root_link = self.get_single_link('root')
if root_link:
return root_link.resolve_stac_object().target
else:
return None
def set_root(self, root):
self.links = [l for l in self.links if l.rel != 'root']
self.links.append(Link.root(root))
def get_parent(self):
parent_link = self.get_single_link('parent')
if parent_link:
return parent_link.resolve_stac_target().target
else:
return None
def set_parent(self, parent):
self.links = [l for l in self.links if l.rel != 'parent']
self.links.append(Link.parent(parent))
def get_self_href(self):
parent_link = self.get_single_link('self')
if parent_link:
return parent_link.target
else:
return None
def set_self_href(self, href):
self.links = [l for l in self.links if l.rel != 'self']
self.links.append(Link.self_href(href))
def get_stac_objects(self, rel, parent=None):
result = []
for i in range(0, len(self.links)):
link = self.links[i]
if link.rel == rel:
link.resolve_stac_object(root=self.get_root(), parent=parent)
result.append(link.target)
return result
def get_links(self, rel):
return [l for l in self.links if l.rel == rel]
class Catalog(LinkMixin):
def __init__(self, id, description, title=None, href=None):
self.id = id
self.description = description
self.title = title
self.links = [Link.root(self)]
if href is not None:
self.set_self_href(href)
def add_child(self, child, title=None):
child.set_root(self.get_root())
child.set_parent(self)
self.add_link(Link.child(child, title=title))
def add_item(self, item, title=None):
item.set_root(self.get_root())
item.set_parent(self)
self.add_link(Link.item(item, title=title))
def get_children(self):
return self.get_stac_objects('child')
def get_child_links(self):
return self.get_links('child')
def clear_children(self):
self.links = [l for l in self.links if l.rel != 'child']
return self
def get_items(self):
return self.get_stac_objects('item')
def clear_items(self):
self.links = [l for l in self.links if l.rel != 'item']
return self
def get_all_items(self):
"""Get all items from this catalog and all subcatalogs."""
items = self.get_items()
for child in self.get_children():
items += child.get_all_items()
return items
def get_item_links(self):
return self.get_links('item')
def to_dict(self):
d = {
'id': self.id,
'stac_version': STAC_VERSION,
'description': self.description,
'links': [l.to_dict() for l in self.links]
}
if self.title is not None:
d['title'] = self.title
return d
def clone(self):
clone = Catalog(id=self.id,
description=self.description,
title=self.title)
clone.links = [l.clone() for l in self.links]
return clone
def set_uris_from_root(self, root_uri):
self.set_self_href(os.path.join(root_uri, 'catalog.json'))
for child in self.get_children():
child_root = os.path.join(root_uri, '{}/'.format(child.id))
child.set_uris_from_root(child_root)
for item in self.get_items():
item.set_self_href(os.path.join(root_uri, '{}.json'.format(item.id)))
def save(self):
for child_link in self.get_child_links():
if child_link.is_resolved():
child_link.target.save()
for item_link in self.get_item_links():
if item_link.is_resolved():
item_link.target.save()
save_json(self.get_self_href(), self.to_dict())
def map_items(self, item_mapper):
"""Creates a copy of a catalog, with each item passed through the item_mapper function.
Args:
item_mapper: A function that takes in an item, and returns either an item or list of items.
The item that is passed into the item_mapper is a copy, so the method can mutate it safetly.
"""
new_cat = self.clone()
new_cat.clear_children()
new_cat.clear_items()
for child_link in self.get_child_links():
new_link = child_link.clone().resolve_stac_object()
new_link.target = new_link.target.map_items(item_mapper)
new_cat.add_link(new_link)
for item_link in self.get_item_links():
new_link = item_link.clone().resolve_stac_object()
new_link.target = item_mapper(new_link.target.clone())
new_cat.add_link(new_link)
return new_cat
def map_assets(self, asset_mapper):
"""Creates a copy of a catalog, with each Asset for each Item passed
through the asset_mapper function.
Args:
asset_mapper: A function that takes in an Asset, and returns either an Asset or list of Assets.
The Asset that is passed into the item_mapper is a copy, so the method can mutate it safetly.
"""
def item_mapper(item):
new_assets = list(map(asset_mapper, item.assets))
item.assets = new_assets
return item
return self.map_items(item_mapper)
@staticmethod
def from_dict(d):
id = d['id']
description = d['description']
title = d.get('title')
cat = Catalog(id=id,
description=description,
title=title)
for l in d['links']:
cat.add_link(Link.from_dict(l))
return cat
@staticmethod
def from_file(uri):
d = json.loads(STAC_IO.read_text(uri))
return Catalog.from_dict(d)
class Collection(Catalog):
def __init__(self, id, description, extent, title=None, href=None, license='proprietary'):
super(Collection, self).__init__(id, description, title, href)
self.extent = extent
self.license = license
def to_dict(self):
d = super(Collection, self).to_dict()
d['extent'] = self.extent.to_dict()
d['license'] = self.license
return d
def clone(self):
col = Collection(id=self.id,
description=self.description,
extent=self.extent.clone(),
title=self.title,
license=self.license)
clone.links = [l.clone for l in self.links]
return clone
@staticmethod
def from_dict(d):
id = d['id']
description = d['description']
extent = Extent.from_dict(d['extent'])
title = d.get('title')
collection = Collection(id=id,
description=description,
extent=extent,
title=title)
for l in d['links']:
collection.add_link(Link.from_dict(l))
return collection
class Item(LinkMixin):
def __init__(self,
id,
geometry,
bbox,
properties,
stac_extensions=None,
href=None):
self.id = id
self.geometry = geometry
self.bbox = bbox
self.properties = properties
self.stac_extensions = stac_extensions
self.links = []
self.assets = {}
if href is not None:
self.set_self_href(href)
def add_asset(self, key, href, title=None, media_type=None):
self.assets[key] = Asset(href, title=title, media_type=media_type)
def to_dict(self):
links = list(map(lambda x: x.to_dict(), self.links))
assets = dict(map(lambda x: (x[0], x[1].to_dict()), self.assets.items()))
if not 'datetime' in self.properties:
self.properties['datetime'] = "2016-05-03T13:21:30.040Z"
d = {
'type': 'Feature',
'id': self.id,
'properties': self.properties,
'geometry': self.geometry,
'bbox': self.bbox,
'links': links,
'assets': assets
}
if self.stac_extensions is not None:
d['stac_extensions'] = self.stac_extensions
return d
def clone(self):
clone = Item(id=self.id,
geometry=deepcopy(self.geometry),
bbox=copy(self.bbox),
properties=deepcopy(self.properties),
stac_extensions=deepcopy(self.stac_extensions))
clone.links = [l.clone() for l in self.links]
clone.assets = dict([(k, a.clone()) for (k, a) in self.assets.items()])
return clone
def save(self):
save_json(self.get_self_href(), self.to_dict())
@staticmethod
def from_dict(d):
id = d['id']
geometry = d['geometry']
bbox = d['bbox']
properties = d['properties']
stac_extensions = d.get('stac_extensions')
item = Item(id=id,
geometry=geometry,
bbox=bbox,
properties=properties,
stac_extensions=stac_extensions)
for l in d['links']:
item.add_link(Link.from_dict(l))
for k, v in d['assets'].items():
item.assets[k] = Asset.from_dict(v)
return item
class Link:
def __init__(self, rel, target, media_type=None, title=None):
self.rel = rel
self.target = target # An object or an href
self.media_type = media_type
self.title = title
def __repr__(self):
return '<Link rel={} target={}>'.format(self.rel, self.target)
def resolve_stac_object(self, root=None, parent=None):
if isinstance(self.target, str):
self.target = read_stac_json(self.target, root=root, parent=parent)
return self
def is_resolved(self):
return not isinstance(self.target, str)
def to_dict(self):
d = { 'rel': self.rel }
if self.is_resolved():
d['href'] = self.target.get_self_href()
else:
d['href'] = self.target
if self.media_type is not None:
d['type'] = self.media_type
if self.title is not None:
d['title'] = self.title
return d
def clone(self):
return Link(rel=self.rel,
target=self.target,
media_type=self.media_type,
title=self.title)
@staticmethod
def from_dict(d):
return Link(rel=d['rel'],
target=d['href'],
media_type=d.get('media_type'),
title=d.get('title'))
@staticmethod
def root(c):
"""Creates a link to a root Catalog or Collection."""
return Link('root', c, media_type='application/json')
@staticmethod
def parent(c):
"""Creates a link to a parent Catalog or Collection."""
return Link('parent', c, media_type='application/json')
@staticmethod
def self_href(href):
"""Creates a self link to the file's location."""
return Link('self', href, media_type='application/json')
@staticmethod
def child(c, title=None):
"""Creates a link to a child Catalog or Collection."""
return Link('child', c, title=title, media_type='application/json')
@staticmethod
def item(item, title=None):
"""Creates a link to an Item."""
return Link('item', item, title=title, media_type='application/json')
class Asset:
def __init__(self, href, title=None, media_type=None, properties=None):
self.href = href
self.title = title
self.media_type = media_type
self.properties = None
def to_dict(self):
d = {
'href': self.href
}
if self.media_type is not None:
d['type'] = self.media_type
if self.title is not None:
d['title'] = self.title
if self.properties is not None:
for k in properties:
d[k] = properties[k]
return d
def clone(self):
return Asset(href=self.href,
title=self.title,
media_type=self.media_type)
def __repr__(self):
return '<Asset href={}>'.format(self.href)
@staticmethod
def from_dict(d):
d = copy(d)
href = d.pop('href')
media_type = d.pop('media_type', None)
title = d.pop('title', None)
properties = None
if any(d):
properties = d
return Asset(href=href,
media_type=media_type,
title=title,
properties=properties)
class Extent:
# TODO - fix temporal extent
def __init__(self, spatial, temporal=[None,None]):
self.spatial = spatial
self.temporal = temporal
def to_dict(self):
return {
'spatial': self.spatial,
'temporal': self.temporal
}
def clone(self):
return Extent(spatial=copy(self.spatial),
temporal=copy(self.temporal))
@staticmethod
def from_dict(d):
return Extent(d['spatial'], d['temporal'])
import rasterio
from shapely.geometry import (Polygon, mapping, shape)
import pyproj
from model import (Collection, Item, Extent)
def item_from_raster(id,
key,
uri,
properties,
title=None,
media_type=None,
stac_extensions=[]):
with rasterio.open(uri) as ds:
map_proj = pyproj.Proj(init='epsg:4326')
image_proj = pyproj.Proj(ds.crs)
bounds = ds.bounds
left, bottom = pyproj.transform(image_proj, map_proj,
bounds.left, bounds.bottom)
right, top = pyproj.transform(image_proj, map_proj,
bounds.right, bounds.top)
poly = Polygon([(left, top),
(right, top),
(right, bottom),
(left, bottom)])
geometry = mapping(poly)
bbox = [left, bottom, right, top]
item = Item(id=id,
bbox=bbox,
geometry=geometry,
properties=properties,
stac_extensions=stac_extensions)
item.add_asset(key, uri, title=title, media_type=media_type)
return item
def collection_from_items(items, id, description, title=None, href=None, license='proprietary'):
collection_poly = None
for item in items:
poly = shape(item.geometry)
if collection_poly is None:
collection_poly = poly
else:
collection_poly = collection_poly.union(poly)
collection_bbox = list(collection_poly.envelope.bounds)
collection = Collection(id=id,
description=description,
extent=Extent(collection_bbox),
title=title,
href=href,
license=license)
for item in items:
collection.add_item(item)
return collection
import json
import os
import unittest
import shutil
class ModelTest(unittest.TestCase):
def test_create(self):
catalog = Catalog(id='cat', description='Test Catalog', title='Test!')
for country in d:
items = []
for area in d[country]:
tiffs = d[country][area]
item = item_from_raster(id=area,
key='0',
uri=tiffs[0],
properties={})
for i, uri in enumerate(tiffs[1:]):
item.add_asset(key=str(i+1), href=uri)
items.append(item)
country_collection = collection_from_items(
items=items,
id=country,
description='Challenge imagery over {}'.format(country),
license='proprietary')
catalog.add_child(country_collection)
catalog.set_uris_from_root('/opt/data/catalog')
catalog.save()
def test_read(self):
catalog = Catalog.from_file('/opt/data/catalog/catalog.json')
collections = catalog.get_children()
self.assertEqual(len(collections), 3)
for c in collections:
areas = d[c.id]
items = c.get_items()
for item in items:
self.assertTrue(item.id in areas)
tiffs = set(areas[item.id])
item_hrefs = set([a.href for a in item.assets.values()])
self.assertEqual(tiffs, item_hrefs)
def test_get_all_items(self):
catalog = Catalog.from_file('/opt/data/catalog/catalog.json')
items = catalog.get_all_items()
self.assertEqual(len(items), 7)
def test_map_items(self):
def item_mapper(item):
item.properties['ITEM_MAPPER'] = 'YEP'
return item
catalog = Catalog.from_file('/opt/data/catalog/catalog.json')
new_cat = catalog.map_items(item_mapper)
new_cat.set_uris_from_root('/opt/data/catalog-im')
new_cat.save()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment