Skip to content

Instantly share code, notes, and snippets.

@slaymaker1907
Last active June 28, 2024 21:55
Show Gist options
  • Save slaymaker1907/9a12adc7eee311a8d8a5ce97d8737793 to your computer and use it in GitHub Desktop.
Save slaymaker1907/9a12adc7eee311a8d8a5ce97d8737793 to your computer and use it in GitHub Desktop.
Generators in pure Java using (virtual) threads

This shows how forEach style functions can be converted to full iterators using virtual threads to emulate continuations.

Note that the choice to use virtual threads just increases performance, this technique works just as well with regular platform threads.

import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.concurrent.Semaphore;
import java.util.function.Consumer;
public class Main {
public static class Box<T> {
private T item;
public Box(T item) {
this.item = item;
}
public T getItem() {
return item;
}
public void setItem(T item) {
this.item = item;
}
}
public static <T> Iterator<T> forEachToIterator(Consumer<Consumer<T>> forEachFunc) {
Box<T> current = new Box<>(null);
Thread.yield();
Semaphore sem1 = new Semaphore(1);
Semaphore sem2 = new Semaphore(1);
try {
sem1.acquire();
sem2.acquire();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Box<Boolean> threadDone = new Box<>(false);
Consumer<T> processor = item -> {
try {
sem1.acquire();
current.setItem(item);
sem2.release();
// Perf hack to try and start up other virtual thread.
Thread.yield();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
};
Runnable runner = () -> {
forEachFunc.accept(processor);
try {
sem1.acquire();
threadDone.setItem(true);
sem2.release();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
};
Thread iter = Thread.ofVirtual().start(runner);
Iterator<T> result = new Iterator<T>() {
private T itCurrent = null;
private boolean hasNext = getNext();
@Override
public boolean hasNext() {
return hasNext;
}
private boolean getNext() {
try {
sem1.release();
Thread.yield();
sem2.acquire();
itCurrent = current.getItem();
return !threadDone.getItem();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
@Override
public T next() {
T result = itCurrent;
hasNext = getNext();
return result;
}
};
return result;
}
public record Pair<T1, T2>(T1 first, T2 second) {}
public static <T1, T2> void zipForEach(Iterator<T1> it1, Iterator<T2> it2, Consumer<Pair<T1, T2>> onEach) {
while (it1.hasNext() && it2.hasNext()) {
var pair = new Pair<T1, T2>(it1.next(), it2.next());
onEach.accept(pair);
}
}
public static <T1, T2> Iterator<Pair<T1, T2>> zipIterator(Iterator<T1> it1, Iterator<T2> it2) {
return forEachToIterator(onEach -> zipForEach(it1, it2, onEach));
}
public static void main(String[] args) throws InterruptedException {
Thread.ofVirtual().start(() -> {
var nums1 = new ArrayList<Integer>();
var nums2 = new ArrayList<Integer>();
Random gen = new Random();
for (int i = 0; i < 100000; i++) {
nums1.add(gen.nextInt());
nums2.add(gen.nextInt());
}
Iterator<Integer> it1 = nums1.iterator();
Iterator<Integer> it2 = nums2.iterator();
Instant start = Instant.now();
Iterator<Pair<Integer, Integer>> zipIt = zipIterator(it1, it2);
long sum = 0;
while (zipIt.hasNext()) {
var current = zipIt.next();
sum += current.first() + current.second();
}
Duration timeElapsed = Duration.between(start, Instant.now());
System.out.println(sum); // Print out sum to force evaluation.
System.out.println(timeElapsed.toMillis());
}).join();
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment