added atan, hside, lif structure. lif update broken, needs fix.
Some checks failed
/ Tests (push) Failing after 3m46s
Some checks failed
/ Tests (push) Failing after 3m46s
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
#[allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
pub mod encoders;
|
pub mod encoders;
|
||||||
pub mod neurons;
|
pub mod neurons;
|
||||||
pub mod surrogate;
|
pub mod surrogate;
|
||||||
|
|||||||
13
src/neurons/heaviside.rs
Normal file
13
src/neurons/heaviside.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
#![allow(non_snake_case)]
|
||||||
|
use crate::surrogate::SurrogateFn;
|
||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
pub fn heaviside<B: Backend, const D: usize>(
|
||||||
|
input: Tensor<B, D>,
|
||||||
|
thres: f32,
|
||||||
|
s: SurrogateFn,
|
||||||
|
) -> Tensor<B, D> {
|
||||||
|
let y = input.clone().greater_equal_elem(thres).float();
|
||||||
|
let yb = s.forward(input);
|
||||||
|
(y - yb.clone().detach()).detach() + yb
|
||||||
|
}
|
||||||
@@ -1,24 +1,64 @@
|
|||||||
use burn::{Tensor, module::Module, prelude::Backend, tensor::backend::AutodiffBackend};
|
#![allow(non_snake_case)]
|
||||||
|
|
||||||
|
use crate::neurons::heaviside::heaviside;
|
||||||
|
use crate::surrogate::SurrogateFn;
|
||||||
|
use burn::module::Param;
|
||||||
|
use burn::prelude::*;
|
||||||
|
use burn::{Tensor, module::Module, prelude::Backend};
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
pub struct LIFConfig {
|
||||||
|
beta: f32,
|
||||||
|
threshold: f32,
|
||||||
|
neurons: usize,
|
||||||
|
spikeGrad: SurrogateFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LIFConfig {
|
||||||
|
pub fn init<B: Backend, const D: usize>(&self, device: &B::Device) -> LIF<B, D> {
|
||||||
|
let initMem = Param::from_tensor(Tensor::<B, D>::zeros([1, self.neurons], device));
|
||||||
|
LIF {
|
||||||
|
beta: self.beta,
|
||||||
|
threshold: self.threshold,
|
||||||
|
neurons: self.neurons,
|
||||||
|
hidden: initMem,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: tensor cloning and its lifecycle is probably wrong, may cause comp graph to drop. Refer burn example to find proper tensor handling..
|
||||||
#[derive(Debug, Module)]
|
#[derive(Debug, Module)]
|
||||||
pub struct LIF<B: Backend> {
|
pub struct LIF<B: Backend, const D: usize> {
|
||||||
pub v: Tensor<B, 1>,
|
beta: f32,
|
||||||
|
threshold: f32,
|
||||||
|
neurons: usize,
|
||||||
|
pub hidden: Param<Tensor<B, D>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> LIF<B> {
|
impl<B: Backend, const D: usize> LIF<B, D> {
|
||||||
pub fn new(device: &B::Device) -> Self {
|
pub fn forward(&mut self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
Self {
|
// leaky and Integrate
|
||||||
v: Tensor::zeros([1], device),
|
//
|
||||||
}
|
let curMem = self.hidden.val();
|
||||||
}
|
let nxtMem = curMem.mul_scalar(self.beta).add(input.clone());
|
||||||
|
// fire
|
||||||
|
let spikes = heaviside::<B, D>(nxtMem, self.threshold, SurrogateFn::FastSigmoid);
|
||||||
|
|
||||||
|
self.hidden
|
||||||
|
.val()
|
||||||
|
.sub(spikes.clone().mul_scalar(self.threshold)); // requires update step fix. currently doesnt update.
|
||||||
|
spikes
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: AutodiffBackend> LIF<B> {
|
fn mem_reset(&self, mem: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
pub fn forward(&mut self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
mem.sub_scalar(self.threshold)
|
||||||
input
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn backward(&mut self, x: Tensor<B, 2>) -> Tensor<B, 2> {
|
pub fn reset(&mut self) {
|
||||||
x
|
// self.hidden = self.hidden.zeros_like();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn init(&mut self, batch: usize, device: &B::Device) {
|
||||||
|
// self.hidden = Tensor::zeros(Shape::new([batch, self.neurons]), device);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +1,2 @@
|
|||||||
|
mod heaviside;
|
||||||
mod leaky;
|
mod leaky;
|
||||||
|
|||||||
5
src/surrogate/atan.rs
Normal file
5
src/surrogate/atan.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
pub fn atan<B: Backend, const D: usize>(input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
(1.0 / std::f32::consts::PI) * (input.mul_scalar(std::f32::consts::PI).atan())
|
||||||
|
}
|
||||||
@@ -1 +1,5 @@
|
|||||||
|
use burn::prelude::*;
|
||||||
|
|
||||||
|
pub fn fast_sigmoid<B: Backend, const D: usize>(input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
input.clone() / (1 + input.abs())
|
||||||
|
}
|
||||||
|
|||||||
@@ -1 +1,21 @@
|
|||||||
|
mod atan;
|
||||||
mod fast_sigmoid;
|
mod fast_sigmoid;
|
||||||
|
|
||||||
|
use atan::atan;
|
||||||
|
use burn::{Tensor, config::Config, tensor::backend::Backend};
|
||||||
|
use fast_sigmoid::fast_sigmoid;
|
||||||
|
|
||||||
|
#[derive(Config, Debug)]
|
||||||
|
pub enum SurrogateFn {
|
||||||
|
FastSigmoid,
|
||||||
|
ATan,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SurrogateFn {
|
||||||
|
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
match self {
|
||||||
|
SurrogateFn::FastSigmoid => fast_sigmoid(input),
|
||||||
|
SurrogateFn::ATan => atan(input),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user