Skip to content

Instantly share code, notes, and snippets.

@moradology
Last active October 21, 2024 20:38
Show Gist options
  • Save moradology/0b7aeed018d660a718045b961f93fe91 to your computer and use it in GitHub Desktop.
Save moradology/0b7aeed018d660a718045b961f93fe91 to your computer and use it in GitHub Desktop.
a first pass at code to produce tiles for cnn training
import xarray as xr
import numpy as np
# This is written for _all.nc files but I'm not sure those have the data they should - should verify.
# This is a starting point, nothing here is holy, any of it can be changed according to needs I've failed to anticipate.
# Good luck!
def create_tiles(ds: xr.Dataset, tile_size: int = 32, overlap: float = 0.1) -> xr.Dataset:
# Validate overlap parameter
if not (0 <= overlap < 1):
raise ValueError("Overlap must be between 0 and 1 (exclusive)")
"""
Creates a series of tiles for each time step in the dataset.
Parameters:
- ds (xarray.Dataset): The xarray dataset to be processed.
- tile_size (int): Size of the tiles to be produced in both x and y dimensions. Default is 32.
NOTE: these netcdfs have pretty small x,y dimensions so while I'm used to seeing 256x256, we'll need'em smaller
ALSO NOTE: I originally wrote this thing for 256 as a default so the examples below are sometimes 256. It's late.
- overlap (float): Fraction of overlap between adjacent tiles. Default is 0.1 (10% overlap).
The function will:
1. Iterate through x and y dimensions to create tiles with slight overlap for each time step.
2. Buffer tiles where the input dimensions are not evenly divisible by the tile size.
3. Assign new coordinates to tiles in the range from 1 to tile_size (e.g., if tile_size is 32, coordinates will be set from 1 to 32).
This step ensures each tile has consistent spatial dimensions, making the downstream processing simpler.
4. Concatenate all tiles along a new 'tiles' dimension.
Returns:
- xarray.Dataset: A new dataset containing the tiled data.
The purpose of this function is to preprocess a given wildfire dataset by splitting it into smaller, consistently sized spatial tiles.
This approach means that the input data fed into a convolutional neural network, is uniform in size. Probably a good idea
for training efficiency/performance. We use overlapping windows to ensure that spatial context at tile boundaries can be learned from.
Buffering/padding means that edge tiles have the correct dimensions. This entire preprocessing step makes the data ready for training.
Note: Padding may introduce artificial data that could affect model training. Consider the impact of zero-padding on model performance.
There are alternatives, but they're not great:
Zero Padding: The most common form of padding. zeros are added around the borders of the input tile. simple, computationally efficient but can introduce artificial boundaries.
zero values do not naturally occur in most data, but neural networks are pretty good at approximating arbitrary functions ¯\_(ツ)_/¯
Reflect Padding: Borders are padded with values that mirror the adjacent pixels. preserves edge continuity and is often preferred when natural continuity is important (e.g., image data).
Constant Padding: Padding with a constant value (other than zero). This is less common but can be used if a specific value has significance (e.g., mean value of the dataset).
Replicate Padding: Extends the border values to the padding area, replicating the last pixel value. This is also used to preserve the continuity of edges, similar to reflect padding.
"""
# It's worth making sure that all input and target features are present in the dataset
# In the netcdfs I was looking at, some of these values were just empty across the board. Looked into it and
# it appeared as though both h5netcdf and netcdf4 were failing to open some of these files. Examples to investigate:
# s3://maap-ops-workspace/shared/jiannamar07/wildfire_ids_14/F10146_fireline.nc
# s3://maap-ops-workspace/shared/jiannamar07/wildfire_ids_14/F10146_perimeter.nc
input_features: list[str] = ['Fuels', 'GPM.LATE.v5_FWI', 'GEOS-5_FWI', 'isPerimeter', 'isFireline']
target_features: list[str] = ['isPerimeter_target', 'isFireline_target']
for var in input_features:
if var not in ds:
raise KeyError(f"Variable '{var}' not found in the dataset")
# Combine all input features into a single DataArray along a new dimension called 'combined_channels'.
# Here, we are taking only the first channel from each variable (isel(channels=0)) since channels may represent extra dimensions.
input_da: xr.DataArray = xr.concat([ds[var].isel(channels=0) for var in input_features], dim='combined_channels')
# Assign the names of the combined channels (input features) to the newly created dimension.
input_da = input_da.assign_coords(combined_channels=input_features)
# Shift the target features forward in time by 1 step, we're predicting future values.
# For simplicity, fill missing values (due to the shift) with 0. can consider dropping the last time value instead
is_perimeter_da: xr.DataArray = ds['isPerimeter'].isel(channels=0).shift(time=1, fill_value=0)
is_perimeter_da = is_perimeter_da.expand_dims('combined_channels').assign_coords(combined_channels=['isPerimeter_target'])
is_fireline_da: xr.DataArray = ds['isFireline'].isel(channels=0).shift(time=1, fill_value=0)
is_fireline_da = is_fireline_da.expand_dims('combined_channels').assign_coords(combined_channels=['isFireline_target'])
# Concatenate the input features and target features into one combined DataArray along the 'combined_channels' dimension.
combined_da: xr.DataArray = xr.concat([input_da, is_perimeter_da, is_fireline_da], dim='combined_channels')
# Assign the names of all features (input + target) to the 'combined_channels' dimension.
combined_da = combined_da.assign_coords(combined_channels=input_features + target_features).transpose("x", "y", "time", "combined_channels")
# Calculate the mean and standard deviation across the spatial (x, y) and channel dimensions for normalization before tiling.
means_dict = {}
stds_dict = {}
for feature in input_features + target_features:
feature_da = combined_da.sel(combined_channels=feature)
mean_value = feature_da.mean(dim=['x', 'y', 'time'], skipna=True).compute().values.item()
std_value = feature_da.std(dim=['x', 'y', 'time'], skipna=True).compute().values.item()
means_dict[f'{feature}_mean'] = mean_value
stds_dict[f'{feature}_std'] = std_value
print(f"Feature: {feature}, Mean: {mean_value}, Std: {std_value}") # Debugging print statement
tiles: list[xr.DataArray] = []
# Calculate the step size for moving the tile window, considering the overlap parameter.
# For example, if tile_size is 256 and overlap is 0.1 (10%), the step size will be 256 * (1 - 0.1) = 230.
step_size: int = int(tile_size * (1 - overlap))
# Loop through each time step to generate tiles for each time frame independently.
for t in range(combined_da.sizes['time']):
# Extract the data for the current time step.
time_slice: xr.DataArray = combined_da.isel(time=t)
# Loop through the x and y dimensions to create tiles.
for x_start in range(0, combined_da.sizes['x'], step_size):
for y_start in range(0, combined_da.sizes['y'], step_size):
# Calculate the end index for the current tile in the x and y dimensions.
# For example, if x_start is 0 and tile_size is 256, x_end will be min(0 + 256, combined_da.sizes['x']).
x_end: int = min(x_start + tile_size, combined_da.sizes['x'])
y_end: int = min(y_start + tile_size, combined_da.sizes['y'])
# Extract the tile using the calculated indices.
tile: xr.DataArray = time_slice.isel(x=slice(x_start, x_end), y=slice(y_start, y_end))
# Calculate how much padding is needed if the tile is smaller than the desired tile size.
# For example, if x_end is 200 and x_start is 0, pad_x will be 256 - (200 - 0) = 56. (for tiles sized 256)
pad_x: int = tile_size - (x_end - x_start)
pad_y: int = tile_size - (y_end - y_start)
# If padding is needed (i.e., the tile is smaller than the desired size), pad the tile with zeros.
if pad_x > 0 or pad_y > 0:
print(f"Padding tile at x_start: {x_start}, y_start: {y_start}, pad_x: {pad_x}, pad_y: {pad_y}") # Debugging print statement
tile = tile.pad(x=(0, pad_x), y=(0, pad_y), constant_values=0)
# Assign new coordinates to the x and y dimensions of the tile, ranging from 1 to tile_size.
# For example, if tile_size is 256, the coordinates for x and y will be set from 1 to 256.
tile = tile.assign_coords(x=np.arange(1, tile_size + 1), y=np.arange(1, tile_size + 1))
# Add the generated tile to the list of tiles.
tiles.append(tile)
print(f"Generated tile with x_start: {x_start}, y_start: {y_start}, shape: {tile.shape}") # Debugging print statement
# Concatenate all the tiles along a new 'tiles' dimension.
tiled_da: xr.DataArray = xr.concat(tiles, dim='tiles')
# Assign unique indices to the 'tiles' dimension, from 0 to the number of tiles - 1.
tiled_da = tiled_da.assign_coords(tiles=np.arange(len(tiles)))
# Create a new dataset with the tiled DataArray and assign the calculated attributes.
tiled_dataset: xr.Dataset = xr.Dataset({'combined_features': tiled_da})
tiled_dataset = tiled_dataset.assign_attrs({**means_dict, **stds_dict})
return tiled_dataset
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment