1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -5107,6 +5107,7 @@ name = "spikers"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"burn",
|
||||
"burn-autodiff",
|
||||
"rstest",
|
||||
"rstest_reuse",
|
||||
]
|
||||
|
||||
@@ -4,7 +4,8 @@ version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] }
|
||||
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray", "cuda"] }
|
||||
burn-autodiff = "0.20.1"
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = "0.26.1"
|
||||
|
||||
41
examples/simple1d.rs
Normal file
41
examples/simple1d.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
#![allow(non_snake_case)]
|
||||
use burn::Tensor;
|
||||
//
|
||||
use burn::prelude::Backend;
|
||||
use spikers::encoders;
|
||||
|
||||
fn main() {
|
||||
let useGPU = {
|
||||
// use burn::backend::Wgpu;
|
||||
use burn::backend::wgpu::WgpuDevice;
|
||||
match WgpuDevice::DefaultDevice {
|
||||
WgpuDevice::Cpu => {
|
||||
println!("GPU Not Found, using NDArray Backend");
|
||||
false
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
};
|
||||
|
||||
if useGPU {
|
||||
use burn::backend::Wgpu;
|
||||
use burn::backend::wgpu::WgpuDevice;
|
||||
run::<Wgpu>(WgpuDevice::default());
|
||||
} else {
|
||||
use burn::backend::NdArray;
|
||||
use burn::backend::ndarray::NdArrayDevice;
|
||||
run::<NdArray>(NdArrayDevice::default());
|
||||
}
|
||||
// encode(data, numSteps);
|
||||
}
|
||||
|
||||
fn run<B: Backend>(device: B::Device) {
|
||||
println!("Using {:?} Device", B::name(&device));
|
||||
// let someTens = Tensor::<B, 1>::random(Shape::new([3]), Distribution::Normal(0., 1.), &device);
|
||||
let someTens = Tensor::<B, 1>::from_floats([0., 0.2, 0.4, 0.6, 0.8, 1.0], &device);
|
||||
println!("{}", encoders::rate::encode::<B, 1, 2>(&someTens, 10));
|
||||
println!(
|
||||
"{}",
|
||||
encoders::temporal::encode::<B, 1, 2>(&someTens, 10, true)
|
||||
);
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
#[allow(non_snake_case)]
|
||||
use burn::{Tensor, prelude::*, tensor::Distribution};
|
||||
#![allow(non_snake_case)]
|
||||
use burn::{Tensor, prelude::Backend, tensor::Distribution};
|
||||
|
||||
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||
data: &Tensor<B, D>,
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
#![allow(non_snake_case)]
|
||||
use burn::{Tensor, prelude::Backend};
|
||||
|
||||
// Must test this with python implementation, across dimensions.
|
||||
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||
data: &Tensor<B, D>,
|
||||
time_steps: usize,
|
||||
linear: bool,
|
||||
) -> Tensor<B, D2> {
|
||||
if time_steps < 1 {
|
||||
panic!("Time steps must be greater than 0");
|
||||
}
|
||||
// assumes the data is normalised.
|
||||
if !linear {
|
||||
panic!("Log conversion not implemented.")
|
||||
}
|
||||
let idxs = data
|
||||
.clone()
|
||||
.mul_scalar((time_steps - 1) as i32)
|
||||
.round()
|
||||
.unsqueeze_dim::<D2>(0)
|
||||
.repeat_dim(0, time_steps)
|
||||
.int();
|
||||
let out = idxs.zeros_like();
|
||||
let ones = out.ones_like();
|
||||
out.scatter(0, idxs, ones, burn::tensor::IndexingUpdateOp::Add)
|
||||
.float()
|
||||
.div_scalar(time_steps as u32)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod temporalEncoderTests {
|
||||
|
||||
use burn::{
|
||||
Tensor,
|
||||
backend::{NdArray, ndarray::NdArrayDevice},
|
||||
tensor::{Float, Shape},
|
||||
};
|
||||
type B = NdArray;
|
||||
use super::encode;
|
||||
use rstest::rstest;
|
||||
use rstest_reuse::{self, *};
|
||||
|
||||
#[template]
|
||||
#[rstest]
|
||||
#[case(1, 1)]
|
||||
#[case(10, 1)]
|
||||
#[case(100, 1)]
|
||||
#[case(1, 10)]
|
||||
#[case(10, 10)]
|
||||
#[case(100, 10)]
|
||||
#[case(1, 100)]
|
||||
#[case(10, 100)]
|
||||
#[case(100, 100)]
|
||||
|
||||
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
|
||||
let out = encode::<B, 1, 2>(&t0, t, true);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
s,
|
||||
t,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.sum().into_scalar(), s as f32);
|
||||
}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
|
||||
let out = encode::<B, 1, 2>(&t0, t, true);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
t,
|
||||
s,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.clone().int().sum().into_scalar(), (s) as i64);
|
||||
}
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
|
||||
let out = encode::<B, 2, 3>(&t0, t, true);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
s,
|
||||
t,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.sum().into_scalar(), (s * s) as f32);
|
||||
}
|
||||
|
||||
#[apply(testShapeStepComb)]
|
||||
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
|
||||
let device = NdArrayDevice::default();
|
||||
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
|
||||
let out = encode::<B, 2, 3>(&t0, t, true);
|
||||
assert_eq!(
|
||||
out.shape(),
|
||||
Shape::new([t, s, s]),
|
||||
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||
t,
|
||||
s,
|
||||
out.shape()
|
||||
);
|
||||
assert_eq!(out.clone().int().sum().into_scalar(), (s * s) as i64);
|
||||
}
|
||||
}
|
||||
|
||||
19
src/lib.rs
19
src/lib.rs
@@ -1,17 +1,4 @@
|
||||
#[allow(non_snake_case)]
|
||||
mod encoders;
|
||||
|
||||
pub fn add(left: u64, right: u64) -> u64 {
|
||||
left + right
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn it_works() {
|
||||
let result = add(2, 2);
|
||||
assert_eq!(result, 4);
|
||||
}
|
||||
}
|
||||
pub mod encoders;
|
||||
pub mod neurons;
|
||||
pub mod surrogate;
|
||||
|
||||
14
src/neurons/leaky.rs
Normal file
14
src/neurons/leaky.rs
Normal file
@@ -0,0 +1,14 @@
|
||||
use burn::{Tensor, module::Module, prelude::Backend};
|
||||
|
||||
#[derive(Debug, Module, Clone, Default)]
|
||||
pub struct LIF;
|
||||
|
||||
impl LIF {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
pub fn forward<B: Backend, const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
input
|
||||
}
|
||||
}
|
||||
1
src/neurons/mod.rs
Normal file
1
src/neurons/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod leaky;
|
||||
1
src/surrogate/fast_sigmoid.rs
Normal file
1
src/surrogate/fast_sigmoid.rs
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
1
src/surrogate/mod.rs
Normal file
1
src/surrogate/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
mod fast_sigmoid;
|
||||
Reference in New Issue
Block a user