Skip to content

Commit

Permalink
sms: evict agt entry when dcache refill (#2437)
Browse files Browse the repository at this point in the history
* sms: evict agt entry when dcache refill

* fix compile

* sms: evict on any region match
  • Loading branch information
happy-lx authored Oct 31, 2023
1 parent 88e7a6d commit 6005a7e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/main/scala/xiangshan/backend/MemBlock.scala
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class MemBlockImp(outer: MemBlock) extends LazyModuleImp(outer)
val stData = stdExeUnits.map(_.io.out)
val exeUnits = loadUnits ++ storeUnits
val l1_pf_req = Wire(Decoupled(new L1PrefetchReq()))
dcache.io.sms_agt_evict_req.ready := false.B
val prefetcherOpt: Option[BasePrefecher] = coreParams.prefetcher.map {
case _: SMSParams =>
val sms = Module(new SMSPrefetcher())
Expand All @@ -262,6 +263,7 @@ class MemBlockImp(outer: MemBlock) extends LazyModuleImp(outer)
sms.io_act_threshold := RegNextN(io.ooo_to_mem.csrCtrl.l1D_pf_active_threshold, 2, Some(12.U))
sms.io_act_stride := RegNextN(io.ooo_to_mem.csrCtrl.l1D_pf_active_stride, 2, Some(30.U))
sms.io_stride_en := false.B
sms.io_dcache_evict <> dcache.io.sms_agt_evict_req
sms
}
prefetcherOpt.foreach{ pf => pf.io.l1_req.ready := false.B }
Expand Down
2 changes: 2 additions & 0 deletions src/main/scala/xiangshan/cache/dcache/DCacheWrapper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ class DCacheIO(implicit p: Parameters) extends DCacheBundle {
val lqEmpty = Input(Bool())
val pf_ctrl = Output(new PrefetchControlBundle)
val force_write = Input(Bool())
val sms_agt_evict_req = DecoupledIO(new AGTEvictReq)
val debugTopDown = new DCacheTopDownIO
val debugRolling = Flipped(new RobDebugRollingIO)
}
Expand Down Expand Up @@ -847,6 +848,7 @@ class DCacheImp(outer: DCache) extends LazyModuleImp(outer) with HasDCacheParame
missQueue.io.hartId := io.hartId
missQueue.io.l2_pf_store_only := RegNext(io.l2_pf_store_only, false.B)
missQueue.io.debugTopDown <> io.debugTopDown
missQueue.io.sms_agt_evict_req <> io.sms_agt_evict_req
io.memSetPattenDetected := missQueue.io.memSetPattenDetected

val errors = ldu.map(_.io.error) ++ // load error
Expand Down
24 changes: 24 additions & 0 deletions src/main/scala/xiangshan/cache/dcache/mainpipe/MissQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,8 @@ class MissEntry(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule {
val forwardInfo = Output(new MissEntryForwardIO)
val l2_pf_store_only = Input(Bool())

val sms_agt_evict_req = ValidIO(new AGTEvictReq)

// whether the pipeline reg has send out an acquire
val acquire_fired_by_pipe_reg = Input(Bool())
val memSetPattenDetected = Input(Bool())
Expand Down Expand Up @@ -367,6 +369,8 @@ class MissEntry(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule {
val should_refill_data_reg = Reg(Bool())
val should_refill_data = WireInit(should_refill_data_reg)

val should_replace = RegInit(false.B)

// val full_overwrite = req.isFromStore && req_store_mask.andR
val full_overwrite = Reg(Bool())

Expand Down Expand Up @@ -431,6 +435,9 @@ class MissEntry(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule {
when (!miss_req_pipe_reg_bits.hit && miss_req_pipe_reg_bits.replace_coh.isValid() && !miss_req_pipe_reg_bits.isFromAMO) {
s_replace_req := false.B
w_replace_resp := false.B
should_replace := true.B
}.otherwise {
should_replace := false.B
}

when (miss_req_pipe_reg_bits.isFromAMO) {
Expand Down Expand Up @@ -717,6 +724,9 @@ class MissEntry(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule {
refill.alias := req.vaddr(13, 12) // TODO
assert(!io.refill_pipe_req.valid || (refill.meta.coh =/= ClientMetadata(Nothing)), "refill modifies meta to Nothing, should not happen")

io.sms_agt_evict_req.valid := io.refill_pipe_req.fire && should_replace && req_valid
io.sms_agt_evict_req.bits.vaddr := Cat(req.replace_tag(tagBits - 1, 2), req.vaddr(13, 12), 0.U((VAddrBits - tagBits).W))

io.main_pipe_req.valid := !s_mainpipe_req && w_grantlast
io.main_pipe_req.bits := DontCare
io.main_pipe_req.bits.miss := true.B
Expand Down Expand Up @@ -826,6 +836,8 @@ class MissQueue(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule wi
val tag = UInt(tagBits.W) // paddr
}))

val sms_agt_evict_req = DecoupledIO(new AGTEvictReq)

// forward missqueue
val forward = Vec(LoadPipelineWidth, new LduToMissqueueForwardIO)
val l2_pf_store_only = Input(Bool())
Expand Down Expand Up @@ -1024,6 +1036,18 @@ class MissQueue(edge: TLEdgeOut)(implicit p: Parameters) extends DCacheModule wi
io.main_pipe_req.bits := Mux1H(main_pipe_req_vec.map(_.valid), main_pipe_req_vec.map(_.bits))
assert(PopCount(VecInit(main_pipe_req_vec.map(_.valid))) <= 1.U, "multi main pipe req")

// send evict hint to sms
val sms_agt_evict_valid = Cat(entries.map(_.io.sms_agt_evict_req.valid)).orR
val sms_agt_evict_valid_reg = RegInit(false.B)
io.sms_agt_evict_req.valid := sms_agt_evict_valid_reg
io.sms_agt_evict_req.bits := RegEnable(Mux1H(entries.map(_.io.sms_agt_evict_req.valid), entries.map(_.io.sms_agt_evict_req.bits)), sms_agt_evict_valid)
when(sms_agt_evict_valid) {
sms_agt_evict_valid_reg := true.B
}.elsewhen(io.sms_agt_evict_req.fire) {
sms_agt_evict_valid_reg := false.B
}
assert(PopCount(VecInit(entries.map(_.io.sms_agt_evict_req.valid))) <= 1.U, "multi sms_agt_evict req")

io.probe_block := Cat(probe_block_vec).orR

io.full := ~Cat(entries.map(_.io.primary_ready)).andR
Expand Down
36 changes: 32 additions & 4 deletions src/main/scala/xiangshan/mem/prefetch/SMSPrefetcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ class PfGenReq()(implicit p: Parameters) extends XSBundle with HasSMSModuleHelpe
val debug_source_type = UInt(log2Up(nSourceType).W)
}

class AGTEvictReq()(implicit p: Parameters) extends XSBundle {
val vaddr = UInt(VAddrBits.W)
}

class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasSMSModuleHelper {
val io = IO(new Bundle() {
val agt_en = Input(Bool())
Expand All @@ -267,6 +271,8 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
val region_paddr = UInt(REGION_ADDR_BITS.W)
val region_vaddr = UInt(REGION_ADDR_BITS.W)
}))
// dcache has released a block, evict it from agt
val s0_dcache_evict = Flipped(DecoupledIO(new AGTEvictReq))
val s1_sel_stride = Output(Bool())
val s2_stride_hit = Input(Bool())
// if agt/stride missed, try lookup pht
Expand All @@ -287,6 +293,10 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
val s0_lookup = io.s0_lookup.bits
val s0_lookup_valid = io.s0_lookup.valid

val s0_dcache_evict = io.s0_dcache_evict.bits
val s0_dcache_evict_valid = io.s0_dcache_evict.valid
val s0_dcache_evict_tag = block_hash_tag(s0_dcache_evict.vaddr).head(REGION_TAG_WIDTH)

val prev_lookup = RegEnable(s0_lookup, s0_lookup_valid)
val prev_lookup_valid = RegNext(s0_lookup_valid, false.B)

Expand All @@ -306,6 +316,14 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
val any_region_p1_match = Cat(region_p1_match_vec_s0).orR && s0_lookup.allow_cross_region_p1
val any_region_m1_match = Cat(region_m1_match_vec_s0).orR && s0_lookup.allow_cross_region_m1

val region_match_vec_dcache_evict_s0 = gen_match_vec(s0_dcache_evict_tag)
val any_region_dcache_evict_match = Cat(region_match_vec_dcache_evict_s0).orR
// s0 dcache evict a entry that may be replaced in s1
val s0_dcache_evict_conflict = Cat(VecInit(region_match_vec_dcache_evict_s0).asUInt & s1_replace_mask_w).orR
val s0_do_dcache_evict = io.s0_dcache_evict.fire && any_region_dcache_evict_match

io.s0_dcache_evict.ready := !s0_lookup_valid && !s0_dcache_evict_conflict

val s0_region_hit = any_region_match
val s0_cross_region_hit = any_region_m1_match || any_region_p1_match
val s0_alloc = s0_lookup_valid && !s0_region_hit && !s0_match_prev
Expand Down Expand Up @@ -350,8 +368,13 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
val s1_cross_region_match = RegNext(s0_lookup_valid && s0_cross_region_hit, false.B)
val s1_alloc = RegNext(s0_alloc, false.B)
val s1_alloc_entry = s1_agt_entry
val s1_replace_mask = RegEnable(s0_replace_mask, s0_lookup_valid)
s1_replace_mask_w := s1_replace_mask & Fill(smsParams.active_gen_table_size, s1_alloc)
val s1_do_dcache_evict = RegNext(s0_do_dcache_evict, false.B)
val s1_replace_mask = Mux(
s1_do_dcache_evict,
RegEnable(VecInit(region_match_vec_dcache_evict_s0).asUInt, s0_do_dcache_evict),
RegEnable(s0_replace_mask, s0_lookup_valid)
)
s1_replace_mask_w := s1_replace_mask & Fill(smsParams.active_gen_table_size, s1_alloc || s1_do_dcache_evict)
val s1_evict_entry = Mux1H(s1_replace_mask, entries)
val s1_evict_valid = Mux1H(s1_replace_mask, valids)
// pf gen
Expand Down Expand Up @@ -446,8 +469,9 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
io.s1_sel_stride := prev_lookup_valid && (s1_alloc && s1_cross_region_match || s1_update) && !s1_in_active_page

// stage2: gen pf reg / evict entry to pht
val s2_evict_entry = RegEnable(s1_evict_entry, s1_alloc)
val s2_evict_valid = RegNext(s1_alloc && s1_evict_valid, false.B)
val s2_do_dcache_evict = RegNext(s1_do_dcache_evict, false.B)
val s2_evict_entry = RegEnable(s1_evict_entry, s1_alloc || s1_do_dcache_evict)
val s2_evict_valid = RegNext((s1_alloc || s1_do_dcache_evict) && s1_evict_valid, false.B)
val s2_paddr_valid = RegEnable(s1_pf_gen_paddr_valid, s1_pf_gen_valid)
val s2_pf_gen_region_tag = RegEnable(s1_pf_gen_region_tag, s1_pf_gen_valid)
val s2_pf_gen_decr_mode = RegEnable(s1_pf_gen_decr_mode, s1_pf_gen_valid)
Expand Down Expand Up @@ -489,6 +513,8 @@ class ActiveGenerationTable()(implicit p: Parameters) extends XSModule with HasS
)
}
XSPerfAccumulate("sms_agt_evict", s2_evict_valid)
XSPerfAccumulate("sms_agt_evict_by_plru", s2_evict_valid && !s2_do_dcache_evict)
XSPerfAccumulate("sms_agt_evict_by_dcache", s2_evict_valid && s2_do_dcache_evict)
XSPerfAccumulate("sms_agt_evict_one_hot_pattern", s2_evict_valid && (s2_evict_entry.access_cnt === 1.U))
}

Expand Down Expand Up @@ -1066,6 +1092,7 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM
val io_pht_en = IO(Input(Bool()))
val io_act_threshold = IO(Input(UInt(REGION_OFFSET.W)))
val io_act_stride = IO(Input(UInt(6.W)))
val io_dcache_evict = IO(Flipped(DecoupledIO(new AGTEvictReq)))

val train_filter = Module(new SMSTrainFilter)

Expand Down Expand Up @@ -1135,6 +1162,7 @@ class SMSPrefetcher()(implicit p: Parameters) extends BasePrefecher with HasSMSM
active_gen_table.io.s0_lookup.bits.region_paddr := train_region_paddr_s0
active_gen_table.io.s0_lookup.bits.region_vaddr := train_region_vaddr_s0
active_gen_table.io.s2_stride_hit := stride.io.s2_gen_req.valid
active_gen_table.io.s0_dcache_evict <> io_dcache_evict

stride.io.stride_en := io_stride_en
stride.io.s0_lookup.valid := train_vld_s0
Expand Down

0 comments on commit 6005a7e

Please sign in to comment.