Skip to content

Commit

Permalink
improve key-value store for benchmarks
Browse files Browse the repository at this point in the history
- fix request ordering
- fix concurrency issues in the DataManager
  • Loading branch information
Kiibou-chan committed Oct 29, 2024
1 parent 1a32b44 commit 91fb805
Show file tree
Hide file tree
Showing 14 changed files with 169 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package channels

import de.rmgk.delay.{Async, Callback}

import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, DataOutputStream, EOFException, IOException, InputStream, OutputStream}
import java.io.*

class SendingClosedException extends IOException

Expand Down
5 changes: 5 additions & 0 deletions Modules/Examples/Protocol Benchmarks/args/client-1-1
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# wait-for-res false
multiput key%n valueX%n 200
multiget key%n 200
# wait
exit
5 changes: 5 additions & 0 deletions Modules/Examples/Protocol Benchmarks/args/client-1-2
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# wait-for-res false
multiput key%n valueY%n 200
multiget key%n 200
# wait
exit
5 changes: 5 additions & 0 deletions Modules/Examples/Protocol Benchmarks/args/client-1-3
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# wait-for-res false
multiput key%n valueZ%n 200
multiget key%n 200
# wait
exit
Original file line number Diff line number Diff line change
@@ -1,31 +1,42 @@
package probench

import probench.data.{ClientNodeState, DataManager, KVOperation, Request, Response}
import probench.data.*
import rdts.base.{Bottom, LocalUid, Uid}
import rdts.datatypes.contextual.CausalQueue
import rdts.dotted.Dotted
import rdts.syntax.DeltaBuffer

import scala.io.StdIn
import scala.io.StdIn.readLine
import scala.util.matching.Regex

class Client {
class Client(val name: Uid) {

given localUid: LocalUid = LocalUid.gen()
given localUid: LocalUid = LocalUid(name)
private val dataManager = DataManager[ClientNodeState](localUid, Bottom[ClientNodeState].empty, onStateChange)

private val lock = new Object()
private var currentOp: Option[Request] = None
private var waitForOp: Boolean = true

val get: Regex = """get (\w+)""".r
val put: Regex = """put (\w+) (\w+)""".r
private val commented: Regex = """#.*""".r
private val waitForRes: Regex = """wait-for-res (true|false)""".r
private val get: Regex = """get ([\w%]+)""".r
private val put: Regex = """put ([\w%]+) ([\w%]+)""".r
private val multiget: Regex = """multiget ([\w%]+) (\d+)""".r
private val multiput: Regex = """multiput ([\w%]+) ([\w%]+) (\d+)""".r

private def onStateChange(oldState: ClientNodeState, newState: ClientNodeState): Unit = {
/* val diff = newState.responses.data.values.size - oldState.responses.data.values.size
if diff > 0 then {
println(s"Got $diff result(s): ${newState.responses.data.values.toList.reverseIterator.take(diff).toList.reverse.map(_.value)}")
} */

for {
op <- currentOp
CausalQueue.QueueElement(res@Response(req, _), _, _) <- newState.responses.data.values if req == op
op <- currentOp
CausalQueue.QueueElement(res @ Response(req, _), _, _) <- newState.responses.data.values if req == op
} {
println(res.response)
println(res.payload)

currentOp = None

Expand All @@ -41,12 +52,18 @@ class Client {
val req = Request(op)
currentOp = Some(req)

// println(s"Put $req")

dataManager.transform { current =>
current.mod(it => it.copy(requests = it.requests.mod(_.enqueue(req))))
}

lock.synchronized {
lock.wait()
// println(s"New Requests ${dataManager.mergedState.data.requests.data.values.toList.map(_.value)}")

if waitForOp then {
lock.synchronized {
lock.wait()
}
}
}

Expand All @@ -58,16 +75,28 @@ class Client {
handleOp(KVOperation.Write(key, value))
}

private def multiget(key: String, times: Int): Unit = {
for i <- 1 to times do read(key.replace("%n", i.toString))
}

private def multiput(key: String, value: String, times: Int): Unit = {
for i <- 1 to times do write(key.replace("%n", i.toString), value.replace("%n", i.toString))
}

def startCLI(): Unit = {
while true do {
print("client> ")
val line = readLine()

val line = Option(readLine())
line match {
case get(key) => read(key)
case put(key, value) => write(key, value)
case "exit" => System.exit(0)
case _ => println(s"Error parsing: $line")
case Some(commented()) => // ignore
case Some(get(key)) => read(key)
case Some(put(key, value)) => write(key, value)
case Some(multiget(key, times)) => multiget(key, times.toInt)
case Some(multiput(key, value, times)) => multiput(key, value, times.toInt)
case Some(waitForRes(flag)) => waitForOp = flag.toBoolean
case Some("wait") => lock.synchronized { lock.wait() }
case Some("exit") => System.exit(0)
case None | Some(_) => println(s"Error parsing: $line")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,76 @@ import rdts.datatypes.experiments.protocols.{LogHack, Membership}
import rdts.dotted.Dotted
import rdts.syntax.DeltaBuffer

class Node(val name: String, val initialClusterIds: Set[Uid]) {
class Node(val name: Uid, val initialClusterIds: Set[Uid]) {

private type ClusterState = Membership[Request, Paxos, Paxos]

given localUid: LocalUid = LocalUid.predefined(name)
given localUid: LocalUid = LocalUid(name)
given LogHack = new LogHack(false)

private val clientDataManager =
DataManager[ClientNodeState](localUid, Bottom[ClientNodeState].empty, onClientStateChange)
private val clusterDataManager =
DataManager[ClusterState](localUid, Membership.init(initialClusterIds), onClusterStateChange)

private def onClientStateChange(oldState: ClientNodeState, newState: ClientNodeState): Unit = {
if newState.requests.data.values.nonEmpty then {
clusterDataManager.transform(_.mod(_.write(newState.requests.data.values.last.value)))
/*
val diff = newState.requests.data.values.size - oldState.requests.data.values.size
if diff > 0 then {
println(s"Requests: ${newState.requests.data.values.toList.map(_.value)}")
println(s"Sorted : ${newState.requests.data.values.toList.sortBy(_.order)(using VectorClock.vectorClockTotalOrdering).map(it => it.order -> it.value)}")
println(s"Dots : ${newState.requests.data.dots}")
println(s"Time : ${newState.requests.data.clock}")
}
*/

if newState.requests.data.values.size == 1 then {
clusterDataManager.transform(_.mod(_.write(newState.requests.data.head)))
}
}

given LogHack = new LogHack(true)

private def onClusterStateChange(oldState: ClusterState, newState: ClusterState): Unit = {
val upkept: ClusterState = newState.merge(newState.upkeep())
val delta = newState.upkeep()
val upkept: ClusterState = newState.merge(delta)

if !(upkept <= newState) then {
clusterDataManager.transform(_.mod(_ => upkept))
if !(upkept <= newState) || upkept.log.size > newState.log.size then {
clusterDataManager.transform(_.mod(_ => delta))
}

if newState.log.size > oldState.log.size then {
val op = newState.log.last
if upkept.log.size > oldState.log.size then {
val diff = upkept.log.size - oldState.log.size
// println(s"DIFF $diff")

val res: String = op match {
case Request(KVOperation.Read(key), _) =>
newState.log.reverseIterator.collectFirst {
case Request(KVOperation.Write(writeKey, value), _) if writeKey == key => value
}.getOrElse(s"Key $key has not been written to!")
case Request(KVOperation.Write(_, _), _) => "OK"
}
for op <- upkept.log.reverseIterator.take(diff).toList.reverseIterator do {

val res: String = op match {
case Request(KVOperation.Read(key), _) =>
upkept.log.reverseIterator.collectFirst {
case Request(KVOperation.Write(writeKey, value), _) if writeKey == key => s"$key=$value"
}.getOrElse(s"Key '$key' has not been written to!")
case Request(KVOperation.Write(key, value), _) => s"$key=$value; OK"
}

clientDataManager.transform(_.mod(state =>
state.copy(
requests = state.requests.mod(_.removeBy(_ == op)),
responses = state.responses.mod(_.enqueue(Response(op, res))),
)
))
clientDataManager.transform { it =>
if it.state.requests.data.values.exists { e => e.value == op } then {
it.mod { state =>
// println(s"Writing Response: $op -> $res")
val newState = state.copy(
requests = state.requests.mod(_.removeBy(_ == op)),
responses = state.responses.mod(_.enqueue(Response(op, res))),
)
//println(s"Remaining Requests: ${newState.requests.data.values.toList.map(_.value)}")
newState
}
} else it
}

val clientState = clientDataManager.mergedState.data

if clientState.requests.data.values.nonEmpty then {
clusterDataManager.transform(_.mod(_.write(clientState.requests.data.head)))
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ object cli {
private val ec: ExecutionContext = ExecutionContext.fromExecutor(executor)

def main(args: Array[String]): Unit = {
val name = named[String]("--name", "")

val clientPort = named[Int]("--listen-client-port", "")
val peerPort = named[Int]("--listen-peer-port", "")

val ipAndPort = """(.+):(\d*)""".r
val ipAndPort = """(.+):(\d+)""".r

given ipAndPortParser: ArgumentValueParser[(String, Int)] with
override def apply(args: List[String]): (Option[(String, Int)], List[String]) =
Expand All @@ -32,7 +32,7 @@ object cli {
case _ => (None, args)
}

override def valueDescription: String = "[<ip:port>]"
override def valueDescription: String = "<ip:port>"
end ipAndPortParser

given uidParser: ArgumentValueParser[Uid] with
Expand All @@ -42,18 +42,19 @@ object cli {
case _ => (None, args)
}

override def valueDescription: String = "[uid]"
override def valueDescription: String = "<uid>"
end uidParser

given JsonValueCodec[ClientNodeState] = JsonCodecMaker.make(CodecMakerConfig.withMapAsArray(true))

given JsonValueCodec[Membership[Request, Paxos, Paxos]] =
JsonCodecMaker.make(CodecMakerConfig.withMapAsArray(true))

val argparse = argumentParser {
inline def cluster = named[List[(String, Int)]]("--cluster", "[<ip:port>]")
inline def initialClusterIds = named[List[Uid]]("--initial-cluster-ids", "[name]")

inline def clientNode = named[(String, Int)]("--node", "<ip:port>")
inline def cluster = named[List[(String, Int)]]("--cluster", "")
inline def initialClusterIds = named[List[Uid]]("--initial-cluster-ids", "")
inline def clientNode = named[(String, Int)]("--node", "<ip:port>")
inline def name = named[Uid]("--name", "", Uid.gen())

subcommand("node", "starts a cluster node") {
val node = Node(name.value, initialClusterIds.value.toSet)
Expand All @@ -67,7 +68,7 @@ object cli {
}.value

subcommand("client", "starts a client to interact with a node") {
val client = Client()
val client = Client(name.value)

val (ip, port) = clientNode.value

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package probench.data

import rdts.base.{Bottom, Lattice, Uid}
import rdts.base.{Bottom, Lattice, LocalUid, Uid}
import rdts.datatypes.contextual.CausalQueue
import rdts.datatypes.{GrowOnlyList, GrowOnlyMap}
import rdts.dotted.Dotted
import rdts.time.VectorClock

enum KVOperation[Key, Value] {
def key: Key
Expand All @@ -13,9 +14,10 @@ enum KVOperation[Key, Value] {
}

case class Request(op: KVOperation[String, String], requestUid: Uid = Uid.gen())
case class Response(request: Request, response: String)
case class Response(request: Request, payload: String)

case class ClientNodeState(
requests: Dotted[CausalQueue[Request]],
responses: Dotted[CausalQueue[Response]],
) derives Lattice, Bottom

Original file line number Diff line number Diff line change
@@ -1,30 +1,34 @@
package probench.data

import channels.LatentConnection
import rdts.base.{Lattice, LocalUid}
import rdts.syntax.DeltaBuffer
import rdts.time.Dots
import replication.{ProtocolDots, DataManager as RepDataManager}
import replication.{ProtocolDots, ProtocolMessage, DataManager as RepDataManager}

class DataManager[State: Lattice](
val localReplicaId: LocalUid,
val initialState: State,
val onChange: (State, State) => Unit,
) {
given Lattice[ProtocolDots[State]] = Lattice.derived
private val dataManager = RepDataManager[State](localReplicaId, _ => (), receivedChanges)
private var mergedState: ProtocolDots[State] = dataManager.allDeltas.foldLeft(ProtocolDots(initialState, Dots.empty))(Lattice[ProtocolDots[State]].merge)
private val dataManager = RepDataManager[State](localReplicaId, _ => (), receivedChanges)

var mergedState: ProtocolDots[State] =
dataManager.allDeltas.foldLeft(ProtocolDots(initialState, Dots.empty))(Lattice[ProtocolDots[State]].merge)

private def receivedChanges(changes: ProtocolDots[State]): Unit = {
val oldState = mergedState
mergedState = mergedState.merge(changes)

dataManager.lock.synchronized {
mergedState = mergedState.merge(changes)
}
onChange(oldState.data, mergedState.data)
}

def transform(fun: DeltaBuffer[State] => DeltaBuffer[State]): Unit = dataManager.lock.synchronized {
val current: DeltaBuffer[State] = DeltaBuffer(mergedState.data)
val next: DeltaBuffer[State] = fun(current)

next.deltaBuffer.foreach { delta =>
dataManager.applyLocalDelta(ProtocolDots(
delta,
Expand Down
Loading

0 comments on commit 91fb805

Please sign in to comment.