Skip to content

Commit

Permalink
style(MemBlock): wipe out selectOldest and selectOldestRedirect
Browse files Browse the repository at this point in the history
This commit clears up `selectOldest` and `selectOldestRedirect` function
that repeatedly appear in MemBlock and abstract them into one object in
`MemCommon`.
  • Loading branch information
linjuanZ committed Dec 11, 2024
1 parent 7dc438a commit 9f5a6f0
Show file tree
Hide file tree
Showing 11 changed files with 156 additions and 199 deletions.
4 changes: 3 additions & 1 deletion src/main/scala/xiangshan/Bundle.scala
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,12 @@ class MicroOp(implicit p: Parameters) extends CfCtrl {
}
}

class XSBundleWithMicroOp(implicit p: Parameters) extends XSBundle {
trait HasMicroOp { this: XSBundle =>
val uop = new DynInst
}

class XSBundleWithMicroOp(implicit p: Parameters) extends XSBundle with HasMicroOp

class MicroOpRbExt(implicit p: Parameters) extends XSBundleWithMicroOp {
val flag = UInt(1.W)
}
Expand Down
11 changes: 1 addition & 10 deletions src/main/scala/xiangshan/backend/MemBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1308,17 +1308,8 @@ class MemBlockInlinedImp(outer: MemBlockInlined) extends LazyModuleImp(outer)
lsq.io.brqRedirect <> redirect

// violation rollback
def selectOldestRedirect(xs: Seq[Valid[Redirect]]): Vec[Bool] = {
val compareVec = (0 until xs.length).map(i => (0 until i).map(j => isAfter(xs(j).bits.robIdx, xs(i).bits.robIdx)))
val resultOnehot = VecInit((0 until xs.length).map(i => Cat((0 until xs.length).map(j =>
(if (j < i) !xs(j).valid || compareVec(i)(j)
else if (j == i) xs(i).valid
else !xs(j).valid || !compareVec(j)(i))
)).andR))
resultOnehot
}
val allRedirect = loadUnits.map(_.io.rollback) ++ hybridUnits.map(_.io.ldu_io.rollback) ++ lsq.io.nack_rollback ++ lsq.io.nuke_rollback
val oldestOneHot = selectOldestRedirect(allRedirect)
val oldestOneHot = Redirect.selectOldestRedirect(allRedirect)
val oldestRedirect = WireDefault(Mux1H(oldestOneHot, allRedirect))
// memory replay would not cause IAF/IPF/IGPF
oldestRedirect.bits.cfiUpdate.backendIAF := false.B
Expand Down
101 changes: 99 additions & 2 deletions src/main/scala/xiangshan/mem/MemCommon.scala
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,107 @@ object shiftMaskToHigh {
}
}

object SelectOldest {
def apply[T <: Bundle](
valids: Seq[Bool],
bits: Seq[T],
f: (T, T) => T,
groupSizeOpt: Option[Int] = None,
flushOpt: Option[Valid[Redirect]] = None
): (Seq[Bool], Seq[T]) = {

val len = valids.length
/**
* When `groupSizeOpt` is defined, select oldest entry within the group size in a cycle and iterate the selection
* in the next cycle.
* When `flushOpt` is defined, the selected result needs to be flushed in each cycle.
* Therefore `flushOpt` can be defined only if `groupSizeOpt` is defined.
*/
val needGroup = groupSizeOpt.isDefined
val needFlush = flushOpt.isDefined
require(valids.length == bits.length)
require(!needFlush || needGroup)
require(!needFlush || bits.head.isInstanceOf[HasMicroOp])

if (needGroup) {
val groupSize = groupSizeOpt.get
require(groupSize > 0)
val groups = scala.math.ceil(len.toFloat / groupSize).toInt

val validGroups = valids.grouped(groupSize).toList
val bitsGroups = bits.grouped(groupSize).toList
val selects = (0 until groups).map { case g =>
val (selValid, selBits) = apply(validGroups(g), bitsGroups(g), f)
val selValidNext = RegNext(selValid.head)
val selBitsNext = RegEnable(selBits.head, selValid.head)
val doFlush = flushOpt match {
case Some(r) => selBitsNext.asInstanceOf[HasMicroOp].uop.robIdx.needFlush(RegNext(r))
case None => false.B
}
(selValidNext && !doFlush, selBitsNext)
}
if (groups <= 1) (selects.map(_._1), selects.map(_._2))
else apply(selects.map(_._1), selects.map(_._2), f, groupSizeOpt, flushOpt)

} else {
if (len == 0 || len == 1) {
(valids, bits)
} else if (len == 2) {
val oldest = Mux(
valids(0) && valids(1),
f(bits(0), bits(1)),
Mux(valids(0) && !valids(1), bits(0), bits(1))
)
(Seq(valids(0) || valids(1)), Seq(oldest))
} else {
val left = apply(valids.take(len / 2), bits.take(len / 2), f)
val right = apply(valids.takeRight(len - (len / 2)), bits.takeRight(len - (len / 2)), f)
apply(left._1 ++ right._1, left._2 ++ right._2, f)
}
}
}
}

object SelectOldestRobIdx {
def apply[T <: Bundle](
valids: Seq[Bool],
bits: Seq[T],
groupSizeOpt: Option[Int] = None,
flushOpt: Option[Valid[Redirect]] = None
): (Seq[Bool], Seq[T]) = {
require(bits.head.isInstanceOf[HasMicroOp])
SelectOldest(valids, bits,
(a: T, b: T) => {
val aRobIdx = a.asInstanceOf[HasMicroOp].uop.robIdx
val bRobIdx = b.asInstanceOf[HasMicroOp].uop.robIdx
Mux(aRobIdx > bRobIdx, b, a)
}, groupSizeOpt, flushOpt
)
}
}

object SelectOldestUopIdx {
def apply[T <: Bundle](
valids: Seq[Bool],
bits: Seq[T],
groupSizeOpt: Option[Int] = None,
flushOpt: Option[Valid[Redirect]] = None
): (Seq[Bool], Seq[T]) = {
require(bits.head.isInstanceOf[HasMicroOp])
SelectOldest(valids, bits,
(a: T, b: T) => {
val au = a.asInstanceOf[HasMicroOp].uop
val bu = b.asInstanceOf[HasMicroOp].uop
Mux(au.robIdx > bu.robIdx || au.robIdx === bu.robIdx && au.uopIdx > bu.uopIdx, b, a)
}, groupSizeOpt, flushOpt
)
}
}

class LsPipelineBundle(implicit p: Parameters) extends XSBundle
with HasDCacheParameters
with HasVLSUParameters {
val uop = new DynInst
with HasVLSUParameters
with HasMicroOp {
val vaddr = UInt(VAddrBits.W)
// For exception vaddr generate
val fullva = UInt(XLEN.W)
Expand Down
24 changes: 1 addition & 23 deletions src/main/scala/xiangshan/mem/lsqueue/LoadExceptionBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,7 @@ class LqExceptionBuffer(implicit p: Parameters) extends XSModule with HasCircula
req_valid := true.B
}

def selectOldest[T <: LqWriteBundle](valid: Seq[Bool], bits: Seq[T]): (Seq[Bool], Seq[T]) = {
assert(valid.length == bits.length)
if (valid.length == 0 || valid.length == 1) {
(valid, bits)
} else if (valid.length == 2) {
val res = Seq.fill(2)(Wire(ValidIO(chiselTypeOf(bits(0)))))
for (i <- res.indices) {
res(i).valid := valid(i)
res(i).bits := bits(i)
}
val oldest = Mux(valid(0) && valid(1),
Mux(isAfter(bits(0).uop.robIdx, bits(1).uop.robIdx) ||
(bits(0).uop.robIdx === bits(1).uop.robIdx && bits(0).uop.uopIdx > bits(1).uop.uopIdx), res(1), res(0)),
Mux(valid(0) && !valid(1), res(0), res(1)))
(Seq(oldest.valid), Seq(oldest.bits))
} else {
val left = selectOldest(valid.take(valid.length / 2), bits.take(bits.length / 2))
val right = selectOldest(valid.takeRight(valid.length - (valid.length / 2)), bits.takeRight(bits.length - (bits.length / 2)))
selectOldest(left._1 ++ right._1, left._2 ++ right._2)
}
}

val reqSel = selectOldest(s2_enqueue, s2_req)
val reqSel = SelectOldestUopIdx(s2_enqueue, s2_req)

when (req_valid) {
req := Mux(
Expand Down
22 changes: 0 additions & 22 deletions src/main/scala/xiangshan/mem/lsqueue/LoadMisalignBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -91,28 +91,6 @@ class LoadMisalignBuffer(implicit p: Parameters) extends XSModule
truncateData(XLEN - 1, 0)
}

def selectOldest[T <: LqWriteBundle](valid: Seq[Bool], bits: Seq[T]): (Seq[Bool], Seq[T]) = {
assert(valid.length == bits.length)
if (valid.length == 0 || valid.length == 1) {
(valid, bits)
} else if (valid.length == 2) {
val res = Seq.fill(2)(Wire(ValidIO(chiselTypeOf(bits(0)))))
for (i <- res.indices) {
res(i).valid := valid(i)
res(i).bits := bits(i)
}
val oldest = Mux(valid(0) && valid(1),
Mux(isAfter(bits(0).uop.robIdx, bits(1).uop.robIdx) ||
(bits(0).uop.robIdx === bits(1).uop.robIdx && bits(0).uop.uopIdx > bits(1).uop.uopIdx), res(1), res(0)),
Mux(valid(0) && !valid(1), res(0), res(1)))
(Seq(oldest.valid), Seq(oldest.bits))
} else {
val left = selectOldest(valid.take(valid.length / 2), bits.take(bits.length / 2))
val right = selectOldest(valid.takeRight(valid.length - (valid.length / 2)), bits.takeRight(bits.length - (bits.length / 2)))
selectOldest(left._1 ++ right._1, left._2 ++ right._2)
}
}

val io = IO(new Bundle() {
val redirect = Flipped(Valid(new Redirect))
val req = Vec(enqPortNum, Flipped(Decoupled(new LqWriteBundle)))
Expand Down
50 changes: 6 additions & 44 deletions src/main/scala/xiangshan/mem/lsqueue/LoadQueueRAW.scala
Original file line number Diff line number Diff line change
Expand Up @@ -248,49 +248,6 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
val lgSelectGroupSize = log2Ceil(SelectGroupSize)
val TotalSelectCycles = scala.math.ceil(log2Ceil(LoadQueueRAWSize).toFloat / lgSelectGroupSize).toInt + 1

def selectPartialOldest[T <: XSBundleWithMicroOp](valid: Seq[Bool], bits: Seq[T]): (Seq[Bool], Seq[T]) = {
assert(valid.length == bits.length)
if (valid.length == 0 || valid.length == 1) {
(valid, bits)
} else if (valid.length == 2) {
val res = Seq.fill(2)(Wire(ValidIO(chiselTypeOf(bits(0)))))
for (i <- res.indices) {
res(i).valid := valid(i)
res(i).bits := bits(i)
}
val oldest = Mux(valid(0) && valid(1), Mux(isAfter(bits(0).uop.robIdx, bits(1).uop.robIdx), res(1), res(0)), Mux(valid(0) && !valid(1), res(0), res(1)))
(Seq(oldest.valid), Seq(oldest.bits))
} else {
val left = selectPartialOldest(valid.take(valid.length / 2), bits.take(bits.length / 2))
val right = selectPartialOldest(valid.takeRight(valid.length - (valid.length / 2)), bits.takeRight(bits.length - (bits.length / 2)))
selectPartialOldest(left._1 ++ right._1, left._2 ++ right._2)
}
}

def selectOldest[T <: XSBundleWithMicroOp](valid: Seq[Bool], bits: Seq[T]): (Seq[Bool], Seq[T]) = {
assert(valid.length == bits.length)
val numSelectGroups = scala.math.ceil(valid.length.toFloat / SelectGroupSize).toInt

// group info
val selectValidGroups = valid.grouped(SelectGroupSize).toList
val selectBitsGroups = bits.grouped(SelectGroupSize).toList
// select logic
if (valid.length <= SelectGroupSize) {
val (selValid, selBits) = selectPartialOldest(valid, bits)
val selValidNext = GatedValidRegNext(selValid(0))
val selBitsNext = RegEnable(selBits(0), selValid(0))
(Seq(selValidNext && !selBitsNext.uop.robIdx.needFlush(RegNext(io.redirect))), Seq(selBitsNext))
} else {
val select = (0 until numSelectGroups).map(g => {
val (selValid, selBits) = selectPartialOldest(selectValidGroups(g), selectBitsGroups(g))
val selValidNext = RegNext(selValid(0))
val selBitsNext = RegEnable(selBits(0), selValid(0))
(selValidNext && !selBitsNext.uop.robIdx.needFlush(io.redirect) && !selBitsNext.uop.robIdx.needFlush(RegNext(io.redirect)), selBitsNext)
})
selectOldest(select.map(_._1), select.map(_._2))
}
}

val storeIn = io.storeIn

def detectRollback(i: Int) = {
Expand All @@ -312,7 +269,12 @@ class LoadQueueRAW(implicit p: Parameters) extends XSModule
})

// select logic
val lqSelect: (Seq[Bool], Seq[XSBundleWithMicroOp]) = selectOldest(lqViolationSelVec, lqViolationSelUopExts)
val lqSelect: (Seq[Bool], Seq[XSBundleWithMicroOp]) = SelectOldestRobIdx(
valids = lqViolationSelVec,
bits = lqViolationSelUopExts,
groupSizeOpt = Some(SelectGroupSize),
flushOpt = Some(io.redirect)
)

// select one inst
val lqViolation = lqSelect._1(0)
Expand Down
11 changes: 1 addition & 10 deletions src/main/scala/xiangshan/mem/lsqueue/LoadQueueUncache.scala
Original file line number Diff line number Diff line change
Expand Up @@ -528,15 +528,6 @@ class LoadQueueUncache(implicit p: Parameters) extends XSModule
* rollback req
*
******************************************************************/
def selectOldestRedirect(xs: Seq[Valid[Redirect]]): Vec[Bool] = {
val compareVec = (0 until xs.length).map(i => (0 until i).map(j => isAfter(xs(j).bits.robIdx, xs(i).bits.robIdx)))
val resultOnehot = VecInit((0 until xs.length).map(i => Cat((0 until xs.length).map(j =>
(if (j < i) !xs(j).valid || compareVec(i)(j)
else if (j == i) xs(i).valid
else !xs(j).valid || !compareVec(j)(i))
)).andR))
resultOnehot
}
val reqNeedCheck = VecInit((0 until LoadPipelineWidth).map(w =>
s2_enqueue(w) && !s2_enqValidVec(w)
))
Expand All @@ -554,7 +545,7 @@ class LoadQueueUncache(implicit p: Parameters) extends XSModule
redirect.bits.debug_runahead_checkpoint_id := reqSelUops(i).debugInfo.runahead_checkpoint_id
redirect
})
val oldestOneHot = selectOldestRedirect(allRedirect)
val oldestOneHot = Redirect.selectOldestRedirect(allRedirect)
val oldestRedirect = Mux1H(oldestOneHot, allRedirect)
val lastCycleRedirect = Wire(Valid(new Redirect))
lastCycleRedirect.valid := RegNext(io.redirect.valid)
Expand Down
57 changes: 23 additions & 34 deletions src/main/scala/xiangshan/mem/lsqueue/StoreMisalignBuffer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,35 +64,6 @@ class StoreMisalignBuffer(implicit p: Parameters) extends XSModule
SD -> 0xff.U
))

def selectOldest[T <: LsPipelineBundle](valid: Seq[Bool], bits: Seq[T], index: Seq[UInt]): (Seq[Bool], Seq[T], Seq[UInt]) = {
assert(valid.length == bits.length)
if (valid.length == 0 || valid.length == 1) {
(valid, bits, index)
} else if (valid.length == 2) {
val res = Seq.fill(2)(Wire(ValidIO(chiselTypeOf(bits(0)))))
val resIndex = Seq.fill(2)(Wire(chiselTypeOf(index(0))))
for (i <- res.indices) {
res(i).valid := valid(i)
res(i).bits := bits(i)
resIndex(i) := index(i)
}
val oldest = Mux(valid(0) && valid(1),
Mux(isAfter(bits(0).uop.robIdx, bits(1).uop.robIdx) ||
(isNotBefore(bits(0).uop.robIdx, bits(1).uop.robIdx) && bits(0).uop.uopIdx > bits(1).uop.uopIdx), res(1), res(0)),
Mux(valid(0) && !valid(1), res(0), res(1)))

val oldestIndex = Mux(valid(0) && valid(1),
Mux(isAfter(bits(0).uop.robIdx, bits(1).uop.robIdx) ||
(bits(0).uop.robIdx === bits(1).uop.robIdx && bits(0).uop.uopIdx > bits(1).uop.uopIdx), resIndex(1), resIndex(0)),
Mux(valid(0) && !valid(1), resIndex(0), resIndex(1)))
(Seq(oldest.valid), Seq(oldest.bits), Seq(oldestIndex))
} else {
val left = selectOldest(valid.take(valid.length / 2), bits.take(bits.length / 2), index.take(index.length / 2))
val right = selectOldest(valid.takeRight(valid.length - (valid.length / 2)), bits.takeRight(bits.length - (bits.length / 2)), index.takeRight(index.length - (index.length / 2)))
selectOldest(left._1 ++ right._1, left._2 ++ right._2, left._3 ++ right._3)
}
}

val io = IO(new Bundle() {
val redirect = Flipped(Valid(new Redirect))
val req = Vec(enqPortNum, Flipped(Decoupled(new LsPipelineBundle)))
Expand Down Expand Up @@ -135,11 +106,30 @@ class StoreMisalignBuffer(implicit p: Parameters) extends XSModule
val s1_valid = VecInit(io.req.map(x => x.valid))

val s1_index = (0 until io.req.length).map(_.asUInt)
val reqSel = selectOldest(s1_valid, s1_req, s1_index)
val reqSel = ParallelOperation(s1_valid zip s1_req zip s1_index,
(a: ((Bool, LsPipelineBundle), UInt), b: ((Bool, LsPipelineBundle), UInt)) => {
val au = a._1._2.uop
val bu = b._1._2.uop
val aValid = a._1._1
val bValid = b._1._1
val bSel = au.robIdx > bu.robIdx || au.robIdx === bu.robIdx && au.uopIdx > bu.uopIdx
val bits = Mux(
aValid && bValid,
Mux(bSel, b._1._2, a._1._2),
Mux(aValid && !bValid, a._1._2, b._1._2)
)
val idx = Mux(
aValid && bValid,
Mux(bSel, b._2, a._2),
Mux(aValid && !bValid, a._2, b._2)
)
((aValid || bValid, bits), idx)
}
)

val reqSelValid = reqSel._1(0)
val reqSelBits = reqSel._2(0)
val reqSelPort = reqSel._3(0)
val reqSelValid = reqSel._1._1
val reqSelBits = reqSel._1._2
val reqSelPort = reqSel._2

val reqRedirect = reqSelBits.uop.robIdx.needFlush(io.redirect)

Expand Down Expand Up @@ -168,7 +158,6 @@ class StoreMisalignBuffer(implicit p: Parameters) extends XSModule
case (reqPort, index) => reqPort.ready := reqSelCanEnq(index) && (!req_valid || cross4KBPageBoundary && cross4KBPageEnq)
}


io.toVecStoreMergeBuffer.zipWithIndex.map{
case (toStMB, index) => {
toStMB.flush := req_valid && cross4KBPageBoundary && cross4KBPageEnq && UIntToOH(req.portIndex)(index)
Expand Down
Loading

0 comments on commit 9f5a6f0

Please sign in to comment.