Skip to content

Commit

Permalink
mask approach fit method rework
Browse files Browse the repository at this point in the history
  • Loading branch information
timetoai committed Sep 2, 2023
1 parent 6fa07b1 commit cddd113
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/timediffusion/frameworks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def device(self):
return next(self.model.parameters()).device

def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.ndarray, torch.Tensor] = None,
epochs: int = 20, batch_size: int = 2, steps_per_epoch: int = 32,
mask_fill: str = "noise", epochs: int = 20, batch_size: int = 2, steps_per_epoch: int = 32,
lr: float = 4e-4, distance_loss: Union[str, nn.Module] = "MAE",
distribution_loss: Union[str, nn.Module] = "kl_div", distrib_loss_coef = 1e-2,
verbose: bool = False, seed=42) -> list[float]:
Expand All @@ -63,6 +63,8 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda
`mask` - None for full model fitting on `example`
or same shape as `example` for not fitting in points, that masked with 1
`mask_fill` - "noise" / "original"
`epochs` - number of training epochs
`batch_size` - number of random noises to train on
Expand Down Expand Up @@ -110,6 +112,8 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda
raise NotImplementedError(f"Distribution loss should be 'kl_div' or nn.Module got {type(distribution_loss)}")

# mask check
if mask is not None and mask_fill not in ("noise", "original"):
raise ValueError(f"mask_fill should be 'noise' or 'original', got {mask_fill}")
if mask is not None and mask.shape != example.shape:
raise ValueError(f"Mask should None or the same shape as example, got {example.shape = } and {mask.shape = }")

Expand All @@ -120,7 +124,7 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda
X = train_tensor.repeat(batch_size, *[1] * (len(train_tensor.shape) - 1))

if mask is not None:
mask_tensor = ~ torch.tensor(mask, dtype=torch.bool, device=self.device())
mask_tensor = ~ torch.tensor(mask, dtype=torch.bool, device=self.device()).unsqueeze(0)

optim = torch.optim.Adam(self.parameters(), lr=lr)
losses = []
Expand All @@ -129,6 +133,9 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda
for epoch in (tqdm(range(1, epochs + 1)) if verbose else range(1, epochs + 1)):
self.model.train()

if mask is not None and mask_fill == "noise":
X[~ mask_tensor] = torch.rand(batch_size * (~ mask_tensor).sum())

noise = torch.rand(*X.shape, device=self.device(), dtype=self.dtype())
# noise_level = torch.rand(X.shape).to(device=self.device(), dtype=self.dtype())
# noise *= noise_level
Expand All @@ -148,6 +155,7 @@ def fit(self, example: Union[np.ndarray, torch.Tensor], mask: Union[None, np.nda
noise -= y_hat
losses.append(loss.item())

# saving some training parameters, could be useful in inference
self.training_steps_per_epoch = steps_per_epoch
self.training_example = example
self.distance_loss = distance_loss
Expand Down Expand Up @@ -199,6 +207,10 @@ def restore(self, example: Union[None, np.ndarray, torch.Tensor] = None, shape:

torch.random.manual_seed(seed)
X = torch.rand(*shape).to(device=self.device(), dtype=self.dtype())

# no real meaning behind masking random noise
# maybe for fun, but setting it here as None for stability
mask = None
else:
if len(self.input_dims) != len(example.shape):
raise ValueError(f"Model fitted with {len(self.input_dims)} dims, but got {len(example.shape)}")
Expand Down

0 comments on commit cddd113

Please sign in to comment.