This is pretty experimental WMSDataset but you are welcome to copy it and play with it, basically you build it with an URL string to the capabilities of the server, the resolution you want to use and a layer name and optionaly a CRS string otherwise the first one provided by the server will be used.
aerial = WMSDataset("http://localhost:8080/geoserver/ows?service=WMS&version=1.3.0&request=GetCapabilities",
1.0, layer="ian:aerial_rgb25cm", crs="27700")
...
from torchgeo.datasets.utils import BoundingBox
sampler = RandomGeoSampler(train_dset, size=500, length=1, units=Units.CRS,
roi=BoundingBox(290000.0,550000.0,390000.0,620000.0,0,0))
batch = next(iter(dataloader))
plt.imshow(batch['image'][0][1:4].permute(1,2,0))
I'd recommend cascading a remote WMS through a local geoserver or other cache to avoid really annoying the owner of the WMS you are using. If I ever finish writing the current paper I'll see if the TorchGeo guys want it.
Licence: MIT
Acknowledgement: The author received support from the UK Research and Innovation Future Leaders Fellowships "Missing Data as Useful Data", grant number MR/Y011856/1, “Indicative Data: Extracting 3D Models of Cities from Unavailability and Degradation of Global Navigation Satellite Systems (GNSS)”, grant number MR/S01795X/2, and the Alan Turing Institute-DSO partnership project on "Multi-Lingual and Multi-Modal Location Information Extraction"