Skip to content

Commit

Permalink
Merge branch 'Transpose' of github.com:niket-agarwal/nntrainer into T…
Browse files Browse the repository at this point in the history
…ranspose

Updating the transpose function with the recent GPU pipeline changes.
  • Loading branch information
niket-agarwal committed Oct 7, 2024
2 parents f80cd62 + 8f846e8 commit 8ea1e02
Showing 1 changed file with 168 additions and 0 deletions.
168 changes: 168 additions & 0 deletions nntrainer/tensor/cl_operations/blas_kernel_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,87 @@ static const std::string sscal_cl_kernel_ =
X[i] *= alpha;
})";

static const std::string transpose_cl_kernel_axis0 =
R"(__kernel void transpose_cl_axis0(__global const float* in,
__global float* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate h and w from the global IDs
int h = get_global_id(0);
int w = get_global_id(1);
if (h < height && w < width) {
for (int c = 0; c < channels; ++c) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + h * (channels * width) + c * width + w;
// Transpose channel and height, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";

static const std::string transpose_cl_kernel_axis1 =
R"(__kernel void transpose_cl_axis1(__global const float* in,
__global float* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate h and w from the global IDs
int h = get_global_id(0);
int w = get_global_id(1);
if (h < height && w < width) {
for (int c = 0; c < channels; ++c) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + c * (height * width) + w * height + h;
// Transpose height and width, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";

static const std::string transpose_cl_kernel_axis2 =
R"(__kernel void transpose_cl_axis2(__global const float* in,
__global float* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate c and w from the global IDs
int c = get_global_id(0);
int w = get_global_id(1);
if (c < channels && w < width) {
for (int h = 0; h < height; ++h) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + w * (height * channels) + h * channels + c;
// Transpose channel and width, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";

#ifdef ENABLE_FP16
static const std::string sgemv_cl_kernel_fp16_ =
R"(
Expand Down Expand Up @@ -244,6 +325,93 @@ static const std::string sscal_cl_kernel_fp16_ =
unsigned int i = get_global_id(0);
X[i] *= alpha;
})";

static const std::string transpose_cl_kernel_fp16_axis0 =
R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void transpose_cl_fp16_axis0(__global const half* in,
__global half* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate h and w from the global IDs
int h = get_global_id(0);
int w = get_global_id(1);
if (h < height && w < width) {
for (int c = 0; c < channels; ++c) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + h * (channels * width) + c * width + w;
// Transpose channel and height, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";

static const std::string transpose_cl_kernel_fp16_axis1 =
R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void transpose_cl_fp16_axis1(__global const half* in,
__global half* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate h and w from the global IDs
int h = get_global_id(0);
int w = get_global_id(1);
if (h < height && w < width) {
for (int c = 0; c < channels; ++c) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + c * (height * width) + w * height + h;
// Transpose height and width, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";

static const std::string transpose_cl_kernel_fp16_axis2 =
R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void transpose_cl_fp16_axis2(__global const half* in,
__global half* output,
const int batch_size,
const int channels,
const int height,
const int width) {
// Calculate c and w from the global IDs
int c = get_global_id(0);
int w = get_global_id(1);
if (c < channels && w < width) {
for (int h = 0; h < height; ++h) {
for (int n = 0; n < batch_size; ++n) {
// Calculate the input and output indices
int input_index = n * (channels * height * width) + c * (height * width) + h * width + w;
int output_index = n * (channels * height * width) + w * (height * channels) + h * channels + c;
// Transpose channel and width, copying data from input to output
output[output_index] = in[input_index];
}
}
}
})";
#endif
} // namespace nntrainer
#endif /* __BLAS_KERNEL_INTERFACE_H__ */

0 comments on commit 8ea1e02

Please sign in to comment.