-
Notifications
You must be signed in to change notification settings - Fork 23
/
utils.py
117 lines (78 loc) · 3.16 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared utilities for the networks module."""
from typing import Any
import torch
from einops import pack, rearrange, unpack
def time2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size = x.shape[0]
return rearrange(x, "b c t h w -> (b t) c h w"), batch_size
def batch2time(x: torch.Tensor, batch_size: int) -> torch.Tensor:
return rearrange(x, "(b t) c h w -> b c t h w", b=batch_size)
def space2batch(x: torch.Tensor) -> tuple[torch.Tensor, int]:
batch_size, height = x.shape[0], x.shape[-2]
return rearrange(x, "b c t h w -> (b h w) c t"), batch_size, height
def batch2space(x: torch.Tensor, batch_size: int, height: int) -> torch.Tensor:
return rearrange(x, "(b h w) c t -> b c t h w", b=batch_size, h=height)
def cast_tuple(t: Any, length: int = 1) -> Any:
return t if isinstance(t, tuple) else ((t,) * length)
def replication_pad(x):
return torch.cat([x[:, :, :1, ...], x], dim=2)
def divisible_by(num: int, den: int) -> bool:
return (num % den) == 0
def is_odd(n: int) -> bool:
return not divisible_by(n, 2)
def nonlinearity(x):
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True
)
class CausalNormalize(torch.nn.Module):
def __init__(self, in_channels, num_groups=1):
super().__init__()
self.norm = torch.nn.GroupNorm(
num_groups=num_groups,
num_channels=in_channels,
eps=1e-6,
affine=True,
)
self.num_groups = num_groups
def forward(self, x):
# if num_groups !=1, we apply a spatio-temporal groupnorm for backward compatibility purpose.
# All new models should use num_groups=1, otherwise causality is not guaranteed.
if self.num_groups == 1:
x, batch_size = time2batch(x)
return batch2time(self.norm(x), batch_size)
return self.norm(x)
def exists(v):
return v is not None
def default(*args):
for arg in args:
if exists(arg):
return arg
return None
def pack_one(t, pattern):
return pack([t], pattern)
def unpack_one(t, ps, pattern):
return unpack(t, ps, pattern)[0]
def round_ste(z: torch.Tensor) -> torch.Tensor:
"""Round with straight through gradients."""
zhat = z.round()
return z + (zhat - z).detach()
def log(t, eps=1e-5):
return t.clamp(min=eps).log()
def entropy(prob):
return (-prob * log(prob)).sum(dim=-1)