Skip to content

Commit

Permalink
Use 2d weight and bias texture for conv2d quantized op (pytorch#114902)
Browse files Browse the repository at this point in the history
Summary:
The performance with 2D texture for weight and bias is better for quantized conv2d, the un-quantized version of conv2d also uses 2D texture.
The performance gain is:

With 3D:
Kernel Name              Workgroup Size         Duration P50 (ns)
===========              ==============         =================
vulkan.quantized_conv2d  {96, 72, 2}                      5965440
vulkan.quantized_conv2d  {96, 72, 2}                     11316968
vulkan.quantized_conv2d_dw{96, 72, 2}                      2735564
vulkan.quantized_conv2d_pw_2x2{96, 72, 2}                      1645696

With 2D:
vulkan.quantized_conv2d  {96, 72, 2}                      4295772
vulkan.quantized_conv2d  {96, 72, 2}                      7874620
vulkan.quantized_conv2d_dw{96, 72, 2}                      2658552
vulkan.quantized_conv2d_pw_2x2{96, 72, 2}                      1632020

Test Plan:
Ensure all vulkan quantize tests pass:
buck2 run --target-platforms ovr_configplatform/macos:arm64-fbsourcexplat/caffe2:pt_vulkan_quantized_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"
Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 78 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 78 tests from VulkanAPITest
....
[----------] 78 tests from VulkanAPITest (1519 ms total)
[----------] Global test environment tear-down
[==========] 78 tests from 1 test suite ran. (1519 ms total)
[  PASSED  ] 78 tests.

buck2 run --target-platforms ovr_config//platform/macos:arm64-fbsource  //xplat/caffe2:pt_vulkan_api_test_binAppleMac\#macosx-arm64 -c pt.vulkan_full_precision=1 --show-output"

Running main() from third-party/googletest/1.11.0/googletest/googletest/src/gtest_main.cc
[==========] Running 395 tests from 1 test suite.
[----------] Global test environment set-up.
[----------] 395 tests from VulkanAPITest
......
----------] 395 tests from VulkanAPITest (6515 ms total)

[----------] Global test environment tear-down
[==========] 395 tests from 1 test suite ran. (6515 ms total)
[  PASSED  ] 394 tests.
[  SKIPPED ] 1 test, listed below:
[  SKIPPED ] VulkanAPITest.querypool_flushed_shader_log

  YOU HAVE 5 DISABLED TESTS

Reviewed By: yipjustin

Differential Revision: D50997534

Pull Request resolved: pytorch#114902
Approved by: https://github.com/yipjustin
  • Loading branch information
shubhraprakash1 authored and pytorchmergebot committed Dec 4, 2023
1 parent 6317a03 commit 8dbae73
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 22 deletions.
53 changes: 53 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/nchw_to_image2d_int32.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, rgba32i) uniform PRECISION restrict writeonly iimage2D uImage;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
int data[];
}
uBuffer;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the output texture, w contains HxW to help
// calculate buffer offsets
ivec4 out_extents;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
return;
}

const int base_index =
pos.x + uBlock.out_extents.x * pos.y + (4 * uBlock.out_extents.w) * pos.z;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w;

int val_x = uBuffer.data[buf_indices.x];
int val_y = uBuffer.data[buf_indices.y];
int val_z = uBuffer.data[buf_indices.z];
int val_w = uBuffer.data[buf_indices.w];

imageStore(uImage, pos.xy, ivec4(val_x, val_y, val_z, val_w));
}
81 changes: 81 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/nchw_to_image2d_int8.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, rgba8i) uniform PRECISION restrict writeonly iimage2D uImage;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
int data[];
}
uBuffer;

/*
* Extends sign of int8
*/
int extend_sign(int x) {
if (x >> 7 == 1) {
return x | 0xFFFFFF00;
}
return x;
}

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the output texture, w contains HxW to help
// calculate buffer offsets
ivec4 out_extents;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
return;
}

const int base_index =
pos.x + uBlock.out_extents.x * pos.y + (4 * uBlock.out_extents.w) * pos.z;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w;

int shift = (1 << 8) - 1;
ivec4 masks;
masks.x = shift << 8 * (buf_indices.x % 4);
masks.y = shift << 8 * (buf_indices.y % 4);
masks.z = shift << 8 * (buf_indices.z % 4);
masks.w = shift << 8 * (buf_indices.w % 4);

int buf_in_1 = uBuffer.data[buf_indices.x / 4];
int val_x = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4);
val_x = extend_sign(val_x);

int buf_in_2 = uBuffer.data[buf_indices.y / 4];
int val_y = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4);
val_y = extend_sign(val_y);

int buf_in_3 = uBuffer.data[buf_indices.z / 4];
int val_z = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4);
val_z = extend_sign(val_z);

int buf_in_4 = uBuffer.data[buf_indices.w / 4];
int val_w = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4);
val_w = extend_sign(val_w);

imageStore(uImage, pos.xy, ivec4(val_x, val_y, val_z, val_w));
}
67 changes: 67 additions & 0 deletions aten/src/ATen/native/vulkan/glsl/nchw_to_image2d_uint8.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#version 450 core
#define PRECISION $precision
#define FORMAT $format

layout(std430) buffer;

/*
* Output Image
*/
layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimage2D uImage;

/*
* Input Buffer
*/
layout(set = 0, binding = 1) buffer PRECISION restrict readonly Buffer {
uint data[];
}
uBuffer;

/*
* Params Buffer
*/
layout(set = 0, binding = 2) uniform PRECISION restrict Block {
// xyz contain the extents of the output texture, w contains HxW to help
// calculate buffer offsets
ivec4 out_extents;
}
uBlock;

/*
* Local Work Group Size
*/
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

void main() {
const ivec3 pos = ivec3(gl_GlobalInvocationID);

if (any(greaterThanEqual(pos, uBlock.out_extents.xyz))) {
return;
}

const int base_index =
pos.x + uBlock.out_extents.x * pos.y + (4 * uBlock.out_extents.w) * pos.z;
const ivec4 buf_indices =
base_index + ivec4(0, 1, 2, 3) * uBlock.out_extents.w;

int shift = (1 << 8) - 1;
ivec4 masks;
masks.x = shift << 8 * (buf_indices.x % 4);
masks.y = shift << 8 * (buf_indices.y % 4);
masks.z = shift << 8 * (buf_indices.z % 4);
masks.w = shift << 8 * (buf_indices.w % 4);

uint buf_in_1 = uBuffer.data[buf_indices.x / 4];
uint a_v = (buf_in_1 & masks.x) >> 8 * (buf_indices.x % 4);

uint buf_in_2 = uBuffer.data[buf_indices.y / 4];
uint b_v = (buf_in_2 & masks.y) >> 8 * (buf_indices.y % 4);

uint buf_in_3 = uBuffer.data[buf_indices.z / 4];
uint g_v = (buf_in_3 & masks.z) >> 8 * (buf_indices.z % 4);

uint buf_in_4 = uBuffer.data[buf_indices.w / 4];
uint r_v = (buf_in_4 & masks.w) >> 8 * (buf_indices.w % 4);

imageStore(uImage, pos.xy, uvec4(a_v, b_v, g_v, r_v));
}
14 changes: 7 additions & 7 deletions aten/src/ATen/native/vulkan/glsl/quantized_conv2d.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimag
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias;
layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;

/*
* Params Buffer
Expand Down Expand Up @@ -103,7 +103,7 @@ void main() {
kstart.y += pos.z * uBlock.kernel_size.y;

vec4 sum = dequantize(
texelFetch(uBias, ivec3(pos.z, 0, 0), 0),
texelFetch(uBias, ivec2(pos.z, 0), 0),
uBlock.scales.w,
uBlock.zero_points.w);

Expand Down Expand Up @@ -153,25 +153,25 @@ void main() {
// which is what is expressed in the following calculations.

const vec4 ktex_0 = dequantize(
texelFetch(uKernel, ivec3(kx + 0, ky, 0), 0),
texelFetch(uKernel, ivec2(kx + 0, ky), 0),
uBlock.scales.z,
uBlock.zero_points.z);
sum = fma(in_tex.xxxx, ktex_0, sum);

const vec4 ktex_1 = dequantize(
texelFetch(uKernel, ivec3(kx + 1, ky, 0), 0),
texelFetch(uKernel, ivec2(kx + 1, ky), 0),
uBlock.scales.z,
uBlock.zero_points.z);
sum = fma(in_tex.yyyy, ktex_1, sum);

const vec4 ktex_2 = dequantize(
texelFetch(uKernel, ivec3(kx + 2, ky, 0), 0),
texelFetch(uKernel, ivec2(kx + 2, ky), 0),
uBlock.scales.z,
uBlock.zero_points.z);
sum = fma(in_tex.zzzz, ktex_2, sum);

const vec4 ktex_3 = dequantize(
texelFetch(uKernel, ivec3(kx + 3, ky, 0), 0),
texelFetch(uKernel, ivec2(kx + 3, ky), 0),
uBlock.scales.z,
uBlock.zero_points.z);
sum = fma(in_tex.wwww, ktex_3, sum);
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/vulkan/glsl/quantized_conv2d_dw.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimag
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias;
layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;

/*
* Params Buffer
Expand Down Expand Up @@ -90,7 +90,7 @@ void main() {
const ivec2 kstart = (start - ipos) / uBlock.dilate;

vec4 sum = dequantize(
texelFetch(uBias, ivec3(pos.z, 0, 0), 0),
texelFetch(uBias, ivec2(pos.z, 0), 0),
uBlock.scales.w,
uBlock.zero_points.w);

Expand All @@ -104,7 +104,7 @@ void main() {
const int k_ind = kx + ky * uBlock.kernel_size.x;

const vec4 k_tex = dequantize(
texelFetch(uKernel, ivec3(k_ind, pos.z, 0), 0),
texelFetch(uKernel, ivec2(k_ind, pos.z), 0),
uBlock.scales.z,
uBlock.zero_points.z);
const vec4 in_tex = dequantize(
Expand Down
14 changes: 7 additions & 7 deletions aten/src/ATen/native/vulkan/glsl/quantized_conv2d_pw_2x2.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ layout(set = 0, binding = 0, rgba8ui) uniform PRECISION restrict writeonly uimag
* Input Textures
*/
layout(set = 0, binding = 1) uniform PRECISION isampler3D uInput;
layout(set = 0, binding = 2) uniform PRECISION isampler3D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler3D uBias;
layout(set = 0, binding = 2) uniform PRECISION isampler2D uKernel;
layout(set = 0, binding = 3) uniform PRECISION isampler2D uBias;

/*
* Params Buffer
Expand Down Expand Up @@ -100,7 +100,7 @@ void main() {

vec4 sum[4];
sum[0] = dequantize(
texelFetch(uBias, ivec3(gpos.z, 0, 0), 0),
texelFetch(uBias, ivec2(gpos.z, 0), 0),
uBlock.scales.w,
uBlock.zero_points.w);
for (int i = 1; i < 4; ++i) {
Expand All @@ -114,19 +114,19 @@ void main() {
// channel (IC) dim is along the x axis, and the batch (OC) dim is along
// the z axis.
const vec4 ktex_0 = dequantize(
texelFetch(uKernel, ivec3(z + 0, gpos.z, 0), 0),
texelFetch(uKernel, ivec2(z + 0, gpos.z), 0),
uBlock.scales.z,
uBlock.zero_points.z);
const vec4 ktex_1 = dequantize(
texelFetch(uKernel, ivec3(z + 1, gpos.z, 0), 0),
texelFetch(uKernel, ivec2(z + 1, gpos.z), 0),
uBlock.scales.z,
uBlock.zero_points.z);
const vec4 ktex_2 = dequantize(
texelFetch(uKernel, ivec3(z + 2, gpos.z, 0), 0),
texelFetch(uKernel, ivec2(z + 2, gpos.z), 0),
uBlock.scales.z,
uBlock.zero_points.z);
const vec4 ktex_3 = dequantize(
texelFetch(uKernel, ivec3(z + 3, gpos.z, 0), 0),
texelFetch(uKernel, ivec2(z + 3, gpos.z), 0),
uBlock.scales.z,
uBlock.zero_points.z);

Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/vulkan/impl/Packing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,20 @@ api::ShaderInfo get_nchw_to_image_shader(const vTensor& v_dst) {
"Vulkan quantization currently not supported for dtype ",
v_dst.dtype());
}
case api::StorageType::TEXTURE_2D:
switch (v_dst.dtype()) {
case c10::ScalarType::QUInt8:
return VK_KERNEL(nchw_to_image2d_uint8);
case c10::ScalarType::QInt8:
return VK_KERNEL(nchw_to_image2d_int8);
case c10::ScalarType::QInt32:
return VK_KERNEL(nchw_to_image2d_int32);
default:
TORCH_CHECK(
false,
"Vulkan quantization currently not supported for dtype ",
v_dst.dtype());
}
default:
TORCH_CHECK(false, "No kernel available!");
case api::StorageType::BUFFER:
Expand Down
8 changes: 4 additions & 4 deletions aten/src/ATen/native/vulkan/ops/Convolution.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#include <ATen/native/ConvUtils.h>
#include <ATen/native/utils/ParamUtils.h>

#include <ATen/Context.h>

#include <ATen/native/ConvUtils.h>
#include <ATen/native/utils/ParamUtils.h>
#include <ATen/native/vulkan/api/Utils.h>
#include <ATen/native/vulkan/ops/Common.h>
#include <ATen/native/vulkan/ops/Convolution.h>
Expand Down Expand Up @@ -529,7 +529,7 @@ vTensor pack_weights(
api::context(),
weight_rearranged.sizes(),
weight_arg.scalar_type(),
quantized ? api::StorageType::TEXTURE_3D : api::StorageType::TEXTURE_2D,
api::StorageType::TEXTURE_2D,
};

if (quantized) {
Expand Down Expand Up @@ -557,7 +557,7 @@ vTensor pack_biases(
api::context(),
bias_rearranged.sizes(),
bias_rearranged.scalar_type(),
quantized ? api::StorageType::TEXTURE_3D : api::StorageType::TEXTURE_2D,
api::StorageType::TEXTURE_2D,
};

if (quantized) {
Expand Down

0 comments on commit 8dbae73

Please sign in to comment.