From 1bbba085328081b50eb01ad54ba786824e46b3d5 Mon Sep 17 00:00:00 2001 From: Nikos Kolotouros Date: Sun, 12 Aug 2018 20:55:30 -0400 Subject: [PATCH 1/4] added texture wrapping --- neural_renderer/cuda/load_textures_cuda.cpp | 10 +- .../cuda/load_textures_cuda_kernel.cu | 124 ++++++++++++------ neural_renderer/load_obj.py | 17 ++- 3 files changed, 105 insertions(+), 46 deletions(-) diff --git a/neural_renderer/cuda/load_textures_cuda.cpp b/neural_renderer/cuda/load_textures_cuda.cpp index 1a12a51..a4a92d2 100644 --- a/neural_renderer/cuda/load_textures_cuda.cpp +++ b/neural_renderer/cuda/load_textures_cuda.cpp @@ -6,7 +6,9 @@ at::Tensor load_textures_cuda( at::Tensor image, at::Tensor faces, at::Tensor textures, - at::Tensor is_update); + at::Tensor is_update, + int texture_wrapping, + int use_bilinear); // C++ interface @@ -19,14 +21,16 @@ at::Tensor load_textures( at::Tensor image, at::Tensor faces, at::Tensor textures, - at::Tensor is_update) { + at::Tensor is_update, + int texture_wrapping, + int use_bilinear) { CHECK_INPUT(image); CHECK_INPUT(faces); CHECK_INPUT(is_update); CHECK_INPUT(textures); - return load_textures_cuda(image, faces, textures, is_update); + return load_textures_cuda(image, faces, textures, is_update, texture_wrapping, use_bilinear); } diff --git a/neural_renderer/cuda/load_textures_cuda_kernel.cu b/neural_renderer/cuda/load_textures_cuda_kernel.cu index 6d088f4..3c10431 100644 --- a/neural_renderer/cuda/load_textures_cuda_kernel.cu +++ b/neural_renderer/cuda/load_textures_cuda_kernel.cu @@ -3,17 +3,26 @@ #include #include + namespace { + +const int REPEAT = 0; +const int MIRRORED_REPEAT = 1; +const int CLAMP_TO_EDGE = 2; +const int CLAMP_TO_BORDER = 3; + template __global__ void load_textures_cuda_kernel( - const scalar_t* __restrict__ image, - const scalar_t* __restrict__ faces, + const scalar_t* image, const int32_t* __restrict__ is_update, + scalar_t* __restrict__ faces, scalar_t* __restrict__ textures, - size_t textures_size, - size_t texture_size, - size_t image_height, - size_t image_width) { + int textures_size, + int texture_size, + int image_height, + int image_width, + int texture_wrapping, + bool use_bilinear) { const int i = blockIdx.x * blockDim.x + threadIdx.x; if (i >= textures_size / 3) { return; @@ -29,35 +38,70 @@ __global__ void load_textures_cuda_kernel( dim1 /= sum; dim2 /= sum; } - const scalar_t* face = &faces[fn * 3 * 2]; - scalar_t* texture = &textures[i * 3]; - if (is_update[fn] == 0) return; - - const scalar_t pos_x = ( - (face[2 * 0 + 0] * dim0 + face[2 * 1 + 0] * dim1 + face[2 * 2 + 0] * dim2) * (image_width - 1)); - const scalar_t pos_y = ( - (face[2 * 0 + 1] * dim0 + face[2 * 1 + 1] * dim1 + face[2 * 2 + 1] * dim2) * (image_height - 1)); - if (1) { - /* bilinear sampling */ - const scalar_t weight_x1 = pos_x - (int)pos_x; - const scalar_t weight_x0 = 1 - weight_x1; - const scalar_t weight_y1 = pos_y - (int)pos_y; - const scalar_t weight_y0 = 1 - weight_y1; - for (int k = 0; k < 3; k++) { - scalar_t c = 0; - c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0); - c += image[((int)(pos_y + 1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); - c += image[((int)pos_y * image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y0); - c += image[((int)(pos_y + 1)* image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y1); - texture[k] = c; - } - } else { - /* nearest neighbor */ - const int pos_xi = round(pos_x); - const int pos_yi = round(pos_y); - for (int k = 0; k < 3; k++) { - texture[k] = image[(pos_yi * image_width + pos_xi) * 3 + k]; - } + scalar_t* face = &faces[fn * 3 * 2]; + scalar_t* texture_ = &textures[i * 3]; + + if (is_update[fn] != 0) { + if (texture_wrapping == REPEAT) { + #pragma unroll + for (int i = 0; i < 6; ++i) { + face[i] = fmod(face[i], (scalar_t)2.); + // face[i] = 1; + } + } + else if (texture_wrapping == MIRRORED_REPEAT) { + #pragma unroll + for (int i = 0; i < 6; ++i) { + if ( ((int)face[i] / 2) % 2 == 1) { + face[i] = 1 - fmod(face[i], (scalar_t)2.); + } + else { + face[i] = fmod(face[i], (scalar_t)2.); + } + } + } + else if (texture_wrapping == CLAMP_TO_EDGE) { + #pragma unroll + for (int i = 0; i < 6; ++i) { + face[i] = max(min(face[i], (scalar_t) 1), (scalar_t) 0); + } + } + const scalar_t pos_x = ( + (face[2 * 0 + 0] * dim0 + face[2 * 1 + 0] * dim1 + face[2 * 2 + 0] * dim2) * (image_width - 1)); + const scalar_t pos_y = ( + (face[2 * 0 + 1] * dim0 + face[2 * 1 + 1] * dim1 + face[2 * 2 + 1] * dim2) * (image_height - 1)); + if (use_bilinear) { + /* bilinear sampling */ + const scalar_t weight_x1 = pos_x - (int)pos_x; + const scalar_t weight_x0 = 1 - weight_x1; + const scalar_t weight_y1 = pos_y - (int)pos_y; + const scalar_t weight_y0 = 1 - weight_y1; + for (int k = 0; k < 3; k++) { + if (texture_wrapping != CLAMP_TO_BORDER || (pos_x > 0 && pos_x < image_width-1) || (pos_y > 0 && pos_y < image_height-1)) { + scalar_t c = 0; + c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0); + c += image[((int)(pos_y + 1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); + c += image[((int)pos_y * image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y0); + c += image[((int)(pos_y + 1)* image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y1); + texture_[k] = c; + } + else { + texture_[k] = 0; + } + } + } else { + /* nearest neighbor */ + const int pos_xi = round(pos_x); + const int pos_yi = round(pos_y); + for (int k = 0; k < 3; k++) { + if (texture_wrapping != CLAMP_TO_BORDER || (pos_xi > 0 && pos_xi < image_width) || (pos_yi > 0 && pos_yi < image_height)) { + texture_[k] = image[(pos_yi * image_width + pos_xi) * 3 + k]; + } + else { + texture_[k] = 0; + } + } + } } } } @@ -66,7 +110,9 @@ at::Tensor load_textures_cuda( at::Tensor image, at::Tensor faces, at::Tensor textures, - at::Tensor is_update) { + at::Tensor is_update, + int texture_wrapping, + int use_bilinear) { // textures_size = size of the textures tensor const auto textures_size = textures.numel(); // notice that texture_size != texture_size @@ -80,13 +126,15 @@ at::Tensor load_textures_cuda( AT_DISPATCH_FLOATING_TYPES(image.type(), "load_textures_cuda", ([&] { load_textures_cuda_kernel<<>>( image.data(), - faces.data(), is_update.data(), + faces.data(), textures.data(), textures_size, texture_size, image_height, - image_width); + image_width, + texture_wrapping, + use_bilinear); })); cudaError_t err = cudaGetLastError(); diff --git a/neural_renderer/load_obj.py b/neural_renderer/load_obj.py index 7043495..4f08b05 100644 --- a/neural_renderer/load_obj.py +++ b/neural_renderer/load_obj.py @@ -7,6 +7,9 @@ import neural_renderer.cuda.load_textures as load_textures_cuda +texture_wrapping_dict = {'REPEAT': 0, 'MIRRORED_REPEAT': 1, + 'CLAMP_TO_EDGE': 2, 'CLAMP_TO_BORDER': 3} + def load_mtl(filename_mtl): ''' load color (Kd) and filename of textures from *.mtl @@ -26,7 +29,7 @@ def load_mtl(filename_mtl): return colors, texture_filenames -def load_textures(filename_obj, filename_mtl, texture_size): +def load_textures(filename_obj, filename_mtl, texture_size, texture_wrapping='REPEAT', use_bilinear=True): # load vertices vertices = [] with open(filename_obj) as f: @@ -68,7 +71,6 @@ def load_textures(filename_obj, filename_mtl, texture_size): faces = np.vstack(faces).astype(np.int32) - 1 faces = vertices[faces] faces = torch.from_numpy(faces).cuda() - faces[1 < faces] = faces[1 < faces] % 1 colors, texture_filenames = load_mtl(filename_mtl) @@ -98,10 +100,13 @@ def load_textures(filename_obj, filename_mtl, texture_size): image = torch.from_numpy(image.copy()).cuda() is_update = (np.array(material_names) == material_name).astype(np.int32) is_update = torch.from_numpy(is_update).cuda() - textures = load_textures_cuda.load_textures(image, faces, textures, is_update) + textures = load_textures_cuda.load_textures(image, faces, textures, is_update, + texture_wrapping_dict[texture_wrapping], + use_bilinear) return textures -def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=False): +def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=False, + texture_wrapping='REPEAT', use_bilinear=True): """ Load Wavefront .obj file. This function only supports vertices (v x x x) and faces (f x x x). @@ -140,7 +145,9 @@ def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=Fals for line in lines: if line.startswith('mtllib'): filename_mtl = os.path.join(os.path.dirname(filename_obj), line.split()[1]) - textures = load_textures(filename_obj, filename_mtl, texture_size) + textures = load_textures(filename_obj, filename_mtl, texture_size, + texture_wrapping=texture_wrapping, + use_bilinear=use_bilinear) if textures is None: raise Exception('Failed to load textures.') From b067de5ad33952415e335a5f8b93f8587ad20ebb Mon Sep 17 00:00:00 2001 From: Nikos Kolotouros Date: Mon, 20 Aug 2018 15:26:39 -0400 Subject: [PATCH 2/4] fixed out of bounds accessing bug --- neural_renderer/cuda/load_textures_cuda_kernel.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/neural_renderer/cuda/load_textures_cuda_kernel.cu b/neural_renderer/cuda/load_textures_cuda_kernel.cu index 3c10431..9eb7ee8 100644 --- a/neural_renderer/cuda/load_textures_cuda_kernel.cu +++ b/neural_renderer/cuda/load_textures_cuda_kernel.cu @@ -77,12 +77,12 @@ __global__ void load_textures_cuda_kernel( const scalar_t weight_y1 = pos_y - (int)pos_y; const scalar_t weight_y0 = 1 - weight_y1; for (int k = 0; k < 3; k++) { - if (texture_wrapping != CLAMP_TO_BORDER || (pos_x > 0 && pos_x < image_width-1) || (pos_y > 0 && pos_y < image_height-1)) { + if (texture_wrapping != CLAMP_TO_BORDER) { scalar_t c = 0; c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0); - c += image[((int)(pos_y + 1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); - c += image[((int)pos_y * image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y0); - c += image[((int)(pos_y + 1)* image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y1); + c += image[((int)min((pos_y + 1), image_height-1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); + c += image[((int)pos_y * image_width + (min((int)pos_x) + 1, image_width-1)) * 3 + k] * (weight_x1 * weight_y0); + c += image[(min((int)(pos_y + 1), image_height-1)* image_width + min(((int)pos_x) + 1), image_width-1) * 3 + k] * (weight_x1 * weight_y1); texture_[k] = c; } else { @@ -94,7 +94,7 @@ __global__ void load_textures_cuda_kernel( const int pos_xi = round(pos_x); const int pos_yi = round(pos_y); for (int k = 0; k < 3; k++) { - if (texture_wrapping != CLAMP_TO_BORDER || (pos_xi > 0 && pos_xi < image_width) || (pos_yi > 0 && pos_yi < image_height)) { + if (texture_wrapping != CLAMP_TO_BORDER) { texture_[k] = image[(pos_yi * image_width + pos_xi) * 3 + k]; } else { From a5d870cb05b6b3bf04b47d85412d4691b0a2bef0 Mon Sep 17 00:00:00 2001 From: Nikos Kolotouros Date: Mon, 20 Aug 2018 15:33:55 -0400 Subject: [PATCH 3/4] fixed compilation issues --- neural_renderer/cuda/load_textures_cuda_kernel.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_renderer/cuda/load_textures_cuda_kernel.cu b/neural_renderer/cuda/load_textures_cuda_kernel.cu index 9eb7ee8..6092ad8 100644 --- a/neural_renderer/cuda/load_textures_cuda_kernel.cu +++ b/neural_renderer/cuda/load_textures_cuda_kernel.cu @@ -80,9 +80,9 @@ __global__ void load_textures_cuda_kernel( if (texture_wrapping != CLAMP_TO_BORDER) { scalar_t c = 0; c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0); - c += image[((int)min((pos_y + 1), image_height-1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); - c += image[((int)pos_y * image_width + (min((int)pos_x) + 1, image_width-1)) * 3 + k] * (weight_x1 * weight_y0); - c += image[(min((int)(pos_y + 1), image_height-1)* image_width + min(((int)pos_x) + 1), image_width-1) * 3 + k] * (weight_x1 * weight_y1); + c += image[(min((int)(pos_y + 1), image_height-1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); + c += image[((int)pos_y * image_width + (min((int)pos_x + 1, image_width-1))) * 3 + k] * (weight_x1 * weight_y0); + c += image[(min((int)(pos_y + 1), image_height-1)* image_width + min((int)pos_x + 1, image_width-1)) * 3 + k] * (weight_x1 * weight_y1); texture_[k] = c; } else { From 1bc8387354fd7b4323ed0f23d6d034b9532ef456 Mon Sep 17 00:00:00 2001 From: Nikos Kolotouros Date: Tue, 21 Aug 2018 15:39:26 -0400 Subject: [PATCH 4/4] fixed out of bounds bug --- .../cuda/load_textures_cuda_kernel.cu | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/neural_renderer/cuda/load_textures_cuda_kernel.cu b/neural_renderer/cuda/load_textures_cuda_kernel.cu index 6092ad8..848f0a9 100644 --- a/neural_renderer/cuda/load_textures_cuda_kernel.cu +++ b/neural_renderer/cuda/load_textures_cuda_kernel.cu @@ -3,6 +3,15 @@ #include #include +template +static __inline__ __device__ scalar_t mod(scalar_t x, scalar_t y) { + if (x > 0) { + return fmod(x,y); + } + else { + return y + fmod(x,y); + } +} namespace { @@ -14,8 +23,8 @@ const int CLAMP_TO_BORDER = 3; template __global__ void load_textures_cuda_kernel( const scalar_t* image, - const int32_t* __restrict__ is_update, - scalar_t* __restrict__ faces, + const int32_t* is_update, + scalar_t* faces, scalar_t* __restrict__ textures, int textures_size, int texture_size, @@ -45,18 +54,17 @@ __global__ void load_textures_cuda_kernel( if (texture_wrapping == REPEAT) { #pragma unroll for (int i = 0; i < 6; ++i) { - face[i] = fmod(face[i], (scalar_t)2.); - // face[i] = 1; + face[i] = mod(face[i], (scalar_t)1.); } } else if (texture_wrapping == MIRRORED_REPEAT) { #pragma unroll for (int i = 0; i < 6; ++i) { - if ( ((int)face[i] / 2) % 2 == 1) { - face[i] = 1 - fmod(face[i], (scalar_t)2.); + if (mod(face[i], (scalar_t)2) < 1) { + face[i] = mod(face[i], (scalar_t)1.); } else { - face[i] = fmod(face[i], (scalar_t)2.); + face[i] = 1 - mod(face[i], (scalar_t)1.); } } } @@ -79,10 +87,10 @@ __global__ void load_textures_cuda_kernel( for (int k = 0; k < 3; k++) { if (texture_wrapping != CLAMP_TO_BORDER) { scalar_t c = 0; - c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0); - c += image[(min((int)(pos_y + 1), image_height-1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1); - c += image[((int)pos_y * image_width + (min((int)pos_x + 1, image_width-1))) * 3 + k] * (weight_x1 * weight_y0); - c += image[(min((int)(pos_y + 1), image_height-1)* image_width + min((int)pos_x + 1, image_width-1)) * 3 + k] * (weight_x1 * weight_y1); + c += image[(int)pos_y * image_width * 3 + (int)pos_x * 3 + k] * (weight_x0 * weight_y0); + c += image[min((int)(pos_y + 1), image_height-1) * image_width * 3 + (int)pos_x * 3 + k] * (weight_x0 * weight_y1); + c += image[(int)pos_y * image_width * 3 + min((int)pos_x + 1, image_width-1) * 3 + k] * (weight_x1 * weight_y0); + c += image[min((int)(pos_y + 1), image_height-1) * image_width * 3 + min((int)pos_x + 1, image_width-1) * 3 + k] * (weight_x1 * weight_y1); texture_[k] = c; } else { @@ -95,7 +103,7 @@ __global__ void load_textures_cuda_kernel( const int pos_yi = round(pos_y); for (int k = 0; k < 3; k++) { if (texture_wrapping != CLAMP_TO_BORDER) { - texture_[k] = image[(pos_yi * image_width + pos_xi) * 3 + k]; + texture_[k] = image[pos_yi * image_width * 3 + pos_xi * 3 + k]; } else { texture_[k] = 0;