Skip to content

Instantly share code, notes, and snippets.

@libetl
Created August 31, 2020 19:08
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save libetl/71b826a0db248e6770a2c0b5c0ae6d18 to your computer and use it in GitHub Desktop.
Save libetl/71b826a0db248e6770a2c0b5c0ae6d18 to your computer and use it in GitHub Desktop.
package com.mycompany.infra.config
import kotlinx.coroutines.CoroutineExceptionHandler
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.channels.SendChannel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import org.springframework.stereotype.Component
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
import kotlin.coroutines.CoroutineContext
@Component
class BatchCoroutinesStrategy<T, U> {
val started = AtomicBoolean(false)
override fun batchOf(
workers: Int,
coroutineContext: CoroutineContext,
operation: suspend (T) -> U
): suspend (List<T>) -> Deferred<Map<T, U>> {
val startTask = Channel<T>()
val endTask = Channel<Pair<T, U>>()
val failTask = Channel<Pair<T, Throwable>>()
return { inputs: List<T> ->
if (started.compareAndSet(false, true))
init(workers, coroutineContext, operation, startTask, endTask, failTask)
CoroutineScope(coroutineContext).async {
scheduler(workers, inputs, startTask, failTask, endTask)
}
}
}
private fun init(
workers: Int,
coroutineContext: CoroutineContext,
operation: suspend (T) -> U,
startTask: Channel<T>,
endTask: Channel<Pair<T, U>>,
failTask: Channel<Pair<T, Throwable>>
) {
CoroutineScope(coroutineContext).launch {
repeat(workers) { worker(startTask, endTask, failTask, operation) }
}
}
private suspend fun worker(
taskStarted: ReceiveChannel<T>,
endTask: SendChannel<Pair<T, U>>,
failTask: SendChannel<Pair<T, Throwable>>,
operation: suspend (T) -> U
) = coroutineScope {
launch {
while (true) {
val key = taskStarted.receive()
launch(
CoroutineExceptionHandler { _, exception ->
launch {
failTask.send(key to exception)
}
}
) {
endTask.send(key to operation(key))
}
}
}
}
private suspend fun scheduler(
workers: Int,
inputs: List<T>,
startTask: SendChannel<T>,
failTask: ReceiveChannel<Pair<T, Throwable>>,
taskEnded: ReceiveChannel<Pair<T, U>>
): Map<T, U> {
val available = AtomicInteger(workers)
val resultMap = mutableMapOf<T, U>()
val iterator = inputs.iterator()
while (iterator.hasNext() || available.get() < workers) {
while (available.get() > 0 && iterator.hasNext()) {
startTask.send(iterator.next())
available.decrementAndGet()
}
if (available.get() < workers) {
val (key, value) = taskEnded.receive()
available.incrementAndGet()
val problem = failTask.poll()
if (problem != null) {
available.incrementAndGet()
}
resultMap[key] = value
}
}
return resultMap.toMap()
}
}
package com.mycompany.infra.config
import com.google.common.cache.CacheBuilder
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
import org.springframework.stereotype.Component
import java.net.InetAddress
import java.time.Duration
import java.time.temporal.ChronoUnit
import java.util.UUID
import java.util.concurrent.atomic.AtomicBoolean
import kotlin.coroutines.CoroutineContext
@Component
class CachingCoroutinesStrategy<T, U> {
enum class Status {
NOT_EXECUTED_HERE, NOT_STARTED, IN_PROGRESS, FINISHED
}
companion object {
val machineName = InetAddress.getLocalHost().getHostName()
val taskNameRegex = Regex(
"$machineName-[0-9a-fA-F]{8}\\-[0-9a-fA-F]{4}\\-" +
"[0-9a-fA-F]{4}\\-[0-9a-fA-F]{4}\\-[0-9a-fA-F]{12}"
)
}
override fun cache(
workers: Int,
coroutineContext: CoroutineContext,
process: suspend (T) -> U
): AsyncCache<T, U> = AsyncCoroutinesCache<T, U>(workers, coroutineContext, process)
class AsyncCoroutinesCache<T, U>(
val howManyWorkers: Int = 20,
val coroutineContext: CoroutineContext,
val operation: suspend (T) -> U
) : AsyncCache<T, U> {
val assignedTasksId =
mutableMapOf<UUID, T>()
val results = CacheBuilder.newBuilder()
.concurrencyLevel(howManyWorkers)
.maximumSize(1000)
.expireAfterAccess(Duration.of(10, ChronoUnit.MINUTES))
.build<UUID, Pair<T, U>>()
val startTask: Channel<Pair<UUID, T>> = Channel<Pair<UUID, T>>()
val endTask: Channel<Triple<UUID, T, U>> = Channel<Triple<UUID, T, U>>()
val started = AtomicBoolean(false)
private fun worker() = CoroutineScope(coroutineContext).launch {
while (true) {
val (uuid, key) = startTask.receive()
endTask.send(Triple(uuid, key, operation(key)))
}
}
private fun cacheAdder() = CoroutineScope(coroutineContext).launch {
while (true) {
val (uuid, key, value) = endTask.receive()
assignedTasksId.remove(uuid)
results.put(uuid, key to value)
}
}
fun init() {
if (started.compareAndSet(false, true)) {
repeat(howManyWorkers) { worker() }
cacheAdder()
}
}
fun send(newUUID: UUID, input: T) {
CoroutineScope(coroutineContext).launch {
startTask.send(newUUID to input)
}
}
override fun start(input: T): String {
val newUUID = UUID.randomUUID()
assignedTasksId[newUUID] = input
init()
send(newUUID, input)
return "$machineName-$newUUID"
}
override fun readStatusOf(
taskId: String
): Pair<CachingStrategy.Status, Pair<T, U?>?> {
if (!taskId.value.matches(taskNameRegex))
return NOT_EXECUTED_HERE to null
if (!taskId.value.startsWith(machineName))
return NOT_EXECUTED_HERE to null
val uuid =
UUID.fromString(taskId.value.substring(machineName.length + 1))
val result = results.getIfPresent(uuid)
if (result != null) return FINISHED to result
if (assignedTasksId[uuid] != null)
return IN_PROGRESS to (assignedTasksId[uuid]!! to null)
return NOT_STARTED to null
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment