diff --git a/jflux/modules/layers.py b/jflux/modules/layers.py index 050d486..57898b9 100644 --- a/jflux/modules/layers.py +++ b/jflux/modules/layers.py @@ -214,6 +214,7 @@ def __init__( self.img_norm1 = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, @@ -229,6 +230,7 @@ def __init__( self.img_norm2 = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, @@ -257,6 +259,7 @@ def __init__( self.txt_norm1 = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, @@ -272,6 +275,7 @@ def __init__( self.txt_norm2 = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, @@ -382,6 +386,7 @@ def __init__( self.pre_norm = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, @@ -419,6 +424,7 @@ def __init__( self.norm_final = nnx.LayerNorm( num_features=hidden_size, use_scale=False, + use_bias=False, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, diff --git a/jflux/port.py b/jflux/port.py index 3f3d1a7..60a8590 100644 --- a/jflux/port.py +++ b/jflux/port.py @@ -1,116 +1,146 @@ from einops import rearrange +############################################################################################## +# AUTOENCODER MODEL PORTING +############################################################################################## + + +def port_group_norm(group_norm, tensors, prefix): + group_norm.scale.value = tensors[f"{prefix}.weight"] + group_norm.bias.value = tensors[f"{prefix}.bias"] + + return group_norm + + +def port_conv(conv, tensors, prefix): + conv.kernel.value = rearrange(tensors[f"{prefix}.weight"], "i o k1 k2 -> k1 k2 o i") + conv.bias.value = tensors[f"{prefix}.bias"] + + return conv + def port_attn_block(attn_block, tensors, prefix): # port the norm - attn_block.norm.scale.value = tensors[f"{prefix}.norm.weight"] - attn_block.norm.bias.value = tensors[f"{prefix}.norm.bias"] + attn_block.norm = port_group_norm( + group_norm=attn_block.norm, + tensors=tensors, + prefix=f"{prefix}.norm", + ) # port the k, q, v layers - attn_block.k.kernel.value = rearrange( - tensors[f"{prefix}.k.weight"], "i o k1 k2 -> k1 k2 o i" + attn_block.k = port_conv( + conv=attn_block.k, + tensors=tensors, + prefix=f"{prefix}.k", ) - attn_block.k.bias.value = tensors[f"{prefix}.k.bias"] - attn_block.q.kernel.value = rearrange( - tensors[f"{prefix}.q.weight"], "i o k1 k2 -> k1 k2 o i" + attn_block.q = port_conv( + conv=attn_block.q, + tensors=tensors, + prefix=f"{prefix}.q", ) - attn_block.q.bias.value = tensors[f"{prefix}.q.weight"] - attn_block.v.kernel.value = rearrange( - tensors[f"{prefix}.v.weight"], "i o k1 k2 -> k1 k2 o i" + attn_block.v = port_conv( + conv=attn_block.v, + tensors=tensors, + prefix=f"{prefix}.v", ) - attn_block.v.bias.value = tensors[f"{prefix}.v.weight"] # port the proj_out layer - attn_block.proj_out.kernel.value = rearrange( - tensors[f"{prefix}.proj_out.weight"], "i o k1 k2 -> k1 k2 o i" + attn_block.proj_out = port_conv( + conv=attn_block.proj_out, + tensors=tensors, + prefix=f"{prefix}.proj_out", ) - attn_block.proj_out.bias.value = tensors[f"{prefix}.proj_out.weight"] return attn_block def port_resent_block(resnet_block, tensors, prefix): # port the norm - resnet_block.norm1.scale.value = tensors[f"{prefix}.norm1.weight"] - resnet_block.norm1.bias.value = tensors[f"{prefix}.norm1.bias"] - - resnet_block.norm2.scale.value = tensors[f"{prefix}.norm2.weight"] - resnet_block.norm2.bias.value = tensors[f"{prefix}.norm2.bias"] + resnet_block.norm1 = port_group_norm( + group_norm=resnet_block.norm1, + tensors=tensors, + prefix=f"{prefix}.norm1", + ) + resnet_block.norm2 = port_group_norm( + group_norm=resnet_block.norm2, + tensors=tensors, + prefix=f"{prefix}.norm2", + ) # port the convs - resnet_block.conv1.kernel.value = rearrange( - tensors[f"{prefix}.conv1.weight"], "i o k1 k2 -> k1 k2 o i" + resnet_block.conv1 = port_conv( + conv=resnet_block.conv1, + tensors=tensors, + prefix=f"{prefix}.conv1", ) - resnet_block.conv1.bias.value = tensors[f"{prefix}.conv1.weight"] - - resnet_block.conv2.kernel.value = rearrange( - tensors[f"{prefix}.conv2.weight"], "i o k1 k2 -> k1 k2 o i" + resnet_block.conv2 = port_conv( + conv=resnet_block.conv2, + tensors=tensors, + prefix=f"{prefix}.conv2", ) - resnet_block.conv2.bias.value = tensors[f"{prefix}.conv2.weight"] if resnet_block.in_channels != resnet_block.out_channels: - resnet_block.nin_shortcut.kernel.value = rearrange( - tensors[f"{prefix}.nin_shortcut.weight"], "i o k1 k2 -> k1 k2 o i" + resnet_block.nin_shortcut = port_conv( + conv=resnet_block.nin_shortcut, + tensors=tensors, + prefix=f"{prefix}.nin_shortcut", ) - resnet_block.nin_shortcut.bias.value = tensors[f"{prefix}.nin_shortcut.bias"] return resnet_block def port_downsample(downsample, tensors, prefix): # port the conv - downsample.conv.kernel.value = rearrange( - tensors[f"{prefix}.conv.weight"], "i o k1 k2 -> k1 k2 o i" + downsample.conv = port_conv( + conv=downsample.conv, + tensors=tensors, + prefix=f"{prefix}.conv", ) - downsample.conv.bias.value = tensors[f"{prefix}.conv.bias"] + return downsample def port_upsample(upsample, tensors, prefix): # port the conv - upsample.conv.kernel.value = rearrange( - tensors[f"{prefix}.conv.weight"], "i o k1 k2 -> k1 k2 o i" + upsample.conv = port_conv( + conv=upsample.conv, + tensors=tensors, + prefix=f"{prefix}.conv", ) - upsample.conv.bias.value = tensors[f"{prefix}.conv.bias"] + return upsample def port_encoder(encoder, tensors, prefix): - # port downsampling - conv_in = encoder.conv_in - conv_in.kernel.value = rearrange( - tensors[f"{prefix}.conv_in.weight"], "i o k1 k2 -> k1 k2 o i" + # conv in + encoder.conv_in = port_conv( + conv=encoder.conv_in, + tensors=tensors, + prefix=f"{prefix}.conv_in", ) - conv_in.bias.value = tensors[f"{prefix}.conv_in.bias"] # down - down = encoder.down - for i in range(len(down.layers)): + for i, down_layer in enumerate(encoder.down.layers): # block - block = down.layers[i].block - for j in range(len(block.layers)): - resnet_block = block.layers[j] - resnet_block = port_resent_block( - resnet_block=resnet_block, + for j, block_layer in enumerate(down_layer.block.layers): + block_layer = port_resent_block( + resnet_block=block_layer, tensors=tensors, prefix=f"{prefix}.down.{i}.block.{j}", ) - # attn - attn = down.layers[i].attn - for j in range(len(attn.layers)): - attn_block = attn.layers[j] - attn_block = port_attn_block( - attn_block=attn_block, + for j, attn_layer in enumerate(down_layer.attn.layers): + attn_layer = port_attn_block( + attn_block=attn_layer, tensors=tensors, prefix=f"{prefix}.attn.{i}.block.{j}", ) # downsample if i != encoder.num_resolutions - 1: - downsample = down.layers[i].downsample + downsample = down_layer.downsample downsample = port_downsample( downsample=downsample, tensors=tensors, @@ -118,115 +148,336 @@ def port_encoder(encoder, tensors, prefix): ) # mid - mid = encoder.mid - mid_block_1 = mid.block_1 - mid_block_1 = port_resent_block( - resnet_block=mid_block_1, tensors=tensors, prefix=f"{prefix}.mid.block_1" + encoder.mid.block_1 = port_resent_block( + resnet_block=encoder.mid.block_1, + tensors=tensors, + prefix=f"{prefix}.mid.block_1", ) - - mid_attn_1 = mid.attn_1 - mid_attn_1 = port_attn_block( - attn_block=mid_attn_1, tensors=tensors, prefix=f"{prefix}.mid.attn_1" + encoder.mid.attn_1 = port_attn_block( + attn_block=encoder.mid.attn_1, + tensors=tensors, + prefix=f"{prefix}.mid.attn_1", ) - - mid_block_2 = mid.block_2 - mid_block_2 = port_resent_block( - resnet_block=mid_block_2, tensors=tensors, prefix=f"{prefix}.mid.block_2" + encoder.mid.block_2 = port_resent_block( + resnet_block=encoder.mid.block_2, + tensors=tensors, + prefix=f"{prefix}.mid.block_2", ) # norm out - norm_out = encoder.norm_out - norm_out.scale.value = tensors[f"{prefix}.norm_out.weight"] - norm_out.bias.value = tensors[f"{prefix}.norm_out.bias"] + encoder.norm_out = port_group_norm( + group_norm=encoder.norm_out, + tensors=tensors, + prefix=f"{prefix}.norm_out", + ) # conv out - conv_out = encoder.conv_out - conv_out.kernel.value = rearrange( - tensors[f"{prefix}.conv_out.weight"], "i o k1 k2 -> k1 k2 o i" + encoder.conv_out = port_conv( + conv=encoder.conv_out, + tensors=tensors, + prefix=f"{prefix}.conv_out", ) - conv_out.bias.value = tensors[f"{prefix}.conv_out.bias"] return encoder def port_decoder(decoder, tensors, prefix): - # port downsampling - conv_in = decoder.conv_in - - conv_in.kernel.value = rearrange( - tensors[f"{prefix}.conv_in.weight"], "i o k1 k2 -> k1 k2 o i" + # conv in + decoder.conv_in = port_conv( + conv=decoder.conv_in, + tensors=tensors, + prefix=f"{prefix}.conv_in", ) - conv_in.bias.value = tensors[f"{prefix}.conv_in.bias"] # mid - mid = decoder.mid - - mid_block_1 = mid.block_1 - mid_block_1 = port_resent_block( - resnet_block=mid_block_1, tensors=tensors, prefix=f"{prefix}.mid.block_1" + decoder.mid.block_1 = port_resent_block( + resnet_block=decoder.mid.block_1, + tensors=tensors, + prefix=f"{prefix}.mid.block_1", ) - - mid_attn_1 = mid.attn_1 - mid_attn_1 = port_attn_block( - attn_block=mid_attn_1, tensors=tensors, prefix=f"{prefix}.mid.attn_1" + decoder.mid.attn_1 = port_attn_block( + attn_block=decoder.mid.attn_1, + tensors=tensors, + prefix=f"{prefix}.mid.attn_1", ) - - mid_block_2 = mid.block_2 - mid_block_2 = port_resent_block( - resnet_block=mid_block_2, tensors=tensors, prefix=f"{prefix}.mid.block_2" + decoder.mid.block_2 = port_resent_block( + resnet_block=decoder.mid.block_2, + tensors=tensors, + prefix=f"{prefix}.mid.block_2", ) - # up - up = decoder.up - - for i in range(len(up.layers)): + for i, up_layer in enumerate(decoder.up.layers): # block - block = up.layers[i].block - for j in range(len(block.layers)): - resnet_block = block.layers[j] - resnet_block = port_resent_block( - resnet_block=resnet_block, + for j, block_layer in enumerate(up_layer.block.layers): + block_layer = port_resent_block( + resnet_block=block_layer, tensors=tensors, prefix=f"{prefix}.up.{i}.block.{j}", ) # attn - attn = up.layers[i].attn - for j in range(len(attn.layers)): - attn_block = attn.layers[j] - attn_block = port_attn_block( - attn_block=attn_block, + for j, attn_layer in enumerate(up_layer.attn.layers): + attn_layer = port_attn_block( + attn_block=attn_layer, tensors=tensors, prefix=f"{prefix}.up.{i}.attn.{j}", ) # upsample if i != 0: - upsample = up.layers[i].upsample - upsample = port_upsample( - upsample=upsample, tensors=tensors, prefix=f"{prefix}.up.{i}.upsample" + up_layer.upsample = port_upsample( + upsample=up_layer.upsample, + tensors=tensors, + prefix=f"{prefix}.up.{i}.upsample", ) # norm out - norm_out = decoder.norm_out - norm_out.scale.value = tensors[f"{prefix}.norm_out.weight"] - norm_out.bias.value = tensors[f"{prefix}.norm_out.bias"] + decoder.norm_out = port_group_norm( + group_norm=decoder.norm_out, + tensors=tensors, + prefix=f"{prefix}.norm_out", + ) # conv out - conv_out = decoder.conv_out - conv_out.kernel.value = rearrange( - tensors[f"{prefix}.conv_out.weight"], "i o k1 k2 -> k1 k2 o i" + decoder.conv_out = port_conv( + conv=decoder.conv_out, + tensors=tensors, + prefix=f"{prefix}.conv_out", ) - conv_out.bias.value = tensors[f"{prefix}.conv_out.bias"] return decoder def port_autoencoder(autoencoder, tensors): autoencoder.encoder = port_encoder( - encoder=autoencoder.encoder, tensors=tensors, prefix="encoder" + encoder=autoencoder.encoder, + tensors=tensors, + prefix="encoder", ) autoencoder.decoder = port_decoder( - decoder=autoencoder.decoder, tensors=tensors, prefix="decoder" + decoder=autoencoder.decoder, + tensors=tensors, + prefix="decoder", ) return autoencoder + + +############################################################################################## +# FLUX MODEL PORTING +############################################################################################## + + +def port_linear(linear, tensors, prefix): + linear.kernel.value = rearrange(tensors[f"{prefix}.weight"]) + linear.bias.value = rearrange(tensors[f"{prefix}.bias"]) + return linear + + +def port_modulation(modulation, tensors, prefix): + modulation.lin = port_linear( + linear=modulation.lin, tensors=tensors, prefix=f"{prefix}.lin" + ) + return modulation + + +def port_rms_norm(rms_norm, tensors, prefix): + rms_norm.scale.value = tensors[f"{prefix}.scale"] + return rms_norm + + +def port_qk_norm(qk_norm, tensors, prefix): + qk_norm.query_norm = port_rms_norm( + rms_norm=qk_norm.query_norm, + tensors=tensors, + prefix=f"{prefix}.query_norm", + ) + qk_norm.key_norm = port_rms_norm( + rms_norm=qk_norm.key_norm, + tensors=tensors, + prefix=f"{prefix}.key_norm", + ) + return qk_norm + + +def port_self_attention(self_attention, tensors, prefix): + self_attention.qkv = port_linear( + linear=self_attention.qkv, + tensors=tensors, + prefix=f"{prefix}.qkv", + ) + + self_attention.norm = port_qk_norm( + qk_norm=self_attention.norm, + tensors=tensors, + prefix=f"{prefix}.norm", + ) + + self_attention.proj = port_linear( + linear=self_attention.proj, + tensors=tensors, + prefix=f"{prefix}.proj", + ) + + return self_attention + + +def port_double_stream_block(double_stream_block, tensors, prefix): + double_stream_block.img_mod = port_modulation( + modulation=double_stream_block.img_mod, + tensors=tensors, + prefix=f"{prefix}.img_mod", + ) + + # double_stream_block.img_norm1 has no params + + double_stream_block.img_attn = port_self_attention( + self_attention=double_stream_block.img_attn, + tensors=tensors, + prefix="{prefix}.img_attn", + ) + + # double_stream_block.img_norm2 has no params + + double_stream_block.img_mlp.layers[0] = port_linear( + linear=double_stream_block.img_mlp.layers[0], + tensors=tensors, + prefix=f"{prefix}.img_mlp.0", + ) + double_stream_block.img_mlp.layers[2] = port_linear( + linear=double_stream_block.img_mlp.layers[2], + tensors=tensors, + prefix=f"{prefix}.img_mlp.2", + ) + + double_stream_block.txt_mod = port_modulation( + modulation=double_stream_block.txt_mod, + tensors=tensors, + prefix=f"{prefix}.txt_mod", + ) + + # double_stream_block.txt_norm1 has no params + + double_stream_block.txt_attn = port_self_attention( + self_attention=double_stream_block.txt_attn, + tensors=tensors, + prefix="{prefix}.txt_attn", + ) + + # double_stream_block.txt_norm2 has no params + + double_stream_block.txt_mlp.layers[0] = port_linear( + linear=double_stream_block.txt_mlp.layers[0], + tensors=tensors, + prefix=f"{prefix}.txt_mlp.0", + ) + double_stream_block.txt_mlp.layers[2] = port_linear( + linear=double_stream_block.txt_mlp.layers[2], + tensors=tensors, + prefix=f"{prefix}.txt_mlp.2", + ) + + return double_stream_block + + +def port_single_stream_block(single_stream_block, tensors, prefix): + single_stream_block.linear1 = port_linear( + linear=single_stream_block.linear1, tensors=tensors, prefix="{prefix}.linear1" + ) + single_stream_block.linear2 = port_linear( + linear=single_stream_block.linear2, tensors=tensors, prefix="{prefix}.linear2" + ) + + single_stream_block.norm = port_qk_norm( + qk_norm=single_stream_block.norm, tensors=tensors, prefix="{prefix}.norm" + ) + + # single_stream_block.pre_norm has no params + + single_stream_block.modulation = port_modulation( + modulation=single_stream_block.modulation, + tensors=tensors, + prefix="{prefix}.modulation", + ) + + return single_stream_block + + +def port_mlp_embedder(mlp_embedder, tensors, prefix): + mlp_embedder.in_layer = port_linear( + linear=mlp_embedder.in_layer, tensors=tensors, prefix=f"{prefix}.in_layer" + ) + + mlp_embedder.out_layer = port_linear( + linear=mlp_embedder.out_layer, tensors=tensors, prefix=f"{prefix}.out_layer" + ) + return mlp_embedder + + +def port_last_layer(last_layer, tensors, prefix): + # last_layer.norm_final has no params + last_layer.linear = port_linear( + linear=last_layer.linear, + tensors=tensors, + prefix=f"{prefix}.linear", + ) + + last_layer.adaLN_modulation.layers[1] = port_linear( + linear=last_layer.adaLN_modulation.layers[1], + tensors=tensors, + prefix=f"{prefix}.adaLN_modulation.1", + ) + + return last_layer + + +def port_flux(flux, tensors): + flux.img_in = port_linear( + linear=flux.img_in, + tensors=tensors, + prefix="img_in", + ) + + flux.time_in = port_mlp_embedder( + mlp_embedder=flux.time_in, + tensors=tensors, + prefix="time_in", + ) + + flux.vector_in = port_mlp_embedder( + mlp_embedder=flux.vector_in, + tensors=tensors, + prefix="vector_in", + ) + + if flux.params.guidance_embed: + flux.guidance_in = port_mlp_embedder( + mlp_embedder=flux.guidance_in, + tensors=tensors, + prefix="guidance_in", + ) + + flux.txt_in = port_linear( + linear=flux.txt_in, + tensors=tensors, + prefix="txt_in", + ) + + for i, layer in enumerate(flux.double_blocks.layers): + layer = port_double_stream_block( + double_stream_block=layer, + tensors=tensors, + prefix=f"double_blocks.{i}", + ) + + for i, layer in enumerate(flux.single_blocks.layers): + layer = port_single_stream_block( + single_stream_block=layer, + tensors=tensors, + prefix=f"single_blocks.{i}", + ) + + flux.last_layer = port_last_layer( + last_layer=flux.last_layer, + tensors=tensors, + prefix="last_layer", + )