Skip to content

Instantly share code, notes, and snippets.

@ed1d1a8d
Last active July 14, 2022 21:51
Show Gist options
  • Save ed1d1a8d/424e5bc83325c93037cfe2de9e457a68 to your computer and use it in GitHub Desktop.
Save ed1d1a8d/424e5bc83325c93037cfe2de9e457a68 to your computer and use it in GitHub Desktop.
ffcv-tqdm-thread-leak
import torch.utils.data
import torchvision
from ffcv.fields import IntField, RGBImageField
from ffcv.writer import DatasetWriter
ds = torch.utils.data.Subset(
dataset=torchvision.datasets.CIFAR10(
"/var/tmp", train=False, download=True
),
indices=range(64),
)
writer = DatasetWriter(
"test.beton",
{"image": RGBImageField(write_mode="raw"), "label": IntField()},
num_workers=4,
)
writer.from_indexed_dataset(ds)
import os
import psutil
from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
from ffcv.loader import Loader
from ffcv.transforms import ToDevice, ToTensor
from tqdm.auto import tqdm
CUR_PROCESS = psutil.Process(os.getpid())
MAX_THREADS: int = 0
def print_max_threads_encountered(idx: int):
global MAX_THREADS
cur_threads = CUR_PROCESS.num_threads()
MAX_THREADS = max(cur_threads, MAX_THREADS)
print(f"Max threads: {MAX_THREADS}; cur threads: {cur_threads}; idx={idx}")
def get_loader(
batch_size: int,
num_workers: int,
device: str = "cpu", # BUG occurs on "cpu" or "cuda"!
) -> Loader:
label_pipeline = [
IntDecoder(),
ToTensor(),
ToDevice(device),
]
image_pipeline = [
SimpleRGBImageDecoder(),
ToTensor(),
ToDevice(device),
]
return Loader(
"test.beton",
batch_size=batch_size,
num_workers=num_workers,
os_cache=True, # BUG occurs with os_cache = False or True
pipelines={"image": image_pipeline, "label": label_pipeline},
)
def main():
loader = get_loader(batch_size=4, num_workers=4)
cnt: int = 0
while True:
# This has a thread leak
for _ in tqdm(loader):
pass
# This also has a thread leak
# with tqdm(loader) as pbar:
# for _ in pbar:
# pass
# Without tqdm, there is no thread leak!
# for _ in loader:
# pass
# Manual tqdm is also okay!
# with tqdm(total=len(loader)) as pbar:
# for _ in loader:
# pbar.update(1)
cnt += 1
print_max_threads_encountered(idx=cnt)
if __name__ == "__main__":
main()
ffcv==0.0.3
torch==1.12.0
torchvision==0.13.0
tqdm==4.64.0
You can either run the gen_beton.py script to generate this file, or you can download it from
https://drive.google.com/file/d/1d_XT5CAG9MxZ8gIao5qOxNVbF8HNuXfg/view?usp=sharing
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment