Skip to content

Instantly share code, notes, and snippets.

@satyaog
Created February 9, 2024 20:24
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save satyaog/eb664202daba42dfb60d0d9bf2883c5f to your computer and use it in GitHub Desktop.
Save satyaog/eb664202daba42dfb60d0d9bf2883c5f to your computer and use it in GitHub Desktop.
MilaDatasetBuilder
from dataclasses import dataclass
import io
import os
from pathlib import Path
import re
import tarfile
import datasets
from datasets.data_files import DataFilesDict
from datasets.download.download_manager import DownloadManager
from datasets.features import Features
from datasets.info import DatasetInfo
import numpy as np
from huggingface_hub import HfFileSystem
import huggingface_hub
def strip(string:str, chars=r"\s"):
return re.sub(f"{chars}+$", "", string)
class SplittedFile(io.RawIOBase):
@dataclass
class Split:
name: str
pos: int
size: int
def __init__(self, filesplits: list, mode: str = "rb") -> None:
super().__init__()
self._splits = [
self.Split(fn, 0, Path(fn).stat().st_size)
for fn in filesplits
]
self._size = 0
for split in self._splits:
split.pos = self._size
self._size += split.size
self._file: io.IOBase = None
self._mode = mode
self._split_index = None
def __enter__(self):
if self.closed:
self._open_split(0)
return self
def __exit__(self, *args, **kwargs):
del args, kwargs
self.close()
def close(self) -> None:
if not self.closed:
self._file.close()
self._file = None
self._split_index = None
@property
def closed(self):
return self._file is None or self._file.closed
@property
def _current_split(self) -> "SplittedFile.Split | None":
return self._splits[self._split_index] if self._split_index is not None else None
def flush(self) -> None:
pass
def isatty(self) -> bool:
return False
def readable(self) -> bool:
return True
def read(self, size: int = -1) -> bytes | None:
buffer = np.empty(size if size > -1 else self._size, dtype="<u1")
size = self.readinto(memoryview(buffer))
return bytes(buffer[:size])
def readall(self) -> bytes:
return self.read(-1)
def readinto(self, buffer: io.IOBase) -> int | None:
if not isinstance(buffer, memoryview):
buffer = memoryview(buffer)
cum_bytes_read = 0
while cum_bytes_read < len(buffer):
bytes_read = self._file.readinto(buffer[cum_bytes_read:])
cum_bytes_read += bytes_read
if not bytes_read:
if self._split_index + 1 >= len(self._splits):
break
# Open the next split to read from
self._open_split(self._split_index + 1)
return cum_bytes_read
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
if self.closed:
raise ValueError("seek of closed file")
if whence == io.SEEK_CUR:
offset += self.tell()
elif whence == io.SEEK_END:
offset += self._size
for i, split in enumerate(self._splits):
if offset < split.pos + split.size:
self._open_split(i, offset - split.pos)
break
return offset
def seekable(self) -> bool:
return True
def tell(self) -> int:
if self.closed:
raise ValueError("I/O operation on closed file")
return self._current_split.pos + self._file.tell()
def writable(self) -> bool:
return False
def _open_split(self, split_index, split_offset=0) -> None:
split = self._splits[split_index]
if split_index != self._split_index:
self.close()
self._file = open(split.name, self._mode)
self._split_index = split_index
self._file.seek(split_offset)
class MilaDatasetBuilder(datasets.GeneratorBasedBuilder):
DEFAULT_VERSION = "main"
def __init__(
self, cache_dir: str | None = None, dataset_name: str | None = None,
config_name: str | None = None, hash: str | None = None, base_path:
str | None = None, info: DatasetInfo | None = None, features:
Features | None = None, token: bool | str | None = None,
use_auth_token="deprecated", repo_id: str | None = None, data_files:
str | list | dict | DataFilesDict | None = None, data_dir: str |
None = None, storage_options: dict | None = None, writer_batch_size:
int | None = None, name="deprecated",
**config_kwargs):
if not dataset_name:
dataset_name = repo_id.replace("/", "___")
if not base_path and config_name:
base_path = config_name
if not os.path.isdir(base_path or ""):
base_path = f"hf://datasets/{repo_id}@{config_kwargs.get('version', self.DEFAULT_VERSION)}/{base_path or ''}".rstrip("/")
if any(isinstance(data_files, t) for t in (dict, DataFilesDict)):
if "*" in data_files and datasets.Split.ALL not in data_files:
data_files[datasets.Split.ALL] = data_files["*"]
del data_files["*"]
super().__init__(cache_dir, dataset_name, config_name, hash, base_path, info, features, token, use_auth_token, repo_id, data_files, data_dir, storage_options, writer_batch_size, name, **config_kwargs)
self._cache_downloaded_dir = str(Path(self._cache_downloaded_dir) / self.repo_id / self._version())
def _build_cache_dir(self):
cache_dir = super()._build_cache_dir()
version = self._version()
return str(
Path(
strip(
"/".join(cache_dir.split(self.dataset_name)[0:-1]),
"[^a-zA-Z-]"
)
) / self.repo_id / version
)
def _info(self) -> DatasetInfo:
"""Construct the DatasetInfo object. See `DatasetInfo` for details.
Warning: This function is only called once and the result is cached for all
following .info() calls.
Returns:
info: (DatasetInfo) The dataset information
"""
return DatasetInfo()
def _split_generators(self, dl_manager: DownloadManager):
"""Specify feature dictionary generators and dataset splits.
This function returns a list of `SplitGenerator`s defining how to generate
data and what splits to use.
Example:
return [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={'file': 'train_data.zip'},
),
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={'file': 'test_data.zip'},
),
]
The above code will first call `_generate_examples(file='train_data.zip')`
to write the train data, then `_generate_examples(file='test_data.zip')` to
write the test data.
Datasets are typically split into different subsets to be used at various
stages of training and evaluation.
Note that for datasets without a `VALIDATION` split, you can use a
fraction of the `TRAIN` data for evaluation as you iterate on your model
so as not to overfit to the `TEST` data.
For downloads and extractions, use the given `download_manager`.
Note that the `DownloadManager` caches downloads, so it is fine to have each
generator attempt to download the source data.
A good practice is to download all data in this function, and then
distribute the relevant parts to each split with the `gen_kwargs` argument
Args:
dl_manager (`DownloadManager`):
Download manager to download the data
Returns:
`list<SplitGenerator>`.
"""
downloaded_files = {s: dl_manager.download(files) for s, files in self.config.data_files.items()}
symlinks = {}
for s in self.config.data_files:
for _file, _downloaded_file in zip(self.config.data_files[s], downloaded_files[s]):
url_path = _file.split(self.repo_id)[-1]
url_path = "/".join(url_path.split("/")[1:])
symlink = Path(_downloaded_file).parent / url_path
if not symlink.parent.exists():
symlink.parent.mkdir()
if not symlink.exists():
symlink.symlink_to(_downloaded_file)
symlinks.setdefault(s, [])
symlinks[s].append(str(symlink))
return [
datasets.SplitGenerator(name=s, gen_kwargs={"filepaths": files})
for s, files in symlinks.items()
]
def _generate_examples(self, filepaths, **_kwargs):
"""Default function generating examples for each `SplitGenerator`.
This function preprocess the examples from the raw data to the preprocessed
dataset files.
This function is called once for each `SplitGenerator` defined in
`_split_generators`. The examples yielded here will be written on
disk.
Args:
**kwargs (additional keyword arguments):
Arguments forwarded from the SplitGenerator.gen_kwargs
Yields:
key: `str` or `int`, a unique deterministic example identification key.
* Unique: An error will be raised if two examples are yield with the
same key.
* Deterministic: When generating the dataset twice, the same example
should have the same key.
Good keys can be the image id, or line number if examples are extracted
from a text file.
The key will be hashed and sorted to shuffle examples deterministically,
such as generating the dataset multiple times keep examples in the
same order.
example: `dict<str feature_name, feature_value>`, a feature dictionary
ready to be encoded and written to disk. The example will be
encoded with `self.info.features.encode_example({...})`.
"""
id_ = 0
with SplittedFile(filepaths) as sf:
tf = tarfile.open(fileobj=sf)
while True:
tarinfo = tf.next()
if tarinfo is None:
break
f = tf.extractfile(tarinfo)
if f is not None:
b = f.read()
import hashlib
print(f"{hashlib.md5(b).hexdigest()} {tarinfo.path}")
yield id_, {"filename":tarinfo.path, "bytes":b}
id_ += 1
def _version(self):
return self.config.version if self.config.version > "0.0.0" else self.base_path.split(self.repo_id)[-1].strip("@")
if __name__ == "__main__":
MilaDatasetBuilder(repo_id="satyaortiz-gagne/bigearthnet", data_files={"S1":["S1/**.tar.gz*"], "S2":["S2/**.tar.gz*"]}).download_and_prepare()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment