-
Notifications
You must be signed in to change notification settings - Fork 31
/
matrix_functions_types.py
116 lines (77 loc) · 4.11 KB
/
matrix_functions_types.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
"""
from dataclasses import dataclass
from commons import AbstractDataclass
@dataclass(init=False)
class PreconditionerComputationConfig(AbstractDataclass):
"""Configuration for preconditioner computation in Shampoo."""
@dataclass(init=False)
class RootInvConfig(PreconditionerComputationConfig):
"""Base dataclass for matrix root inverse method configurations in Shampoo."""
@dataclass(kw_only=True)
class EigenConfig(RootInvConfig):
"""Configuration for eigendecomposition method in Shampoo.
Args:
make_positive_semidefinite (bool): Perturbs matrix eigenvalues to ensure it is numerically positive semi-definite. (Default: True)
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
exponent_multiplier (float): Number to be multiplied to the numerator of the inverse root, i.e., eta where the
exponent is -eta / (2 * p). (Default: 1.0)
"""
make_positive_semidefinite: bool = True
retry_double_precision: bool = True
exponent_multiplier: float = 1.0
DefaultEigenConfig = EigenConfig()
@dataclass(kw_only=True)
class CoupledNewtonConfig(RootInvConfig):
"""Configuration for coupled Newton method in Shampoo.
Args:
max_iterations (int): Maximum number of iterations for coupled Newton iteration. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled Newton iteration. (Default: 1e-6)
"""
max_iterations: int = 100
tolerance: float = 1e-6
@dataclass(kw_only=True)
class CoupledHigherOrderConfig(RootInvConfig):
"""Configuration for coupled higher-order method in Shampoo.
Args:
rel_epsilon (float): Relative epsilon for coupled higher order method. Adds epsilon * lambda_max * I to matrix
before taking matrix root, where lambda_max is an upper bound on maximum eigenvalue. (Default: 0.0)
max_iterations (int): Maximum number of iterations for coupled higher order method. (Default: 100)
tolerance (float): Tolerance for computing root inverse using coupled higher order method. (Default: 1e-8)
order (int): Order of the method. Order must be >= 2. Higher order methods accelerate convergence (fewer iterations),
but can take more matmuls per iteration. order=2 represents Newton's method. (Default: 3)
disable_tf32 (bool): Whether to disable tf32 matmuls or not internally. Highly recommend keeping True,
since tf32 is challenging numerically here. (Default: True)
"""
rel_epsilon: float = 0.0
max_iterations: int = 100
tolerance: float = 1e-8
order: int = 3
disable_tf32: bool = True
@dataclass(init=False)
class EigenvalueCorrectionConfig(PreconditionerComputationConfig):
"""Base dataclass for matrix eigenvector method configurations in eigenvalue-corrected Shampoo."""
@dataclass(kw_only=True)
class EighEigenvalueCorrectionConfig(EigenvalueCorrectionConfig):
"""Configuration for eigendecomposition method used in eigenvalue-corrected Shampoo.
Args:
retry_double_precision (bool): Whether to re-trying eigendecomposition with higher(double) precision if lower precision fails due
to CuSOLVER failure. (Default: True)
"""
retry_double_precision: bool = True
DefaultEighEigenvalueCorrectionConfig = EighEigenvalueCorrectionConfig()
@dataclass(kw_only=True)
class QREigenvalueCorrectionConfig(EigenvalueCorrectionConfig):
"""Configuration for orthogonal/simultaneous iterations (QR algorithm) used in eigenvalue-corrected Shampoo.
Args:
max_iterations (int): The maximum number of iterations to perform. (Default: 1)
tolerance (float): The tolerance for determining convergence in terms of the relative change of the eigenvectors estimate.
(Default: 1e-5)
"""
max_iterations: int = 1
tolerance: float = 1e-5