Skip to content

Instantly share code, notes, and snippets.

@pecigonzalo
Created August 5, 2021 12:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pecigonzalo/26469014b96b26b7da53f66519350498 to your computer and use it in GitHub Desktop.
Save pecigonzalo/26469014b96b26b7da53f66519350498 to your computer and use it in GitHub Desktop.
Kotlin Kinesis Producer
package app
import com.amazonaws.auth.AWSStaticCredentialsProvider
import com.amazonaws.auth.BasicSessionCredentials
import com.amazonaws.auth.DefaultAWSCredentialsProviderChain
import com.amazonaws.services.kinesis.producer.KinesisProducer
import com.amazonaws.services.kinesis.producer.KinesisProducerConfiguration
import com.amazonaws.services.kinesis.producer.UnexpectedMessageException
import com.amazonaws.services.kinesis.producer.UserRecordFailedException
import com.amazonaws.services.kinesis.producer.UserRecordResult
import com.amazonaws.services.securitytoken.AWSSecurityTokenService
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder
import com.amazonaws.services.securitytoken.model.AssumeRoleRequest
import com.amazonaws.services.securitytoken.model.AssumeRoleResult
import com.amazonaws.services.securitytoken.model.Credentials
import com.google.common.collect.Iterables
import com.google.common.util.concurrent.FutureCallback
import com.google.common.util.concurrent.Futures
import java.math.BigInteger
import java.nio.ByteBuffer
import java.time.Instant
import java.util.Random
import java.util.concurrent.Executors
import java.util.concurrent.ScheduledExecutorService
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicLong
import org.json.JSONObject
import software.amazon.awssdk.services.s3.model.*
val RANDOM = Random()
val TIMESTAMP = System.currentTimeMillis().toString()
val STREAM_NAME = "inviting-coyote"
val AWS_REGION = "eu-central-1"
val SECONDS_TO_RUN = 10
val RECORDS_PER_SECOND = (100 * 10)
val TICKERS = arrayOf("AAPL", "AMZN", "MSFT", "INTC", "TBV")
val SERVICE_NAME = "my-producer-service"
val KINESIS_ROLE = "arn:aws:iam::123123123123:role/inviting-coyote-writter-role"
fun randomExplicitHashKey(): String {
return BigInteger(128, RANDOM).toString(10)
}
fun assumeGivenRole(roleARN: String, roleSessionName: String): BasicSessionCredentials {
val stsClient: AWSSecurityTokenService =
AWSSecurityTokenServiceClientBuilder.standard()
.withCredentials(DefaultAWSCredentialsProviderChain())
.build()
val roleRequest: AssumeRoleRequest =
AssumeRoleRequest().withRoleArn(roleARN).withRoleSessionName(roleSessionName)
val roleResponse: AssumeRoleResult = stsClient.assumeRole(roleRequest)
val sessionCredentials: Credentials = roleResponse.getCredentials()
val awsCredentials: BasicSessionCredentials =
BasicSessionCredentials(
sessionCredentials.getAccessKeyId(),
sessionCredentials.getSecretAccessKey(),
sessionCredentials.getSessionToken()
)
return awsCredentials
}
fun generateData(): ByteBuffer {
val index: Int = RANDOM.nextInt(TICKERS.size)
val record =
JSONObject()
.put("EVENT_TIME", Instant.now().toString())
.put("TICKER", TICKERS[index])
.put("PRICE", RANDOM.nextDouble() * 100)
.toString()
// println("Record: " + record.toString())
val recordBytes = record.toByteArray()
return ByteBuffer.wrap(recordBytes)
}
fun main() {
val kinesisCreds =
assumeGivenRole(KINESIS_ROLE, SERVICE_NAME)
val kConfig =
KinesisProducerConfiguration()
.setCredentialsProvider(AWSStaticCredentialsProvider(kinesisCreds))
.setRegion(AWS_REGION)
// .setMaxConnections(1)
.setAggregationEnabled(true)
val producer = KinesisProducer(kConfig)
val sequenceNumber: AtomicLong = AtomicLong(0)
val completed: AtomicLong = AtomicLong(0)
val callbackThreadPool = Executors.newCachedThreadPool()
val EXECUTOR: ScheduledExecutorService = Executors.newScheduledThreadPool(1)
val callback: FutureCallback<UserRecordResult> =
object : FutureCallback<UserRecordResult> {
override fun onFailure(t: Throwable) {
when (t) {
is UserRecordFailedException -> {
val attempts = t.getResult().getAttempts().size - 1
val last = Iterables.getLast(t.getResult().getAttempts())
if (attempts > 1) {
val previous = t.getResult().getAttempts().get(attempts - 1)
println(
"Record failed to put - %s : %s. Previous failure - %s : %s".format(
last.getErrorCode(),
last.getErrorMessage(),
previous.getErrorCode(),
previous.getErrorMessage()
)
)
} else {
println(
"Record failed to put - %s : %s".format(
last.getErrorCode(),
last.getErrorMessage()
)
)
}
}
is UnexpectedMessageException -> {
println(
"Record failed to put due to unexpected message received from native layer - " + t
)
}
else -> println("Exception during put: " + t)
}
}
override fun onSuccess(result: UserRecordResult?) {
completed.getAndIncrement()
}
}
val putOneRecord: () -> Unit = {
val data: ByteBuffer = generateData()
val f = producer.addUserRecord(STREAM_NAME, TIMESTAMP, randomExplicitHashKey(), data)
Futures.addCallback(f, callback, callbackThreadPool)
}
fun printProgress(exec: ScheduledExecutorService, startTime: Long) {
exec.scheduleAtFixedRate(
object : Runnable {
override fun run() {
val secondsRun = ((System.nanoTime() - startTime) / 1e9)
val put: Long = sequenceNumber.get()
val total: Long = RECORDS_PER_SECOND.toLong() * SECONDS_TO_RUN
val putPercent: Double = 100.0 * put / total
val done: Long = completed.get()
val donePercent: Double = 100.0 * done / total
val pps = done / secondsRun
println(
"Put %d of %d so far (%.2f %%), %d have completed (%.2f %%). PPS: %.2f".format(
put,
total,
putPercent,
done,
donePercent,
pps
)
)
println(
"Oldest future as of now in millis is %s".format(
producer.getOldestRecordTimeInMillis()
)
)
}
},
1,
1,
TimeUnit.SECONDS
)
}
fun executeAtTargetRate(
exec: ScheduledExecutorService,
task: Runnable,
counter: AtomicLong,
durationSeconds: Int,
ratePerSecond: Int
) {
exec.scheduleWithFixedDelay(
object : Runnable {
val startTime = System.nanoTime()
override fun run() {
val secondsRun = ((System.nanoTime() - startTime) / 1e9)
val targetCount = (Math.min(durationSeconds.toDouble(), secondsRun) * ratePerSecond)
while (counter.get() < targetCount) {
counter.getAndIncrement()
try {
task.run()
} catch (e: Exception) {
println("Error running task: " + e)
System.exit(1)
}
}
if (secondsRun >= durationSeconds) {
exec.shutdown()
}
}
},
0,
1,
TimeUnit.MILLISECONDS
)
}
println("Starting producer")
println(
"Stream name: %s Region: %s secondsToRun %d".format(STREAM_NAME, AWS_REGION, SECONDS_TO_RUN)
)
println("Will attempt to run the KPL at %f MB/s...".format(1 * RECORDS_PER_SECOND / (1000.0)))
printProgress(EXECUTOR, System.nanoTime())
executeAtTargetRate(EXECUTOR, putOneRecord, sequenceNumber, SECONDS_TO_RUN, RECORDS_PER_SECOND)
EXECUTOR.awaitTermination(SECONDS_TO_RUN + 1L, TimeUnit.SECONDS)
println("Waiting for remaining puts to finish...")
producer.flushSync()
println("All records complete")
producer.destroy()
println("Finished")
producer.flushSync()
println("All records complete")
producer.destroy()
println("Finished")
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment