Skip to content

Instantly share code, notes, and snippets.

@dvruette
Last active January 18, 2024 22:12
Show Gist options
  • Save dvruette/72ecac9c623b89548ed3627d69acdf69 to your computer and use it in GitHub Desktop.
Save dvruette/72ecac9c623b89548ed3627d69acdf69 to your computer and use it in GitHub Desktop.
Faster Grid
diff --git a/grid.py b/grid.py
index f9f1557..8eafb91 100755
--- a/grid.py
+++ b/grid.py
@@ -5,8 +5,10 @@
# Written by Francois Fleuret <francois@fleuret.org>
-import math
-import torch, torchvision
+from concurrent.futures import ProcessPoolExecutor
+import os
+import tqdm
+import torch
import torch.nn.functional as F
######################################################################
@@ -166,24 +168,24 @@ class GridFactory:
"white_smoke",
][:nb_colors]
- def generate_scene(self):
- nb_items = torch.randint(self.max_nb_items - 1, (1,)).item() + 2
+ def generate_scene(self, generator=None):
+ nb_items = torch.randint(self.max_nb_items - 1, (1,), generator=generator).item() + 2
col = torch.full((self.size * self.size,), -1)
shp = torch.full((self.size * self.size,), -1)
- a = torch.randperm(len(self.name_colors) * len(self.name_shapes))[:nb_items]
+ a = torch.randperm(len(self.name_colors) * len(self.name_shapes), generator=generator)[:nb_items]
col[:nb_items] = a % len(self.name_colors)
shp[:nb_items] = a // len(self.name_colors)
- i = torch.randperm(self.size * self.size)
+ i = torch.randperm(self.size * self.size, generator=generator)
col = col[i]
shp = shp[i]
return col.reshape(self.size, self.size), shp.reshape(self.size, self.size)
- def random_transformations(self, scene):
+ def random_transformations(self, scene, generator=None):
col, shp = scene
descriptions = []
- nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,)).item()
- transformations = torch.randint(5, (nb_transformations,))
+ nb_transformations = torch.randint(self.max_nb_transformations + 1, (1,), generator=generator).item()
+ transformations = torch.randint(5, (nb_transformations,), generator=generator)
for t in transformations:
if t == 0:
@@ -284,11 +286,15 @@ class GridFactory:
return properties
- def generate_scene_and_questions(self):
+ def generate_scene_and_questions(self, seed=None): # dummy argument to be able to use executor.map
+ rng = torch.Generator()
+ if seed is not None:
+ rng.manual_seed(seed)
+
while True:
while True:
- start_scene = self.generate_scene()
- scene, transformations = self.random_transformations(start_scene)
+ start_scene = self.generate_scene(generator=rng)
+ scene, transformations = self.random_transformations(start_scene, generator=rng)
true = self.all_properties(scene)
if len(true) >= self.nb_questions:
break
@@ -296,7 +302,7 @@ class GridFactory:
for a in range(10):
col, shp = scene
col, shp = col.view(-1), shp.view(-1)
- p = torch.randperm(col.size(0))
+ p = torch.randperm(col.size(0), generator=rng)
col, shp = col[p], shp[p]
other_scene = (
col.view(self.size, self.size),
@@ -308,8 +314,8 @@ class GridFactory:
# We sometime add properties from a totally different
# scene to have negative "there is a xxx xxx"
# properties
- if torch.rand(1).item() < 0.2:
- other_scene = self.generate_scene()
+ if torch.rand(1, generator=rng).item() < 0.2:
+ other_scene = self.generate_scene(generator=rng)
false += self.all_properties(other_scene)
false = list(set(false) - set(true))
@@ -319,13 +325,13 @@ class GridFactory:
if a < 10:
break
- true = [true[k] for k in torch.randperm(len(true))[: self.nb_questions]]
- false = [false[k] for k in torch.randperm(len(false))[: self.nb_questions]]
+ true = [true[k] for k in torch.randperm(len(true), generator=rng)[: self.nb_questions]]
+ false = [false[k] for k in torch.randperm(len(false), generator=rng)[: self.nb_questions]]
true = ["<prop> " + q + " <ans> true" for q in true]
false = ["<prop> " + q + " <ans> false" for q in false]
union = true + false
- questions = [union[k] for k in torch.randperm(len(union))[: self.nb_questions]]
+ questions = [union[k] for k in torch.randperm(len(union), generator=rng)[: self.nb_questions]]
result = " ".join(
["<obj> " + x for x in self.grid_positions(start_scene)]
@@ -335,15 +341,28 @@ class GridFactory:
return start_scene, scene, result
- def generate_samples(self, nb, progress_bar=None):
+ def generate_samples(self, nb, show_progress=False, num_workers="auto", seed=None):
result = []
- r = range(nb)
- if progress_bar is not None:
- r = progress_bar(r)
-
- for _ in r:
- result.append(self.generate_scene_and_questions()[2])
+ rng = torch.Generator()
+ if seed is None:
+ seed = torch.randint(0, 2 ** 32 - 1, (1,)).item()
+ rng.manual_seed(seed)
+
+ if isinstance(num_workers, str):
+ num_workers = os.cpu_count()
+
+ with tqdm.tqdm(total=nb, smoothing=0.01, disable=not show_progress) as pbar:
+ if num_workers == 1:
+ for _ in range(nb):
+ result.append(self.generate_scene_and_questions()[2])
+ pbar.update()
+ else:
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
+ seeds = torch.randint(0, 2 ** 32 - 1, (nb,), generator=rng).tolist()
+ for sample in executor.map(self.generate_scene_and_questions, seeds, chunksize=32):
+ result.append(sample[2])
+ pbar.update()
return result
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment