1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5107,6 +5107,7 @@ name = "spikers"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"burn",
|
"burn",
|
||||||
|
"burn-autodiff",
|
||||||
"rstest",
|
"rstest",
|
||||||
"rstest_reuse",
|
"rstest_reuse",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ version = "0.1.0"
|
|||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] }
|
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray", "cuda"] }
|
||||||
|
burn-autodiff = "0.20.1"
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
rstest = "0.26.1"
|
rstest = "0.26.1"
|
||||||
|
|||||||
41
examples/simple1d.rs
Normal file
41
examples/simple1d.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
#![allow(non_snake_case)]
|
||||||
|
use burn::Tensor;
|
||||||
|
//
|
||||||
|
use burn::prelude::Backend;
|
||||||
|
use spikers::encoders;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let useGPU = {
|
||||||
|
// use burn::backend::Wgpu;
|
||||||
|
use burn::backend::wgpu::WgpuDevice;
|
||||||
|
match WgpuDevice::DefaultDevice {
|
||||||
|
WgpuDevice::Cpu => {
|
||||||
|
println!("GPU Not Found, using NDArray Backend");
|
||||||
|
false
|
||||||
|
}
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if useGPU {
|
||||||
|
use burn::backend::Wgpu;
|
||||||
|
use burn::backend::wgpu::WgpuDevice;
|
||||||
|
run::<Wgpu>(WgpuDevice::default());
|
||||||
|
} else {
|
||||||
|
use burn::backend::NdArray;
|
||||||
|
use burn::backend::ndarray::NdArrayDevice;
|
||||||
|
run::<NdArray>(NdArrayDevice::default());
|
||||||
|
}
|
||||||
|
// encode(data, numSteps);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run<B: Backend>(device: B::Device) {
|
||||||
|
println!("Using {:?} Device", B::name(&device));
|
||||||
|
// let someTens = Tensor::<B, 1>::random(Shape::new([3]), Distribution::Normal(0., 1.), &device);
|
||||||
|
let someTens = Tensor::<B, 1>::from_floats([0., 0.2, 0.4, 0.6, 0.8, 1.0], &device);
|
||||||
|
println!("{}", encoders::rate::encode::<B, 1, 2>(&someTens, 10));
|
||||||
|
println!(
|
||||||
|
"{}",
|
||||||
|
encoders::temporal::encode::<B, 1, 2>(&someTens, 10, true)
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
#[allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
use burn::{Tensor, prelude::*, tensor::Distribution};
|
use burn::{Tensor, prelude::Backend, tensor::Distribution};
|
||||||
|
|
||||||
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||||
data: &Tensor<B, D>,
|
data: &Tensor<B, D>,
|
||||||
|
|||||||
@@ -0,0 +1,120 @@
|
|||||||
|
#![allow(non_snake_case)]
|
||||||
|
use burn::{Tensor, prelude::Backend};
|
||||||
|
|
||||||
|
// Must test this with python implementation, across dimensions.
|
||||||
|
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||||
|
data: &Tensor<B, D>,
|
||||||
|
time_steps: usize,
|
||||||
|
linear: bool,
|
||||||
|
) -> Tensor<B, D2> {
|
||||||
|
if time_steps < 1 {
|
||||||
|
panic!("Time steps must be greater than 0");
|
||||||
|
}
|
||||||
|
// assumes the data is normalised.
|
||||||
|
if !linear {
|
||||||
|
panic!("Log conversion not implemented.")
|
||||||
|
}
|
||||||
|
let idxs = data
|
||||||
|
.clone()
|
||||||
|
.mul_scalar((time_steps - 1) as i32)
|
||||||
|
.round()
|
||||||
|
.unsqueeze_dim::<D2>(0)
|
||||||
|
.repeat_dim(0, time_steps)
|
||||||
|
.int();
|
||||||
|
let out = idxs.zeros_like();
|
||||||
|
let ones = out.ones_like();
|
||||||
|
out.scatter(0, idxs, ones, burn::tensor::IndexingUpdateOp::Add)
|
||||||
|
.float()
|
||||||
|
.div_scalar(time_steps as u32)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod temporalEncoderTests {
|
||||||
|
|
||||||
|
use burn::{
|
||||||
|
Tensor,
|
||||||
|
backend::{NdArray, ndarray::NdArrayDevice},
|
||||||
|
tensor::{Float, Shape},
|
||||||
|
};
|
||||||
|
type B = NdArray;
|
||||||
|
use super::encode;
|
||||||
|
use rstest::rstest;
|
||||||
|
use rstest_reuse::{self, *};
|
||||||
|
|
||||||
|
#[template]
|
||||||
|
#[rstest]
|
||||||
|
#[case(1, 1)]
|
||||||
|
#[case(10, 1)]
|
||||||
|
#[case(100, 1)]
|
||||||
|
#[case(1, 10)]
|
||||||
|
#[case(10, 10)]
|
||||||
|
#[case(100, 10)]
|
||||||
|
#[case(1, 100)]
|
||||||
|
#[case(10, 100)]
|
||||||
|
#[case(100, 100)]
|
||||||
|
|
||||||
|
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
|
||||||
|
let out = encode::<B, 1, 2>(&t0, t, true);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
s,
|
||||||
|
t,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.sum().into_scalar(), s as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
|
||||||
|
let out = encode::<B, 1, 2>(&t0, t, true);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
t,
|
||||||
|
s,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.clone().int().sum().into_scalar(), (s) as i64);
|
||||||
|
}
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
|
||||||
|
let out = encode::<B, 2, 3>(&t0, t, true);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
s,
|
||||||
|
t,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.sum().into_scalar(), (s * s) as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
|
||||||
|
let out = encode::<B, 2, 3>(&t0, t, true);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
t,
|
||||||
|
s,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.clone().int().sum().into_scalar(), (s * s) as i64);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
19
src/lib.rs
19
src/lib.rs
@@ -1,17 +1,4 @@
|
|||||||
#[allow(non_snake_case)]
|
#[allow(non_snake_case)]
|
||||||
mod encoders;
|
pub mod encoders;
|
||||||
|
pub mod neurons;
|
||||||
pub fn add(left: u64, right: u64) -> u64 {
|
pub mod surrogate;
|
||||||
left + right
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn it_works() {
|
|
||||||
let result = add(2, 2);
|
|
||||||
assert_eq!(result, 4);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
14
src/neurons/leaky.rs
Normal file
14
src/neurons/leaky.rs
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
use burn::{Tensor, module::Module, prelude::Backend};
|
||||||
|
|
||||||
|
#[derive(Debug, Module, Clone, Default)]
|
||||||
|
pub struct LIF;
|
||||||
|
|
||||||
|
impl LIF {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||||
|
input
|
||||||
|
}
|
||||||
|
}
|
||||||
1
src/neurons/mod.rs
Normal file
1
src/neurons/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
mod leaky;
|
||||||
1
src/surrogate/fast_sigmoid.rs
Normal file
1
src/surrogate/fast_sigmoid.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
|
||||||
1
src/surrogate/mod.rs
Normal file
1
src/surrogate/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
mod fast_sigmoid;
|
||||||
Reference in New Issue
Block a user