diff --git a/TensoRF/LICENSE b/TensoRF/LICENSE
new file mode 100644
index 0000000..3eac45c
--- /dev/null
+++ b/TensoRF/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2022 Anpei Chen
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/TensoRF/README.md b/TensoRF/README.md
new file mode 100644
index 0000000..c4f8369
--- /dev/null
+++ b/TensoRF/README.md
@@ -0,0 +1,97 @@
+# TensoRF
+## [Project page](https://apchenstu.github.io/TensoRF/) | [Paper](https://arxiv.org/abs/2203.09517)
+This repository contains a pytorch implementation for the paper: [TensoRF: Tensorial Radiance Fields](https://arxiv.org/abs/2203.09517). Our work present a novel approach to model and reconstruct radiance fields, which achieves super
+**fast** training process, **compact** memory footprint and **state-of-the-art** rendering quality.
+
+
+https://user-images.githubusercontent.com/16453770/158920837-3fafaa17-6ed9-4414-a0b1-a80dc9e10301.mp4
+## Installation
+
+#### Tested on Ubuntu 20.04 + Pytorch 1.10.1
+
+Install environment:
+```
+conda create -n TensoRF python=3.8
+conda activate TensoRF
+pip install torch torchvision
+pip install tqdm scikit-image opencv-python configargparse lpips imageio-ffmpeg kornia lpips tensorboard
+```
+
+
+## Dataset
+* [Synthetic-NeRF](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
+* [Synthetic-NSVF](https://dl.fbaipublicfiles.com/nsvf/dataset/Synthetic_NSVF.zip)
+* [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip)
+* [Forward-facing](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)
+
+
+
+## Quick Start
+The training script is in `train.py`, to train a TensoRF:
+
+```
+python train.py --config configs/lego.txt
+```
+
+
+we provide a few examples in the configuration folder, please note:
+
+ `dataset_name`, choices = ['blender', 'llff', 'nsvf', 'tankstemple'];
+
+ `shadingMode`, choices = ['MLP_Fea', 'SH'];
+
+ `model_name`, choices = ['TensorVMSplit', 'TensorCP'], corresponding to the VM and CP decomposition.
+ You need to uncomment the last a few rows of the configuration file if you want to training with the TensorCP model;
+
+ `n_lamb_sigma` and `n_lamb_sh` are string type refer to the basis number of density and appearance along XYZ
+dimension;
+
+ `N_voxel_init` and `N_voxel_final` control the resolution of matrix and vector;
+
+ `N_vis` and `vis_every` control the visualization during training;
+
+ You need to set `--render_test 1`/`--render_path 1` if you want to render testing views or path after training.
+
+More options refer to the `opt.py`.
+
+### For pretrained checkpoints and results please see:
+[https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm](https://1drv.ms/u/s!Ard0t_p4QWIMgQ2qSEAs7MUk8hVw?e=dc6hBm)
+
+
+
+## Rendering
+
+```
+python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --render_only 1 --render_test 1
+```
+
+You can just simply pass `--render_only 1` and `--ckpt path/to/your/checkpoint` to render images from a pre-trained
+checkpoint. You may also need to specify what you want to render, like `--render_test 1`, `--render_train 1` or `--render_path 1`.
+The rendering results are located in your checkpoint folder.
+
+## Extracting mesh
+You can also export the mesh by passing `--export_mesh 1`:
+```
+python train.py --config configs/lego.txt --ckpt path/to/your/checkpoint --export_mesh 1
+```
+Note: Please re-train the model and don't use the pretrained checkpoints provided by us for mesh extraction,
+because some render parameters has changed.
+
+## Training with your own data
+We provide two options for training on your own image set:
+
+1. Following the instructions in the [NSVF repo](https://github.com/facebookresearch/NSVF#prepare-your-own-dataset), then set the dataset_name to 'tankstemple'.
+2. Calibrating images with the script from [NGP](https://github.com/NVlabs/instant-ngp/blob/master/docs/nerf_dataset_tips.md):
+`python dataLoader/colmap2nerf.py --colmap_matcher exhaustive --run_colmap`, then adjust the datadir in `configs/your_own_data.txt`. Please check the `scene_bbox` and `near_far` if you get abnormal results.
+
+
+## Citation
+If you find our code or paper helps, please consider citing:
+```
+@article{tensorf,
+ title={TensoRF: Tensorial Radiance Fields},
+ author={Chen, Anpei and Xu, Zexiang and Geiger, Andreas and Yu, Jingyi and Su, Hao},
+ journal={arXiv preprint arXiv:2203.09517},
+ year={2022}
+}
+```
diff --git a/TensoRF/configs/chair.txt b/TensoRF/configs/chair.txt
new file mode 100644
index 0000000..4eaf055
--- /dev/null
+++ b/TensoRF/configs/chair.txt
@@ -0,0 +1,44 @@
+
+dataset_name = blender
+datadir = ../nerf_synthetic/chair
+expname = tensorf_lego_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000, 3000, 4000, 5500, 7000]
+update_AlphaMask_list = [2000, 4000]
+
+N_vis = 5
+vis_every = 10000
+
+# lr_init = 0.005 # 0.001 # 0.5 # 0.02 # test
+# lr_basis = 0.005 # 0.001 # 0.02 # 0.001 # test
+
+render_test = 1
+
+n_lamb_sigma = [16, 16, 16]
+n_lamb_sh = [48, 48, 48]
+model_name = TensorVMSplit
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+L1_weight_inital = 0 # 8e-5
+L1_weight_rest = 0 # 4e-5
+rm_weight_mask_thre = 1e-4
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
+
diff --git a/TensoRF/configs/flower.txt b/TensoRF/configs/flower.txt
new file mode 100644
index 0000000..3a1c8ee
--- /dev/null
+++ b/TensoRF/configs/flower.txt
@@ -0,0 +1,35 @@
+
+dataset_name = llff
+datadir = ./data/nerf_llff_data/flower
+expname = tensorf_flower_VM
+basedir = ./log
+
+downsample_train = 4.0
+ndc_ray = 1
+
+n_iters = 25000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 262144000 # 640**3
+upsamp_list = [2000,3000,4000,5500]
+update_AlphaMask_list = [2500]
+
+N_vis = -1 # vis all testing images
+vis_every = 10000
+
+render_test = 1
+render_path = 1
+
+n_lamb_sigma = [16,4,4]
+n_lamb_sh = [48,12,12]
+
+shadingMode = MLP_Fea
+fea2denseAct = relu
+
+view_pe = 0
+fea_pe = 0
+
+TV_weight_density = 1.0
+TV_weight_app = 1.0
+
diff --git a/TensoRF/configs/lego.txt b/TensoRF/configs/lego.txt
new file mode 100644
index 0000000..d768244
--- /dev/null
+++ b/TensoRF/configs/lego.txt
@@ -0,0 +1,41 @@
+
+dataset_name = blender
+datadir = ./data/nerf_synthetic/lego
+expname = tensorf_lego_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000,3000,4000,5500,7000]
+update_AlphaMask_list = [2000,4000]
+
+N_vis = 5
+vis_every = 10000
+
+render_test = 1
+
+n_lamb_sigma = [16,16,16]
+n_lamb_sh = [48,48,48]
+model_name = TensorVMSplit
+
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+L1_weight_inital = 8e-5
+L1_weight_rest = 4e-5
+rm_weight_mask_thre = 1e-4
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
diff --git a/TensoRF/configs/lego2.txt b/TensoRF/configs/lego2.txt
new file mode 100644
index 0000000..331963e
--- /dev/null
+++ b/TensoRF/configs/lego2.txt
@@ -0,0 +1,39 @@
+dataset_name = blender
+datadir = ../nerf_synthetic/lego
+expname = tensorf_lego_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000,3000,4000,5500,7000]
+update_AlphaMask_list = [2000,4000]
+
+n_vis = 5
+vis_every = 10000
+
+render_test = 1
+
+n_lamb_sigma = [16,16,16]
+n_lamb_sh = [48,48,48]
+model_name = PREF
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+L1_weight_inital = 8e-5
+L1_weight_rest = 4e-5
+rm_weight_mask_thre = 1e-4
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
diff --git a/TensoRF/configs/truck.txt b/TensoRF/configs/truck.txt
new file mode 100644
index 0000000..6a4545b
--- /dev/null
+++ b/TensoRF/configs/truck.txt
@@ -0,0 +1,40 @@
+
+
+dataset_name = tankstemple
+datadir = ./data/TanksAndTemple/Truck
+expname = tensorf_truck_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000,3000,4000,5500,7000]
+update_AlphaMask_list = [2000,4000]
+
+N_vis = 5
+vis_every = 10000
+
+render_test = 1
+
+n_lamb_sigma = [16,16,16]
+n_lamb_sh = [48,48,48]
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+TV_weight_density = 0.1
+TV_weight_app = 0.01
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
+
diff --git a/TensoRF/configs/wineholder.txt b/TensoRF/configs/wineholder.txt
new file mode 100644
index 0000000..4b945ea
--- /dev/null
+++ b/TensoRF/configs/wineholder.txt
@@ -0,0 +1,39 @@
+
+dataset_name = nsvf
+datadir = ./data/Synthetic_NSVF/Wineholder
+expname = tensorf_Wineholder_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000,3000,4000,5500,7000]
+update_AlphaMask_list = [2000,4000]
+
+N_vis = 5
+vis_every = 10000
+
+render_test = 1
+
+n_lamb_sigma = [16,16,16]
+n_lamb_sh = [48,48,48]
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+L1_weight_inital = 8e-5
+L1_weight_rest = 4e-5
+rm_weight_mask_thre = 1e-4
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
diff --git a/TensoRF/configs/your_own_data.txt b/TensoRF/configs/your_own_data.txt
new file mode 100644
index 0000000..6d3b0a2
--- /dev/null
+++ b/TensoRF/configs/your_own_data.txt
@@ -0,0 +1,45 @@
+
+dataset_name = own_data
+datadir = ./data/xxx
+expname = tensorf_xxx_VM
+basedir = ./log
+
+n_iters = 30000
+batch_size = 4096
+
+N_voxel_init = 2097156 # 128**3
+N_voxel_final = 27000000 # 300**3
+upsamp_list = [2000,3000,4000,5500,7000]
+update_AlphaMask_list = [2000,4000]
+
+N_vis = 5
+vis_every = 10000
+
+render_test = 1
+
+n_lamb_sigma = [16,16,16]
+n_lamb_sh = [48,48,48]
+model_name = TensorVMSplit
+
+
+shadingMode = MLP_Fea
+fea2denseAct = softplus
+
+view_pe = 2
+fea_pe = 2
+
+view_pe = 2
+fea_pe = 2
+
+TV_weight_density = 0.1
+TV_weight_app = 0.01
+
+rm_weight_mask_thre = 1e-4
+
+## please uncomment following configuration if hope to training on cp model
+#model_name = TensorCP
+#n_lamb_sigma = [96]
+#n_lamb_sh = [288]
+#N_voxel_final = 125000000 # 500**3
+#L1_weight_inital = 1e-5
+#L1_weight_rest = 1e-5
diff --git a/TensoRF/dataLoader/__init__.py b/TensoRF/dataLoader/__init__.py
new file mode 100644
index 0000000..62d441b
--- /dev/null
+++ b/TensoRF/dataLoader/__init__.py
@@ -0,0 +1,13 @@
+from .llff import LLFFDataset
+from .blender import BlenderDataset
+from .nsvf import NSVF
+from .tankstemple import TanksTempleDataset
+from .your_own_data import YourOwnDataset
+
+
+
+dataset_dict = {'blender': BlenderDataset,
+ 'llff':LLFFDataset,
+ 'tankstemple':TanksTempleDataset,
+ 'nsvf':NSVF,
+ 'own_data':YourOwnDataset}
\ No newline at end of file
diff --git a/TensoRF/dataLoader/blender.py b/TensoRF/dataLoader/blender.py
new file mode 100644
index 0000000..630ecd0
--- /dev/null
+++ b/TensoRF/dataLoader/blender.py
@@ -0,0 +1,127 @@
+import torch,cv2
+from torch.utils.data import Dataset
+import json
+from tqdm import tqdm
+import os
+from PIL import Image
+from torchvision import transforms as T
+
+
+from .ray_utils import *
+
+
+class BlenderDataset(Dataset):
+ def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
+
+ self.N_vis = N_vis
+ self.root_dir = datadir
+ self.split = split
+ self.is_stack = is_stack
+ self.img_wh = (int(800/downsample),int(800/downsample))
+ self.define_transforms()
+
+ self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
+ self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ self.read_meta()
+ self.define_proj_mat()
+
+ self.white_bg = True
+ self.near_far = [2.0,6.0]
+
+ self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
+ self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
+ self.downsample=downsample
+
+ def read_depth(self, filename):
+ depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
+ return depth
+
+ def read_meta(self):
+
+ with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
+ self.meta = json.load(f)
+
+ w, h = self.img_wh
+ self.focal = 0.5 * 800 / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
+ self.focal *= self.img_wh[0] / 800 # modify focal length to match size self.img_wh
+
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ self.directions = get_ray_directions(h, w, [self.focal,self.focal]) # (h, w, 3)
+ self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
+ self.intrinsics = torch.tensor([[self.focal,0,w/2],[0,self.focal,h/2],[0,0,1]]).float()
+
+ self.image_paths = []
+ self.poses = []
+ self.all_rays = []
+ self.all_rgbs = []
+ self.all_masks = []
+ self.all_depth = []
+ self.downsample=1.0
+
+ img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
+ idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
+ for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
+
+ frame = self.meta['frames'][i]
+ pose = np.array(frame['transform_matrix']) @ self.blender2opencv
+ c2w = torch.FloatTensor(pose)
+ self.poses += [c2w]
+
+ image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
+ self.image_paths += [image_path]
+ img = Image.open(image_path)
+
+ if self.downsample!=1.0:
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (4, h, w)
+ img = img.view(4, -1).permute(1, 0) # (h*w, 4) RGBA
+ img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
+ self.all_rgbs += [img]
+
+
+ rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
+ self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
+
+
+ self.poses = torch.stack(self.poses)
+ if not self.is_stack:
+ self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
+
+# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
+ # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
+
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+ def define_proj_mat(self):
+ self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
+
+ def world2ndc(self,points,lindisp=None):
+ device = points.device
+ return (points - self.center.to(device)) / self.radius.to(device)
+
+ def __len__(self):
+ return len(self.all_rgbs)
+
+ def __getitem__(self, idx):
+
+ if self.split == 'train': # use data in the buffers
+ sample = {'rays': self.all_rays[idx],
+ 'rgbs': self.all_rgbs[idx]}
+
+ else: # create data for each image separately
+
+ img = self.all_rgbs[idx]
+ rays = self.all_rays[idx]
+ mask = self.all_masks[idx] # for quantity evaluation
+
+ sample = {'rays': rays,
+ 'rgbs': img,
+ 'mask': mask}
+ return sample
diff --git a/TensoRF/dataLoader/colmap2nerf.py b/TensoRF/dataLoader/colmap2nerf.py
new file mode 100644
index 0000000..b91bbf0
--- /dev/null
+++ b/TensoRF/dataLoader/colmap2nerf.py
@@ -0,0 +1,305 @@
+#!/usr/bin/env python3
+
+# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
+#
+# NVIDIA CORPORATION and its licensors retain all intellectual property
+# and proprietary rights in and to this software, related documentation
+# and any modifications thereto. Any use, reproduction, disclosure or
+# distribution of this software and related documentation without an express
+# license agreement from NVIDIA CORPORATION is strictly prohibited.
+
+import argparse
+import os
+from pathlib import Path, PurePosixPath
+
+import numpy as np
+import json
+import sys
+import math
+import cv2
+import os
+import shutil
+
+def parse_args():
+ parser = argparse.ArgumentParser(description="convert a text colmap export to nerf format transforms.json; optionally convert video to images, and optionally run colmap in the first place")
+
+ parser.add_argument("--video_in", default="", help="run ffmpeg first to convert a provided video file into a set of images. uses the video_fps parameter also")
+ parser.add_argument("--video_fps", default=2)
+ parser.add_argument("--time_slice", default="", help="time (in seconds) in the format t1,t2 within which the images should be generated from the video. eg: \"--time_slice '10,300'\" will generate images only from 10th second to 300th second of the video")
+ parser.add_argument("--run_colmap", action="store_true", help="run colmap first on the image folder")
+ parser.add_argument("--colmap_matcher", default="sequential", choices=["exhaustive","sequential","spatial","transitive","vocab_tree"], help="select which matcher colmap should use. sequential for videos, exhaustive for adhoc images")
+ parser.add_argument("--colmap_db", default="colmap.db", help="colmap database filename")
+ parser.add_argument("--images", default="images", help="input path to the images")
+ parser.add_argument("--text", default="colmap_text", help="input path to the colmap text files (set automatically if run_colmap is used)")
+ parser.add_argument("--aabb_scale", default=16, choices=["1","2","4","8","16"], help="large scene scale factor. 1=scene fits in unit cube; power of 2 up to 16")
+ parser.add_argument("--skip_early", default=0, help="skip this many images from the start")
+ parser.add_argument("--out", default="transforms.json", help="output path")
+ args = parser.parse_args()
+ return args
+
+def do_system(arg):
+ print(f"==== running: {arg}")
+ err = os.system(arg)
+ if err:
+ print("FATAL: command failed")
+ sys.exit(err)
+
+def run_ffmpeg(args):
+ if not os.path.isabs(args.images):
+ args.images = os.path.join(os.path.dirname(args.video_in), args.images)
+ images = args.images
+ video = args.video_in
+ fps = float(args.video_fps) or 1.0
+ print(f"running ffmpeg with input video file={video}, output image folder={images}, fps={fps}.")
+ if (input(f"warning! folder '{images}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
+ sys.exit(1)
+ try:
+ shutil.rmtree(images)
+ except:
+ pass
+ do_system(f"mkdir {images}")
+
+ time_slice_value = ""
+ time_slice = args.time_slice
+ if time_slice:
+ start, end = time_slice.split(",")
+ time_slice_value = f",select='between(t\,{start}\,{end})'"
+ do_system(f"ffmpeg -i {video} -qscale:v 1 -qmin 1 -vf \"fps={fps}{time_slice_value}\" {images}/%04d.jpg")
+
+def run_colmap(args):
+ db=args.colmap_db
+ images=args.images
+ db_noext=str(Path(db).with_suffix(""))
+
+ if args.text=="text":
+ args.text=db_noext+"_text"
+ text=args.text
+ sparse=db_noext+"_sparse"
+ print(f"running colmap with:\n\tdb={db}\n\timages={images}\n\tsparse={sparse}\n\ttext={text}")
+ if (input(f"warning! folders '{sparse}' and '{text}' will be deleted/replaced. continue? (Y/n)").lower().strip()+"y")[:1] != "y":
+ sys.exit(1)
+ if os.path.exists(db):
+ os.remove(db)
+ do_system(f"colmap feature_extractor --ImageReader.camera_model OPENCV --SiftExtraction.estimate_affine_shape=true --SiftExtraction.domain_size_pooling=true --ImageReader.single_camera 1 --database_path {db} --image_path {images}")
+ do_system(f"colmap {args.colmap_matcher}_matcher --SiftMatching.guided_matching=true --database_path {db}")
+ try:
+ shutil.rmtree(sparse)
+ except:
+ pass
+ do_system(f"mkdir {sparse}")
+ do_system(f"colmap mapper --database_path {db} --image_path {images} --output_path {sparse}")
+ do_system(f"colmap bundle_adjuster --input_path {sparse}/0 --output_path {sparse}/0 --BundleAdjustment.refine_principal_point 1")
+ try:
+ shutil.rmtree(text)
+ except:
+ pass
+ do_system(f"mkdir {text}")
+ do_system(f"colmap model_converter --input_path {sparse}/0 --output_path {text} --output_type TXT")
+
+def variance_of_laplacian(image):
+ return cv2.Laplacian(image, cv2.CV_64F).var()
+
+def sharpness(imagePath):
+ image = cv2.imread(imagePath)
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
+ fm = variance_of_laplacian(gray)
+ return fm
+
+def qvec2rotmat(qvec):
+ return np.array([
+ [
+ 1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]
+ ], [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]
+ ], [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1]**2 - 2 * qvec[2]**2
+ ]
+ ])
+
+def rotmat(a, b):
+ a, b = a / np.linalg.norm(a), b / np.linalg.norm(b)
+ v = np.cross(a, b)
+ c = np.dot(a, b)
+ s = np.linalg.norm(v)
+ kmat = np.array([[0, -v[2], v[1]], [v[2], 0, -v[0]], [-v[1], v[0], 0]])
+ return np.eye(3) + kmat + kmat.dot(kmat) * ((1 - c) / (s ** 2 + 1e-10))
+
+def closest_point_2_lines(oa, da, ob, db): # returns point closest to both rays of form o+t*d, and a weight factor that goes to 0 if the lines are parallel
+ da = da / np.linalg.norm(da)
+ db = db / np.linalg.norm(db)
+ c = np.cross(da, db)
+ denom = np.linalg.norm(c)**2
+ t = ob - oa
+ ta = np.linalg.det([t, db, c]) / (denom + 1e-10)
+ tb = np.linalg.det([t, da, c]) / (denom + 1e-10)
+ if ta > 0:
+ ta = 0
+ if tb > 0:
+ tb = 0
+ return (oa+ta*da+ob+tb*db) * 0.5, denom
+
+if __name__ == "__main__":
+ args = parse_args()
+ if args.video_in != "":
+ run_ffmpeg(args)
+ if args.run_colmap:
+ run_colmap(args)
+ AABB_SCALE = int(args.aabb_scale)
+ SKIP_EARLY = int(args.skip_early)
+ IMAGE_FOLDER = args.images
+ TEXT_FOLDER = args.text
+ OUT_PATH = args.out
+ print(f"outputting to {OUT_PATH}...")
+ with open(os.path.join(TEXT_FOLDER,"cameras.txt"), "r") as f:
+ angle_x = math.pi / 2
+ for line in f:
+ # 1 SIMPLE_RADIAL 2048 1536 1580.46 1024 768 0.0045691
+ # 1 OPENCV 3840 2160 3178.27 3182.09 1920 1080 0.159668 -0.231286 -0.00123982 0.00272224
+ # 1 RADIAL 1920 1080 1665.1 960 540 0.0672856 -0.0761443
+ if line[0] == "#":
+ continue
+ els = line.split(" ")
+ w = float(els[2])
+ h = float(els[3])
+ fl_x = float(els[4])
+ fl_y = float(els[4])
+ k1 = 0
+ k2 = 0
+ p1 = 0
+ p2 = 0
+ cx = w / 2
+ cy = h / 2
+ if els[1] == "SIMPLE_PINHOLE":
+ cx = float(els[5])
+ cy = float(els[6])
+ elif els[1] == "PINHOLE":
+ fl_y = float(els[5])
+ cx = float(els[6])
+ cy = float(els[7])
+ elif els[1] == "SIMPLE_RADIAL":
+ cx = float(els[5])
+ cy = float(els[6])
+ k1 = float(els[7])
+ elif els[1] == "RADIAL":
+ cx = float(els[5])
+ cy = float(els[6])
+ k1 = float(els[7])
+ k2 = float(els[8])
+ elif els[1] == "OPENCV":
+ fl_y = float(els[5])
+ cx = float(els[6])
+ cy = float(els[7])
+ k1 = float(els[8])
+ k2 = float(els[9])
+ p1 = float(els[10])
+ p2 = float(els[11])
+ else:
+ print("unknown camera model ", els[1])
+ # fl = 0.5 * w / tan(0.5 * angle_x);
+ angle_x = math.atan(w / (fl_x * 2)) * 2
+ angle_y = math.atan(h / (fl_y * 2)) * 2
+ fovx = angle_x * 180 / math.pi
+ fovy = angle_y * 180 / math.pi
+
+ print(f"camera:\n\tres={w,h}\n\tcenter={cx,cy}\n\tfocal={fl_x,fl_y}\n\tfov={fovx,fovy}\n\tk={k1,k2} p={p1,p2} ")
+
+ with open(os.path.join(TEXT_FOLDER,"images.txt"), "r") as f:
+ i = 0
+ bottom = np.array([0.0, 0.0, 0.0, 1.0]).reshape([1, 4])
+ out = {
+ "camera_angle_x": angle_x,
+ "camera_angle_y": angle_y,
+ "fl_x": fl_x,
+ "fl_y": fl_y,
+ "k1": k1,
+ "k2": k2,
+ "p1": p1,
+ "p2": p2,
+ "cx": cx,
+ "cy": cy,
+ "w": w,
+ "h": h,
+ "aabb_scale": AABB_SCALE,
+ "frames": [],
+ }
+
+ up = np.zeros(3)
+ for line in f:
+ line = line.strip()
+ if line[0] == "#":
+ continue
+ i = i + 1
+ if i < SKIP_EARLY*2:
+ continue
+ if i % 2 == 1:
+ elems=line.split(" ") # 1-4 is quat, 5-7 is trans, 9ff is filename (9, if filename contains no spaces)
+ #name = str(PurePosixPath(Path(IMAGE_FOLDER, elems[9])))
+ # why is this requireing a relitive path while using ^
+ image_rel = os.path.relpath(IMAGE_FOLDER)
+ name = str(f"./{image_rel}/{'_'.join(elems[9:])}")
+ b=sharpness(name)
+ print(name, "sharpness=",b)
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ R = qvec2rotmat(-qvec)
+ t = tvec.reshape([3,1])
+ m = np.concatenate([np.concatenate([R, t], 1), bottom], 0)
+ c2w = np.linalg.inv(m)
+ c2w[0:3,2] *= -1 # flip the y and z axis
+ c2w[0:3,1] *= -1
+ c2w = c2w[[1,0,2,3],:] # swap y and z
+ c2w[2,:] *= -1 # flip whole world upside down
+
+ up += c2w[0:3,1]
+
+ frame={"file_path":name,"sharpness":b,"transform_matrix": c2w}
+ out["frames"].append(frame)
+ nframes = len(out["frames"])
+ up = up / np.linalg.norm(up)
+ print("up vector was", up)
+ R = rotmat(up,[0,0,1]) # rotate up vector to [0,0,1]
+ R = np.pad(R,[0,1])
+ R[-1, -1] = 1
+
+
+ for f in out["frames"]:
+ f["transform_matrix"] = np.matmul(R, f["transform_matrix"]) # rotate up to be the z axis
+
+ # find a central point they are all looking at
+ print("computing center of attention...")
+ totw = 0.0
+ totp = np.array([0.0, 0.0, 0.0])
+ for f in out["frames"]:
+ mf = f["transform_matrix"][0:3,:]
+ for g in out["frames"]:
+ mg = g["transform_matrix"][0:3,:]
+ p, w = closest_point_2_lines(mf[:,3], mf[:,2], mg[:,3], mg[:,2])
+ if w > 0.01:
+ totp += p*w
+ totw += w
+ totp /= totw
+ print(totp) # the cameras are looking at totp
+ for f in out["frames"]:
+ f["transform_matrix"][0:3,3] -= totp
+
+ avglen = 0.
+ for f in out["frames"]:
+ avglen += np.linalg.norm(f["transform_matrix"][0:3,3])
+ avglen /= nframes
+ print("avg camera distance from origin", avglen)
+ for f in out["frames"]:
+ f["transform_matrix"][0:3,3] *= 4.0 / avglen # scale to "nerf sized"
+
+ for f in out["frames"]:
+ f["transform_matrix"] = f["transform_matrix"].tolist()
+ print(nframes,"frames")
+ print(f"writing {OUT_PATH}")
+ with open(OUT_PATH, "w") as outfile:
+ json.dump(out, outfile, indent=2)
\ No newline at end of file
diff --git a/TensoRF/dataLoader/llff.py b/TensoRF/dataLoader/llff.py
new file mode 100644
index 0000000..3b31db9
--- /dev/null
+++ b/TensoRF/dataLoader/llff.py
@@ -0,0 +1,242 @@
+import torch
+from torch.utils.data import Dataset
+import glob
+import numpy as np
+import os
+from PIL import Image
+from torchvision import transforms as T
+
+from .ray_utils import *
+
+
+def normalize(v):
+ """Normalize a vector."""
+ return v / np.linalg.norm(v)
+
+
+def average_poses(poses):
+ """
+ Calculate the average pose, which is then used to center all poses
+ using @center_poses. Its computation is as follows:
+ 1. Compute the center: the average of pose centers.
+ 2. Compute the z axis: the normalized average z axis.
+ 3. Compute axis y': the average y axis.
+ 4. Compute x' = y' cross product z, then normalize it as the x axis.
+ 5. Compute the y axis: z cross product x.
+
+ Note that at step 3, we cannot directly use y' as y axis since it's
+ not necessarily orthogonal to z axis. We need to pass from x to y.
+ Inputs:
+ poses: (N_images, 3, 4)
+ Outputs:
+ pose_avg: (3, 4) the average pose
+ """
+ # 1. Compute the center
+ center = poses[..., 3].mean(0) # (3)
+
+ # 2. Compute the z axis
+ z = normalize(poses[..., 2].mean(0)) # (3)
+
+ # 3. Compute axis y' (no need to normalize as it's not the final output)
+ y_ = poses[..., 1].mean(0) # (3)
+
+ # 4. Compute the x axis
+ x = normalize(np.cross(z, y_)) # (3)
+
+ # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
+ y = np.cross(x, z) # (3)
+
+ pose_avg = np.stack([x, y, z, center], 1) # (3, 4)
+
+ return pose_avg
+
+
+def center_poses(poses, blender2opencv):
+ """
+ Center the poses so that we can use NDC.
+ See https://github.com/bmild/nerf/issues/34
+ Inputs:
+ poses: (N_images, 3, 4)
+ Outputs:
+ poses_centered: (N_images, 3, 4) the centered poses
+ pose_avg: (3, 4) the average pose
+ """
+ poses = poses @ blender2opencv
+ pose_avg = average_poses(poses) # (3, 4)
+ pose_avg_homo = np.eye(4)
+ pose_avg_homo[:3] = pose_avg # convert to homogeneous coordinate for faster computation
+ pose_avg_homo = pose_avg_homo
+ # by simply adding 0, 0, 0, 1 as the last row
+ last_row = np.tile(np.array([0, 0, 0, 1]), (len(poses), 1, 1)) # (N_images, 1, 4)
+ poses_homo = \
+ np.concatenate([poses, last_row], 1) # (N_images, 4, 4) homogeneous coordinate
+
+ poses_centered = np.linalg.inv(pose_avg_homo) @ poses_homo # (N_images, 4, 4)
+ # poses_centered = poses_centered @ blender2opencv
+ poses_centered = poses_centered[:, :3] # (N_images, 3, 4)
+
+ return poses_centered, pose_avg_homo
+
+
+def viewmatrix(z, up, pos):
+ vec2 = normalize(z)
+ vec1_avg = up
+ vec0 = normalize(np.cross(vec1_avg, vec2))
+ vec1 = normalize(np.cross(vec2, vec0))
+ m = np.eye(4)
+ m[:3] = np.stack([-vec0, vec1, vec2, pos], 1)
+ return m
+
+
+def render_path_spiral(c2w, up, rads, focal, zdelta, zrate, N_rots=2, N=120):
+ render_poses = []
+ rads = np.array(list(rads) + [1.])
+
+ for theta in np.linspace(0., 2. * np.pi * N_rots, N + 1)[:-1]:
+ c = np.dot(c2w[:3, :4], np.array([np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.]) * rads)
+ z = normalize(c - np.dot(c2w[:3, :4], np.array([0, 0, -focal, 1.])))
+ render_poses.append(viewmatrix(z, up, c))
+ return render_poses
+
+
+def get_spiral(c2ws_all, near_fars, rads_scale=1.0, N_views=120):
+ # center pose
+ c2w = average_poses(c2ws_all)
+
+ # Get average pose
+ up = normalize(c2ws_all[:, :3, 1].sum(0))
+
+ # Find a reasonable "focus depth" for this dataset
+ dt = 0.75
+ close_depth, inf_depth = near_fars.min() * 0.9, near_fars.max() * 5.0
+ focal = 1.0 / (((1.0 - dt) / close_depth + dt / inf_depth))
+
+ # Get radii for spiral path
+ zdelta = near_fars.min() * .2
+ tt = c2ws_all[:, :3, 3]
+ rads = np.percentile(np.abs(tt), 90, 0) * rads_scale
+ render_poses = render_path_spiral(c2w, up, rads, focal, zdelta, zrate=.5, N=N_views)
+ return np.stack(render_poses)
+
+
+class LLFFDataset(Dataset):
+ def __init__(self, datadir, split='train', downsample=4, is_stack=False, hold_every=8):
+ """
+ spheric_poses: whether the images are taken in a spheric inward-facing manner
+ default: False (forward-facing)
+ val_num: number of val images (used for multigpu training, validate same image for all gpus)
+ """
+
+ self.root_dir = datadir
+ self.split = split
+ self.hold_every = hold_every
+ self.is_stack = is_stack
+ self.downsample = downsample
+ self.define_transforms()
+
+ self.blender2opencv = np.eye(4)#np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ self.read_meta()
+ self.white_bg = False
+
+ # self.near_far = [np.min(self.near_fars[:,0]),np.max(self.near_fars[:,1])]
+ self.near_far = [0.0, 1.0]
+ self.scene_bbox = torch.tensor([[-1.5, -1.67, -1.0], [1.5, 1.67, 1.0]])
+ # self.scene_bbox = torch.tensor([[-1.67, -1.5, -1.0], [1.67, 1.5, 1.0]])
+ self.center = torch.mean(self.scene_bbox, dim=0).float().view(1, 1, 3)
+ self.invradius = 1.0 / (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
+
+ def read_meta(self):
+
+
+ poses_bounds = np.load(os.path.join(self.root_dir, 'poses_bounds.npy')) # (N_images, 17)
+ self.image_paths = sorted(glob.glob(os.path.join(self.root_dir, 'images_4/*')))
+ # load full resolution image then resize
+ if self.split in ['train', 'test']:
+ assert len(poses_bounds) == len(self.image_paths), \
+ 'Mismatch between number of images and number of poses! Please rerun COLMAP!'
+
+ poses = poses_bounds[:, :15].reshape(-1, 3, 5) # (N_images, 3, 5)
+ self.near_fars = poses_bounds[:, -2:] # (N_images, 2)
+ hwf = poses[:, :, -1]
+
+ # Step 1: rescale focal length according to training resolution
+ H, W, self.focal = poses[0, :, -1] # original intrinsics, same for all images
+ self.img_wh = np.array([int(W / self.downsample), int(H / self.downsample)])
+ self.focal = [self.focal * self.img_wh[0] / W, self.focal * self.img_wh[1] / H]
+
+ # Step 2: correct poses
+ # Original poses has rotation in form "down right back", change to "right up back"
+ # See https://github.com/bmild/nerf/issues/34
+ poses = np.concatenate([poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1)
+ # (N_images, 3, 4) exclude H, W, focal
+ self.poses, self.pose_avg = center_poses(poses, self.blender2opencv)
+
+ # Step 3: correct scale so that the nearest depth is at a little more than 1.0
+ # See https://github.com/bmild/nerf/issues/34
+ near_original = self.near_fars.min()
+ scale_factor = near_original * 0.75 # 0.75 is the default parameter
+ # the nearest depth is at 1/0.75=1.33
+ self.near_fars /= scale_factor
+ self.poses[..., 3] /= scale_factor
+
+ # build rendering path
+ N_views, N_rots = 120, 2
+ tt = self.poses[:, :3, 3] # ptstocam(poses[:3,3,:].T, c2w).T
+ up = normalize(self.poses[:, :3, 1].sum(0))
+ rads = np.percentile(np.abs(tt), 90, 0)
+
+ self.render_path = get_spiral(self.poses, self.near_fars, N_views=N_views)
+
+ # distances_from_center = np.linalg.norm(self.poses[..., 3], axis=1)
+ # val_idx = np.argmin(distances_from_center) # choose val image as the closest to
+ # center image
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ W, H = self.img_wh
+ self.directions = get_ray_directions_blender(H, W, self.focal) # (H, W, 3)
+
+ average_pose = average_poses(self.poses)
+ dists = np.sum(np.square(average_pose[:3, 3] - self.poses[:, :3, 3]), -1)
+ i_test = np.arange(0, self.poses.shape[0], self.hold_every) # [np.argmin(dists)]
+ img_list = i_test if self.split != 'train' else list(set(np.arange(len(self.poses))) - set(i_test))
+
+ # use first N_images-1 to train, the LAST is val
+ self.all_rays = []
+ self.all_rgbs = []
+ for i in img_list:
+ image_path = self.image_paths[i]
+ c2w = torch.FloatTensor(self.poses[i])
+
+ img = Image.open(image_path).convert('RGB')
+ if self.downsample != 1.0:
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (3, h, w)
+
+ img = img.view(3, -1).permute(1, 0) # (h*w, 3) RGB
+ self.all_rgbs += [img]
+ rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
+ rays_o, rays_d = ndc_rays_blender(H, W, self.focal[0], 1.0, rays_o, rays_d)
+ # viewdir = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
+
+ self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
+
+ if not self.is_stack:
+ self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w,3)
+ else:
+ self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h,w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
+
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+ def __len__(self):
+ return len(self.all_rgbs)
+
+ def __getitem__(self, idx):
+
+ sample = {'rays': self.all_rays[idx],
+ 'rgbs': self.all_rgbs[idx]}
+
+ return sample
\ No newline at end of file
diff --git a/TensoRF/dataLoader/nsvf.py b/TensoRF/dataLoader/nsvf.py
new file mode 100644
index 0000000..f9dc0a9
--- /dev/null
+++ b/TensoRF/dataLoader/nsvf.py
@@ -0,0 +1,160 @@
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+import os
+from PIL import Image
+from torchvision import transforms as T
+
+from .ray_utils import *
+
+trans_t = lambda t : torch.Tensor([
+ [1,0,0,0],
+ [0,1,0,0],
+ [0,0,1,t],
+ [0,0,0,1]]).float()
+
+rot_phi = lambda phi : torch.Tensor([
+ [1,0,0,0],
+ [0,np.cos(phi),-np.sin(phi),0],
+ [0,np.sin(phi), np.cos(phi),0],
+ [0,0,0,1]]).float()
+
+rot_theta = lambda th : torch.Tensor([
+ [np.cos(th),0,-np.sin(th),0],
+ [0,1,0,0],
+ [np.sin(th),0, np.cos(th),0],
+ [0,0,0,1]]).float()
+
+
+def pose_spherical(theta, phi, radius):
+ c2w = trans_t(radius)
+ c2w = rot_phi(phi/180.*np.pi) @ c2w
+ c2w = rot_theta(theta/180.*np.pi) @ c2w
+ c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w
+ return c2w
+
+class NSVF(Dataset):
+ """NSVF Generic Dataset."""
+ def __init__(self, datadir, split='train', downsample=1.0, wh=[800,800], is_stack=False):
+ self.root_dir = datadir
+ self.split = split
+ self.is_stack = is_stack
+ self.downsample = downsample
+ self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
+ self.define_transforms()
+
+ self.white_bg = True
+ self.near_far = [0.5,6.0]
+ self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)
+ self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ self.read_meta()
+ self.define_proj_mat()
+
+ self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
+ self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
+
+ def bbox2corners(self):
+ corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
+ for i in range(3):
+ corners[i,[0,1],i] = corners[i,[1,0],i]
+ return corners.view(-1,3)
+
+
+ def read_meta(self):
+ with open(os.path.join(self.root_dir, "intrinsics.txt")) as f:
+ focal = float(f.readline().split()[0])
+ self.intrinsics = np.array([[focal,0,400.0],[0,focal,400.0],[0,0,1]])
+ self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([800,800])).reshape(2,1)
+
+ pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
+ img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
+
+ if self.split == 'train':
+ pose_files = [x for x in pose_files if x.startswith('0_')]
+ img_files = [x for x in img_files if x.startswith('0_')]
+ elif self.split == 'val':
+ pose_files = [x for x in pose_files if x.startswith('1_')]
+ img_files = [x for x in img_files if x.startswith('1_')]
+ elif self.split == 'test':
+ test_pose_files = [x for x in pose_files if x.startswith('2_')]
+ test_img_files = [x for x in img_files if x.startswith('2_')]
+ if len(test_pose_files) == 0:
+ test_pose_files = [x for x in pose_files if x.startswith('1_')]
+ test_img_files = [x for x in img_files if x.startswith('1_')]
+ pose_files = test_pose_files
+ img_files = test_img_files
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
+ self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
+
+
+ self.render_path = torch.stack([pose_spherical(angle, -30.0, 4.0) for angle in np.linspace(-180,180,40+1)[:-1]], 0)
+
+ self.poses = []
+ self.all_rays = []
+ self.all_rgbs = []
+
+ assert len(img_files) == len(pose_files)
+ for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
+ image_path = os.path.join(self.root_dir, 'rgb', img_fname)
+ img = Image.open(image_path)
+ if self.downsample!=1.0:
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (4, h, w)
+ img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
+ if img.shape[-1]==4:
+ img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
+ self.all_rgbs += [img]
+
+ c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname)) #@ self.blender2opencv
+ c2w = torch.FloatTensor(c2w)
+ self.poses.append(c2w) # C2W
+ rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
+ self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
+
+# w2c = torch.inverse(c2w)
+#
+
+ self.poses = torch.stack(self.poses)
+ if 'train' == self.split:
+ if self.is_stack:
+ self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
+
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+ def define_proj_mat(self):
+ self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
+
+ def world2ndc(self, points):
+ device = points.device
+ return (points - self.center.to(device)) / self.radius.to(device)
+
+ def __len__(self):
+ if self.split == 'train':
+ return len(self.all_rays)
+ return len(self.all_rgbs)
+
+ def __getitem__(self, idx):
+
+ if self.split == 'train': # use data in the buffers
+ sample = {'rays': self.all_rays[idx],
+ 'rgbs': self.all_rgbs[idx]}
+
+ else: # create data for each image separately
+
+ img = self.all_rgbs[idx]
+ rays = self.all_rays[idx]
+
+ sample = {'rays': rays,
+ 'rgbs': img}
+ return sample
\ No newline at end of file
diff --git a/TensoRF/dataLoader/ray_utils.py b/TensoRF/dataLoader/ray_utils.py
new file mode 100644
index 0000000..c7f0437
--- /dev/null
+++ b/TensoRF/dataLoader/ray_utils.py
@@ -0,0 +1,275 @@
+import torch, re
+import numpy as np
+from torch import searchsorted
+from kornia import create_meshgrid
+
+
+# from utils import index_point_feature
+
+def depth2dist(z_vals, cos_angle):
+ # z_vals: [N_ray N_sample]
+ device = z_vals.device
+ dists = z_vals[..., 1:] - z_vals[..., :-1]
+ dists = torch.cat([dists, torch.Tensor([1e10]).to(device).expand(dists[..., :1].shape)], -1) # [N_rays, N_samples]
+ dists = dists * cos_angle.unsqueeze(-1)
+ return dists
+
+
+def ndc2dist(ndc_pts, cos_angle):
+ dists = torch.norm(ndc_pts[:, 1:] - ndc_pts[:, :-1], dim=-1)
+ dists = torch.cat([dists, 1e10 * cos_angle.unsqueeze(-1)], -1) # [N_rays, N_samples]
+ return dists
+
+
+def get_ray_directions(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0] + 0.5
+
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], (j - cent[1]) / focal[1], torch.ones_like(i)], -1) # (H, W, 3)
+
+ return directions
+
+
+def get_ray_directions_blender(H, W, focal, center=None):
+ """
+ Get ray directions for all pixels in camera coordinate.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ H, W, focal: image height, width and focal length
+ Outputs:
+ directions: (H, W, 3), the direction of the rays in camera coordinate
+ """
+ grid = create_meshgrid(H, W, normalized_coordinates=False)[0]+0.5
+ i, j = grid.unbind(-1)
+ # the direction here is without +0.5 pixel centering as calibration is not so accurate
+ # see https://github.com/bmild/nerf/issues/24
+ cent = center if center is not None else [W / 2, H / 2]
+ directions = torch.stack([(i - cent[0]) / focal[0], -(j - cent[1]) / focal[1], -torch.ones_like(i)],
+ -1) # (H, W, 3)
+
+ return directions
+
+
+def get_rays(directions, c2w):
+ """
+ Get ray origin and normalized directions in world coordinate for all pixels in one image.
+ Reference: https://www.scratchapixel.com/lessons/3d-basic-rendering/
+ ray-tracing-generating-camera-rays/standard-coordinate-systems
+ Inputs:
+ directions: (H, W, 3) precomputed ray directions in camera coordinate
+ c2w: (3, 4) transformation matrix from camera coordinate to world coordinate
+ Outputs:
+ rays_o: (H*W, 3), the origin of the rays in world coordinate
+ rays_d: (H*W, 3), the normalized direction of the rays in world coordinate
+ """
+ # Rotate ray directions from camera coordinate to the world coordinate
+ rays_d = directions @ c2w[:3, :3].T # (H, W, 3)
+ # rays_d = rays_d / torch.norm(rays_d, dim=-1, keepdim=True)
+ # The origin of all rays is the camera origin in world coordinate
+ rays_o = c2w[:3, 3].expand(rays_d.shape) # (H, W, 3)
+
+ rays_d = rays_d.view(-1, 3)
+ rays_o = rays_o.view(-1, 3)
+
+ return rays_o, rays_d
+
+
+def ndc_rays_blender(H, W, focal, near, rays_o, rays_d):
+ # Shift ray origins to near plane
+ t = -(near + rays_o[..., 2]) / rays_d[..., 2]
+ rays_o = rays_o + t[..., None] * rays_d
+
+ # Projection
+ o0 = -1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
+ o1 = -1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
+ o2 = 1. + 2. * near / rays_o[..., 2]
+
+ d0 = -1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
+ d1 = -1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
+ d2 = -2. * near / rays_o[..., 2]
+
+ rays_o = torch.stack([o0, o1, o2], -1)
+ rays_d = torch.stack([d0, d1, d2], -1)
+
+ return rays_o, rays_d
+
+def ndc_rays(H, W, focal, near, rays_o, rays_d):
+ # Shift ray origins to near plane
+ t = (near - rays_o[..., 2]) / rays_d[..., 2]
+ rays_o = rays_o + t[..., None] * rays_d
+
+ # Projection
+ o0 = 1. / (W / (2. * focal)) * rays_o[..., 0] / rays_o[..., 2]
+ o1 = 1. / (H / (2. * focal)) * rays_o[..., 1] / rays_o[..., 2]
+ o2 = 1. - 2. * near / rays_o[..., 2]
+
+ d0 = 1. / (W / (2. * focal)) * (rays_d[..., 0] / rays_d[..., 2] - rays_o[..., 0] / rays_o[..., 2])
+ d1 = 1. / (H / (2. * focal)) * (rays_d[..., 1] / rays_d[..., 2] - rays_o[..., 1] / rays_o[..., 2])
+ d2 = 2. * near / rays_o[..., 2]
+
+ rays_o = torch.stack([o0, o1, o2], -1)
+ rays_d = torch.stack([d0, d1, d2], -1)
+
+ return rays_o, rays_d
+
+# Hierarchical sampling (section 5.2)
+def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
+ device = weights.device
+ # Get pdf
+ weights = weights + 1e-5 # prevent nans
+ pdf = weights / torch.sum(weights, -1, keepdim=True)
+ cdf = torch.cumsum(pdf, -1)
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
+
+ # Take uniform samples
+ if det:
+ u = torch.linspace(0., 1., steps=N_samples, device=device)
+ u = u.expand(list(cdf.shape[:-1]) + [N_samples])
+ else:
+ u = torch.rand(list(cdf.shape[:-1]) + [N_samples], device=device)
+
+ # Pytest, overwrite u with numpy's fixed random numbers
+ if pytest:
+ np.random.seed(0)
+ new_shape = list(cdf.shape[:-1]) + [N_samples]
+ if det:
+ u = np.linspace(0., 1., N_samples)
+ u = np.broadcast_to(u, new_shape)
+ else:
+ u = np.random.rand(*new_shape)
+ u = torch.Tensor(u)
+
+ # Invert CDF
+ u = u.contiguous()
+ inds = searchsorted(cdf.detach(), u, right=True)
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
+
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
+
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
+ t = (u - cdf_g[..., 0]) / denom
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
+
+ return samples
+
+
+def dda(rays_o, rays_d, bbox_3D):
+ inv_ray_d = 1.0 / (rays_d + 1e-6)
+ t_min = (bbox_3D[:1] - rays_o) * inv_ray_d # N_rays 3
+ t_max = (bbox_3D[1:] - rays_o) * inv_ray_d
+ t = torch.stack((t_min, t_max)) # 2 N_rays 3
+ t_min = torch.max(torch.min(t, dim=0)[0], dim=-1, keepdim=True)[0]
+ t_max = torch.min(torch.max(t, dim=0)[0], dim=-1, keepdim=True)[0]
+ return t_min, t_max
+
+
+def ray_marcher(rays,
+ N_samples=64,
+ lindisp=False,
+ perturb=0,
+ bbox_3D=None):
+ """
+ sample points along the rays
+ Inputs:
+ rays: ()
+
+ Returns:
+
+ """
+
+ # Decompose the inputs
+ N_rays = rays.shape[0]
+ rays_o, rays_d = rays[:, 0:3], rays[:, 3:6] # both (N_rays, 3)
+ near, far = rays[:, 6:7], rays[:, 7:8] # both (N_rays, 1)
+
+ if bbox_3D is not None:
+ # cal aabb boundles
+ near, far = dda(rays_o, rays_d, bbox_3D)
+
+ # Sample depth points
+ z_steps = torch.linspace(0, 1, N_samples, device=rays.device) # (N_samples)
+ if not lindisp: # use linear sampling in depth space
+ z_vals = near * (1 - z_steps) + far * z_steps
+ else: # use linear sampling in disparity space
+ z_vals = 1 / (1 / near * (1 - z_steps) + 1 / far * z_steps)
+
+ z_vals = z_vals.expand(N_rays, N_samples)
+
+ if perturb > 0: # perturb sampling depths (z_vals)
+ z_vals_mid = 0.5 * (z_vals[:, :-1] + z_vals[:, 1:]) # (N_rays, N_samples-1) interval mid points
+ # get intervals between samples
+ upper = torch.cat([z_vals_mid, z_vals[:, -1:]], -1)
+ lower = torch.cat([z_vals[:, :1], z_vals_mid], -1)
+
+ perturb_rand = perturb * torch.rand(z_vals.shape, device=rays.device)
+ z_vals = lower + (upper - lower) * perturb_rand
+
+ xyz_coarse_sampled = rays_o.unsqueeze(1) + \
+ rays_d.unsqueeze(1) * z_vals.unsqueeze(2) # (N_rays, N_samples, 3)
+
+ return xyz_coarse_sampled, rays_o, rays_d, z_vals
+
+
+def read_pfm(filename):
+ file = open(filename, 'rb')
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().decode('utf-8').rstrip()
+ if header == 'PF':
+ color = True
+ elif header == 'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8'))
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ file.close()
+ return data, scale
+
+
+def ndc_bbox(all_rays):
+ near_min = torch.min(all_rays[...,:3].view(-1,3),dim=0)[0]
+ near_max = torch.max(all_rays[..., :3].view(-1, 3), dim=0)[0]
+ far_min = torch.min((all_rays[...,:3]+all_rays[...,3:6]).view(-1,3),dim=0)[0]
+ far_max = torch.max((all_rays[...,:3]+all_rays[...,3:6]).view(-1, 3), dim=0)[0]
+ print(f'===> ndc bbox near_min:{near_min} near_max:{near_max} far_min:{far_min} far_max:{far_max}')
+ return torch.stack((torch.minimum(near_min,far_min),torch.maximum(near_max,far_max)))
\ No newline at end of file
diff --git a/TensoRF/dataLoader/tankstemple.py b/TensoRF/dataLoader/tankstemple.py
new file mode 100644
index 0000000..4215803
--- /dev/null
+++ b/TensoRF/dataLoader/tankstemple.py
@@ -0,0 +1,216 @@
+import torch
+from torch.utils.data import Dataset
+from tqdm import tqdm
+import os
+from PIL import Image
+from torchvision import transforms as T
+
+from .ray_utils import *
+
+
+def circle(radius=3.5, h=0.0, axis='z', t0=0, r=1):
+ if axis == 'z':
+ return lambda t: [radius * np.cos(r * t + t0), radius * np.sin(r * t + t0), h]
+ elif axis == 'y':
+ return lambda t: [radius * np.cos(r * t + t0), h, radius * np.sin(r * t + t0)]
+ else:
+ return lambda t: [h, radius * np.cos(r * t + t0), radius * np.sin(r * t + t0)]
+
+
+def cross(x, y, axis=0):
+ T = torch if isinstance(x, torch.Tensor) else np
+ return T.cross(x, y, axis)
+
+
+def normalize(x, axis=-1, order=2):
+ if isinstance(x, torch.Tensor):
+ l2 = x.norm(p=order, dim=axis, keepdim=True)
+ return x / (l2 + 1e-8), l2
+
+ else:
+ l2 = np.linalg.norm(x, order, axis)
+ l2 = np.expand_dims(l2, axis)
+ l2[l2 == 0] = 1
+ return x / l2,
+
+
+def cat(x, axis=1):
+ if isinstance(x[0], torch.Tensor):
+ return torch.cat(x, dim=axis)
+ return np.concatenate(x, axis=axis)
+
+
+def look_at_rotation(camera_position, at=None, up=None, inverse=False, cv=False):
+ """
+ This function takes a vector 'camera_position' which specifies the location
+ of the camera in world coordinates and two vectors `at` and `up` which
+ indicate the position of the object and the up directions of the world
+ coordinate system respectively. The object is assumed to be centered at
+ the origin.
+ The output is a rotation matrix representing the transformation
+ from world coordinates -> view coordinates.
+ Input:
+ camera_position: 3
+ at: 1 x 3 or N x 3 (0, 0, 0) in default
+ up: 1 x 3 or N x 3 (0, 1, 0) in default
+ """
+
+ if at is None:
+ at = torch.zeros_like(camera_position)
+ else:
+ at = torch.tensor(at).type_as(camera_position)
+ if up is None:
+ up = torch.zeros_like(camera_position)
+ up[2] = -1
+ else:
+ up = torch.tensor(up).type_as(camera_position)
+
+ z_axis = normalize(at - camera_position)[0]
+ x_axis = normalize(cross(up, z_axis))[0]
+ y_axis = normalize(cross(z_axis, x_axis))[0]
+
+ R = cat([x_axis[:, None], y_axis[:, None], z_axis[:, None]], axis=1)
+ return R
+
+
+def gen_path(pos_gen, at=(0, 0, 0), up=(0, -1, 0), frames=180):
+ c2ws = []
+ for t in range(frames):
+ c2w = torch.eye(4)
+ cam_pos = torch.tensor(pos_gen(t * (360.0 / frames) / 180 * np.pi))
+ cam_rot = look_at_rotation(cam_pos, at=at, up=up, inverse=False, cv=True)
+ c2w[:3, 3], c2w[:3, :3] = cam_pos, cam_rot
+ c2ws.append(c2w)
+ return torch.stack(c2ws)
+
+class TanksTempleDataset(Dataset):
+ """NSVF Generic Dataset."""
+ def __init__(self, datadir, split='train', downsample=1.0, wh=[1920,1080], is_stack=False):
+ self.root_dir = datadir
+ self.split = split
+ self.is_stack = is_stack
+ self.downsample = downsample
+ self.img_wh = (int(wh[0]/downsample),int(wh[1]/downsample))
+ self.define_transforms()
+
+ self.white_bg = True
+ self.near_far = [0.01,6.0]
+ self.scene_bbox = torch.from_numpy(np.loadtxt(f'{self.root_dir}/bbox.txt')).float()[:6].view(2,3)*1.2
+
+ self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ self.read_meta()
+ self.define_proj_mat()
+
+ self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
+ self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
+
+ def bbox2corners(self):
+ corners = self.scene_bbox.unsqueeze(0).repeat(4,1,1)
+ for i in range(3):
+ corners[i,[0,1],i] = corners[i,[1,0],i]
+ return corners.view(-1,3)
+
+
+ def read_meta(self):
+
+ self.intrinsics = np.loadtxt(os.path.join(self.root_dir, "intrinsics.txt"))
+ self.intrinsics[:2] *= (np.array(self.img_wh)/np.array([1920,1080])).reshape(2,1)
+ pose_files = sorted(os.listdir(os.path.join(self.root_dir, 'pose')))
+ img_files = sorted(os.listdir(os.path.join(self.root_dir, 'rgb')))
+
+ if self.split == 'train':
+ pose_files = [x for x in pose_files if x.startswith('0_')]
+ img_files = [x for x in img_files if x.startswith('0_')]
+ elif self.split == 'val':
+ pose_files = [x for x in pose_files if x.startswith('1_')]
+ img_files = [x for x in img_files if x.startswith('1_')]
+ elif self.split == 'test':
+ test_pose_files = [x for x in pose_files if x.startswith('2_')]
+ test_img_files = [x for x in img_files if x.startswith('2_')]
+ if len(test_pose_files) == 0:
+ test_pose_files = [x for x in pose_files if x.startswith('1_')]
+ test_img_files = [x for x in img_files if x.startswith('1_')]
+ pose_files = test_pose_files
+ img_files = test_img_files
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ self.directions = get_ray_directions(self.img_wh[1], self.img_wh[0], [self.intrinsics[0,0],self.intrinsics[1,1]], center=self.intrinsics[:2,2]) # (h, w, 3)
+ self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
+
+
+
+ self.poses = []
+ self.all_rays = []
+ self.all_rgbs = []
+
+ assert len(img_files) == len(pose_files)
+ for img_fname, pose_fname in tqdm(zip(img_files, pose_files), desc=f'Loading data {self.split} ({len(img_files)})'):
+ image_path = os.path.join(self.root_dir, 'rgb', img_fname)
+ img = Image.open(image_path)
+ if self.downsample!=1.0:
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (4, h, w)
+ img = img.view(img.shape[0], -1).permute(1, 0) # (h*w, 4) RGBA
+ if img.shape[-1]==4:
+ img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
+ self.all_rgbs.append(img)
+
+
+ c2w = np.loadtxt(os.path.join(self.root_dir, 'pose', pose_fname))# @ cam_trans
+ c2w = torch.FloatTensor(c2w)
+ self.poses.append(c2w) # C2W
+ rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
+ self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 8)
+
+ self.poses = torch.stack(self.poses)
+
+ center = torch.mean(self.scene_bbox, dim=0)
+ radius = torch.norm(self.scene_bbox[1]-center)*1.2
+ up = torch.mean(self.poses[:, :3, 1], dim=0).tolist()
+ pos_gen = circle(radius=radius, h=-0.2*up[1], axis='y')
+ self.render_path = gen_path(pos_gen, up=up,frames=200)
+ self.render_path[:, :3, 3] += center
+
+
+
+ if 'train' == self.split:
+ if self.is_stack:
+ self.all_rays = torch.stack(self.all_rays, 0).reshape(-1,*self.img_wh[::-1], 6) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
+
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+ def define_proj_mat(self):
+ self.proj_mat = torch.from_numpy(self.intrinsics[:3,:3]).unsqueeze(0).float() @ torch.inverse(self.poses)[:,:3]
+
+ def world2ndc(self, points):
+ device = points.device
+ return (points - self.center.to(device)) / self.radius.to(device)
+
+ def __len__(self):
+ if self.split == 'train':
+ return len(self.all_rays)
+ return len(self.all_rgbs)
+
+ def __getitem__(self, idx):
+
+ if self.split == 'train': # use data in the buffers
+ sample = {'rays': self.all_rays[idx],
+ 'rgbs': self.all_rgbs[idx]}
+
+ else: # create data for each image separately
+
+ img = self.all_rgbs[idx]
+ rays = self.all_rays[idx]
+
+ sample = {'rays': rays,
+ 'rgbs': img}
+ return sample
\ No newline at end of file
diff --git a/TensoRF/dataLoader/your_own_data.py b/TensoRF/dataLoader/your_own_data.py
new file mode 100644
index 0000000..79313e2
--- /dev/null
+++ b/TensoRF/dataLoader/your_own_data.py
@@ -0,0 +1,129 @@
+import torch,cv2
+from torch.utils.data import Dataset
+import json
+from tqdm import tqdm
+import os
+from PIL import Image
+from torchvision import transforms as T
+
+
+from .ray_utils import *
+
+
+class YourOwnDataset(Dataset):
+ def __init__(self, datadir, split='train', downsample=1.0, is_stack=False, N_vis=-1):
+
+ self.N_vis = N_vis
+ self.root_dir = datadir
+ self.split = split
+ self.is_stack = is_stack
+ self.downsample = downsample
+ self.define_transforms()
+
+ self.scene_bbox = torch.tensor([[-1.5, -1.5, -1.5], [1.5, 1.5, 1.5]])
+ self.blender2opencv = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
+ self.read_meta()
+ self.define_proj_mat()
+
+ self.white_bg = True
+ self.near_far = [0.1,100.0]
+
+ self.center = torch.mean(self.scene_bbox, axis=0).float().view(1, 1, 3)
+ self.radius = (self.scene_bbox[1] - self.center).float().view(1, 1, 3)
+ self.downsample=downsample
+
+ def read_depth(self, filename):
+ depth = np.array(read_pfm(filename)[0], dtype=np.float32) # (800, 800)
+ return depth
+
+ def read_meta(self):
+
+ with open(os.path.join(self.root_dir, f"transforms_{self.split}.json"), 'r') as f:
+ self.meta = json.load(f)
+
+ w, h = int(self.meta['w']/self.downsample), int(self.meta['h']/self.downsample)
+ self.img_wh = [w,h]
+ self.focal_x = 0.5 * w / np.tan(0.5 * self.meta['camera_angle_x']) # original focal length
+ self.focal_y = 0.5 * h / np.tan(0.5 * self.meta['camera_angle_y']) # original focal length
+ self.cx, self.cy = self.meta['cx'],self.meta['cy']
+
+
+ # ray directions for all pixels, same for all images (same H, W, focal)
+ self.directions = get_ray_directions(h, w, [self.focal_x,self.focal_y], center=[self.cx, self.cy]) # (h, w, 3)
+ self.directions = self.directions / torch.norm(self.directions, dim=-1, keepdim=True)
+ self.intrinsics = torch.tensor([[self.focal_x,0,self.cx],[0,self.focal_y,self.cy],[0,0,1]]).float()
+
+ self.image_paths = []
+ self.poses = []
+ self.all_rays = []
+ self.all_rgbs = []
+ self.all_masks = []
+ self.all_depth = []
+
+
+ img_eval_interval = 1 if self.N_vis < 0 else len(self.meta['frames']) // self.N_vis
+ idxs = list(range(0, len(self.meta['frames']), img_eval_interval))
+ for i in tqdm(idxs, desc=f'Loading data {self.split} ({len(idxs)})'):#img_list:#
+
+ frame = self.meta['frames'][i]
+ pose = np.array(frame['transform_matrix']) @ self.blender2opencv
+ c2w = torch.FloatTensor(pose)
+ self.poses += [c2w]
+
+ image_path = os.path.join(self.root_dir, f"{frame['file_path']}.png")
+ self.image_paths += [image_path]
+ img = Image.open(image_path)
+
+ if self.downsample!=1.0:
+ img = img.resize(self.img_wh, Image.LANCZOS)
+ img = self.transform(img) # (4, h, w)
+ img = img.view(-1, w*h).permute(1, 0) # (h*w, 4) RGBA
+ if img.shape[-1]==4:
+ img = img[:, :3] * img[:, -1:] + (1 - img[:, -1:]) # blend A to RGB
+ self.all_rgbs += [img]
+
+
+ rays_o, rays_d = get_rays(self.directions, c2w) # both (h*w, 3)
+ self.all_rays += [torch.cat([rays_o, rays_d], 1)] # (h*w, 6)
+
+
+ self.poses = torch.stack(self.poses)
+ if not self.is_stack:
+ self.all_rays = torch.cat(self.all_rays, 0) # (len(self.meta['frames])*h*w, 3)
+ self.all_rgbs = torch.cat(self.all_rgbs, 0) # (len(self.meta['frames])*h*w, 3)
+
+# self.all_depth = torch.cat(self.all_depth, 0) # (len(self.meta['frames])*h*w, 3)
+ else:
+ self.all_rays = torch.stack(self.all_rays, 0) # (len(self.meta['frames]),h*w, 3)
+ self.all_rgbs = torch.stack(self.all_rgbs, 0).reshape(-1,*self.img_wh[::-1], 3) # (len(self.meta['frames]),h,w,3)
+ # self.all_masks = torch.stack(self.all_masks, 0).reshape(-1,*self.img_wh[::-1]) # (len(self.meta['frames]),h,w,3)
+
+
+ def define_transforms(self):
+ self.transform = T.ToTensor()
+
+ def define_proj_mat(self):
+ self.proj_mat = self.intrinsics.unsqueeze(0) @ torch.inverse(self.poses)[:,:3]
+
+ def world2ndc(self,points,lindisp=None):
+ device = points.device
+ return (points - self.center.to(device)) / self.radius.to(device)
+
+ def __len__(self):
+ return len(self.all_rgbs)
+
+ def __getitem__(self, idx):
+
+ if self.split == 'train': # use data in the buffers
+ sample = {'rays': self.all_rays[idx],
+ 'rgbs': self.all_rgbs[idx]}
+
+ else: # create data for each image separately
+
+ img = self.all_rgbs[idx]
+ rays = self.all_rays[idx]
+ mask = self.all_masks[idx] # for quantity evaluation
+
+ sample = {'rays': rays,
+ 'rgbs': img}
+ return sample
diff --git a/TensoRF/extra/auto_run_paramsets.py b/TensoRF/extra/auto_run_paramsets.py
new file mode 100644
index 0000000..52b4f1a
--- /dev/null
+++ b/TensoRF/extra/auto_run_paramsets.py
@@ -0,0 +1,207 @@
+import os
+import threading, queue
+import numpy as np
+import time
+
+
+def getFolderLocker(logFolder):
+ while True:
+ try:
+ os.makedirs(logFolder+"/lockFolder")
+ break
+ except:
+ time.sleep(0.01)
+
+def releaseFolderLocker(logFolder):
+ os.removedirs(logFolder+"/lockFolder")
+
+def getStopFolder(logFolder):
+ return os.path.isdir(logFolder+"/stopFolder")
+
+
+def get_param_str(key, val):
+ if key == 'data_name':
+ return f'--datadir {datafolder}/{val} '
+ else:
+ return f'--{key} {val} '
+
+def get_param_list(param_dict):
+ param_keys = list(param_dict.keys())
+ param_modes = len(param_keys)
+ param_nums = [len(param_dict[key]) for key in param_keys]
+
+ param_ids = np.zeros(param_nums+[param_modes], dtype=int)
+ for i in range(param_modes):
+ broad_tuple = np.ones(param_modes, dtype=int).tolist()
+ broad_tuple[i] = param_nums[i]
+ broad_tuple = tuple(broad_tuple)
+ print(broad_tuple)
+ param_ids[...,i] = np.arange(param_nums[i]).reshape(broad_tuple)
+ param_ids = param_ids.reshape(-1, param_modes)
+ # print(param_ids)
+ print(len(param_ids))
+
+ params = []
+ expnames = []
+ for i in range(param_ids.shape[0]):
+ one = ""
+ name = ""
+ param_id = param_ids[i]
+ for j in range(param_modes):
+ key = param_keys[j]
+ val = param_dict[key][param_id[j]]
+ if type(key) is tuple:
+ assert len(key) == len(val)
+ for k in range(len(key)):
+ one += get_param_str(key[k], val[k])
+ name += f'{val[k]},'
+ name=name[:-1]+'-'
+ else:
+ one += get_param_str(key, val)
+ name += f'{val}-'
+ params.append(one)
+ name=name.replace(' ','')
+ print(name)
+ expnames.append(name[:-1])
+ # print(params)
+ return params, expnames
+
+
+
+
+
+
+
+if __name__ == '__main__':
+
+
+
+ # nerf
+ expFolder = "nerf/"
+ # parameters to iterate, use tuple to couple multiple parameters
+ datafolder = '/mnt/new_disk_2/anpei/Dataset/nerf_synthetic/'
+ param_dict = {
+ 'data_name': ['ship', 'mic', 'chair', 'lego', 'drums', 'ficus', 'hotdog', 'materials'],
+ 'data_dim_color': [13, 27, 54]
+ }
+
+ # n_iters = 30000
+ # for data_name in ['Robot']:#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
+ # cmd = f'CUDA_VISIBLE_DEVICES={cuda} python train.py ' \
+ # f'--dataset_name nsvf --datadir /mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/{data_name} '\
+ # f'--expname {data_name} --batch_size {batch_size} ' \
+ # f'--n_iters {n_iters} ' \
+ # f'--N_voxel_init {128**3} --N_voxel_final {300**3} '\
+ # f'--N_vis {5} ' \
+ # f'--n_lamb_sigma "[16,16,16]" --n_lamb_sh "[48,48,48]" ' \
+ # f'--upsamp_list "[2000, 3000, 4000, 5500,7000]" --update_AlphaMask_list "[3000,4000]" ' \
+ # f'--shadingMode MLP_Fea --fea2denseAct softplus --view_pe {2} --fea_pe {2} ' \
+ # f'--L1_weight_inital {8e-5} --L1_weight_rest {4e-5} --rm_weight_mask_thre {1e-4} --add_timestamp 0 ' \
+ # f'--render_test 1 '
+ # print(cmd)
+ # os.system(cmd)
+
+ # nsvf
+ # expFolder = "nsvf_0227/"
+ # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/Synthetic_NSVF/'
+ # param_dict = {
+ # 'data_name': ['Robot','Steamtrain','Bike','Lifestyle','Palace','Spaceship','Toad','Wineholder'],#'Bike','Lifestyle','Palace','Robot','Spaceship','Steamtrain','Toad','Wineholder'
+ # 'shadingMode': ['SH'],
+ # ('n_lamb_sigma', 'n_lamb_sh'): [ ("[8,8,8]", "[8,8,8]")],
+ # ('view_pe', 'fea_pe', 'featureC','fea2denseAct','N_voxel_init') : [(2, 2, 128, 'softplus',128**3)],
+ # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'):[(4e-5, 4e-5, 1e-4)],
+ # ('n_iters','N_voxel_final'): [(30000,300**3)],
+ # ('dataset_name','N_vis','render_test') : [("nsvf",5,1)],
+ # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[3000,4000]")]
+ #
+ # }
+
+ # tankstemple
+ # expFolder = "tankstemple_0304/"
+ # datafolder = '/mnt/new_disk_2/anpei/Dataset/TeRF/TanksAndTemple/'
+ # param_dict = {
+ # 'data_name': ['Truck','Barn','Caterpillar','Family','Ignatius'],
+ # 'shadingMode': ['MLP_Fea'],
+ # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,16,16]", "[48,48,48]")],
+ # ('view_pe', 'fea_pe','fea2denseAct','N_voxel_init','render_test') : [(2, 2, 'softplus',128**3,1)],
+ # ('TV_weight_density','TV_weight_app'):[(0.1,0.01)],
+ # # ('L1_weight_inital', 'L1_weight_rest', 'rm_weight_mask_thre'): [(4e-5, 4e-5, 1e-4)],
+ # ('n_iters','N_voxel_final'): [(15000,300**3)],
+ # ('dataset_name','N_vis') : [("tankstemple",5)],
+ # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2000,4000]")]
+ # }
+
+ # llff
+ # expFolder = "real_iconic/"
+ # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/real_iconic/'
+ # List = os.listdir(datafolder)
+ # param_dict = {
+ # 'data_name': List,
+ # ('shadingMode', 'view_pe', 'fea_pe','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 'relu',512,128**3)],
+ # ('n_lamb_sigma', 'n_lamb_sh') : [("[16,4,4]", "[48,12,12]")],
+ # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
+ # ('n_iters','N_voxel_final'): [(25000,640**3)],
+ # ('dataset_name','downsample_train','ndc_ray','N_vis','render_path') : [("llff",4.0, 1,-1,1)],
+ # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
+ # }
+
+ # expFolder = "llff/"
+ # datafolder = '/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data'
+ # param_dict = {
+ # 'data_name': ['fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'],#'fern', 'flower', 'room', 'leaves', 'horns', 'trex', 'fortress', 'orchids'
+ # ('n_lamb_sigma', 'n_lamb_sh'): [("[16,4,4]", "[48,12,12]")],
+ # ('shadingMode', 'view_pe', 'fea_pe', 'featureC','fea2denseAct', 'nSamples','N_voxel_init') : [('MLP_Fea', 0, 0, 128, 'relu',512,128**3),('SH', 0, 0, 128, 'relu',512,128**3)],
+ # ('TV_weight_density', 'TV_weight_app'):[(1.0,1.0)],
+ # ('n_iters','N_voxel_final'): [(25000,640**3)],
+ # ('dataset_name','downsample_train','ndc_ray','N_vis','render_test','render_path') : [("llff",4.0, 1,-1,1,1)],
+ # ('upsamp_list','update_AlphaMask_list'): [("[2000,3000,4000,5500,7000]","[2500]")],
+ # }
+
+ #setting available gpus
+ gpus_que = queue.Queue(3)
+ for i in [1,2,3]:
+ gpus_que.put(i)
+
+ os.makedirs(f"log/{expFolder}", exist_ok=True)
+
+ def run_program(gpu, expname, param):
+ cmd = f'CUDA_VISIBLE_DEVICES={gpu} python train.py ' \
+ f'--expname {expname} --basedir ./log/{expFolder} --config configs/lego.txt ' \
+ f'{param}' \
+ f'> "log/{expFolder}{expname}/{expname}.txt"'
+ print(cmd)
+ os.system(cmd)
+ gpus_que.put(gpu)
+
+ params, expnames = get_param_list(param_dict)
+
+
+ logFolder=f"log/{expFolder}"
+ os.makedirs(logFolder, exist_ok=True)
+
+ ths = []
+ for i in range(len(params)):
+
+ if getStopFolder(logFolder):
+ break
+
+
+ targetFolder = f"log/{expFolder}{expnames[i]}"
+ gpu = gpus_que.get()
+ getFolderLocker(logFolder)
+ if os.path.isdir(targetFolder):
+ releaseFolderLocker(logFolder)
+ gpus_que.put(gpu)
+ continue
+ else:
+ os.makedirs(targetFolder, exist_ok=True)
+ print("making",targetFolder, "running",expnames[i], params[i])
+ releaseFolderLocker(logFolder)
+
+
+ t = threading.Thread(target=run_program, args=(gpu, expnames[i], params[i]), daemon=True)
+ t.start()
+ ths.append(t)
+
+ for th in ths:
+ th.join()
\ No newline at end of file
diff --git a/TensoRF/extra/compute_metrics.py b/TensoRF/extra/compute_metrics.py
new file mode 100644
index 0000000..59efcb2
--- /dev/null
+++ b/TensoRF/extra/compute_metrics.py
@@ -0,0 +1,182 @@
+import os, math
+import numpy as np
+import scipy.signal
+from typing import List, Optional
+from PIL import Image
+import os
+import torch
+import configargparse
+
+__LPIPS__ = {}
+def init_lpips(net_name, device):
+ assert net_name in ['alex', 'vgg']
+ import lpips
+ print(f'init_lpips: lpips_{net_name}')
+ return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
+
+def rgb_lpips(np_gt, np_im, net_name, device):
+ if net_name not in __LPIPS__:
+ __LPIPS__[net_name] = init_lpips(net_name, device)
+ gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
+ im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
+ return __LPIPS__[net_name](gt, im, normalize=True).item()
+
+
+def findItem(items, target):
+ for one in items:
+ if one[:len(target)]==target:
+ return one
+ return None
+
+
+''' Evaluation metrics (ssim, lpips)
+'''
+def rgb_ssim(img0, img1, max_val,
+ filter_size=11,
+ filter_sigma=1.5,
+ k1=0.01,
+ k2=0.03,
+ return_map=False):
+ # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
+ assert len(img0.shape) == 3
+ assert img0.shape[-1] == 3
+ assert img0.shape == img1.shape
+
+ # Construct a 1D Gaussian blur filter.
+ hw = filter_size // 2
+ shift = (2 * hw - filter_size + 1) / 2
+ f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
+ filt = np.exp(-0.5 * f_i)
+ filt /= np.sum(filt)
+
+ # Blur in x and y (faster than the 2D convolution).
+ def convolve2d(z, f):
+ return scipy.signal.convolve2d(z, f, mode='valid')
+
+ filt_fn = lambda z: np.stack([
+ convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
+ for i in range(z.shape[-1])], -1)
+ mu0 = filt_fn(img0)
+ mu1 = filt_fn(img1)
+ mu00 = mu0 * mu0
+ mu11 = mu1 * mu1
+ mu01 = mu0 * mu1
+ sigma00 = filt_fn(img0**2) - mu00
+ sigma11 = filt_fn(img1**2) - mu11
+ sigma01 = filt_fn(img0 * img1) - mu01
+
+ # Clip the variances and covariances to valid values.
+ # Variance must be non-negative:
+ sigma00 = np.maximum(0., sigma00)
+ sigma11 = np.maximum(0., sigma11)
+ sigma01 = np.sign(sigma01) * np.minimum(
+ np.sqrt(sigma00 * sigma11), np.abs(sigma01))
+ c1 = (k1 * max_val)**2
+ c2 = (k2 * max_val)**2
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
+ ssim_map = numer / denom
+ ssim = np.mean(ssim_map)
+ return ssim_map if return_map else ssim
+
+
+if __name__ == '__main__':
+
+ parser = configargparse.ArgumentParser()
+ parser.add_argument("--exp", type=str, help="folder of exps")
+ parser.add_argument("--paramStr", type=str, help="str of params")
+ args = parser.parse_args()
+
+
+ # datanames = ['drums','hotdog','materials','ficus','lego','mic','ship','chair'] #['ship']#
+ # gtFolder = "/home/code-base/user_space/codes/nerf/data/nerf_synthetic"
+ # expFolder = "/home/code-base/user_space/codes/TensoRF/log/"+args.exp
+
+ # datanames = ['room','fortress', 'flower','orchids','leaves','horns','trex','fern'] #['ship']#
+ # gtFolder = "/mnt/new_disk_2/anpei/Dataset/MVSNeRF/nerf_llff_data/"
+ # expFolder = "/mnt/new_disk_2/anpei/code/TensoRF/log/"+args.exp
+ paramStr = args.paramStr
+ fileNum = 200
+
+
+ expitems = os.listdir(expFolder)
+ finalFolder = f'{expFolder}/finals/{paramStr}'
+ outFile = f'{finalFolder}/{paramStr}_metrics.txt'
+ os.makedirs(finalFolder, exist_ok=True)
+
+ expitems.sort(reverse=True)
+
+
+ with open(outFile, 'w') as f:
+ all_psnr = []
+ all_ssim = []
+ all_alex = []
+ all_vgg = []
+ for dataname in datanames:
+
+
+ gtstr = gtFolder+"/"+dataname+"/test/r_%d.png"
+ expname = findItem(expitems, f'{paramStr}-{dataname}')
+ print("expname: ", expname)
+ if expname is None:
+ print("no ",dataname, "exists")
+ continue
+ resultstr = expFolder+"/"+expname+"/imgs_test_all/"+ dataname+"-"+paramStr+ "_%03d.png"
+ metric_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_mean.txt'
+ video_file = f'{expFolder}/{expname}/imgs_test_all/{paramStr}-{dataname}_video.mp4'
+
+ exist_metric=False
+ if os.path.isfile(metric_file):
+ metrics = np.loadtxt(metric_file)
+ print(metrics, metrics.tolist())
+ if metrics.size == 4:
+ psnr, ssim, l_a, l_v = metrics.tolist()
+ exist_metric = True
+ os.system(f"cp {video_file} {finalFolder}/")
+
+ if not exist_metric:
+ psnrs = []
+ ssims = []
+ l_alex = []
+ l_vgg = []
+ for i in range(fileNum):
+ gt = np.asarray(Image.open(gtstr%i),dtype=np.float32) / 255.0
+ gtmask = gt[...,[3]]
+ gt = gt[...,:3]
+ gt = gt*gtmask + (1-gtmask)
+ img = np.asarray(Image.open(resultstr%i),dtype=np.float32)[...,:3] / 255.0
+ # print(gt[0,0],img[0,0],gt.shape, img.shape, gt.max(), img.max())
+
+
+ psnr = -10. * np.log10(np.mean(np.square(img - gt)))
+ ssim = rgb_ssim(img, gt, 1)
+ lpips_alex = rgb_lpips(gt, img, 'alex','cuda')
+ lpips_vgg = rgb_lpips(gt, img, 'vgg','cuda')
+
+ print(i, psnr, ssim, lpips_alex, lpips_vgg)
+ psnrs.append(psnr)
+ ssims.append(ssim)
+ l_alex.append(lpips_alex)
+ l_vgg.append(lpips_vgg)
+ psnr = np.mean(np.array(psnrs))
+ ssim = np.mean(np.array(ssims))
+ l_a = np.mean(np.array(l_alex))
+ l_v = np.mean(np.array(l_vgg))
+
+ rS=f'{dataname} : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
+ print(rS)
+ f.write(rS+"\n")
+
+ all_psnr.append(psnr)
+ all_ssim.append(ssim)
+ all_alex.append(l_a)
+ all_vgg.append(l_v)
+
+ psnr = np.mean(np.array(all_psnr))
+ ssim = np.mean(np.array(all_ssim))
+ l_a = np.mean(np.array(all_alex))
+ l_v = np.mean(np.array(all_vgg))
+
+ rS=f'mean : psnr {psnr} ssim {ssim} l_a {l_a} l_v {l_v}'
+ print(rS)
+ f.write(rS+"\n")
\ No newline at end of file
diff --git a/TensoRF/models/__init__.py b/TensoRF/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/TensoRF/models/cosine_transform.py b/TensoRF/models/cosine_transform.py
new file mode 100644
index 0000000..4285b41
--- /dev/null
+++ b/TensoRF/models/cosine_transform.py
@@ -0,0 +1,107 @@
+import torch
+
+
+def dct(coefs, coords=None):
+ '''
+ coefs: [..., C] # C: n_coefs
+ coords: [..., S] # S: n_samples
+ '''
+ if coords is None:
+ coords = torch.ones_like(coefs) \
+ * torch.arange(coefs.size(-1)).to(coefs.device) # \
+ # / coefs.size(-1)
+ # cos = torch.cos(torch.pi * coords.unsqueeze(-1)
+ cos = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5) / coefs.size(-1)
+ * (torch.arange(coefs.size(-1)).to(coefs.device) + 0.5))
+ # cos = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.) / coefs.size(-1)
+ # * (torch.arange(coefs.size(-1)).to(coefs.device) + 0.))
+ return torch.einsum('...C,...SC->...S', coefs*(2/coefs.size(-1))**0.5, cos)
+
+
+def dctn(coefs, axes=None):
+ if axes is None:
+ axes = tuple(range(len(coefs.shape)))
+ out = coefs
+ for ax in axes:
+ out = out.transpose(-1, ax)
+ out = dct(out)
+ out = out.transpose(-1, ax)
+ return out
+
+
+def idctn(coefs, axes=None, n_out=None, **kwargs):
+ if axes is None:
+ axes = tuple(range(len(coefs.shape)))
+
+ if n_out is None or isinstance(n_out, int):
+ n_out = [n_out] * len(axes)
+
+ out = coefs
+ for ax, n_o in zip(axes, n_out):
+ out = out.transpose(-1, ax)
+ out = idct(out, n_o, **kwargs)
+ out = out.transpose(-1, ax)
+ return out
+
+
+def idct(coefs, n_out=None):
+ N = coefs.size(-1)
+ if n_out is None:
+ n_out = N
+ '''
+ # TYPE II
+ out = torch.cos(torch.pi * (torch.arange(N).unsqueeze(-1) + 0.5)
+ * torch.arange(1, N) / N)
+ out = 2 * torch.einsum('...C,...SC->...S', coefs[..., 1:], out)
+ return out + coefs[..., :1]
+ '''
+ # TYPE IV
+ out = torch.cos(torch.pi * (torch.arange(N).to(coefs.device) + 0.5) / N
+ * (torch.linspace(0, N-1, n_out).unsqueeze(-1).to(coefs.device) + 0.5))
+ # CCT version
+ # out = torch.cos(torch.pi / N * (torch.arange(N).to(coefs.device))
+ # * (torch.linspace(0, N-1, n_out).unsqueeze(-1).to(coefs.device)))
+ # return 2 * torch.einsum('...C,...SC->...S', coefs, out)
+ return torch.einsum('...C,...SC->...S', coefs*(2/N)**0.5, out)
+
+
+if __name__ == '__main__':
+ from scipy.fftpack import dct as org_dct
+ from scipy.fftpack import dctn as org_dctn
+ from scipy.fftpack import idct as org_idct
+ from scipy.fftpack import idctn as org_idctn
+
+ arr = torch.randn((1, 8, 240, 250)) * 10
+ print((arr - dctn(idctn(arr, (-2, -1)), (-2, -1))).abs().max())
+ print((arr - idctn(dctn(arr, (-2, -1)), (-2, -1))).abs().max())
+ print((arr - dctn(idctn(arr, (-2,)), (-2,))).abs().max())
+ print((arr - idctn(dctn(arr, (-2,)), (-2,))).abs().max())
+ print((arr - idctn(dctn(arr, (-2,)), (-2,))).abs().max())
+ print((org_idctn(arr.numpy(), 4, axes=(-2, -1), norm='ortho')
+ - idctn(arr, (-2, -1)).numpy()).max())
+ '''
+ arr = torch.randn((3, 8))
+
+ print(arr) # org_idct(arr.numpy(), 4))
+ print(dct(idct(arr)))
+ print(idct(dct(arr)))
+ print(idct(arr).numpy())
+ print(org_idctn(arr.numpy(), 4, axes=(-2, -1), norm='ortho') - idctn(arr).numpy())
+
+ print(arr)
+ print(org_dct(arr.numpy()))
+ print(org_dct(arr.numpy()) - dct(arr, torch.arange(8) / 8).numpy())
+ print()
+ print(org_dct(org_dct(arr.numpy()), type=3))
+ print(org_dct(org_dct(arr.numpy()), type=3)
+ - idct(dct(arr, torch.arange(8) / 8)).numpy())
+ '''
+
+ # print(idct(dct(arr, torch.arange(16) / 16)) / torch.sqrt(torch.tensor(16)))
+ # print(idct(dct(arr)))
+ # print(org_dct(arr.numpy()) - dct(arr).numpy())
+
+ # ndarr = torch.randn((3, 2, 4, 5))
+ # axes = (3, ) # (1, 2, 3)
+ # print(org_dctn(ndarr.numpy(), axes=axes) - dctn(ndarr, axes=axes).numpy())
+
diff --git a/TensoRF/models/sh.py b/TensoRF/models/sh.py
new file mode 100644
index 0000000..27e3cad
--- /dev/null
+++ b/TensoRF/models/sh.py
@@ -0,0 +1,133 @@
+import torch
+
+################## sh function ##################
+C0 = 0.28209479177387814
+C1 = 0.4886025119029199
+C2 = [
+ 1.0925484305920792,
+ -1.0925484305920792,
+ 0.31539156525252005,
+ -1.0925484305920792,
+ 0.5462742152960396
+]
+C3 = [
+ -0.5900435899266435,
+ 2.890611442640554,
+ -0.4570457994644658,
+ 0.3731763325901154,
+ -0.4570457994644658,
+ 1.445305721320277,
+ -0.5900435899266435
+]
+C4 = [
+ 2.5033429417967046,
+ -1.7701307697799304,
+ 0.9461746957575601,
+ -0.6690465435572892,
+ 0.10578554691520431,
+ -0.6690465435572892,
+ 0.47308734787878004,
+ -1.7701307697799304,
+ 0.6258357354491761,
+]
+
+def eval_sh(deg, sh, dirs):
+ """
+ Evaluate spherical harmonics at unit directions
+ using hardcoded SH polynomials.
+ Works with torch/np/jnp.
+ ... Can be 0 or more batch dimensions.
+ :param deg: int SH max degree. Currently, 0-4 supported
+ :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2)
+ :param dirs: torch.Tensor unit directions (..., 3)
+ :return: (..., C)
+ """
+ assert deg <= 4 and deg >= 0
+ assert (deg + 1) ** 2 == sh.shape[-1]
+ C = sh.shape[-2]
+
+ result = C0 * sh[..., 0]
+ if deg > 0:
+ x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3]
+ result = (result -
+ C1 * y * sh[..., 1] +
+ C1 * z * sh[..., 2] -
+ C1 * x * sh[..., 3])
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result = (result +
+ C2[0] * xy * sh[..., 4] +
+ C2[1] * yz * sh[..., 5] +
+ C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] +
+ C2[3] * xz * sh[..., 7] +
+ C2[4] * (xx - yy) * sh[..., 8])
+
+ if deg > 2:
+ result = (result +
+ C3[0] * y * (3 * xx - yy) * sh[..., 9] +
+ C3[1] * xy * z * sh[..., 10] +
+ C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] +
+ C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] +
+ C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] +
+ C3[5] * z * (xx - yy) * sh[..., 14] +
+ C3[6] * x * (xx - 3 * yy) * sh[..., 15])
+ if deg > 3:
+ result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] +
+ C4[1] * yz * (3 * xx - yy) * sh[..., 17] +
+ C4[2] * xy * (7 * zz - 1) * sh[..., 18] +
+ C4[3] * yz * (7 * zz - 3) * sh[..., 19] +
+ C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] +
+ C4[5] * xz * (7 * zz - 3) * sh[..., 21] +
+ C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] +
+ C4[7] * xz * (xx - 3 * yy) * sh[..., 23] +
+ C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24])
+ return result
+
+def eval_sh_bases(deg, dirs):
+ """
+ Evaluate spherical harmonics bases at unit directions,
+ without taking linear combination.
+ At each point, the final result may the be
+ obtained through simple multiplication.
+ :param deg: int SH max degree. Currently, 0-4 supported
+ :param dirs: torch.Tensor (..., 3) unit directions
+ :return: torch.Tensor (..., (deg+1) ** 2)
+ """
+ assert deg <= 4 and deg >= 0
+ result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device)
+ result[..., 0] = C0
+ if deg > 0:
+ x, y, z = dirs.unbind(-1)
+ result[..., 1] = -C1 * y;
+ result[..., 2] = C1 * z;
+ result[..., 3] = -C1 * x;
+ if deg > 1:
+ xx, yy, zz = x * x, y * y, z * z
+ xy, yz, xz = x * y, y * z, x * z
+ result[..., 4] = C2[0] * xy;
+ result[..., 5] = C2[1] * yz;
+ result[..., 6] = C2[2] * (2.0 * zz - xx - yy);
+ result[..., 7] = C2[3] * xz;
+ result[..., 8] = C2[4] * (xx - yy);
+
+ if deg > 2:
+ result[..., 9] = C3[0] * y * (3 * xx - yy);
+ result[..., 10] = C3[1] * xy * z;
+ result[..., 11] = C3[2] * y * (4 * zz - xx - yy);
+ result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy);
+ result[..., 13] = C3[4] * x * (4 * zz - xx - yy);
+ result[..., 14] = C3[5] * z * (xx - yy);
+ result[..., 15] = C3[6] * x * (xx - 3 * yy);
+
+ if deg > 3:
+ result[..., 16] = C4[0] * xy * (xx - yy);
+ result[..., 17] = C4[1] * yz * (3 * xx - yy);
+ result[..., 18] = C4[2] * xy * (7 * zz - 1);
+ result[..., 19] = C4[3] * yz * (7 * zz - 3);
+ result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3);
+ result[..., 21] = C4[5] * xz * (7 * zz - 3);
+ result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1);
+ result[..., 23] = C4[7] * xz * (xx - 3 * yy);
+ result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy));
+ return result
diff --git a/TensoRF/models/tensoRF.py b/TensoRF/models/tensoRF.py
new file mode 100644
index 0000000..e237b4e
--- /dev/null
+++ b/TensoRF/models/tensoRF.py
@@ -0,0 +1,445 @@
+from .tensorBase import *
+
+
+def min_max_quantize(inputs, bits):
+ if bits == 32:
+ return inputs
+
+ # rounding
+ min_value = torch.amin(inputs)
+ max_value = torch.amax(inputs)
+ scale = (max_value - min_value).clamp(min=1e-8) / (bits ** 2 - 1)
+
+ rounded = torch.round((inputs - min_value) / scale) * scale + min_value
+
+ return (rounded - inputs).detach() + inputs
+
+
+class TensorVM(TensorBase):
+ def __init__(self, aabb, gridSize, device, **kargs):
+ super(TensorVM, self).__init__(aabb, gridSize, device, **kargs)
+
+ def init_svd_volume(self, res, device):
+ self.plane_coef = torch.nn.Parameter(
+ 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device))
+ self.line_coef = torch.nn.Parameter(
+ 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device))
+ self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device)
+
+ def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
+ grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz}, {'params': self.plane_coef, 'lr': lr_init_spatialxyz},
+ {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
+ if isinstance(self.renderModule, torch.nn.Module):
+ grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
+ return grad_vars
+
+ def compute_features(self, xyz_sampled):
+ coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach()
+ coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach()
+
+ plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view(
+ -1, *xyz_sampled.shape[:1])
+ line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
+ -1, *xyz_sampled.shape[:1])
+
+ sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
+
+ plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1)
+ line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1)
+
+ app_features = self.basis_mat((plane_feats * line_feats).T)
+
+ return sigma_feature, app_features
+
+ def compute_densityfeature(self, xyz_sampled):
+ coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
+ coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
+
+ plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view(
+ -1, *xyz_sampled.shape[:1])
+ line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view(
+ -1, *xyz_sampled.shape[:1])
+
+ sigma_feature = torch.sum(plane_feats * line_feats, dim=0)
+
+ return sigma_feature
+
+ def compute_appfeature(self, xyz_sampled):
+ coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2)
+ coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
+
+ plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1)
+ line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1)
+
+ app_features = self.basis_mat((plane_feats * line_feats).T)
+
+ return app_features
+
+ def vectorDiffs(self, vector_comps):
+ total = 0
+
+ for idx in range(len(vector_comps)):
+ # print(self.line_coef.shape, vector_comps[idx].shape)
+ n_comp, n_size = vector_comps[idx].shape[:-1]
+
+ dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2))
+ # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape)
+ non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
+ # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape)
+ total = total + torch.mean(torch.abs(non_diagonal))
+ return total
+
+ def vector_comp_diffs(self):
+ return self.vectorDiffs(self.line_coef[:,-self.density_n_comp:]) + self.vectorDiffs(self.line_coef[:,:self.app_n_comp])
+
+ @torch.no_grad()
+ def up_sampling_VM(self, plane_coef, line_coef, res_target):
+ for i in range(len(self.vecMode)):
+ vec_id = self.vecMode[i]
+ mat_id_0, mat_id_1 = self.matMode[i]
+
+ plane_coef[i] = torch.nn.Parameter(
+ F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear',
+ align_corners=True))
+ line_coef[i] = torch.nn.Parameter(
+ F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
+
+ return plane_coef, line_coef
+
+ @torch.no_grad()
+ def upsample_volume_grid(self, res_target):
+ scale = res_target[0]/self.line_coef.shape[2] #assuming xyz have the same scale
+ plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',align_corners=True)
+ line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0],1), mode='bilinear',align_corners=True)
+ self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef)
+ self.compute_stepSize(res_target)
+ print(f'upsamping to {res_target}')
+
+
+class TensorVMSplit(TensorBase):
+ def __init__(self, aabb, gridSize, device, **kargs):
+ super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs)
+
+ def init_svd_volume(self, res, device):
+ self.density_plane, self.density_line = self.init_one_svd(
+ self.density_n_comp, self.gridSize, 0.1, device)
+ self.app_plane, self.app_line = self.init_one_svd(
+ self.app_n_comp, self.gridSize, 0.1, device)
+ self.basis_mat = torch.nn.Linear(
+ sum(self.app_n_comp), self.app_dim, bias=False).to(device)
+
+ def init_one_svd(self, n_component, gridSize, scale, device):
+ plane_coef, line_coef = [], []
+
+ for i in range(len(self.vecMode)):
+ vec_id = self.vecMode[i]
+ mat_id_0, mat_id_1 = self.matMode[i]
+ plane_coef.append(torch.nn.Parameter(
+ scale * torch.randn((1, n_component[i], gridSize[mat_id_1],
+ gridSize[mat_id_0]))))
+ line_coef.append(torch.nn.Parameter(
+ scale * torch.randn((1, n_component[i], gridSize[vec_id], 1))))
+
+ return (torch.nn.ParameterList(plane_coef).to(device),
+ torch.nn.ParameterList(line_coef).to(device))
+
+ def get_optparam_groups(self, lr_init_spatialxyz=0.02,
+ lr_init_network=0.001):
+ grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
+ {'params': self.density_plane, 'lr': lr_init_spatialxyz},
+ {'params': self.app_line, 'lr': lr_init_spatialxyz},
+ {'params': self.app_plane, 'lr': lr_init_spatialxyz},
+ {'params': self.basis_mat.parameters(),
+ 'lr':lr_init_network}]
+ if isinstance(self.renderModule, torch.nn.Module):
+ grad_vars += [{'params':self.renderModule.parameters(),
+ 'lr':lr_init_network}]
+ return grad_vars
+
+ def compute_densityfeature(self, points):
+ # plane + line basis
+ # [3, B, 1, 2]
+ coordinate_plane = points[..., self.matMode].transpose(0, -2) \
+ .view(3, -1, 1, 2)
+ coordinate_line = points[..., self.vecMode, None].transpose(0, -2)
+ coordinate_line = F.pad(coordinate_line, (1, 0)).reshape(3, -1, 1, 2)
+
+ sigma_feature = torch.zeros((points.shape[0],), device=points.device)
+
+ for idx in range(len(self.density_plane)):
+ # plane = self.density_plane[idx]
+ # line = self.density_line[idx]
+ plane = min_max_quantize(self.density_plane[idx], self.grid_bit)
+ line = min_max_quantize(self.density_line[idx], self.grid_bit)
+
+ plane_coef_point = F.grid_sample(
+ plane, coordinate_plane[[idx]],
+ align_corners=True).view(-1, *points.shape[:1])
+
+ line_coef_point = F.grid_sample(
+ line, coordinate_line[[idx]],
+ align_corners=True).view(-1, *points.shape[:1])
+
+ sigma_feature += torch.sum(plane_coef_point*line_coef_point, dim=0)
+
+ return sigma_feature
+
+ def compute_appfeature(self, points):
+ # plane + line basis
+ # [3, B, 1, 2]
+ coordinate_plane = points[..., self.matMode].transpose(0, -2) \
+ .view(3, -1, 1, 2)
+ coordinate_line = points[..., self.vecMode, None].transpose(0, -2)
+ coordinate_line = F.pad(coordinate_line, (1, 0)).reshape(3, -1, 1, 2)
+
+ plane_coef_point, line_coef_point = [], []
+ for idx in range(len(self.app_plane)):
+ # plane = self.app_plane[idx]
+ # line = self.app_line[idx]
+ plane = min_max_quantize(self.app_plane[idx], self.grid_bit)
+ line = min_max_quantize(self.app_line[idx], self.grid_bit)
+
+ plane_coef_point.append(F.grid_sample(
+ plane, coordinate_plane[[idx]],
+ align_corners=True).view(-1, *points.shape[:1]))
+ line_coef_point.append(F.grid_sample(
+ line, coordinate_line[[idx]],
+ align_corners=True).view(-1, *points.shape[:1]))
+
+ plane_coef_point = torch.cat(plane_coef_point)
+ line_coef_point = torch.cat(line_coef_point)
+
+ return self.basis_mat((plane_coef_point * line_coef_point).T)
+
+ @torch.no_grad()
+ def upsample_volume_grid(self, res_target):
+ self.app_plane, self.app_line = self.up_sampling_VM(
+ self.app_plane, self.app_line, res_target)
+ self.density_plane, self.density_line = self.up_sampling_VM(
+ self.density_plane, self.density_line, res_target)
+
+ self.update_stepSize(res_target)
+ print(f'upsamping to {res_target}')
+
+ @torch.no_grad()
+ def up_sampling_VM(self, plane_coef, line_coef, res_target):
+ for i in range(len(self.vecMode)):
+ vec_id = self.vecMode[i]
+ mat_id_0, mat_id_1 = self.matMode[i]
+ plane_coef[i] = torch.nn.Parameter(
+ F.interpolate(plane_coef[i].data,
+ size=(res_target[mat_id_1], res_target[mat_id_0]),
+ mode='bilinear', align_corners=True))
+ line_coef[i] = torch.nn.Parameter(
+ F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1),
+ mode='bilinear', align_corners=True))
+
+ return plane_coef, line_coef
+
+ @torch.no_grad()
+ def shrink(self, new_aabb):
+ print("====> shrinking ...")
+ xyz_min, xyz_max = new_aabb
+ t_l = (xyz_min - self.aabb[0]) / self.units
+ t_l = torch.round(t_l).long()
+
+ b_r = (xyz_max - self.aabb[0]) / self.units
+ b_r = torch.round(b_r).long() + 1
+ b_r = torch.stack([b_r, self.gridSize]).amin(0)
+
+ for i in range(len(self.vecMode)):
+ mode0 = self.vecMode[i]
+ self.density_line[i] = torch.nn.Parameter(
+ self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:])
+ self.app_line[i] = torch.nn.Parameter(
+ self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:])
+
+ mode0, mode1 = self.matMode[i]
+ self.density_plane[i] = torch.nn.Parameter(
+ self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],
+ t_l[mode0]:b_r[mode0]])
+ self.app_plane[i] = torch.nn.Parameter(
+ self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],
+ t_l[mode0]:b_r[mode0]])
+
+ if not torch.all(self.alphaMask.gridSize == self.gridSize):
+ t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
+ correct_aabb = torch.zeros_like(new_aabb)
+ correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
+ correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
+ print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
+ new_aabb = correct_aabb
+
+ newSize = b_r - t_l
+ self.aabb = new_aabb
+ self.update_stepSize((newSize[0], newSize[1], newSize[2]))
+
+ def vectorDiffs(self, vector_comps):
+ breakpoint()
+ total = 0
+ for idx in range(len(vector_comps)):
+ n_comp, n_size = vector_comps[idx].shape[1:-1]
+
+ dotp = torch.matmul(
+ vector_comps[idx].view(n_comp, n_size),
+ vector_comps[idx].view(n_comp ,n_size).transpose(-1, -2))
+ non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1]
+ total = total + torch.mean(torch.abs(non_diagonal))
+ return total
+
+ def vector_comp_diffs(self):
+ return (self.vectorDiffs(self.density_line)
+ + self.vectorDiffs(self.app_line))
+
+ def density_L1(self):
+ total = 0
+ for idx in range(len(self.density_plane)):
+ total = total + torch.mean(torch.abs(self.density_plane[idx])) \
+ + torch.mean(torch.abs(self.density_line[idx]))
+ return total
+
+ def TV_loss_density(self, reg):
+ total = 0
+ for idx in range(len(self.density_plane)):
+ total = total + reg(self.density_plane[idx]) * 1e-2
+ return total
+
+ def TV_loss_app(self, reg):
+ total = 0
+ for idx in range(len(self.app_plane)):
+ total = total + reg(self.app_plane[idx]) * 1e-2
+ return total
+
+
+class TensorCP(TensorBase):
+ def __init__(self, aabb, gridSize, device, **kargs):
+ super(TensorCP, self).__init__(aabb, gridSize, device, **kargs)
+
+
+ def init_svd_volume(self, res, device):
+ self.density_line = self.init_one_svd(self.density_n_comp[0], self.gridSize, 0.2, device)
+ self.app_line = self.init_one_svd(self.app_n_comp[0], self.gridSize, 0.2, device)
+ self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device)
+
+
+ def init_one_svd(self, n_component, gridSize, scale, device):
+ line_coef = []
+ for i in range(len(self.vecMode)):
+ vec_id = self.vecMode[i]
+ line_coef.append(
+ torch.nn.Parameter(scale * torch.randn((1, n_component, gridSize[vec_id], 1))))
+ return torch.nn.ParameterList(line_coef).to(device)
+
+
+ def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001):
+ grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz},
+ {'params': self.app_line, 'lr': lr_init_spatialxyz},
+ {'params': self.basis_mat.parameters(), 'lr':lr_init_network}]
+ if isinstance(self.renderModule, torch.nn.Module):
+ grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}]
+ return grad_vars
+
+ def compute_densityfeature(self, xyz_sampled):
+
+ coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
+
+
+ line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+ line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+ line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+ sigma_feature = torch.sum(line_coef_point, dim=0)
+
+
+ return sigma_feature
+
+ def compute_appfeature(self, xyz_sampled):
+
+ coordinate_line = torch.stack(
+ (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]]))
+ coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2)
+
+
+ line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+ line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+ line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]],
+ align_corners=True).view(-1, *xyz_sampled.shape[:1])
+
+ return self.basis_mat(line_coef_point.T)
+
+
+ @torch.no_grad()
+ def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target):
+
+ for i in range(len(self.vecMode)):
+ vec_id = self.vecMode[i]
+ density_line_coef[i] = torch.nn.Parameter(
+ F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
+ app_line_coef[i] = torch.nn.Parameter(
+ F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True))
+
+ return density_line_coef, app_line_coef
+
+ @torch.no_grad()
+ def upsample_volume_grid(self, res_target):
+ self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target)
+
+ self.update_stepSize(res_target)
+ print(f'upsamping to {res_target}')
+
+ @torch.no_grad()
+ def shrink(self, new_aabb):
+ print("====> shrinking ...")
+ xyz_min, xyz_max = new_aabb
+ t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units
+
+ t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1
+ b_r = torch.stack([b_r, self.gridSize]).amin(0)
+
+
+ for i in range(len(self.vecMode)):
+ mode0 = self.vecMode[i]
+ self.density_line[i] = torch.nn.Parameter(
+ self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:]
+ )
+ self.app_line[i] = torch.nn.Parameter(
+ self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:]
+ )
+
+ if not torch.all(self.alphaMask.gridSize == self.gridSize):
+ t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1)
+ correct_aabb = torch.zeros_like(new_aabb)
+ correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1]
+ correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1]
+ print("aabb", new_aabb, "\ncorrect aabb", correct_aabb)
+ new_aabb = correct_aabb
+
+ newSize = b_r - t_l
+ self.aabb = new_aabb
+ self.update_stepSize((newSize[0], newSize[1], newSize[2]))
+
+ def density_L1(self):
+ total = 0
+ for idx in range(len(self.density_line)):
+ total = total + torch.mean(torch.abs(self.density_line[idx]))
+ return total
+
+ def TV_loss_density(self, reg):
+ total = 0
+ for idx in range(len(self.density_line)):
+ total = total + reg(self.density_line[idx]) * 1e-3
+ return total
+
+ def TV_loss_app(self, reg):
+ total = 0
+ for idx in range(len(self.app_line)):
+ total = total + reg(self.app_line[idx]) * 1e-3
+ return total
diff --git a/TensoRF/models/tensorBase.py b/TensoRF/models/tensorBase.py
new file mode 100644
index 0000000..487d35c
--- /dev/null
+++ b/TensoRF/models/tensorBase.py
@@ -0,0 +1,479 @@
+import torch
+import torch.nn
+import torch.nn.functional as F
+from .sh import eval_sh_bases
+import numpy as np
+import time
+
+
+def positional_encoding(positions, freqs):
+ freq_bands = 2 ** torch.arange(freqs, dtype=torch.float32,
+ device=positions.device)
+ pts = (positions[..., None] * freq_bands).reshape(
+ positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF)
+ pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1)
+ return pts
+
+
+def raw2alpha(sigma, dist):
+ # sigma, dist [N_rays, N_samples]
+ alpha = 1. - torch.exp(-sigma*dist)
+
+ # T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1)
+ T = torch.cumprod(torch.cat([
+ torch.ones(alpha.shape[0], 1, device=alpha.device),
+ 1. - alpha + 1e-10], -1), -1)
+
+ weights = alpha * T[:, :-1] # [N_rays, N_samples]
+ return alpha, weights, T[:,-1:]
+
+
+def SHRender(xyz_sampled, viewdirs, features):
+ sh_mult = eval_sh_bases(2, viewdirs)[:, None]
+ rgb_sh = features.view(-1, 3, sh_mult.shape[-1])
+ rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5)
+ return rgb
+
+
+def RGBRender(xyz_sampled, viewdirs, features):
+
+ rgb = features
+ return rgb
+
+class AlphaGridMask(torch.nn.Module):
+ def __init__(self, device, aabb, alpha_volume):
+ super(AlphaGridMask, self).__init__()
+ self.device = device
+
+ self.aabb = aabb.to(self.device)
+ self.aabbSize = self.aabb[1] - self.aabb[0]
+ self.invgridSize = 1.0/self.aabbSize * 2
+ self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:])
+ self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device)
+
+ def sample_alpha(self, xyz_sampled):
+ xyz_sampled = self.normalize_coord(xyz_sampled)
+ alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1)
+
+ return alpha_vals
+
+ def normalize_coord(self, xyz_sampled):
+ return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1
+
+
+class MLPRender_Fea(torch.nn.Module):
+ def __init__(self,inChanel, viewpe=6, feape=6, featureC=128):
+ super(MLPRender_Fea, self).__init__()
+
+ self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel
+ self.viewpe = viewpe
+ self.feape = feape
+ layer1 = torch.nn.Linear(self.in_mlpC, featureC)
+ layer2 = torch.nn.Linear(featureC, featureC)
+ layer3 = torch.nn.Linear(featureC,3)
+
+ self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
+ torch.nn.init.constant_(self.mlp[-1].bias, 0)
+
+ def forward(self, pts, viewdirs, features):
+ indata = [features, viewdirs]
+ if self.feape > 0:
+ indata += [positional_encoding(features, self.feape)]
+ if self.viewpe > 0:
+ indata += [positional_encoding(viewdirs, self.viewpe)]
+ mlp_in = torch.cat(indata, dim=-1)
+ rgb = self.mlp(mlp_in)
+ rgb = torch.sigmoid(rgb)
+
+ return rgb
+
+class MLPRender_PE(torch.nn.Module):
+ def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128):
+ super(MLPRender_PE, self).__init__()
+
+ self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel #
+ self.viewpe = viewpe
+ self.pospe = pospe
+ layer1 = torch.nn.Linear(self.in_mlpC, featureC)
+ layer2 = torch.nn.Linear(featureC, featureC)
+ layer3 = torch.nn.Linear(featureC,3)
+
+ self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
+ torch.nn.init.constant_(self.mlp[-1].bias, 0)
+
+ def forward(self, pts, viewdirs, features):
+ indata = [features, viewdirs]
+ if self.pospe > 0:
+ indata += [positional_encoding(pts, self.pospe)]
+ if self.viewpe > 0:
+ indata += [positional_encoding(viewdirs, self.viewpe)]
+ mlp_in = torch.cat(indata, dim=-1)
+ rgb = self.mlp(mlp_in)
+ rgb = torch.sigmoid(rgb)
+
+ return rgb
+
+class MLPRender(torch.nn.Module):
+ def __init__(self,inChanel, viewpe=6, featureC=128):
+ super(MLPRender, self).__init__()
+
+ self.in_mlpC = (3+2*viewpe*3) + inChanel
+ self.viewpe = viewpe
+
+ layer1 = torch.nn.Linear(self.in_mlpC, featureC)
+ layer2 = torch.nn.Linear(featureC, featureC)
+ layer3 = torch.nn.Linear(featureC,3)
+
+ self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3)
+ torch.nn.init.constant_(self.mlp[-1].bias, 0)
+
+ def forward(self, pts, viewdirs, features):
+ indata = [features, viewdirs]
+ if self.viewpe > 0:
+ indata += [positional_encoding(viewdirs, self.viewpe)]
+ mlp_in = torch.cat(indata, dim=-1)
+ rgb = self.mlp(mlp_in)
+ rgb = torch.sigmoid(rgb)
+
+ return rgb
+
+
+
+class TensorBase(torch.nn.Module):
+ def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27,
+ shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0],
+ density_shift = -10, alphaMask_thres=0.001, distance_scale=25, rayMarch_weight_thres=0.0001,
+ pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0,
+ fea2denseAct = 'softplus', grid_bit=32):
+ super(TensorBase, self).__init__()
+
+ self.density_n_comp = density_n_comp
+ self.app_n_comp = appearance_n_comp
+ self.app_dim = app_dim
+ self.aabb = aabb
+ self.alphaMask = alphaMask
+ self.device=device
+
+ self.density_shift = density_shift
+ self.alphaMask_thres = alphaMask_thres
+ self.distance_scale = distance_scale
+ self.rayMarch_weight_thres = rayMarch_weight_thres
+ self.fea2denseAct = fea2denseAct
+
+ self.near_far = near_far
+ self.step_ratio = step_ratio
+
+
+ self.update_stepSize(gridSize)
+
+ self.matMode = [[0,1], [0,2], [1,2]]
+ self.vecMode = [2, 1, 0]
+ self.comp_w = [1,1,1]
+
+ self.grid_bit = grid_bit
+
+ self.init_svd_volume(gridSize[0], device)
+
+ self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC
+ self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device)
+
+ def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device):
+ if shadingMode == 'MLP_PE':
+ self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device)
+ elif shadingMode == 'MLP_Fea':
+ self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device)
+ elif shadingMode == 'MLP':
+ self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device)
+ elif shadingMode == 'SH':
+ self.renderModule = SHRender
+ elif shadingMode == 'RGB':
+ assert self.app_dim == 3
+ self.renderModule = RGBRender
+ else:
+ print("Unrecognized shading module")
+ exit()
+ print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe)
+ print(self.renderModule)
+
+ def update_stepSize(self, gridSize):
+ print("aabb", self.aabb.view(-1))
+ print("grid size", gridSize)
+ self.aabbSize = self.aabb[1] - self.aabb[0]
+ self.invaabbSize = 2.0/self.aabbSize
+ self.gridSize= torch.LongTensor(gridSize).to(self.device)
+ self.units=self.aabbSize / (self.gridSize-1)
+ self.stepSize=torch.mean(self.units)*self.step_ratio
+ self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize)))
+ self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1
+ print("sampling step size: ", self.stepSize)
+ print("sampling number: ", self.nSamples)
+
+ def init_svd_volume(self, res, device):
+ pass
+
+ def compute_features(self, xyz_sampled):
+ pass
+
+ def compute_densityfeature(self, xyz_sampled):
+ pass
+
+ def compute_appfeature(self, xyz_sampled):
+ pass
+
+ def normalize_coord(self, xyz_sampled):
+ return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1
+
+ def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001):
+ pass
+
+ def get_kwargs(self):
+ return {
+ 'aabb': self.aabb,
+ 'gridSize':self.gridSize.tolist(),
+ 'density_n_comp': self.density_n_comp,
+ 'appearance_n_comp': self.app_n_comp,
+ 'app_dim': self.app_dim,
+
+ 'density_shift': self.density_shift,
+ 'alphaMask_thres': self.alphaMask_thres,
+ 'distance_scale': self.distance_scale,
+ 'rayMarch_weight_thres': self.rayMarch_weight_thres,
+ 'fea2denseAct': self.fea2denseAct,
+
+ 'near_far': self.near_far,
+ 'step_ratio': self.step_ratio,
+
+ 'shadingMode': self.shadingMode,
+ 'pos_pe': self.pos_pe,
+ 'view_pe': self.view_pe,
+ 'fea_pe': self.fea_pe,
+ 'featureC': self.featureC,
+
+ 'grid_bit': self.grid_bit
+ }
+
+ def save(self, path):
+ kwargs = self.get_kwargs()
+ ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()}
+ if self.alphaMask is not None:
+ alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy()
+ ckpt.update({'alphaMask.shape':alpha_volume.shape})
+ ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))})
+ ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()})
+ torch.save(ckpt, path)
+
+ def load(self, ckpt):
+ if 'alphaMask.aabb' in ckpt.keys():
+ length = np.prod(ckpt['alphaMask.shape'])
+ alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape']))
+ self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device))
+ self.load_state_dict(ckpt['state_dict'])
+
+
+ def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1):
+ N_samples = N_samples if N_samples > 0 else self.nSamples
+ near, far = self.near_far
+ # interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o)
+ interpx = torch.linspace(near, far, N_samples, device=rays_o.device)
+ interpx = interpx.unsqueeze(0)
+
+ if is_train:
+ interpx += torch.rand_like(interpx) * ((far - near) / N_samples)
+
+ rays_pts = rays_o[..., None, :] \
+ + rays_d[..., None, :] * interpx[..., None]
+ mask_outbbox = ((self.aabb[0] > rays_pts)
+ | (rays_pts > self.aabb[1])).any(dim=-1)
+ return rays_pts, interpx, ~mask_outbbox
+
+ def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1):
+ N_samples = N_samples if N_samples>0 else self.nSamples
+ stepsize = self.stepSize
+ near, far = self.near_far
+ vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d)
+ rate_a = (self.aabb[1] - rays_o) / vec
+ rate_b = (self.aabb[0] - rays_o) / vec
+ t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far)
+
+ rng = torch.arange(N_samples, dtype=torch.float32, device=rays_o.device)
+ rng = rng[None]
+ if is_train:
+ rng = rng.repeat(rays_d.shape[-2],1)
+ rng += torch.rand_like(rng[:, [0]])
+ step = stepsize * rng
+ interpx = (t_min[...,None] + step)
+
+ rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None]
+ mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1)
+
+ return rays_pts, interpx, ~mask_outbbox
+
+
+ def shrink(self, new_aabb, voxel_size):
+ pass
+
+ @torch.no_grad()
+ def getDenseAlpha(self,gridSize=None):
+ gridSize = self.gridSize if gridSize is None else gridSize
+
+ samples = torch.stack(torch.meshgrid(
+ torch.linspace(0, 1, gridSize[0], device=self.device),
+ torch.linspace(0, 1, gridSize[1], device=self.device),
+ torch.linspace(0, 1, gridSize[2], device=self.device),
+ ), -1)
+ dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples
+
+ # dense_xyz = dense_xyz
+ # print(self.stepSize, self.distance_scale*self.aabbDiag)
+ alpha = torch.zeros_like(dense_xyz[...,0])
+ for i in range(gridSize[0]):
+ alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.stepSize).view((gridSize[1], gridSize[2]))
+ return alpha, dense_xyz
+
+ @torch.no_grad()
+ def updateAlphaMask(self, gridSize=(200,200,200)):
+
+ alpha, dense_xyz = self.getDenseAlpha(gridSize)
+ dense_xyz = dense_xyz.transpose(0,2).contiguous()
+ alpha = alpha.clamp(0,1).transpose(0,2).contiguous()[None,None]
+ total_voxels = gridSize[0] * gridSize[1] * gridSize[2]
+
+ ks = 3
+ alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1])
+ alpha[alpha>=self.alphaMask_thres] = 1
+ alpha[alpha0.5]
+
+ xyz_min = valid_xyz.amin(0)
+ xyz_max = valid_xyz.amax(0)
+
+ new_aabb = torch.stack((xyz_min, xyz_max))
+
+ total = torch.sum(alpha)
+ print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100))
+ return new_aabb
+
+ @torch.no_grad()
+ def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False):
+ print('========> filtering rays ...')
+ tt = time.time()
+
+ N = torch.tensor(all_rays.shape[:-1]).prod()
+
+ mask_filtered = []
+ idx_chunks = torch.split(torch.arange(N), chunk)
+ for idx_chunk in idx_chunks:
+ rays_chunk = all_rays[idx_chunk].to(self.device)
+
+ rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6]
+ if bbox_only:
+ vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d)
+ rate_a = (self.aabb[1] - rays_o) / vec
+ rate_b = (self.aabb[0] - rays_o) / vec
+ t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far)
+ t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far)
+ mask_inbbox = t_max > t_min
+
+ else:
+ xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False)
+ mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1)
+
+ mask_filtered.append(mask_inbbox.cpu())
+
+ mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1])
+
+ print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}')
+ return all_rays[mask_filtered], all_rgbs[mask_filtered]
+
+
+ def feature2density(self, density_features):
+ if self.fea2denseAct == "softplus":
+ return F.softplus(density_features+self.density_shift)
+ elif self.fea2denseAct == "relu":
+ return F.relu(density_features)
+
+
+ def compute_alpha(self, xyz_locs, length=1):
+
+ if self.alphaMask is not None:
+ alphas = self.alphaMask.sample_alpha(xyz_locs)
+ alpha_mask = alphas > 0
+ else:
+ alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool)
+
+
+ sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device)
+
+ if alpha_mask.any():
+ xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask])
+ sigma_feature = self.compute_densityfeature(xyz_sampled)
+ validsigma = self.feature2density(sigma_feature)
+ sigma[alpha_mask] = validsigma
+
+
+ alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1])
+
+ return alpha
+
+
+ def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1):
+
+ # sample points
+ viewdirs = rays_chunk[:, 3:6]
+ if ndc_ray:
+ xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
+ dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
+ rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True)
+ dists = dists * rays_norm
+ viewdirs = viewdirs / rays_norm
+ else:
+ xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples)
+ dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1)
+ viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape)
+
+ if self.alphaMask is not None:
+ alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid])
+ alpha_mask = alphas > 0
+ ray_invalid = ~ray_valid
+ ray_invalid[ray_valid] |= (~alpha_mask)
+ ray_valid = ~ray_invalid
+
+
+ sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device)
+ rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device)
+
+ if ray_valid.any():
+ xyz_sampled = self.normalize_coord(xyz_sampled)
+ sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid])
+
+ validsigma = self.feature2density(sigma_feature)
+ sigma[ray_valid] = validsigma
+
+
+ alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale)
+
+ app_mask = weight > self.rayMarch_weight_thres
+
+ if app_mask.any():
+ app_features = self.compute_appfeature(xyz_sampled[app_mask])
+ valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features)
+ rgb[app_mask] = valid_rgbs
+
+ acc_map = torch.sum(weight, -1)
+ rgb_map = torch.sum(weight[..., None] * rgb, -2)
+
+ if white_bg or (is_train and torch.rand((1,))<0.5):
+ rgb_map = rgb_map + (1. - acc_map[..., None])
+
+
+ rgb_map = rgb_map.clamp(0,1)
+
+ with torch.no_grad():
+ depth_map = torch.sum(weight * z_vals, -1)
+ depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1]
+
+ return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight
+
diff --git a/TensoRF/models/voxel_based.py b/TensoRF/models/voxel_based.py
new file mode 100644
index 0000000..c42c33d
--- /dev/null
+++ b/TensoRF/models/voxel_based.py
@@ -0,0 +1,147 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+import models.cosine_transform as ct
+
+
+class PREF(nn.Module):
+ def __init__(self, res, ch):
+ """
+ INPUTS
+ res: resolution
+ ch: channel
+ """
+ super(PREF, self).__init__()
+ reduced_res = np.ceil(np.log2(res)+1).astype('int')
+ self.res = res
+ self.ch = ch
+ self.reduced_res = reduced_res
+
+ self.phasor = nn.ParameterList([
+ # nn.Parameter(0.*torch.randn((1, reduced_res[0]*ch, res[1], res[2]),
+ nn.Parameter(0.*torch.randn((1, reduced_res[0]*ch, res[1], res[2]),
+ dtype=torch.float32),
+ requires_grad=True),
+ nn.Parameter(0.*torch.randn((1, reduced_res[1]*ch, res[0], res[2]),
+ dtype=torch.float32),
+ requires_grad=True),
+ nn.Parameter(0.*torch.randn((1, reduced_res[2]*ch, res[0], res[1]),
+ dtype=torch.float32),
+ requires_grad=True)])
+
+ def forward(self, inputs):
+ inputs = inputs.reshape(1, 1, *inputs.shape) # [B, 3] to [1, 1, B, 3]
+ Pu = self.phasor[0]
+ Pv = self.phasor[1]
+ Pw = self.phasor[2]
+
+ Pu = F.grid_sample(Pu, inputs[..., (1, 2)], mode='bilinear',
+ align_corners=True)
+ Pu = Pu.transpose(1, 3).reshape(-1, self.ch, self.reduced_res[0])
+ Pv = F.grid_sample(Pv, inputs[..., (0, 2)], mode='bilinear',
+ align_corners=True)
+ Pv = Pv.transpose(1, 3).reshape(-1, self.ch, self.reduced_res[1])
+ Pw = F.grid_sample(Pw, inputs[..., (0, 1)], mode='bilinear',
+ align_corners=True)
+ Pw = Pw.transpose(1, 3).reshape(-1, self.ch, self.reduced_res[2])
+
+ Pu = self.numerical_integration(Pu, inputs[0, 0, ..., 0])
+ Pv = self.numerical_integration(Pv, inputs[0, 0, ..., 1])
+ Pw = self.numerical_integration(Pw, inputs[0, 0, ..., 2])
+
+ outputs = Pu + Pv + Pw
+ return outputs
+
+ def numerical_integration(self, inputs, coords):
+ # assume coords in [-1, 1]
+ N = self.reduced_res[0] # inputs.size(-1)
+ coords = (coords + 1) / 2 * ((2**(N-1)) - 1)
+
+ '''
+ out = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5)
+ * (2 ** torch.arange(N-1).to(coords.device)) / (2**N))
+ out = 2 * torch.einsum('...C,...SC->...S', out, inputs[..., 1:])
+ return out + inputs[..., 0]
+ '''
+ out = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5)
+ * (2 ** torch.arange(N).to(coords.device)-0.5) / (2**N))
+ out = 2 * torch.einsum('...C,...SC->...S', out, inputs)
+ return out
+
+ def compute_tv(self):
+ weight = (2 ** torch.arange(self.reduced_res[0]).to(self.phasor[0].device) - 1).repeat(self.ch).reshape(-1, 1, 1)
+ return (self.phasor[0]*weight).square().mean() \
+ + (self.phasor[1]*weight).square().mean() \
+ + (self.phasor[2]*weight).square().mean()
+
+
+class PREFFFT(nn.Module):
+ def __init__(self, res, ch):
+ """
+ INPUTS
+ res: resolution
+ ch: channel
+ """
+ super(PREFFFT, self).__init__()
+ # reduced_res = (np.ceil(np.log2(res)) + 1).astype('int')
+ reduced_res = (np.ceil(np.log2(res)) + 0).astype('int')
+ self.res = res
+ self.ch = ch
+ self.reduced_res = reduced_res
+
+ self.phasor = nn.ParameterList([
+ nn.Parameter(0.001*torch.randn((1, 2*reduced_res[0]*ch, res[1], res[2]),
+ dtype=torch.float32),
+ requires_grad=True),
+ nn.Parameter(0.001*torch.randn((1, 2*reduced_res[1]*ch, res[0], res[2]),
+ dtype=torch.float32),
+ requires_grad=True),
+ nn.Parameter(0.001*torch.randn((1, 2*reduced_res[2]*ch, res[0], res[1]),
+ dtype=torch.float32),
+ requires_grad=True)])
+
+ def forward(self, inputs):
+ inputs = inputs.reshape(1, 1, *inputs.shape) # [B, 3] to [1, 1, B, 3]
+ Pu = self.phasor[0]
+ Pv = self.phasor[1]
+ Pw = self.phasor[2]
+
+ Pu = F.grid_sample(Pu, inputs[..., (1, 2)], mode='bilinear',
+ align_corners=True)
+ Pu = Pu.transpose(1, 3).reshape(-1, 2*self.ch, self.reduced_res[0])
+ Pv = F.grid_sample(Pv, inputs[..., (0, 2)], mode='bilinear',
+ align_corners=True)
+ Pv = Pv.transpose(1, 3).reshape(-1, 2*self.ch, self.reduced_res[1])
+ Pw = F.grid_sample(Pw, inputs[..., (0, 1)], mode='bilinear',
+ align_corners=True)
+ Pw = Pw.transpose(1, 3).reshape(-1, 2*self.ch, self.reduced_res[2])
+
+ Pu = self.numerical_integration(Pu, inputs[0, 0, ..., 0])
+ Pv = self.numerical_integration(Pv, inputs[0, 0, ..., 1])
+ Pw = self.numerical_integration(Pw, inputs[0, 0, ..., 2])
+
+ outputs = Pu + Pv + Pw
+ return outputs
+
+ def numerical_integration(self, inputs, coords):
+ # assume coords in [-1, 1]
+ N = inputs.size(-1)
+ '''
+ coords = (coords + 1) / 2 * ((2**(N-1)) - 1)
+
+ out = torch.cos(torch.pi * (coords.unsqueeze(-1) + 0.5)
+ * (2 ** torch.arange(N-1).to(coords.device)) / (2**N))
+ out = 2 * torch.einsum('...C,...SC->...S', out, inputs[..., 1:])
+ return out + inputs[..., 0]
+ '''
+ # inputs: [B, C, D]
+ inputs = torch.stack(torch.split(inputs, self.ch, dim=1), -1)
+ inputs = torch.view_as_complex(inputs)
+ coords = (coords + 1) / 2 * (2**N - 1)
+ coef = torch.cat([torch.zeros((1,)), 2**torch.arange(N-1)]).to(inputs.device)
+ out = torch.exp(2j* torch.pi * coords.unsqueeze(-1) * coef / (2**N))
+ out = torch.einsum('...C,...SC->...S', out, inputs)
+ return out.real
+
diff --git a/TensoRF/models/voxel_based_test.py b/TensoRF/models/voxel_based_test.py
new file mode 100644
index 0000000..aeb658f
--- /dev/null
+++ b/TensoRF/models/voxel_based_test.py
@@ -0,0 +1,20 @@
+import torch
+import torch.nn as nn
+import unittest
+from voxel_based import *
+
+
+class UtilsTest(unittest.TestCase):
+ def test_PREF(self):
+ inputs = torch.rand((32, 3)) # 3D coordinates
+ ch, hidden_ch, out_ch = 12, 64, 27
+ resolution = (64, 128, 48)
+
+ net = PREF(resolution, ch, hidden_ch, out_ch)
+
+ self.assertEqual(net(inputs).shape, (32, out_ch))
+
+
+if __name__ == '__main__':
+ unittest.main()
+
diff --git a/TensoRF/opt.py b/TensoRF/opt.py
new file mode 100644
index 0000000..1988add
--- /dev/null
+++ b/TensoRF/opt.py
@@ -0,0 +1,136 @@
+import configargparse
+
+def config_parser(cmd=None):
+ parser = configargparse.ArgumentParser()
+ parser.add_argument('--config', is_config_file=True,
+ help='config file path')
+ parser.add_argument("--expname", type=str,
+ help='experiment name')
+ parser.add_argument("--basedir", type=str, default='./log',
+ help='where to store ckpts and logs')
+ parser.add_argument("--add_timestamp", type=int, default=0,
+ help='add timestamp to dir')
+ parser.add_argument("--datadir", type=str, default='./data/llff/fern',
+ help='input data directory')
+ parser.add_argument("--progress_refresh_rate", type=int, default=10,
+ help='how many iterations to show psnrs or iters')
+
+ parser.add_argument('--with_depth', action='store_true')
+ parser.add_argument('--downsample_train', type=float, default=1.0)
+ parser.add_argument('--downsample_test', type=float, default=1.0)
+
+ parser.add_argument('--model_name', type=str, default='TensorVMSplit',
+ choices=['TensorVMSplit', 'TensorCP'])
+
+ # loader options
+ parser.add_argument("--batch_size", type=int, default=4096)
+ parser.add_argument("--n_iters", type=int, default=30000)
+
+ parser.add_argument('--dataset_name', type=str, default='blender',
+ choices=['blender', 'llff', 'nsvf', 'dtu','tankstemple', 'own_data'])
+
+
+ # training options
+ # learning rate
+ parser.add_argument("--lr_init", type=float, default=0.02,
+ help='learning rate')
+ parser.add_argument("--lr_basis", type=float, default=1e-3,
+ help='learning rate')
+ parser.add_argument("--lr_decay_iters", type=int, default=-1,
+ help = 'number of iterations the lr will decay to the target ratio; -1 will set it to n_iters')
+ parser.add_argument("--lr_decay_target_ratio", type=float, default=0.1,
+ help='the target decay ratio; after decay_iters inital lr decays to lr*ratio')
+ parser.add_argument("--lr_upsample_reset", type=int, default=1,
+ help='reset lr to inital after upsampling')
+
+ # loss
+ parser.add_argument("--L1_weight_inital", type=float, default=0.0,
+ help='loss weight')
+ parser.add_argument("--L1_weight_rest", type=float, default=0,
+ help='loss weight')
+ parser.add_argument("--Ortho_weight", type=float, default=0.0,
+ help='loss weight')
+ parser.add_argument("--TV_weight_density", type=float, default=0.0,
+ help='loss weight')
+ parser.add_argument("--TV_weight_app", type=float, default=0.0,
+ help='loss weight')
+
+ # model
+ # volume options
+ parser.add_argument("--n_lamb_sigma", type=int, action="append")
+ parser.add_argument("--n_lamb_sh", type=int, action="append")
+ parser.add_argument("--data_dim_color", type=int, default=27)
+
+ parser.add_argument("--rm_weight_mask_thre", type=float, default=0.0001,
+ help='mask points in ray marching')
+ parser.add_argument("--alpha_mask_thre", type=float, default=0.0001,
+ help='threshold for creating alpha mask volume')
+ parser.add_argument("--distance_scale", type=float, default=25,
+ help='scaling sampling distance for computation')
+ parser.add_argument("--density_shift", type=float, default=-10,
+ help='shift density in softplus; making density = 0 when feature == 0')
+
+ parser.add_argument("--grid_bit", type=int, default=32)
+
+ # network decoder
+ parser.add_argument("--shadingMode", type=str, default="MLP_PE",
+ help='which shading mode to use')
+ parser.add_argument("--pos_pe", type=int, default=6,
+ help='number of pe for pos')
+ parser.add_argument("--view_pe", type=int, default=6,
+ help='number of pe for view')
+ parser.add_argument("--fea_pe", type=int, default=6,
+ help='number of pe for features')
+ parser.add_argument("--featureC", type=int, default=128,
+ help='hidden feature channel in MLP')
+
+
+
+ parser.add_argument("--ckpt", type=str, default=None,
+ help='specific weights npy file to reload for coarse network')
+ parser.add_argument("--render_only", type=int, default=0)
+ parser.add_argument("--render_test", type=int, default=0)
+ parser.add_argument("--render_train", type=int, default=0)
+ parser.add_argument("--render_path", type=int, default=0)
+ parser.add_argument("--export_mesh", type=int, default=0)
+
+ # rendering options
+ parser.add_argument('--lindisp', default=False, action="store_true",
+ help='use disparity depth sampling')
+ parser.add_argument("--perturb", type=float, default=1.,
+ help='set to 0. for no jitter, 1. for jitter')
+ parser.add_argument("--accumulate_decay", type=float, default=0.998)
+ parser.add_argument("--fea2denseAct", type=str, default='softplus')
+ parser.add_argument('--ndc_ray', type=int, default=0)
+ parser.add_argument('--nSamples', type=int, default=1e6,
+ help='sample point each ray, pass 1e6 if automatic adjust')
+ parser.add_argument('--step_ratio',type=float,default=0.5)
+
+
+ ## blender flags
+ parser.add_argument("--white_bkgd", action='store_true',
+ help='set to render synthetic data on a white bkgd (always use for dvoxels)')
+
+
+
+ parser.add_argument('--N_voxel_init',
+ type=int,
+ default=100**3)
+ parser.add_argument('--N_voxel_final',
+ type=int,
+ default=300**3)
+ parser.add_argument("--upsamp_list", type=int, action="append")
+ parser.add_argument("--update_AlphaMask_list", type=int, action="append")
+
+ parser.add_argument('--idx_view',
+ type=int,
+ default=0)
+ # logging/saving options
+ parser.add_argument("--N_vis", type=int, default=5,
+ help='N images to vis')
+ parser.add_argument("--vis_every", type=int, default=10000,
+ help='frequency of visualize the image')
+ if cmd is not None:
+ return parser.parse_args(cmd)
+ else:
+ return parser.parse_args()
diff --git a/TensoRF/renderer.py b/TensoRF/renderer.py
new file mode 100644
index 0000000..9ea0646
--- /dev/null
+++ b/TensoRF/renderer.py
@@ -0,0 +1,145 @@
+import torch,os,imageio,sys
+from tqdm.auto import tqdm
+from dataLoader.ray_utils import get_rays
+from models.tensoRF import TensorVM, TensorCP, raw2alpha, TensorVMSplit, AlphaGridMask
+from utils import *
+from dataLoader.ray_utils import ndc_rays_blender
+
+
+def OctreeRender_trilinear_fast(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'):
+
+ rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], []
+ N_rays_all = rays.shape[0]
+ for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)):
+ rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device)
+
+ rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples)
+
+ rgbs.append(rgb_map)
+ depth_maps.append(depth_map)
+
+ return torch.cat(rgbs), None, torch.cat(depth_maps), None, None
+
+@torch.no_grad()
+def evaluation(test_dataset,tensorf, args, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
+ white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
+ PSNRs, rgb_maps, depth_maps = [], [], []
+ ssims,l_alex,l_vgg=[],[],[]
+ os.makedirs(savePath, exist_ok=True)
+ os.makedirs(savePath+"/rgbd", exist_ok=True)
+
+ try:
+ tqdm._instances.clear()
+ except Exception:
+ pass
+
+ near_far = test_dataset.near_far
+ img_eval_interval = 1 if N_vis < 0 else max(test_dataset.all_rays.shape[0] // N_vis,1)
+ idxs = list(range(0, test_dataset.all_rays.shape[0], img_eval_interval))
+ for idx, samples in tqdm(enumerate(test_dataset.all_rays[0::img_eval_interval]), file=sys.stdout):
+
+ W, H = test_dataset.img_wh
+ rays = samples.view(-1,samples.shape[-1])
+
+ rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=4096, N_samples=N_samples,
+ ndc_ray=ndc_ray, white_bg = white_bg, device=device)
+ rgb_map = rgb_map.clamp(0.0, 1.0)
+
+ rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
+
+ depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
+ if len(test_dataset.all_rgbs):
+ gt_rgb = test_dataset.all_rgbs[idxs[idx]].view(H, W, 3)
+ loss = torch.mean((rgb_map - gt_rgb) ** 2)
+ PSNRs.append(-10.0 * np.log(loss.item()) / np.log(10.0))
+
+ if compute_extra_metrics:
+ ssim = rgb_ssim(rgb_map, gt_rgb, 1)
+ l_a = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'alex', tensorf.device)
+ l_v = rgb_lpips(gt_rgb.numpy(), rgb_map.numpy(), 'vgg', tensorf.device)
+ ssims.append(ssim)
+ l_alex.append(l_a)
+ l_vgg.append(l_v)
+
+ rgb_map = (rgb_map.numpy() * 255).astype('uint8')
+ # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
+ rgb_maps.append(rgb_map)
+ depth_maps.append(depth_map)
+ if savePath is not None:
+ imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
+ rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
+ imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
+
+ imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=10)
+ imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=10)
+
+ if PSNRs:
+ psnr = np.mean(np.asarray(PSNRs))
+ if compute_extra_metrics:
+ ssim = np.mean(np.asarray(ssims))
+ l_a = np.mean(np.asarray(l_alex))
+ l_v = np.mean(np.asarray(l_vgg))
+ np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
+ else:
+ np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
+
+
+ return PSNRs
+
+@torch.no_grad()
+def evaluation_path(test_dataset,tensorf, c2ws, renderer, savePath=None, N_vis=5, prtx='', N_samples=-1,
+ white_bg=False, ndc_ray=False, compute_extra_metrics=True, device='cuda'):
+ PSNRs, rgb_maps, depth_maps = [], [], []
+ ssims,l_alex,l_vgg=[],[],[]
+ os.makedirs(savePath, exist_ok=True)
+ os.makedirs(savePath+"/rgbd", exist_ok=True)
+
+ try:
+ tqdm._instances.clear()
+ except Exception:
+ pass
+
+ near_far = test_dataset.near_far
+ for idx, c2w in tqdm(enumerate(c2ws)):
+
+ W, H = test_dataset.img_wh
+
+ c2w = torch.FloatTensor(c2w)
+ rays_o, rays_d = get_rays(test_dataset.directions, c2w) # both (h*w, 3)
+ if ndc_ray:
+ rays_o, rays_d = ndc_rays_blender(H, W, test_dataset.focal[0], 1.0, rays_o, rays_d)
+ rays = torch.cat([rays_o, rays_d], 1) # (h*w, 6)
+
+ rgb_map, _, depth_map, _, _ = renderer(rays, tensorf, chunk=8192, N_samples=N_samples,
+ ndc_ray=ndc_ray, white_bg = white_bg, device=device)
+ rgb_map = rgb_map.clamp(0.0, 1.0)
+
+ rgb_map, depth_map = rgb_map.reshape(H, W, 3).cpu(), depth_map.reshape(H, W).cpu()
+
+ depth_map, _ = visualize_depth_numpy(depth_map.numpy(),near_far)
+
+ rgb_map = (rgb_map.numpy() * 255).astype('uint8')
+ # rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
+ rgb_maps.append(rgb_map)
+ depth_maps.append(depth_map)
+ if savePath is not None:
+ imageio.imwrite(f'{savePath}/{prtx}{idx:03d}.png', rgb_map)
+ rgb_map = np.concatenate((rgb_map, depth_map), axis=1)
+ imageio.imwrite(f'{savePath}/rgbd/{prtx}{idx:03d}.png', rgb_map)
+
+ imageio.mimwrite(f'{savePath}/{prtx}video.mp4', np.stack(rgb_maps), fps=30, quality=8)
+ imageio.mimwrite(f'{savePath}/{prtx}depthvideo.mp4', np.stack(depth_maps), fps=30, quality=8)
+
+ if PSNRs:
+ psnr = np.mean(np.asarray(PSNRs))
+ if compute_extra_metrics:
+ ssim = np.mean(np.asarray(ssims))
+ l_a = np.mean(np.asarray(l_alex))
+ l_v = np.mean(np.asarray(l_vgg))
+ np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr, ssim, l_a, l_v]))
+ else:
+ np.savetxt(f'{savePath}/{prtx}mean.txt', np.asarray([psnr]))
+
+
+ return PSNRs
+
diff --git a/TensoRF/train.py b/TensoRF/train.py
new file mode 100644
index 0000000..37f06bc
--- /dev/null
+++ b/TensoRF/train.py
@@ -0,0 +1,332 @@
+
+import os
+from tqdm.auto import tqdm
+from opt import config_parser
+
+
+
+import json, random
+from renderer import *
+from utils import *
+from torch.utils.tensorboard import SummaryWriter
+import datetime
+
+from dataLoader import dataset_dict
+import sys
+
+
+
+device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+renderer = OctreeRender_trilinear_fast
+
+
+def count_params(module):
+ return sum(map(lambda x: x.numel(), module.parameters()))
+
+
+def tensorf_param_count(module):
+ total = count_params(module)
+ non_grid = count_params(module.renderModule) \
+ + count_params(module.basis_mat)
+ return total - non_grid, non_grid
+
+
+class SimpleSampler:
+ def __init__(self, total, batch):
+ self.total = total
+ self.batch = batch
+ self.curr = total
+ self.ids = None
+
+ def nextids(self):
+ self.curr+=self.batch
+ if self.curr + self.batch > self.total:
+ self.ids = torch.LongTensor(np.random.permutation(self.total))
+ self.curr = 0
+ return self.ids[self.curr:self.curr+self.batch]
+
+
+@torch.no_grad()
+def export_mesh(args):
+
+ ckpt = torch.load(args.ckpt, map_location=device)
+ kwargs = ckpt['kwargs']
+ kwargs.update({'device': device})
+ tensorf = eval(args.model_name)(**kwargs)
+ tensorf.load(ckpt)
+
+ alpha,_ = tensorf.getDenseAlpha()
+ convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=tensorf.aabb.cpu(), level=0.005)
+
+
+@torch.no_grad()
+def render_test(args):
+ # init dataset
+ dataset = dataset_dict[args.dataset_name]
+ test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
+ white_bg = test_dataset.white_bg
+ ndc_ray = args.ndc_ray
+
+ if not os.path.exists(args.ckpt):
+ print('the ckpt path does not exists!!')
+ return
+
+ ckpt = torch.load(args.ckpt, map_location=device)
+ kwargs = ckpt['kwargs']
+ kwargs.update({'device': device})
+ tensorf = eval(args.model_name)(**kwargs)
+ tensorf.load(ckpt)
+
+ logfolder = os.path.dirname(args.ckpt)
+ if args.render_train:
+ os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
+ train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
+ PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
+ N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
+ print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
+
+ if args.render_test:
+ os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
+ evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
+ N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
+
+ if args.render_path:
+ c2ws = test_dataset.render_path
+ os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
+ evaluation_path(test_dataset,tensorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
+ N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
+
+def reconstruction(args):
+
+ # init dataset
+ dataset = dataset_dict[args.dataset_name]
+ train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
+ test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
+ white_bg = train_dataset.white_bg
+ near_far = train_dataset.near_far
+ ndc_ray = args.ndc_ray
+
+ # init resolution
+ upsamp_list = args.upsamp_list
+ update_AlphaMask_list = args.update_AlphaMask_list
+ n_lamb_sigma = args.n_lamb_sigma
+ n_lamb_sh = args.n_lamb_sh
+
+
+ if args.add_timestamp:
+ logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
+ else:
+ logfolder = f'{args.basedir}/{args.expname}'
+
+
+ # init log file
+ os.makedirs(logfolder, exist_ok=True)
+ os.makedirs(f'{logfolder}/imgs_vis', exist_ok=True)
+ os.makedirs(f'{logfolder}/imgs_rgba', exist_ok=True)
+ os.makedirs(f'{logfolder}/rgba', exist_ok=True)
+ summary_writer = SummaryWriter(logfolder)
+
+ # init parameters
+ aabb = train_dataset.scene_bbox.to(device)
+ reso_cur = N_to_reso(args.N_voxel_init, aabb)
+ nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
+
+ if args.ckpt is not None:
+ ckpt = torch.load(args.ckpt, map_location=device)
+ kwargs = ckpt['kwargs']
+ kwargs.update({'device':device})
+ tensorf = eval(args.model_name)(**kwargs)
+ tensorf.load(ckpt)
+ else:
+ tensorf = eval(args.model_name)(aabb, reso_cur, device,
+ density_n_comp=n_lamb_sigma, appearance_n_comp=n_lamb_sh, app_dim=args.data_dim_color, near_far=near_far,
+ shadingMode=args.shadingMode, alphaMask_thres=args.alpha_mask_thre, density_shift=args.density_shift, distance_scale=args.distance_scale,
+ pos_pe=args.pos_pe, view_pe=args.view_pe, fea_pe=args.fea_pe, featureC=args.featureC, step_ratio=args.step_ratio, fea2denseAct=args.fea2denseAct, grid_bit=args.grid_bit)
+
+ print(tensorf)
+ print(sum([p.numel() for p in tensorf.parameters()]) * 16 / 8_388_608)
+
+ grad_vars = tensorf.get_optparam_groups(args.lr_init, args.lr_basis)
+ if args.lr_decay_iters > 0:
+ lr_factor = args.lr_decay_target_ratio**(1/args.lr_decay_iters)
+ else:
+ args.lr_decay_iters = args.n_iters
+ lr_factor = args.lr_decay_target_ratio**(1/args.n_iters)
+
+ print("lr decay", args.lr_decay_target_ratio, args.lr_decay_iters)
+
+ optimizer = torch.optim.Adam(grad_vars, betas=(0.9,0.99))
+
+
+ #linear in logrithmic space
+ N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(args.N_voxel_init), np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]
+
+
+ torch.cuda.empty_cache()
+ PSNRs,PSNRs_test = [],[0]
+
+ allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
+ if not args.ndc_ray:
+ allrays, allrgbs = tensorf.filtering_rays(allrays, allrgbs, bbox_only=True)
+ allrays = allrays.cuda()
+ allrgbs = allrgbs.cuda()
+ trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)
+
+ Ortho_reg_weight = args.Ortho_weight
+ print("initial Ortho_reg_weight", Ortho_reg_weight)
+
+ L1_reg_weight = args.L1_weight_inital
+ print("initial L1_reg_weight", L1_reg_weight)
+ TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
+ tvreg = TVLoss()
+ print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")
+
+ pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
+ for iteration in pbar:
+ ray_idx = trainingSampler.nextids()
+ rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx] # .to(device)
+
+ #rgb_map, alphas_map, depth_map, weights, uncertainty
+ rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, tensorf, chunk=args.batch_size,
+ N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)
+
+ loss = torch.mean((rgb_map - rgb_train) ** 2)
+
+ # loss
+ total_loss = loss
+ if Ortho_reg_weight > 0:
+ loss_reg = tensorf.vector_comp_diffs()
+ total_loss += Ortho_reg_weight*loss_reg
+ summary_writer.add_scalar('train/reg', loss_reg.detach().item(), global_step=iteration)
+ if L1_reg_weight > 0:
+ loss_reg_L1 = tensorf.density_L1()
+ total_loss += L1_reg_weight*loss_reg_L1
+ summary_writer.add_scalar('train/reg_l1', loss_reg_L1.detach().item(), global_step=iteration)
+
+ if TV_weight_density>0:
+ TV_weight_density *= lr_factor
+ loss_tv = tensorf.TV_loss_density(tvreg) * TV_weight_density
+ total_loss = total_loss + loss_tv
+ summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
+ if TV_weight_app>0:
+ TV_weight_app *= lr_factor
+ loss_tv = tensorf.TV_loss_app(tvreg)*TV_weight_app
+ total_loss = total_loss + loss_tv
+ summary_writer.add_scalar('train/reg_tv_app', loss_tv.detach().item(), global_step=iteration)
+
+ optimizer.zero_grad()
+ total_loss.backward()
+ optimizer.step()
+
+ loss = loss.detach().item()
+
+ PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
+ summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
+ summary_writer.add_scalar('train/mse', loss, global_step=iteration)
+
+ for param_group in optimizer.param_groups:
+ param_group['lr'] = param_group['lr'] * lr_factor
+
+ # Print the current values of the losses.
+ if iteration % args.progress_refresh_rate == 0:
+ pbar.set_description(
+ f'Iteration {iteration:05d}:'
+ + f' train_psnr = {float(np.mean(PSNRs)):.2f}'
+ + f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
+ + f' mse = {loss:.6f}'
+ )
+ PSNRs = []
+
+ if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
+ PSNRs_test = evaluation(test_dataset,tensorf, args, renderer, f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
+ prtx=f'{iteration:06d}_', N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, compute_extra_metrics=False)
+ summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
+
+ if iteration in update_AlphaMask_list:
+ if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256**3:
+ # update volume resolution
+ reso_mask = reso_cur
+
+ new_aabb = tensorf.updateAlphaMask(tuple(reso_mask))
+
+ if iteration == update_AlphaMask_list[0]:
+ tensorf.shrink(new_aabb)
+ # tensorVM.alphaMask = None
+ L1_reg_weight = args.L1_weight_rest
+ print("continuing L1_reg_weight", L1_reg_weight)
+
+ if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
+ # filter rays outside the bbox
+ allrays,allrgbs = tensorf.filtering_rays(allrays,allrgbs)
+ trainingSampler = SimpleSampler(allrgbs.shape[0],
+ args.batch_size)
+ allrays = allrays.cuda()
+ allrgbs = allrgbs.cuda()
+
+ if iteration in upsamp_list:
+ n_voxels = N_voxel_list.pop(0)
+ reso_cur = N_to_reso(n_voxels, tensorf.aabb)
+ nSamples = min(args.nSamples,
+ cal_n_samples(reso_cur, args.step_ratio))
+ tensorf.upsample_volume_grid(reso_cur)
+
+ if args.lr_upsample_reset:
+ print("reset lr to initial")
+ lr_scale = 1 #0.1 ** (iteration / args.n_iters)
+ else:
+ lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
+ grad_vars = tensorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
+ optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))
+
+ tensorf.save(f'{logfolder}/{args.expname}.th')
+
+ grid, non_grid = tensorf_param_count(tensorf)
+ grid_bytes = grid * args.grid_bit / 8
+ non_grid_bytes = non_grid * 4
+ print(f'total: {(grid_bytes + non_grid_bytes)/1_048_576:.3f}MB '
+ f'(G ({args.grid_bit}bit): {grid_bytes/1_048_576:.3f}MB) '
+ f'(N: {non_grid_bytes/1_048_576:3f}MB)')
+
+ if args.render_train:
+ os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
+ train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
+ PSNRs_test = evaluation(train_dataset,tensorf, args, renderer, f'{logfolder}/imgs_train_all/',
+ N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
+ print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
+
+ if args.render_test:
+ os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
+ PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/imgs_test_all/',
+ N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
+ summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
+
+ print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
+
+ if args.render_path:
+ c2ws = test_dataset.render_path
+ # c2ws = test_dataset.poses
+ print('========>',c2ws.shape)
+ os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
+ evaluation_path(test_dataset,tensorf, c2ws, renderer,
+ f'{logfolder}/imgs_path_all/',
+ N_vis=-1, N_samples=-1, white_bg=white_bg,
+ ndc_ray=ndc_ray,device=device)
+
+
+if __name__ == '__main__':
+ torch.set_default_dtype(torch.float32)
+ torch.manual_seed(20211202)
+ np.random.seed(20211202)
+
+ args = config_parser()
+ print(args)
+
+ if args.export_mesh:
+ export_mesh(args)
+
+ if args.render_only and (args.render_test or args.render_path):
+ render_test(args)
+ else:
+ reconstruction(args)
+
diff --git a/TensoRF/utils.py b/TensoRF/utils.py
new file mode 100644
index 0000000..3c29586
--- /dev/null
+++ b/TensoRF/utils.py
@@ -0,0 +1,221 @@
+import cv2,torch
+import numpy as np
+from PIL import Image
+import torchvision.transforms as T
+import torch.nn.functional as F
+import scipy.signal
+
+mse2psnr = lambda x : -10. * torch.log(x) / torch.log(torch.Tensor([10.]))
+
+
+def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """
+ depth: (H, W)
+ """
+
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi,ma = minmax
+
+ x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
+ x = (255*x).astype(np.uint8)
+ x_ = cv2.applyColorMap(x, cmap)
+ return x_, [mi,ma]
+
+def init_log(log, keys):
+ for key in keys:
+ log[key] = torch.tensor([0.0], dtype=float)
+ return log
+
+def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """
+ depth: (H, W)
+ """
+ if type(depth) is not np.ndarray:
+ depth = depth.cpu().numpy()
+
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi,ma = minmax
+
+ x = (x-mi)/(ma-mi+1e-8) # normalize to 0~1
+ x = (255*x).astype(np.uint8)
+ x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
+ x_ = T.ToTensor()(x_) # (3, H, W)
+ return x_, [mi,ma]
+
+def N_to_reso(n_voxels, bbox):
+ xyz_min, xyz_max = bbox
+ dim = len(xyz_min)
+ voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / dim)
+ return ((xyz_max - xyz_min) / voxel_size).long().tolist()
+
+def cal_n_samples(reso, step_ratio=0.5):
+ return int(np.linalg.norm(reso)/step_ratio)
+
+
+
+
+__LPIPS__ = {}
+def init_lpips(net_name, device):
+ assert net_name in ['alex', 'vgg']
+ import lpips
+ print(f'init_lpips: lpips_{net_name}')
+ return lpips.LPIPS(net=net_name, version='0.1').eval().to(device)
+
+def rgb_lpips(np_gt, np_im, net_name, device):
+ if net_name not in __LPIPS__:
+ __LPIPS__[net_name] = init_lpips(net_name, device)
+ gt = torch.from_numpy(np_gt).permute([2, 0, 1]).contiguous().to(device)
+ im = torch.from_numpy(np_im).permute([2, 0, 1]).contiguous().to(device)
+ return __LPIPS__[net_name](gt, im, normalize=True).item()
+
+
+def findItem(items, target):
+ for one in items:
+ if one[:len(target)]==target:
+ return one
+ return None
+
+
+''' Evaluation metrics (ssim, lpips)
+'''
+def rgb_ssim(img0, img1, max_val,
+ filter_size=11,
+ filter_sigma=1.5,
+ k1=0.01,
+ k2=0.03,
+ return_map=False):
+ # Modified from https://github.com/google/mipnerf/blob/16e73dfdb52044dcceb47cda5243a686391a6e0f/internal/math.py#L58
+ assert len(img0.shape) == 3
+ assert img0.shape[-1] == 3
+ assert img0.shape == img1.shape
+
+ # Construct a 1D Gaussian blur filter.
+ hw = filter_size // 2
+ shift = (2 * hw - filter_size + 1) / 2
+ f_i = ((np.arange(filter_size) - hw + shift) / filter_sigma)**2
+ filt = np.exp(-0.5 * f_i)
+ filt /= np.sum(filt)
+
+ # Blur in x and y (faster than the 2D convolution).
+ def convolve2d(z, f):
+ return scipy.signal.convolve2d(z, f, mode='valid')
+
+ filt_fn = lambda z: np.stack([
+ convolve2d(convolve2d(z[...,i], filt[:, None]), filt[None, :])
+ for i in range(z.shape[-1])], -1)
+ mu0 = filt_fn(img0)
+ mu1 = filt_fn(img1)
+ mu00 = mu0 * mu0
+ mu11 = mu1 * mu1
+ mu01 = mu0 * mu1
+ sigma00 = filt_fn(img0**2) - mu00
+ sigma11 = filt_fn(img1**2) - mu11
+ sigma01 = filt_fn(img0 * img1) - mu01
+
+ # Clip the variances and covariances to valid values.
+ # Variance must be non-negative:
+ sigma00 = np.maximum(0., sigma00)
+ sigma11 = np.maximum(0., sigma11)
+ sigma01 = np.sign(sigma01) * np.minimum(
+ np.sqrt(sigma00 * sigma11), np.abs(sigma01))
+ c1 = (k1 * max_val)**2
+ c2 = (k2 * max_val)**2
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
+ ssim_map = numer / denom
+ ssim = np.mean(ssim_map)
+ return ssim_map if return_map else ssim
+
+
+import torch.nn as nn
+class TVLoss(nn.Module):
+ def __init__(self,TVLoss_weight=1):
+ super(TVLoss,self).__init__()
+ self.TVLoss_weight = TVLoss_weight
+
+ def forward(self,x):
+ batch_size = x.size()[0]
+ h_x = x.size()[2]
+ w_x = x.size()[3]
+ count_h = self._tensor_size(x[:,:,1:,:])
+ count_w = self._tensor_size(x[:,:,:,1:])
+ h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
+ w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
+ return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
+
+ def _tensor_size(self,t):
+ return t.size()[1]*t.size()[2]*t.size()[3]
+
+
+
+import plyfile
+import skimage.measure
+def convert_sdf_samples_to_ply(
+ pytorch_3d_sdf_tensor,
+ ply_filename_out,
+ bbox,
+ level=0.5,
+ offset=None,
+ scale=None,
+):
+ """
+ Convert sdf samples to .ply
+
+ :param pytorch_3d_sdf_tensor: a torch.FloatTensor of shape (n,n,n)
+ :voxel_grid_origin: a list of three floats: the bottom, left, down origin of the voxel grid
+ :voxel_size: float, the size of the voxels
+ :ply_filename_out: string, path of the filename to save to
+
+ This function adapted from: https://github.com/RobotLocomotion/spartan
+ """
+
+ numpy_3d_sdf_tensor = pytorch_3d_sdf_tensor.numpy()
+ voxel_size = list((bbox[1]-bbox[0]) / np.array(pytorch_3d_sdf_tensor.shape))
+
+ verts, faces, normals, values = skimage.measure.marching_cubes(
+ numpy_3d_sdf_tensor, level=level, spacing=voxel_size
+ )
+ faces = faces[...,::-1] # inverse face orientation
+
+ # transform from voxel coordinates to camera coordinates
+ # note x and y are flipped in the output of marching_cubes
+ mesh_points = np.zeros_like(verts)
+ mesh_points[:, 0] = bbox[0,0] + verts[:, 0]
+ mesh_points[:, 1] = bbox[0,1] + verts[:, 1]
+ mesh_points[:, 2] = bbox[0,2] + verts[:, 2]
+
+ # apply additional offset and scale
+ if scale is not None:
+ mesh_points = mesh_points / scale
+ if offset is not None:
+ mesh_points = mesh_points - offset
+
+ # try writing to the ply file
+
+ num_verts = verts.shape[0]
+ num_faces = faces.shape[0]
+
+ verts_tuple = np.zeros((num_verts,), dtype=[("x", "f4"), ("y", "f4"), ("z", "f4")])
+
+ for i in range(0, num_verts):
+ verts_tuple[i] = tuple(mesh_points[i, :])
+
+ faces_building = []
+ for i in range(0, num_faces):
+ faces_building.append(((faces[i, :].tolist(),)))
+ faces_tuple = np.array(faces_building, dtype=[("vertex_indices", "i4", (3,))])
+
+ el_verts = plyfile.PlyElement.describe(verts_tuple, "vertex")
+ el_faces = plyfile.PlyElement.describe(faces_tuple, "face")
+
+ ply_data = plyfile.PlyData([el_verts, el_faces])
+ print("saving mesh to %s" % (ply_filename_out))
+ ply_data.write(ply_filename_out)
diff --git a/TensoRF/vis_utils.py b/TensoRF/vis_utils.py
new file mode 100644
index 0000000..f27add2
--- /dev/null
+++ b/TensoRF/vis_utils.py
@@ -0,0 +1,37 @@
+import cv2
+import numpy as np
+
+
+def visualize_depth_numpy(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """ depth: (H, W) """
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi, ma = minmax
+
+ x = (x-mi) / (ma-mi+1e-8) # normalize to 0~1
+ x = (255 * x).astype(np.uint8)
+ x_ = cv2.applyColorMap(x, cmap)
+ return x_, [mi, ma]
+
+
+def visualize_depth(depth, minmax=None, cmap=cv2.COLORMAP_JET):
+ """ depth: (H, W) """
+ if type(depth) is not np.ndarray:
+ depth = depth.cpu().numpy()
+
+ x = np.nan_to_num(depth) # change nan to 0
+ if minmax is None:
+ mi = np.min(x[x>0]) # get minimum positive depth (ignore background)
+ ma = np.max(x)
+ else:
+ mi, ma = minmax
+
+ x = (x-mi) / (ma-mi+1e-8) # normalize to 0~1
+ x = (255 * x).astype(np.uint8)
+ x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
+ x_ = T.ToTensor()(x_) # (3, H, W)
+ return x_, [mi, ma]
+