Files
spikers/src/encoders/rate.rs
2026-03-12 22:20:39 +05:30

114 lines
3.3 KiB
Rust

#[allow(non_snake_case)]
use burn::{Tensor, prelude::*, tensor::Distribution};
pub fn encode<B: Backend, const D: usize, const D2: usize>(
data: &Tensor<B, D>,
numSteps: usize,
) -> Tensor<B, D2> {
if D2 != D + 1 {
panic!(
"Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}",
D2,
D + 1
)
}
let newData = data.clone().unsqueeze_dim::<D2>(0).repeat_dim(0, numSteps);
let shape = newData.shape();
match numSteps {
..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"),
_ => Tensor::<B, D2>::random(shape, Distribution::Uniform(0., 1.), &data.device())
.lower(newData)
.float(),
}
}
#[cfg(test)]
mod rateEncoderTests {
use burn::{
Tensor,
backend::{NdArray, ndarray::NdArrayDevice},
tensor::{Float, Shape},
};
type B = NdArray;
use super::encode;
use rstest::rstest;
use rstest_reuse::{self, *};
#[template]
#[rstest]
#[case(1, 1)]
#[case(10, 1)]
#[case(100, 1)]
#[case(1, 10)]
#[case(10, 10)]
#[case(100, 10)]
#[case(1, 100)]
#[case(10, 100)]
#[case(100, 100)]
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
#[apply(testShapeStepComb)]
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
let out = encode::<B, 1, 2>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
s,
t,
out.shape()
);
assert_eq!(out.sum().into_scalar(), 0.);
}
#[apply(testShapeStepComb)]
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
let out = encode::<B, 1, 2>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
t,
s,
out.shape()
);
assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64);
}
#[apply(testShapeStepComb)]
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
let out = encode::<B, 2, 3>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
s,
t,
out.shape()
);
assert_eq!(out.sum().into_scalar(), 0.);
}
#[apply(testShapeStepComb)]
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
let out = encode::<B, 2, 3>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
t,
s,
out.shape()
);
assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64);
}
}