From 199b4ce4baefa00959e32fe299f561664bd35388 Mon Sep 17 00:00:00 2001 From: yucai Date: Mon, 11 Nov 2024 08:16:03 +0000 Subject: [PATCH] add _transformer_encoder_layer_fwd --- test/xpu/skip_list_common.py | 2 +- yaml/native/native_functions.yaml | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 74b2a8eea..e1f3e66f9 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -810,7 +810,7 @@ # https://github.com/intel/torch-xpu-ops/issues/761 # AssertionError: False is not true # CPU fallback failure. To support aten::transformer_encoder_layer_forward with proper priority. - "test_disable_fastpath_xpu", + # "test_disable_fastpath_xpu", # We have no mechanism to handle SDPBackend::ERROR so far. Will give a fully support when we support all SDPBackends. "test_dispatch_fails_no_backend_xpu", # Could not run 'aten::_to_copy' with arguments from the 'NestedTensorXPU' backend diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index f4a580b40..67f1d4c37 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5585,6 +5585,13 @@ XPU: _dirichlet_grad_xpu autogen: _dirichlet_grad.out +# Apparently, putting "forward" in the name will cause Python bindings to be skipped, so "fwd" it is. +- func: _transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor + variants: function + dispatch: + XPU: transformer_encoder_layer_forward + autogen: _transformer_encoder_layer_fwd.out + # Fused implementation detail for transformers. Adds in-projection bias to QKV and divides Q by sqrt(D/num_heads). - func: _transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor) dispatch: