Last active
September 5, 2015 16:35
-
-
Save schwehr/dd75f73cedf8f7b5357e to your computer and use it in GitHub Desktop.
GDAL autotests2 raster testing helps
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2014 Google Inc. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Support for the gdal raster driver tests. | |
Provides tools to simplify testing a driver, which drivers are | |
available, and where to find test files. | |
Rewrite of GDALTest class: | |
http://trac.osgeo.org/gdal/browser/trunk/autotest/pymod/gdaltest.py#L284 | |
""" | |
import contextlib | |
import logging | |
# TODO(schwehr): Drop optparse for argparse. | |
from optparse import OptionParser | |
import os | |
import sys | |
import unittest | |
from osgeo import gdal | |
from osgeo import osr | |
# FIX: from gflags import flags | |
FLAGS = flags.FLAGS | |
drivers = [gdal.GetDriver(i).ShortName.lower() | |
for i in range(gdal.GetDriverCount())] | |
AAIGRID_DRIVER = 'aaigrid' | |
ACE2_DRIVER = 'ace2' | |
ADRG_DRIVER = 'adrg' | |
AIG_DRIVER = 'aig' | |
AIRSAR_DRIVER = 'airsar' | |
ARG_DRIVER = 'arg' | |
BAG_DRIVER = 'bag' | |
BIGGIF_DRIVER = 'biggif' | |
BLX_DRIVER = 'blx' | |
BMP_DRIVER = 'bmp' | |
BSB_DRIVER = 'bsb' | |
BT_DRIVER = 'bt' | |
CEOS_DRIVER = 'ceos' | |
COASP_DRIVER = 'coasp' | |
COSAR_DRIVER = 'cosar' | |
CPG_DRIVER = 'cpg' | |
CTABLE2_DRIVER = 'ctable2' | |
CTG_DRIVER = 'ctg' | |
DIMAP_DRIVER = 'dimap' | |
DIPEX_DRIVER = 'dipex' | |
DOQ1_DRIVER = 'doq1' | |
DOQ2_DRIVER = 'doq2' | |
DTED_DRIVER = 'dted' | |
E00GRID_DRIVER = 'e00grid' | |
ECRGTOC_DRIVER = 'ecrgtoc' | |
ECW_DRIVER = 'ecw' | |
EHDR_DRIVER = 'ehdr' | |
EIR_DRIVER = 'eir' | |
ELAS_DRIVER = 'elas' | |
ENVI_DRIVER = 'envi' | |
ERS_DRIVER = 'ers' | |
ESAT_DRIVER = 'esat' | |
FAST_DRIVER = 'fast' | |
FIT_DRIVER = 'fit' | |
FITS_DRIVER = 'fits' | |
FUJIBAS_DRIVER = 'fujibas' | |
GENBIN_DRIVER = 'genbin' | |
GFF_DRIVER = 'gff' | |
GIF_DRIVER = 'gif' | |
GMT_DRIVER = 'gmt' | |
GRASS_DRIVER = 'grass' | |
GRASSASCIIGRID_DRIVER = 'grassasciigrid' | |
GRIB_DRIVER = 'grib' | |
GS7BG_DRIVER = 'gs7bg' | |
GSAG_DRIVER = 'gsag' | |
GSBG_DRIVER = 'gsbg' | |
GSC_DRIVER = 'gsc' | |
GTIFF_DRIVER = 'gtiff' | |
GTX_DRIVER = 'gtx' | |
GXF_DRIVER = 'gxf' | |
HDF5_DRIVER = 'hdf5' | |
HDF5IMAGE_DRIVER = 'hdf5image' | |
HF2_DRIVER = 'hf2' | |
HFA_DRIVER = 'hfa' | |
HTTP_DRIVER = 'http' | |
IDA_DRIVER = 'ida' | |
ILWIS_DRIVER = 'ilwis' | |
INGR_DRIVER = 'ingr' | |
IRIS_DRIVER = 'iris' | |
ISIS2_DRIVER = 'isis2' | |
ISIS3_DRIVER = 'isis3' | |
JAXAPALSAR_DRIVER = 'jaxapalsar' | |
JDEM_DRIVER = 'jdem' | |
JP2ECW_DRIVER = 'jp2ecw' | |
JP2KAK_DRIVER = 'jp2kak' | |
JPEG2000_DRIVER = 'jpeg2000' | |
JP2MRSID = 'jp2mrsid' | |
JP2OPENJPEG = 'jp2openjpeg' | |
JPEG_DRIVER = 'jpeg' | |
JPIPKAK_DRIVER = 'jpipkak' | |
KMLSUPEROVERLAY_DRIVER = 'kmlsuperoverlay' | |
KRO_DRIVER = 'kro' | |
L1B_DRIVER = 'l1b' | |
LAN_DRIVER = 'lan' | |
LCP_DRIVER = 'lcp' | |
LEVELLER_DRIVER = 'leveller' | |
LOSLAS_DRIVER = 'loslas' | |
MAP_DRIVER = 'map' | |
MBTILES_DRIVER = 'mbtiles' | |
MEM_DRIVER = 'mem' | |
MFF_DRIVER = 'mff' | |
MFF2_DRIVER = 'mff2' | |
MG4LIDAR_DRIVER = 'mg4lidar' | |
MRSID_DRIVER = 'mrsid' | |
MSGN_DRIVER = 'msgn' | |
NDF_DRIVER = 'ndf' | |
NETCDF_DRIVER = 'netcdf' | |
NGSGEOID_DRIVER = 'ngsgeoid' | |
NITF_DRIVER = 'nitf' | |
NTV2_DRIVER = 'ntv2' | |
NWT_GRC_DRIVER = 'nwt_grc' | |
NWT_GRD_DRIVER = 'nwt_grd' | |
OZI_DRIVER = 'ozi' | |
PAUX_DRIVER = 'paux' | |
PCIDSK_DRIVER = 'pcidsk' | |
PCRASTER_DRIVER = 'pcraster' | |
PDF_DRIVER = 'pdf' | |
PDS_DRIVER = 'pds' | |
PNG_DRIVER = 'png' | |
PNM_DRIVER = 'pnm' | |
POSTGISRASTER_DRIVER = 'postgisraster' | |
R_DRIVER = 'r' | |
RASTERLITE_DRIVER = 'rasterlite' | |
RIK_DRIVER = 'rik' | |
RMF_DRIVER = 'rmf' | |
RPFTOC_DRIVER = 'rpftoc' | |
RS2_DRIVER = 'rs2' | |
RST_DRIVER = 'rst' | |
SAGA_DRIVER = 'saga' | |
SAR_CEOS_DRIVER = 'sar_ceos' | |
SDTS_DRIVER = 'sdts' | |
SGI_DRIVER = 'sgi' | |
SNODAS_DRIVER = 'snodas' | |
SRP_DRIVER = 'srp' | |
SRTMHGT_DRIVER = 'srtmhgt' | |
TERRAGEN_DRIVER = 'terragen' | |
TIL_DRIVER = 'til' | |
TSX_DRIVER = 'tsx' | |
USGSDEM_DRIVER = 'usgsdem' | |
VRT_DRIVER = 'vrt' | |
WCS_DRIVER = 'wcs' | |
WEBP_DRIVER = 'webp' | |
WMS_DRIVER = 'wms' | |
XPM_DRIVER = 'xpm' | |
XYZ_DRIVER = 'xyz' | |
ZMAP_DRIVER = 'zmap' | |
def SkipIfDriverMissing(driver_name): | |
"""Decorator that only runs a test if a required driver is found. | |
Args: | |
driver_name: Lower case short name of a driver. e.g. 'dted'. | |
Returns: | |
A pass through function if the test should be run or the unittest skip | |
function if the test or TestCase should not be run. | |
""" | |
def _IdReturn(obj): | |
return obj | |
debug = gdal.GetConfigOption('CPL_DEBUG') | |
if driver_name not in drivers: | |
if debug: | |
logging.info('Debug: Skipping test. Driver not found: %s', driver_name) | |
return unittest.case.skip('Skipping "%s" driver dependent test.' % | |
driver_name) | |
if debug: | |
logging.info('Debug: Running test. Found driver: %s', driver_name) | |
return _IdReturn | |
def GetTestFilePath(filename): | |
return os.path.join( | |
FLAGS.test_srcdir, | |
os.path.split(os.path.abspath(__file__))[0], | |
'testdata', | |
filename | |
) | |
def CreateParser(): | |
parser = OptionParser() | |
parser.add_option('-t', '--temp-dir', default=os.getcwd(), | |
help='Where to put temporary files.', | |
metavar='DIR') | |
parser.add_option('-p', '--pam-dir', default=None, | |
help='Where to store the .aux.xml files created ' | |
'by the persistent auxiliary metadata system. ' | |
'Defaults to temp-directory/pam.', | |
metavar='DIR') | |
parser.add_option('-v', '--verbose', default=False, action='store_true', | |
help='Put the unittest run into verbose mode.') | |
return parser | |
def Setup(options): | |
if options.verbose: | |
logging.basicConfig(level=logging.INFO) | |
options.temp_dir = os.path.abspath(options.temp_dir) | |
gdal.SetConfigOption('CPL_TMPDIR', options.temp_dir) | |
logging.info('CPL_TMPDIR: %s', options.temp_dir) | |
options.pam_dir = options.pam_dir or os.path.join(options.temp_dir, 'pam') | |
if not os.path.isdir(options.pam_dir): | |
os.mkdir(options.pam_dir) | |
gdal.SetConfigOption('GDAL_PAM_PROXY_DIR', options.pam_dir) | |
logging.info('GDAL_PAM_PROXY_DIR: %s', options.pam_dir) | |
def Main(): | |
parser = CreateParser() | |
options, args = parser.parse_args() | |
Setup(options) | |
argv = sys.argv[:1] | |
if options.verbose: | |
argv.append('-v') | |
unittest.main(argv=argv) | |
class TempFiles(object): | |
def __init__(self): | |
self.count = 0 | |
self.tmp_dir = None | |
def TempFile(self, basename, ext=''): | |
if not self.tmp_dir: | |
self.tmp_dir = gdal.GetConfigOption('TMPDIR') | |
if not self.tmp_dir: | |
logging.fatal('Do not have a tmp_dir!!!') | |
filepath = os.path.join(self.tmp_dir, | |
basename + '%03d' % self.count + ext) | |
self.count += 1 | |
return filepath | |
_temp_files = TempFiles() | |
@contextlib.contextmanager | |
def ConfigOption(key, value, default=None): | |
"""Set a gdal config option and when the context closes, try to revert it. | |
TODO(schwehr): This would be better as part of gcore_util.py. | |
Args: | |
key: String naming the config option. | |
value: String value to set the option to. | |
default: String value to reset the option to if no starting value. | |
Yields: | |
None | |
""" | |
original_value = gdal.GetConfigOption(key, default) | |
gdal.SetConfigOption(key, value) | |
try: | |
yield | |
finally: | |
gdal.SetConfigOption(key, original_value) | |
class DriverTestCase(unittest.TestCase): | |
"""Checks the basic functioning of a single raster driver. | |
Assumes that only one driver is registered for the file type. | |
CheckOpen has a critical side effect that it puts the open data | |
source in the src attribute. Checks below CheckOpen in this class | |
assume that self.src is the original open file. | |
""" | |
def setUp(self, driver_name, ext): | |
super(DriverTestCase, self).setUp() | |
assert driver_name | |
self.driver_name = driver_name.lower() | |
self.driver = gdal.GetDriverByName(driver_name) | |
assert self.driver | |
self.ext = ext | |
# Start with a clean slate. | |
gdal.ErrorReset() | |
def assertIterAlmostEqual(self, first, second, places=None, msg=None, | |
delta=None): | |
msg = msg or '' | |
self.assertEqual(len(first), len(second), 'lists not same length ' + msg) | |
for a, b in zip(first, second): | |
self.assertAlmostEqual(a, b, places=places, msg=msg, delta=delta) | |
def CheckDriver(self): | |
self.assertEqual(self.driver_name, self.driver.ShortName.lower()) | |
def CheckOpen(self, filepath, check_driver=True): | |
"""Open the test file and keep it open as self.src. | |
Args: | |
filepath: str, Path to a file to open with GDAL. | |
check_driver: If True, make sure that the file opened with the | |
default driver for this test. If it is a str, then check that | |
the driver used matches the string. If False, then do not | |
check the driver. | |
""" | |
if filepath.startswith(os.path.sep) and not filepath.startswith('/vsi'): | |
self.assertTrue(os.path.isfile(filepath), 'Does not exist: ' + filepath) | |
self.src = gdal.Open(filepath, gdal.GA_ReadOnly) | |
self.assertTrue(self.src, '%s driver unable to open %s' % (self.driver_name, | |
filepath)) | |
if check_driver: | |
driver_name = self.src.GetDriver().ShortName.lower() | |
if isinstance(check_driver, str) or isinstance(check_driver, unicode): | |
self.assertEqual(check_driver, driver_name) | |
else: | |
self.assertEqual(self.driver_name, driver_name) | |
def CheckGeoTransform(self, gt_expected, gt_delta=None): | |
gt = self.src.GetGeoTransform() | |
if not gt and not gt_expected: | |
return | |
self.assertEqual(len(gt_expected), 6) | |
gt_delta = gt_delta or ((abs(gt_expected[1]) + abs(gt_expected[2])) / 100.0) | |
for idx in range(6): | |
self.assertAlmostEqual(gt[idx], gt_expected[idx], delta=gt_delta) | |
def CheckProjection(self, prj_expected): | |
prj = self.src.GetProjection() | |
if not prj and not prj_expected: | |
return | |
src_osr = osr.SpatialReference(wkt=prj) | |
prj2 = osr.SpatialReference() | |
prj2.SetFromUserInput(prj_expected) | |
self.assertTrue(src_osr.IsSame(prj2)) | |
def CheckBand(self, band_num, checksum, gdal_type=None, nodata=None): | |
band = self.src.GetRasterBand(band_num) | |
self.assertEqual(band.Checksum(), checksum) | |
if gdal_type is not None: | |
self.assertEqual(gdal_type, band.DataType) | |
if nodata is not None: | |
self.assertEqual(nodata, band.GetNoDataValue()) | |
def CheckBandSubRegion(self, band_num, checksum, xoff, yoff, xsize, ysize): | |
band = self.src.GetRasterBand(band_num) | |
self.assertEqual(checksum, band.Checksum(xoff, yoff, xsize, ysize)) | |
# TODO(schwehr): Add assertCreateCopyInterrupt method. | |
def CheckCreateCopy(self, check_checksums=True, check_stats=True, | |
check_geotransform=True, | |
check_projection=True, options=None, strict=True, | |
vsimem=False): | |
"""Compare a copy to the currently open file. | |
Args: | |
check_checksums: Set to False to not check checksums. | |
check_stats: Compare band statistics if true. | |
check_geotransform: Set to False to skip checking the geotransform. | |
check_projection: Set to False to skip checking the projection. | |
options: List of options to pass to CreateCopy. | |
strict: Set to False to have the CreateCopy operation in loose mode. | |
vsimem: If true, copy to memory. | |
Returns: | |
Open gdal raster Dataset. | |
""" | |
# TODO(schwehr): Complain if options is a str or unicode. | |
# TODO(schwehr): Use gdal.GetConfigOption('TMPDIR') if available. | |
options = options or [] | |
basename = os.path.basename(self.src.GetFileList()[0]) | |
if vsimem: | |
dst_file = os.path.join('/vsimem/', basename + self.ext) | |
else: | |
dst_file = _temp_files.TempFile(basename, self.ext) | |
dst = self.driver.CreateCopy(dst_file, self.src, strict=strict, | |
options=options) | |
self.assertTrue(dst) | |
self.assertEqual(dst.GetDriver().ShortName.lower(), self.driver_name) | |
# TODO(schwehr): Pre-close tests. | |
del dst # Flush the file. | |
self.dst = gdal.Open(dst_file) | |
self.assertTrue(self.dst) | |
self.assertEqual(self.dst.RasterCount, self.src.RasterCount) | |
for band_num in range(1, self.dst.RasterCount + 1): | |
src_band = self.src.GetRasterBand(band_num) | |
dst_band = self.dst.GetRasterBand(band_num) | |
if check_checksums: | |
dst_checksum = dst_band.Checksum() | |
self.assertEqual(dst_checksum, src_band.Checksum()) | |
if dst_checksum and check_stats: | |
self.assertEqual(dst_band.ComputeRasterMinMax(), | |
src_band.ComputeRasterMinMax()) | |
if check_geotransform: | |
self.CheckGeoTransform(self.dst.GetGeoTransform()) | |
if check_projection: | |
self.CheckProjection(self.dst.GetProjection()) | |
return self.dst |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment