diff --git a/src/main/scala/xiangshan/frontend/BPU.scala b/src/main/scala/xiangshan/frontend/BPU.scala index a86f8c90e5d..584c2fe055a 100644 --- a/src/main/scala/xiangshan/frontend/BPU.scala +++ b/src/main/scala/xiangshan/frontend/BPU.scala @@ -258,7 +258,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H for (i <- 0 until numOfStage - 1) { topdown_stages(i + 1) := topdown_stages(i) } - + // ctrl signal @@ -392,7 +392,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H for (((((s1_fire, s2_flush), s2_fire), s2_valid), s1_flush) <- s1_fire_dup zip s2_flush_dup zip s2_fire_dup zip s2_valid_dup zip s1_flush_dup) { - + when (s2_flush) { s2_valid := false.B } .elsewhen(s1_fire) { s2_valid := !s1_flush } .elsewhen(s2_fire) { s2_valid := false.B } @@ -405,7 +405,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H for (((((s2_fire, s3_flush), s3_fire), s3_valid), s2_flush) <- s2_fire_dup zip s3_flush_dup zip s3_fire_dup zip s3_valid_dup zip s2_flush_dup) { - + when (s3_flush) { s3_valid := false.B } .elsewhen(s2_fire) { s3_valid := !s2_flush } .elsewhen(s3_fire) { s3_valid := false.B } @@ -445,7 +445,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H } .otherwise { full_pred_diff_stage := 3.U } - } + } } XSError(full_pred_diff, "Full prediction difference detected!") @@ -471,13 +471,13 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H // s1 val s1_possible_predicted_ghist_ptrs_dup = s1_ghist_ptr_dup.map(ptr => (0 to numBr).map(ptr - _.U)) val s1_predicted_ghist_ptr_dup = s1_possible_predicted_ghist_ptrs_dup.zip(resp.s1.lastBrPosOH).map{ case (ptr, oh) => Mux1H(oh, ptr)} - val s1_possible_predicted_fhs_dup = + val s1_possible_predicted_fhs_dup = for (((((fgh, afh), br_num_oh), t), br_pos_oh) <- s1_folded_gh_dup zip s1_ahead_fh_oldest_bits_dup zip s1_last_br_num_oh_dup zip resp.s1.brTaken zip resp.s1.lastBrPosOH) yield (0 to numBr).map(i => fgh.update(afh, br_num_oh, i, t & br_pos_oh(i)) ) - val s1_predicted_fh_dup = resp.s1.lastBrPosOH.zip(s1_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} + val s1_predicted_fh_dup = resp.s1.lastBrPosOH.zip(s1_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} val s1_ahead_fh_ob_src_dup = dup_wire(new AllAheadFoldedHistoryOldestBits(foldedGHistInfos)) s1_ahead_fh_ob_src_dup.zip(s1_ghist_ptr_dup).map{ case (src, ptr) => src.read(ghv, ptr)} @@ -521,21 +521,35 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H } class PreviousPredInfo extends Bundle { + val hit = Vec(numDup, Bool()) val target = Vec(numDup, UInt(VAddrBits.W)) val lastBrPosOH = Vec(numDup, Vec(numBr+1, Bool())) val taken = Vec(numDup, Bool()) + val takenMask = Vec(numDup, Vec(numBr, Bool())) val cfiIndex = Vec(numDup, UInt(log2Ceil(PredictWidth).W)) } def preds_needs_redirect_vec_dup(x: PreviousPredInfo, y: BranchPredictionBundle) = { - val target_diff = x.target.zip(y.getTarget).map {case (t1, t2) => t1 =/= t2 } - val lastBrPosOH_diff = x.lastBrPosOH.zip(y.lastBrPosOH).map {case (oh1, oh2) => oh1.asUInt =/= oh2.asUInt} - val taken_diff = x.taken.zip(y.taken).map {case (t1, t2) => t1 =/= t2} - val takenOffset_diff = x.cfiIndex.zip(y.cfiIndex).zip(x.taken).zip(y.taken).map {case (((i1, i2), xt), yt) => xt && yt && i1 =/= i2.bits} + // Timing optimization + // We first compare all target with previous stage target, + // then select the difference by taken & hit + // Usually target is generated quicker than taken, so do target compare before select can help timing + val targetDiffVec: IndexedSeq[Vec[Bool]] = + x.target.zip(y.getAllTargets).map { + case (t1, t2) => VecInit(t2.map(_ =/= t1)) + } // [0:numDup][flattened all Target comparison] + val targetDiff : IndexedSeq[Bool] = + targetDiffVec.zip(x.hit).zip(x.takenMask).map { + case ((diff, hit), takenMask) => selectByTaken(takenMask, hit, diff) + } + + val lastBrPosOHDiff: IndexedSeq[Bool] = x.lastBrPosOH.zip(y.lastBrPosOH).map { case (oh1, oh2) => oh1.asUInt =/= oh2.asUInt } + val takenDiff : IndexedSeq[Bool] = x.taken.zip(y.taken).map { case (t1, t2) => t1 =/= t2 } + val takenOffsetDiff: IndexedSeq[Bool] = x.cfiIndex.zip(y.cfiIndex).zip(x.taken).zip(y.taken).map { case (((i1, i2), xt), yt) => xt && yt && i1 =/= i2.bits } VecInit( for ((((tgtd, lbpohd), tkd), tod) <- - target_diff zip lastBrPosOH_diff zip taken_diff zip takenOffset_diff) - yield VecInit(tgtd, lbpohd, tkd, tod) + targetDiff zip lastBrPosOHDiff zip takenDiff zip takenOffsetDiff) + yield VecInit(tgtd, lbpohd, tkd, tod) // x.shouldShiftVec.asUInt =/= y.shouldShiftVec.asUInt, // x.brTaken =/= y.brTaken ) @@ -545,13 +559,13 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H val s2_possible_predicted_ghist_ptrs_dup = s2_ghist_ptr_dup.map(ptr => (0 to numBr).map(ptr - _.U)) val s2_predicted_ghist_ptr_dup = s2_possible_predicted_ghist_ptrs_dup.zip(resp.s2.lastBrPosOH).map{ case (ptr, oh) => Mux1H(oh, ptr)} - val s2_possible_predicted_fhs_dup = + val s2_possible_predicted_fhs_dup = for ((((fgh, afh), br_num_oh), full_pred) <- s2_folded_gh_dup zip s2_ahead_fh_oldest_bits_dup zip s2_last_br_num_oh_dup zip resp.s2.full_pred) yield (0 to numBr).map(i => fgh.update(afh, br_num_oh, i, if (i > 0) full_pred.br_taken_mask(i-1) else false.B) ) - val s2_predicted_fh_dup = resp.s2.lastBrPosOH.zip(s2_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} + val s2_predicted_fh_dup = resp.s2.lastBrPosOH.zip(s2_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} val s2_ahead_fh_ob_src_dup = dup_wire(new AllAheadFoldedHistoryOldestBits(foldedGHistInfos)) s2_ahead_fh_ob_src_dup.zip(s2_ghist_ptr_dup).map{ case (src, ptr) => src.read(ghv, ptr)} @@ -580,10 +594,12 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H ) val s1_pred_info = Wire(new PreviousPredInfo) + s1_pred_info.hit := resp.s1.full_pred.map(_.hit) s1_pred_info.target := resp.s1.getTarget s1_pred_info.lastBrPosOH := resp.s1.lastBrPosOH s1_pred_info.taken := resp.s1.taken - s1_pred_info.cfiIndex := resp.s1.cfiIndex.map{case x => x.bits} + s1_pred_info.takenMask := resp.s1.full_pred.map(_.taken_mask_on_slot) + s1_pred_info.cfiIndex := resp.s1.cfiIndex.map { case x => x.bits } val previous_s1_pred_info = RegEnable(s1_pred_info, 0.U.asTypeOf(new PreviousPredInfo), s1_fire_dup(0)) @@ -630,7 +646,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H yield (0 to numBr).map(i => fgh.update(afh, br_num_oh, i, if (i > 0) full_pred.br_taken_mask(i-1) else false.B) ) - val s3_predicted_fh_dup = resp.s3.lastBrPosOH.zip(s3_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} + val s3_predicted_fh_dup = resp.s3.lastBrPosOH.zip(s3_possible_predicted_fhs_dup).map{ case (oh, fh) => Mux1H(oh, fh)} val s3_ahead_fh_ob_src_dup = dup_wire(new AllAheadFoldedHistoryOldestBits(foldedGHistInfos)) s3_ahead_fh_ob_src_dup.zip(s3_ghist_ptr_dup).map{ case (src, ptr) => src.read(ghv, ptr)} @@ -714,7 +730,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H predictors.io.update := RegNext(io.ftq_to_bpu.update) predictors.io.update.bits.ghist := RegNext(getHist(io.ftq_to_bpu.update.bits.spec_info.histPtr)) - + val redirect_dup = do_redirect_dup.map(_.bits) predictors.io.redirect := do_redirect_dup(0) @@ -737,7 +753,7 @@ class Predictor(implicit p: Parameters) extends XSModule with HasBPUConst with H val oldPtr_dup = redirect_dup.map(_.cfiUpdate.histPtr) val oldFh_dup = redirect_dup.map(_.cfiUpdate.folded_hist) val updated_ptr_dup = oldPtr_dup.zip(shift_dup).map {case (oldPtr, shift) => oldPtr - shift} - val updated_fh_dup = + val updated_fh_dup = for ((((((oldFh, afhob), lastBrNumOH), taken), addIntoHist), shift) <- oldFh_dup zip afhob_dup zip lastBrNumOH_dup zip taken_dup zip addIntoHist_dup zip shift_dup) yield VecInit((0 to numBr).map(i => oldFh.update(afhob, lastBrNumOH, i, taken && addIntoHist)))(shift) diff --git a/src/main/scala/xiangshan/frontend/FrontendBundle.scala b/src/main/scala/xiangshan/frontend/FrontendBundle.scala index b39b1b0f234..8f25914d44d 100644 --- a/src/main/scala/xiangshan/frontend/FrontendBundle.scala +++ b/src/main/scala/xiangshan/frontend/FrontendBundle.scala @@ -412,6 +412,17 @@ trait BasicPrediction extends HasXSParameter { def fallThruError: Bool } +// selectByTaken selects some data according to takenMask +// allTargets should be in flattened 2-dim Vec, like [taken, not taken, not hit, taken, ...] +object selectByTaken { + def apply[T <: Data](takenMask: Vec[Bool], hit: Bool, allTargets: Vec[T]): T = { + val selVecOH = + takenMask.zipWithIndex.map { case (t, i) => !takenMask.take(i).fold(false.B)(_ || _) && t && hit } :+ + (!takenMask.asUInt.orR && hit) :+ !hit + Mux1H(selVecOH, allTargets) + } +} + class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUConst with BasicPrediction { val br_taken_mask = Vec(numBr, Bool()) @@ -455,7 +466,7 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC def real_slot_taken_mask(): Vec[Bool] = { VecInit(taken_mask_on_slot.map(_ && hit)) } - + // len numBr def real_br_taken_mask(): Vec[Bool] = { VecInit( @@ -482,12 +493,17 @@ class FullBranchPrediction(implicit p: Parameters) extends XSBundle with HasBPUC def brTaken = (br_valids zip br_taken_mask).map{ case (a, b) => a && b && hit}.reduce(_||_) def target(pc: UInt): UInt = { - val targetVec = targets :+ fallThroughAddr :+ (pc + (FetchWidth * 4).U) - val tm = taken_mask_on_slot - val selVecOH = - tm.zipWithIndex.map{ case (t, i) => !tm.take(i).fold(false.B)(_||_) && t && hit} :+ - (!tm.asUInt.orR && hit) :+ !hit - Mux1H(selVecOH, targetVec) + selectByTaken(taken_mask_on_slot, hit, allTarget(pc)) + } + + // allTarget return a flattened 2-dim Vec of all possible target of a BP stage + // in the following order: [0:totalSlot][taken_targets, fallThroughAddr, not hit (plus fetch width)] + // after flatten looks like [t0, f0, n0, t1, f1, n0, ...] (t,f,n stands for taken, fallthrough, not hit) + // + // This exposes internal targets for timing optimization, + // since usually targets are generated quicker than taken + def allTarget(pc: UInt): Vec[UInt] = { + VecInit(targets :+ fallThroughAddr :+ (pc + (FetchWidth * 4).U)) } def fallThruError: Bool = hit && fallThroughErr @@ -564,7 +580,8 @@ class BranchPredictionBundle(implicit p: Parameters) extends XSBundle def target(pc: UInt) = VecInit(full_pred.map(_.target(pc))) - def targets(pc: Vec[UInt]) = VecInit(pc.zipWithIndex.map{case (a, i) => full_pred(i).target(a)}) + def targets(pc: Vec[UInt]) = VecInit(pc.zipWithIndex.map{case (pc, idx) => full_pred(idx).target(pc)}) + def allTargets(pc: Vec[UInt]) = VecInit(pc.zipWithIndex.map{case (pc, idx) => full_pred(idx).allTarget(pc)}) def cfiIndex = VecInit(full_pred.map(_.cfiIndex)) def lastBrPosOH = VecInit(full_pred.map(_.lastBrPosOH)) def brTaken = VecInit(full_pred.map(_.brTaken)) @@ -574,6 +591,7 @@ class BranchPredictionBundle(implicit p: Parameters) extends XSBundle def taken = VecInit(cfiIndex.map(_.valid)) def getTarget = targets(pc) + def getAllTargets = allTargets(pc) def display(cond: Bool): Unit = { XSDebug(cond, p"[pc] ${Hexadecimal(pc(0))}\n")