diff --git a/imgx/task/diffusion_segmentation/experiment.py b/imgx/task/diffusion_segmentation/experiment.py index c6e79b9..2198924 100644 --- a/imgx/task/diffusion_segmentation/experiment.py +++ b/imgx/task/diffusion_segmentation/experiment.py @@ -387,10 +387,6 @@ def train_init( aug_rng = jax.random.PRNGKey(self.config["seed"]) batch = aug_fn(aug_rng, batch) - # check image size - image_shape = self.dataset_info.image_spatial_shape - chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape) - # init train state on cpu first dtype = get_half_precision_dtype(self.config.half_precision) model = instantiate(self.config.task.model, dtype=dtype) diff --git a/imgx/task/segmentation/experiment.py b/imgx/task/segmentation/experiment.py index 5c66ba2..10c45a5 100644 --- a/imgx/task/segmentation/experiment.py +++ b/imgx/task/segmentation/experiment.py @@ -261,10 +261,6 @@ def train_init( aug_rng = jax.random.PRNGKey(self.config["seed"]) batch = aug_fn(aug_rng, batch) - # check image size - image_shape = self.dataset_info.image_spatial_shape - chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape) - # init train state on cpu first dtype = get_half_precision_dtype(self.config.half_precision) model = instantiate(self.config.task.model, dtype=dtype)