Skip to content

Commit

Permalink
Merge pull request hiroharu-kato#14 from daniilidis-group/texture_sam…
Browse files Browse the repository at this point in the history
…pling

Add texture wrapping
  • Loading branch information
nkolot authored Aug 22, 2018
2 parents 29a7615 + 1bc8387 commit 55a05a2
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 47 deletions.
10 changes: 7 additions & 3 deletions neural_renderer/cuda/load_textures_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ at::Tensor load_textures_cuda(
at::Tensor image,
at::Tensor faces,
at::Tensor textures,
at::Tensor is_update);
at::Tensor is_update,
int texture_wrapping,
int use_bilinear);

// C++ interface

Expand All @@ -19,14 +21,16 @@ at::Tensor load_textures(
at::Tensor image,
at::Tensor faces,
at::Tensor textures,
at::Tensor is_update) {
at::Tensor is_update,
int texture_wrapping,
int use_bilinear) {

CHECK_INPUT(image);
CHECK_INPUT(faces);
CHECK_INPUT(is_update);
CHECK_INPUT(textures);

return load_textures_cuda(image, faces, textures, is_update);
return load_textures_cuda(image, faces, textures, is_update, texture_wrapping, use_bilinear);

}

Expand Down
134 changes: 95 additions & 39 deletions neural_renderer/cuda/load_textures_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,35 @@
#include <cuda.h>
#include <cuda_runtime.h>

template <typename scalar_t>
static __inline__ __device__ scalar_t mod(scalar_t x, scalar_t y) {
if (x > 0) {
return fmod(x,y);
}
else {
return y + fmod(x,y);
}
}

namespace {

const int REPEAT = 0;
const int MIRRORED_REPEAT = 1;
const int CLAMP_TO_EDGE = 2;
const int CLAMP_TO_BORDER = 3;

template <typename scalar_t>
__global__ void load_textures_cuda_kernel(
const scalar_t* __restrict__ image,
const scalar_t* __restrict__ faces,
const int32_t* __restrict__ is_update,
const scalar_t* image,
const int32_t* is_update,
scalar_t* faces,
scalar_t* __restrict__ textures,
size_t textures_size,
size_t texture_size,
size_t image_height,
size_t image_width) {
int textures_size,
int texture_size,
int image_height,
int image_width,
int texture_wrapping,
bool use_bilinear) {
const int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i >= textures_size / 3) {
return;
Expand All @@ -29,35 +47,69 @@ __global__ void load_textures_cuda_kernel(
dim1 /= sum;
dim2 /= sum;
}
const scalar_t* face = &faces[fn * 3 * 2];
scalar_t* texture = &textures[i * 3];
if (is_update[fn] == 0) return;

const scalar_t pos_x = (
(face[2 * 0 + 0] * dim0 + face[2 * 1 + 0] * dim1 + face[2 * 2 + 0] * dim2) * (image_width - 1));
const scalar_t pos_y = (
(face[2 * 0 + 1] * dim0 + face[2 * 1 + 1] * dim1 + face[2 * 2 + 1] * dim2) * (image_height - 1));
if (1) {
/* bilinear sampling */
const scalar_t weight_x1 = pos_x - (int)pos_x;
const scalar_t weight_x0 = 1 - weight_x1;
const scalar_t weight_y1 = pos_y - (int)pos_y;
const scalar_t weight_y0 = 1 - weight_y1;
for (int k = 0; k < 3; k++) {
scalar_t c = 0;
c += image[((int)pos_y * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y0);
c += image[((int)(pos_y + 1) * image_width + (int)pos_x) * 3 + k] * (weight_x0 * weight_y1);
c += image[((int)pos_y * image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y0);
c += image[((int)(pos_y + 1)* image_width + ((int)pos_x) + 1) * 3 + k] * (weight_x1 * weight_y1);
texture[k] = c;
}
} else {
/* nearest neighbor */
const int pos_xi = round(pos_x);
const int pos_yi = round(pos_y);
for (int k = 0; k < 3; k++) {
texture[k] = image[(pos_yi * image_width + pos_xi) * 3 + k];
}
scalar_t* face = &faces[fn * 3 * 2];
scalar_t* texture_ = &textures[i * 3];

if (is_update[fn] != 0) {
if (texture_wrapping == REPEAT) {
#pragma unroll
for (int i = 0; i < 6; ++i) {
face[i] = mod(face[i], (scalar_t)1.);
}
}
else if (texture_wrapping == MIRRORED_REPEAT) {
#pragma unroll
for (int i = 0; i < 6; ++i) {
if (mod(face[i], (scalar_t)2) < 1) {
face[i] = mod(face[i], (scalar_t)1.);
}
else {
face[i] = 1 - mod(face[i], (scalar_t)1.);
}
}
}
else if (texture_wrapping == CLAMP_TO_EDGE) {
#pragma unroll
for (int i = 0; i < 6; ++i) {
face[i] = max(min(face[i], (scalar_t) 1), (scalar_t) 0);
}
}
const scalar_t pos_x = (
(face[2 * 0 + 0] * dim0 + face[2 * 1 + 0] * dim1 + face[2 * 2 + 0] * dim2) * (image_width - 1));
const scalar_t pos_y = (
(face[2 * 0 + 1] * dim0 + face[2 * 1 + 1] * dim1 + face[2 * 2 + 1] * dim2) * (image_height - 1));
if (use_bilinear) {
/* bilinear sampling */
const scalar_t weight_x1 = pos_x - (int)pos_x;
const scalar_t weight_x0 = 1 - weight_x1;
const scalar_t weight_y1 = pos_y - (int)pos_y;
const scalar_t weight_y0 = 1 - weight_y1;
for (int k = 0; k < 3; k++) {
if (texture_wrapping != CLAMP_TO_BORDER) {
scalar_t c = 0;
c += image[(int)pos_y * image_width * 3 + (int)pos_x * 3 + k] * (weight_x0 * weight_y0);
c += image[min((int)(pos_y + 1), image_height-1) * image_width * 3 + (int)pos_x * 3 + k] * (weight_x0 * weight_y1);
c += image[(int)pos_y * image_width * 3 + min((int)pos_x + 1, image_width-1) * 3 + k] * (weight_x1 * weight_y0);
c += image[min((int)(pos_y + 1), image_height-1) * image_width * 3 + min((int)pos_x + 1, image_width-1) * 3 + k] * (weight_x1 * weight_y1);
texture_[k] = c;
}
else {
texture_[k] = 0;
}
}
} else {
/* nearest neighbor */
const int pos_xi = round(pos_x);
const int pos_yi = round(pos_y);
for (int k = 0; k < 3; k++) {
if (texture_wrapping != CLAMP_TO_BORDER) {
texture_[k] = image[pos_yi * image_width * 3 + pos_xi * 3 + k];
}
else {
texture_[k] = 0;
}
}
}
}
}
}
Expand All @@ -66,7 +118,9 @@ at::Tensor load_textures_cuda(
at::Tensor image,
at::Tensor faces,
at::Tensor textures,
at::Tensor is_update) {
at::Tensor is_update,
int texture_wrapping,
int use_bilinear) {
// textures_size = size of the textures tensor
const auto textures_size = textures.numel();
// notice that texture_size != texture_size
Expand All @@ -80,13 +134,15 @@ at::Tensor load_textures_cuda(
AT_DISPATCH_FLOATING_TYPES(image.type(), "load_textures_cuda", ([&] {
load_textures_cuda_kernel<scalar_t><<<blocks, threads>>>(
image.data<scalar_t>(),
faces.data<scalar_t>(),
is_update.data<int32_t>(),
faces.data<scalar_t>(),
textures.data<scalar_t>(),
textures_size,
texture_size,
image_height,
image_width);
image_width,
texture_wrapping,
use_bilinear);
}));

cudaError_t err = cudaGetLastError();
Expand Down
17 changes: 12 additions & 5 deletions neural_renderer/load_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

import neural_renderer.cuda.load_textures as load_textures_cuda

texture_wrapping_dict = {'REPEAT': 0, 'MIRRORED_REPEAT': 1,
'CLAMP_TO_EDGE': 2, 'CLAMP_TO_BORDER': 3}

def load_mtl(filename_mtl):
'''
load color (Kd) and filename of textures from *.mtl
Expand All @@ -26,7 +29,7 @@ def load_mtl(filename_mtl):
return colors, texture_filenames


def load_textures(filename_obj, filename_mtl, texture_size):
def load_textures(filename_obj, filename_mtl, texture_size, texture_wrapping='REPEAT', use_bilinear=True):
# load vertices
vertices = []
with open(filename_obj) as f:
Expand Down Expand Up @@ -68,7 +71,6 @@ def load_textures(filename_obj, filename_mtl, texture_size):
faces = np.vstack(faces).astype(np.int32) - 1
faces = vertices[faces]
faces = torch.from_numpy(faces).cuda()
faces[1 < faces] = faces[1 < faces] % 1

colors, texture_filenames = load_mtl(filename_mtl)

Expand Down Expand Up @@ -98,10 +100,13 @@ def load_textures(filename_obj, filename_mtl, texture_size):
image = torch.from_numpy(image.copy()).cuda()
is_update = (np.array(material_names) == material_name).astype(np.int32)
is_update = torch.from_numpy(is_update).cuda()
textures = load_textures_cuda.load_textures(image, faces, textures, is_update)
textures = load_textures_cuda.load_textures(image, faces, textures, is_update,
texture_wrapping_dict[texture_wrapping],
use_bilinear)
return textures

def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=False):
def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=False,
texture_wrapping='REPEAT', use_bilinear=True):
"""
Load Wavefront .obj file.
This function only supports vertices (v x x x) and faces (f x x x).
Expand Down Expand Up @@ -140,7 +145,9 @@ def load_obj(filename_obj, normalization=True, texture_size=4, load_texture=Fals
for line in lines:
if line.startswith('mtllib'):
filename_mtl = os.path.join(os.path.dirname(filename_obj), line.split()[1])
textures = load_textures(filename_obj, filename_mtl, texture_size)
textures = load_textures(filename_obj, filename_mtl, texture_size,
texture_wrapping=texture_wrapping,
use_bilinear=use_bilinear)
if textures is None:
raise Exception('Failed to load textures.')

Expand Down

0 comments on commit 55a05a2

Please sign in to comment.