114 lines
3.3 KiB
Rust
114 lines
3.3 KiB
Rust
#[allow(non_snake_case)]
|
|
use burn::{Tensor, prelude::*, tensor::Distribution};
|
|
|
|
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
|
data: &Tensor<B, D>,
|
|
numSteps: usize,
|
|
) -> Tensor<B, D2> {
|
|
if D2 != D + 1 {
|
|
panic!(
|
|
"Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}",
|
|
D2,
|
|
D + 1
|
|
)
|
|
}
|
|
let newData = data.clone().unsqueeze_dim::<D2>(0).repeat_dim(0, numSteps);
|
|
let shape = newData.shape();
|
|
match numSteps {
|
|
..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"),
|
|
_ => Tensor::<B, D2>::random(shape, Distribution::Uniform(0., 1.), &data.device())
|
|
.lower(newData)
|
|
.float(),
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod rateEncoderTests {
|
|
use burn::{
|
|
Tensor,
|
|
backend::{NdArray, ndarray::NdArrayDevice},
|
|
tensor::{Float, Shape},
|
|
};
|
|
type B = NdArray;
|
|
use super::encode;
|
|
use rstest::rstest;
|
|
use rstest_reuse::{self, *};
|
|
|
|
#[template]
|
|
#[rstest]
|
|
#[case(1, 1)]
|
|
#[case(10, 1)]
|
|
#[case(100, 1)]
|
|
#[case(1, 10)]
|
|
#[case(10, 10)]
|
|
#[case(100, 10)]
|
|
#[case(1, 100)]
|
|
#[case(10, 100)]
|
|
#[case(100, 100)]
|
|
|
|
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
|
|
|
|
#[apply(testShapeStepComb)]
|
|
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
|
|
let device = NdArrayDevice::default();
|
|
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
|
|
let out = encode::<B, 1, 2>(&t0, t);
|
|
assert_eq!(
|
|
out.shape(),
|
|
Shape::new([t, s]),
|
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
|
s,
|
|
t,
|
|
out.shape()
|
|
);
|
|
assert_eq!(out.sum().into_scalar(), 0.);
|
|
}
|
|
|
|
#[apply(testShapeStepComb)]
|
|
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
|
|
let device = NdArrayDevice::default();
|
|
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
|
|
let out = encode::<B, 1, 2>(&t0, t);
|
|
assert_eq!(
|
|
out.shape(),
|
|
Shape::new([t, s]),
|
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
|
t,
|
|
s,
|
|
out.shape()
|
|
);
|
|
assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64);
|
|
}
|
|
#[apply(testShapeStepComb)]
|
|
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
|
|
let device = NdArrayDevice::default();
|
|
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
|
|
let out = encode::<B, 2, 3>(&t0, t);
|
|
assert_eq!(
|
|
out.shape(),
|
|
Shape::new([t, s, s]),
|
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
|
s,
|
|
t,
|
|
out.shape()
|
|
);
|
|
assert_eq!(out.sum().into_scalar(), 0.);
|
|
}
|
|
|
|
#[apply(testShapeStepComb)]
|
|
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
|
|
let device = NdArrayDevice::default();
|
|
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
|
|
let out = encode::<B, 2, 3>(&t0, t);
|
|
assert_eq!(
|
|
out.shape(),
|
|
Shape::new([t, s, s]),
|
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
|
t,
|
|
s,
|
|
out.shape()
|
|
);
|
|
assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64);
|
|
}
|
|
}
|