From 080c04e99211ef17115750fbecf9c957e7146bb1 Mon Sep 17 00:00:00 2001 From: Phani Pavan K Date: Thu, 28 May 2026 13:09:40 +0530 Subject: [PATCH] added some lif logic --- src/neurons/leaky.rs | 70 ++++++++++++++++++++++++++++---------------- 1 file changed, 45 insertions(+), 25 deletions(-) diff --git a/src/neurons/leaky.rs b/src/neurons/leaky.rs index 60eeb6f..e81d040 100644 --- a/src/neurons/leaky.rs +++ b/src/neurons/leaky.rs @@ -1,8 +1,7 @@ -#![allow(non_snake_case)] +#![allow(non_snake_case, dead_code)] use crate::neurons::heaviside::heaviside; use crate::surrogate::SurrogateFn; -use burn::module::Param; use burn::prelude::*; use burn::{Tensor, module::Module, prelude::Backend}; @@ -15,50 +14,71 @@ pub struct LIFConfig { } impl LIFConfig { - pub fn init(&self, device: &B::Device) -> LIF { - let initMem = Param::from_tensor(Tensor::::zeros([1, self.neurons], device)); + pub fn init(&self) -> LIF { LIF { beta: self.beta, threshold: self.threshold, neurons: self.neurons, - hidden: initMem, } } } // TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling.. -#[derive(Debug, Module)] -pub struct LIF { +#[derive(Debug, Module, Clone)] +pub struct LIF { beta: f32, threshold: f32, neurons: usize, - pub hidden: Param>, } -impl LIF { - pub fn forward(&mut self, input: Tensor) -> Tensor { - // leaky and Integrate - // - let curMem = self.hidden.val(); - let nxtMem = curMem.mul_scalar(self.beta).add(input.clone()); - // fire - let spikes = heaviside::(nxtMem, self.threshold, SurrogateFn::FastSigmoid); +impl LIF { + pub fn forward( + &mut self, + input: Tensor, + mem: Tensor, + ) -> [Tensor; 2] { + // check if input shape and mem shape are same. init to zero of input shape if not. + if input.shape() != mem.shape() { + panic!( + "Input shape {} and memory shape {} are different", + input.shape(), + mem.shape() + ) + } - self.hidden - .val() - .sub(spikes.clone().mul_scalar(self.threshold)); // requires update step fix. currently doesnt update. - spikes + // memory reset at current state. + let resetSignal = self.mem_reset(mem.clone()); + + // Decay memory and add input (B*v + X) + let dmem = mem.mul_scalar(self.beta).add(input); + + // Reset memory based on reset method. + let outMem = self.step_subtract(dmem, resetSignal); + + // Generate output spikes + let spikes = heaviside(outMem.clone(), self.threshold, SurrogateFn::FastSigmoid); + + [spikes, outMem] } - fn mem_reset(&self, mem: Tensor) -> Tensor { + fn mem_reset(&self, mem: Tensor) -> Tensor { + // Generates reset signal. + // Take diff of mem and threshold and pass through heaviside function. mem.sub_scalar(self.threshold) + .greater_elem(0.0) + .float() + .detach() } - pub fn reset(&mut self) { - // self.hidden = self.hidden.zeros_like(); + fn step_subtract( + &self, + input: Tensor, + reset: Tensor, + ) -> Tensor { + input - reset.mul_scalar(self.threshold) } - pub fn init(&mut self, batch: usize, device: &B::Device) { - // self.hidden = Tensor::zeros(Shape::new([batch, self.neurons]), device); + pub fn reset(&self, mem: &Tensor) -> Tensor { + Tensor::zeros_like(mem) } }