Skip to content

Commit

Permalink
krasserm#96 initial commit. Write side journal table definition.
Browse files Browse the repository at this point in the history
  • Loading branch information
zapletal-martin committed Oct 18, 2015
1 parent dd7f787 commit 2fc6d2d
Show file tree
Hide file tree
Showing 6 changed files with 504 additions and 256 deletions.
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
package akka.persistence.cassandra.journal

import java.lang.{ Long => JLong }
import java.lang.{Long => JLong}
import java.nio.ByteBuffer

import com.datastax.driver.core.policies.{LoggingRetryPolicy, RetryPolicy}
import com.datastax.driver.core.policies.RetryPolicy.RetryDecision

import scala.concurrent._
import scala.collection.immutable.Seq
import scala.collection.JavaConversions._
import scala.math.min
import scala.util.{Success, Failure, Try}
import scala.collection.immutable.Seq
import scala.concurrent._
import scala.util.{Failure, Success, Try}

import akka.persistence.journal.AsyncWriteJournal
import akka.persistence._
import akka.persistence.cassandra._
import akka.persistence.journal.AsyncWriteJournal
import akka.serialization.SerializationExtension

import com.datastax.driver.core._
import com.datastax.driver.core.policies.RetryPolicy.RetryDecision
import com.datastax.driver.core.policies.{LoggingRetryPolicy, RetryPolicy}
import com.datastax.driver.core.utils.Bytes

class CassandraJournal extends AsyncWriteJournal with CassandraRecovery with CassandraStatements {

// TODO: journalId management.
// TODO: Cluster membership can change and new Journal instances may be added and old removed.
// TODO: We need to ensure globally unique journalId. Conflicts would violate the single writer requirement.
// TODO: Garbage collecting or infinitely growing journalId set?
private[this] val journalId = context.self.path.toString
println(journalId)

private[this] var journalSequenceNr = 0L

val config = new CassandraJournalConfig(context.system.settings.config.getConfig("cassandra-journal"))
val serialization = SerializationExtension(context.system)

Expand Down Expand Up @@ -50,28 +56,28 @@ class CassandraJournal extends AsyncWriteJournal with CassandraRecovery with Cas
session.execute(writeConfig, CassandraJournalConfig.TargetPartitionProperty, config.targetPartitionSize.toString)

val preparedWriteMessage = session.prepare(writeMessage)
val preparedDeletePermanent = session.prepare(deleteMessage)
val preparedSelectMessages = session.prepare(selectMessages).setConsistencyLevel(readConsistency)
val preparedCheckInUse = session.prepare(selectInUse).setConsistencyLevel(readConsistency)
val preparedWriteInUse = session.prepare(writeInUse)
val preparedSelectHighestSequenceNr = session.prepare(selectHighestSequenceNr).setConsistencyLevel(readConsistency)
val preparedSelectDeletedTo = session.prepare(selectDeletedTo).setConsistencyLevel(readConsistency)
val preparedInsertDeletedTo = session.prepare(insertDeletedTo).setConsistencyLevel(writeConsistency)

def asyncWriteMessages(messages: Seq[AtomicWrite]): Future[Seq[Try[Unit]]] = {
override def asyncWriteMessages(messages: Seq[AtomicWrite]): Future[Seq[Try[Unit]]] = {
// we need to preserve the order / size of this sequence even though we don't map
// AtomicWrites 1:1 with a C* insert
val serialized = messages.map(aw => Try { SerializedAtomicWrite(
aw.payload.head.persistenceId,
aw.payload.map(pr => Serialized(pr.sequenceNr, persistentToByteBuffer(pr))))
})
val newJournalSequenceNr = journalSequenceNr + messages.size

val serialized = (journalSequenceNr to newJournalSequenceNr)
.zip(messages)
.map(aw => Try { SerializedAtomicWrite(
aw._2.payload.head.persistenceId,
aw._2.payload.map(pr => Serialized(aw._1, pr.sequenceNr, persistentToByteBuffer(pr))))
})
journalSequenceNr = newJournalSequenceNr

val result = serialized.map(a => a.map(_ => ()))

val byPersistenceId = serialized.collect({ case Success(caw) => caw }).groupBy(_.persistenceId).values
val boundStatements = byPersistenceId.map(statementGroup)

val batchStatements = boundStatements.map({ unit =>
executeBatch(batch => unit.foreach(batch.add))
executeBatch(batch => unit.foreach(batch.add))
})
val promise = Promise[Seq[Try[Unit]]]()

Expand All @@ -84,9 +90,12 @@ class CassandraJournal extends AsyncWriteJournal with CassandraRecovery with Cas
}

private def statementGroup(atomicWrites: Seq[SerializedAtomicWrite]): Seq[BoundStatement] = {
val maxPnr = partitionNr(atomicWrites.last.payload.last.sequenceNr)
val firstJournalSequenceNr = atomicWrites.last.payload.last.journaSequenceNr
val lastJournalSequenceNr = atomicWrites.head.payload.head.journaSequenceNr

val maxPnr = partitionNr(firstJournalSequenceNr)
val firstSeq = atomicWrites.head.payload.head.sequenceNr
val minPnr = partitionNr(firstSeq)
val minPnr = partitionNr(lastJournalSequenceNr)
val persistenceId: String = atomicWrites.head.persistenceId
val all = atomicWrites.flatMap(_.payload)

Expand All @@ -95,49 +104,28 @@ class CassandraJournal extends AsyncWriteJournal with CassandraRecovery with Cas
require(maxPnr - minPnr <= 1, "Do not support AtomicWrites that span 3 partitions. Keep AtomicWrites <= max partition size.")

val writes: Seq[BoundStatement] = all.map { m =>
preparedWriteMessage.bind(persistenceId, maxPnr: JLong, m.sequenceNr: JLong, m.serialized)
preparedWriteMessage.bind(journalId, maxPnr: JLong, m.journaSequenceNr: JLong, persistenceId, m.sequenceNr: JLong, m.serialized)
}
// in case we skip an entire partition we want to make sure the empty partition has in in-use flag so scans
// keep going when they encounter it
if (partitionNew(firstSeq) && minPnr != maxPnr) writes :+ preparedWriteInUse.bind(persistenceId, minPnr: JLong)
if (partitionNew(firstJournalSequenceNr) && minPnr != maxPnr) writes :+ preparedWriteInUse.bind(journalId, minPnr: JLong)
else writes

}

def asyncDeleteMessagesTo(persistenceId: String, toSequenceNr: Long): Future[Unit] = {
val logicalDelete = session.executeAsync(preparedInsertDeletedTo.bind(persistenceId, toSequenceNr: JLong))

val fromSequenceNr = readLowestSequenceNr(persistenceId, 1L)
val lowestPartition = partitionNr(fromSequenceNr)
val highestPartition = partitionNr(toSequenceNr) + 1 // may have been moved to the next partition
val partitionInfos = (lowestPartition to highestPartition).map(partitionInfo(persistenceId, _, toSequenceNr))

partitionInfos.map( future => future.flatMap( pi => {
Future.sequence((pi.minSequenceNr to pi.maxSequenceNr).grouped(config.maxMessageBatchSize).map { group => {
val delete = asyncDeleteMessages(pi.partitionNr, group map (MessageId(persistenceId, _)))
delete.onFailure {
case e => log.warning(s"Unable to complete deletes for persistence id ${persistenceId}, toSequenceNr ${toSequenceNr}. The plugin will continue to function correctly but you will need to manually delete the old messages.", e)
}
delete
}
})
}))

logicalDelete.map(_ => ())
}
// TODO: FIX
override def asyncDeleteMessagesTo(persistenceId: String, toSequenceNr: Long): Future[Unit] =
Future(())

private def partitionInfo(persistenceId: String, partitionNr: Long, maxSequenceNr: Long): Future[PartitionInfo] = {
session.executeAsync(preparedSelectHighestSequenceNr.bind(persistenceId, partitionNr: JLong))
.map(rs => Option(rs.one()))
.map(row => row.map(s => PartitionInfo(partitionNr, minSequenceNr(partitionNr), min(s.getLong("sequence_nr"), maxSequenceNr)))
.getOrElse(PartitionInfo(partitionNr, minSequenceNr(partitionNr), -1)))
}
// TODO: FIX
override def asyncReadHighestSequenceNr(
persistenceId: String,
fromSequenceNr: Long): Future[Long] = Future(1l)

private def asyncDeleteMessages(partitionNr: Long, messageIds: Seq[MessageId]): Future[Unit] = executeBatch({ batch =>
messageIds.foreach { mid =>
batch.add(preparedDeletePermanent.bind(mid.persistenceId, partitionNr: JLong, mid.sequenceNr: JLong))
}
}, Some(config.deleteRetries))
override def asyncReplayMessages(
persistenceId: String,
fromSequenceNr: Long,
toSequenceNr: Long,
max: Long)(recoveryCallback: (PersistentRepr) => Unit): Future[Unit] = Future(())

private def executeBatch(body: BatchStatement Unit, retries: Option[Int] = None): Future[Unit] = {
val batch = new BatchStatement().setConsistencyLevel(writeConsistency).asInstanceOf[BatchStatement]
Expand Down Expand Up @@ -168,7 +156,7 @@ class CassandraJournal extends AsyncWriteJournal with CassandraRecovery with Cas
}

private case class SerializedAtomicWrite(persistenceId: String, payload: Seq[Serialized])
private case class Serialized(sequenceNr: Long, serialized: ByteBuffer)
private case class Serialized(journaSequenceNr: Long, sequenceNr: Long, serialized: ByteBuffer)
private case class PartitionInfo(partitionNr: Long, minSequenceNr: Long, maxSequenceNr: Long)
}

Expand All @@ -181,3 +169,4 @@ class FixedRetryPolicy(number: Int) extends RetryPolicy {
if (nbRetry < number) RetryDecision.retry(cl) else RetryDecision.rethrow()
}
}

Original file line number Diff line number Diff line change
@@ -1,162 +1,10 @@
package akka.persistence.cassandra.journal

import java.lang.{ Long => JLong }

import akka.actor.ActorLogging

import scala.concurrent._

import com.datastax.driver.core.{ResultSet, Row}

import akka.persistence.PersistentRepr

trait CassandraRecovery extends ActorLogging {
this: CassandraJournal =>
import config._

implicit lazy val replayDispatcher = context.system.dispatchers.lookup(replayDispatcherId)

def asyncReplayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(replayCallback: (PersistentRepr) => Unit): Future[Unit] = Future {
replayMessages(persistenceId, fromSequenceNr, toSequenceNr, max)(replayCallback)
}

def asyncReadHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Future[Long] = Future {
readHighestSequenceNr(persistenceId, fromSequenceNr)
}

def readHighestSequenceNr(persistenceId: String, fromSequenceNr: Long): Long = {
findHighestSequenceNr(persistenceId, math.max(fromSequenceNr, highestDeletedSequenceNumber(persistenceId)))
}

def readLowestSequenceNr(persistenceId: String, fromSequenceNr: Long): Long = {
new MessageIterator(persistenceId, fromSequenceNr, Long.MaxValue, Long.MaxValue).find(!_.deleted).map(_.sequenceNr).getOrElse(fromSequenceNr)
}

def replayMessages(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long)(replayCallback: (PersistentRepr) => Unit): Unit = {
new MessageIterator(persistenceId, fromSequenceNr, toSequenceNr, max).foreach( msg => {
replayCallback(msg)
})
}

/**
* Iterator over messages, crossing partition boundaries.
*/
class MessageIterator(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long, max: Long) extends Iterator[PersistentRepr] {

import PersistentRepr.Undefined

private val initialFromSequenceNr = math.max(highestDeletedSequenceNumber(persistenceId) + 1, fromSequenceNr)
log.debug("Starting message scan from {}", initialFromSequenceNr)

private val iter = new RowIterator(persistenceId, initialFromSequenceNr, toSequenceNr)
private var mcnt = 0L

private var c: PersistentRepr = null
private var n: PersistentRepr = PersistentRepr(Undefined)

fetch()

def hasNext: Boolean =
n != null && mcnt < max

def next(): PersistentRepr = {
fetch()
mcnt += 1
c
}

/**
* Make next message n the current message c, complete c
* and pre-fetch new n.
*/
private def fetch(): Unit = {
c = n
n = null
while (iter.hasNext && n == null) {
val row = iter.next()
val snr = row.getLong("sequence_nr")
val m = persistentFromByteBuffer(row.getBytes("message"))
// there may be duplicates returned by iter
// (on scan boundaries within a partition)
if (snr == c.sequenceNr) c = m else n = m
}
}
}


private def findHighestSequenceNr(persistenceId: String, fromSequenceNr: Long) = {
@annotation.tailrec
def find(currentPnr: Long, currentSnr: Long): Long = {
// if every message has been deleted and thus no sequence_nr the driver gives us back 0 for "null" :(
val next = Option(session.execute(preparedSelectHighestSequenceNr.bind(persistenceId, currentPnr: JLong)).one())
.map(row => (row.getBool("used"), row.getLong("sequence_nr")))
next match {
// never been to this partition
case None => currentSnr
// don't currently explicitly set false
case Some((false, _)) => currentSnr
// everything deleted in this partition, move to the next
case Some((true, 0)) => find(currentPnr+1, currentSnr)
case Some((_, nextHighest)) => find(currentPnr+1, nextHighest)
}
}
find(partitionNr(fromSequenceNr), fromSequenceNr)
}

private def highestDeletedSequenceNumber(persistenceId: String): Long = {
Option(session.execute(preparedSelectDeletedTo.bind(persistenceId)).one())
.map(_.getLong("deleted_to")).getOrElse(0)
}

/**
* Iterates over rows, crossing partition boundaries.
*/
class RowIterator(persistenceId: String, fromSequenceNr: Long, toSequenceNr: Long) extends Iterator[Row] {
var currentPnr = partitionNr(fromSequenceNr)
var currentSnr = fromSequenceNr

var fromSnr = fromSequenceNr
var toSnr = toSequenceNr

var iter = newIter()

def newIter() = {
session.execute(preparedSelectMessages.bind(persistenceId, currentPnr: JLong, fromSnr: JLong, toSnr: JLong)).iterator
}

def inUse: Boolean = {
val execute: ResultSet = session.execute(preparedCheckInUse.bind(persistenceId, currentPnr: JLong))
if (execute.isExhausted) false
else execute.one().getBool("used")
}

@annotation.tailrec
final def hasNext: Boolean = {
if (iter.hasNext) {
// more entries available in current resultset
true
} else if (!inUse) {
// partition has never been in use so stop
false
} else {
// all entries consumed, try next partition
currentPnr += 1
fromSnr = currentSnr
iter = newIter()
hasNext
}
}

def next(): Row = {
val row = iter.next()
currentSnr = row.getLong("sequence_nr")
row
}

private def sequenceNrMin(partitionNr: Long): Long =
partitionNr * targetPartitionSize + 1L

private def sequenceNrMax(partitionNr: Long): Long =
(partitionNr + 1L) * targetPartitionSize
}
}
Loading

0 comments on commit 2fc6d2d

Please sign in to comment.