From a8b74aecb9a2e523e5a05fa7cd4d82195e63960e Mon Sep 17 00:00:00 2001 From: Seungtae Date: Fri, 11 Nov 2022 06:36:07 +0900 Subject: [PATCH 1/5] add levelwise compression, mask reconstruction --- TensoRF/compress.py | 544 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 495 insertions(+), 49 deletions(-) diff --git a/TensoRF/compress.py b/TensoRF/compress.py index dcab481..785b1bb 100644 --- a/TensoRF/compress.py +++ b/TensoRF/compress.py @@ -3,11 +3,32 @@ from opt import config_parser from renderer import * from utils import * +from scan import * from huffman import * from run_length_encoding.rle.np_impl import dense_to_rle, rle_to_dense from collections import OrderedDict from dataLoader import dataset_dict +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +def cubify(arr, newshape): + oldshape = np.array(arr.shape) + repeats = (oldshape / newshape).astype(int) + tmpshape = np.column_stack([repeats, newshape]).ravel() + order = np.arange(len(tmpshape)) + order = np.concatenate([order[::2], order[1::2]]) + # newshape must divide oldshape evenly or else ValueError will be raised + return arr.reshape(tmpshape).transpose(order).reshape(-1, *newshape) + + +def uncubify(arr, oldshape): + N, newshape = arr.shape[0], arr.shape[1:] + oldshape = np.array(oldshape) + repeats = (oldshape / newshape).astype(int) + tmpshape = np.concatenate([repeats, newshape]) + order = np.arange(len(tmpshape)).reshape(2, -1).ravel(order='F') + return arr.reshape(tmpshape).transpose(order).reshape(oldshape) + def bit2byte(enc): BIT = 8 @@ -61,8 +82,82 @@ def dequantize_int(inputs, scale, minvl): return (inputs - 1) * scale + minvl +def split_grid(grid, level): + if level < 1: + return np.stack(grid) + + H, W = grid.shape[-2:] + if H % 2 != 0 or W % 2 != 0: + raise ValueError("grid dimension is not divisable.") + + grid = np.squeeze(cubify(grid, (1, H//2, W//2))) # (C*4, H, W) + idxs = np.arange(len(grid)) # number of channels + + if level >= 1: + topleft = split_grid(grid[idxs%4 == 0, ...], level-1) + others = grid[idxs%4 != 0, ...] + return topleft, others + + +def concat_grid(grids): + if len(grids) < 2: + raise ValueError("# of girds must be greater than 1.") + # the highest level of grid + topleft = grids[-1] + # high level (small) to low level (large) + for others in reversed(grids[:-1]): + # interleave blocks along channel axis + # [c1_1, c2_1, c2_2, c2_3, c1_2, c2_4, ...] + (c1, h1, w1), c2 = topleft.shape, others.shape[0] + temp = np.empty((c1+c2, h1, w1), dtype=topleft.dtype) + idxs = np.arange(c1+c2) + temp[idxs%4 == 0] = topleft + temp[idxs%4 != 0] = others + # uncubify ((c1+c2), 1, h, w) -> ((c1+c2)//4, h*2, w*2) + topleft = uncubify(temp[:, None, ...], ((c1+c2)//4, h1*2, w1*2)) + return topleft + + +def get_levelwise_shape(grids, dwt_level): + total_shapes = [] + for i in range(3): + grid = grids[i] + shape_per_lv = [] + # from low (large) to high (small) + for j in range(dwt_level): + # split level + topleft, others = grid + # save shape + shape_per_lv += [others.shape] + # upgrad grid + grid = topleft + # save the last level shape in channel-wise + shape_per_lv += [topleft.shape] + total_shapes += [shape_per_lv] + return total_shapes + + +def packbits_by_level(grids, dwt_level): + new_grids = [] + for i in range(3): + grid = grids[i] + grid_per_lv = [] # dim: (level+1,) + # from low (large) to high (small) + for j in range(dwt_level): + # split level + topleft, others = grid + # save high level feat in channel-wise + grid_per_lv += [np.packbits(others.transpose(1, 2, 0))] + # update grid + grid = topleft + # save the last level feat in channel-wise + grid_per_lv += [np.packbits(topleft.transpose(1, 2, 0))] + new_grids += [grid_per_lv] + return new_grids + + @torch.no_grad() -def compress_method1(args, device): # save grid + mask +def compress_dwt(args): # save grid + mask # check if ckpt exists if not os.path.exists(args.ckpt): print("the ckpt path does not exists!") @@ -73,11 +168,11 @@ def compress_method1(args, device): # save grid + mask # update kwargs kwargs = ckpt['kwargs'] - kwargs.update({'device': device}) - kwargs.update({'use_mask': args.use_mask}) - kwargs.update({'use_dwt': args.use_dwt}) - kwargs.update({'dwt_level': args.dwt_level}) + + # NOTE: temp code + if "trans_func" in kwargs: + del kwargs['trans_func'] # make model tensorf = eval(args.model_name)(**kwargs) @@ -86,32 +181,42 @@ def compress_method1(args, device): # save grid + mask # ship to cpu tensorf.to('cpu') + if args.reconstruct_mask: + # (1) mask reconstruction + den_plane_mask, den_line_mask = [], [] + app_plane_mask, app_line_mask = [], [] + for i in range(3): + den_plane_mask += [np.where(tensorf.density_plane[i] != 0, 1, 0)] + den_line_mask += [np.where(tensorf.density_line[i] != 0, 1, 0)] + app_plane_mask += [np.where(tensorf.app_plane[i] != 0, 1, 0)] + app_line_mask += [np.where(tensorf.app_line[i] != 0, 1, 0)] + else: + # (1) binarize mask + den_plane_mask, den_line_mask = [], [] + app_plane_mask, app_line_mask = [], [] + for i in range(3): + den_plane_mask += [np.where(tensorf.density_plane_mask[i]>=0, 1, 0)] + den_line_mask += [np.where(tensorf.density_line_mask[i]>=0, 1, 0)] + app_plane_mask += [np.where(tensorf.app_plane_mask[i]>=0, 1, 0)] + app_line_mask += [np.where(tensorf.app_line_mask[i]>=0, 1, 0)] + + # mask shape mask_shape = { - "density_plane": [x.shape for x in tensorf.density_plane_mask], - "density_line": [x.shape for x in tensorf.density_line_mask], - "app_plane": [x.shape for x in tensorf.app_plane_mask], - "app_line": [x.shape for x in tensorf.app_line_mask] + "density_plane": [x.shape for x in den_plane_mask], + "density_line": [x.shape for x in den_line_mask], + "app_plane": [x.shape for x in app_plane_mask], + "app_line": [x.shape for x in app_line_mask] } - # (1) binarize mask - den_plane_mask, den_line_mask = [], [] - app_plane_mask, app_line_mask = [], [] - for i in range(3): - den_plane_mask += [torch.where(tensorf.density_plane_mask[i]>=0, 1, 0)] - den_line_mask += [torch.where(tensorf.density_line_mask[i]>=0, 1, 0)] - app_plane_mask += [torch.where(tensorf.app_plane_mask[i]>=0, 1, 0)] - app_line_mask += [torch.where(tensorf.app_line_mask[i]>=0, 1, 0)] - # (2) get non-masked values in the feature grids - den_plane, den_line = [], [] app_plane, app_line = [], [] for i in range(3): - den_plane += [tensorf.density_plane[i][(den_plane_mask[i] == 1)].flatten()] - den_line += [tensorf.density_line[i][(den_line_mask[i] == 1)].flatten()] - app_plane += [tensorf.app_plane[i][(app_plane_mask[i] == 1)].flatten()] - app_line += [tensorf.app_line[i][(app_line_mask[i] == 1)].flatten()] + den_plane += [tensorf.density_plane[i][(den_plane_mask[i][None, ...] == 1)].flatten()] + den_line += [tensorf.density_line[i][(den_line_mask[i][None, ...] == 1)].flatten()] + app_plane += [tensorf.app_plane[i][(app_plane_mask[i][None, ...] == 1)].flatten()] + app_line += [tensorf.app_line[i][(app_line_mask[i][None, ...] == 1)].flatten()] # scale & minimum value scale = {k: [0]*3 for k in mask_shape.keys()} @@ -130,13 +235,20 @@ def compress_method1(args, device): # save grid + mask den_line[i] = den_line[i].to(torch.uint8) app_plane[i] = app_plane[i].to(torch.uint8) app_line[i] = app_line[i].to(torch.uint8) + + # (5) pack bits to byte + for i in range(3): + den_plane_mask[i] = np.packbits(den_plane_mask[i]) + den_line_mask[i] = np.packbits(den_line_mask[i]) + app_plane_mask[i] = np.packbits(app_plane_mask[i]) + app_line_mask[i] = np.packbits(app_line_mask[i]) # (5) RLE masks for i in range(3): - den_plane_mask[i] = dense_to_rle(den_plane_mask[i].flatten(), np.int8) - den_line_mask[i] = dense_to_rle(den_line_mask[i].flatten(), np.int8) - app_plane_mask[i] = dense_to_rle(app_plane_mask[i].flatten(), np.int8) - app_line_mask[i] = dense_to_rle(app_line_mask[i].flatten(), np.int8) + den_plane_mask[i] = dense_to_rle(den_plane_mask[i].flatten(), np.int8).astype(np.int8) + den_line_mask[i] = dense_to_rle(den_line_mask[i].flatten(), np.int8).astype(np.int8) + app_plane_mask[i] = dense_to_rle(app_plane_mask[i].flatten(), np.int8).astype(np.int8) + app_line_mask[i] = dense_to_rle(app_line_mask[i].flatten(), np.int8).astype(np.int8) # (6) concatenate masks mask = np.concatenate([*den_plane_mask, *den_line_mask, *app_plane_mask, *app_line_mask]) @@ -147,15 +259,15 @@ def compress_method1(args, device): # save grid + mask "app_line": [r.shape[0] for r in app_line_mask] } - # (7) huffman coding masks + # (7) Huffman masks mask, mask_tree = huffman(mask) # (8) bit -> byte, numpy -> tensor mask = bit2byte(mask) + # mask = torch.ByteTensor(np.packbits(np.array(list(mask), np.uint8))) - # (9) save model + # (9) save params params = { - "kwargs": tensorf.get_kwargs(), "feature": { "density_plane": den_plane, "density_line": den_line, @@ -174,11 +286,18 @@ def compress_method1(args, device): # save grid + mask # set directory root_dir = args.ckpt.split('/')[:-1] - param_path = os.path.join(*root_dir, 'model_compressed.th') + param_path = os.path.join(*root_dir, 'params.th') torch.save(params, param_path) - param_size = os.path.getsize(param_path)/1024 - print(f"============> Grid + Mask + MLP + etc (kb): {param_size} <============") + param_size = os.path.getsize(param_path)/1024/1024 + print(f"============> Grid + Mask + MLP (mb): {param_size} <============") + + # (10) save kwargs + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + torch.save({"kwargs": tensorf.get_kwargs()}, kwargs_path) + + kwargs_size = os.path.getsize(kwargs_path)/1024/1024 + print(f"============> kwargs (mb): {kwargs_size} <============") if tensorf.alphaMask is not None: alpha_volume = tensorf.alphaMask.alpha_volume.bool().cpu().numpy() @@ -191,14 +310,14 @@ def compress_method1(args, device): # save grid + mask alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th') torch.save(alpha_mask, alpha_mask_path) - mask_size = os.path.getsize(alpha_mask_path)/1024 - print(f"============> Alpha mask (kb): {mask_size} <============") + mask_size = os.path.getsize(alpha_mask_path)/1024/1024 + print(f"============> Alpha mask (mb): {mask_size} <============") print("encoding done.") @torch.no_grad() -def decompress_method1(args): +def decompress_dwt(args): # check if ckpt exists if not os.path.exists(args.ckpt): print("the ckpt path does not exists!") @@ -206,36 +325,46 @@ def decompress_method1(args): # set directory root_dir = args.ckpt.split('/')[:-1] - param_path = os.path.join(*root_dir, 'model_compressed.th') + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + param_path = os.path.join(*root_dir, 'params.th') + + # load kwargs + kwargs = torch.load(kwargs_path, map_location='cpu')["kwargs"] # load checkpoint ckpt = torch.load(param_path, map_location='cpu') + # dictionary keys + state_keys = ["density_plane", "density_line", "app_plane", "app_line"] + # (1) byte -> bit mask = byte2bit(ckpt["mask"]) + # mask = np.unpackbits(ckpt["mask"].numpy()) # (2) inverse huffman mask = dehuffman(ckpt["mask_tree"], mask) # (3) split an array into multiple arrays and inverse RLE - masks = OrderedDict({k: [] for k in ckpt["mask_shape"].keys()}) + masks = OrderedDict({k: [] for k in state_keys}) begin = 0 - for key in masks.keys(): + for key in state_keys: for length in ckpt["rle_length"][key]: - masks[key] += [torch.from_numpy(rle_to_dense(mask[begin:begin+length]))] + masks[key] += [np.unpackbits(rle_to_dense(mask[begin:begin+length]).astype(np.uint8))] masks[key][-1][masks[key][-1] == 0] = -1 begin += length # (4) reshape mask - for key in masks.keys(): + for key in state_keys: for i in range(3): shape = ckpt["mask_shape"][key][i] - masks[key][i] = nn.Parameter(masks[key][i].to(torch.float32).reshape(*shape)) + masks[key][i] = nn.Parameter( + torch.from_numpy(masks[key][i]).to(torch.float32).reshape(*shape) + ) masks[key] = nn.ParameterList(masks[key]) # (5) dequantize feature grid - features = {k: [] for k in masks.keys()} + features = {k: [] for k in state_keys} for key in features.keys(): for i in range(3): feat = ckpt["feature"][key][i] @@ -245,9 +374,7 @@ def decompress_method1(args): features[key][-1][masks[key][i] == 1] = dequantize_int(feat, scale, minvl) features[key] = nn.ParameterList(features[key]) - # check kwargs - kwargs = ckpt['kwargs'] kwargs.update({'device': device}) # IMPORTANT: aabb to cuda @@ -296,6 +423,321 @@ def decompress_method1(args): print(f'============> {args.expname} test all psnr: {np.mean(PSNRs_test)} <============') +@torch.no_grad() +def compress_dwt_levelwise(args): + # check if ckpt exists + if not os.path.exists(args.ckpt): + print("the ckpt path does not exists!") + return + + # load checkpoint + ckpt = torch.load(args.ckpt, map_location=device) + + # update kwargs + kwargs = ckpt['kwargs'] + kwargs.update({'device': device}) + + # NOTE: temp code + if "trans_func" in kwargs: + del kwargs['trans_func'] + + # make model + tensorf = eval(args.model_name)(**kwargs) + tensorf.load(ckpt) + + # ship to cpu + tensorf.to('cpu') + + # dictionary keys + state_keys = ["density_plane", "density_line", "app_plane", "app_line"] + + # ---------------------- feature grid compression ---------------------- # + + if args.reconstruct_mask: + # (1) mask reconstruction + den_plane_mask, den_line_mask = [], [] + app_plane_mask, app_line_mask = [], [] + for i in range(3): + den_plane_mask += [np.where(tensorf.density_plane[i] != 0, 1, 0)] + den_line_mask += [np.where(tensorf.density_line[i] != 0, 1, 0)] + app_plane_mask += [np.where(tensorf.app_plane[i] != 0, 1, 0)] + app_line_mask += [np.where(tensorf.app_line[i] != 0, 1, 0)] + else: + # (1) binarize mask + den_plane_mask, den_line_mask = [], [] + app_plane_mask, app_line_mask = [], [] + for i in range(3): + den_plane_mask += [np.where(tensorf.density_plane_mask[i]>=0, 1, 0)] + den_line_mask += [np.where(tensorf.density_line_mask[i]>=0, 1, 0)] + app_plane_mask += [np.where(tensorf.app_plane_mask[i]>=0, 1, 0)] + app_line_mask += [np.where(tensorf.app_line_mask[i]>=0, 1, 0)] + + # (2) get non-masked values in the feature grids + den_plane, den_line = [], [] + app_plane, app_line = [], [] + for i in range(3): + den_plane += [tensorf.density_plane[i][(den_plane_mask[i][None, ...] == 1)].flatten()] + den_line += [tensorf.density_line[i][(den_line_mask[i][None, ...] == 1)].flatten()] + app_plane += [tensorf.app_plane[i][(app_plane_mask[i][None, ...] == 1)].flatten()] + app_line += [tensorf.app_line[i][(app_line_mask[i][None, ...] == 1)].flatten()] + + # scale & minimum value + scale = {k: [0]*3 for k in state_keys} + minvl = {k: [0]*3 for k in state_keys} + + # (3) quantize non-masked values + for i in range(3): + den_plane[i], scale["density_plane"][i], minvl["density_plane"][i] = quantize_int(den_plane[i], tensorf.grid_bit) + den_line[i], scale["density_line"][i], minvl["density_line"][i] = quantize_int(den_line[i], tensorf.grid_bit) + app_plane[i], scale["app_plane"][i], minvl["app_plane"][i] = quantize_int(app_plane[i], tensorf.grid_bit) + app_line[i], scale["app_line"][i], minvl["app_line"][i] = quantize_int(app_line[i], tensorf.grid_bit) + + # (4) convert dtype (float -> uint8) + for i in range(3): + den_plane[i] = den_plane[i].to(torch.uint8) + den_line[i] = den_line[i].to(torch.uint8) + app_plane[i] = app_plane[i].to(torch.uint8) + app_line[i] = app_line[i].to(torch.uint8) + + # ---------------------- mask compression ---------------------- # + + dwt_level = kwargs["dwt_level"] + + # (5) split by level: (((lv3 topleft, lv3 others), lv2 others), lv1 others) + for i in range(3): + den_plane_mask[i] = split_grid(den_plane_mask[i].squeeze(0), level=dwt_level) + app_plane_mask[i] = split_grid(app_plane_mask[i].squeeze(0), level=dwt_level) + + # mask shape for reconstruction + mask_shape = { + "density_plane": get_levelwise_shape(den_plane_mask, dwt_level), + "density_line": [x.shape for x in den_line_mask], + "app_plane": get_levelwise_shape(app_plane_mask, dwt_level), + "app_line": [x.shape for x in app_line_mask] + } + + # (6) pack bits by level + den_plane_mask = packbits_by_level(den_plane_mask, dwt_level) + app_plane_mask = packbits_by_level(app_plane_mask, dwt_level) + den_line_mask = [np.packbits(den_line_mask[i]) for i in range(3)] + app_line_mask = [np.packbits(app_line_mask[i]) for i in range(3)] + + # (7) RLE (masks), save rle length + rle_length = {k: [] for k in state_keys} + for i in range(3): + # RLE line + den_line_mask[i] = dense_to_rle(den_line_mask[i], np.int8).astype(np.int8) + app_line_mask[i] = dense_to_rle(app_line_mask[i], np.int8).astype(np.int8) + # save line length + rle_length["density_line"] += [den_line_mask[i].shape[0]] + rle_length["app_line"] += [app_line_mask[i].shape[0]] + # RLE plane container + den_plane_rle_length = [] + app_plane_rle_length = [] + for j in range(dwt_level+1): + # RLE plane by level + den_plane_mask[i][j] = dense_to_rle(den_plane_mask[i][j], np.int8).astype(np.int8) + app_plane_mask[i][j] = dense_to_rle(app_plane_mask[i][j], np.int8).astype(np.int8) + # save plane length + den_plane_rle_length += [den_plane_mask[i][j].shape[0]] + app_plane_rle_length += [app_plane_mask[i][j].shape[0]] + rle_length["density_plane"] += [den_plane_rle_length] + rle_length["app_plane"] += [app_plane_rle_length] + # concat mask by axis (x, y, z) + den_plane_mask[i] = np.concatenate(den_plane_mask[i]) + app_plane_mask[i] = np.concatenate(app_plane_mask[i]) + + # (8) concatenate masks + mask = np.concatenate([*den_plane_mask, *den_line_mask, *app_plane_mask, *app_line_mask]) + + # (9) Huffman (masks) + mask, mask_tree = huffman(mask) + + # (10) pack bits (string) to byte, numpy to tensor + mask = bit2byte(mask) + + # (11) save params + params = { + "feature": { + "density_plane": den_plane, + "density_line": den_line, + "app_plane": app_plane, + "app_line": app_line + }, + "scale": scale, + "minvl": minvl, + "mask": mask, + "mask_tree": mask_tree, + "mask_shape": mask_shape, + "rle_length": rle_length, + "render_module": tensorf.renderModule, + "basis_mat": tensorf.basis_mat + } + + # set directory + root_dir = args.ckpt.split('/')[:-1] + param_path = os.path.join(*root_dir, 'params.th') + torch.save(params, param_path) + + param_size = os.path.getsize(param_path)/1024/1024 + print(f"============> Grid + Mask + MLP (mb): {param_size} <============") + + # (12) save kwargs + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + torch.save({"kwargs": tensorf.get_kwargs()}, kwargs_path) + + kwargs_size = os.path.getsize(kwargs_path)/1024/1024 + print(f"============> kwargs (mb): {kwargs_size} <============") + + # (13) save alphaMask + if tensorf.alphaMask is not None: + alpha_volume = tensorf.alphaMask.alpha_volume.bool().cpu().numpy() + alpha_mask = { + 'alphaMask.shape': alpha_volume.shape, + 'alphaMask.mask': np.packbits(alpha_volume.reshape(-1)), + 'alphaMask.aabb': tensorf.alphaMask.aabb.cpu() + } + + alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th') + torch.save(alpha_mask, alpha_mask_path) + + mask_size = os.path.getsize(alpha_mask_path)/1024/1024 + print(f"============> Alpha mask (mb): {mask_size} <============") + + print("encoding done.") + + +@torch.no_grad() +def decompress_dwt_levelwise(args): + # check if ckpt exists + if not os.path.exists(args.ckpt): + print("the ckpt path does not exists!") + return + + # set directory + root_dir = args.ckpt.split('/')[:-1] + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + param_path = os.path.join(*root_dir, 'params.th') + + # load kwargs + kwargs = torch.load(kwargs_path, map_location='cpu')["kwargs"] + + # load checkpoint + ckpt = torch.load(param_path, map_location='cpu') + + # ---------------------- mask reconstruction ---------------------- # + + # (1) unpack byte to bits + mask = byte2bit(ckpt["mask"]) + + # (2) inverse Huffman + mask = dehuffman(ckpt["mask_tree"], mask) + + # dictionary keys + state_keys = ["density_plane", "density_line", "app_plane", "app_line"] + + dwt_level = kwargs["dwt_level"] + + # (3) split mask vector, inverse RLE, and unpack bits + begin = 0 + masks = OrderedDict({k: [] for k in state_keys}) + for key in state_keys: + for i in range(3): + rle_length = ckpt["rle_length"][key][i] + mask_shape = ckpt["mask_shape"][key][i] + if key in ["density_plane", "app_plane"]: + mask_per_lv = [] + # from low level to high level + for j in range(dwt_level+1): + # unpack bits + mask_per_lv += [np.unpackbits(rle_to_dense(mask[begin:begin+rle_length[j]]).astype(np.uint8))] + # unpack(inv_reshape(inv_transpose(A))) = B + # reshape to transposed shape, then transpose + c, h, w = mask_shape[j] + mask_per_lv[-1] = mask_per_lv[-1].reshape((h, w, c)).transpose(2, 0, 1) + mask_per_lv[-1][mask_per_lv[-1] == 0] = -1 # to make masked area zero + begin += rle_length[j] + masks[key] += [mask_per_lv] + else: + masks[key] += [np.unpackbits(rle_to_dense(mask[begin:begin+rle_length]).astype(np.uint8))] + masks[key][-1] = masks[key][-1].reshape(mask_shape) + masks[key][-1][masks[key][-1] == 0] = -1 # to make masked area zero + begin += rle_length + + # (4) concatenate levelwise masks + for i in range(3): + masks["density_plane"][i] = concat_grid(masks["density_plane"][i])[None, ...] + masks["app_plane"][i] = concat_grid(masks["app_plane"][i])[None, ...] + + # (5) convert dtype: int8 -> float32 + for key in state_keys: + for i in range(3): + masks[key][i] = nn.Parameter( + torch.from_numpy(masks[key][i].astype(np.float32)) + ) + masks[key] = nn.ParameterList(masks[key]) + + # ---------------------- grid reconstruction ---------------------- # + + # (6) dequantize feature grid + features = {k: [] for k in masks.keys()} + for key in features.keys(): + for i in range(3): + feat = ckpt["feature"][key][i] + scale = ckpt["scale"][key][i] + minvl = ckpt["minvl"][key][i] + features[key] += [nn.Parameter(torch.zeros(masks[key][i].shape))] + features[key][-1][masks[key][i] == 1] = dequantize_int(feat, scale, minvl) + features[key] = nn.ParameterList(features[key]) + + # check kwargs + kwargs.update({'device': device}) + + # IMPORTANT: aabb to cuda + kwargs["aabb"] = kwargs["aabb"].to(device) + + # load params + tensorf = eval(args.model_name)(**kwargs) + tensorf.density_plane = features["density_plane"].to(device) + tensorf.density_line = features["density_line"].to(device) + tensorf.app_plane = features["app_plane"].to(device) + tensorf.app_line = features["app_line"].to(device) + tensorf.density_plane_mask = masks["density_plane"].to(device) + tensorf.density_line_mask = masks["density_line"].to(device) + tensorf.app_plane_mask = masks["app_plane"].to(device) + tensorf.app_line_mask = masks["app_line"].to(device) + tensorf.renderModule = ckpt["render_module"].to(device) + tensorf.basis_mat = ckpt["basis_mat"].to(device) + + # load alpha mask + alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th') + if os.path.exists(alpha_mask_path): + print("loading alpha mask...") + alpha_mask = torch.load(alpha_mask_path, map_location=device) + length = np.prod(alpha_mask['alphaMask.shape']) + alpha_volume = torch.from_numpy(np.unpackbits(alpha_mask['alphaMask.mask'])[:length].reshape(alpha_mask['alphaMask.shape'])) + tensorf.alphaMask = AlphaGridMask(device, alpha_mask['alphaMask.aabb'].to(device), alpha_volume.float().to(device)) + + print("model loaded.") + + if args.decompress_and_validate: + # renderder + renderer = OctreeRender_trilinear_fast + + # init dataset + dataset = dataset_dict[args.dataset_name] + test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) + + white_bg = test_dataset.white_bg + ndc_ray = args.ndc_ray + + logfolder = os.path.dirname(args.ckpt) + + os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) + PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', + N_vis=args.N_vis, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) + print(f'============> {args.expname} test all psnr: {np.mean(PSNRs_test)} <============') + if __name__ == '__main__': @@ -303,12 +745,16 @@ def decompress_method1(args): torch.manual_seed(20211202) np.random.seed(20211202) - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - args = config_parser() if args.compress: - compress_method1(args, device) + if args.compress_levelwise: + compress_dwt_levelwise(args) + else: + compress_dwt(args) if args.decompress: - decompress_method1(args) \ No newline at end of file + if args.decompress_levelwise: + decompress_dwt_levelwise(args) + else: + decompress_dwt(args) \ No newline at end of file From dfc0099c78912693fa77d3c50d19a53d69ef1969 Mon Sep 17 00:00:00 2001 From: Seungtae Date: Fri, 11 Nov 2022 06:36:27 +0900 Subject: [PATCH 2/5] add encoding options --- TensoRF/opt.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/TensoRF/opt.py b/TensoRF/opt.py index ebbd5ec..6458227 100644 --- a/TensoRF/opt.py +++ b/TensoRF/opt.py @@ -78,13 +78,12 @@ def config_parser(cmd=None): parser.add_argument("--use_dwt", action='store_true') parser.add_argument("--dwt_level", type=int, default=2) - # Alpha mask - parser.add_argument("--alpha_offset", type=float, default=0.0, - help='add to alphamask threshold') - # encoding option + parser.add_argument("--reconstruct_mask", type=int, default=1) parser.add_argument("--compress", type=int, default=0) + parser.add_argument("--compress_levelwise", type=int, default=0) parser.add_argument("--decompress", type=int, default=0) + parser.add_argument("--decompress_levelwise", type=int, default=0) parser.add_argument("--decompress_and_validate", type=int, default=1) # network decoder From 8b0970be7da4159d416d86fdddd7cfa588c5beff Mon Sep 17 00:00:00 2001 From: Seungtae Date: Fri, 11 Nov 2022 06:36:59 +0900 Subject: [PATCH 3/5] add new scripts --- TensoRF/script.sh | 90 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 69 insertions(+), 21 deletions(-) diff --git a/TensoRF/script.sh b/TensoRF/script.sh index be66eb8..cff3c0b 100644 --- a/TensoRF/script.sh +++ b/TensoRF/script.sh @@ -1,31 +1,79 @@ #!/bin/bash -# weak-model (lego; 1e-11) -CUDA_VISIBLE_DEVICES=3 python compress.py \ - --config=configs/chair.txt \ - --use_mask \ - --mask_weight=1e-11 \ - --grid_bit=8 \ - --use_dwt \ - --dwt_level=1 \ - --datadir=/workspace/dataset/nerf_synthetic/lego \ - --ckpt=log/lego/weak_model_lego.th \ - --compress=1 \ - --decompress=1 \ - --decompress_and_validate=1 \ - --N_vis=-1 - -# # strong-model (chair; 1e-10) -# CUDA_VISIBLE_DEVICES=3 python compress.py \ +# # weak-model (lego; 1e-11) +# CUDA_VISIBLE_DEVICES=2 python compress.py \ # --config=configs/chair.txt \ # --use_mask \ -# --mask_weight=1e-10 \ +# --mask_weight=1e-11 \ # --grid_bit=8 \ # --use_dwt \ # --dwt_level=1 \ -# --datadir=/workspace/dataset/nerf_synthetic/chair \ -# --ckpt=log/chair/strong_model_chair.th \ +# --datadir=/workspace/dataset/nerf_synthetic/lego \ +# --ckpt=log/lego/weak_model_lego.th \ # --compress=1 \ # --decompress=1 \ # --decompress_and_validate=1 \ -# --N_vis=-1 \ No newline at end of file +# --N_vis=-1 + +########################### w/o mask reconstruction ########################### + +# DWT (compress, decompress) +CUDA_VISIBLE_DEVICES=0 python compress.py \ + --config=configs/chair.txt \ + --datadir=/workspace/dataset/nerf_synthetic/chair \ + --ckpt=log/chair/lv4/lv4.th \ + --reconstruct_mask=0 \ + --compress=1 \ + --decompress=1 \ + --decompress_and_validate=1 \ + --N_vis=5 + +# DWT (levelwise comrpess, decompress) +CUDA_VISIBLE_DEVICES=0 python compress.py \ + --config=configs/chair.txt \ + --datadir=/workspace/dataset/nerf_synthetic/chair \ + --ckpt=log/chair/lv4/lv4.th \ + --reconstruct_mask=0 \ + --compress=1 \ + --compress_levelwise=1 \ + --decompress=1 \ + --decompress_levelwise=1 \ + --decompress_and_validate=1 \ + --N_vis=5 + +########################### w/ mask reconstruction ########################### + +# DWT (compress, decompress) +CUDA_VISIBLE_DEVICES=0 python compress.py \ + --config=configs/chair.txt \ + --datadir=/workspace/dataset/nerf_synthetic/chair \ + --ckpt=log/dwt/test.th \ + --reconstruct_mask=1 \ + --compress=1 \ + --decompress=1 \ + --decompress_and_validate=1 \ + --N_vis=5 + +# DWT (levelwise compress, decompress) +CUDA_VISIBLE_DEVICES=0 python compress.py \ + --config=configs/chair.txt \ + --datadir=/workspace/dataset/nerf_synthetic/chair \ + --ckpt=log/dwt/test.th \ + --reconstruct_mask=1 \ + --compress=1 \ + --compress_levelwise=1 \ + --decompress=1 \ + --decompress_levelwise=1 \ + --decompress_and_validate=1 \ + --N_vis=5 + +# DCT (compress, decompress) +CUDA_VISIBLE_DEVICES=0 python compress_dct.py \ + --config=configs/chair.txt \ + --datadir=/workspace/dataset/nerf_synthetic/chair \ + --ckpt=log/dct/test.th \ + --reconstruct_mask=1 \ + --compress=1 \ + --decompress=1 \ + --decompress_and_validate=1 \ + --N_vis=5 From 5f378d80d7ed548bca92ef95f8b3eee88a674c63 Mon Sep 17 00:00:00 2001 From: Seungtae Date: Fri, 11 Nov 2022 06:37:19 +0900 Subject: [PATCH 4/5] initial commit --- TensoRF/compress_dct.py | 360 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 360 insertions(+) create mode 100644 TensoRF/compress_dct.py diff --git a/TensoRF/compress_dct.py b/TensoRF/compress_dct.py new file mode 100644 index 0000000..2c94672 --- /dev/null +++ b/TensoRF/compress_dct.py @@ -0,0 +1,360 @@ +import os +import math +from opt import config_parser +from renderer import * +from utils import * +from scan import * +from huffman import * +from run_length_encoding.rle.np_impl import dense_to_rle, rle_to_dense +from collections import OrderedDict +from dataLoader import dataset_dict + + +def cubify(arr, newshape): + oldshape = np.array(arr.shape) + repeats = (oldshape / newshape).astype(int) + tmpshape = np.column_stack([repeats, newshape]).ravel() + order = np.arange(len(tmpshape)) + order = np.concatenate([order[::2], order[1::2]]) + # newshape must divide oldshape evenly or else ValueError will be raised + return arr.reshape(tmpshape).transpose(order).reshape(-1, *newshape) + + +def uncubify(arr, oldshape): + N, newshape = arr.shape[0], arr.shape[1:] + oldshape = np.array(oldshape) + repeats = (oldshape / newshape).astype(int) + tmpshape = np.concatenate([repeats, newshape]) + order = np.arange(len(tmpshape)).reshape(2, -1).ravel(order='F') + return arr.reshape(tmpshape).transpose(order).reshape(oldshape) + + +def bit2byte(enc): + BIT = 8 + length = len(enc) + total_int = math.ceil(length/BIT) + + start, out = 0, [] + for i in range(total_int): + target = enc[start:start+BIT] + out.append(int(target, 2)) + start += BIT + + last_target_length = length - BIT * (total_int - 1) + out.append(last_target_length) + enc_byte_tensor = torch.ByteTensor(out) + return enc_byte_tensor + + +def byte2bit(bytes): + bit = [] + bytecode = bytes[:-2] + for byte in bytecode: + b = format(byte, '08b') + bit.append(b) + + last_ele = format(bytes[-2], 'b') # 이걸 왜 08로 안했지? + last_tar_len = bytes[-1] + num_to_add_zeros = last_tar_len - len(last_ele) + output =''.join(bit) + '0'*num_to_add_zeros + last_ele + return output + + +def quantize_float(inputs, bits): + if bits == 32: + return inputs + n = float(2**(bits-1) - 1) + out = np.floor(np.abs(inputs) * n) / n + rounded = out * np.sign(inputs) + return rounded + +def quantize_int(inputs, bits): + if bits == 32: + return inputs + minvl = torch.amin(inputs) + maxvl = torch.amax(inputs) + scale = (maxvl - minvl).clip(min=1e-8) / (2**bits-2) + rounded = torch.round((inputs - minvl)/scale) + 1 + return rounded, scale, minvl + +def dequantize_int(inputs, scale, minvl): + return (inputs - 1) * scale + minvl + + +@torch.no_grad() +def compress_dct(args, device): + # check if ckpt exists + if not os.path.exists(args.ckpt): + print("the ckpt path does not exists!") + return + + # load checkpoint + ckpt = torch.load(args.ckpt, map_location=device) + + # update kwargs + kwargs = ckpt['kwargs'] + kwargs.update({'device': device}) + + # NOTE: temp code + del kwargs['trans_func'] + + # make model + tensorf = eval(args.model_name)(**kwargs) + tensorf.load(ckpt) + + # ship to cpu + tensorf.to('cpu') + + # (1) mask reconstruction + den_plane_mask, den_line_mask = [], [] + app_plane_mask, app_line_mask = [], [] + for i in range(3): + den_plane_mask += [np.where(tensorf.density_plane[i] != 0, 1, 0)] + den_line_mask += [np.where(tensorf.density_line[i] != 0, 1, 0)] + app_plane_mask += [np.where(tensorf.app_plane[i] != 0, 1, 0)] + app_line_mask += [np.where(tensorf.app_line[i] != 0, 1, 0)] + + # mask shape + mask_shape = { + "density_plane": [x.shape for x in den_plane_mask], + "density_line": [x.shape for x in den_line_mask], + "app_plane": [x.shape for x in app_plane_mask], + "app_line": [x.shape for x in app_line_mask] + } + + # (2) get non-masked values in the feature grids + den_plane, den_line = [], [] + app_plane, app_line = [], [] + for i in range(3): + den_plane += [tensorf.density_plane[i][(den_plane_mask[i][None, ...] == 1)].flatten()] + den_line += [tensorf.density_line[i][(den_line_mask[i][None, ...] == 1)].flatten()] + app_plane += [tensorf.app_plane[i][(app_plane_mask[i][None, ...] == 1)].flatten()] + app_line += [tensorf.app_line[i][(app_line_mask[i][None, ...] == 1)].flatten()] + + # scale & minimum value + scale = {k: [0]*3 for k in mask_shape.keys()} + minvl = {k: [0]*3 for k in mask_shape.keys()} + + # (3) quantize non-masked values + for i in range(3): + den_plane[i], scale["density_plane"][i], minvl["density_plane"][i] = quantize_int(den_plane[i], tensorf.grid_bit) + den_line[i], scale["density_line"][i], minvl["density_line"][i] = quantize_int(den_line[i], tensorf.grid_bit) + app_plane[i], scale["app_plane"][i], minvl["app_plane"][i] = quantize_int(app_plane[i], tensorf.grid_bit) + app_line[i], scale["app_line"][i], minvl["app_line"][i] = quantize_int(app_line[i], tensorf.grid_bit) + + # (4) convert dtype (float -> uint8) + for i in range(3): + den_plane[i] = den_plane[i].to(torch.uint8) + den_line[i] = den_line[i].to(torch.uint8) + app_plane[i] = app_plane[i].to(torch.uint8) + app_line[i] = app_line[i].to(torch.uint8) + + # (5) zigzag scan (channel-first) + for i in range(3): + den_plane_mask[i] = zigzag_block(den_plane_mask[i].transpose(0, 2, 3, 1)) + app_plane_mask[i] = zigzag_block(app_plane_mask[i].transpose(0, 2, 3, 1)) + + # (6) pack bits to byte + for i in range(3): + den_plane_mask[i] = np.packbits(den_plane_mask[i]) + den_line_mask[i] = np.packbits(den_line_mask[i]) + app_plane_mask[i] = np.packbits(app_plane_mask[i]) + app_line_mask[i] = np.packbits(app_line_mask[i]) + + # (7) RLE masks + for i in range(3): + den_plane_mask[i] = dense_to_rle(den_plane_mask[i].flatten(), np.int8).astype(np.int8) + den_line_mask[i] = dense_to_rle(den_line_mask[i].flatten(), np.int8).astype(np.int8) + app_plane_mask[i] = dense_to_rle(app_plane_mask[i].flatten(), np.int8).astype(np.int8) + app_line_mask[i] = dense_to_rle(app_line_mask[i].flatten(), np.int8).astype(np.int8) + + # (6) concatenate masks + mask = np.concatenate([*den_plane_mask, *den_line_mask, *app_plane_mask, *app_line_mask]) + rle_length = { + "density_plane": [r.shape[0] for r in den_plane_mask], + "density_line": [r.shape[0] for r in den_line_mask], + "app_plane": [r.shape[0] for r in app_plane_mask], + "app_line": [r.shape[0] for r in app_line_mask] + } + + # (7) Huffman masks + mask, mask_tree = huffman(mask) + + # (8) bit -> byte, numpy -> tensor + mask = bit2byte(mask) + # mask = torch.ByteTensor(np.packbits(np.array(list(mask), np.uint8))) + + # (9) save params + params = { + "feature": { + "density_plane": den_plane, + "density_line": den_line, + "app_plane": app_plane, + "app_line": app_line + }, + "scale": scale, + "minvl": minvl, + "mask": mask, + "mask_tree": mask_tree, + "mask_shape": mask_shape, + "rle_length": rle_length, + "render_module": tensorf.renderModule, + "basis_mat": tensorf.basis_mat + } + + # set directory + root_dir = args.ckpt.split('/')[:-1] + param_path = os.path.join(*root_dir, 'params.th') + torch.save(params, param_path) + + param_size = os.path.getsize(param_path)/1024/1024 + print(f"============> Grid + Mask + MLP (mb): {param_size} <============") + + # (10) save kwargs + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + torch.save({"kwargs": tensorf.get_kwargs()}, kwargs_path) + + kwargs_size = os.path.getsize(kwargs_path)/1024/1024 + print(f"============> kwargs (mb): {kwargs_size} <============") + + if tensorf.alphaMask is not None: + alpha_volume = tensorf.alphaMask.alpha_volume.bool().cpu().numpy() + alpha_mask = { + 'alphaMask.shape': alpha_volume.shape, + 'alphaMask.mask': np.packbits(alpha_volume.reshape(-1)), + 'alphaMask.aabb': tensorf.alphaMask.aabb.cpu() + } + + alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th') + torch.save(alpha_mask, alpha_mask_path) + + mask_size = os.path.getsize(alpha_mask_path)/1024/1024 + print(f"============> Alpha mask (mb): {mask_size} <============") + + print("encoding done.") + + +@torch.no_grad() +def decompress_dct(args): + # check if ckpt exists + if not os.path.exists(args.ckpt): + print("the ckpt path does not exists!") + return + + # set directory + root_dir = args.ckpt.split('/')[:-1] + param_path = os.path.join(*root_dir, 'params.th') + + # load checkpoint + ckpt = torch.load(param_path, map_location='cpu') + + # dictionary keys + state_keys = ["density_plane", "density_line", "app_plane", "app_line"] + + # (1) byte -> bit + mask = byte2bit(ckpt["mask"]) + # mask = np.unpackbits(ckpt["mask"].numpy()) + + # (2) inverse huffman + mask = dehuffman(ckpt["mask_tree"], mask) + + # (3) split an array into multiple arrays and inverse RLE + masks = OrderedDict({k: [] for k in state_keys}) + + begin = 0 + for key in masks.keys(): + for length in ckpt["rle_length"][key]: + masks[key] += [np.unpackbits(rle_to_dense(mask[begin:begin+length]).astype(np.uint8))] + masks[key][-1][masks[key][-1] == 0] = -1 + begin += length + + # (4) inverse zigzag and reshape + for key in state_keys: + for i in range(3): + B, C, H, W = ckpt["mask_shape"][key][i] + if key in ["density_plane", "app_plane"]: + mask = inverse_zigzag_block(masks[key][i].reshape(B, H*W, C), B, H, W, C).transpose(0, 3, 1, 2) + else: + mask = masks[key][i].reshape((B, C, H, W)) + masks[key][i] = nn.Parameter(torch.from_numpy(mask).to(torch.float32)) + masks[key] = nn.ParameterList(masks[key]) + + # (5) dequantize feature grid + features = {k: [] for k in state_keys} + for key in features.keys(): + for i in range(3): + feat = ckpt["feature"][key][i] + scale = ckpt["scale"][key][i] + minvl = ckpt["minvl"][key][i] + features[key] += [nn.Parameter(torch.zeros(ckpt["mask_shape"][key][i]))] + features[key][-1][masks[key][i] == 1] = dequantize_int(feat, scale, minvl) + features[key] = nn.ParameterList(features[key]) + + # load kwargs + kwargs_path = os.path.join(*root_dir, 'kwargs.th') + kwargs = torch.load(kwargs_path, map_location='cpu')["kwargs"] + + # check kwargs + kwargs.update({'device': device}) + + # IMPORTANT: aabb to cuda + kwargs["aabb"] = kwargs["aabb"].to(device) + + # load params + tensorf = eval(args.model_name)(**kwargs) + tensorf.density_plane = features["density_plane"].to(device) + tensorf.density_line = features["density_line"].to(device) + tensorf.app_plane = features["app_plane"].to(device) + tensorf.app_line = features["app_line"].to(device) + tensorf.density_plane_mask = masks["density_plane"].to(device) + tensorf.density_line_mask = masks["density_line"].to(device) + tensorf.app_plane_mask = masks["app_plane"].to(device) + tensorf.app_line_mask = masks["app_line"].to(device) + tensorf.renderModule = ckpt["render_module"].to(device) + tensorf.basis_mat = ckpt["basis_mat"].to(device) + + # load alpha mask + alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th') + if os.path.exists(alpha_mask_path): + print("loading alpha mask...") + alpha_mask = torch.load(alpha_mask_path, map_location=device) + length = np.prod(alpha_mask['alphaMask.shape']) + alpha_volume = torch.from_numpy(np.unpackbits(alpha_mask['alphaMask.mask'])[:length].reshape(alpha_mask['alphaMask.shape'])) + tensorf.alphaMask = AlphaGridMask(device, alpha_mask['alphaMask.aabb'].to(device), alpha_volume.float().to(device)) + + print("model loaded.") + + if args.decompress_and_validate: + # renderder + renderer = OctreeRender_trilinear_fast + + # init dataset + dataset = dataset_dict[args.dataset_name] + test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True) + + white_bg = test_dataset.white_bg + ndc_ray = args.ndc_ray + + logfolder = os.path.dirname(args.ckpt) + + os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True) + PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/', + N_vis=args.N_vis, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device) + print(f'============> {args.expname} test all psnr: {np.mean(PSNRs_test)} <============') + + + +if __name__ == '__main__': + torch.set_default_dtype(torch.float32) + torch.manual_seed(20211202) + np.random.seed(20211202) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + args = config_parser() + + if args.compress: + compress_dct(args, device) + + if args.decompress: + decompress_dct(args) \ No newline at end of file From 186ff9aab2c5f79716b3672759815765ff9383a8 Mon Sep 17 00:00:00 2001 From: Seungtae Date: Fri, 11 Nov 2022 06:56:50 +0900 Subject: [PATCH 5/5] deleted --- TensoRF/log/chair/strong_model_chair.th | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 TensoRF/log/chair/strong_model_chair.th diff --git a/TensoRF/log/chair/strong_model_chair.th b/TensoRF/log/chair/strong_model_chair.th deleted file mode 100644 index e73dbdb..0000000 --- a/TensoRF/log/chair/strong_model_chair.th +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4510365b010ea0c5b33a94d5ad0f98305f0c5db7ae7eb626dbb7c0899d7bcd6b -size 137736322