This commit is contained in:
@@ -1,8 +1,7 @@
|
|||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case, dead_code)]
|
||||||
|
|
||||||
use crate::neurons::heaviside::heaviside;
|
use crate::neurons::heaviside::heaviside;
|
||||||
use crate::surrogate::SurrogateFn;
|
use crate::surrogate::SurrogateFn;
|
||||||
use burn::module::Param;
|
|
||||||
use burn::prelude::*;
|
use burn::prelude::*;
|
||||||
use burn::{Tensor, module::Module, prelude::Backend};
|
use burn::{Tensor, module::Module, prelude::Backend};
|
||||||
|
|
||||||
@@ -15,50 +14,71 @@ pub struct LIFConfig {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LIFConfig {
|
impl LIFConfig {
|
||||||
pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> LIF<B, D> {
|
pub fn init<B: Backend, const D: usize>(&self) -> LIF {
|
||||||
let initMem = Param::from_tensor(Tensor::<B, D>::zeros([1, self.neurons], device));
|
|
||||||
LIF {
|
LIF {
|
||||||
beta: self.beta,
|
beta: self.beta,
|
||||||
threshold: self.threshold,
|
threshold: self.threshold,
|
||||||
neurons: self.neurons,
|
neurons: self.neurons,
|
||||||
hidden: initMem,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling..
|
// TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling..
|
||||||
#[derive(Debug, Module)]
|
#[derive(Debug, Module, Clone)]
|
||||||
pub struct LIF<B: Backend, const D: usize> {
|
pub struct LIF {
|
||||||
beta: f32,
|
beta: f32,
|
||||||
threshold: f32,
|
threshold: f32,
|
||||||
neurons: usize,
|
neurons: usize,
|
||||||
pub hidden: Param<Tensor<B, D>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, const D: usize> LIF<B, D> {
|
impl LIF {
|
||||||
pub fn forward(&mut self, input: Tensor<B, D>) -> Tensor<B, D> {
|
pub fn forward<B: Backend, const D: usize>(
|
||||||
// leaky and Integrate
|
&mut self,
|
||||||
//
|
input: Tensor<B, D>,
|
||||||
let curMem = self.hidden.val();
|
mem: Tensor<B, D>,
|
||||||
let nxtMem = curMem.mul_scalar(self.beta).add(input.clone());
|
) -> [Tensor<B, D>; 2] {
|
||||||
// fire
|
// check if input shape and mem shape are same. init to zero of input shape if not.
|
||||||
let spikes = heaviside::<B, D>(nxtMem, self.threshold, SurrogateFn::FastSigmoid);
|
if input.shape() != mem.shape() {
|
||||||
|
panic!(
|
||||||
|
"Input shape {} and memory shape {} are different",
|
||||||
|
input.shape(),
|
||||||
|
mem.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
self.hidden
|
// memory reset at current state.
|
||||||
.val()
|
let resetSignal = self.mem_reset(mem.clone());
|
||||||
.sub(spikes.clone().mul_scalar(self.threshold)); // requires update step fix. currently doesnt update.
|
|
||||||
spikes
|
// Decay memory and add input (B*v + X)
|
||||||
|
let dmem = mem.mul_scalar(self.beta).add(input);
|
||||||
|
|
||||||
|
// Reset memory based on reset method.
|
||||||
|
let outMem = self.step_subtract(dmem, resetSignal);
|
||||||
|
|
||||||
|
// Generate output spikes
|
||||||
|
let spikes = heaviside(outMem.clone(), self.threshold, SurrogateFn::FastSigmoid);
|
||||||
|
|
||||||
|
[spikes, outMem]
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mem_reset(&self, mem: Tensor<B, D>) -> Tensor<B, D> {
|
fn mem_reset<B: Backend, const D: usize>(&self, mem: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
// Generates reset signal.
|
||||||
|
// Take diff of mem and threshold and pass through heaviside function.
|
||||||
mem.sub_scalar(self.threshold)
|
mem.sub_scalar(self.threshold)
|
||||||
|
.greater_elem(0.0)
|
||||||
|
.float()
|
||||||
|
.detach()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reset(&mut self) {
|
fn step_subtract<B: Backend, const D: usize>(
|
||||||
// self.hidden = self.hidden.zeros_like();
|
&self,
|
||||||
|
input: Tensor<B, D>,
|
||||||
|
reset: Tensor<B, D>,
|
||||||
|
) -> Tensor<B, D> {
|
||||||
|
input - reset.mul_scalar(self.threshold)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn init(&mut self, batch: usize, device: &B::Device) {
|
pub fn reset<B: Backend, const D: usize>(&self, mem: &Tensor<B, D>) -> Tensor<B, D> {
|
||||||
// self.hidden = Tensor::zeros(Shape::new([batch, self.neurons]), device);
|
Tensor::zeros_like(mem)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user