From a0f03a45acf52510b4abd0a927c03cf4dfd45254 Mon Sep 17 00:00:00 2001 From: Manjunath Kudlur Date: Wed, 23 Oct 2024 16:17:48 -0700 Subject: [PATCH] Correctly type the Input layers Also expose model config from the Moonshine class --- moonshine/model.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/moonshine/model.py b/moonshine/model.py index 35f3e70..1ccdd43 100644 --- a/moonshine/model.py +++ b/moonshine/model.py @@ -221,7 +221,7 @@ def __init__(self, n_layers, dim, inner_dim, n_head, ff_mult=4, ff_swiglu=False) axis=-1, epsilon=1e-5, center=False, scale=True ) inputs = keras.layers.Input(shape=[None, dim]) - seq_len = keras.layers.Input(shape=[], batch_size=1) + seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32") pos_emb = rot_pos_emb(Arange()(inputs=seq_len)) x = inputs @@ -504,9 +504,9 @@ def __init__( ) def get_uncached_call(self, dim, rot_embed_dim): - inputs = keras.layers.Input(shape=[None]) - seq_len = keras.layers.Input(shape=[], batch_size=1) - context = keras.layers.Input(shape=[None, dim]) + inputs = keras.layers.Input(shape=[None], dtype="int32") + seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32") + context = keras.layers.Input(shape=[None, dim], dtype="float32") rot_pos_emb = RotaryEmbedding(rot_embed_dim) x = inputs @@ -526,17 +526,17 @@ def get_uncached_call(self, dim, rot_embed_dim): return Model(inputs=[inputs, context, seq_len], outputs=[logits] + outputs) def get_cached_call(self, dim, rot_embed_dim, key_dim, n_head, n_layers): - inputs = keras.layers.Input(shape=[None]) - seq_len = keras.layers.Input(shape=[], batch_size=1) - context = keras.layers.Input(shape=[None, dim]) + inputs = keras.layers.Input(shape=[None], dtype="int32") + seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32") + context = keras.layers.Input(shape=[None, dim], dtype="float32") rot_pos_emb = RotaryEmbedding(rot_embed_dim) cache = [ [ - keras.layers.Input(shape=[None, n_head, key_dim]), - keras.layers.Input(shape=[None, n_head, key_dim]), - keras.layers.Input(shape=[None, n_head, key_dim]), - keras.layers.Input(shape=[None, n_head, key_dim]), + keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"), + keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"), + keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"), + keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"), ] for _ in range(n_layers) ] @@ -596,6 +596,11 @@ def __init__( self.decoder = Decoder( dec_n_layers, dim, inner_dim, n_head, vocab_size, dec_ff_mult, dec_ff_swiglu ) + self.dim = dim + self.inner_dim = inner_dim + self.n_head = n_head + self.enc_n_layers = enc_n_layers + self.dec_n_layers = dec_n_layers def _load_weights(self, preprocessor_weights, encoder_weights, decoder_weights): self.preprocessor.preprocess.load_weights(preprocessor_weights)