Skip to content

Instantly share code, notes, and snippets.

@PeterCorless
Created March 11, 2019 21:57
Show Gist options
  • Save PeterCorless/d2ca07f9b4fe5a70afb99ab2f3a34ae2 to your computer and use it in GitHub Desktop.
Save PeterCorless/d2ca07f9b4fe5a70afb99ab2f3a34ae2 to your computer and use it in GitHub Desktop.
Scylla Migrator
val connector =
new CassandraConnector(CassandraConnectorConf(sparkContext.getConf))
import com.datastax.spark.connector.cql.TableDef
case class TableDef(
keyspaceName: String,
tableName: String,
partitionKey: Seq[ColumnDef],
clusteringColumns: Seq[ColumnDef],
regularColumns: Seq[ColumnDef],
indexes: Seq[IndexDef] = Seq.empty,
isView: Boolean = false
)
import com.datastax.spark.connector.cql.Schema
val tableDef: TableDef = Schema.tableFromCassandra(connector, "keyspace", "table")
case class StructType(fields: Array[StructField])
case class StructField(
name: String,
dataType: DataType,
nullable: Boolean = true,
metadata: Metadata = Metadata.empty)
def toStructField(column: ColumnDef): StructField =
StructField(column.columnName, catalystDataType(column.columnType, nullable = true))
val schema = StructType(tableDef.columns.map(DataTypeConverter.toStructField))
def select(columns: ColumnRef*): Self
val refs: Seq[ColumnRef] = tableDef.allColumns.map(_.ref)
import com.datastax.spark.connector.{TTL, WriteTime}
val projection: Seq[ColumnRef] =
tableDef.allColumns.flatMap { columnDef =>
val colName = columnDef.columnName
List(
columnDef,
TTL(colName).as(s"${colName}_ttl"),
WriteTime(colName).as(s"${colName}_writetime")
)
}
val projection: Seq[ColumnRef] =
tableDef.partitionKey.map(_.ref) ++
tableDef.clusteringColumns.map(_.ref) ++
tableDef.regularColumns.flatMap { columnDef =>
val colName = columnDef.columnName
List(
columnDef,
TTL(colName).as(s"${colName}_ttl"),
WriteTime(colName).as(s"${colName}_writetime")
)
val rdd: CassandraRDD[CassandraSQLRow]
spark.sparkContext
.cassandraTable[CassandraSQLRow](source.keyspace, source.table)
.select(projection: _*)
val modifiedSchema =
StructType(
for {
origField <- schema.fields
isRegular = tableDef.regularColumns.exists(_.ref.columnName == origField.name)
field <- if (isRegular)
List(
origField,
StructField(s"${origField.name}_ttl", LongType, true),
StructField(s"${origField.name}_writetime", LongType, true))
else List(origField)
} yield field
)
val dataframe =
spark.createDataset(rdd)(RowEncoder(modifiedSchema))
INSERT INTO table (key1, key2, regular1, regular2)
VALUES ("a", 1, "reg1", "reg2")
USING TTL 86400;
- key1: "a"
- key2: "b"
- regular1: "reg1", TTL: 10, WRITETIME: 1000
- regular2: "reg2", TTL: 10, WRITETIME: 2000
- regular3: "reg3", TTL: 20, WRITETIME: 3000
- regular4: "reg4", TTL: 20, WRITETIME: 3000
INSERT INTO table (key1, key2, regular1)
VALUES ("a", 1, "reg1")
USING TTL 10 AND TIMESTAMP 1000;
INSERT INTO table (key1, key2, regular2)
VALUES ("a", 1, "reg2")
USING TTL 10 AND TIMESTAMP 2000;
INSERT INTO table (key1, key2, regular3, regular4)
VALUES ("a", 1, "reg3", "reg4")
USING TTL 20 AND TIMESTAMP 3000;
sealed trait CassandraOption[+A] extends Product with Serializable
object CassandraOption {
case class Value[+A](value: A) extends CassandraOption[A]
case object Unset extends CassandraOption[Nothing]
case object Null extends CassandraOption[Nothing]
}
def flatMap[U : Encoder](func: T => TraversableOnce[U]): Dataset[U]
def indexFields(fieldNames: List[String],
tableDef: TableDef) = {
val fieldIndices = fieldNames.zipWithIndex.toMap
val primaryKeyIndices: Map[String, Int] =
(for {
fieldName <- fieldNames
if tableDef.primaryKey.exists(_.ref.columnName == fieldName)
index <- fieldIndices.get(fieldName)
} yield origFieldName -> index).toMap
val regularKeyIndices: Map[String, (Int, Int, Int)] =
(for {
fieldName <- fieldNames
if tableDef.regularColumns.exists(_.ref.columnName == fieldName)
fieldIndex <- fieldIndices.get(fieldName)
ttlIndex <- fieldIndices.get(s"${fieldName}_ttl")
writetimeIndex <- fieldIndices.get(s"${fieldName}_writetime")
} yield fieldName -> (fieldIndex, ttlIndex, writetimeIndex)).toMap
(primaryKeyIndices, regularKeyIndices)
}
def explodeRow(row: Row,
schema: StructType,
primaryKeyIndices: Map[String, Int],
regularKeyIndices: Map[String, (Int, Int, Int)]) =
if (regularKeyIndices.isEmpty) List(row)
else
regularKeyIndices
.map {
case (fieldName, (ordinal, ttlOrdinal, writetimeOrdinal)) =>
(fieldName,
if (row.isNullAt(ordinal)) CassandraOption.Null
else CassandraOption.Value(row.get(ordinal)),
if (row.isNullAt(ttlOrdinal)) None
else Some(row.getLong(ttlOrdinal)),
row.getLong(writetimeOrdinal))
}
.groupBy {
case (_, _, ttl, writetime) => (ttl, writetime)
}
.mapValues { fieldGroups =>
fieldGroups
.map {
case (fieldName, value, _, _) => fieldName -> value
}
.toMap
}
.map {
case ((ttl, writetime), fields) =>
val newValues = schema.fields.map { field =>
primaryKeyIndices
.get(field.name)
.flatMap { ord =>
if (row.isNullAt(ord)) None
else Some(row.get(ord))
}
.getOrElse(fields.getOrElse(field.name, CassandraOption.Unset))
} ++ Seq(ttl.getOrElse(0L), writetime)
Row(newValues: _*)
}
val (primaryKeyOrdinals, regularKeyOrdinals) = indexFields(
df.schema.fields.map(_.name).toList,
tableDef)
val broadcastPrimaryKeyOrdinals = spark.sparkContext.broadcast(primaryKeyOrdinals)
val broadcastRegularKeyOrdinals = spark.sparkContext.broadcast(regularKeyOrdinals)
val broadcastSchema = spark.sparkContext.broadcast(origSchema)
df.flatMap {
explodeRow(
_,
broadcastSchema.value,
broadcastPrimaryKeyOrdinals.value,
broadcastRegularKeyOrdinals.value)
}
val colSelector: ColumnSelector =
SomeColumns(origSchema.fields.map(x => x.name: ColumnRef))
val writeConf =
WriteConf.fromSparkConf(spark.sparkContext.getConf)
.copy(
ttl = TTLOption.perRow("ttl"),
timestamp = TimestampOption.perRow("writetime")
)
df.rdd.saveToCassandra(
keyspaceName,
tableName,
colSelector,
writeConf
)
abstract class AccumulatorV2[IN, OUT] {
def isZero: Boolean
def copy(): AccumulatorV2[IN, OUT]
def reset(): Unit
def add(v: IN): Unit
def merge(other: AccumulatorV2[IN, OUT]): Unit
def value: OUT
}
import java.util.concurrent.atomic.AtomicReference
import java.util.function.UnaryOperator
import com.datastax.spark.connector.rdd.partitioner.CqlTokenRange
import org.apache.spark.util.AccumulatorV2
class TokenRangeAccumulator(acc: AtomicReference[Set[CqlTokenRange[_, _]]])
extends AccumulatorV2[Set[CqlTokenRange[_, _]], Set[CqlTokenRange[_, _]]] {
override def add(v: Set[CqlTokenRange[_, _]]): Unit =
acc.getAndUpdate(
new UnaryOperator[Set[CqlTokenRange[_, _]]] {
override def apply(t: Set[CqlTokenRange[_, _]]): Set[CqlTokenRange[_, _]] =
t ++ v
}
)
override def value: AtomicReference[Set[CqlTokenRange[_, _]]] = acc.get()
}
def extractTokenRange(partitionId: Int): Iterable[CqlTokenRange[_, _]] =
partitions.lift(partitionId) match {
case Some(CassandraPartition(_, _, ranges, _)) => ranges
case _ => List()
}
tokenRangeAcc.foreach(_.add(tokenRanges.toSet))
def startSavepointSchedule(svc: ScheduledThreadPoolExecutor,
config: MigratorConfig,
acc: TokenRangeAccumulator): Unit = {
val runnable = new Runnable {
override def run(): Unit =
try dumpAccumulatorState(config, acc, "schedule")
catch {
case e: Throwable =>
log.error("Could not create the savepoint. This will be retried.", e)
}
}
log.info(
s"Starting savepoint schedule; will write a savepoint every ${config.savepoints.intervalSeconds} seconds")
svc.scheduleAtFixedRate(runnable, 0, config.savepoints.intervalSeconds, TimeUnit.SECONDS)
}
val tokenRanges =
partition.tokenRanges.filter { cqlRange =>
val (start, end) = (cqlRange.range.start.value, cqlRange.range.end.value) match {
case (s: Long, e: Long) => (s, e)
case _ =>
throw new Exception("Encountered TokenRanges that use tokens of a type that isn't Long." +
"This probably means that the server is using a Random partitioner which is currently" +
s"unsupported. Range: ${cqlRange.range}")
}
tokenRangeFilter(start, end)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment