From e32a4ac2d56c264e06d9e37bb7b34df0d65e9db1 Mon Sep 17 00:00:00 2001 From: Phani Pavan K Date: Fri, 15 May 2026 12:02:18 +0530 Subject: [PATCH] added atan, hside, lif structure. lif update broken, needs fix. --- src/lib.rs | 2 +- src/neurons/heaviside.rs | 13 +++++++ src/neurons/leaky.rs | 66 ++++++++++++++++++++++++++++------- src/neurons/mod.rs | 1 + src/surrogate/atan.rs | 5 +++ src/surrogate/fast_sigmoid.rs | 4 +++ src/surrogate/mod.rs | 20 +++++++++++ 7 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 src/neurons/heaviside.rs create mode 100644 src/surrogate/atan.rs diff --git a/src/lib.rs b/src/lib.rs index aeeded2..2181114 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,4 @@ -#[allow(non_snake_case)] +#![allow(non_snake_case)] pub mod encoders; pub mod neurons; pub mod surrogate; diff --git a/src/neurons/heaviside.rs b/src/neurons/heaviside.rs new file mode 100644 index 0000000..b37070e --- /dev/null +++ b/src/neurons/heaviside.rs @@ -0,0 +1,13 @@ +#![allow(non_snake_case)] +use crate::surrogate::SurrogateFn; +use burn::prelude::*; + +pub fn heaviside( + input: Tensor, + thres: f32, + s: SurrogateFn, +) -> Tensor { + let y = input.clone().greater_equal_elem(thres).float(); + let yb = s.forward(input); + (y - yb.clone().detach()).detach() + yb +} diff --git a/src/neurons/leaky.rs b/src/neurons/leaky.rs index 426a694..60eeb6f 100644 --- a/src/neurons/leaky.rs +++ b/src/neurons/leaky.rs @@ -1,24 +1,64 @@ -use burn::{Tensor, module::Module, prelude::Backend, tensor::backend::AutodiffBackend}; +#![allow(non_snake_case)] -#[derive(Debug, Module)] -pub struct LIF { - pub v: Tensor, +use crate::neurons::heaviside::heaviside; +use crate::surrogate::SurrogateFn; +use burn::module::Param; +use burn::prelude::*; +use burn::{Tensor, module::Module, prelude::Backend}; + +#[derive(Config, Debug)] +pub struct LIFConfig { + beta: f32, + threshold: f32, + neurons: usize, + spikeGrad: SurrogateFn, } -impl LIF { - pub fn new(device: &B::Device) -> Self { - Self { - v: Tensor::zeros([1], device), +impl LIFConfig { + pub fn init(&self, device: &B::Device) -> LIF { + let initMem = Param::from_tensor(Tensor::::zeros([1, self.neurons], device)); + LIF { + beta: self.beta, + threshold: self.threshold, + neurons: self.neurons, + hidden: initMem, } } } -impl LIF { - pub fn forward(&mut self, input: Tensor) -> Tensor { - input +// TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling.. +#[derive(Debug, Module)] +pub struct LIF { + beta: f32, + threshold: f32, + neurons: usize, + pub hidden: Param>, +} + +impl LIF { + pub fn forward(&mut self, input: Tensor) -> Tensor { + // leaky and Integrate + // + let curMem = self.hidden.val(); + let nxtMem = curMem.mul_scalar(self.beta).add(input.clone()); + // fire + let spikes = heaviside::(nxtMem, self.threshold, SurrogateFn::FastSigmoid); + + self.hidden + .val() + .sub(spikes.clone().mul_scalar(self.threshold)); // requires update step fix. currently doesnt update. + spikes } - pub fn backward(&mut self, x: Tensor) -> Tensor { - x + fn mem_reset(&self, mem: Tensor) -> Tensor { + mem.sub_scalar(self.threshold) + } + + pub fn reset(&mut self) { + // self.hidden = self.hidden.zeros_like(); + } + + pub fn init(&mut self, batch: usize, device: &B::Device) { + // self.hidden = Tensor::zeros(Shape::new([batch, self.neurons]), device); } } diff --git a/src/neurons/mod.rs b/src/neurons/mod.rs index 600bd49..b7c9a21 100644 --- a/src/neurons/mod.rs +++ b/src/neurons/mod.rs @@ -1 +1,2 @@ +mod heaviside; mod leaky; diff --git a/src/surrogate/atan.rs b/src/surrogate/atan.rs new file mode 100644 index 0000000..8629703 --- /dev/null +++ b/src/surrogate/atan.rs @@ -0,0 +1,5 @@ +use burn::prelude::*; + +pub fn atan(input: Tensor) -> Tensor { + (1.0 / std::f32::consts::PI) * (input.mul_scalar(std::f32::consts::PI).atan()) +} diff --git a/src/surrogate/fast_sigmoid.rs b/src/surrogate/fast_sigmoid.rs index 8b13789..8c0baf4 100644 --- a/src/surrogate/fast_sigmoid.rs +++ b/src/surrogate/fast_sigmoid.rs @@ -1 +1,5 @@ +use burn::prelude::*; +pub fn fast_sigmoid(input: Tensor) -> Tensor { + input.clone() / (1 + input.abs()) +} diff --git a/src/surrogate/mod.rs b/src/surrogate/mod.rs index 418a990..e4ef76e 100644 --- a/src/surrogate/mod.rs +++ b/src/surrogate/mod.rs @@ -1 +1,21 @@ +mod atan; mod fast_sigmoid; + +use atan::atan; +use burn::{Tensor, config::Config, tensor::backend::Backend}; +use fast_sigmoid::fast_sigmoid; + +#[derive(Config, Debug)] +pub enum SurrogateFn { + FastSigmoid, + ATan, +} + +impl SurrogateFn { + pub fn forward(&self, input: Tensor) -> Tensor { + match self { + SurrogateFn::FastSigmoid => fast_sigmoid(input), + SurrogateFn::ATan => atan(input), + } + } +}