Skip to content

Commit

Permalink
add visualization for mask (train.py)
Browse files Browse the repository at this point in the history
  • Loading branch information
daniel03c1 committed Jul 31, 2022
1 parent 73f0a31 commit 4765688
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion PREF/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def reconstruction(args, return_bbox=False, return_memory=False,
loss_tv.detach().item(),
global_step=iteration)

# mask
# total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.den_mask])
# total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.app_mask])

if TV_weight_app > 0:
TV_weight_app *= lr_factor
raise NotImplementedError('not implemented')
Expand Down Expand Up @@ -318,7 +322,19 @@ def reconstruction(args, return_bbox=False, return_memory=False,

phasorf.save(f'{logfolder}/{args.expname}.th')

breakpoint()
# test
numel = sum([p.numel() for p in phasorf.parameters()])
if hasattr(phasorf, 'den_mask'):
numel -= sum([m.numel() for m in phasorf.den_mask])
numel -= sum([m.numel() for m in phasorf.app_mask])

print(f'Total size: {numel*4/1_048_576:.4f}MB')
if hasattr(phasorf, 'den_mask'):
reduced = sum([d.numel() * (m < 0).float().mean()
for d, m in zip(phasorf.den, phasorf.den_mask)]) \
+ sum([d.numel() * (m < 0).float().mean()
for d, m in zip(phasorf.app, phasorf.app_mask)])
print(f'reduced size: {(numel - reduced)*4/1_048_576:.4f}MB')

if args.render_train:
os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)
Expand Down

0 comments on commit 4765688

Please sign in to comment.