Skip to content

Instantly share code, notes, and snippets.

@marmbrus
Created October 12, 2016 01:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save marmbrus/7d116b0a9672337497ddfccc0657dbf0 to your computer and use it in GitHub Desktop.
Save marmbrus/7d116b0a9672337497ddfccc0657dbf0 to your computer and use it in GitHub Desktop.
import java.nio.charset.Charset
import java.util.concurrent.TimeoutException
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import org.apache.kafka.clients.consumer.ConsumerConfig
import org.apache.kafka.clients.consumer.internals.{ConsumerNetworkClient, RequestFutureListener}
import org.apache.kafka.clients._
import scala.util.Random
import org.apache.kafka.clients.producer.RecordMetadata
import org.apache.kafka.common.{Cluster, Node, TopicPartition}
import org.apache.kafka.common.metrics.{MetricConfig, Metrics}
import org.apache.kafka.common.network.{ChannelBuilder, Selector}
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.protocol.types.Struct
import org.apache.kafka.common.record.MemoryRecords
import org.apache.kafka.common.requests._
import org.apache.kafka.common.serialization.BytesDeserializer
import org.apache.kafka.common.utils.SystemTime
import org.apache.spark.internal.Logging
import org.scalatest.BeforeAndAfter
import org.scalatest.time.SpanSugar._
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.streaming.StreamTest
import org.apache.spark.sql.test.SharedSQLContext
class RawTest extends StreamTest with SharedSQLContext with Logging {
type Config = {
def getLong(key: String): Long
def getInt(key: String): Int
def getList(key: String): java.util.List[String]
def values(): java.util.HashMap[String, String]
}
test("test") {
var tries = 0
val ctr = classOf[ConsumerConfig].getDeclaredConstructors.head
ctr.setAccessible(true)
val map = Map(
"bootstrap.servers" -> "192.168.2.109:9092",
"value.deserializer" -> classOf[BytesDeserializer],
"key.deserializer" -> classOf[BytesDeserializer])
val config = ctr.newInstance(map.asJava).asInstanceOf[Config]
val channelBuilder: ChannelBuilder = ClientUtils.createChannelBuilder(config.values)
val time = new SystemTime
val servers =
ClientUtils.parseAndValidateAddresses(
config.getList(ConsumerConfig.BOOTSTRAP_SERVERS_CONFIG))
val cluster = Cluster.bootstrap(servers)
val metadata = new Metadata(
10,
config.getLong(ConsumerConfig.METADATA_MAX_AGE_CONFIG))
metadata.update(cluster, time.milliseconds())
val netClient: NetworkClient =
new NetworkClient(
new Selector(
config.getLong(ConsumerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG),
new Metrics(new MetricConfig(), Nil, time),
time,
"",
channelBuilder),
metadata,
"-1",
100, // a fixed large enough value will suffice
config.getLong(
ConsumerConfig.RECONNECT_BACKOFF_MS_CONFIG),
config.getInt(ConsumerConfig.SEND_BUFFER_CONFIG),
config.getInt(ConsumerConfig.RECEIVE_BUFFER_CONFIG),
config.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG),
new SystemTime)
def syncSend(
apiKey: Short,
version: Short,
node: Node,
requestStruct: Struct,
client: String = "spark"): ClientResponse = {
while(!netClient.ready(node, time.milliseconds())) {
println(s"Bootstrap connection to $node")
netClient.poll(1000, time.milliseconds())
}
val header = new RequestHeader(apiKey, 0.toShort, "-1", 0)
val requestSend =
new RequestSend(node.idString(), header, requestStruct)
var result: ClientResponse = null
val request = new ClientRequest(
time.milliseconds(),
true,
requestSend,
new RequestCompletionHandler {
override def onComplete(response: ClientResponse): Unit = {
println("callback!")
result = response
}
})
netClient.send(request, time.milliseconds())
var tries = 0
while (result == null && tries < 5) {
println("waiting for response")
netClient.poll(1000L, time.milliseconds())
tries += 1
}
if (result == null) throw new TimeoutException()
result
}
val bootstrapNode = cluster.nodes().get(0)
val initialMetadata =
new MetadataResponse(
syncSend(
ApiKeys.METADATA.id,
0,
bootstrapNode,
new MetadataRequest(Nil).toStruct).responseBody())
metadata.update(initialMetadata.cluster(), time.milliseconds())
val topicMetadata =
initialMetadata.topicMetadata
.find(_.topic == "test")
.getOrElse(sys.error("topic missing"))
val topicPartition = new TopicPartition(topicMetadata.topic(), 0)
val partitionMetadata = topicMetadata.partitionMetadata().head
val offsets =
new ListOffsetResponse(
syncSend(
ApiKeys.LIST_OFFSETS.id,
0,
partitionMetadata.leader(),
new ListOffsetRequest(
Map(
topicPartition -> new ListOffsetRequest.PartitionData(-2, 1))).toStruct).responseBody())
val partitionOffset = offsets.responseData().get(topicPartition)
val fetchData = new FetchRequest.PartitionData(partitionOffset.offsets.get(0), 1024)
val fetchPartitions = Map(topicPartition -> fetchData).asJava
val fetchRequest = new FetchRequest(-1, 1000 * 5, 1, fetchPartitions)
val fetchResponse =
new FetchResponse(
syncSend(
ApiKeys.FETCH.id,
0,
partitionMetadata.isr().get(0),
fetchRequest.toStruct).responseBody())
val records =
MemoryRecords.readableRecords(fetchResponse.responseData().asScala.head._2.recordSet)
println("=== RECORDS ===")
records.asScala.foreach { rec =>
println(rec.offset())
println(new String(rec.record().value().array(), 0, rec.record().size(), Charset.forName("utf-8")))
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment