Skip to content

Instantly share code, notes, and snippets.

@schwehr
Last active September 5, 2015 16:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save schwehr/dd75f73cedf8f7b5357e to your computer and use it in GitHub Desktop.
Save schwehr/dd75f73cedf8f7b5357e to your computer and use it in GitHub Desktop.
GDAL autotests2 raster testing helps
# Copyright 2014 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Support for the gdal raster driver tests.
Provides tools to simplify testing a driver, which drivers are
available, and where to find test files.
Rewrite of GDALTest class:
http://trac.osgeo.org/gdal/browser/trunk/autotest/pymod/gdaltest.py#L284
"""
import contextlib
import logging
# TODO(schwehr): Drop optparse for argparse.
from optparse import OptionParser
import os
import sys
import unittest
from osgeo import gdal
from osgeo import osr
# FIX: from gflags import flags
FLAGS = flags.FLAGS
drivers = [gdal.GetDriver(i).ShortName.lower()
for i in range(gdal.GetDriverCount())]
AAIGRID_DRIVER = 'aaigrid'
ACE2_DRIVER = 'ace2'
ADRG_DRIVER = 'adrg'
AIG_DRIVER = 'aig'
AIRSAR_DRIVER = 'airsar'
ARG_DRIVER = 'arg'
BAG_DRIVER = 'bag'
BIGGIF_DRIVER = 'biggif'
BLX_DRIVER = 'blx'
BMP_DRIVER = 'bmp'
BSB_DRIVER = 'bsb'
BT_DRIVER = 'bt'
CEOS_DRIVER = 'ceos'
COASP_DRIVER = 'coasp'
COSAR_DRIVER = 'cosar'
CPG_DRIVER = 'cpg'
CTABLE2_DRIVER = 'ctable2'
CTG_DRIVER = 'ctg'
DIMAP_DRIVER = 'dimap'
DIPEX_DRIVER = 'dipex'
DOQ1_DRIVER = 'doq1'
DOQ2_DRIVER = 'doq2'
DTED_DRIVER = 'dted'
E00GRID_DRIVER = 'e00grid'
ECRGTOC_DRIVER = 'ecrgtoc'
ECW_DRIVER = 'ecw'
EHDR_DRIVER = 'ehdr'
EIR_DRIVER = 'eir'
ELAS_DRIVER = 'elas'
ENVI_DRIVER = 'envi'
ERS_DRIVER = 'ers'
ESAT_DRIVER = 'esat'
FAST_DRIVER = 'fast'
FIT_DRIVER = 'fit'
FITS_DRIVER = 'fits'
FUJIBAS_DRIVER = 'fujibas'
GENBIN_DRIVER = 'genbin'
GFF_DRIVER = 'gff'
GIF_DRIVER = 'gif'
GMT_DRIVER = 'gmt'
GRASS_DRIVER = 'grass'
GRASSASCIIGRID_DRIVER = 'grassasciigrid'
GRIB_DRIVER = 'grib'
GS7BG_DRIVER = 'gs7bg'
GSAG_DRIVER = 'gsag'
GSBG_DRIVER = 'gsbg'
GSC_DRIVER = 'gsc'
GTIFF_DRIVER = 'gtiff'
GTX_DRIVER = 'gtx'
GXF_DRIVER = 'gxf'
HDF5_DRIVER = 'hdf5'
HDF5IMAGE_DRIVER = 'hdf5image'
HF2_DRIVER = 'hf2'
HFA_DRIVER = 'hfa'
HTTP_DRIVER = 'http'
IDA_DRIVER = 'ida'
ILWIS_DRIVER = 'ilwis'
INGR_DRIVER = 'ingr'
IRIS_DRIVER = 'iris'
ISIS2_DRIVER = 'isis2'
ISIS3_DRIVER = 'isis3'
JAXAPALSAR_DRIVER = 'jaxapalsar'
JDEM_DRIVER = 'jdem'
JP2ECW_DRIVER = 'jp2ecw'
JP2KAK_DRIVER = 'jp2kak'
JPEG2000_DRIVER = 'jpeg2000'
JP2MRSID = 'jp2mrsid'
JP2OPENJPEG = 'jp2openjpeg'
JPEG_DRIVER = 'jpeg'
JPIPKAK_DRIVER = 'jpipkak'
KMLSUPEROVERLAY_DRIVER = 'kmlsuperoverlay'
KRO_DRIVER = 'kro'
L1B_DRIVER = 'l1b'
LAN_DRIVER = 'lan'
LCP_DRIVER = 'lcp'
LEVELLER_DRIVER = 'leveller'
LOSLAS_DRIVER = 'loslas'
MAP_DRIVER = 'map'
MBTILES_DRIVER = 'mbtiles'
MEM_DRIVER = 'mem'
MFF_DRIVER = 'mff'
MFF2_DRIVER = 'mff2'
MG4LIDAR_DRIVER = 'mg4lidar'
MRSID_DRIVER = 'mrsid'
MSGN_DRIVER = 'msgn'
NDF_DRIVER = 'ndf'
NETCDF_DRIVER = 'netcdf'
NGSGEOID_DRIVER = 'ngsgeoid'
NITF_DRIVER = 'nitf'
NTV2_DRIVER = 'ntv2'
NWT_GRC_DRIVER = 'nwt_grc'
NWT_GRD_DRIVER = 'nwt_grd'
OZI_DRIVER = 'ozi'
PAUX_DRIVER = 'paux'
PCIDSK_DRIVER = 'pcidsk'
PCRASTER_DRIVER = 'pcraster'
PDF_DRIVER = 'pdf'
PDS_DRIVER = 'pds'
PNG_DRIVER = 'png'
PNM_DRIVER = 'pnm'
POSTGISRASTER_DRIVER = 'postgisraster'
R_DRIVER = 'r'
RASTERLITE_DRIVER = 'rasterlite'
RIK_DRIVER = 'rik'
RMF_DRIVER = 'rmf'
RPFTOC_DRIVER = 'rpftoc'
RS2_DRIVER = 'rs2'
RST_DRIVER = 'rst'
SAGA_DRIVER = 'saga'
SAR_CEOS_DRIVER = 'sar_ceos'
SDTS_DRIVER = 'sdts'
SGI_DRIVER = 'sgi'
SNODAS_DRIVER = 'snodas'
SRP_DRIVER = 'srp'
SRTMHGT_DRIVER = 'srtmhgt'
TERRAGEN_DRIVER = 'terragen'
TIL_DRIVER = 'til'
TSX_DRIVER = 'tsx'
USGSDEM_DRIVER = 'usgsdem'
VRT_DRIVER = 'vrt'
WCS_DRIVER = 'wcs'
WEBP_DRIVER = 'webp'
WMS_DRIVER = 'wms'
XPM_DRIVER = 'xpm'
XYZ_DRIVER = 'xyz'
ZMAP_DRIVER = 'zmap'
def SkipIfDriverMissing(driver_name):
"""Decorator that only runs a test if a required driver is found.
Args:
driver_name: Lower case short name of a driver. e.g. 'dted'.
Returns:
A pass through function if the test should be run or the unittest skip
function if the test or TestCase should not be run.
"""
def _IdReturn(obj):
return obj
debug = gdal.GetConfigOption('CPL_DEBUG')
if driver_name not in drivers:
if debug:
logging.info('Debug: Skipping test. Driver not found: %s', driver_name)
return unittest.case.skip('Skipping "%s" driver dependent test.' %
driver_name)
if debug:
logging.info('Debug: Running test. Found driver: %s', driver_name)
return _IdReturn
def GetTestFilePath(filename):
return os.path.join(
FLAGS.test_srcdir,
os.path.split(os.path.abspath(__file__))[0],
'testdata',
filename
)
def CreateParser():
parser = OptionParser()
parser.add_option('-t', '--temp-dir', default=os.getcwd(),
help='Where to put temporary files.',
metavar='DIR')
parser.add_option('-p', '--pam-dir', default=None,
help='Where to store the .aux.xml files created '
'by the persistent auxiliary metadata system. '
'Defaults to temp-directory/pam.',
metavar='DIR')
parser.add_option('-v', '--verbose', default=False, action='store_true',
help='Put the unittest run into verbose mode.')
return parser
def Setup(options):
if options.verbose:
logging.basicConfig(level=logging.INFO)
options.temp_dir = os.path.abspath(options.temp_dir)
gdal.SetConfigOption('CPL_TMPDIR', options.temp_dir)
logging.info('CPL_TMPDIR: %s', options.temp_dir)
options.pam_dir = options.pam_dir or os.path.join(options.temp_dir, 'pam')
if not os.path.isdir(options.pam_dir):
os.mkdir(options.pam_dir)
gdal.SetConfigOption('GDAL_PAM_PROXY_DIR', options.pam_dir)
logging.info('GDAL_PAM_PROXY_DIR: %s', options.pam_dir)
def Main():
parser = CreateParser()
options, args = parser.parse_args()
Setup(options)
argv = sys.argv[:1]
if options.verbose:
argv.append('-v')
unittest.main(argv=argv)
class TempFiles(object):
def __init__(self):
self.count = 0
self.tmp_dir = None
def TempFile(self, basename, ext=''):
if not self.tmp_dir:
self.tmp_dir = gdal.GetConfigOption('TMPDIR')
if not self.tmp_dir:
logging.fatal('Do not have a tmp_dir!!!')
filepath = os.path.join(self.tmp_dir,
basename + '%03d' % self.count + ext)
self.count += 1
return filepath
_temp_files = TempFiles()
@contextlib.contextmanager
def ConfigOption(key, value, default=None):
"""Set a gdal config option and when the context closes, try to revert it.
TODO(schwehr): This would be better as part of gcore_util.py.
Args:
key: String naming the config option.
value: String value to set the option to.
default: String value to reset the option to if no starting value.
Yields:
None
"""
original_value = gdal.GetConfigOption(key, default)
gdal.SetConfigOption(key, value)
try:
yield
finally:
gdal.SetConfigOption(key, original_value)
class DriverTestCase(unittest.TestCase):
"""Checks the basic functioning of a single raster driver.
Assumes that only one driver is registered for the file type.
CheckOpen has a critical side effect that it puts the open data
source in the src attribute. Checks below CheckOpen in this class
assume that self.src is the original open file.
"""
def setUp(self, driver_name, ext):
super(DriverTestCase, self).setUp()
assert driver_name
self.driver_name = driver_name.lower()
self.driver = gdal.GetDriverByName(driver_name)
assert self.driver
self.ext = ext
# Start with a clean slate.
gdal.ErrorReset()
def assertIterAlmostEqual(self, first, second, places=None, msg=None,
delta=None):
msg = msg or ''
self.assertEqual(len(first), len(second), 'lists not same length ' + msg)
for a, b in zip(first, second):
self.assertAlmostEqual(a, b, places=places, msg=msg, delta=delta)
def CheckDriver(self):
self.assertEqual(self.driver_name, self.driver.ShortName.lower())
def CheckOpen(self, filepath, check_driver=True):
"""Open the test file and keep it open as self.src.
Args:
filepath: str, Path to a file to open with GDAL.
check_driver: If True, make sure that the file opened with the
default driver for this test. If it is a str, then check that
the driver used matches the string. If False, then do not
check the driver.
"""
if filepath.startswith(os.path.sep) and not filepath.startswith('/vsi'):
self.assertTrue(os.path.isfile(filepath), 'Does not exist: ' + filepath)
self.src = gdal.Open(filepath, gdal.GA_ReadOnly)
self.assertTrue(self.src, '%s driver unable to open %s' % (self.driver_name,
filepath))
if check_driver:
driver_name = self.src.GetDriver().ShortName.lower()
if isinstance(check_driver, str) or isinstance(check_driver, unicode):
self.assertEqual(check_driver, driver_name)
else:
self.assertEqual(self.driver_name, driver_name)
def CheckGeoTransform(self, gt_expected, gt_delta=None):
gt = self.src.GetGeoTransform()
if not gt and not gt_expected:
return
self.assertEqual(len(gt_expected), 6)
gt_delta = gt_delta or ((abs(gt_expected[1]) + abs(gt_expected[2])) / 100.0)
for idx in range(6):
self.assertAlmostEqual(gt[idx], gt_expected[idx], delta=gt_delta)
def CheckProjection(self, prj_expected):
prj = self.src.GetProjection()
if not prj and not prj_expected:
return
src_osr = osr.SpatialReference(wkt=prj)
prj2 = osr.SpatialReference()
prj2.SetFromUserInput(prj_expected)
self.assertTrue(src_osr.IsSame(prj2))
def CheckBand(self, band_num, checksum, gdal_type=None, nodata=None):
band = self.src.GetRasterBand(band_num)
self.assertEqual(band.Checksum(), checksum)
if gdal_type is not None:
self.assertEqual(gdal_type, band.DataType)
if nodata is not None:
self.assertEqual(nodata, band.GetNoDataValue())
def CheckBandSubRegion(self, band_num, checksum, xoff, yoff, xsize, ysize):
band = self.src.GetRasterBand(band_num)
self.assertEqual(checksum, band.Checksum(xoff, yoff, xsize, ysize))
# TODO(schwehr): Add assertCreateCopyInterrupt method.
def CheckCreateCopy(self, check_checksums=True, check_stats=True,
check_geotransform=True,
check_projection=True, options=None, strict=True,
vsimem=False):
"""Compare a copy to the currently open file.
Args:
check_checksums: Set to False to not check checksums.
check_stats: Compare band statistics if true.
check_geotransform: Set to False to skip checking the geotransform.
check_projection: Set to False to skip checking the projection.
options: List of options to pass to CreateCopy.
strict: Set to False to have the CreateCopy operation in loose mode.
vsimem: If true, copy to memory.
Returns:
Open gdal raster Dataset.
"""
# TODO(schwehr): Complain if options is a str or unicode.
# TODO(schwehr): Use gdal.GetConfigOption('TMPDIR') if available.
options = options or []
basename = os.path.basename(self.src.GetFileList()[0])
if vsimem:
dst_file = os.path.join('/vsimem/', basename + self.ext)
else:
dst_file = _temp_files.TempFile(basename, self.ext)
dst = self.driver.CreateCopy(dst_file, self.src, strict=strict,
options=options)
self.assertTrue(dst)
self.assertEqual(dst.GetDriver().ShortName.lower(), self.driver_name)
# TODO(schwehr): Pre-close tests.
del dst # Flush the file.
self.dst = gdal.Open(dst_file)
self.assertTrue(self.dst)
self.assertEqual(self.dst.RasterCount, self.src.RasterCount)
for band_num in range(1, self.dst.RasterCount + 1):
src_band = self.src.GetRasterBand(band_num)
dst_band = self.dst.GetRasterBand(band_num)
if check_checksums:
dst_checksum = dst_band.Checksum()
self.assertEqual(dst_checksum, src_band.Checksum())
if dst_checksum and check_stats:
self.assertEqual(dst_band.ComputeRasterMinMax(),
src_band.ComputeRasterMinMax())
if check_geotransform:
self.CheckGeoTransform(self.dst.GetGeoTransform())
if check_projection:
self.CheckProjection(self.dst.GetProjection())
return self.dst
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment