diff --git a/mlc_llm/dispatch/llama/main.py b/mlc_llm/dispatch/llama/main.py index 6a993474d8..654e258743 100644 --- a/mlc_llm/dispatch/llama/main.py +++ b/mlc_llm/dispatch/llama/main.py @@ -855,11 +855,11 @@ def matmul1_before(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, m for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), T.int64(128), n): with T.block("matmul"): v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) - T.reads(rxplaceholder[T.int64(0), v_i1, v_i2, v_k], rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3]) + T.reads(rxplaceholder[v_i0, v_i1, v_i2, v_k], rxplaceholder_1[v_i0, v_i1, v_k, v_i3]) T.writes(matmul[v_i0, v_i1, v_i2, v_i3]) with T.init(): matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0) - matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[T.int64(0), v_i1, v_i2, v_k] * rxplaceholder_1[T.int64(0), v_i1, v_k, v_i3] + matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, v_i3] + rxplaceholder[v_i0, v_i1, v_i2, v_k] * rxplaceholder_1[v_i0, v_i1, v_k, v_i3] @T.prim_func @@ -2827,7 +2827,142 @@ def fused_NT_matmul_divide_maximum_minimum_cast_sch_func(): return sch.mod["main"].with_attr("tir.is_scheduled", 1) +@T.prim_func +def fused_NT_matmul_divide_maximum_minimum_before(lv1540: T.Buffer((T.int64(1), T.int64(32), T.int64(1), T.int64(128)), "float32"), p_lv1541: T.handle, p_lv1517: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv1541 = T.match_buffer(p_lv1541, (T.int64(1), T.int64(32), n, T.int64(128))) + lv1517 = T.match_buffer(p_lv1517, (T.int64(1), T.int64(1), T.int64(1), n)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), T.int64(1), n)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1), n)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), n, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv1540[v_i0, v_i1, v_i2, v_k], lv1541[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv1540[v_i0, v_i1, v_i2, v_k] * lv1541[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), T.int64(1), n): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv1517[v_ax0, T.int64(0), v_ax2, v_ax3]) +def fused_NT_matmul_divide_maximum_minimum_sch_func(): + sch = tvm.tir.Schedule(fused_NT_matmul_divide_maximum_minimum_before) + b0 = sch.get_block("NT_matmul") + sch.pad_einsum(b0, [1, 1, 1, 32, 1]) + l1, l2, l3, l4, l5 = sch.get_loops(b0) + l6, l7 = sch.split(l4, [None, 32]) + sch.reorder(l6, l1, l2, l3, l7, l5) + + b0 = sch.get_block(name="NT_matmul", func_name="main") + b1 = sch.get_block(name="T_divide", func_name="main") + b2 = sch.get_block(name="T_maximum", func_name="main") + b3 = sch.get_block(name="T_minimum", func_name="main") + b4 = sch.get_block(name="root", func_name="main") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SSSRRSRS") + _, l5, l6, l7, l8, l9 = sch.get_loops(block=b0) + v10, v11, v12, v13, v14 = sch.sample_perfect_tile(loop=l5, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l15, l16, l17, l18, l19 = sch.split(loop=l5, factors=[v10, v11, v12, v13, v14], preserve_unit_iters=True) + v20, v21, v22, v23, v24 = sch.sample_perfect_tile(loop=l6, n=5, max_innermost_factor=64, decision=[8, 1, 4, 1, 1]) + l25, l26, l27, l28, l29 = sch.split(loop=l6, factors=[v20, v21, v22, v23, v24], preserve_unit_iters=True) + v30, v31, v32, v33, v34 = sch.sample_perfect_tile(loop=l7, n=5, max_innermost_factor=64, decision=[1, 1, 1, 1, 1]) + l35, l36, l37, l38, l39 = sch.split(loop=l7, factors=[v30, v31, v32, v33, v34], preserve_unit_iters=True) + v40, v41, v42, v43, v44 = sch.sample_perfect_tile(loop=l8, n=5, max_innermost_factor=64, decision=[2, 1, 16, 1, 1]) + l45, l46, l47, l48, l49 = sch.split(loop=l8, factors=[v40, v41, v42, v43, v44], preserve_unit_iters=True) + v50, v51, v52 = sch.sample_perfect_tile(loop=l9, n=3, max_innermost_factor=64, decision=[4, 4, 8]) + l53, l54, l55 = sch.split(loop=l9, factors=[v50, v51, v52], preserve_unit_iters=True) + sch.reorder(l15, l25, l35, l45, l16, l26, l36, l46, l17, l27, l37, l47, l53, l54, l18, l28, l38, l48, l55, l19, l29, l39, l49) + l56 = sch.fuse(l15, l25, l35, l45, preserve_unit_iters=True) + sch.bind(loop=l56, thread_axis="blockIdx.x") + l57 = sch.fuse(l16, l26, l36, l46, preserve_unit_iters=True) + sch.bind(loop=l57, thread_axis="vthread.x") + l58 = sch.fuse(l17, l27, l37, l47, preserve_unit_iters=True) + sch.bind(loop=l58, thread_axis="threadIdx.x") + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_low_inclusive", ann_val=32) + sch.annotate(block_or_loop=b0, ann_key="meta_schedule.thread_extent_high_inclusive", ann_val=256) + b59 = sch.cache_write(block=b0, write_buffer_index=0, storage_scope="local") + sch.reverse_compute_at(block=b59, loop=l58, preserve_unit_loops=True, index=-1) + b60 = sch.cache_read(block=b0, read_buffer_index=0, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b60, loop=l53, preserve_unit_loops=True, index=-1) + _, l61, l62, l63, l64, l65, l66, l67, l68 = sch.get_loops(block=b60) + l69 = sch.fuse(l65, l66, l67, l68, preserve_unit_iters=True) + v70 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch", ann_val=v70) + b71 = sch.cache_read(block=b0, read_buffer_index=1, storage_scope="shared", consumer_blocks=[b0]) + sch.compute_at(block=b71, loop=l53, preserve_unit_loops=True, index=-1) + _, l72, l73, l74, l75, l76, l77, l78, l79 = sch.get_loops(block=b71) + l80 = sch.fuse(l76, l77, l78, l79, preserve_unit_iters=True) + v81 = sch.sample_categorical(candidates=[1, 2, 4, 8], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + sch.annotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch", ann_val=v81) + sch.reverse_compute_inline(block=b3) + sch.compute_inline(block=b1) + v82 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=1) + sch.annotate(block_or_loop=b4, ann_key="meta_schedule.unroll_explicit", ann_val=v82) + + # inline ewise + sch.reverse_compute_inline(b2) + # l83, l84, l85, l86 = sch.get_loops(block=b2) + # l87 = sch.fuse(l83, l84, l85, l86, preserve_unit_iters=True) + # v88 = sch.sample_categorical(candidates=[32, 64, 128, 256], probs=[0.25, 0.25, 0.25, 0.25], decision=0) + # l89, l90 = sch.split(loop=l87, factors=[None, v88], preserve_unit_iters=True) + # sch.bind(loop=l89, thread_axis="blockIdx.x") + # sch.bind(loop=l90, thread_axis="threadIdx.x") + + sch.enter_postproc() + sch.unannotate(block_or_loop=b60, ann_key="meta_schedule.cooperative_fetch") + _, l91, l92, l93, l94, l95 = sch.get_loops(block=b60) + l96, l97 = sch.split(loop=l95, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l97, thread_axis="threadIdx.x") + sch.unannotate(block_or_loop=b71, ann_key="meta_schedule.cooperative_fetch") + _, l98, l99, l100, l101, l102 = sch.get_loops(block=b71) + l103, l104 = sch.split(loop=l102, factors=[None, 64], preserve_unit_iters=True) + sch.bind(loop=l104, thread_axis="threadIdx.x") + b105 = sch.get_block(name="root", func_name="main") + sch.unannotate(block_or_loop=b105, ann_key="meta_schedule.unroll_explicit") + _, b106, b107, b108, b109, _ = sch.get_child_blocks(b105) + _, l111, l112, l113, l114, l115, l116 = sch.get_loops(block=b106) + sch.annotate(block_or_loop=l111, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l111, ann_key="pragma_unroll_explicit", ann_val=1) + _, l117, l118, l119, l120, l121, l122 = sch.get_loops(block=b107) + sch.annotate(block_or_loop=l117, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l117, ann_key="pragma_unroll_explicit", ann_val=1) + _, l123, l124, l125, l126, l127, l128, l129, l130, l131, l132, l133, l134, l135, l136 = sch.get_loops(block=b108) + sch.annotate(block_or_loop=l123, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l123, ann_key="pragma_unroll_explicit", ann_val=1) + _, l137, l138, l139, l140, l141, l142, l143 = sch.get_loops(block=b109) + sch.annotate(block_or_loop=l137, ann_key="pragma_auto_unroll_max_step", ann_val=16) + sch.annotate(block_or_loop=l137, ann_key="pragma_unroll_explicit", ann_val=1) + + b146 = sch.get_block(name="NT_matmul", func_name="main") + l0, l147, l148, l149, l150, l151, l152, l153, l154, l155, l156, l157, l158, l159, l160 = sch.get_loops(block=b146) + sch.bind(l0, "blockIdx.y") + b161 = sch.decompose_reduction(block=b146, loop=l150) + + b1 = sch.get_block("lv1541_pad") + sch.compute_inline(b1) + b2 = sch.get_block("var_NT_matmul_intermediate_pad") + sch.reverse_compute_inline(b2) + + return sch.mod["main"].with_attr("tir.is_scheduled", 1) @T.prim_func def fused_NT_matmul1_add3_before(p_lv39: T.handle, lv1848: T.Buffer((T.int64(4096), T.int64(4096)), "float16"), p_lv2: T.handle, p_output0: T.handle): @@ -3206,6 +3341,123 @@ def fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after(p_lv22: T.handle, p_l if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.Cast("float32", T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3])) +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_before(p_lv28: T.handle, p_lv29: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True)}) + n = T.int64() + lv28 = T.match_buffer(p_lv28, (T.int64(1), T.int64(32), n, T.int64(128))) + m = T.int64() + lv29 = T.match_buffer(p_lv29, (T.int64(1), T.int64(32), m, T.int64(128))) + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m)) + var_T_minimum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_divide_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + var_T_maximum_intermediate = T.alloc_buffer((T.int64(1), T.int64(32), n, m)) + for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), n, m, T.int64(128)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, i2, i3, k]) + T.reads(lv28[v_i0, v_i1, v_i2, v_k], lv29[v_i0, v_i1, v_i3, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = T.float32(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2, v_i3] + lv28[v_i0, v_i1, v_i2, v_k] * lv29[v_i0, v_i1, v_i3, v_k] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_divide"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] * T.float32(0.088388349161020605) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_maximum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + T.writes(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.max(var_T_divide_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], T.float32(-3.4028234663852886e+38)) + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(32), n, m): + with T.block("T_minimum"): + v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3]) + T.reads(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + T.writes(var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3]) + var_T_minimum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3] = T.min(var_T_maximum_intermediate[v_ax0, v_ax1, v_ax2, v_ax3], lv5[v_ax0, T.int64(0), v_ax2, v_ax3]) + +@T.prim_func +def fused_NT_matmul2_divide1_maximum1_minimum1_after(p_lv22: T.handle, p_lv23: T.handle, p_lv5: T.handle, p_output0: T.handle): + T.func_attr({"tir.noalias": T.bool(True), "tir.is_scheduled": 1}) + n = T.int64() + m = T.int64() + lv22 = T.match_buffer(p_lv22, (T.int64(1), T.int64(32), n, T.int64(128)), "float32") + lv23 = T.match_buffer(p_lv23, (T.int64(1), T.int64(32), m, T.int64(128)), "float32") + lv5 = T.match_buffer(p_lv5, (T.int64(1), T.int64(1), n, m), "float32") + var_T_maximum_intermediate = T.match_buffer(p_output0, (T.int64(1), T.int64(32), n, m)) + # with T.block("root"): + for i2_0_i3_0_fused in T.thread_binding((n + T.int64(31)) // T.int64(32) * ((m + T.int64(31)) // T.int64(32)), thread="blockIdx.y"): + with T.block("NT_matmul_o"): + v_i0 = T.axis.spatial(T.int64(1), T.int64(0)) + v_i2_o = T.axis.spatial((n + T.int64(31)) // T.int64(32), i2_0_i3_0_fused // ((m + T.int64(31)) // T.int64(32))) + v_i3_o = T.axis.spatial((m + T.int64(31)) // T.int64(32), i2_0_i3_0_fused % ((m + T.int64(31)) // T.int64(32))) + T.reads(lv22[T.int64(0), T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv23[T.int64(0), T.int64(0):T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32), T.int64(0):T.int64(128)], lv5[v_i0, T.int64(0), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + T.writes(var_T_maximum_intermediate[v_i0, T.int64(0):T.int64(32), v_i2_o * T.int64(32):v_i2_o * T.int64(32) + T.int64(32), v_i3_o * T.int64(32):v_i3_o * T.int64(32) + T.int64(32)]) + C_pad_local = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(32)), "float32", scope="local") + A_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") + B_pad_shared = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(32), T.int64(128)), "float32", scope="shared") + for i0_0_i1_0_i2_1_0_i3_1_0_fused in T.thread_binding(T.int64(128), thread="blockIdx.x", annotations={"pragma_auto_unroll_max_step": 512, "pragma_unroll_explicit": 1}): + for i0_1_i1_1_i2_1_1_i3_1_1_fused in T.thread_binding(T.int64(4), thread="vthread.x"): + for i0_2_i1_2_i2_1_2_i3_1_2_fused in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for i1_3_init, i2_1_3_init, i3_1_3_init, i1_4_init, i2_1_4_init, i3_1_4_init in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_init"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3_init) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3_init + i2_1_4_init) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4_init + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3_init) + T.reads() + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = T.float32(0) + for k_0 in range(T.int64(16)): + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("A_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3]) + T.writes(A_pad_shared[v0, v1, v2, v3]) + A_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i2_o * T.int64(32) + v2 < n, lv22[v0, v1, v_i2_o * T.int64(32) + v2, v3], T.float32(0)) + for ax0_ax1_ax2_ax3_fused_0 in range(T.int64(1)): + for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(T.int64(64), thread="threadIdx.x"): + for ax0_ax1_ax2_ax3_fused_2 in T.vectorized(T.int64(2)): + with T.block("B_pad_shared"): + v0 = T.axis.spatial(T.int64(1), T.int64(0)) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4)) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) // T.int64(8)) + v3 = T.axis.spatial(T.int64(128), k_0 * T.int64(8) + (ax0_ax1_ax2_ax3_fused_0 * T.int64(128) + ax0_ax1_ax2_ax3_fused_1 * T.int64(2) + ax0_ax1_ax2_ax3_fused_2) % T.int64(8)) + T.reads(lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3]) + T.writes(B_pad_shared[v0, v1, v2, v3]) + B_pad_shared[v0, v1, v2, v3] = T.if_then_else(v_i3_o * T.int64(32) + v2 < m, lv23[v0, v1, v_i3_o * T.int64(32) + v2, v3], T.float32(0)) + for k_1, i0_3, i1_3, i2_1_3, i3_1_3, k_2, i0_4, i1_4, i2_1_4, i3_1_4 in T.grid(T.int64(4), T.int64(1), T.int64(1), T.int64(1), T.int64(1), T.int64(2), T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("NT_matmul_update"): + v_i1_i = T.axis.spatial(T.int64(32), i1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + i1_3) + v_i2_i = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + i2_1_3 + i2_1_4) + v_i3_i = T.axis.spatial(T.int64(32), i3_1_4 + i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + i3_1_3) + v_k_i = T.axis.reduce(T.int64(128), k_0 * T.int64(8) + k_1 * T.int64(2) + k_2) + T.reads(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i], A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i], B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i]) + T.writes(C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i]) + T.block_attr({"meta_schedule.thread_extent_high_inclusive": 256, "meta_schedule.thread_extent_low_inclusive": 32, "meta_schedule.tiling_structure": "SSSRRSRS"}) + C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] = C_pad_local[T.int64(0), v_i1_i, v_i2_i, v_i3_i] + A_pad_shared[T.int64(0), v_i1_i, v_i2_i, v_k_i] * B_pad_shared[T.int64(0), v_i1_i, v_i3_i, v_k_i] + for ax0, ax1, ax2, ax3 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + with T.block("C_pad_local"): + v0 = T.axis.spatial(T.int64(1), ax0) + v1 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused // T.int64(4) + ax1) + v2 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(4) // T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused // T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused // T.int64(8) + ax2) + v3 = T.axis.spatial(T.int64(32), i0_0_i1_0_i2_1_0_i3_1_0_fused % T.int64(2) * T.int64(16) + i0_1_i1_1_i2_1_1_i3_1_1_fused % T.int64(2) * T.int64(8) + i0_2_i1_2_i2_1_2_i3_1_2_fused % T.int64(8) + ax3) + T.reads(C_pad_local[v0, v1, v2, v3], lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + T.writes(var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + # if T.int64(0) <= v_i0 and v_i0 < T.int64(1) and T.int64(0) <= v_i2_o * T.int64(32) + v2 and v_i2_o * T.int64(32) + v2 < n and T.int64(0) <= v_i3_o * T.int64(32) + v3 and v_i3_o * T.int64(32) + v3 < n: + if v_i2_o * T.int64(32) + v2 < n and v_i3_o * T.int64(32) + v3 < m: + var_T_maximum_intermediate[v_i0 + v0, v1, v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3] = T.min(T.max(C_pad_local[v0, v1, v2, v3] * T.float32(0.088397790055248615), T.float16(-65504)), lv5[v_i0 + v0, T.int64(0), v_i2_o * T.int64(32) + v2, v_i3_o * T.int64(32) + v3]) + def fused_NT_matmul2_divide1_add2_maximum1_sch_func(func): sch = tvm.tir.Schedule(func) b0 = sch.get_block("NT_matmul") @@ -4255,7 +4507,7 @@ def fused_decode3_matmul1_before(lv2931: T.Buffer((T.int64(512), T.int64(32000)) @T.prim_func def fused_decode3_matmul1_after(lv1123: T.Buffer((T.int64(512), T.int64(32000)), "uint32"), lv1124: T.Buffer((T.int64(128), T.int64(32000)), "uint32"), lv1511: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), var_matmul_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(32000)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) # with T.block("root"): var_decode_intermediate_pad_local = T.alloc_buffer((T.int64(4096), T.int64(32000)), scope="local") var_matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(32000)), scope="local") @@ -4675,7 +4927,7 @@ def fused_decode6_fused_matmul9_add3_before(lv1623: T.Buffer((T.int64(1376), T.i @T.prim_func def fused_decode6_fused_matmul9_add3_after(lv1158: T.Buffer((T.int64(1376), T.int64(4096)), "uint32"), lv1159: T.Buffer((T.int64(344), T.int64(4096)), "uint32"), lv6: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float32"), lv4: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float32")): - T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)}) + T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True), "tir.is_scheduled": 1}) # with T.block("root"): var_decode_intermediate_local = T.alloc_buffer((T.int64(11008), T.int64(4096)), scope="local") var_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), scope="local") @@ -6398,9 +6650,11 @@ def get_dict_key(func): get_dict_key(fused_NT_matmul2_silu_before): fused_NT_matmul2_silu_after, get_dict_key(fused_NT_matmul3_add1_before): fused_NT_matmul3_add1_after, get_dict_key(fused_NT_matmul_divide_maximum_minimum_cast_before): fused_NT_matmul_divide_maximum_minimum_cast_sch_func(), + get_dict_key(fused_NT_matmul_divide_maximum_minimum_before): fused_NT_matmul_divide_maximum_minimum_sch_func(), get_dict_key(fused_NT_matmul1_add3_before): fused_NT_matmul1_add3_sch_func(), get_dict_key(fused_NT_matmul2_divide1_add2_maximum1_before): fused_NT_matmul2_divide1_add2_maximum1_sch_func(fused_NT_matmul2_divide1_add2_maximum1_before), get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_cast3_before): fused_NT_matmul2_divide1_maximum1_minimum1_cast3_after, + get_dict_key(fused_NT_matmul2_divide1_maximum1_minimum1_before): fused_NT_matmul2_divide1_maximum1_minimum1_after, get_dict_key(fused_NT_matmul3_multiply1_before): fused_NT_matmul3_multiply1_sch_func(), get_dict_key(fused_NT_matmul3_silu1_before): fused_NT_matmul3_silu1_sch_func(), get_dict_key(fused_NT_matmul4_add3_before): fused_NT_matmul4_add3_sch_func(),