Skip to content

Commit 53b7e8e

Browse files
committed
Change Filtfilt to use macros for 1 to 6 dimensions
Similar to lfilter, using macros to help generalize to multiple dimensions. Correspondingly add a test case for 1-dim input.
1 parent bd5506b commit 53b7e8e

File tree

1 file changed

+143
-83
lines changed

1 file changed

+143
-83
lines changed

sci-rs/src/signal/filter/filtfilt.rs

Lines changed: 143 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -319,96 +319,95 @@ where
319319
}
320320
}
321321

322-
impl<S, T> FiltFilt<S, T, 2> for ArrayBase<S, Dim<[Ix; 2]>>
323-
where
324-
S: Data<Elem = T>,
325-
{
326-
fn filtfilt<'a>(
327-
b: ArrayView1<'a, T>,
328-
a: ArrayView1<'a, T>,
329-
x: Self,
330-
axis: Option<isize>,
331-
padding: Option<FiltFiltPad>,
332-
) -> Result<Array<T, Dim<[Ix; 2]>>>
333-
where
334-
T: Clone + Add<T, Output = T> + Sub<T, Output = T> + num_traits::One,
335-
Dim<[Ix; 2]>: Dimension,
336-
T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn
337-
{
338-
let axis = {
339-
if axis.is_some_and(|axis| {
340-
!(if axis < 0 {
341-
axis.unsigned_abs() <= 2
342-
} else {
343-
axis.unsigned_abs() < 2
344-
})
345-
}) {
346-
return Err(Error::InvalidArg {
322+
macro_rules! filtfilt_for_dim {
323+
($N: literal) => {
324+
impl<S, T> FiltFilt<S, T, $N> for ArrayBase<S, Dim<[Ix; $N]>>
325+
where
326+
S: Data<Elem = T>,
327+
{
328+
fn filtfilt<'a>(
329+
b: ArrayView1<'a, T>,
330+
a: ArrayView1<'a, T>,
331+
x: Self,
332+
axis: Option<isize>,
333+
padding: Option<FiltFiltPad>,
334+
) -> Result<Array<T, Dim<[Ix; $N]>>>
335+
where
336+
T: Clone + Add<T, Output = T> + Sub<T, Output = T> + num_traits::One,
337+
Dim<[Ix; $N]>: Dimension,
338+
T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn
339+
{
340+
let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg {
347341
arg: "axis".into(),
348342
reason: "index out of range.".into(),
349-
});
350-
}
351-
352-
// We make a best effort to convert into appropriate usize.
353-
let axis: isize = axis.unwrap_or(-1);
354-
if axis >= 0 {
355-
axis.unsigned_abs()
356-
} else {
357-
a.ndim()
358-
.checked_add_signed(axis)
359-
.expect("Invalid add to `axis` option")
360-
}
361-
};
362-
let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?;
363-
364-
let zi: Array<T, Dim<[Ix; 2]>> = {
365-
let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap());
366-
let mut sh = [1; 2];
367-
sh[axis] = zi.len(); // .size()?
368-
369-
zi.into_shape_with_order(sh)
370-
.map_err(|_| Error::InvalidArg {
371-
arg: "b/a".into(),
372-
reason: "Generated lfilter_zi from given b or a resulted in an error.".into(),
373-
})?
374-
};
375-
let (y, _) = {
376-
let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?;
377-
let zi_arg = zi.clone() * x0; // Is it possible to not need to clone?
378-
ArrayBase::<_, Dim<[Ix; 2]>>::lfilter(
379-
b.view(),
380-
a.view(),
381-
ext,
382-
Some(axis as _),
383-
Some(zi_arg.view()),
384-
)?
385-
};
386-
387-
let (y, _) = {
388-
let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?;
389-
let zi_arg = zi * y0; // originally zi * y0
390-
ArrayView::<T, Dim<[Ix; 2]>>::lfilter(
391-
b.view(),
392-
a.view(),
393-
unsafe { axis_reverse_unsafe(&y, axis, 2) },
394-
Some(axis as _),
395-
Some(zi_arg.view()),
396-
)?
397-
};
343+
})?;
344+
let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?;
345+
346+
let zi: Array<T, Dim<[Ix; $N]>> = {
347+
let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap());
348+
let mut sh = [1; $N];
349+
sh[axis] = zi.len(); // .size()?
350+
351+
zi.into_shape_with_order(sh)
352+
.map_err(|_| Error::InvalidArg {
353+
arg: "b/a".into(),
354+
reason: "Generated lfilter_zi from given b or a resulted in an error."
355+
.into(),
356+
})?
357+
};
358+
let (y, _) = {
359+
let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?;
360+
let zi_arg = zi.clone() * x0; // Is it possible to not need to clone?
361+
ArrayBase::<_, Dim<[Ix; $N]>>::lfilter(
362+
b.view(),
363+
a.view(),
364+
ext,
365+
Some(axis as _),
366+
Some(zi_arg.view()),
367+
)?
368+
};
398369

399-
let y = unsafe { axis_reverse_unsafe(&y, axis, 2) };
370+
let (y, _) = {
371+
let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?;
372+
let zi_arg = zi * y0; // originally zi * y0
373+
ArrayView::<T, Dim<[Ix; $N]>>::lfilter(
374+
b.view(),
375+
a.view(),
376+
unsafe { axis_reverse_unsafe(&y, axis, $N) },
377+
Some(axis as _),
378+
Some(zi_arg.view()),
379+
)?
380+
};
400381

401-
if edge > 0 {
402-
let y = unsafe {
403-
axis_slice_unsafe(&y, Some(edge as _), Some(-(edge as isize)), None, axis, 2)
404-
}?;
405-
Ok(y.to_owned())
406-
} else {
407-
Ok(y.to_owned())
382+
let y = unsafe { axis_reverse_unsafe(&y, axis, $N) };
383+
384+
if edge > 0 {
385+
let y = unsafe {
386+
axis_slice_unsafe(
387+
&y,
388+
Some(edge as _),
389+
Some(-(edge as isize)),
390+
None,
391+
axis,
392+
$N,
393+
)
394+
}?;
395+
Ok(y.to_owned())
396+
} else {
397+
Ok(y.to_owned())
398+
}
399+
}
408400
}
409-
}
401+
};
410402
}
411403

404+
filtfilt_for_dim!(1);
405+
filtfilt_for_dim!(2);
406+
filtfilt_for_dim!(3);
407+
filtfilt_for_dim!(4);
408+
filtfilt_for_dim!(5);
409+
filtfilt_for_dim!(6);
410+
412411
#[cfg(test)]
413412
mod test {
414413
use super::*;
@@ -596,6 +595,67 @@ mod test {
596595
assert_eq!(result, expected);
597596
}
598597

598+
/// Tests that filtfilt works with default padding with a FIR filter.
599+
#[test]
600+
fn filtfilt_1d_fir_default_pad_small() {
601+
let x =
602+
array![0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051];
603+
let b = array![0.5, 0.5];
604+
let a = array![1.];
605+
let result = Array::<_, Dim<[_; 1]>>::filtfilt(
606+
b.view(),
607+
a.view(),
608+
x,
609+
None,
610+
Some(FiltFiltPad::default()),
611+
)
612+
.expect("Could not filtfilt none_pad");
613+
let expected =
614+
array![0., 0.5421249, 0.8507858, 0.9639715, 0.9893054, 0.9702733, 0.9275059, 0.8734051];
615+
Zip::from(&result)
616+
.and(&expected)
617+
.for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6, epsilon = 1e-10));
618+
}
619+
620+
/// Tests that filtfilt works with default padding with a FIR filter.
621+
#[test]
622+
fn filtfilt_1d_fir_default_pad_big() {
623+
// n_elems = 25
624+
// x = np.sin(np.log(np.linspace(1., n_elems, n_elems)))
625+
// b = firwin(8, 0.2)
626+
// a = np.array([1.])
627+
// expected = filtfilt(b, a, x)
628+
629+
let x = array![
630+
0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051,
631+
0.8101266, 0.7439803, 0.6770137, 0.6104955, 0.5452131, 0.481649, 0.4200881, 0.3606866,
632+
0.3035148, 0.2485867, 0.1958789, 0.1453437, 0.0969178, 0.0505287, 0.0060984,
633+
-0.0364531, -0.0772063
634+
];
635+
let b = array![
636+
0.0087547, 0.0479489, 0.1640244, 0.279272, 0.279272, 0.1640244, 0.0479489, 0.0087547
637+
];
638+
let a = array![1.];
639+
let result = Array::<_, Dim<[_; 1]>>::filtfilt(
640+
b.view(),
641+
a.view(),
642+
x,
643+
None,
644+
Some(FiltFiltPad::default()),
645+
)
646+
.expect("Could not filtfilt none_pad");
647+
let expected = array![
648+
0., 0.3503788, 0.6340265, 0.8172474, 0.9055143, 0.9253101, 0.9036955, 0.8594274,
649+
0.8033733, 0.7414859, 0.6771011, 0.6121664, 0.5478511, 0.4848631, 0.4236259, 0.3643826,
650+
0.3072603, 0.25231, 0.1995331, 0.1488972, 0.1003401, 0.0537529, 0.0089268, -0.0345238,
651+
-0.0772063
652+
];
653+
654+
Zip::from(&result)
655+
.and(&expected)
656+
.for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-5, epsilon = 1e-10));
657+
}
658+
599659
/// Tests that filtfilt works with no padding with a FIR filter.
600660
#[test]
601661
fn filtfilt_2d_fir_none_pad() {

0 commit comments

Comments
 (0)