#[allow(non_snake_case)] use burn::{Tensor, prelude::*, tensor::Distribution}; pub fn encode( data: &Tensor, numSteps: usize, ) -> Tensor { if D2 != D + 1 { panic!( "Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}", D2, D + 1 ) } let newData = data.clone().unsqueeze_dim::(0).repeat_dim(0, numSteps); let shape = newData.shape(); match numSteps { ..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"), _ => Tensor::::random(shape, Distribution::Uniform(0., 1.), &data.device()) .lower(newData) .float(), } } #[cfg(test)] mod rateEncoderTests { use burn::{ Tensor, backend::{NdArray, ndarray::NdArrayDevice}, tensor::{Float, Shape}, }; type B = NdArray; use super::encode; use rstest::rstest; use rstest_reuse::{self, *}; #[template] #[rstest] #[case(1, 1)] #[case(10, 1)] #[case(100, 1)] #[case(1, 10)] #[case(10, 10)] #[case(100, 10)] #[case(1, 100)] #[case(10, 100)] #[case(100, 100)] fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {} #[apply(testShapeStepComb)] fn test1DZeros(#[case] t: usize, #[case] s: usize) { let device = NdArrayDevice::default(); let t0 = Tensor::::zeros(Shape::new([s]), &device); let out = encode::(&t0, t); assert_eq!( out.shape(), Shape::new([t, s]), "Shape testing failed. Expected Shape: [{}, {}], got: {}", s, t, out.shape() ); assert_eq!(out.sum().into_scalar(), 0.); } #[apply(testShapeStepComb)] fn test1DOnes(#[case] t: usize, #[case] s: usize) { let device = NdArrayDevice::default(); let t0 = Tensor::::ones(Shape::new([s]), &device); let out = encode::(&t0, t); assert_eq!( out.shape(), Shape::new([t, s]), "Shape testing failed. Expected Shape: [{}, {}], got: {}", t, s, out.shape() ); assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64); } #[apply(testShapeStepComb)] fn test2DZeros(#[case] t: usize, #[case] s: usize) { let device = NdArrayDevice::default(); let t0 = Tensor::::zeros(Shape::new([s; 2]), &device); let out = encode::(&t0, t); assert_eq!( out.shape(), Shape::new([t, s, s]), "Shape testing failed. Expected Shape: [{}, {}], got: {}", s, t, out.shape() ); assert_eq!(out.sum().into_scalar(), 0.); } #[apply(testShapeStepComb)] fn test2DOnes(#[case] t: usize, #[case] s: usize) { let device = NdArrayDevice::default(); let t0 = Tensor::::ones(Shape::new([s; 2]), &device); let out = encode::(&t0, t); assert_eq!( out.shape(), Shape::new([t, s, s]), "Shape testing failed. Expected Shape: [{}, {}], got: {}", t, s, out.shape() ); assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64); } }