Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stanislav #4

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions musco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from . import tf
101 changes: 48 additions & 53 deletions musco/tf/compressor/compress.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tensorflow import keras
from musco.tf.compressor.decompositions.cp3 import get_cp3_seq
from musco.tf.compressor.decompositions.cp4 import get_cp4_seq
from musco.tf.compressor.decompositions.svd import get_svd_seq
from musco.tf.compressor.decompositions.svd import get_svd_seq, get_svd_new_layer
from musco.tf.compressor.decompositions.tucker2 import get_tucker2_seq
from musco.tf.compressor.exceptions.compression_error import CompressionError
from tqdm import tqdm
Expand Down Expand Up @@ -39,7 +39,7 @@ def compress_seq(model, decompose_info, optimize_rank=False, vbmf=True, vbmf_wea
:return: new tf.keras.Model with compressed layers
"""

x = model.input
x = model.input # to fix bug input->Input (on graph)
new_model = keras.Sequential([])

for idx, layer in enumerate(tqdm(model.layers)):
Expand All @@ -50,42 +50,53 @@ def compress_seq(model, decompose_info, optimize_rank=False, vbmf=True, vbmf_wea
continue

decompose, decomp_rank = decompose_info[layer.name]
if decompose.lower() == "svd":
logging.info("SVD layer {}".format(layer.name))
new_layer = get_svd_seq(layer, rank=decomp_rank)
elif decompose.lower() == "cp3":
logging.info("CP3 layer {}".format(layer.name))
new_layer = get_cp3_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "cp4":
logging.info("CP4 layer {}".format(layer.name))
new_layer = get_cp4_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "tucker2":
logging.info("Tucker2 layer {}".format(layer.name))
new_layer = get_tucker2_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank,
vbmf=vbmf,
vbmf_weaken_factor=vbmf_weaken_factor)
else:
logging.info("Incorrect decompositions type for the layer {}".format(layer.name))
raise NameError(
"Wrong Decomposition Name. You should use one of: [\"svd\", \"cp3\", \"cp4\", \"tucker-2\"]")


new_layer = get_new_layer(decompose, decomp_rank, layer, optimize_rank, vbmf, vbmf_weaken_factor)
x = new_layer(x)
new_model.add(new_layer)

return new_model

def get_new_layer(decompose, decomp_rank, layer, optimize_rank, vbmf, vbmf_weaken_factor):
if decompose.lower() == "svd":
logging.info("SVD layer {}".format(layer.name))
if isinstance(layer, keras.layers.TimeDistributed):
new_layer = get_svd_seq(layer.layer, rank=decomp_rank)
new_layer = keras.layers.TimeDistributed(new_layer)
elif isinstance(layer, keras.layers.Dense) or isinstance(layer, keras.Sequential):
new_layer = get_svd_seq(layer, rank=decomp_rank)
else:
new_layer = get_svd_new_layer(layer, rank=decomp_rank)

elif decompose.lower() == "cp3":
logging.info("CP3 layer {}".format(layer.name))
new_layer = get_cp3_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "cp4":
logging.info("CP4 layer {}".format(layer.name))
new_layer = get_cp4_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "tucker2":
logging.info("Tucker2 layer {}".format(layer.name))
new_layer = get_tucker2_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank,
vbmf=vbmf,
vbmf_weaken_factor=vbmf_weaken_factor)
else:
logging.info("Incorrect decompositions type for the layer {}".format(layer.name))
raise NameError(
"Wrong Decomposition Name. You should use one of: [\"svd\", \"cp3\", \"cp4\", \"tucker-2\"]")

return new_layer


def insert_layer_noseq(model, layer_regexs):
# Auxiliary dictionary to describe the network graph.
network_dict = dict(input_layers_of={}, new_output_tensor_of={})
current_session = tf.keras.backend.get_session()

# current_session = tf.compat.v1.keras.backend.get_session()
# Set the input layers of each layer.
for layer in model.layers:
try:
Expand Down Expand Up @@ -146,9 +157,11 @@ def insert_layer_noseq(model, layer_regexs):
network_dict["new_output_tensor_of"].update({layer.name: x})

# Reconstruct graph.
tf.reset_default_graph()
new_sess = tf.Session()
tf.keras.backend.set_session(new_sess)
# Do not need sessions in tf-v2
#
# tf.compat.v1.reset_default_graph()
# new_sess = tf.compat.v1.Session()
# tf.compat.v1.keras.backend.set_session(new_sess)

input_constructor, input_conf, _ = conenctions[layers_order[0]]
new_model_input = input_constructor.from_config(input_conf)
Expand All @@ -170,40 +183,22 @@ def insert_layer_noseq(model, layer_regexs):
network_dict["new_output_tensor_of"].update({layer.name: x})

new_model = Model(inputs=new_model_input.input, outputs=x)
current_session.close()
# current_session.close()

return new_model


def compress_noseq(model, decompose_info, optimize_rank=False, vbmf=True, vbmf_weaken_factor=0.8):
new_model = model
layer_regexs = dict()
layer_regexs = dict()

for idx, layer in enumerate(model.layers[1:]):
if layer.name not in decompose_info:
continue

decompose, decomp_rank = decompose_info[layer.name]

try:
if decompose.lower() == "svd":
layer_regexs[layer.name] = get_svd_seq(layer, rank=decomp_rank)
elif decompose.lower() == "cp3":
layer_regexs[layer.name] = get_cp3_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "cp4":
layer_regexs[layer.name] = get_cp4_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank)
elif decompose.lower() == "tucker2":
layer_regexs[layer.name] = get_tucker2_seq(layer,
rank=decomp_rank,
optimize_rank=optimize_rank,
vbmf=vbmf,
vbmf_weaken_factor=vbmf_weaken_factor)
except ValueError:
continue
layer_regexs[layer.name] = get_new_layer(decompose, decomp_rank, layer, optimize_rank, vbmf, vbmf_weaken_factor)

new_model = insert_layer_noseq(new_model, layer_regexs)

Expand Down
3 changes: 1 addition & 2 deletions musco/tf/compressor/decompositions/constructor.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import tensorflow as tf
from tensorflow import keras
import keras as K


def check_layer_type(layer, accepted_layers):
return isinstance(layer, accepted_layers)


def check_data_format(layer):
if isinstance(layer, keras.Sequential) or isinstance(layer, K.Sequential):
if isinstance(layer, keras.Sequential):
return any(layer.data_format != "channel_last" for layer in layer.layers)
else:
return layer.data_format != "channel_last"
Expand Down
27 changes: 27 additions & 0 deletions musco/tf/compressor/decompositions/decomp_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import numpy as np

def uv_decompose(matrix, max_rank=None, epsilon=1e-8):
"""
Calcultes a skeleton decomposition using SVD
Factor U takes left singular vectors multiplied by singular values,
factor V takes right singular vectors.
Parameters
----------
matrix: np.array, matrix to decompose
max_rank: int, maximal value of rank (included)
epsilon: maximal difference in Frobenius norm
of the resulting decomposition
Return
------
uv: tuple, two factors of the decomposition
"""
# import pdb; pdb.set_trace()
u, s, v = np.linalg.svd(matrix, full_matrices=False)
errors = np.sqrt(np.abs(np.sum(s**2) - np.cumsum(s**2)))
rank_num = np.argmax(errors < epsilon)
if rank_num == 0: # none of the errors < epsilon => full rank
rank_num = s.shape[0]
if max_rank is None:
max_rank = s.shape[0]
rank = min(max_rank, rank_num)
return np.dot(u[:, :rank+1], np.diag(s[:rank+1])), v[:rank+1, :]
21 changes: 20 additions & 1 deletion musco/tf/compressor/decompositions/svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import numpy as np
from musco.tf.compressor.decompositions.constructor import construct_compressor
import tensorflow as tf
from tensorflow import keras
from musco.tf.compressor.common.utils import del_keys
from musco.tf.compressor import layers as musco_layers


def get_params(layer):
Expand All @@ -20,7 +22,6 @@ def get_params(layer):

return params


def get_truncated_svd(weights, rank):
u, s, v_adj = np.linalg.svd(weights, full_matrices=False)

Expand Down Expand Up @@ -73,3 +74,21 @@ def get_config(layer, copy_conf):

get_svd_seq = construct_compressor(get_params, None, get_svd_factors, get_layers_params_for_factors, get_config,
(keras.layers.Dense, keras.Sequential))

def get_svd_new_layer(layer, rank=2):

conf = layer.get_config()
if isinstance(layer, keras.layers.RNN):
layer = layer.cell

if isinstance(layer, keras.layers.LSTM)\
or isinstance(layer, tf.python.keras.layers.recurrent.LSTMCell)\
or isinstance(layer, keras.layers.LSTMCell):
cell = musco_layers.decomp_recurrent_cell.FusedSVDLSTMCell(
units=layer.units,
parent_layer=layer,
rank=rank
)
new_layer = keras.layers.RNN(cell, return_sequences=conf['return_sequences'])

return new_layer
6 changes: 6 additions & 0 deletions musco/tf/compressor/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from . import svd_layer
from . import group_conv_2d
from . import decomp_recurrent_cell
from . import testing_layers

__all__ = ['group_conv_2d', 'svd_layer', 'decomp_recurrent_cell', 'testing_layers']
Loading