Skip to content

Instantly share code, notes, and snippets.

@dmitrii-artuhov
Created August 10, 2023 23:38
Show Gist options
  • Save dmitrii-artuhov/b6b3b66f99ebb6026451df0255459037 to your computer and use it in GitHub Desktop.
Save dmitrii-artuhov/b6b3b66f99ebb6026451df0255459037 to your computer and use it in GitHub Desktop.
Threadpool
package org.hse.java.threadpool;
public class LightExecutionException extends Exception {
public LightExecutionException() {
super();
}
public LightExecutionException(String message) {
super(message);
}
public LightExecutionException(String message, Throwable cause) {
super(message, cause);
}
public LightExecutionException(Throwable cause) {
super(cause);
}
protected LightExecutionException(String message, Throwable cause, boolean enableSuppression,
boolean writableStackTrace) {
super(message, cause, enableSuppression, writableStackTrace);
}
}
package org.hse.java.threadpool;
import java.util.function.Function;
import org.jetbrains.annotations.NotNull;
public interface LightFuture<R> {
boolean isReady();
@NotNull R get() throws LightExecutionException;
<R1> @NotNull LightFuture<R1> thenApply(Function<R, R1> function);
}
package org.hse.java.threadpool;
import java.util.function.Function;
import java.util.function.Supplier;
import org.hse.java.threadpool.TaskState.ExecutionStatus;
import org.jetbrains.annotations.NotNull;
public class LightFutureImpl<R> implements LightFuture<R> {
private final TaskState<R> taskState;
public LightFutureImpl(TaskState<R> taskState_) {
taskState = taskState_;
}
@Override
public boolean isReady() {
return taskState.getExecState() == ExecutionStatus.COMPLETED;
}
@Override
public @NotNull R get() throws LightExecutionException {
taskState.lockState();
try {
while (taskState.getExecState() == ExecutionStatus.IN_PROGRESS
|| taskState.getExecState() == ExecutionStatus.QUEUED) {
try {
taskState.await();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
ExecutionStatus status = taskState.getExecState();
if (status == ExecutionStatus.IN_PROGRESS) {
throw new RuntimeException("Task cannot be in '" + ExecutionStatus.IN_PROGRESS
+ "' status, LightFutureImpl::get waited for it to finish");
}
if (status == ExecutionStatus.EXCEPTION_THROWN || status == ExecutionStatus.INTERRUPTED) {
throw new LightExecutionException(taskState.getThrownException());
} else if (status == ExecutionStatus.COMPLETED) {
return taskState.getResult();
}
throw new RuntimeException("Unreachable code inside was reached");
}
finally {
taskState.unlockState();
}
}
@Override
public @NotNull <R1> LightFuture<R1> thenApply(Function<R, R1> function) {
TaskState<R1> newTaskState = new TaskState<>();
taskState.addOnCompletedCallback((result, threadPool) -> {
threadPool.scheduleTask(newTaskState, () -> function.apply(result));
});
return new LightFutureImpl<>(newTaskState);
}
}
package org.hse.java.threadpool;
public abstract class Task implements Runnable {}
package org.hse.java.threadpool;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
public class TaskState<R> {
public enum ExecutionStatus {
EXCEPTION_THROWN,
IN_PROGRESS,
INTERRUPTED,
COMPLETED,
QUEUED
}
private final Lock stateLock = new ReentrantLock();
private final Condition taskFinished = stateLock.newCondition();
private R resultValue;
private Exception thrownException;
private ExecutionStatus execState = ExecutionStatus.QUEUED;
private final List<BiConsumer<R, ThreadPoolImpl>> dependantTasksCallbacks = new ArrayList<>();
public void addOnCompletedCallback(BiConsumer<R, ThreadPoolImpl> callback) {
dependantTasksCallbacks.add(callback);
}
public List<BiConsumer<R, ThreadPoolImpl>> getOnCompletedCallbacks() {
return Collections.unmodifiableList(dependantTasksCallbacks);
}
public void lockState() {
stateLock.lock();
}
public void unlockState() {
stateLock.unlock();
}
public void signal() {
taskFinished.signal();
}
public void await() throws InterruptedException {
taskFinished.await();
}
public void setResult(R value) {
resultValue = value;
}
public R getResult() {
return resultValue;
}
public void setExecState(ExecutionStatus value) {
execState = value;
}
public ExecutionStatus getExecState() {
return execState;
}
public void setThrownException(Exception e) {
thrownException = e;
}
public Exception getThrownException() {
return thrownException;
}
}
package org.hse.java.threadpool;
import java.util.function.Supplier;
import org.jetbrains.annotations.NotNull;
public interface ThreadPool {
static @NotNull ThreadPool create(int threads) {
return new ThreadPoolImpl(threads);
}
<R> @NotNull LightFuture<R> submit(Supplier<R> supplier);
void shutdown();
int getNumberOfThreads();
}
package org.hse.java.threadpool;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Supplier;
import org.hse.java.threadpool.TaskState.ExecutionStatus;
import org.jetbrains.annotations.NotNull;
public class ThreadPoolImpl implements ThreadPool {
private final List<Thread> threads;
private final Lock queueLock = new ReentrantLock();
private final Condition queueNotEmpty = queueLock.newCondition();
private final AtomicBoolean shutdownRequested = new AtomicBoolean(false);
Queue<Task> queue = new LinkedList<>();
public ThreadPoolImpl(int n) {
threads = new ArrayList<>();
for (int i = 0; i < n; ++i) {
Thread t = new Thread(() -> {
// System.out.println("Thread " + Thread.currentThread().getId() + " started");
while (!shutdownRequested.get()) {
queueLock.lock();
try {
while (queue.isEmpty() && !shutdownRequested.get()) {
// System.out.println("Waiting while queue becomes non-empty");
queueNotEmpty.await(); // Wait for the signal that queue is not empty
// System.out.println("[" + Thread.currentThread().getId() + "]: Signal arrived, stop waiting");
}
if (shutdownRequested.get()) {
throw new InterruptedException();
}
// System.out.println("[" + Thread.currentThread().getId() + "]: Taking task");
Task task = queue.remove();
// System.out.println("Thread " + Thread.currentThread().getId() + " took task " + task);
task.run();
} catch (InterruptedException e) {
// e.printStackTrace();
} finally {
queueLock.unlock();
}
}
});
threads.add(t);
t.start();
}
}
@Override
public @NotNull <R> LightFuture<R> submit(Supplier<R> supplier) {
TaskState<R> state = new TaskState<>();
scheduleTask(state, supplier);
return new LightFutureImpl<>(state);
}
public <R> void scheduleTask(TaskState<R> state, Supplier<R> supplier) {
queueLock.lock();
try {
queue.add(new Task() {
@Override
public void run() {
// System.out.println("Running task on thread " + Thread.currentThread().getId() + ", inside task body");
// set in-progress status for the task
state.lockState();
state.setExecState(ExecutionStatus.IN_PROGRESS);
state.unlockState();
// calculate the result of the task
R result = null;
Exception err = null;
try {
result = supplier.get();
}
catch (Exception e) {
err = e;
}
// update task status according to the result
state.lockState();
if (err == null) {
state.setResult(result);
state.setExecState(ExecutionStatus.COMPLETED);
}
else {
state.setExecState(ExecutionStatus.EXCEPTION_THROWN);
state.setThrownException(err);
}
// schedule all dependant tasks
List<BiConsumer<R, ThreadPoolImpl>> dependantTasks = state.getOnCompletedCallbacks();
for (var biCons : dependantTasks) {
biCons.accept(result, ThreadPoolImpl.this);
}
state.signal(); // result was saves, stop waiting on LightFuture::get() method
state.unlockState();
}
});
queueNotEmpty.signal();
}
finally {
queueLock.unlock();
}
}
@Override
public void shutdown() {
shutdownRequested.set(true);
for (Thread t : threads) {
t.interrupt();
}
for (Thread t : threads) {
try {
t.join();
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}
}
@Override
public int getNumberOfThreads() {
int activeThreadsCount = 0;
for (Thread t : threads) {
if (t.isAlive()) {
++activeThreadsCount;
}
}
return activeThreadsCount;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment