-
-
Save nebulorum/f7978aa5519cab8bece65d4dac689d4f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package tensortest; | |
import static java.util.concurrent.Executors.newFixedThreadPool; | |
import com.google.common.util.concurrent.ThreadFactoryBuilder; | |
import java.util.concurrent.ScheduledThreadPoolExecutor; | |
import java.util.concurrent.TimeUnit; | |
import java.util.concurrent.atomic.AtomicLong; | |
import org.tensorflow.types.TInt64; | |
class AllocateStress { | |
AtomicLong nAlloc = new AtomicLong(); | |
int[] timeRange; | |
AtomicLong[] timeBuckets; | |
public AllocateStress(int[] timeRange) { | |
this.timeRange = timeRange; | |
timeBuckets = new AtomicLong[timeRange.length + 1]; | |
for (var i = 0; i < timeRange.length + 1; i++) { | |
timeBuckets[i] = new AtomicLong(); | |
} | |
} | |
public static void main(String[] args) throws InterruptedException { | |
AllocateStress runner = new AllocateStress(new int[]{10, 20, 30, 40, 50, 100, 200}); | |
var nThread = 32; | |
if (args.length > 0) { | |
nThread = Integer.parseInt(args[0]); | |
} | |
System.out.println("Running " + nThread + " threads for allocation."); | |
runner.run(nThread); | |
Thread.currentThread().join(); | |
} | |
private void run(int nThreads) { | |
var ses = new ScheduledThreadPoolExecutor(1, new ThreadFactoryBuilder().setNameFormat("sched-%d").build()); | |
ses.scheduleAtFixedRate(report(), 1, 1, TimeUnit.SECONDS); | |
var alloc = newFixedThreadPool(nThreads); | |
for (var i = 0; i < nThreads; i++) { | |
alloc.submit(this::allocForever); | |
} | |
safeSleep(60000); | |
} | |
private void allocForever() { | |
while (true) { | |
var time = recordTime(this::allocLoop); | |
recordLatency(time); | |
} | |
} | |
private void recordLatency(long time) { | |
var i = 0; | |
for (; i < timeRange.length; i++) { | |
if (time < timeRange[i]) { | |
break; | |
} | |
} | |
timeBuckets[i].incrementAndGet(); | |
} | |
private long recordTime(Runnable r) { | |
long start = System.currentTimeMillis(); | |
r.run(); | |
nAlloc.incrementAndGet(); | |
return System.currentTimeMillis() - start; | |
} | |
private void allocLoop() { | |
TInt64[] tensors = new TInt64[80]; | |
long[] nd = new long[200]; | |
for (var i = 0; i < tensors.length; i++) { | |
tensors[i] = TInt64.vectorOf(nd); | |
} | |
safeSleep(5); | |
for (var i = 0; i < tensors.length; i++) { | |
tensors[i].close(); | |
} | |
} | |
private Runnable report() { | |
return () -> { | |
System.out.print("nAlloc " + nAlloc.getAndSet(0) + " (total): "); | |
var i = 0; | |
for (; i < timeRange.length; i++) { | |
System.out.print(" " + format(timeBuckets[i].getAndSet(0)) + " (<" + timeRange[i] + "ms); "); | |
} | |
System.out.println(" " + format(timeBuckets[timeRange.length].getAndSet(0)) + " ( >" + timeRange[timeRange.length - 1] + "ms)"); | |
}; | |
} | |
private String format(long v) { | |
return v == 0 ? "-" : Long.toString(v); | |
} | |
private void safeSleep(long millis) { | |
try { | |
Thread.sleep(millis); | |
} catch (InterruptedException e) { | |
e.printStackTrace(); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment