initial commit
This commit is contained in:
7095
Cargo.lock
generated
Normal file
7095
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
11
Cargo.toml
Normal file
11
Cargo.toml
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
[package]
|
||||||
|
name = "spikers"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
burn = { version = "0.20.1", features = ["wgpu", "std", "fusion", "ndarray"] }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
rstest = "0.26.1"
|
||||||
|
rstest_reuse = "0.7.0"
|
||||||
2
src/encoders/mod.rs
Normal file
2
src/encoders/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod rate;
|
||||||
|
pub mod temporal;
|
||||||
113
src/encoders/rate.rs
Normal file
113
src/encoders/rate.rs
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
#[allow(non_snake_case)]
|
||||||
|
use burn::{Tensor, prelude::*, tensor::Distribution};
|
||||||
|
|
||||||
|
pub fn encode<B: Backend, const D: usize, const D2: usize>(
|
||||||
|
data: &Tensor<B, D>,
|
||||||
|
numSteps: usize,
|
||||||
|
) -> Tensor<B, D2> {
|
||||||
|
if D2 != D + 1 {
|
||||||
|
panic!(
|
||||||
|
"Output Dims must be 1 more than input Dims. Given D2: {}, expected: {}",
|
||||||
|
D2,
|
||||||
|
D + 1
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let newData = data.clone().unsqueeze_dim::<D2>(0).repeat_dim(0, numSteps);
|
||||||
|
let shape = newData.shape();
|
||||||
|
match numSteps {
|
||||||
|
..=0 => panic!("numSteps cannot be non-positive, received {numSteps}"),
|
||||||
|
_ => Tensor::<B, D2>::random(shape, Distribution::Uniform(0., 1.), &data.device())
|
||||||
|
.lower(newData)
|
||||||
|
.float(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod rateEncoderTests {
|
||||||
|
use burn::{
|
||||||
|
Tensor,
|
||||||
|
backend::{NdArray, ndarray::NdArrayDevice},
|
||||||
|
tensor::{Float, Shape},
|
||||||
|
};
|
||||||
|
type B = NdArray;
|
||||||
|
use super::encode;
|
||||||
|
use rstest::rstest;
|
||||||
|
use rstest_reuse::{self, *};
|
||||||
|
|
||||||
|
#[template]
|
||||||
|
#[rstest]
|
||||||
|
#[case(1, 1)]
|
||||||
|
#[case(10, 1)]
|
||||||
|
#[case(100, 1)]
|
||||||
|
#[case(1, 10)]
|
||||||
|
#[case(10, 10)]
|
||||||
|
#[case(100, 10)]
|
||||||
|
#[case(1, 100)]
|
||||||
|
#[case(10, 100)]
|
||||||
|
#[case(100, 100)]
|
||||||
|
|
||||||
|
fn testShapeStepComb(#[case] t: usize, #[case] s: usize) {}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test1DZeros(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 1>::zeros(Shape::new([s]), &device);
|
||||||
|
let out = encode::<B, 1, 2>(&t0, t);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
s,
|
||||||
|
t,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.sum().into_scalar(), 0.);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test1DOnes(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 1, Float>::ones(Shape::new([s]), &device);
|
||||||
|
let out = encode::<B, 1, 2>(&t0, t);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
t,
|
||||||
|
s,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.clone().int().sum().into_scalar(), (t * s) as i64);
|
||||||
|
}
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test2DZeros(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 2>::zeros(Shape::new([s; 2]), &device);
|
||||||
|
let out = encode::<B, 2, 3>(&t0, t);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
s,
|
||||||
|
t,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.sum().into_scalar(), 0.);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[apply(testShapeStepComb)]
|
||||||
|
fn test2DOnes(#[case] t: usize, #[case] s: usize) {
|
||||||
|
let device = NdArrayDevice::default();
|
||||||
|
let t0 = Tensor::<B, 2, Float>::ones(Shape::new([s; 2]), &device);
|
||||||
|
let out = encode::<B, 2, 3>(&t0, t);
|
||||||
|
assert_eq!(
|
||||||
|
out.shape(),
|
||||||
|
Shape::new([t, s, s]),
|
||||||
|
"Shape testing failed. Expected Shape: [{}, {}], got: {}",
|
||||||
|
t,
|
||||||
|
s,
|
||||||
|
out.shape()
|
||||||
|
);
|
||||||
|
assert_eq!(out.clone().int().sum().into_scalar(), (t * s * s) as i64);
|
||||||
|
}
|
||||||
|
}
|
||||||
0
src/encoders/temporal.rs
Normal file
0
src/encoders/temporal.rs
Normal file
17
src/lib.rs
Normal file
17
src/lib.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
#[allow(non_snake_case)]
|
||||||
|
mod encoders;
|
||||||
|
|
||||||
|
pub fn add(left: u64, right: u64) -> u64 {
|
||||||
|
left + right
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn it_works() {
|
||||||
|
let result = add(2, 2);
|
||||||
|
assert_eq!(result, 4);
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user