Skip to content

Commit

Permalink
compile for different architectures
Browse files Browse the repository at this point in the history
  • Loading branch information
nkolot committed Jul 13, 2018
1 parent 6aded30 commit ffc2762
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 25 deletions.
4 changes: 2 additions & 2 deletions neural_renderer/cuda/create_texture_image_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ at::Tensor create_texture_image_cuda(
const auto tile_width = int(sqrt(num_faces - 1)) + 1;
const auto texture_size_out = image.size(1) / tile_width;

const int threads = 1024;
const int threads = 128;
const int image_size = image.numel();
const int blocks = (image_size / 3 - 1) / threads + 1;
const dim3 blocks ((image_size / 3 - 1) / threads + 1, 1, 1);

AT_DISPATCH_FLOATING_TYPES(image.type(), "create_texture_image_cuda", ([&] {
create_texture_image_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down
2 changes: 1 addition & 1 deletion neural_renderer/cuda/load_textures_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ at::Tensor load_textures_cuda(
const auto image_width = image.size(1);

const int threads = 1024;
const int blocks = (textures_size / 3 - 1) / threads + 1;
const dim3 blocks ((textures_size / 3 - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(image.type(), "load_textures_cuda", ([&] {
load_textures_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down
38 changes: 19 additions & 19 deletions neural_renderer/cuda/rasterize_cuda_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ __global__ void forward_face_index_map_cuda_kernel(
/* lock and update */
bool locked = false;
do {
if (locked = atomicCAS(&lock[index], 0, 1) == 0) {
if (zp < depth_map[index]) {
if ( locked = atomicCAS(&lock[index], 0, 1) == 0) {
if (zp <= depth_map[index]) {
depth_map[index] = zp;
face_index_map[index] = fn;
for (int k = 0; k < 3; k++)
Expand Down Expand Up @@ -566,13 +566,13 @@ __global__ void backward_textures_cuda_kernel(

template <typename scalar_t>
__global__ void backward_depth_map_cuda_kernel(
const scalar_t* __restrict__ faces,
const scalar_t* __restrict__ depth_map,
const int32_t* __restrict__ face_index_map,
const scalar_t* __restrict__ face_inv_map,
const scalar_t* __restrict__ weight_map,
scalar_t* __restrict__ grad_depth_map,
scalar_t* __restrict__ grad_faces,
const scalar_t* faces,
const scalar_t* depth_map,
const int32_t* face_index_map,
const scalar_t* face_inv_map,
const scalar_t* weight_map,
scalar_t* grad_depth_map,
scalar_t* grad_faces,
size_t batch_size,
size_t num_faces,
int image_size) {
Expand Down Expand Up @@ -633,8 +633,8 @@ std::vector<at::Tensor> forward_face_index_map_cuda(

const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const int threads = 1024;
const int blocks = (batch_size * num_faces - 1) / threads +1;
const int threads = 512;
const dim3 blocks ((batch_size * num_faces - 1) / threads +1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_face_index_map_cuda", ([&] {
forward_face_index_map_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down Expand Up @@ -675,8 +675,8 @@ std::vector<at::Tensor> forward_texture_sampling_cuda(
const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const auto texture_size = textures.size(2);
const int threads = 1024;
const int blocks = (batch_size * image_size * image_size - 1) / threads + 1;
const int threads = 512;
const dim3 blocks ((batch_size * image_size * image_size - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "forward_texture_sampling_cuda", ([&] {
forward_texture_sampling_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down Expand Up @@ -717,8 +717,8 @@ at::Tensor backward_pixel_map_cuda(

const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const int threads = 1024;
const int blocks = (batch_size * num_faces - 1) / threads + 1;
const int threads = 512;
const dim3 blocks ((batch_size * num_faces - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "backward_pixel_map_cuda", ([&] {
backward_pixel_map_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down Expand Up @@ -755,8 +755,8 @@ at::Tensor backward_textures_cuda(
const auto batch_size = face_index_map.size(0);
const auto image_size = face_index_map.size(1);
const auto texture_size = grad_textures.size(2);
const int threads = 1024;
const int blocks = (batch_size * image_size * image_size - 1) / threads + 1;
const int threads = 512;
const dim3 blocks ((batch_size * image_size * image_size - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(sampling_weight_map.type(), "backward_textures_cuda", ([&] {
backward_textures_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down Expand Up @@ -789,8 +789,8 @@ at::Tensor backward_depth_map_cuda(

const auto batch_size = faces.size(0);
const auto num_faces = faces.size(1);
const int threads = 1024;
const int blocks = (batch_size * image_size * image_size - 1) / threads + 1;
const int threads = 512;
const dim3 blocks ((batch_size * image_size * image_size - 1) / threads + 1);

AT_DISPATCH_FLOATING_TYPES(faces.type(), "backward_depth_map_cuda", ([&] {
backward_depth_map_cuda_kernel<scalar_t><<<blocks, threads>>>(
Expand Down
18 changes: 15 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

CUDA_FLAGS = ['-gencode=arch=compute_30,code=sm_30',
'-gencode=arch=compute_35,code=sm_35',
'-gencode=arch=compute_50,code=sm_50',
'-gencode=arch=compute_52,code=sm_52',
'-gencode=arch=compute_60,code=sm_60',
'-gencode=arch=compute_61,code=sm_61',
'-gencode=arch=compute_70,code=sm_70',
'-gencode=arch=compute_70,code=compute_70']

def test_all():
test_loader = unittest.TestLoader()
test_suite = test_loader.discover('tests', pattern='test_*.py')
Expand All @@ -13,15 +22,18 @@ def test_all():
CUDAExtension('neural_renderer.cuda.load_textures', [
'neural_renderer/cuda/load_textures_cuda.cpp',
'neural_renderer/cuda/load_textures_cuda_kernel.cu',
]),
],
extra_compile_args={'cxx': [], 'nvcc': CUDA_FLAGS}),
CUDAExtension('neural_renderer.cuda.rasterize', [
'neural_renderer/cuda/rasterize_cuda.cpp',
'neural_renderer/cuda/rasterize_cuda_kernel.cu',
]),
],
extra_compile_args={'cxx': [], 'nvcc': CUDA_FLAGS}),
CUDAExtension('neural_renderer.cuda.create_texture_image', [
'neural_renderer/cuda/create_texture_image_cuda.cpp',
'neural_renderer/cuda/create_texture_image_cuda_kernel.cu',
]),
],
extra_compile_args={'cxx': [], 'nvcc': CUDA_FLAGS}),
]

INSTALL_REQUIREMENTS = ['numpy', 'torch', 'torchvision', 'scikit-image', 'tqdm', 'imageio']
Expand Down

0 comments on commit ffc2762

Please sign in to comment.