added some lif logic
Some checks failed
/ Tests (push) Failing after 9m42s

This commit is contained in:
2026-05-28 13:09:40 +05:30
parent e32a4ac2d5
commit 080c04e992

View File

@@ -1,8 +1,7 @@
#![allow(non_snake_case)] #![allow(non_snake_case, dead_code)]
use crate::neurons::heaviside::heaviside; use crate::neurons::heaviside::heaviside;
use crate::surrogate::SurrogateFn; use crate::surrogate::SurrogateFn;
use burn::module::Param;
use burn::prelude::*; use burn::prelude::*;
use burn::{Tensor, module::Module, prelude::Backend}; use burn::{Tensor, module::Module, prelude::Backend};
@@ -15,50 +14,71 @@ pub struct LIFConfig {
} }
impl LIFConfig { impl LIFConfig {
pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> LIF<B, D> { pub fn init<B: Backend, const D: usize>(&self) -> LIF {
let initMem = Param::from_tensor(Tensor::<B, D>::zeros([1, self.neurons], device));
LIF { LIF {
beta: self.beta, beta: self.beta,
threshold: self.threshold, threshold: self.threshold,
neurons: self.neurons, neurons: self.neurons,
hidden: initMem,
} }
} }
} }
// TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling.. // TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling..
#[derive(Debug, Module)] #[derive(Debug, Module, Clone)]
pub struct LIF<B: Backend, const D: usize> { pub struct LIF {
beta: f32, beta: f32,
threshold: f32, threshold: f32,
neurons: usize, neurons: usize,
pub hidden: Param<Tensor<B, D>>,
} }
impl<B: Backend, const D: usize> LIF<B, D> { impl LIF {
pub fn forward(&mut self, input: Tensor<B, D>) -> Tensor<B, D> { pub fn forward<B: Backend, const D: usize>(
// leaky and Integrate &mut self,
// input: Tensor<B, D>,
let curMem = self.hidden.val(); mem: Tensor<B, D>,
let nxtMem = curMem.mul_scalar(self.beta).add(input.clone()); ) -> [Tensor<B, D>; 2] {
// fire // check if input shape and mem shape are same. init to zero of input shape if not.
let spikes = heaviside::<B, D>(nxtMem, self.threshold, SurrogateFn::FastSigmoid); if input.shape() != mem.shape() {
panic!(
"Input shape {} and memory shape {} are different",
input.shape(),
mem.shape()
)
}
self.hidden // memory reset at current state.
.val() let resetSignal = self.mem_reset(mem.clone());
.sub(spikes.clone().mul_scalar(self.threshold)); // requires update step fix. currently doesnt update.
spikes // Decay memory and add input (B*v + X)
let dmem = mem.mul_scalar(self.beta).add(input);
// Reset memory based on reset method.
let outMem = self.step_subtract(dmem, resetSignal);
// Generate output spikes
let spikes = heaviside(outMem.clone(), self.threshold, SurrogateFn::FastSigmoid);
[spikes, outMem]
} }
fn mem_reset(&self, mem: Tensor<B, D>) -> Tensor<B, D> { fn mem_reset<B: Backend, const D: usize>(&self, mem: Tensor<B, D>) -> Tensor<B, D> {
// Generates reset signal.
// Take diff of mem and threshold and pass through heaviside function.
mem.sub_scalar(self.threshold) mem.sub_scalar(self.threshold)
.greater_elem(0.0)
.float()
.detach()
} }
pub fn reset(&mut self) { fn step_subtract<B: Backend, const D: usize>(
// self.hidden = self.hidden.zeros_like(); &self,
input: Tensor<B, D>,
reset: Tensor<B, D>,
) -> Tensor<B, D> {
input - reset.mul_scalar(self.threshold)
} }
pub fn init(&mut self, batch: usize, device: &B::Device) { pub fn reset<B: Backend, const D: usize>(&self, mem: &Tensor<B, D>) -> Tensor<B, D> {
// self.hidden = Tensor::zeros(Shape::new([batch, self.neurons]), device); Tensor::zeros_like(mem)
} }
} }