From 5bb2fce9985b8fa4bfa4f289ac0c76b8a3471121 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 7 Aug 2023 07:53:05 +0200 Subject: [PATCH] Implement group-norm. (#334) * Implement group-norm. * Add some testing for group-norm. --- .../examples/stable-diffusion/clip.rs | 3 +- .../examples/stable-diffusion/utils.rs | 5 - candle-nn/src/group_norm.rs | 47 ++++++-- candle-nn/src/ops.rs | 6 + candle-nn/tests/group_norm.rs | 103 ++++++++++++++++++ 5 files changed, 150 insertions(+), 14 deletions(-) create mode 100644 candle-nn/tests/group_norm.rs diff --git a/candle-examples/examples/stable-diffusion/clip.rs b/candle-examples/examples/stable-diffusion/clip.rs index be798ad011..227660b131 100644 --- a/candle-examples/examples/stable-diffusion/clip.rs +++ b/candle-examples/examples/stable-diffusion/clip.rs @@ -6,6 +6,7 @@ //! //! https://github.com/openai/CLIP use candle::{Device, Result, Tensor, D}; +use candle_nn as nn; #[derive(Debug, Clone, Copy)] pub enum Activation { @@ -16,7 +17,7 @@ pub enum Activation { impl Activation { fn forward(&self, xs: &Tensor) -> Result { match self { - Activation::QuickGelu => xs * crate::utils::sigmoid(&(xs * 1.702f64)?)?, + Activation::QuickGelu => xs * nn::ops::sigmoid(&(xs * 1.702f64)?)?, Activation::Gelu => xs.gelu(), } } diff --git a/candle-examples/examples/stable-diffusion/utils.rs b/candle-examples/examples/stable-diffusion/utils.rs index 90fe3f9a7e..4294d82393 100644 --- a/candle-examples/examples/stable-diffusion/utils.rs +++ b/candle-examples/examples/stable-diffusion/utils.rs @@ -1,10 +1,5 @@ use candle::{Device, Result, Tensor}; -pub fn sigmoid(xs: &Tensor) -> Result { - // TODO: Add sigmoid as binary ops. - (xs.neg()?.exp()? - 1.0)?.recip() -} - pub fn avg_pool2d(_: &Tensor) -> Result { todo!() } diff --git a/candle-nn/src/group_norm.rs b/candle-nn/src/group_norm.rs index 4b9bed7307..e277ae8561 100644 --- a/candle-nn/src/group_norm.rs +++ b/candle-nn/src/group_norm.rs @@ -1,10 +1,9 @@ //! Group Normalization. //! //! This layer applies Group Normalization over a mini-batch of inputs. -use candle::{Result, Tensor}; +use candle::{DType, Result, Tensor}; // This group norm version handles both weight and bias so removes the mean. -#[allow(dead_code)] #[derive(Debug)] pub struct GroupNorm { weight: Tensor, @@ -21,18 +20,50 @@ impl GroupNorm { num_channels: usize, num_groups: usize, eps: f64, - ) -> Self { - Self { + ) -> Result { + if num_channels % num_groups != 0 { + candle::bail!( + "GroupNorm: num_groups ({num_groups}) must divide num_channels ({num_channels})" + ) + } + Ok(Self { weight, bias, eps, num_channels, num_groups, - } + }) } - pub fn forward(&self, _: &Tensor) -> Result { - todo!() + pub fn forward(&self, x: &Tensor) -> Result { + let x_shape = x.dims(); + if x_shape.len() <= 2 { + candle::bail!("input rank for GroupNorm should be at least 3"); + } + let (b_sz, n_channels) = (x_shape[0], x_shape[1]); + let hidden_size = x_shape[2..].iter().product::() * n_channels / self.num_groups; + if n_channels != self.num_channels { + candle::bail!( + "unexpected num-channels in GroupNorm ({n_channels} <> {}", + self.num_channels + ) + } + let x_dtype = x.dtype(); + let internal_dtype = match x_dtype { + DType::F16 | DType::BF16 => DType::F32, + d => d, + }; + let x = x.reshape((b_sz, self.num_groups, hidden_size))?; + let x = x.to_dtype(internal_dtype)?; + let mean_x = (x.sum_keepdim(2)? / hidden_size as f64)?; + let x = x.broadcast_sub(&mean_x)?; + let norm_x = (x.sqr()?.sum_keepdim(2)? / hidden_size as f64)?; + let x_normed = x.broadcast_div(&(norm_x + self.eps)?.sqrt()?)?; + x_normed + .to_dtype(x_dtype)? + .broadcast_mul(&self.weight)? + .broadcast_add(&self.bias)? + .reshape(x_shape) } } @@ -44,5 +75,5 @@ pub fn group_norm( ) -> Result { let weight = vb.get_or_init(num_channels, "weight", crate::Init::Const(1.))?; let bias = vb.get_or_init(num_channels, "bias", crate::Init::Const(0.))?; - Ok(GroupNorm::new(weight, bias, num_channels, num_groups, eps)) + GroupNorm::new(weight, bias, num_channels, num_groups, eps) } diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 29cc697312..397674f3de 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -34,5 +34,11 @@ pub fn log_softmax(xs: &Tensor, d: D) -> Result { } pub fn silu(xs: &Tensor) -> Result { + // TODO: Should we have a specialized op for this? xs / (xs.neg()?.exp()? + 1.0)? } + +pub fn sigmoid(xs: &Tensor) -> Result { + // TODO: Should we have a specialized op for this? + (xs.neg()?.exp()? + 1.0)?.recip() +} diff --git a/candle-nn/tests/group_norm.rs b/candle-nn/tests/group_norm.rs new file mode 100644 index 0000000000..d48b69f655 --- /dev/null +++ b/candle-nn/tests/group_norm.rs @@ -0,0 +1,103 @@ +/* Equivalent PyTorch code. +import torch +from torch.nn.functional import group_norm +t = torch.tensor( + [[[-0.3034, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [ 1.9507, 1.2554, -0.8625], + [ 1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [ 1.5157, -0.1546, -0.5596]], + + [[-1.6698, -0.4040, -0.7927], + [ 0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641]]]) +print(group_norm(t, num_groups=2)) +print(group_norm(t, num_groups=3)) +*/ +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use anyhow::Result; +use candle::{Device, Tensor}; +use candle_nn::GroupNorm; +mod test_utils; +use test_utils::to_vec3_round; + +#[test] +fn group_norm() -> Result<()> { + let device = &Device::Cpu; + let w = Tensor::new(&[1f32], device)?; + let b = Tensor::new(&[0f32], device)?; + let gn2 = GroupNorm::new(w.clone(), b.clone(), 6, 2, 1e-5)?; + let gn3 = GroupNorm::new(w, b, 6, 3, 1e-5)?; + + let input = Tensor::new( + &[ + [ + [-0.3034f32, 0.2726, -0.9659], + [-1.1845, -1.3236, 0.0172], + [1.9507, 1.2554, -0.8625], + [1.0682, 0.3604, 0.3985], + [-0.4957, -0.4461, -0.9721], + [1.5157, -0.1546, -0.5596], + ], + [ + [-1.6698, -0.4040, -0.7927], + [0.3736, -0.0975, -0.1351], + [-0.9461, 0.5461, -0.6334], + [-1.0919, -0.1158, 0.1213], + [-0.9535, 0.1281, 0.4372], + [-0.2845, 0.3488, 0.5641], + ], + ], + device, + )?; + assert_eq!( + to_vec3_round(gn2.forward(&input)?, 4)?, + &[ + [ + [-0.1653, 0.3748, -0.7866], + [-0.9916, -1.1220, 0.1353], + [1.9485, 1.2965, -0.6896], + [1.2769, 0.3628, 0.4120], + [-0.7427, -0.6786, -1.3578], + [1.8547, -0.3022, -0.8252] + ], + [ + [-1.9342, 0.0211, -0.5793], + [1.2223, 0.4945, 0.4365], + [-0.8163, 1.4887, -0.3333], + [-1.7960, -0.0392, 0.3875], + [-1.5469, 0.3998, 0.9561], + [-0.3428, 0.7970, 1.1845] + ] + ] + ); + assert_eq!( + to_vec3_round(gn3.forward(&input)?, 4)?, + &[ + [ + [0.4560, 1.4014, -0.6313], + [-0.9901, -1.2184, 0.9822], + [1.4254, 0.6360, -1.7682], + [0.4235, -0.3800, -0.3367], + [-0.3890, -0.3268, -0.9862], + [2.1325, 0.0386, -0.4691] + ], + [ + [-1.8797, 0.0777, -0.5234], + [1.2802, 0.5517, 0.4935], + [-1.0102, 1.5327, -0.4773], + [-1.2587, 0.4047, 0.8088], + [-1.9074, 0.1691, 0.7625], + [-0.6230, 0.5928, 1.0061] + ] + ] + ); + + Ok(()) +}