From de3f3990ee0ce631af2a5cad0d466f35f33048fc Mon Sep 17 00:00:00 2001 From: Phani Pavan K Date: Thu, 2 Apr 2026 16:58:07 +0530 Subject: [PATCH] added temp enc --- Cargo.lock | 1 + Cargo.toml | 3 +- examples/simple1d.rs | 41 ++++++++++++ src/encoders/rate.rs | 4 +- src/encoders/temporal.rs | 120 ++++++++++++++++++++++++++++++++++ src/lib.rs | 19 +----- src/neurons/leaky.rs | 14 ++++ src/neurons/mod.rs | 1 + src/surrogate/fast_sigmoid.rs | 1 + src/surrogate/mod.rs | 1 + 10 files changed, 186 insertions(+), 19 deletions(-) create mode 100644 examples/simple1d.rs create mode 100644 src/neurons/leaky.rs create mode 100644 src/neurons/mod.rs create mode 100644 src/surrogate/fast_sigmoid.rs create mode 100644 src/surrogate/mod.rs diff --git a/Cargo.lock b/Cargo.lock index d5c79bb..90b5b89 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5107,6 +5107,7 @@ name = "spikers" version = "0.1.0" dependencies = [ "burn", + "burn-autodiff", "rstest", "rstest_reuse", ] diff --git a/Cargo.toml b/Cargo.toml index 3d54eba..2795819 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,8 @@ version = "0.1.0" edition = "2024" [dependencies] -burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] } +burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray", "cuda"] } +burn-autodiff = "0.20.1" [dev-dependencies] rstest = "0.26.1" diff --git a/examples/simple1d.rs b/examples/simple1d.rs new file mode 100644 index 0000000..549deb3 --- /dev/null +++ b/examples/simple1d.rs @@ -0,0 +1,41 @@ +#![allow(non_snake_case)] +use burn::Tensor; +// +use burn::prelude::Backend; +use spikers::encoders; + +fn main() { + let useGPU = { + // use burn::backend::Wgpu; + use burn::backend::wgpu::WgpuDevice; + match WgpuDevice::DefaultDevice { + WgpuDevice::Cpu => { + println!("GPU Not Found, using NDArray Backend"); + false + } + _ => true, + } + }; + + if useGPU { + use burn::backend::Wgpu; + use burn::backend::wgpu::WgpuDevice; + run::(WgpuDevice::default()); + } else { + use burn::backend::NdArray; + use burn::backend::ndarray::NdArrayDevice; + run::(NdArrayDevice::default()); + } + // encode(data, numSteps); +} + +fn run(device: B::Device) { + println!("Using {:?} Device", B::name(&device)); + // let someTens = Tensor::::random(Shape::new([3]), Distribution::Normal(0., 1.), &device); + let someTens = Tensor::::from_floats([0., 0.2, 0.4, 0.6, 0.8, 1.0], &device); + println!("{}", encoders::rate::encode::(&someTens, 10)); + println!( + "{}", + encoders::temporal::encode::(&someTens, 10, true) + ); +} diff --git a/src/encoders/rate.rs b/src/encoders/rate.rs index 393fb56..f5ab021 100644 --- a/src/encoders/rate.rs +++ b/src/encoders/rate.rs @@ -1,5 +1,5 @@ -#[allow(non_snake_case)] -use burn::{Tensor, prelude::*, tensor::Distribution}; +#![allow(non_snake_case)] +use burn::{Tensor, prelude::Backend, tensor::Distribution}; pub fn encode( data: &Tensor, diff --git a/src/encoders/temporal.rs b/src/encoders/temporal.rs index e69de29..219a791 100644 --- a/src/encoders/temporal.rs +++ b/src/encoders/temporal.rs @@ -0,0 +1,120 @@ +#![allow(non_snake_case)] +use burn::{Tensor, prelude::Backend}; + +// Must test this with python implementation, across dimensions. +pub fn encode( + data: &Tensor, + time_steps: usize, + linear: bool, +) -> Tensor { + if time_steps < 1 { + panic!("Time steps must be greater than 0"); + } + // assumes the data is normalised. + if !linear { + panic!("Log conversion not implemented.") + } + let idxs = data + .clone() + .mul_scalar((time_steps - 1) as i32) + .round() + .unsqueeze_dim::(0) + .repeat_dim(0, time_steps) + .int(); + let out = idxs.zeros_like(); + let ones = out.ones_like(); + out.scatter(0, idxs, ones, burn::tensor::IndexingUpdateOp::Add) + .float() + .div_scalar(time_steps as u32) +} + +#[cfg(test)] +mod temporalEncoderTests { + + use burn::{ + Tensor, + backend::{NdArray, ndarray::NdArrayDevice}, + tensor::{Float, Shape}, + }; + type B = NdArray; + use super::encode; + use rstest::rstest; + use rstest_reuse::{self, *}; + + #[template] + #[rstest] + #[case(1, 1)] + #[case(10, 1)] + #[case(100, 1)] + #[case(1, 10)] + #[case(10, 10)] + #[case(100, 10)] + #[case(1, 100)] + #[case(10, 100)] + #[case(100, 100)] + + fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {} + + #[apply(testShapeStepComb)] + fn test1DZeros(#[case] t: usize, #[case] s: usize) { + let device = NdArrayDevice::default(); + let t0 = Tensor::::zeros(Shape::new([s]), &device); + let out = encode::(&t0, t, true); + assert_eq!( + out.shape(), + Shape::new([t, s]), + "Shape testing failed. Expected Shape: [{}, {}], got: {}", + s, + t, + out.shape() + ); + assert_eq!(out.sum().into_scalar(), s as f32); + } + + #[apply(testShapeStepComb)] + fn test1DOnes(#[case] t: usize, #[case] s: usize) { + let device = NdArrayDevice::default(); + let t0 = Tensor::::ones(Shape::new([s]), &device); + let out = encode::(&t0, t, true); + assert_eq!( + out.shape(), + Shape::new([t, s]), + "Shape testing failed. Expected Shape: [{}, {}], got: {}", + t, + s, + out.shape() + ); + assert_eq!(out.clone().int().sum().into_scalar(), (s) as i64); + } + #[apply(testShapeStepComb)] + fn test2DZeros(#[case] t: usize, #[case] s: usize) { + let device = NdArrayDevice::default(); + let t0 = Tensor::::zeros(Shape::new([s; 2]), &device); + let out = encode::(&t0, t, true); + assert_eq!( + out.shape(), + Shape::new([t, s, s]), + "Shape testing failed. Expected Shape: [{}, {}], got: {}", + s, + t, + out.shape() + ); + assert_eq!(out.sum().into_scalar(), (s * s) as f32); + } + + #[apply(testShapeStepComb)] + fn test2DOnes(#[case] t: usize, #[case] s: usize) { + let device = NdArrayDevice::default(); + let t0 = Tensor::::ones(Shape::new([s; 2]), &device); + let out = encode::(&t0, t, true); + assert_eq!( + out.shape(), + Shape::new([t, s, s]), + "Shape testing failed. Expected Shape: [{}, {}], got: {}", + t, + s, + out.shape() + ); + assert_eq!(out.clone().int().sum().into_scalar(), (s * s) as i64); + } +} diff --git a/src/lib.rs b/src/lib.rs index aa7f25d..aeeded2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,17 +1,4 @@ #[allow(non_snake_case)] -mod encoders; - -pub fn add(left: u64, right: u64) -> u64 { - left + right -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn it_works() { - let result = add(2, 2); - assert_eq!(result, 4); - } -} +pub mod encoders; +pub mod neurons; +pub mod surrogate; diff --git a/src/neurons/leaky.rs b/src/neurons/leaky.rs new file mode 100644 index 0000000..778738d --- /dev/null +++ b/src/neurons/leaky.rs @@ -0,0 +1,14 @@ +use burn::{Tensor, module::Module, prelude::Backend}; + +#[derive(Debug, Module, Clone, Default)] +pub struct LIF; + +impl LIF { + pub fn new() -> Self { + Self + } + + pub fn forward(&self, input: Tensor) -> Tensor { + input + } +} diff --git a/src/neurons/mod.rs b/src/neurons/mod.rs new file mode 100644 index 0000000..600bd49 --- /dev/null +++ b/src/neurons/mod.rs @@ -0,0 +1 @@ +mod leaky; diff --git a/src/surrogate/fast_sigmoid.rs b/src/surrogate/fast_sigmoid.rs new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/surrogate/fast_sigmoid.rs @@ -0,0 +1 @@ + diff --git a/src/surrogate/mod.rs b/src/surrogate/mod.rs new file mode 100644 index 0000000..418a990 --- /dev/null +++ b/src/surrogate/mod.rs @@ -0,0 +1 @@ +mod fast_sigmoid;