Skip to content

Commit

Permalink
Add docstrings for kimm.timm_utils.* (#51)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Jun 3, 2024
1 parent 927370b commit ec8e096
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions kimm/_src/utils/timm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def _is_non_trainable_weights(name: str):

@kimm_export(parent_path=["kimm.timm_utils"])
def separate_torch_state_dict(state_dict: typing.OrderedDict):
"""Separate the torch state dict into trainable and non-trainable parts.
Args:
state_dict: A `collections.OrderedDict`.
Returns:
A tuple containing the trainable and non-trainable state dicts.
"""
trainable_state_dict = state_dict.copy()
non_trainable_state_dict = state_dict.copy()
trainable_remove_keys = []
Expand All @@ -44,6 +52,15 @@ def separate_torch_state_dict(state_dict: typing.OrderedDict):

@kimm_export(parent_path=["kimm.timm_utils"])
def separate_keras_weights(keras_model: keras.Model):
"""Separate the Keras model into trainable and non-trainable parts.
Args:
keras_model: A `keras.Model` instance.
Returns:
A tuple containing the trainable and non-trainable state lists. Each
list contains (`keras.Variable`, name) pairs.
"""
trainable_weights = []
non_trainable_weights = []
for layer in keras_model.layers:
Expand Down Expand Up @@ -75,6 +92,20 @@ def separate_keras_weights(keras_model: keras.Model):
def assign_weights(
keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray
):
"""Assign the torch weights to the keras weights based on the arguments.
Some basic criterion:
1. 4D must be a convolution weights (also check the name)
2. 2D must be a dense weights
3. 1D must be a vector weights
4. 0D must be a scalar weights
Args:
keras_name: A `str` representing the name of the target weights.
keras_weights: A `keras.Variable` representing the target weights.
torch_weights: A `numpy.ndarray` representing the original source
weights.
"""
if len(keras_weight.shape) == 4:
if (
"conv" in keras_name
Expand Down Expand Up @@ -119,6 +150,19 @@ def is_same_weights(
torch_name: str,
torch_weights: np.ndarray,
):
"""Check whether the given keras weights and torch weigths are the same.
Args:
keras_name: A `str` representing the name of the target weights.
keras_weights: A `keras.Variable` representing the target weights.
torch_name: A `str` representing the name of the original source
weights.
torch_weights: A `numpy.ndarray` representing the original source
weights.
Returns:
A boolean indicating whether the two weights are the same.
"""
if np.sum(keras_weights.shape) != np.sum(torch_weights.shape):
if np.sum(keras_weights.shape) == 0: # Deal with scalar
if np.sum(torch_weights.shape) == 1:
Expand Down

0 comments on commit ec8e096

Please sign in to comment.