initial commit

This commit is contained in:
2026-03-12 22:20:39 +05:30
parent f6e7433bea
commit d9258bc71a
6 changed files with 7238 additions and 0 deletions

7095
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

11
Cargo.toml Normal file
View File

@@ -0,0 +1,11 @@
[package]
name = "spikers"
version = "0.1.0"
edition = "2024"
[dependencies]
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] }
[dev-dependencies]
rstest = "0.26.1"
rstest_reuse = "0.7.0"

2
src/encoders/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod rate;
pub mod temporal;

113
src/encoders/rate.rs Normal file
View File

@@ -0,0 +1,113 @@
#[allow(non_snake_case)]
use burn::{Tensor, prelude::*, tensor::Distribution};
pub fn encode<B: Backend, const D: usize, const D2: usize>(
data: &Tensor<B, D>,
numSteps: usize,
) -> Tensor<B, D2> {
if D2 != D + 1 {
panic!(
"Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}",
D2,
D + 1
)
}
let newData = data.clone().unsqueeze_dim::<D2>(0).repeat_dim(0, numSteps);
let shape = newData.shape();
match numSteps {
..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"),
_ => Tensor::<B, D2>::random(shape, Distribution::Uniform(0., 1.), &data.device())
.lower(newData)
.float(),
}
}
#[cfg(test)]
mod rateEncoderTests {
use burn::{
Tensor,
backend::{NdArray, ndarray::NdArrayDevice},
tensor::{Float, Shape},
};
type B = NdArray;
use super::encode;
use rstest::rstest;
use rstest_reuse::{self, *};
#[template]
#[rstest]
#[case(1, 1)]
#[case(10, 1)]
#[case(100, 1)]
#[case(1, 10)]
#[case(10, 10)]
#[case(100, 10)]
#[case(1, 100)]
#[case(10, 100)]
#[case(100, 100)]
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
#[apply(testShapeStepComb)]
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
let out = encode::<B, 1, 2>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
s,
t,
out.shape()
);
assert_eq!(out.sum().into_scalar(), 0.);
}
#[apply(testShapeStepComb)]
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
let out = encode::<B, 1, 2>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
t,
s,
out.shape()
);
assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64);
}
#[apply(testShapeStepComb)]
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
let out = encode::<B, 2, 3>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
s,
t,
out.shape()
);
assert_eq!(out.sum().into_scalar(), 0.);
}
#[apply(testShapeStepComb)]
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
let device = NdArrayDevice::default();
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
let out = encode::<B, 2, 3>(&t0, t);
assert_eq!(
out.shape(),
Shape::new([t, s, s]),
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
t,
s,
out.shape()
);
assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64);
}
}

0
src/encoders/temporal.rs Normal file
View File

17
src/lib.rs Normal file
View File

@@ -0,0 +1,17 @@
#[allow(non_snake_case)]
mod encoders;
pub fn add(left: u64, right: u64) -> u64 {
left + right
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let result = add(2, 2);
assert_eq!(result, 4);
}
}