Skip to content

Instantly share code, notes, and snippets.

Created February 9, 2024 20:24
Show Gist options
  • Save satyaog/eb664202daba42dfb60d0d9bf2883c5f to your computer and use it in GitHub Desktop.
Save satyaog/eb664202daba42dfb60d0d9bf2883c5f to your computer and use it in GitHub Desktop.
from dataclasses import dataclass
import io
import os
from pathlib import Path
import re
import tarfile
import datasets
from datasets.data_files import DataFilesDict
from import DownloadManager
from datasets.features import Features
from import DatasetInfo
import numpy as np
from huggingface_hub import HfFileSystem
import huggingface_hub
def strip(string:str, chars=r"\s"):
return re.sub(f"{chars}+$", "", string)
class SplittedFile(io.RawIOBase):
class Split:
name: str
pos: int
size: int
def __init__(self, filesplits: list, mode: str = "rb") -> None:
self._splits = [
self.Split(fn, 0, Path(fn).stat().st_size)
for fn in filesplits
self._size = 0
for split in self._splits:
split.pos = self._size
self._size += split.size
self._file: io.IOBase = None
self._mode = mode
self._split_index = None
def __enter__(self):
if self.closed:
return self
def __exit__(self, *args, **kwargs):
del args, kwargs
def close(self) -> None:
if not self.closed:
self._file = None
self._split_index = None
def closed(self):
return self._file is None or self._file.closed
def _current_split(self) -> "SplittedFile.Split | None":
return self._splits[self._split_index] if self._split_index is not None else None
def flush(self) -> None:
def isatty(self) -> bool:
return False
def readable(self) -> bool:
return True
def read(self, size: int = -1) -> bytes | None:
buffer = np.empty(size if size > -1 else self._size, dtype="<u1")
size = self.readinto(memoryview(buffer))
return bytes(buffer[:size])
def readall(self) -> bytes:
def readinto(self, buffer: io.IOBase) -> int | None:
if not isinstance(buffer, memoryview):
buffer = memoryview(buffer)
cum_bytes_read = 0
while cum_bytes_read < len(buffer):
bytes_read = self._file.readinto(buffer[cum_bytes_read:])
cum_bytes_read += bytes_read
if not bytes_read:
if self._split_index + 1 >= len(self._splits):
# Open the next split to read from
self._open_split(self._split_index + 1)
return cum_bytes_read
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self.closed:
raise ValueError("seek of closed file")
if whence == io.SEEK_CUR:
offset += self.tell()
elif whence == io.SEEK_END:
offset += self._size
for i, split in enumerate(self._splits):
if offset < split.pos + split.size:
self._open_split(i, offset - split.pos)
return offset
def seekable(self) -> bool:
return True
def tell(self) -> int:
if self.closed:
raise ValueError("I/O operation on closed file")
return self._current_split.pos + self._file.tell()
def writable(self) -> bool:
return False
def _open_split(self, split_index, split_offset=0) -> None:
split = self._splits[split_index]
if split_index != self._split_index:
self._file = open(, self._mode)
self._split_index = split_index
class MilaDatasetBuilder(datasets.GeneratorBasedBuilder):
def __init__(
self, cache_dir: str | None = None, dataset_name: str | None = None,
config_name: str | None = None, hash: str | None = None, base_path:
str | None = None, info: DatasetInfo | None = None, features:
Features | None = None, token: bool | str | None = None,
use_auth_token="deprecated", repo_id: str | None = None, data_files:
str | list | dict | DataFilesDict | None = None, data_dir: str |
None = None, storage_options: dict | None = None, writer_batch_size:
int | None = None, name="deprecated",
if not dataset_name:
dataset_name = repo_id.replace("/", "___")
if not base_path and config_name:
base_path = config_name
if not os.path.isdir(base_path or ""):
base_path = f"hf://datasets/{repo_id}@{config_kwargs.get('version', self.DEFAULT_VERSION)}/{base_path or ''}".rstrip("/")
if any(isinstance(data_files, t) for t in (dict, DataFilesDict)):
if "*" in data_files and datasets.Split.ALL not in data_files:
data_files[datasets.Split.ALL] = data_files["*"]
del data_files["*"]
super().__init__(cache_dir, dataset_name, config_name, hash, base_path, info, features, token, use_auth_token, repo_id, data_files, data_dir, storage_options, writer_batch_size, name, **config_kwargs)
self._cache_downloaded_dir = str(Path(self._cache_downloaded_dir) / self.repo_id / self._version())
def _build_cache_dir(self):
cache_dir = super()._build_cache_dir()
version = self._version()
return str(
) / self.repo_id / version
def _info(self) -> DatasetInfo:
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
info: (DatasetInfo) The dataset information
return DatasetInfo()
def _split_generators(self, dl_manager: DownloadManager):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
return [
gen_kwargs={'file': ''},
gen_kwargs={'file': ''},
The above code will first call `_generate_examples(file='')`
to write the train data, then `_generate_examples(file='')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
dl_manager (`DownloadManager`):
Download manager to download the data
downloaded_files = {s: for s, files in self.config.data_files.items()}
symlinks = {}
for s in self.config.data_files:
for _file, _downloaded_file in zip(self.config.data_files[s], downloaded_files[s]):
url_path = _file.split(self.repo_id)[-1]
url_path = "/".join(url_path.split("/")[1:])
symlink = Path(_downloaded_file).parent / url_path
if not symlink.parent.exists():
if not symlink.exists():
symlinks.setdefault(s, [])
return [
datasets.SplitGenerator(name=s, gen_kwargs={"filepaths": files})
for s, files in symlinks.items()
def _generate_examples(self, filepaths, **_kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
**kwargs (additional keyword arguments):
Arguments forwarded from the SplitGenerator.gen_kwargs
key: `str` or `int`, a unique deterministic example identification key.
* Unique: An error will be raised if two examples are yield with the
same key.
* Deterministic: When generating the dataset twice, the same example
should have the same key.
Good keys can be the image id, or line number if examples are extracted
from a text file.
The key will be hashed and sorted to shuffle examples deterministically,
such as generating the dataset multiple times keep examples in the
same order.
example: `dict<str feature_name, feature_value>`, a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `{...})`.
id_ = 0
with SplittedFile(filepaths) as sf:
tf =
while True:
tarinfo =
if tarinfo is None:
f = tf.extractfile(tarinfo)
if f is not None:
b =
import hashlib
print(f"{hashlib.md5(b).hexdigest()} {tarinfo.path}")
yield id_, {"filename":tarinfo.path, "bytes":b}
id_ += 1
def _version(self):
return self.config.version if self.config.version > "0.0.0" else self.base_path.split(self.repo_id)[-1].strip("@")
if __name__ == "__main__":
MilaDatasetBuilder(repo_id="satyaortiz-gagne/bigearthnet", data_files={"S1":["S1/**.tar.gz*"], "S2":["S2/**.tar.gz*"]}).download_and_prepare()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment