-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
50 lines (41 loc) · 1.31 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import Iterator
import torch
from torch import Tensor, nn
def pad_dims_like(x: Tensor, other: Tensor) -> Tensor:
"""Pad dimensions of tensor `x` to match the shape of tensor `other`.
Parameters
----------
x : Tensor
Tensor to be padded.
other : Tensor
Tensor whose shape will be used as reference for padding.
Returns
-------
Tensor
Padded tensor with the same shape as other.
"""
ndim = other.ndim - x.ndim
return x.view(*x.shape, *((1,) * ndim))
@torch.no_grad()
def update_ema_model_(
ema_model: nn.Module, online_model: nn.Module, ema_decay_rate: float
) -> nn.Module:
"""Updates weights of a moving average model with an online/source model.
Parameters
----------
ema_model : nn.Module
Moving average model.
online_model : nn.Module
Online or source model.
ema_decay_rate : float
Parameter that controls by how much the moving average weights are changed.
Returns
-------
nn.Module
Updated moving average model.
"""
param = [p.data for p in online_model.parameters()]
param_ema = [p.data for p in ema_model.parameters()]
torch._foreach_mul_(param_ema, ema_decay_rate)
torch._foreach_add_(param_ema, param, alpha=1 - ema_decay_rate)
return ema_model