diff --git a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py index 440587a5a6..142f617ee7 100644 --- a/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py +++ b/mlc_llm/dispatch/gpt_neox/dolly_v2_3b.py @@ -806,7 +806,7 @@ def fused_layer_norm1_cast6(sch: tir.Schedule): sch.annotate(block_or_loop=b3, ann_key="meta_schedule.unroll_explicit", ann_val=v10) l11, l12, l13 = sch.get_loops(block=b1) l14 = sch.fuse(l11, l12, l13, preserve_unit_iters=True) - l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 1024], preserve_unit_iters=True) + l15, l16, l17 = sch.split(loop=l14, factors=[None, 256, 256], preserve_unit_iters=True) sch.reorder(l16, l17, l15) sch.bind(loop=l16, thread_axis="blockIdx.x") sch.bind(loop=l17, thread_axis="threadIdx.x")