From 31a1075f4b4799a4922fa0f617ee982baa5baa81 Mon Sep 17 00:00:00 2001 From: shua Date: Mon, 19 Aug 2024 09:06:17 +0200 Subject: [PATCH] onnx: implement LSTM op (#2268) use candle-nn LSTM --- candle-nn/src/rnn.rs | 4 + candle-onnx/src/eval.rs | 239 ++++++++++++++++++++ candle-onnx/tests/ops.rs | 466 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 709 insertions(+) diff --git a/candle-nn/src/rnn.rs b/candle-nn/src/rnn.rs index 30ad6ff57a..5934af852d 100644 --- a/candle-nn/src/rnn.rs +++ b/candle-nn/src/rnn.rs @@ -55,6 +55,10 @@ pub struct LSTMState { } impl LSTMState { + pub fn new(h: Tensor, c: Tensor) -> Self { + LSTMState { h, c } + } + /// The hidden state vector, which is also the output of the LSTM. pub fn h(&self) -> &Tensor { &self.h diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index d02d7ff09c..036f583873 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -66,6 +66,18 @@ impl Attr for GraphProto { } } +impl AttrOwned for Vec { + const TYPE: AttributeType = AttributeType::Strings; + fn get(attr: &onnx::AttributeProto) -> Result { + let mut ret = vec![]; + for bytes in attr.strings.iter() { + let s = String::from_utf8(bytes.clone()).map_err(candle::Error::wrap)?; + ret.push(s); + } + Ok(ret) + } +} + impl AttrOwned for Tensor { const TYPE: AttributeType = AttributeType::Tensor; fn get(attr: &onnx::AttributeProto) -> Result { @@ -1310,6 +1322,233 @@ fn simple_eval_( .broadcast_add(&c.broadcast_mul(&beta)?)?; values.insert(node.output[0].clone(), output); } + "LSTM" => { + let direction = get_attr_opt(node, "direction")?.unwrap_or("forward"); + if direction != "forward" { + bail!("LSTM currently only supports direction == \"forward\""); + } + let num_directions = if direction == "bidirectional" { 2 } else { 1 }; + let hidden_size: i64 = get_attr(node, "hidden_size").copied()?; + let input_forget = get_attr_opt(node, "input_forget")?.copied().unwrap_or(0); + if input_forget != 0 { + bail!("LSTM currently only supports input_forget == 0"); + } + let activations_default = vec![ + "Sigmoid".to_string(), + "Tanh".to_string(), + "Tanh".to_string(), + ]; + let activations = get_attr_opt_owned::>(node, "activations")? + .unwrap_or(activations_default.clone()); + if activations != activations_default { + bail!("LSTM currently only supports default activations ({activations_default:?})"); + } + // activation_alpha and activation_beta don't apply to (Sigmoid, Tanh, Tanh) so ignoring them is okay + if get_attr_opt::(node, "clip")?.is_some() { + bail!("LSTM does not currently support clip attribute"); + } + + // The shape format of inputs X, initial_h and outputs Y, Y_h. + // If 0, the following shapes are expected: + // X.shape = [seq_length, batch_size, input_size], + // Y.shape = [seq_length, num_directions, batch_size, hidden_size], + // initial_h.shape = Y_h.shape = [num_directions, batch_size, hidden_size]. + // If 1, the following shapes are expected: + // X.shape = [batch_size, seq_length, input_size], + // Y.shape = [batch_size, seq_length, num_directions, hidden_size], + // initial_h.shape = Y_h.shape = [batch_size, num_directions, hidden_size]. + let layout = get_attr_opt(node, "layout")?.copied().unwrap_or(0); + if layout != 0 { + bail!("LSTM currently only supports layout == 0"); + } + + // The input sequences packed (and potentially padded) into one 3-D tensor + // with the shape of `[seq_length, batch_size, input_size]`. + let x = get(&node.input[0])?; + // XXX: depends on layout + let (seq_length, batch_size, input_size) = x.dims3()?; + // The weight tensor for the gates. + // Concatenation of `W[iofc]` and `WB[iofc]` (if bidirectional) along dimension 0. + // The tensor has shape `[num_directions, 4*hidden_size, input_size]`. + let w = get(&node.input[1])?; + // The recurrence weight tensor. + // Concatenation of `R[iofc]` and `RB[iofc]` (if bidirectional) along dimension 0. + // This tensor has shape `[num_directions, 4*hidden_size, hidden_size]`. + let r = get(&node.input[2])?; + + let get_opt = |i: usize| { + node.input + .get(i) + .filter(|s: &&String| !s.is_empty()) + .map(|s| get(s)) + }; + + // The bias tensor for input gate. + // Concatenation of `[Wb[iofc], Rb[iofc]]`, and `[WBb[iofc], RBb[iofc]]` (if bidirectional) along dimension 0. + // This tensor has shape `[num_directions, 8*hidden_size]`. + // Optional: If not specified - assumed to be 0. + let b_default: Tensor; + let b = match get_opt(3) { + Some(n) => n?, + None => { + b_default = Tensor::zeros( + (num_directions, 8 * hidden_size as usize), + DType::F32, + x.device(), + )?; + &b_default + } + }; + + // Optional tensor specifying lengths of the sequences in a batch. + // If not specified - assumed all sequences in the batch to have length `seq_length`. + // It has shape `[batch_size]`. + let seq_lens_default: Tensor; + let seq_lens = match get_opt(4) { + Some(n) => n?, + None => { + seq_lens_default = + Tensor::full(seq_length as i64, (batch_size,), x.device())?; + &seq_lens_default + } + }; + let seq_lens_is_default = + (seq_lens.to_vec1::()?.iter()).all(|e| *e as usize == seq_length); + if !seq_lens_is_default { + bail!("LSTM currently only supports default value of seq_lens"); + } + + // Optional initial value of the hidden. If not specified - assumed to be 0. + // It has shape `[num_directions, batch_size, hidden_size]`. + let initial_h_default: Tensor; + let initial_h = match get_opt(5) { + Some(n) => n?, + _ => { + initial_h_default = Tensor::zeros( + (num_directions, batch_size, hidden_size as usize), + DType::F32, + x.device(), + )?; + &initial_h_default + } + }; + + // Optional initial value of the cell. + // If not specified - assumed to be 0. + // It has shape `[num_directions, batch_size, hidden_size]`. + let initial_c_default: Tensor; + let initial_c = match node.input.get(6) { + Some(n) if !n.is_empty() => get(n)?, + _ => { + initial_c_default = Tensor::zeros( + (num_directions, batch_size, hidden_size as usize), + DType::F32, + x.device(), + )?; + &initial_c_default + } + }; + + // The weight tensor for peepholes. + // Concatenation of `P[iof]` and `PB[iof]` (if bidirectional) along dimension 0. + // It has shape `[num_directions, 3*hidde_size]`. Optional: If not specified - assumed to be 0. + let p_default = Tensor::zeros( + (num_directions, 3 * hidden_size as usize), + DType::F32, + x.device(), + )?; + let p = get_opt(7).unwrap_or(Ok(&p_default))?; + let p_is_zeros = (p.to_vec2::()?.iter()).all(|v| v.iter().all(|e| *e == 0.0)); + if !p_is_zeros { + bail!( + "LSTM currently only supports default value of p (a Tensor of all zeroes)" + ); + } + + // these all have [num_directions, ...] shapes + let w = w.get(0)?; // w[iofc] has shape [4*hidden_size, input_size] + let r = r.get(0)?; // r[iofc] has shape [4*hidden_size, hidden_size] + let b = b.get(0)?; // concat of [wb[iofc],rb[iofc]] has shape [8*hidden_size] + let idx_wb = Tensor::arange(0 * hidden_size, 4 * hidden_size, x.device())?; + let idx_rb = Tensor::arange(4 * hidden_size, 8 * hidden_size, x.device())?; + let wb = b.index_select(&idx_wb, 0)?; + let rb = b.index_select(&idx_rb, 0)?; + let c = initial_c.get(0)?; + let h = initial_h.get(0)?; + + // w, r, wb, rb are all iofc but lstm expects ifco + // so we need to move some stuff around + let idx_i = Tensor::arange(0 * hidden_size, 1 * hidden_size, x.device())?; + let idx_o = Tensor::arange(1 * hidden_size, 2 * hidden_size, x.device())?; + let idx_f = Tensor::arange(2 * hidden_size, 3 * hidden_size, x.device())?; + let idx_c = Tensor::arange(3 * hidden_size, 4 * hidden_size, x.device())?; + let idx_ifco = Tensor::cat(&[&idx_i, &idx_f, &idx_c, &idx_o], 0)?; + let w = w.index_select(&idx_ifco, 0)?; + let r = r.index_select(&idx_ifco, 0)?; + let wb = wb.index_select(&idx_ifco, 0)?; + let rb = rb.index_select(&idx_ifco, 0)?; + let vmap = candle_nn::VarMap::new(); + vmap.data().lock().unwrap().extend([ + ("weight_ih_l0".to_string(), candle::Var::from_tensor(&w)?), + ("weight_hh_l0".to_string(), candle::Var::from_tensor(&r)?), + ("bias_ih_l0".to_string(), candle::Var::from_tensor(&wb)?), + ("bias_hh_l0".to_string(), candle::Var::from_tensor(&rb)?), + ]); + use candle_nn::rnn::RNN as _; + let lstm = candle_nn::rnn::lstm( + input_size, + hidden_size as usize, + candle_nn::rnn::LSTMConfig::default(), + candle_nn::VarBuilder::from_varmap(&vmap, w.dtype(), w.device()), + )?; + + let mut lstm_state = candle_nn::rnn::LSTMState::new(h, c); + let mut h_acc = if node.output.get(0).map(String::as_str).unwrap_or("") != "" { + Some(vec![]) + } else { + None + }; + for t in 0..seq_length { + let x = x.get(t)?; + lstm_state = lstm.step(&x, &lstm_state)?; + if let Some(h_acc) = &mut h_acc { + h_acc.push(lstm_state.clone()); + } + } + + assert_eq!(num_directions, 1, "if support for bidirectional is ever added, outputs will have to be concatenated, not simply reshaped"); + if let Some(name) = node.output.get(0) { + let h_acc = h_acc.as_ref().unwrap(); + let h_acc = lstm.states_to_tensor(h_acc)?; + let h_acc = h_acc.reshape(( + seq_length, + num_directions, + batch_size, + hidden_size as usize, + ))?; + values.insert(name.clone(), h_acc); + } + if let Some(name) = node.output.get(1) { + values.insert( + name.clone(), + lstm_state.h().reshape(( + num_directions, + batch_size, + hidden_size as usize, + ))?, + ); + } + if let Some(name) = node.output.get(2) { + values.insert( + name.clone(), + lstm_state.c().reshape(( + num_directions, + batch_size, + hidden_size as usize, + ))?, + ); + } + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 861b3741fb..51ee037e23 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -6,11 +6,13 @@ extern crate accelerate_src; use candle::test_utils::to_vec2_round; use candle::{DType, Device, NdArray, Result, Tensor}; +use candle_onnx::eval::Value; use candle_onnx::onnx::attribute_proto::AttributeType; use candle_onnx::onnx::tensor_proto::DataType; use candle_onnx::onnx::tensor_shape_proto::{dimension, Dimension}; use candle_onnx::onnx::{type_proto, TensorProto, TensorShapeProto, TypeProto}; use candle_onnx::onnx::{AttributeProto, GraphProto, ModelProto, NodeProto, ValueInfoProto}; +use candle_onnx::simple_eval; use std::collections::HashMap; const INPUT_X: &str = "x"; @@ -3514,3 +3516,467 @@ fn test_slice() -> Result<()> { Ok(()) } + +#[test] +fn test_lstm() -> Result<()> { + // values generated from pytorch, so at least it's close enough to what pytorch does + /* + #!/usr/bin/env python3 + + # torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None) + + import torch + + rand_gen = torch.Generator() + rand_gen.manual_seed(1) + input_size = 3 + hidden_size = 5 + batch_size = 1 + sequence_length = 4 + number_directions = 1 + rnn = torch.nn.LSTM(input_size,hidden_size) + weight_ih_l0 = torch.randn(rnn.weight_ih_l0.shape, generator=rand_gen) + weight_hh_l0 = torch.randn(rnn.weight_hh_l0.shape, generator=rand_gen) + bias_ih_l0 = torch.randn(rnn.bias_ih_l0.shape, generator=rand_gen) + bias_hh_l0 = torch.randn(rnn.bias_hh_l0.shape, generator=rand_gen) + rnn.weight_ih_l0 = torch.nn.Parameter(weight_ih_l0) + rnn.weight_hh_l0 = torch.nn.Parameter(weight_hh_l0) + rnn.bias_ih_l0 = torch.nn.Parameter(bias_ih_l0) + rnn.bias_hh_l0 = torch.nn.Parameter(bias_hh_l0) + input = torch.randn(sequence_length, batch_size, input_size, generator=rand_gen) + h0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen) + c0 = torch.randn(number_directions, batch_size, hidden_size, generator=rand_gen) + output, (hn, cn) = rnn(input, (h0, c0)) + + def fmt_tensor(t): + return "Tensor::from_vec::<_, f32>(vec!"+ str(t.flatten().tolist()) + ", (" + "".join([str(n)+"," for n in t.shape])+"), &Device::Cpu)?" + + print("let input_size = ", input_size, ";") + print("let hidden_size = ", hidden_size, ";") + print("let batch_size = ", batch_size, ";") + print("let sequence_length = ", sequence_length, ";") + print("let number_directions = ", number_directions, ";") + print("let weight_ih_l0 = ", fmt_tensor(rnn.weight_ih_l0), ";") + print("let weight_hh_l0 = ", fmt_tensor(rnn.weight_hh_l0), ";") + print("let bias_ih_l0 = ", fmt_tensor(rnn.bias_ih_l0), ";") + print("let bias_hh_l0 = ", fmt_tensor(rnn.bias_hh_l0), ";") + print("let input = ", fmt_tensor(input), ";") + print("let h0 = ", fmt_tensor(h0), ";") + print("let c0 = ", fmt_tensor(c0), ";") + print("let output = ", fmt_tensor(output), ";") + print("let hn = ", fmt_tensor(hn), ";") + print("let cn = ", fmt_tensor(cn), ";") + */ + let input_size = 3; + let hidden_size = 5; + let batch_size = 1; + let sequence_length = 4; + let number_directions = 1; + let weight_ih_l0 = Tensor::from_vec::<_, f32>( + vec![ + -1.5255959033966064, + -0.7502318024635315, + -0.6539809107780457, + -1.6094847917556763, + -0.1001671776175499, + -0.6091889142990112, + -0.9797722697257996, + -1.6090962886810303, + -0.7121446132659912, + 0.30372199416160583, + -0.777314305305481, + -0.25145524740219116, + -0.22227048873901367, + 1.6871134042739868, + 0.22842517495155334, + 0.46763551235198975, + -0.6969724297523499, + -1.1607614755630493, + 0.6995424032211304, + 0.1990816295146942, + 0.8656923770904541, + 0.2444038987159729, + -0.6629113554954529, + 0.8073082566261292, + 1.1016806364059448, + -0.1759360432624817, + -2.2455577850341797, + -1.4464579820632935, + 0.0611552819609642, + -0.6177444458007812, + -0.7980698347091675, + -0.13162320852279663, + 1.8793457746505737, + -0.07213178277015686, + 0.15777060389518738, + -0.7734549045562744, + 0.1990565061569214, + 0.04570277780294418, + 0.15295691788196564, + -0.47567880153656006, + -0.11101982742547989, + 0.2927352488040924, + -0.1578451544046402, + -0.028787139803171158, + 0.4532545804977417, + 1.1421611309051514, + 0.2486107051372528, + -1.7754007577896118, + -0.025502461940050125, + -1.023330569267273, + -0.5961851477622986, + -1.0055307149887085, + 0.42854228615760803, + 1.4760777950286865, + -1.7868678569793701, + 1.610317587852478, + -0.703956663608551, + -0.18526579439640045, + -0.9962350726127625, + -0.8312552571296692, + ], + (20, 3), + &Device::Cpu, + )?; + let weight_hh_l0 = Tensor::from_vec::<_, f32>( + vec![ + 0.4099724292755127, + 0.4084506630897522, + 0.25786539912223816, + 1.095021367073059, + -0.5064865946769714, + 0.09977540373802185, + -0.653973400592804, + 0.731693685054779, + -1.456732988357544, + 1.6089353561401367, + 0.09376997500658035, + -1.2597490549087524, + 0.25463348627090454, + -0.5019572973251343, + -1.041200041770935, + 0.7322672009468079, + 1.3075355291366577, + -1.1627987623214722, + 0.11963611096143723, + -0.1631353348493576, + 0.6614453196525574, + 1.1899205446243286, + 0.8165339231491089, + -0.9135236144065857, + -0.3538065254688263, + 0.7639270424842834, + -0.5889506936073303, + -0.7635973691940308, + 1.3352056741714478, + 0.6042736172676086, + -0.10344208031892776, + -0.15121692419052124, + 1.2465683221817017, + 0.505721390247345, + 0.9505112171173096, + 1.2966482639312744, + 0.873796284198761, + -0.5602594017982483, + 1.2857844829559326, + 0.8168238401412964, + -1.464799404144287, + -1.2629283666610718, + 1.122018814086914, + 1.5663341283798218, + 2.558138370513916, + -0.23336388170719147, + -0.013472129590809345, + 1.8606348037719727, + 1.549620509147644, + 0.34762924909591675, + 0.09300802648067474, + 0.6147403120994568, + 0.7123645544052124, + -1.7765072584152222, + 0.3538645803928375, + 1.1996132135391235, + -0.7122589349746704, + -0.620034396648407, + -0.22813494503498077, + -0.7892746329307556, + -1.6111117601394653, + -1.8716129064559937, + 0.5430836081504822, + 0.6606786251068115, + 0.270527720451355, + 0.5596919655799866, + -0.31839630007743835, + 1.5117206573486328, + -1.363267183303833, + -0.9832196235656738, + 1.5112667083740234, + 0.6418707370758057, + -0.7474458813667297, + -0.923438549041748, + 0.5733984112739563, + -0.10929951071739197, + 0.5181121230125427, + 0.10653535276651382, + 0.26924076676368713, + 1.3247679471969604, + 0.037456899881362915, + -0.6378393173217773, + -0.8147554397583008, + -0.6895065307617188, + 0.8436542749404907, + 1.1657012701034546, + 0.5269321799278259, + 1.6192532777786255, + -0.963976263999939, + 0.14152038097381592, + -0.1636609584093094, + -0.3582225739955902, + 1.7222793102264404, + -0.3035756051540375, + 0.23887419700622559, + 1.3440011739730835, + 0.1032256931066513, + 1.1003541946411133, + -0.3416801989078522, + 0.947338879108429, + ], + (20, 5), + &Device::Cpu, + )?; + let bias_ih_l0 = Tensor::from_vec::<_, f32>( + vec![ + -0.568515956401825, + 0.8375961780548096, + 1.783660650253296, + -0.1954246610403061, + 0.235193133354187, + 1.9142433404922485, + 1.8364111185073853, + 1.324532389640808, + -0.07051458209753036, + 0.34697940945625305, + -0.653679609298706, + 1.5586202144622803, + 0.2185661494731903, + -0.5743072628974915, + 1.4571250677108765, + 1.7709556818008423, + -2.0172998905181885, + 0.42350319027900696, + 0.5730220079421997, + -1.7962429523468018, + ], + (20,), + &Device::Cpu, + )?; + let bias_hh_l0 = Tensor::from_vec::<_, f32>( + vec![ + 1.2470403909683228, + 1.2738511562347412, + 0.3909492492675781, + 0.387210488319397, + 0.14440394937992096, + 0.7771684527397156, + -2.3381125926971436, + -0.829120397567749, + 1.1661391258239746, + 1.4786574840545654, + 0.26760873198509216, + 0.7561198472976685, + -0.5873361229896545, + -2.061920642852783, + 0.4304734766483307, + 0.3376566171646118, + -0.3437853455543518, + -0.6172260642051697, + 1.2529692649841309, + -0.05141742154955864, + ], + (20,), + &Device::Cpu, + )?; + let input = Tensor::from_vec::<_, f32>( + vec![ + 0.6472128033638, + -0.04116716980934143, + -0.17749308049678802, + -0.500039279460907, + 0.8672749400138855, + -0.27319222688674927, + -0.4607681334018707, + -0.0990937128663063, + 0.47284480929374695, + 1.0049484968185425, + -0.2871420383453369, + -1.1618621349334717, + ], + (4, 1, 3), + &Device::Cpu, + )?; + let h0 = Tensor::from_vec::<_, f32>( + vec![ + 0.02758178487420082, + 0.5652382373809814, + -0.011487378738820553, + 0.6706400513648987, + -0.4929250478744507, + ], + (1, 1, 5), + &Device::Cpu, + )?; + let c0 = Tensor::from_vec::<_, f32>( + vec![ + 1.505028486251831, + -2.32635498046875, + 1.6168899536132812, + -0.9026237726211548, + 0.17366823554039001, + ], + (1, 1, 5), + &Device::Cpu, + )?; + let output = Tensor::from_vec::<_, f32>( + vec![ + 0.5956016778945923, + -0.01723279245197773, + 0.11035571992397308, + -0.49323174357414246, + 0.047632161527872086, + 0.6358451843261719, + 0.040328118950128555, + -0.3788611590862274, + -0.7464339733123779, + 0.20080909132957458, + 0.5840265154838562, + 0.1453288197517395, + -0.7345298528671265, + -0.5214304327964783, + 0.21903817355632782, + 0.7420451641082764, + 0.31943878531455994, + -0.04726646468043327, + -0.2823849618434906, + 0.2713133990764618, + ], + (4, 1, 5), + &Device::Cpu, + )?; + let hn = Tensor::from_vec::<_, f32>( + vec![ + 0.7420451641082764, + 0.31943878531455994, + -0.04726646468043327, + -0.2823849618434906, + 0.2713133990764618, + ], + (1, 1, 5), + &Device::Cpu, + )?; + let cn = Tensor::from_vec::<_, f32>( + vec![ + 0.9630558490753174, + 1.0033069849014282, + -1.754899024963379, + -1.5967122316360474, + 0.8252924680709839, + ], + (1, 1, 5), + &Device::Cpu, + )?; + // end of generated values + + let model = create_model_proto_with_graph(Some(GraphProto { + node: vec![NodeProto { + op_type: "LSTM".to_string(), + name: "LSTM_test".to_string(), + attribute: vec![AttributeProto { + name: "hidden_size".to_string(), + r#type: AttributeType::Int.into(), + i: hidden_size as i64, + ..AttributeProto::default() + }], + input: vec![ + "input".to_string(), + "w".to_string(), + "r".to_string(), + "b".to_string(), // b + "".to_string(), // seq_lens + "h".to_string(), + "c".to_string(), + ], + output: vec!["output".to_string(), "hn".to_string(), "cn".to_string()], + ..NodeProto::default() + }], + input: ["input", "w", "r", "b", "h", "c"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + output: ["output", "hn", "cn"] + .into_iter() + .map(|name| ValueInfoProto { + name: name.to_string(), + ..ValueInfoProto::default() + }) + .collect(), + ..GraphProto::default() + })); + // pytorch stores weight and bias as [ifco] but we want it as [iofc] + // so we need to re-arrange the tensors a bit + let idx_iofc = { + let stride = hidden_size as i64; + let dev = weight_ih_l0.device(); + let idx_i = Tensor::arange(0 * stride, 1 * stride, dev)?; + let idx_f = Tensor::arange(1 * stride, 2 * stride, dev)?; + let idx_g = Tensor::arange(2 * stride, 3 * stride, dev)?; + let idx_o = Tensor::arange(3 * stride, 4 * stride, dev)?; + + Tensor::cat(&[&idx_i, &idx_o, &idx_f, &idx_g], 0)? + }; + let w = weight_ih_l0.index_select(&idx_iofc, 0)?; + let w = w.reshape((number_directions, 4 * hidden_size, input_size))?; + let r = weight_hh_l0.index_select(&idx_iofc, 0)?; + let r = r.reshape((number_directions, 4 * hidden_size, hidden_size))?; + let wb = bias_ih_l0.index_select(&idx_iofc, 0)?; + let rb = bias_hh_l0.index_select(&idx_iofc, 0)?; + let b = Tensor::cat(&[wb, rb], 0)?.reshape((number_directions, 8 * hidden_size))?; + let output = output.reshape((sequence_length, number_directions, batch_size, hidden_size))?; + let result = simple_eval( + &model, + HashMap::from_iter([ + ("input".to_string(), input), + ("w".to_string(), w), + ("r".to_string(), r), + ("b".to_string(), b), + ("h".to_string(), h0), + ("c".to_string(), c0), + ]), + )?; + let actual_output = result.get("output").unwrap(); + assert_eq!(output.dims(), actual_output.dims()); + let actual_hn = result.get("hn").unwrap(); + assert_eq!(hn.dims(), actual_hn.dims()); + let actual_cn = result.get("cn").unwrap(); + assert_eq!(cn.dims(), actual_cn.dims()); + let diff_close_enough = |a: &Tensor, b| -> Result<_> { + let diffs = a.sub(b)?.flatten_all()?.to_vec1::()?; + Ok(diffs.iter().all(|f| f.abs() < 0.0001)) + }; + assert!( + diff_close_enough(&output, &actual_output)?, + "output did not match expected\n{actual_output}\n{output}", + ); + assert!( + diff_close_enough(&hn, &actual_hn)?, + "hn did not match expected\n{actual_hn}\n{hn}", + ); + assert!( + diff_close_enough(&cn, &actual_cn)?, + "cn did not match expected\n{actual_cn}\n{cn}", + ); + + Ok(()) +}