Skip to content

Commit

Permalink
Correctly type the Input layers
Browse files Browse the repository at this point in the history
Also expose model config from the Moonshine class
  • Loading branch information
keveman committed Oct 23, 2024
1 parent 8aabfe9 commit a0f03a4
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions moonshine/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(self, n_layers, dim, inner_dim, n_head, ff_mult=4, ff_swiglu=False)
axis=-1, epsilon=1e-5, center=False, scale=True
)
inputs = keras.layers.Input(shape=[None, dim])
seq_len = keras.layers.Input(shape=[], batch_size=1)
seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32")
pos_emb = rot_pos_emb(Arange()(inputs=seq_len))

x = inputs
Expand Down Expand Up @@ -504,9 +504,9 @@ def __init__(
)

def get_uncached_call(self, dim, rot_embed_dim):
inputs = keras.layers.Input(shape=[None])
seq_len = keras.layers.Input(shape=[], batch_size=1)
context = keras.layers.Input(shape=[None, dim])
inputs = keras.layers.Input(shape=[None], dtype="int32")
seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32")
context = keras.layers.Input(shape=[None, dim], dtype="float32")
rot_pos_emb = RotaryEmbedding(rot_embed_dim)

x = inputs
Expand All @@ -526,17 +526,17 @@ def get_uncached_call(self, dim, rot_embed_dim):
return Model(inputs=[inputs, context, seq_len], outputs=[logits] + outputs)

def get_cached_call(self, dim, rot_embed_dim, key_dim, n_head, n_layers):
inputs = keras.layers.Input(shape=[None])
seq_len = keras.layers.Input(shape=[], batch_size=1)
context = keras.layers.Input(shape=[None, dim])
inputs = keras.layers.Input(shape=[None], dtype="int32")
seq_len = keras.layers.Input(shape=[], batch_size=1, dtype="int32")
context = keras.layers.Input(shape=[None, dim], dtype="float32")
rot_pos_emb = RotaryEmbedding(rot_embed_dim)

cache = [
[
keras.layers.Input(shape=[None, n_head, key_dim]),
keras.layers.Input(shape=[None, n_head, key_dim]),
keras.layers.Input(shape=[None, n_head, key_dim]),
keras.layers.Input(shape=[None, n_head, key_dim]),
keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"),
keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"),
keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"),
keras.layers.Input(shape=[None, n_head, key_dim], dtype="float32"),
]
for _ in range(n_layers)
]
Expand Down Expand Up @@ -596,6 +596,11 @@ def __init__(
self.decoder = Decoder(
dec_n_layers, dim, inner_dim, n_head, vocab_size, dec_ff_mult, dec_ff_swiglu
)
self.dim = dim
self.inner_dim = inner_dim
self.n_head = n_head
self.enc_n_layers = enc_n_layers
self.dec_n_layers = dec_n_layers

def _load_weights(self, preprocessor_weights, encoder_weights, decoder_weights):
self.preprocessor.preprocess.load_weights(preprocessor_weights)
Expand Down

0 comments on commit a0f03a4

Please sign in to comment.