Skip to content

Commit

Permalink
update compress.py
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Feb 16, 2023
1 parent 1e0fe19 commit 5326a09
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions compress.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
import math
import os
from opt import config_parser
from renderer import *
from utils import *
Expand All @@ -8,9 +8,11 @@
from run_length_encoding.rle.np_impl import dense_to_rle, rle_to_dense
from collections import OrderedDict
from dataLoader import dataset_dict
from models.dwt import inverse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def cubify(arr, newshape):
oldshape = np.array(arr.shape)
repeats = (oldshape / newshape).astype(int)
Expand Down Expand Up @@ -355,7 +357,6 @@ def decompress_dwt_levelwise(args):
ckpt = torch.load(param_path, map_location='cpu')

# ---------------------- mask reconstruction ---------------------- #

# (1) unpack byte to bits
mask = byte2bit(ckpt["mask"])

Expand Down Expand Up @@ -420,10 +421,7 @@ def decompress_dwt_levelwise(args):
# (5) convert dtype: int8 -> float32
for key in state_keys:
for i in range(3):
masks[key][i] = nn.Parameter(
torch.from_numpy(masks[key][i].astype(np.float32))
)
masks[key] = nn.ParameterList(masks[key])
masks[key][i] = torch.from_numpy(masks[key][i].astype(np.float32))

# ---------------------- grid reconstruction ---------------------- #

Expand All @@ -434,26 +432,33 @@ def decompress_dwt_levelwise(args):
feat = ckpt["feature"][key][i]
scale = ckpt["scale"][key][i]
minvl = ckpt["minvl"][key][i]
features[key] += [nn.Parameter(torch.zeros(masks[key][i].shape))]
features[key][-1][masks[key][i] == 1] = dequantize_int(feat, scale, minvl)
features[key] = nn.ParameterList(features[key])
features[key] += [torch.zeros(masks[key][i].shape)]
features[key][-1][masks[key][i] == 1] = dequantize_int(
feat, scale, minvl)
if 'plane' in key and args.use_dwt:
features[key][-1] = inverse(features[key][-1], args.dwt_level,
args.trans_func)

for key in state_keys:
masks[key] = nn.ParameterList(
[nn.Parameter(m) for m in masks[key]])

for key in features.keys():
features[key] = nn.ParameterList(
[nn.Parameter(m) for m in features[key]])

# check kwargs
kwargs.update({'device': device})

# IMPORTANT: aabb to cuda
kwargs["aabb"] = kwargs["aabb"].to(device)
kwargs["use_dwt"] = False
kwargs["use_mask"] = False

# load params
tensorf = eval(args.model_name)(**kwargs)
tensorf.density_plane = features["density_plane"].to(device)
tensorf.density_line = features["density_line"].to(device)
tensorf.app_plane = features["app_plane"].to(device)
tensorf.app_line = features["app_line"].to(device)
tensorf.density_plane_mask = masks["density_plane"].to(device)
tensorf.density_line_mask = masks["density_line"].to(device)
tensorf.app_plane_mask = masks["app_plane"].to(device)
tensorf.app_line_mask = masks["app_line"].to(device)
tensorf.renderModule = ckpt["render_module"].to(device)
tensorf.basis_mat = ckpt["basis_mat"].to(device)

Expand All @@ -463,28 +468,29 @@ def decompress_dwt_levelwise(args):
tensorf.updateAlphaMask((X,Y,Z))

print("model loaded.")

args.use_dwt = True
if args.decompress_and_validate:
# renderder
renderer = OctreeRender_trilinear_fast

# init dataset
dataset = dataset_dict[args.dataset_name]
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)
test_dataset = dataset(args.datadir, split='test',
downsample=args.downsample_train, is_stack=True)

white_bg = test_dataset.white_bg
ndc_ray = args.ndc_ray

logfolder = os.path.dirname(args.ckpt)

os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=-1, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
PSNRs_test = evaluation(test_dataset, tensorf, args, renderer,
f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=args.N_vis, N_samples=-1,
white_bg=white_bg,
ndc_ray=ndc_ray, device=device)
print(f'============> {args.expname} test all psnr: {np.mean(PSNRs_test)} <============')



if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(20211202)
Expand All @@ -496,4 +502,5 @@ def decompress_dwt_levelwise(args):
compress_dwt_levelwise(args)

if args.decompress:
decompress_dwt_levelwise(args)
decompress_dwt_levelwise(args)

0 comments on commit 5326a09

Please sign in to comment.