Skip to content

Instantly share code, notes, and snippets.

@alkalait
Last active June 3, 2021 19:33
Show Gist options
  • Save alkalait/c99213c164df691b5e37cd96d5ab9ab2 to your computer and use it in GitHub Desktop.
Save alkalait/c99213c164df691b5e37cd96d5ab9ab2 to your computer and use it in GitHub Desktop.
PyTorch SpaceNet7 Dataset
{
"split": {
"L15-0331E-1257N_1327_3160_13": "train",
"L15-0357E-1223N_1429_3296_13": "train",
"L15-0487E-1246N_1950_3207_13": "train",
"L15-0571E-1075N_2287_3888_13": "train",
"L15-0614E-0946N_2459_4406_13": "train",
"L15-0924E-1108N_3699_3757_13": "train",
"L15-0977E-1187N_3911_3441_13": "train",
"L15-1014E-1375N_4056_2688_13": "train",
"L15-1015E-1062N_4061_3941_13": "train",
"L15-1025E-1366N_4102_2726_13": "train",
"L15-1138E-1216N_4553_3325_13": "train",
"L15-1172E-1306N_4688_2967_13": "train",
"L15-1185E-0935N_4742_4450_13": "train",
"L15-1200E-0847N_4802_4803_13": "train",
"L15-1204E-1204N_4819_3372_13": "train",
"L15-1296E-1198N_5184_3399_13": "train",
"L15-1298E-1322N_5193_2903_13": "train",
"L15-1335E-1166N_5342_3524_13": "train",
"L15-1439E-1134N_5759_3655_13": "train",
"L15-1479E-1101N_5916_3785_13": "train",
"L15-1481E-1119N_5927_3715_13": "train",
"L15-1538E-1163N_6154_3539_13": "train",
"L15-1617E-1207N_6468_3360_13": "train",
"L15-1672E-1207N_6691_3363_13": "train",
"L15-1703E-1219N_6813_3313_13": "train",
"L15-1709E-1112N_6838_3742_13": "train",
"L15-1716E-1211N_6864_3345_13": "train",
"L15-1748E-1247N_6993_3202_13": "train",
"L15-1848E-0793N_7394_5018_13": "train",
"L15-0683E-1006N_2732_4164_13": "train",
"L15-0760E-0887N_3041_4643_13": "train",
"L15-0434E-1218N_1736_3318_13": "train",
"L15-0368E-1245N_1474_3210_13": "train",
"L15-0632E-0892N_2528_4620_13": "train",
"L15-1049E-1370N_4196_2710_13": "train",
"L15-1210E-1025N_4840_4088_13": "train",
"L15-1289E-1169N_5156_3514_13": "train",
"L15-0361E-1300N_1446_2989_13": "val",
"L15-1209E-1113N_4838_3737_13": "val",
"L15-0566E-1185N_2265_3451_13": "val",
"L15-1276E-1107N_5105_3761_13": "val",
"L15-1438E-1134N_5753_3655_13": "test",
"L15-0586E-1127N_2345_3680_13": "test",
"L15-0358E-1220N_1433_3310_13": "test",
"L15-1389E-1284N_5557_3054_13": "test",
"L15-0595E-1278N_2383_3079_13": "rejected"
}
}
# Author: Freddie Kalaitzis
# License: MIT
# Source: https://gist.github.com/alkalait/c99213c164df691b5e37cd96d5ab9ab2
import functools
import itertools
import os
import re
import warnings
import numpy as np
import pandas as pd
import torch
import xarray
from collections import OrderedDict
from datetime import datetime
from glob import glob
from kornia.contrib import ExtractTensorPatches
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from torch import Tensor
from torch.utils.data import Dataset
from torchvision import transforms
from tqdm.notebook import tqdm
###############################################################################
## Adapt this to your machine...
# MNT_SN7 = '/data/spacenet/train/'
## Use this path only if the data have been copied into shared memory:
## rsync -arhP /data/spacenet/train /dev/shm
MNT_SN7 = '/dev/shm/train/'
###############################################################################
TRAIN_VAL_TEST_SPLIT_SN7 = os.path.join(os.path.dirname(__file__), 'scene_split.json')
S2_MEAN = np.array([1972, 1989, 2137, 2245, 2514, 2825, 2962, 3048, 3047, 3740, 2422, 2017])
S2_STD = np.array([471, 517, 522, 576, 565, 590, 617, 674, 640, 766, 596, 561])
def spacenet7_index(
data_dir: str = MNT_SN7,
folder_planet: str = "images",
folder_sentinel: str = "S2L2A",
split_json: str = TRAIN_VAL_TEST_SPLIT_SN7,
) -> pd.DataFrame:
"""
Creates a dataframe with the file index of the SpaceNet7 dataset.
Args:
data_dir: root folder of the dataset (default=MNT_SN7).
folder_planet: folder of the PlanetScope images (default='images'):
{MNT_SN7}/{scene_id}/{folder_images}.
folder_sentinel: folder of the Sentinel-2 images (default='S2L2A'):
{MNT_SN7}/{scene_id}/{folder_images}.
folder_cloud_masks: folder of the cloud masks:
{MNT_SN7}/{scene_id}/{folder_cloud_masks}.
folder_labels: folder of the labels:
{MNT_SN7}/{scene_id}/{folder_labels}.
split_json: path of the JSON file with train/val/test split data.
Returns:
Dataframe with the paths of the images.
"""
paths_planet = sorted(glob(os.path.join(data_dir, f'*/{folder_planet}/*')))
paths_sentinel = sorted(glob(os.path.join(data_dir, f'*/{folder_sentinel}/*')))
## TODO find folder for sentinel cloud masks.
df = pd.concat([pd.DataFrame(dict(path=paths_sentinel, sat='sentinel')),
pd.DataFrame(dict(path=paths_planet, sat='planet'))], ignore_index=True)
## Scene names, filenames from paths.
df['scene'] = df.path.map(lambda x: x.split(os.path.sep)[4])
df['basename'] = df['path'].map(lambda x: os.path.basename(x))
df[['year', 'month', 'day']] = 1
## Year, month, day from names (sentinel).
ix = df.query('sat=="sentinel"').index
df.loc[ix, ['year', 'month', 'day']] = (df.loc[ix, 'basename']
.str.extract(r'(\d{4})-(\d{2})-(\d{2})')
.rename(columns={0:'year', 1:'month', 2:'day'})
.astype(int))
## Year, month, day from names (planet; different format).
ix = df.query('sat=="planet"').index
df.loc[ix, ['year', 'month']] = (df.loc[ix, 'basename']
.str.extract(r'monthly_(\d{4})_(\d{2})')
.rename(columns={0:'year', 1:'month', 2:'day'})
.astype(int))
## Datetimes from year, month, day.
df['datetime'] = df.apply(lambda x: datetime(year=x['year'], month=x['month'], day=x['day']), axis=1)
df.set_index(['sat', 'scene', 'datetime'], inplace=True)
df.sort_index(inplace=True)
def get_cloud_mask_path(path):
return (re.sub(folder_images, folder_cloud_masks, path)
.replace('TCI', 'SCL'))
def get_label_mask_path(path):
return (re.sub(folder_images, folder_labels, path)
.replace('.tif', '_Buildings.tif'))
## Split scenes into train / val / test
scenes = df.index.get_level_values('scene')
partition = pd.read_json(split_json).to_dict()['split']
df['split'] = scenes.map(partition)
return df
class SN7Dataset(Dataset):
""" SpaceNet7 dataset. """
def __init__(
self,
product: str = 'planet',
root: str = MNT_SN7,
df: Optional[pd.DataFrame] = None,
window_size: Optional[int] = None,
stride: Optional[int] = None,
date_range: Optional[pd.DatetimeIndex] = None,
bands: Optional[List[int]] = None,
mode: str = 'train',
transform: Optional[Callable] = None,
precache: bool = False,
) -> None:
"""
Args:
product: name of the product (default='planet').
root: data root path (default=MNT_SN7).
df: (optional), dataframe of image paths, see spacenet7_index().
window_size: (optional) size of the tiling window. By default, no tiling is performed.
stride: (optional) controls the stride to apply to the sliding window and
regulates the overlapping between the extracted patches.
date_range (pandas.DatetimeIndex, optional): Date range of scenes.
Overrides `year` and `month`.
bands: band indices to read. All bands are read by default.
mode: one of [train (default) | val | test].
transform: (optional) torchvision transform.
precache: (default=False) pre-cache all data in memory. Use with caution.
"""
if mode not in ['train', 'val', 'test']:
raise ValueError(f"mode must be one of [train | val | test]. Got {mode}.")
self.product = product
self.root = root
self.date_range = date_range
self.window_size = window_size
self.stride = stride or window_size
self.bands = bands if bands is not None else slice(None) # All bands by default
self.mode = mode
self.transform = transform or self.default_transform()
self.xarrays: Dict[str, xarray.DataArray] = {} # Scene readers
self.patches: List[Tuple[str, Tuple[slice, slice]]] # Patch slice storage
self.df = df if df is not None else spacenet7_index()
## Restrict to scenes of this mode.
self.df = self.df.query(f'split=="{mode}"')
## Restrict to scenes of this date range.
if isinstance(self.date_range, pd.DatetimeIndex):
ix = self.df.index.get_level_values('datetime').isin(self.date_range)
self.df = self.df.loc[ix]
elif self.date_range is not None:
raise TypeError(f'date_range must be a pandas.DatetimeIndex. Got {type(self.date_range)}.')
self.scenes = self.df.index.get_level_values('scene').unique()
self._prepare_dataset()
if precache:
self._precache()
def default_transform(self) -> Callable:
return transforms.Compose([torch.as_tensor])
def _prepare_dataset(self) -> None:
""" Read and TIFF images, and tile if needed. """
for scene in tqdm(self.scenes, desc='Reading scenes...'):
paths = self.df.loc[(self.product, scene)]['path']
X = [xarray.open_rasterio(path, parse_coordinates=False)[self.bands]for path in paths]
dates = xarray.Variable('time', paths.index)
X = xarray.concat(X, dim=dates)
self.xarrays[scene] = X
self.patches = self._extract_patch_slices()
def _extract_patch_slices(self) -> List[Tuple[str, Tuple[slice, slice]]]:
"""
Extracts slices by tiling.
Returns:
A list of tuples, with the scene ID and a slice window (tuple of slices).
"""
def tile_slices(
image_size: Tuple[int, int],
window_size: Tuple[int, int],
stride: Tuple[int, int],
) -> List[Tuple[slice, slice]]:
ims, ws, s = image_size, window_size, stride
y_slices = [slice(i, i + ws[0]) for i in np.arange(0, ims[0] - ws[0] + 1, s[0])]
x_slices = [slice(i, i + ws[1]) for i in np.arange(0, ims[1] - ws[1] + 1, s[1])]
return list(itertools.product(y_slices, x_slices))
windows = []
scenes = []
for scene in self.scenes:
ims = self.xarrays[scene].shape[2:]
ws = (self.window_size, self.window_size) if self.window_size is not None else ims
s = (self.stride, self.stride) if self.stride is not None else ws
windows += tile_slices(image_size=ims, window_size=ws, stride=s)
scenes += [scene] * len(windows)
return list(zip(scenes, windows))
def _precache(self) -> None:
""" Cache full dataset simply by getting all items. """
for i in tqdm(range(len(self)), desc='Pre-caching...'):
self[i]
def __len__(self) -> int:
return len(self.patches)
def _get_image(self, index: int) -> xarray.DataArray:
scene, window = self.patches[index]
return self.xarrays[scene][:, :, window[0], window[1]].data.copy() # (T, C, H, W)
@functools.lru_cache # Memoize function
def __getitem__(self, index: int) -> Union[Tensor, xarray.DataArray]:
x = self._get_image(index) # (T, C, H, W)
return self.transform(x)
def __repr__(self) -> str:
if self.date_range is not None:
date_str = f"start={self.date_range.min().date()}, end={self.date_range.max().date()}"
else:
date_str = "start=None, end=None"
return (
f"root={self.root}\n"
f"split={self.mode}\n"
f"{date_str}\n"
f"scenes={len(self.scenes)}, samples={len(self)}\n"
f"window_size={self.window_size}"
)
class ConcatDataset(Dataset):
""" One dataset to rule them all. """
def __init__(self, **datasets: Dict[str, Dataset]) -> None:
L = [len(ds) for ds in datasets.values()]
if len(set(L)) > 1:
warnings.warn(f'Datasets with non-equal length.', category=UserWarning)
self._len = min(L)
self.datasets = datasets
def __getitem__(self, index : int) -> Dict[str, Tensor]:
item = {name: ds[index] for name, ds in self.datasets.items()}
return item
def __len__(self) -> int:
return self._len
class TransformDataset(Dataset):
"""
Tranform of a dataset.
Args:
dataset: The whole Dataset.
transform: torchvision transform.
"""
def __init__(self, dataset: Dataset, transform: Callable) -> None:
super().__init__()
self.dataset = dataset
self.transform = transform
def __getitem__(self, index: int) -> Any:
item = self.dataset.__getitem__(index)
return self.transform(item)
def __len__(self) -> int:
return len(self.dataset)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment