diff --git a/kimm/_src/utils/timm_utils.py b/kimm/_src/utils/timm_utils.py index 427d162..b4530e2 100644 --- a/kimm/_src/utils/timm_utils.py +++ b/kimm/_src/utils/timm_utils.py @@ -22,6 +22,14 @@ def _is_non_trainable_weights(name: str): @kimm_export(parent_path=["kimm.timm_utils"]) def separate_torch_state_dict(state_dict: typing.OrderedDict): + """Separate the torch state dict into trainable and non-trainable parts. + + Args: + state_dict: A `collections.OrderedDict`. + + Returns: + A tuple containing the trainable and non-trainable state dicts. + """ trainable_state_dict = state_dict.copy() non_trainable_state_dict = state_dict.copy() trainable_remove_keys = [] @@ -44,6 +52,15 @@ def separate_torch_state_dict(state_dict: typing.OrderedDict): @kimm_export(parent_path=["kimm.timm_utils"]) def separate_keras_weights(keras_model: keras.Model): + """Separate the Keras model into trainable and non-trainable parts. + + Args: + keras_model: A `keras.Model` instance. + + Returns: + A tuple containing the trainable and non-trainable state lists. Each + list contains (`keras.Variable`, name) pairs. + """ trainable_weights = [] non_trainable_weights = [] for layer in keras_model.layers: @@ -75,6 +92,20 @@ def separate_keras_weights(keras_model: keras.Model): def assign_weights( keras_name: str, keras_weight: keras.Variable, torch_weight: np.ndarray ): + """Assign the torch weights to the keras weights based on the arguments. + + Some basic criterion: + 1. 4D must be a convolution weights (also check the name) + 2. 2D must be a dense weights + 3. 1D must be a vector weights + 4. 0D must be a scalar weights + + Args: + keras_name: A `str` representing the name of the target weights. + keras_weights: A `keras.Variable` representing the target weights. + torch_weights: A `numpy.ndarray` representing the original source + weights. + """ if len(keras_weight.shape) == 4: if ( "conv" in keras_name @@ -119,6 +150,19 @@ def is_same_weights( torch_name: str, torch_weights: np.ndarray, ): + """Check whether the given keras weights and torch weigths are the same. + + Args: + keras_name: A `str` representing the name of the target weights. + keras_weights: A `keras.Variable` representing the target weights. + torch_name: A `str` representing the name of the original source + weights. + torch_weights: A `numpy.ndarray` representing the original source + weights. + + Returns: + A boolean indicating whether the two weights are the same. + """ if np.sum(keras_weights.shape) != np.sum(torch_weights.shape): if np.sum(keras_weights.shape) == 0: # Deal with scalar if np.sum(torch_weights.shape) == 1: