Skip to content

Instantly share code, notes, and snippets.

@scottstanie
Last active April 26, 2023 21:59
Show Gist options
  • Save scottstanie/1528a4c11b30e57d31f479ce58a77d0b to your computer and use it in GitHub Desktop.
Save scottstanie/1528a4c11b30e57d31f479ce58a77d0b to your computer and use it in GitHub Desktop.
Chunked Dask Array from a list of GDAL-readable files
from dolphin.io import load_gdal, get_raster_xysize, get_raster_nodata, get_raster_dtype, get_raster_chunk_size
class GDALReader:
# https://docs.dask.org/en/stable/generated/dask.array.from_array.html
def __init__(self, filename, masked=True):
self.filename = filename
self._shape = get_raster_xysize(filename)[::-1]
self._nodata = get_raster_nodata(filename)
self._masked = masked
self._dtype = get_raster_dtype(filename)
self._chunks = get_raster_chunk_size(filename)[::-1]
def __getitem__(self, index):
rows, cols = index
if isinstance(rows, int):
rows = slice(rows, rows + 1)
if isinstance(cols, int):
cols = slice(cols, cols + 1)
data = load_gdal(self.filename, rows=rows, cols=cols, masked=self._masked)
return data.squeeze()
def __len__(self):
return self._shape[0]
@property
def ndim(self):
return 2
@property
def dtype(self):
return self._dtype
@property
def shape(self):
return self._shape
@property
def chunks(self):
return self._chunks
class GDALStackReader:
def __init__(
self, filenames: list[Filename], chunk_depth: bool = True, masked: bool = True
):
self.filenames = filenames
assert all(
get_raster_xysize(f) == get_raster_xysize(filenames[0]) for f in filenames
)
f0 = filenames[0]
self._shape = (len(filenames), *get_raster_xysize(f0)[::-1])
self._nodata = get_raster_nodata(f0)
self._dtype = get_raster_dtype(f0)
chunks2d = get_raster_chunk_size(f0)[::-1]
self._chunks = (len(filenames), *chunks2d) if chunk_depth else [1, *chunks2d]
self._masked = masked
def __getitem__(self, index):
if isinstance(index, int):
return load_gdal(self.filenames[index], masked=self._masked)
n, rows, cols = index
if isinstance(n, int):
n = slice(n, n + 1)
if isinstance(rows, int):
rows = slice(rows, rows + 1)
if isinstance(cols, int):
cols = slice(cols, cols + 1)
data = [
load_gdal(f, rows=rows, cols=cols, masked=self._masked)
for f in self.filenames[n]
]
return np.stack(data, axis=0).squeeze()
@property
def ndim(self):
return 3
@property
def dtype(self):
return self._dtype
@property
def shape(self):
return self._shape
@property
def chunks(self):
return self._chunks
def __len__(self):
return self._shape[0]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment