initial commit
This commit is contained in:
7095
Cargo.lock
generated
Normal file
7095
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
11
Cargo.toml
Normal file
11
Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
||||
[package]
|
||||
name = "spikers"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.26.1"
|
||||
rstest_reuse = "0.7.0"
|
||||
2
src/encoders/mod.rs
Normal file
2
src/encoders/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod rate;
|
||||
pub mod temporal;
|
||||
113
src/encoders/rate.rs
Normal file
113
src/encoders/rate.rs
Normal file
@@ -0,0 +1,113 @@
|
||||
#[allow(non_snake_case)]
|
||||
use burn::{Tensor, prelude::*, tensor::Distribution};
|
||||
|
||||
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||
data: &Tensor<B, D>,
|
||||
numSteps: usize,
|
||||
) -> Tensor<B, D2> {
|
||||
if D2 != D + 1 {
|
||||
panic!(
|
||||
"Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}",
|
||||
D2,
|
||||
D + 1
|
||||
)
|
||||
}
|
||||
let newData = data.clone().unsqueeze_dim::<D2>(0).repeat_dim(0, numSteps);
|
||||
let shape = newData.shape();
|
||||
match numSteps {
|
||||
..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"),
|
||||
_ => Tensor::<B, D2>::random(shape, Distribution::Uniform(0., 1.), &data.device())
|
||||
.lower(newData)
|
||||
.float(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod rateEncoderTests {
|
||||
use burn::{
|
||||
Tensor,
|
||||
backend::{NdArray, ndarray::NdArrayDevice},
|
||||
tensor::{Float, Shape},
|
||||
};
|
||||
type B = NdArray;
|
||||
use super::encode;
|
||||
use rstest::rstest;
|
||||
use rstest_reuse::{self, *};
|
||||
|
||||
#[template]
|
||||
#[rstest]
|
||||
#[case(1, 1)]
|
||||
#[case(10, 1)]
|
||||
#[case(100, 1)]
|
||||
#[case(1, 10)]
|
||||
#[case(10, 10)]
|
||||
#[case(100, 10)]
|
||||
#[case(1, 100)]
|
||||
#[case(10, 100)]
|
||||
#[case(100, 100)]
|
||||
|
||||
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
|
||||
let out = encode::<B, 1, 2>(&t0, t);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
s,
|
||||
t,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.sum().into_scalar(), 0.);
|
||||
}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
|
||||
let out = encode::<B, 1, 2>(&t0, t);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
t,
|
||||
s,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64);
|
||||
}
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
|
||||
let out = encode::<B, 2, 3>(&t0, t);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
s,
|
||||
t,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.sum().into_scalar(), 0.);
|
||||
}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
|
||||
let out = encode::<B, 2, 3>(&t0, t);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
t,
|
||||
s,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64);
|
||||
}
|
||||
}
|
||||
0
src/encoders/temporal.rs
Normal file
0
src/encoders/temporal.rs
Normal file
17
src/lib.rs
Normal file
17
src/lib.rs
Normal file
@@ -0,0 +1,17 @@
|
||||
#[allow(non_snake_case)]
|
||||
mod encoders;
|
||||
|
||||
pub fn add(left: u64, right: u64) -> u64 {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user