Skip to content

Commit

Permalink
Merge pull request #2 from daniel03c1/snam
Browse files Browse the repository at this point in the history
Snam
  • Loading branch information
stnamjef authored Nov 7, 2022
2 parents ee1e0e4 + 3690c9f commit 534d925
Show file tree
Hide file tree
Showing 7 changed files with 652 additions and 0 deletions.
314 changes: 314 additions & 0 deletions TensoRF/compress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,314 @@
import os
import math
from opt import config_parser
from renderer import *
from utils import *
from huffman import *
from run_length_encoding.rle.np_impl import dense_to_rle, rle_to_dense
from collections import OrderedDict
from dataLoader import dataset_dict


def bit2byte(enc):
BIT = 8
length = len(enc)
total_int = math.ceil(length/BIT)

start, out = 0, []
for i in range(total_int):
target = enc[start:start+BIT]
out.append(int(target, 2))
start += BIT

last_target_length = length - BIT * (total_int - 1)
out.append(last_target_length)
enc_byte_tensor = torch.ByteTensor(out)
return enc_byte_tensor


def byte2bit(bytes):
bit = []
bytecode = bytes[:-2]
for byte in bytecode:
b = format(byte, '08b')
bit.append(b)

last_ele = format(bytes[-2], 'b') # 이걸 왜 08로 안했지?
last_tar_len = bytes[-1]
num_to_add_zeros = last_tar_len - len(last_ele)
output =''.join(bit) + '0'*num_to_add_zeros + last_ele
return output


def quantize_float(inputs, bits):
if bits == 32:
return inputs
n = float(2**(bits-1) - 1)
out = np.floor(np.abs(inputs) * n) / n
rounded = out * np.sign(inputs)
return rounded

def quantize_int(inputs, bits):
if bits == 32:
return inputs
minvl = torch.amin(inputs)
maxvl = torch.amax(inputs)
scale = (maxvl - minvl).clip(min=1e-8) / (2**bits-2)
rounded = torch.round((inputs - minvl)/scale) + 1
return rounded, scale, minvl

def dequantize_int(inputs, scale, minvl):
return (inputs - 1) * scale + minvl


@torch.no_grad()
def compress_method1(args, device): # save grid + mask
# check if ckpt exists
if not os.path.exists(args.ckpt):
print("the ckpt path does not exists!")
return

# load checkpoint
ckpt = torch.load(args.ckpt, map_location=device)

# update kwargs
kwargs = ckpt['kwargs']

kwargs.update({'device': device})
kwargs.update({'use_mask': args.use_mask})
kwargs.update({'use_dwt': args.use_dwt})
kwargs.update({'dwt_level': args.dwt_level})

# make model
tensorf = eval(args.model_name)(**kwargs)
tensorf.load(ckpt)

# ship to cpu
tensorf.to('cpu')

# mask shape
mask_shape = {
"density_plane": [x.shape for x in tensorf.density_plane_mask],
"density_line": [x.shape for x in tensorf.density_line_mask],
"app_plane": [x.shape for x in tensorf.app_plane_mask],
"app_line": [x.shape for x in tensorf.app_line_mask]
}

# (1) binarize mask
den_plane_mask, den_line_mask = [], []
app_plane_mask, app_line_mask = [], []
for i in range(3):
den_plane_mask += [torch.where(tensorf.density_plane_mask[i]>=0, 1, 0)]
den_line_mask += [torch.where(tensorf.density_line_mask[i]>=0, 1, 0)]
app_plane_mask += [torch.where(tensorf.app_plane_mask[i]>=0, 1, 0)]
app_line_mask += [torch.where(tensorf.app_line_mask[i]>=0, 1, 0)]

# (2) get non-masked values in the feature grids

den_plane, den_line = [], []
app_plane, app_line = [], []
for i in range(3):
den_plane += [tensorf.density_plane[i][(den_plane_mask[i] == 1)].flatten()]
den_line += [tensorf.density_line[i][(den_line_mask[i] == 1)].flatten()]
app_plane += [tensorf.app_plane[i][(app_plane_mask[i] == 1)].flatten()]
app_line += [tensorf.app_line[i][(app_line_mask[i] == 1)].flatten()]

# scale & minimum value
scale = {k: [0]*3 for k in mask_shape.keys()}
minvl = {k: [0]*3 for k in mask_shape.keys()}

# (3) quantize non-masked values
for i in range(3):
den_plane[i], scale["density_plane"][i], minvl["density_plane"][i] = quantize_int(den_plane[i], tensorf.grid_bit)
den_line[i], scale["density_line"][i], minvl["density_line"][i] = quantize_int(den_line[i], tensorf.grid_bit)
app_plane[i], scale["app_plane"][i], minvl["app_plane"][i] = quantize_int(app_plane[i], tensorf.grid_bit)
app_line[i], scale["app_line"][i], minvl["app_line"][i] = quantize_int(app_line[i], tensorf.grid_bit)

# (4) convert dtype (float -> uint8)
for i in range(3):
den_plane[i] = den_plane[i].to(torch.uint8)
den_line[i] = den_line[i].to(torch.uint8)
app_plane[i] = app_plane[i].to(torch.uint8)
app_line[i] = app_line[i].to(torch.uint8)

# (5) RLE masks
for i in range(3):
den_plane_mask[i] = dense_to_rle(den_plane_mask[i].flatten(), np.int8)
den_line_mask[i] = dense_to_rle(den_line_mask[i].flatten(), np.int8)
app_plane_mask[i] = dense_to_rle(app_plane_mask[i].flatten(), np.int8)
app_line_mask[i] = dense_to_rle(app_line_mask[i].flatten(), np.int8)

# (6) concatenate masks
mask = np.concatenate([*den_plane_mask, *den_line_mask, *app_plane_mask, *app_line_mask])
rle_length = {
"density_plane": [r.shape[0] for r in den_plane_mask],
"density_line": [r.shape[0] for r in den_line_mask],
"app_plane": [r.shape[0] for r in app_plane_mask],
"app_line": [r.shape[0] for r in app_line_mask]
}

# (7) huffman coding masks
mask, mask_tree = huffman(mask)

# (8) bit -> byte, numpy -> tensor
mask = bit2byte(mask)

# (9) save model
params = {
"kwargs": tensorf.get_kwargs(),
"feature": {
"density_plane": den_plane,
"density_line": den_line,
"app_plane": app_plane,
"app_line": app_line
},
"scale": scale,
"minvl": minvl,
"mask": mask,
"mask_tree": mask_tree,
"mask_shape": mask_shape,
"rle_length": rle_length,
"render_module": tensorf.renderModule,
"basis_mat": tensorf.basis_mat
}

# set directory
root_dir = args.ckpt.split('/')[:-1]
param_path = os.path.join(*root_dir, 'model_compressed.th')
torch.save(params, param_path)

param_size = os.path.getsize(param_path)/1024
print(f"============> Grid + Mask + MLP + etc (kb): {param_size} <============")

if tensorf.alphaMask is not None:
alpha_volume = tensorf.alphaMask.alpha_volume.bool().cpu().numpy()
alpha_mask = {
'alphaMask.shape': alpha_volume.shape,
'alphaMask.mask': np.packbits(alpha_volume.reshape(-1)),
'alphaMask.aabb': tensorf.alphaMask.aabb.cpu()
}

alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th')
torch.save(alpha_mask, alpha_mask_path)

mask_size = os.path.getsize(alpha_mask_path)/1024
print(f"============> Alpha mask (kb): {mask_size} <============")

print("encoding done.")


@torch.no_grad()
def decompress_method1(args):
# check if ckpt exists
if not os.path.exists(args.ckpt):
print("the ckpt path does not exists!")
return

# set directory
root_dir = args.ckpt.split('/')[:-1]
param_path = os.path.join(*root_dir, 'model_compressed.th')

# load checkpoint
ckpt = torch.load(param_path, map_location='cpu')

# (1) byte -> bit
mask = byte2bit(ckpt["mask"])

# (2) inverse huffman
mask = dehuffman(ckpt["mask_tree"], mask)

# (3) split an array into multiple arrays and inverse RLE
masks = OrderedDict({k: [] for k in ckpt["mask_shape"].keys()})

begin = 0
for key in masks.keys():
for length in ckpt["rle_length"][key]:
masks[key] += [torch.from_numpy(rle_to_dense(mask[begin:begin+length]))]
masks[key][-1][masks[key][-1] == 0] = -1
begin += length

# (4) reshape mask
for key in masks.keys():
for i in range(3):
shape = ckpt["mask_shape"][key][i]
masks[key][i] = nn.Parameter(masks[key][i].to(torch.float32).reshape(*shape))
masks[key] = nn.ParameterList(masks[key])

# (5) dequantize feature grid
features = {k: [] for k in masks.keys()}
for key in features.keys():
for i in range(3):
feat = ckpt["feature"][key][i]
scale = ckpt["scale"][key][i]
minvl = ckpt["minvl"][key][i]
features[key] += [nn.Parameter(torch.zeros(ckpt["mask_shape"][key][i]))]
features[key][-1][masks[key][i] == 1] = dequantize_int(feat, scale, minvl)
features[key] = nn.ParameterList(features[key])


# check kwargs
kwargs = ckpt['kwargs']
kwargs.update({'device': device})

# IMPORTANT: aabb to cuda
kwargs["aabb"] = kwargs["aabb"].to(device)

# load params
tensorf = eval(args.model_name)(**kwargs)
tensorf.density_plane = features["density_plane"].to(device)
tensorf.density_line = features["density_line"].to(device)
tensorf.app_plane = features["app_plane"].to(device)
tensorf.app_line = features["app_line"].to(device)
tensorf.density_plane_mask = masks["density_plane"].to(device)
tensorf.density_line_mask = masks["density_line"].to(device)
tensorf.app_plane_mask = masks["app_plane"].to(device)
tensorf.app_line_mask = masks["app_line"].to(device)
tensorf.renderModule = ckpt["render_module"].to(device)
tensorf.basis_mat = ckpt["basis_mat"].to(device)

# load alpha mask
alpha_mask_path = os.path.join(*root_dir, 'alpha_mask.th')
if os.path.exists(alpha_mask_path):
print("loading alpha mask...")
alpha_mask = torch.load(alpha_mask_path, map_location=device)
length = np.prod(alpha_mask['alphaMask.shape'])
alpha_volume = torch.from_numpy(np.unpackbits(alpha_mask['alphaMask.mask'])[:length].reshape(alpha_mask['alphaMask.shape']))
tensorf.alphaMask = AlphaGridMask(device, alpha_mask['alphaMask.aabb'].to(device), alpha_volume.float().to(device))

print("model loaded.")

if args.decompress_and_validate:
# renderder
renderer = OctreeRender_trilinear_fast

# init dataset
dataset = dataset_dict[args.dataset_name]
test_dataset = dataset(args.datadir, split='test', downsample=args.downsample_train, is_stack=True)

white_bg = test_dataset.white_bg
ndc_ray = args.ndc_ray

logfolder = os.path.dirname(args.ckpt)

os.makedirs(f'{logfolder}/{args.expname}/imgs_test_all', exist_ok=True)
PSNRs_test = evaluation(test_dataset, tensorf, args, renderer, f'{logfolder}/{args.expname}/imgs_test_all/',
N_vis=args.N_vis, N_samples=-1, white_bg = white_bg, ndc_ray=ndc_ray,device=device)
print(f'============> {args.expname} test all psnr: {np.mean(PSNRs_test)} <============')




if __name__ == '__main__':
torch.set_default_dtype(torch.float32)
torch.manual_seed(20211202)
np.random.seed(20211202)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = config_parser()

if args.compress:
compress_method1(args, device)

if args.decompress:
decompress_method1(args)
73 changes: 73 additions & 0 deletions TensoRF/huffman.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from collections import Counter
import numpy as np


class NodeTree(object):
def __init__(self, left=None, right=None):
self.left = left
self.right = right

def children(self):
return self.left, self.right

def __str__(self):
return self.left, self.right


def huffman_code_tree(node, binString=''):
'''
Function to find Huffman Code
'''
if type(node) in (np.float32, np.float64, np.int8, np.int64, str) : # str for debugging
return {node: binString}
(l, r) = node.children()
# import pdb; pdb.set_trace()
d = dict()
d.update(huffman_code_tree(l, binString + '0'))
d.update(huffman_code_tree(r, binString + '1'))
return d


def make_tree(nodes):
'''
Function to make tree
:param nodes: Nodes
:return: Root of the tree
'''
while len(nodes) > 1:
(key1, c1) = nodes[-1]
(key2, c2) = nodes[-2]
nodes = nodes[:-2]
node = NodeTree(key1, key2)
nodes.append((node, c1 + c2))
nodes = sorted(nodes, key=lambda x: x[1], reverse=True)
return nodes[0][0]


def huffman(inputs):
count_dict = dict(Counter(inputs))
count_dict = sorted(count_dict.items(), key=lambda x: x[1], reverse=True)
count_tree = make_tree(count_dict)
huff_table = huffman_code_tree(count_tree)
encoded = list(map(huff_table.get, inputs))
encoded = ''.join(map(str, encoded))
return encoded, count_tree


def dehuffman(root, enc):
ret = []
curr = root
_len = len(enc)
for i in range(_len):
if enc[i] == '0':
curr = curr.left
elif enc[i] == '1':
curr = curr.right
else:
print(enc[i])
raise NotImplementedError

if (type(curr) in (np.int8, np.int64)):
ret.append(curr)
curr = root
return np.array(ret)
Loading

0 comments on commit 534d925

Please sign in to comment.