Skip to content

Commit

Permalink
implement basic key-value store with Membership using Paxos consensus
Browse files Browse the repository at this point in the history
  • Loading branch information
Kiibou-chan committed Oct 19, 2024
1 parent a78185e commit fdccf6d
Show file tree
Hide file tree
Showing 10 changed files with 314 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object TCP {
println(s"handling new connection")
val conn = new JIOStreamConnection(socket.getInputStream, socket.getOutputStream, () => socket.close())
executionContext.execute: () =>
println(s"exeecuting task")
println(s"executing task")
conn.loopHandler(incoming)
conn
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package channels

import de.rmgk.delay.{Async, Callback}

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, IOException, InputStream, OutputStream}
import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, IOException, InputStream, OutputStream}

class SendingClosedException extends IOException

Expand All @@ -16,7 +16,6 @@ class JioInputStreamAdapter(in: InputStream) {
inputStream.readFully(bytes, 0, size)

ArrayMessageBuffer(bytes)

}

def loopReceive(handler: Callback[MessageBuffer]): Unit = {
Expand All @@ -36,7 +35,7 @@ class JioOutputStreamAdapter(out: OutputStream) {

def send(data: MessageBuffer): Unit = {
val outArray = data.asArray
outputStream.writeInt(outArray.size)
outputStream.writeInt(outArray.length)
outputStream.write(outArray)
outputStream.flush()
}
Expand All @@ -53,15 +52,15 @@ class JIOStreamConnection(in: InputStream, out: OutputStream, doClose: () => Uni
// connection interface

def send(data: MessageBuffer): Async[Any, Unit] = Async {
println(s"sending data on jio stream")
// println(s"sending data on jio stream")
outputStream.send(data)
}

def close(): Unit = doClose()

// frame parsing

def loopHandler(handler: Handler[MessageBuffer]) =
def loopHandler(handler: Handler[MessageBuffer]): Unit =
inputStream.loopReceive(handler.getCallbackFor(this))

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package probench

import probench.data.{ClientNodeState, DataManager, KVOperation, Request, Response}
import rdts.base.{Bottom, LocalUid, Uid}
import rdts.datatypes.contextual.CausalQueue
import rdts.dotted.Dotted
import rdts.syntax.DeltaBuffer

import scala.io.StdIn.readLine
import scala.util.matching.Regex

class Client {

given localUid: LocalUid = LocalUid.gen()
private val dataManager = DataManager[ClientNodeState](localUid, Bottom[ClientNodeState].empty, onStateChange)

private val lock = new Object()
private var currentOp: Option[Request] = None

val get: Regex = """get (\w+)""".r
val put: Regex = """put (\w+) (\w+)""".r

private def onStateChange(oldState: ClientNodeState, newState: ClientNodeState): Unit = {
for {
op <- currentOp
CausalQueue.QueueElement(res@Response(req, _), _, _) <- newState.responses.data.values if req == op
} {
println(res.response)

currentOp = None

dataManager.transform(_.mod(state => state.copy(responses = state.responses.mod(_.removeBy(_ == res)))))

lock.synchronized {
lock.notifyAll()
}
}
}

private def handleOp(op: KVOperation[String, String]): Unit = {
val req = Request(op)
currentOp = Some(req)

dataManager.transform { current =>
current.mod(it => it.copy(requests = it.requests.mod(_.enqueue(req))))
}

lock.synchronized {
lock.wait()
}
}

def read(key: String): Unit = {
handleOp(KVOperation.Read(key))
}

def write(key: String, value: String): Unit = {
handleOp(KVOperation.Write(key, value))
}

def startCLI(): Unit = {
while true do {
print("client> ")
val line = readLine()

line match {
case get(key) => read(key)
case put(key, value) => write(key, value)
case "exit" => System.exit(0)
case _ => println(s"Error parsing: $line")
}
}
}

export dataManager.addLatentConnection

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package probench

import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
import probench.data.*
import rdts.base.{Bottom, LocalUid, Uid}
import rdts.datatypes.experiments.protocols.simplified.Paxos
import rdts.datatypes.experiments.protocols.{LogHack, Membership}
import rdts.dotted.Dotted
import rdts.syntax.DeltaBuffer

class Node(val name: String, val initialClusterIds: Set[Uid]) {

private type ClusterState = Membership[Request, Paxos, Paxos]

given localUid: LocalUid = LocalUid.predefined(name)
private val clientDataManager =
DataManager[ClientNodeState](localUid, Bottom[ClientNodeState].empty, onClientStateChange)
private val clusterDataManager =
DataManager[ClusterState](localUid, Membership.init(initialClusterIds), onClusterStateChange)

private def onClientStateChange(oldState: ClientNodeState, newState: ClientNodeState): Unit = {
if newState.requests.data.values.nonEmpty then {
clusterDataManager.transform(_.mod(_.write(newState.requests.data.values.last.value)))
}
}

given LogHack = new LogHack(true)

private def onClusterStateChange(oldState: ClusterState, newState: ClusterState): Unit = {
val upkept: ClusterState = newState.merge(newState.upkeep())

if !(upkept <= newState) then {
clusterDataManager.transform(_.mod(_ => upkept))
}

if newState.log.size > oldState.log.size then {
val op = newState.log.last

val res: String = op match {
case Request(KVOperation.Read(key), _) =>
newState.log.reverseIterator.collectFirst {
case Request(KVOperation.Write(writeKey, value), _) if writeKey == key => value
}.getOrElse(s"Key $key has not been written to!")
case Request(KVOperation.Write(_, _), _) => "OK"
}

clientDataManager.transform(_.mod(state =>
state.copy(
requests = state.requests.mod(_.removeBy(_ == op)),
responses = state.responses.mod(_.enqueue(Response(op, res))),
)
))

}
}

export clientDataManager.addLatentConnection as addClientConnection
export clusterDataManager.addLatentConnection as addClusterConnection

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package probench

import channels.TCP
import com.github.plokhotnyuk.jsoniter_scala.core.JsonValueCodec
import com.github.plokhotnyuk.jsoniter_scala.macros.{CodecMakerConfig, JsonCodecMaker}
import de.rmgk.options.*
import probench.data.{ClientNodeState, KVOperation, Request}
import rdts.base.Uid
import rdts.datatypes.experiments.protocols.Membership
import rdts.datatypes.experiments.protocols.simplified.Paxos

import java.net.Socket
import java.util.concurrent.{ExecutorService, Executors}
import scala.concurrent.ExecutionContext

object cli {

private val executor: ExecutorService = Executors.newCachedThreadPool()
private val ec: ExecutionContext = ExecutionContext.fromExecutor(executor)

def main(args: Array[String]): Unit = {
val name = named[String]("--name", "")
val clientPort = named[Int]("--listen-client-port", "")
val peerPort = named[Int]("--listen-peer-port", "")

val ipAndPort = """(.+):(\d*)""".r

given ipAndPortParser: ArgumentValueParser[(String, Int)] with
override def apply(args: List[String]): (Option[(String, Int)], List[String]) =
args match {
case ipAndPort(ip, port) :: rest => (Some((ip, Integer.parseInt(port))), rest)
case _ => (None, args)
}

override def valueDescription: String = "[<ip:port>]"
end ipAndPortParser

given uidParser: ArgumentValueParser[Uid] with
override def apply(args: List[String]): (Option[Uid], List[String]) =
args match {
case string :: rest => (Some(Uid.predefined(string)), rest)
case _ => (None, args)
}

override def valueDescription: String = "[uid]"
end uidParser

given JsonValueCodec[ClientNodeState] = JsonCodecMaker.make(CodecMakerConfig.withMapAsArray(true))
given JsonValueCodec[Membership[Request, Paxos, Paxos]] =
JsonCodecMaker.make(CodecMakerConfig.withMapAsArray(true))

val argparse = argumentParser {
inline def cluster = named[List[(String, Int)]]("--cluster", "[<ip:port>]")
inline def initialClusterIds = named[List[Uid]]("--initial-cluster-ids", "[name]")

inline def clientNode = named[(String, Int)]("--node", "<ip:port>")

subcommand("node", "starts a cluster node") {
val node = Node(name.value, initialClusterIds.value.toSet)

node.addClientConnection(TCP.listen(TCP.defaultSocket("localhost", clientPort.value), ec))
node.addClusterConnection(TCP.listen(TCP.defaultSocket("localhost", peerPort.value), ec))

cluster.value.foreach { (ip, port) =>
node.addClusterConnection(TCP.connect(() => Socket(ip, port), ec))
}
}.value

subcommand("client", "starts a client to interact with a node") {
val client = Client()

val (ip, port) = clientNode.value

client.addLatentConnection(TCP.connect(() => Socket(ip, port), ec))

client.startCLI()
}.value

subcommand("benchmark", "") {}.value
}

argparse.parse(args.toList).printHelp()
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package probench.data

import rdts.base.{Bottom, Lattice, Uid}
import rdts.datatypes.contextual.CausalQueue
import rdts.datatypes.{GrowOnlyList, GrowOnlyMap}
import rdts.dotted.Dotted

enum KVOperation[Key, Value] {
def key: Key

case Read(key: Key)
case Write(key: Key, value: Value)
}

case class Request(op: KVOperation[String, String], requestUid: Uid = Uid.gen())
case class Response(request: Request, response: String)

case class ClientNodeState(
requests: Dotted[CausalQueue[Request]],
responses: Dotted[CausalQueue[Response]],
) derives Lattice, Bottom
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package probench.data

import rdts.base.{Lattice, LocalUid}
import rdts.syntax.DeltaBuffer
import rdts.time.Dots
import replication.{ProtocolDots, DataManager as RepDataManager}

class DataManager[State: Lattice](
val localReplicaId: LocalUid,
val initialState: State,
val onChange: (State, State) => Unit,
) {
given Lattice[ProtocolDots[State]] = Lattice.derived
private val dataManager = RepDataManager[State](localReplicaId, _ => (), receivedChanges)
private var mergedState: ProtocolDots[State] = dataManager.allDeltas.foldLeft(ProtocolDots(initialState, Dots.empty))(Lattice[ProtocolDots[State]].merge)

private def receivedChanges(changes: ProtocolDots[State]): Unit = {
val oldState = mergedState
mergedState = mergedState.merge(changes)

onChange(oldState.data, mergedState.data)
}

def transform(fun: DeltaBuffer[State] => DeltaBuffer[State]): Unit = dataManager.lock.synchronized {
val current: DeltaBuffer[State] = DeltaBuffer(mergedState.data)
val next: DeltaBuffer[State] = fun(current)

next.deltaBuffer.foreach { delta =>
dataManager.applyLocalDelta(ProtocolDots(
delta,
Dots.single(mergedState.context.nextDot(dataManager.replicaId.uid))
))
}
}

export dataManager.addLatentConnection

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,30 @@ case class Membership[A, C[_], D[_]](
innerConsensus: D[A],
log: List[A],
membershipChanging: Boolean = false
)(using
Consensus[C],
Consensus[D]
) {
private def unchanged: Membership[A, C, D] = Membership(
private def unchanged(using Consensus[C], Consensus[D]): Membership[A, C, D] = Membership(
counter = counter,
membersConsensus = Consensus[C].empty,
innerConsensus = Consensus[D].empty,
log = List()
)

override def toString: String =
s"Membership(counter: $counter, members: $currentMembers,log: $log, membershipChanging: $membershipChanging)".stripMargin
s"Membership(counter: $counter, members: $membersConsensus,log: $log, membershipChanging: $membershipChanging)".stripMargin

def currentMembers: Set[Uid] =
def currentMembers(using Consensus[C], Consensus[D]): Set[Uid] =
assert(membersConsensus.members == innerConsensus.members, "Membership of both consensus protocols is the same")
membersConsensus.members

def addMember(id: Uid)(using LocalUid): Membership[A, C, D] =
def addMember(id: Uid)(using LocalUid, Consensus[C], Consensus[D]): Membership[A, C, D] =
if isMember then
unchanged.copy(
membershipChanging = true,
membersConsensus = membersConsensus.write(currentMembers + id)
)
else unchanged

def removeMember(id: Uid)(using LocalUid): Membership[A, C, D] =
def removeMember(id: Uid)(using LocalUid, Consensus[C], Consensus[D]): Membership[A, C, D] =
if currentMembers.size > 1 && isMember then // cannot remove last member
unchanged.copy(
membershipChanging = true,
Expand All @@ -51,16 +48,16 @@ case class Membership[A, C[_], D[_]](

def read: List[A] = log

def write(value: A)(using LocalUid): Membership[A, C, D] =
def write(value: A)(using LocalUid, Consensus[C], Consensus[D]): Membership[A, C, D] =
if !membershipChanging && isMember then
unchanged.copy(
innerConsensus = innerConsensus.write(value)
)
else unchanged

def isMember(using LocalUid): Boolean = currentMembers.contains(replicaId)
def isMember(using LocalUid, Consensus[C], Consensus[D]): Boolean = currentMembers.contains(replicaId)

def upkeep()(using rid: LocalUid, logger: LogHack): Membership[A, C, D] =
def upkeep()(using rid: LocalUid, logger: LogHack, cc: Consensus[C], cd: Consensus[D]): Membership[A, C, D] =
if !isMember then return unchanged // do nothing if we are not a member anymore
val newMembers = membersConsensus.merge(membersConsensus.upkeep())
val newInner = innerConsensus.merge(innerConsensus.upkeep())
Expand Down
Loading

0 comments on commit fdccf6d

Please sign in to comment.