diff --git a/guided_diffusion/custom_dataset_loader.py b/guided_diffusion/custom_dataset_loader.py index f497aa73..612a99df 100644 --- a/guided_diffusion/custom_dataset_loader.py +++ b/guided_diffusion/custom_dataset_loader.py @@ -23,8 +23,8 @@ def __init__(self, args, data_path , transform = None, mode = 'Training',plane = images = sorted(glob(os.path.join(path, "images/*.png"))) masks = sorted(glob(os.path.join(path, "masks/*.png"))) - self.name_list = images[:2] - self.label_list = masks[:2] + self.name_list = images + self.label_list = masks self.data_path = path self.mode = mode @@ -44,10 +44,10 @@ def __getitem__(self, index): img = Image.open(img_path).convert('RGB') mask = Image.open(msk_path).convert('L') - if self.mode == 'Training': - label = 0 if self.label_list[index] == 'benign' else 1 - else: - label = int(self.label_list[index]) + # if self.mode == 'Training': + # label = 0 if self.label_list[index] == 'benign' else 1 + # else: + # label = int(self.label_list[index]) if self.transform: state = torch.get_rng_state() @@ -55,7 +55,8 @@ def __getitem__(self, index): torch.set_rng_state(state) mask = self.transform(mask) - if self.mode == 'Training': - return (img, mask, name) - else: - return (img, mask, name) \ No newline at end of file + return (img, mask, name) + # if self.mode == 'Training': + # return (img, mask, name) + # else: + # return (img, mask, name) diff --git a/guided_diffusion/unet.py b/guided_diffusion/unet.py index fe7ecdb2..ddcc39fb 100644 --- a/guided_diffusion/unet.py +++ b/guided_diffusion/unet.py @@ -735,6 +735,17 @@ def convert_to_fp32(self): self.middle_block.apply(convert_module_to_f32) self.output_blocks.apply(convert_module_to_f32) + def load_part_state_dict(self, state_dict): + + own_state = self.state_dict() + for name, param in state_dict.items(): + if name not in own_state: + continue + if isinstance(param, th.nn.Parameter): + # backwards compatibility for serialized parameters + param = param.data + own_state[name].copy_(param) + def enhance(self, c, h): cu = layer_norm(c.size()[1:])(c) hu = layer_norm(h.size()[1:])(h) diff --git a/scripts/segmentation_sample.py b/scripts/segmentation_sample.py index b96f7c1d..94299829 100644 --- a/scripts/segmentation_sample.py +++ b/scripts/segmentation_sample.py @@ -17,6 +17,7 @@ from guided_diffusion import dist_util, logger from guided_diffusion.bratsloader import BRATSDataset, BRATSDataset3D from guided_diffusion.isicloader import ISICDataset +from guided_diffusion.custom_dataset_loader import CustomDataset import torchvision.utils as vutils from guided_diffusion.utils import staple from guided_diffusion.script_util import ( @@ -58,6 +59,13 @@ def main(): ds = BRATSDataset3D(args.data_dir,transform_test) args.in_ch = 5 + else: + tran_list = [transforms.Resize((args.image_size,args.image_size)), transforms.ToTensor()] + transform_test = transforms.Compose(tran_list) + + ds = CustomDataset(args, args.data_dir, transform_test, mode = 'Test') + args.in_ch = 4 + datal = th.utils.data.DataLoader( ds, batch_size=args.batch_size,