From cddd11307956cb024b9c6372f05d01b0c5b50830 Mon Sep 17 00:00:00 2001 From: Vladislav Date: Sat, 2 Sep 2023 12:47:10 +0800 Subject: [PATCH] mask approach fit method rework --- src/timediffusion/frameworks.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/timediffusion/frameworks.py b/src/timediffusion/frameworks.py index 6c75e39..422fb14 100644 --- a/src/timediffusion/frameworks.py +++ b/src/timediffusion/frameworks.py @@ -49,7 +49,7 @@ def device(self): return next(self.model.parameters()).device def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.ndarray, torch.Tensor] = None, - epochs: int = 20, batch_size: int = 2, steps_per_epoch: int = 32, + mask_fill: str = "noise", epochs: int = 20, batch_size: int = 2, steps_per_epoch: int = 32, lr: float = 4e-4, distance_loss: Union[str, nn.Module] = "MAE", distribution_loss: Union[str, nn.Module] = "kl_div", distrib_loss_coef = 1e-2, verbose: bool = False, seed=42) -> list[float]: @@ -63,6 +63,8 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda `mask` - None for full model fitting on `example` or same shape as `example` for not fitting in points, that masked with 1 + `mask_fill` - "noise" / "original" + `epochs` - number of training epochs `batch_size` - number of random noises to train on @@ -110,6 +112,8 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda raise NotImplementedError(f"Distribution loss should be 'kl_div' or nn.Module got {type(distribution_loss)}") # mask check + if mask is not None and mask_fill not in ("noise", "original"): + raise ValueError(f"mask_fill should be 'noise' or 'original', got {mask_fill}") if mask is not None and mask.shape != example.shape: raise ValueError(f"Mask should None or the same shape as example, got {example.shape = } and {mask.shape = }") @@ -120,7 +124,7 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda X = train_tensor.repeat(batch_size, *[1] * (len(train_tensor.shape) - 1)) if mask is not None: - mask_tensor = ~ torch.tensor(mask, dtype=torch.bool, device=self.device()) + mask_tensor = ~ torch.tensor(mask, dtype=torch.bool, device=self.device()).unsqueeze(0) optim = torch.optim.Adam(self.parameters(), lr=lr) losses = [] @@ -129,6 +133,9 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda for epoch in (tqdm(range(1, epochs + 1)) if verbose else range(1, epochs + 1)): self.model.train() + if mask is not None and mask_fill == "noise": + X[~ mask_tensor] = torch.rand(batch_size * (~ mask_tensor).sum()) + noise = torch.rand(*X.shape, device=self.device(), dtype=self.dtype()) # noise_level = torch.rand(X.shape).to(device=self.device(), dtype=self.dtype()) # noise *= noise_level @@ -148,6 +155,7 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda noise -= y_hat losses.append(loss.item()) + # saving some training parameters, could be useful in inference self.training_steps_per_epoch = steps_per_epoch self.training_example = example self.distance_loss = distance_loss @@ -199,6 +207,10 @@ def restore(self, example: Union[None, np.ndarray, torch.Tensor] = None, shape: torch.random.manual_seed(seed) X = torch.rand(*shape).to(device=self.device(), dtype=self.dtype()) + + # no real meaning behind masking random noise + # maybe for fun, but setting it here as None for stability + mask = None else: if len(self.input_dims) != len(example.shape): raise ValueError(f"Model fitted with {len(self.input_dims)} dims, but got {len(example.shape)}")