Last active
June 3, 2021 19:33
-
-
Save alkalait/c99213c164df691b5e37cd96d5ab9ab2 to your computer and use it in GitHub Desktop.
PyTorch SpaceNet7 Dataset
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"split": { | |
"L15-0331E-1257N_1327_3160_13": "train", | |
"L15-0357E-1223N_1429_3296_13": "train", | |
"L15-0487E-1246N_1950_3207_13": "train", | |
"L15-0571E-1075N_2287_3888_13": "train", | |
"L15-0614E-0946N_2459_4406_13": "train", | |
"L15-0924E-1108N_3699_3757_13": "train", | |
"L15-0977E-1187N_3911_3441_13": "train", | |
"L15-1014E-1375N_4056_2688_13": "train", | |
"L15-1015E-1062N_4061_3941_13": "train", | |
"L15-1025E-1366N_4102_2726_13": "train", | |
"L15-1138E-1216N_4553_3325_13": "train", | |
"L15-1172E-1306N_4688_2967_13": "train", | |
"L15-1185E-0935N_4742_4450_13": "train", | |
"L15-1200E-0847N_4802_4803_13": "train", | |
"L15-1204E-1204N_4819_3372_13": "train", | |
"L15-1296E-1198N_5184_3399_13": "train", | |
"L15-1298E-1322N_5193_2903_13": "train", | |
"L15-1335E-1166N_5342_3524_13": "train", | |
"L15-1439E-1134N_5759_3655_13": "train", | |
"L15-1479E-1101N_5916_3785_13": "train", | |
"L15-1481E-1119N_5927_3715_13": "train", | |
"L15-1538E-1163N_6154_3539_13": "train", | |
"L15-1617E-1207N_6468_3360_13": "train", | |
"L15-1672E-1207N_6691_3363_13": "train", | |
"L15-1703E-1219N_6813_3313_13": "train", | |
"L15-1709E-1112N_6838_3742_13": "train", | |
"L15-1716E-1211N_6864_3345_13": "train", | |
"L15-1748E-1247N_6993_3202_13": "train", | |
"L15-1848E-0793N_7394_5018_13": "train", | |
"L15-0683E-1006N_2732_4164_13": "train", | |
"L15-0760E-0887N_3041_4643_13": "train", | |
"L15-0434E-1218N_1736_3318_13": "train", | |
"L15-0368E-1245N_1474_3210_13": "train", | |
"L15-0632E-0892N_2528_4620_13": "train", | |
"L15-1049E-1370N_4196_2710_13": "train", | |
"L15-1210E-1025N_4840_4088_13": "train", | |
"L15-1289E-1169N_5156_3514_13": "train", | |
"L15-0361E-1300N_1446_2989_13": "val", | |
"L15-1209E-1113N_4838_3737_13": "val", | |
"L15-0566E-1185N_2265_3451_13": "val", | |
"L15-1276E-1107N_5105_3761_13": "val", | |
"L15-1438E-1134N_5753_3655_13": "test", | |
"L15-0586E-1127N_2345_3680_13": "test", | |
"L15-0358E-1220N_1433_3310_13": "test", | |
"L15-1389E-1284N_5557_3054_13": "test", | |
"L15-0595E-1278N_2383_3079_13": "rejected" | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Author: Freddie Kalaitzis | |
# License: MIT | |
# Source: https://gist.github.com/alkalait/c99213c164df691b5e37cd96d5ab9ab2 | |
import functools | |
import itertools | |
import os | |
import re | |
import warnings | |
import numpy as np | |
import pandas as pd | |
import torch | |
import xarray | |
from collections import OrderedDict | |
from datetime import datetime | |
from glob import glob | |
from kornia.contrib import ExtractTensorPatches | |
from typing import Any, Callable, Dict, List, Optional, Tuple, Union | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from tqdm.notebook import tqdm | |
############################################################################### | |
## Adapt this to your machine... | |
# MNT_SN7 = '/data/spacenet/train/' | |
## Use this path only if the data have been copied into shared memory: | |
## rsync -arhP /data/spacenet/train /dev/shm | |
MNT_SN7 = '/dev/shm/train/' | |
############################################################################### | |
TRAIN_VAL_TEST_SPLIT_SN7 = os.path.join(os.path.dirname(__file__), 'scene_split.json') | |
S2_MEAN = np.array([1972, 1989, 2137, 2245, 2514, 2825, 2962, 3048, 3047, 3740, 2422, 2017]) | |
S2_STD = np.array([471, 517, 522, 576, 565, 590, 617, 674, 640, 766, 596, 561]) | |
def spacenet7_index( | |
data_dir: str = MNT_SN7, | |
folder_planet: str = "images", | |
folder_sentinel: str = "S2L2A", | |
split_json: str = TRAIN_VAL_TEST_SPLIT_SN7, | |
) -> pd.DataFrame: | |
""" | |
Creates a dataframe with the file index of the SpaceNet7 dataset. | |
Args: | |
data_dir: root folder of the dataset (default=MNT_SN7). | |
folder_planet: folder of the PlanetScope images (default='images'): | |
{MNT_SN7}/{scene_id}/{folder_images}. | |
folder_sentinel: folder of the Sentinel-2 images (default='S2L2A'): | |
{MNT_SN7}/{scene_id}/{folder_images}. | |
folder_cloud_masks: folder of the cloud masks: | |
{MNT_SN7}/{scene_id}/{folder_cloud_masks}. | |
folder_labels: folder of the labels: | |
{MNT_SN7}/{scene_id}/{folder_labels}. | |
split_json: path of the JSON file with train/val/test split data. | |
Returns: | |
Dataframe with the paths of the images. | |
""" | |
paths_planet = sorted(glob(os.path.join(data_dir, f'*/{folder_planet}/*'))) | |
paths_sentinel = sorted(glob(os.path.join(data_dir, f'*/{folder_sentinel}/*'))) | |
## TODO find folder for sentinel cloud masks. | |
df = pd.concat([pd.DataFrame(dict(path=paths_sentinel, sat='sentinel')), | |
pd.DataFrame(dict(path=paths_planet, sat='planet'))], ignore_index=True) | |
## Scene names, filenames from paths. | |
df['scene'] = df.path.map(lambda x: x.split(os.path.sep)[4]) | |
df['basename'] = df['path'].map(lambda x: os.path.basename(x)) | |
df[['year', 'month', 'day']] = 1 | |
## Year, month, day from names (sentinel). | |
ix = df.query('sat=="sentinel"').index | |
df.loc[ix, ['year', 'month', 'day']] = (df.loc[ix, 'basename'] | |
.str.extract(r'(\d{4})-(\d{2})-(\d{2})') | |
.rename(columns={0:'year', 1:'month', 2:'day'}) | |
.astype(int)) | |
## Year, month, day from names (planet; different format). | |
ix = df.query('sat=="planet"').index | |
df.loc[ix, ['year', 'month']] = (df.loc[ix, 'basename'] | |
.str.extract(r'monthly_(\d{4})_(\d{2})') | |
.rename(columns={0:'year', 1:'month', 2:'day'}) | |
.astype(int)) | |
## Datetimes from year, month, day. | |
df['datetime'] = df.apply(lambda x: datetime(year=x['year'], month=x['month'], day=x['day']), axis=1) | |
df.set_index(['sat', 'scene', 'datetime'], inplace=True) | |
df.sort_index(inplace=True) | |
def get_cloud_mask_path(path): | |
return (re.sub(folder_images, folder_cloud_masks, path) | |
.replace('TCI', 'SCL')) | |
def get_label_mask_path(path): | |
return (re.sub(folder_images, folder_labels, path) | |
.replace('.tif', '_Buildings.tif')) | |
## Split scenes into train / val / test | |
scenes = df.index.get_level_values('scene') | |
partition = pd.read_json(split_json).to_dict()['split'] | |
df['split'] = scenes.map(partition) | |
return df | |
class SN7Dataset(Dataset): | |
""" SpaceNet7 dataset. """ | |
def __init__( | |
self, | |
product: str = 'planet', | |
root: str = MNT_SN7, | |
df: Optional[pd.DataFrame] = None, | |
window_size: Optional[int] = None, | |
stride: Optional[int] = None, | |
date_range: Optional[pd.DatetimeIndex] = None, | |
bands: Optional[List[int]] = None, | |
mode: str = 'train', | |
transform: Optional[Callable] = None, | |
precache: bool = False, | |
) -> None: | |
""" | |
Args: | |
product: name of the product (default='planet'). | |
root: data root path (default=MNT_SN7). | |
df: (optional), dataframe of image paths, see spacenet7_index(). | |
window_size: (optional) size of the tiling window. By default, no tiling is performed. | |
stride: (optional) controls the stride to apply to the sliding window and | |
regulates the overlapping between the extracted patches. | |
date_range (pandas.DatetimeIndex, optional): Date range of scenes. | |
Overrides `year` and `month`. | |
bands: band indices to read. All bands are read by default. | |
mode: one of [train (default) | val | test]. | |
transform: (optional) torchvision transform. | |
precache: (default=False) pre-cache all data in memory. Use with caution. | |
""" | |
if mode not in ['train', 'val', 'test']: | |
raise ValueError(f"mode must be one of [train | val | test]. Got {mode}.") | |
self.product = product | |
self.root = root | |
self.date_range = date_range | |
self.window_size = window_size | |
self.stride = stride or window_size | |
self.bands = bands if bands is not None else slice(None) # All bands by default | |
self.mode = mode | |
self.transform = transform or self.default_transform() | |
self.xarrays: Dict[str, xarray.DataArray] = {} # Scene readers | |
self.patches: List[Tuple[str, Tuple[slice, slice]]] # Patch slice storage | |
self.df = df if df is not None else spacenet7_index() | |
## Restrict to scenes of this mode. | |
self.df = self.df.query(f'split=="{mode}"') | |
## Restrict to scenes of this date range. | |
if isinstance(self.date_range, pd.DatetimeIndex): | |
ix = self.df.index.get_level_values('datetime').isin(self.date_range) | |
self.df = self.df.loc[ix] | |
elif self.date_range is not None: | |
raise TypeError(f'date_range must be a pandas.DatetimeIndex. Got {type(self.date_range)}.') | |
self.scenes = self.df.index.get_level_values('scene').unique() | |
self._prepare_dataset() | |
if precache: | |
self._precache() | |
def default_transform(self) -> Callable: | |
return transforms.Compose([torch.as_tensor]) | |
def _prepare_dataset(self) -> None: | |
""" Read and TIFF images, and tile if needed. """ | |
for scene in tqdm(self.scenes, desc='Reading scenes...'): | |
paths = self.df.loc[(self.product, scene)]['path'] | |
X = [xarray.open_rasterio(path, parse_coordinates=False)[self.bands]for path in paths] | |
dates = xarray.Variable('time', paths.index) | |
X = xarray.concat(X, dim=dates) | |
self.xarrays[scene] = X | |
self.patches = self._extract_patch_slices() | |
def _extract_patch_slices(self) -> List[Tuple[str, Tuple[slice, slice]]]: | |
""" | |
Extracts slices by tiling. | |
Returns: | |
A list of tuples, with the scene ID and a slice window (tuple of slices). | |
""" | |
def tile_slices( | |
image_size: Tuple[int, int], | |
window_size: Tuple[int, int], | |
stride: Tuple[int, int], | |
) -> List[Tuple[slice, slice]]: | |
ims, ws, s = image_size, window_size, stride | |
y_slices = [slice(i, i + ws[0]) for i in np.arange(0, ims[0] - ws[0] + 1, s[0])] | |
x_slices = [slice(i, i + ws[1]) for i in np.arange(0, ims[1] - ws[1] + 1, s[1])] | |
return list(itertools.product(y_slices, x_slices)) | |
windows = [] | |
scenes = [] | |
for scene in self.scenes: | |
ims = self.xarrays[scene].shape[2:] | |
ws = (self.window_size, self.window_size) if self.window_size is not None else ims | |
s = (self.stride, self.stride) if self.stride is not None else ws | |
windows += tile_slices(image_size=ims, window_size=ws, stride=s) | |
scenes += [scene] * len(windows) | |
return list(zip(scenes, windows)) | |
def _precache(self) -> None: | |
""" Cache full dataset simply by getting all items. """ | |
for i in tqdm(range(len(self)), desc='Pre-caching...'): | |
self[i] | |
def __len__(self) -> int: | |
return len(self.patches) | |
def _get_image(self, index: int) -> xarray.DataArray: | |
scene, window = self.patches[index] | |
return self.xarrays[scene][:, :, window[0], window[1]].data.copy() # (T, C, H, W) | |
@functools.lru_cache # Memoize function | |
def __getitem__(self, index: int) -> Union[Tensor, xarray.DataArray]: | |
x = self._get_image(index) # (T, C, H, W) | |
return self.transform(x) | |
def __repr__(self) -> str: | |
if self.date_range is not None: | |
date_str = f"start={self.date_range.min().date()}, end={self.date_range.max().date()}" | |
else: | |
date_str = "start=None, end=None" | |
return ( | |
f"root={self.root}\n" | |
f"split={self.mode}\n" | |
f"{date_str}\n" | |
f"scenes={len(self.scenes)}, samples={len(self)}\n" | |
f"window_size={self.window_size}" | |
) | |
class ConcatDataset(Dataset): | |
""" One dataset to rule them all. """ | |
def __init__(self, **datasets: Dict[str, Dataset]) -> None: | |
L = [len(ds) for ds in datasets.values()] | |
if len(set(L)) > 1: | |
warnings.warn(f'Datasets with non-equal length.', category=UserWarning) | |
self._len = min(L) | |
self.datasets = datasets | |
def __getitem__(self, index : int) -> Dict[str, Tensor]: | |
item = {name: ds[index] for name, ds in self.datasets.items()} | |
return item | |
def __len__(self) -> int: | |
return self._len | |
class TransformDataset(Dataset): | |
""" | |
Tranform of a dataset. | |
Args: | |
dataset: The whole Dataset. | |
transform: torchvision transform. | |
""" | |
def __init__(self, dataset: Dataset, transform: Callable) -> None: | |
super().__init__() | |
self.dataset = dataset | |
self.transform = transform | |
def __getitem__(self, index: int) -> Any: | |
item = self.dataset.__getitem__(index) | |
return self.transform(item) | |
def __len__(self) -> int: | |
return len(self.dataset) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment