Skip to content

Commit

Permalink
FastViT fixes. (huggingface#2452)
Browse files Browse the repository at this point in the history
* correct optional SE layer dimensions.
 * head_dim instead of num_heads is 32.
 * update test example output.
  • Loading branch information
janimo authored Aug 28, 2024
1 parent aafa24e commit 29e25c4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ If you have an addition to this list, please submit a pull request.
- Parler-TTS, text-to-speech model.
- Computer Vision Models.
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera.
ConvNeXTv2, MobileOne, EfficientVit (MSRA), MobileNetv4, Hiera, FastViT.
- yolo-v3, yolo-v8.
- Segment-Anything Model (SAM).
- SegFormer.
Expand Down
10 changes: 5 additions & 5 deletions candle-examples/examples/fastvit/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ $ cargo run --example fastvit --release -- --image candle-examples/examples/yolo
loaded image Tensor[dims 3, 256, 256; f32]
model built
mountain bike, all-terrain bike, off-roader: 43.45%
bicycle-built-for-two, tandem bicycle, tandem: 14.16%
unicycle, monocycle : 4.12%
crash helmet : 2.26%
alp : 1.40%
mountain bike, all-terrain bike, off-roader: 52.67%
bicycle-built-for-two, tandem bicycle, tandem: 7.93%
unicycle, monocycle : 3.46%
maillot : 1.32%
crash helmet : 1.28%
```
6 changes: 3 additions & 3 deletions candle-transformers/src/models/fastvit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ fn positional_encoding(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
fn attention(dim: usize, vb: VarBuilder) -> Result<Func<'static>> {
let qkv = linear_no_bias(dim, dim * 3, vb.pp("qkv"))?;
let proj = linear(dim, dim, vb.pp("proj"))?;
let num_heads = 32;
let head_dim = dim / num_heads;
let head_dim = 32;
let num_heads = dim / head_dim;
let scale = (head_dim as f64).powf(-0.5);

Ok(Func::new(move |xs| {
Expand Down Expand Up @@ -434,7 +434,7 @@ fn fastvit_patch_embed(
) -> Result<Func<'static>> {
let lk = conv_norm(in_channels, out_channels, 7, 2, vb.pp("proj.0.large_conv"))?;
let sk = conv_norm(in_channels, out_channels, 3, 2, vb.pp("proj.0.small_conv"))?;
let se = squeeze_and_excitation(out_channels, out_channels / 16, vb.pp("proj.0.se"));
let se = squeeze_and_excitation(out_channels, out_channels / 4, vb.pp("proj.0.se"));
let mb = mobileone_block(out_channels, out_channels, 1, 1, 0, true, vb.pp("proj.1"))?;

Ok(Func::new(move |xs| {
Expand Down

0 comments on commit 29e25c4

Please sign in to comment.