Skip to content

Commit

Permalink
Ruff stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
casblaauw committed Nov 18, 2024
1 parent c2158b3 commit a550777
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/crested/tl/zoo/utils/_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def conv_block(
)(inputs)
if normalization == "batch":
x = keras.layers.BatchNormalization(
momentum=batchnorm_momentum,
momentum=batchnorm_momentum,
name=name_prefix + "_batchnorm" if name_prefix else None
)(x)
elif normalization == "layer":
Expand All @@ -164,16 +164,16 @@ def conv_block(
if res:
if filters != residual.shape[2]:
residual = keras.layers.Convolution1D(
filters=filters,
kernel_size=1,
filters=filters,
kernel_size=1,
strides=1,
name=name_prefix + "_resconv" if name_prefix else None,
)(residual)
x = keras.layers.Add()([x, residual])

if pool_size > 1:
x = keras.layers.MaxPooling1D(
pool_size=pool_size,
pool_size=pool_size,
padding=padding,
name=name_prefix + "_pool" if name_prefix else None,
)(x)
Expand Down Expand Up @@ -355,7 +355,7 @@ def conv_block_bs(
else:
bn_layer = keras.layers.BatchNormalization
current = bn_layer(
momentum=bn_momentum,
momentum=bn_momentum,
gamma_initializer=bn_gamma,
name=name_prefix + "_bnorm" if name_prefix else None,
)(current)
Expand All @@ -382,7 +382,7 @@ def conv_block_bs(
else:
pool_layer = keras.layers.MaxPool1D
current = pool_layer(
pool_size=pool_size,
pool_size=pool_size,
padding=padding,
name=name_prefix + "_pool" if name_prefix else None,
)(current)
Expand Down

0 comments on commit a550777

Please sign in to comment.