Skip to content

Commit

Permalink
tidy up train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 30, 2022
1 parent 49f0acb commit df33955
Showing 1 changed file with 110 additions and 67 deletions.
177 changes: 110 additions & 67 deletions PREF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import os
import random
import sys
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm

from dataLoader import dataset_dict
from opt import config_parser
from renderer import *
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
from utils import *


Expand Down Expand Up @@ -40,14 +41,16 @@ def export_mesh(args):
phasorf.load(ckpt)

alpha,_ = phasorf.getDenseAlpha()
convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',bbox=phasorf.aabb.cpu(), level=0.005)
convert_sdf_samples_to_ply(alpha.cpu(), f'{args.ckpt[:-3]}.ply',
bbox=phasorf.aabb.cpu(), level=0.005)


@torch.no_grad()
def render_test(args):
# init dataset
dataset = dataset_dict[args.dataset_name]
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
test_dataset = dataset(args.datadir, split='test',
downsample=args.downsample_train, is_stack=True)
white_bg = test_dataset.white_bg
ndc_ray = args.ndc_ray

Expand All @@ -64,30 +67,43 @@ def render_test(args):
logfolder = os.path.dirname(args.ckpt)
if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset,phasorf, args, renderer, f'{logfolder}/imgs_train_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
train_dataset = dataset(args.datadir, split='train',
downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset, phasorf, args, renderer,
f'{logfolder}/imgs_train_all/', N_vis=-1,
N_samples=-1, white_bg=white_bg,
ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} '
f'<========================')

if args.render_test:
os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
PSNRs_test = evaluation(test_dataset,phasorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
PSNRs_test = evaluation(test_dataset, phasorf, args, renderer,
f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg=white_bg,
ndc_ray=ndc_ray, device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} '
f'<========================')

if args.render_path:
c2ws = test_dataset.render_path
os.makedirs(f'{logfolder}/{args.expname}/imgs_path_all', exist_ok=True)
PSNRs_test = evaluation_path(test_dataset,phasorf, c2ws, renderer, f'{logfolder}/{args.expname}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} <========================')
PSNRs_test = evaluation_path(test_dataset, phasorf, c2ws, renderer,
f'{logfolder}/{args.expname}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg=white_bg,
ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} train all psnr: {np.mean(PSNRs_test)} '
f'<========================')

def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False):

def reconstruction(args, return_bbox=False, return_memory=False,
bbox_only=False):
# init dataset
dataset = dataset_dict[args.dataset_name]
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=False)
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
train_dataset = dataset(args.datadir, split='train',
downsample=args.downsample_train, is_stack=False)
test_dataset = dataset(args.datadir, split='test',
downsample=args.downsample_train, is_stack=True)
white_bg = train_dataset.white_bg
near_far = train_dataset.near_far
ndc_ray = args.ndc_ray
Expand All @@ -97,7 +113,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
update_AlphaMask_list = args.update_AlphaMask_list

if args.add_timestamp:
logfolder = f'{args.basedir}/{args.expname}{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
logfolder = f'{args.basedir}/{args.expname}' \
f'{datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")}'
else:
logfolder = f'{args.basedir}/{args.expname}'

Expand All @@ -112,7 +129,8 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False

# init parameters
if not bbox_only and args.dataset_name=='blender':
# use tight bbox pre-extracted and stored in misc.py, which takes 2k iters
# use tight bbox pre-extracted and stored in misc.py,
# which takes 2k iters
data = args.datadir.split('/')[-1]
from misc import blender_aabb
aabb = torch.tensor(blender_aabb[data]).reshape(2,3).to(device)
Expand Down Expand Up @@ -174,52 +192,58 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
np.log(args.N_voxel_final), len(upsamp_list)+1))).long()).tolist()[1:]

torch.cuda.empty_cache()
PSNRs,PSNRs_test = [],[0]
PSNRs, PSNRs_test = [], [0]

allrays, allrgbs = train_dataset.all_rays, train_dataset.all_rgbs
if not args.ndc_ray:
allrays, allrgbs = phasorf.filtering_rays(allrays, allrgbs, bbox_only=True)
allrays, allrgbs = phasorf.filtering_rays(allrays, allrgbs,
bbox_only=True)
allrays = allrays.cuda()
allrgbs = allrgbs.cuda()
trainingSampler = SimpleSampler(allrays.shape[0], args.batch_size)

TV_weight_density = args.TV_weight_density
TV_weight_app = args.TV_weight_app
print(f"initial TV_weight density: {TV_weight_density} "
f"appearance: {TV_weight_app}")

TV_weight_density, TV_weight_app = args.TV_weight_density, args.TV_weight_app
print(f"initial TV_weight density: {TV_weight_density} appearance: {TV_weight_app}")


pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate, file=sys.stdout)
pbar = tqdm(range(args.n_iters), miniters=args.progress_refresh_rate)

for iteration in pbar:
ray_idx = trainingSampler.nextids()
rays_train, rgb_train = allrays[ray_idx], allrgbs[ray_idx].to(device)

rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(rays_train, phasorf, chunk=args.batch_size,
N_samples=nSamples, white_bg = white_bg, ndc_ray=ndc_ray, device=device, is_train=True)
rgb_map, alphas_map, depth_map, weights, uncertainty = renderer(
rays_train, phasorf, chunk=args.batch_size, N_samples=nSamples,
white_bg=white_bg, ndc_ray=ndc_ray, device=device, is_train=True)

loss = torch.mean((rgb_map - rgb_train) ** 2)

# loss
total_loss = loss
loss_tv = torch.tensor([0.0]).cuda()
loss_tv = torch.tensor([0.]).cuda()

if TV_weight_density > 0 and (iteration % args.TV_step == 0):
TV_weight_density *= lr_factor
loss_tv = phasorf.Parseval_Loss() * TV_weight_density
total_loss = total_loss + loss_tv
summary_writer.add_scalar('train/reg_tv_density', loss_tv.detach().item(), global_step=iteration)
summary_writer.add_scalar('train/reg_tv_density',
loss_tv.detach().item(),
global_step=iteration)

if TV_weight_app>0:
if TV_weight_app > 0:
TV_weight_app *= lr_factor
raise NotImplementedError('not implemented')

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

loss = loss.detach().item()

PSNRs.append(-10.0 * np.log(loss) / np.log(10.0))
summary_writer.add_scalar('train/PSNR', PSNRs[-1], global_step=iteration)
summary_writer.add_scalar('train/PSNR', PSNRs[-1],
global_step=iteration)
summary_writer.add_scalar('train/mse', loss, global_step=iteration)

for param_group in optimizer.param_groups:
Expand All @@ -229,25 +253,27 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
if iteration % args.progress_refresh_rate == 0:
pbar.set_description(
f'Iteration {iteration:05d}:'
+ f' train_psnr = {float(np.mean(PSNRs)):.2f}'
+ f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
+ f' mse = {loss:.6f}'
+ f' tv_loss = {loss_tv.detach().item():.6f}'
)
f' train_psnr = {float(np.mean(PSNRs)):.2f}'
f' test_psnr = {float(np.mean(PSNRs_test)):.2f}'
f' mse = {loss:.6f} tv_loss = {loss_tv.detach().item():.6f}')
PSNRs = []

if iteration % args.vis_every == args.vis_every - 1 and args.N_vis!=0:
PSNRs_test = evaluation(test_dataset,phasorf, args, renderer,
f'{logfolder}/imgs_vis/', N_vis=args.N_vis,
prtx=f'{iteration:06d}_', N_samples=nSamples,
f'{logfolder}/imgs_vis/',
prtx=f'{iteration:06d}_',
N_samples=nSamples, N_vis=args.N_vis,
white_bg = white_bg, ndc_ray=ndc_ray,
compute_extra_metrics=args.compute_extra_metric)
print(np.mean(PSNRs_test))
summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test), global_step=iteration)
summary_writer.add_scalar('test/psnr', np.mean(PSNRs_test),
global_step=iteration)

# TODO: to accelerate
if update_AlphaMask_list is not None and iteration in update_AlphaMask_list:
if reso_cur[0] * reso_cur[1] * reso_cur[2]<256**3:# update volume resolution
if update_AlphaMask_list is not None \
and iteration in update_AlphaMask_list:
# update volume resolution
if reso_cur[0] * reso_cur[1] * reso_cur[2] < 256**3:
reso_mask = reso_cur
new_aabb = phasorf.updateAlphaMask(tuple(reso_mask))

Expand All @@ -262,49 +288,63 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
# phasorf.shrink(new_aabb)
if args.TV_weight_density_reset >= 0:
TV_weight_density = args.TV_weight_density_reset
print(f'TV weight density reset to {args.TV_weight_density_reset}')
print(f'TV weight density reset to '
f'{args.TV_weight_density_reset}')

if not args.ndc_ray and iteration == update_AlphaMask_list[1]:
# filter rays outside the bbox
allrays,allrgbs = phasorf.filtering_rays(allrays,allrgbs)
trainingSampler = SimpleSampler(allrgbs.shape[0], args.batch_size)
trainingSampler = SimpleSampler(allrgbs.shape[0],
args.batch_size)
allrays = allrays.cuda()
allrgbs = allrgbs.cuda()

# TODO:
if upsamp_list is not None and iteration in upsamp_list:
n_voxels = N_voxel_list.pop(0)
reso_cur = N_to_reso(n_voxels, phasorf.aabb)
nSamples = min(args.nSamples, cal_n_samples(reso_cur,args.step_ratio))
nSamples = min(args.nSamples,
cal_n_samples(reso_cur, args.step_ratio))
phasorf.upsample_volume_grid(reso_cur)

if args.lr_upsample_reset:
print("reset lr to initial")
lr_scale = 1 #0.1 ** (iteration / args.n_iters)
else:
lr_scale = args.lr_decay_target_ratio ** (iteration / args.n_iters)
lr_scale = args.lr_decay_target_ratio**(iteration/args.n_iters)
print(f'lr set {lr_scale}')
grad_vars = phasorf.get_optparam_groups(args.lr_init*lr_scale, args.lr_basis*lr_scale)
grad_vars = phasorf.get_optparam_groups(args.lr_init*lr_scale,
args.lr_basis*lr_scale)
optimizer = torch.optim.Adam(grad_vars, betas=(0.9, 0.99))

phasorf.save(f'{logfolder}/{args.expname}.th')

breakpoint()

if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
train_dataset = dataset(args.datadir, split='train', downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset,phasorf, args, renderer, f'{logfolder}/imgs_train_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
train_dataset = dataset(args.datadir, split='train',
downsample=args.downsample_train, is_stack=True)
PSNRs_test = evaluation(train_dataset,phasorf, args, renderer,
f'{logfolder}/imgs_train_all/', N_vis=-1,
N_samples=-1, white_bg=white_bg,
ndc_ray=ndc_ray,device=device)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} '
f'<========================')

if args.render_test:
os.makedirs(f'{logfolder}/imgs_test_all', exist_ok=True)
PSNRs_test = evaluation(test_dataset, phasorf, args, renderer, f'{logfolder}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test), global_step=iteration)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} <========================')
PSNRs_test = evaluation(test_dataset, phasorf, args, renderer,
f'{logfolder}/imgs_test_all/', N_vis=-1,
N_samples=-1, white_bg=white_bg,
ndc_ray=ndc_ray,device=device)
summary_writer.add_scalar('test/psnr_all', np.mean(PSNRs_test),
global_step=iteration)
print(f'======> {args.expname} test all psnr: {np.mean(PSNRs_test)} '
f'<========================')

if return_memory:
memory = np.sum([v.numel() * v.element_size() for k, v in phasorf.named_parameters()])/2**20
memory = np.sum([v.numel() * v.element_size()
for k, v in phasorf.named_parameters()]) / 2**20
return np.mean(PSNRs_test), memory

return np.mean(PSNRs_test)
Expand All @@ -313,22 +353,25 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
c2ws = test_dataset.render_path
print('========>',c2ws.shape)
os.makedirs(f'{logfolder}/imgs_path_all', exist_ok=True)
evaluation_path(test_dataset, phasorf, c2ws, renderer, f'{logfolder}/imgs_path_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
evaluation_path(test_dataset, phasorf, c2ws, renderer,
f'{logfolder}/imgs_path_all/', N_vis=-1, N_samples=-1,
white_bg=white_bg, ndc_ray=ndc_ray,device=device)

if not args.render_test:
PSNRs_test = evaluation(test_dataset,phasorf, args, renderer,
f'{logfolder}/imgs_vis_all/', N_vis=10, N_samples=nSamples,
white_bg = white_bg, ndc_ray=ndc_ray,
compute_extra_metrics=args.compute_extra_metric)
f'{logfolder}/imgs_vis_all/', N_vis=10,
N_samples=nSamples, white_bg=white_bg,
ndc_ray=ndc_ray,
compute_extra_metrics=args.compute_extra_metric)
if return_memory:
memory = np.sum([v.numel() * v.element_size() for k, v in phasorf.named_parameters()])/2**20
memory = np.sum([v.numel() * v.element_size()
for k, v in phasorf.named_parameters()]) / 2**20
return np.mean(PSNRs_test), memory

return np.mean(PSNRs_test)

if __name__ == '__main__':

if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
seed = 2020233254
torch.manual_seed(seed)
Expand All @@ -337,7 +380,7 @@ def reconstruction(args, return_bbox=False, return_memory=False, bbox_only=False
args = config_parser()
print(args)

if args.export_mesh:
if args.export_mesh:
export_mesh(args)

if args.render_only and (args.render_test or args.render_path):
Expand Down

0 comments on commit df33955

Please sign in to comment.