From 47656888a570c6189b56437db118e6fb995b8aeb Mon Sep 17 00:00:00 2001 From: daniel03c1 Date: Sun, 31 Jul 2022 13:53:26 +0000 Subject: [PATCH] add visualization for mask (train.py) --- PREF/train.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/PREF/train.py b/PREF/train.py index 331e9bf..41fbc55 100644 --- a/PREF/train.py +++ b/PREF/train.py @@ -231,6 +231,10 @@ def reconstruction(args, return_bbox=False, return_memory=False, loss_tv.detach().item(), global_step=iteration) + # mask + # total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.den_mask]) + # total_loss += 1e-5 * sum([(m * (m>=0)).abs().mean() for m in phasorf.app_mask]) + if TV_weight_app > 0: TV_weight_app *= lr_factor raise NotImplementedError('not implemented') @@ -318,7 +322,19 @@ def reconstruction(args, return_bbox=False, return_memory=False, phasorf.save(f'{logfolder}/{args.expname}.th') - breakpoint() + # test + numel = sum([p.numel() for p in phasorf.parameters()]) + if hasattr(phasorf, 'den_mask'): + numel -= sum([m.numel() for m in phasorf.den_mask]) + numel -= sum([m.numel() for m in phasorf.app_mask]) + + print(f'Total size: {numel*4/1_048_576:.4f}MB') + if hasattr(phasorf, 'den_mask'): + reduced = sum([d.numel() * (m < 0).float().mean() + for d, m in zip(phasorf.den, phasorf.den_mask)]) \ + + sum([d.numel() * (m < 0).float().mean() + for d, m in zip(phasorf.app, phasorf.app_mask)]) + print(f'reduced size: {(numel - reduced)*4/1_048_576:.4f}MB') if args.render_train: os.makedirs(f'{logfolder}/imgs_train_all', exist_ok=True)