Skip to content

Instantly share code, notes, and snippets.

@ashleysommer
Created September 28, 2021 02:14
Show Gist options
  • Save ashleysommer/c2e41696ab848c2cbd8aa321302e8f2d to your computer and use it in GitHub Desktop.
Save ashleysommer/c2e41696ab848c2cbd8aa321302e8f2d to your computer and use it in GitHub Desktop.
Walk directories, make VRT files, add stats and overviews to the VRT files
from collections import defaultdict
from math import e
from re import L
import sys
from glob import glob
from pathlib import Path
from types import new_class
from lxml import etree
from os.path import commonprefix
from os import unlink
import rasterio
import subprocess
BASE_DIR = "/q3774/landscapes-aet/CMRSET_LANDSAT_V2_1_newcogs"
def get_raster_stats(source_filename, source_band):
base_meta = {}
_min =_max = _mean = _stddev = _valid = None
_scale = 1
other_meta = {}
with rasterio.open(source_filename) as ds:
base_meta.update(ds.tags())
b = rasterio.band(ds, source_band)
if len(ds.scales) >= source_band:
_scale = ds.scales[source_band-1]
print(ds.tags(1))
for k,v in ds.tags(source_band).items():
k = k.upper()
v = v.lower()
if k == "STATISTICS_MAXIMUM":
if "." in v or "e" in v:
_max = float(v)
else:
_max = int(v)
elif k == "STATISTICS_MINIMUM":
if "." in v or "e" in v:
_min = float(v)
else:
_min = int(v)
elif k == "STATISTICS_MEAN":
_mean = float(v)
elif k == "STATISTICS_STDDEV":
_stddev = float(v)
elif k == "STATISTICS_VALID_PERCENT":
_valid = float(v)
else:
other_meta[k] = v
return _min, _max, _mean, _stddev, _valid, _scale, base_meta, other_meta
def process_raster_band(elem, base_file):
mins = []
maxs = []
means = []
stddevs = []
scales = set()
percent_valids = []
other_metas = []
base_metas = []
do_min = True
do_max = True
do_mean = True
do_stddev = True
do_percent_valids = True
base_file = Path(base_file)
m_elem = None
scale_elem = None
for c in elem:
if c.tag == "Metadata":
m_elem = c
continue
if c.tag == "Scale":
scale_elem = c
continue
if (c.tag == "SimpleSource" or c.tag == "ComplexSource"):
source_filename = None
source_band = None
scale_ratio = None
relative = False
for cc in c:
if cc.tag == "SourceFilename":
source_filename = cc.text
r = cc.get("relativeToVRT", "0")
if r in ("1", 1, "True", "true", True):
relative = True
elif cc.tag == "SourceBand":
source_band = cc.text
elif cc.tag == "ScaleRatio":
scale_ratio = cc
if source_filename is None:
print("Bad Source?: {}".format(c.tag))
continue
if source_band is None:
source_band = "1"
try:
source_band = int(source_band)
except ValueError:
print("Can't interpret band! {}".format(source_band))
continue
source_filename = Path(source_filename)
print(source_filename)
if relative:
source_filename = (base_file if base_file.is_dir() else base_file.parent) / source_filename
print(source_filename)
_min, _max, _mean, _stddev, _valid, _scale, base_meta, other = get_raster_stats(source_filename, source_band)
base_metas.append(base_meta)
other_metas.append(other)
scales.add(_scale)
if scale_ratio is None:
scale_ratio = etree.SubElement(c, "ScaleRatio")
scale_ratio.text = "1" #str(_scale)
if _min is None:
do_min = False
mins = []
elif do_min:
mins.append(_min)
if _max is None:
do_max = False
mins = []
elif do_max:
maxs.append(_max)
if _mean is None:
do_mean = False
means = []
elif do_mean:
means.append(_mean)
if _stddev is None:
do_stddev = False
stddevs = []
elif do_stddev:
stddevs.append(_stddev)
if _valid is None:
do_percent_valids = False
percent_valids = []
elif do_percent_valids:
percent_valids.append(_valid)
#---
base_meta = defaultdict(set)
for b in base_metas:
for k,v in b.items():
base_meta[k].add(v)
if m_elem is None:
m_elem = etree.SubElement(elem, "Metadata")
existing_meta = [c for c in m_elem]
for m in existing_meta:
m_elem.remove(m)
if scale_elem is None:
scale_elem = etree.SubElement(elem, "Scale")
if len(scales) > 1:
# More than one scale. Don't have a layer-wide scale.
elem.remove(scale_elem)
elif len(scales) < 1:
scale_elem.text = "1.0"
else:
scale_elem.text = str(next(iter(scales)))
if (not do_min) and (not do_max) and (not do_mean) and (not do_stddev) and (not do_percent_valids) and len(other_metas) < 1:
return elem, base_meta
meta_other = defaultdict(set)
elem.insert(0, m_elem)
for o in other_metas:
for k,v in o.items():
meta_other[k].add(v)
for k,vals in meta_other.items():
if len(vals) == 1:
new_meta = etree.SubElement(m_elem, "MDI", {"key": k})
new_meta.text = str(next(iter(vals)))
if do_min:
minmin = min(mins)
new_meta = etree.SubElement(m_elem, "MDI", {"key": "STATISTICS_MINIMUM"})
new_meta.text=str(minmin)
if do_max:
maxmax = max(maxs)
new_meta = etree.SubElement(m_elem, "MDI", {"key": "STATISTICS_MAXIMUM"})
new_meta.text=str(maxmax)
if do_mean:
if len(means) < 1:
meanmean = 0.0
else:
meanmean = sum(means)/len(means)
new_meta = etree.SubElement(m_elem, "MDI", {"key": "STATISTICS_MEAN"})
new_meta.text=str(meanmean)
if do_stddev:
if len(stddevs) < 1:
meanstddev = 0.0
else:
meanstddev = sum(stddevs)/len(stddevs)
new_meta = etree.SubElement(m_elem, "MDI", {"key": "STATISTICS_STDDEV"})
new_meta.text=str(meanstddev)
if do_percent_valids:
if len(percent_valids) < 1:
meanvalid = 100.0
else:
meanvalid = min(sum(percent_valids)/len(percent_valids), 100.0)
new_meta = etree.SubElement(m_elem, "MDI", {"key": "STATISTICS_VALID_PERCENT"})
new_meta.text=str(meanvalid)
return elem, base_meta
def process_vrt(infile):
parser = etree.XMLParser(remove_blank_text=True)
tree = etree.parse(str(infile), parser)
root = tree.getroot()
if root.tag != "VRTDataset":
print("BAD VRT file? {}".format(infile))
sys.exit()
raster_found = False
olist = None
m_elem = None
base_meta = defaultdict(set)
for c in root:
if c.tag == "Metadata":
m_elem = c
continue
if c.tag == "OverviewList":
olist = c
continue
if c.tag == "VRTRasterBand":
c, base_metas = process_raster_band(c, infile)
for k,vals in base_metas.items():
base_meta[k].update(vals)
raster_found = True
if olist is None:
print("Adding OverviewList")
olist = etree.SubElement(root, "OverviewList", {"resampling": "nearest"})
olist.text = "2 4 8 16 32 64 128"
if m_elem is None:
print("Adding Metadata section")
m_elem = etree.SubElement(root, "Metadata")
root.insert(0, m_elem)
if len(base_meta) > 0:
for k,vals in base_meta.items():
if len(vals) == 1:
new_meta = etree.SubElement(m_elem, "MDI", {"key": k})
new_meta.text = str(next(iter(vals)))
if not raster_found:
print("No VRTRasterBand in the VRT?")
sys.exit()
with open(infile, "wb") as f:
print("Writing out {}".format(infile))
tree.write(f, pretty_print=True)
return
def make_new_vrt(directory, globfor):
dir = Path(directory)
source_files = [f.relative_to(directory) for f in dir.glob(globfor)]
common_name = commonprefix(source_files).rstrip("_/0").lstrip("./")
if len(common_name) < 1:
common_name = "out"
outname = common_name + ".vrt"
args = ["gdalbuildvrt", outname, *source_files]
result = subprocess.run(args, stderr=subprocess.PIPE, stdout=subprocess.PIPE, cwd=directory, universal_newlines=True)
print(result.stdout)
print(result.stderr)
if result.returncode != 0:
print("CANNOT CREATE VRT AT {}".format(outname))
return None
return directory / outname
def main():
base = Path(BASE_DIR)
matches = list(base.glob("./*/*/*ETa*.tif")) # base/2018/2018_01_01/my_ETa_.cog.tif
dirs = defaultdict(lambda:0)
for f in matches:
if f.is_file():
dir = f.parent
dirs[dir] += 1
for d, c in dirs.items():
vrts = list(d.glob("./*ETa*.vrt"))
for v in vrts:
unlink(str(v.absolute()))
vrts = []
if len(vrts) < 1:
print("No VRT for {}: {} files found.".format(d, c))
if c < 12: # require 12 tifs before we can have a vrt
continue
new_vrt = make_new_vrt(d, "./*ETa*.tif")
if new_vrt:
process_vrt(new_vrt)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment