Skip to content

Instantly share code, notes, and snippets.

@ian-gla
Last active March 26, 2024 15:45
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ian-gla/49bd29c30b62a4f83b4e53cd15b1b091 to your computer and use it in GitHub Desktop.
Save ian-gla/49bd29c30b62a4f83b4e53cd15b1b091 to your computer and use it in GitHub Desktop.
A WMS dataset for torchgeo

This is pretty experimental WMSDataset but you are welcome to copy it and play with it, basically you build it with an URL string to the capabilities of the server, the resolution you want to use and a layer name and optionaly a CRS string otherwise the first one provided by the server will be used.

aerial = WMSDataset("http://localhost:8080/geoserver/ows?service=WMS&version=1.3.0&request=GetCapabilities", 
                    1.0, layer="ian:aerial_rgb25cm", crs="27700")
                    
...
from torchgeo.datasets.utils import BoundingBox

sampler = RandomGeoSampler(train_dset, size=500, length=1, units=Units.CRS, 
                           roi=BoundingBox(290000.0,550000.0,390000.0,620000.0,0,0))
batch = next(iter(dataloader))
plt.imshow(batch['image'][0][1:4].permute(1,2,0))                   

image

I'd recommend cascading a remote WMS through a local geoserver or other cache to avoid really annoying the owner of the WMS you are using. If I ever finish writing the current paper I'll see if the TorchGeo guys want it.

Licence: MIT

Acknowledgement: The author received support from the UK Research and Innovation Future Leaders Fellowships "Missing Data as Useful Data", grant number MR/Y011856/1, “Indicative Data: Extracting 3D Models of Cities from Unavailability and Degradation of Global Navigation Satellite Systems (GNSS)”, grant number MR/S01795X/2, and the Alan Turing Institute-DSO partnership project on "Multi-Lingual and Multi-Modal Location Information Extraction"

from typing import Any
from owslib.wms import WebMapService
import rasterio
from rasterio.coords import BoundingBox
#from rasterio.crs import CRS, CRSError
from torchgeo.datasets import GeoDataset
from torch import Tensor
from io import BytesIO
from PIL import Image
import torchvision.transforms as transforms
class WMSDataset(GeoDataset):
"""
Allow models to fetch images from a WMS (at a good resolution)
"""
_url = None
_wms = None
_layers = []
_layer = None
_layer_name = ""
is_image = True
def __init__(self, url, res, layer=None, transforms=None, crs=None):
super().__init__(transforms)
self._url = url
self._res = res
if crs is not None:
self._crs = CRS.from_epsg(crs)
print(self.crs)
self._wms = WebMapService(url)
print(self._wms.identification.version)
self._format = self._wms.getOperationByName('GetMap').formatOptions[0]
self._layers = list(self._wms.contents)
if layer in self._layers: self.layer(layer, crs)
def layer(self, layer, crs=None):
self._layer = self._wms[layer]
self._layer_name = layer
print(list[self._wms[layer].boundingBox[:4]])
coords = self._wms[layer].boundingBox
self.index = Index(interleaved=False, properties=Property(dimension=3),)
self.index.insert(0, (float(coords[0]), float(coords[2]), float(coords[1]), float(coords[3]), 0, 9.223372036854776e+18))
print(self.index)
print(self.bounds)
print(self._layer_name)
if crs is None:
self._crs = CRS.from_epsg(self._layer.crsOptions[0])
def getlayer(self):
return self._layer
def layers(self):
return self._layers
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
Returns:
sample of image/mask and metadata at that index
Raises:
IndexError: if query is not found in the index
"""
print("query size =", query.maxx-query.minx, query.maxy-query.miny)
img = self._wms.getmap(layers=[self._layer_name],
srs="epsg:"+str(self.crs.to_epsg()),
bbox=(query.minx, query.miny, query.maxx, query.maxy),
# TODO fix size
size=(500, 500),
format=self._format,
transparent=True
)
sample = {"crs": self.crs, "bbox": query}
transform = transforms.Compose([transforms.ToTensor()])
# Convert the PIL image to Torch tensor
img_tensor = transform(Image.open(BytesIO(img.read())))
if self.is_image:
sample["image"] = img_tensor
else:
sample["mask"] = img_tensor
if self.transforms is not None:
sample = self.transforms(sample)
return sample
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment