@@ -319,96 +319,95 @@ where
319319 }
320320}
321321
322- impl < S , T > FiltFilt < S , T , 2 > for ArrayBase < S , Dim < [ Ix ; 2 ] > >
323- where
324- S : Data < Elem = T > ,
325- {
326- fn filtfilt < ' a > (
327- b : ArrayView1 < ' a , T > ,
328- a : ArrayView1 < ' a , T > ,
329- x : Self ,
330- axis : Option < isize > ,
331- padding : Option < FiltFiltPad > ,
332- ) -> Result < Array < T , Dim < [ Ix ; 2 ] > > >
333- where
334- T : Clone + Add < T , Output = T > + Sub < T , Output = T > + num_traits:: One ,
335- Dim < [ Ix ; 2 ] > : Dimension ,
336- T : nalgebra:: RealField + Copy + core:: iter:: Sum , // From lfilter_zi_dyn
337- {
338- let axis = {
339- if axis. is_some_and ( |axis| {
340- !( if axis < 0 {
341- axis. unsigned_abs ( ) <= 2
342- } else {
343- axis. unsigned_abs ( ) < 2
344- } )
345- } ) {
346- return Err ( Error :: InvalidArg {
322+ macro_rules! filtfilt_for_dim {
323+ ( $N: literal) => {
324+ impl <S , T > FiltFilt <S , T , $N> for ArrayBase <S , Dim <[ Ix ; $N] >>
325+ where
326+ S : Data <Elem = T >,
327+ {
328+ fn filtfilt<' a>(
329+ b: ArrayView1 <' a, T >,
330+ a: ArrayView1 <' a, T >,
331+ x: Self ,
332+ axis: Option <isize >,
333+ padding: Option <FiltFiltPad >,
334+ ) -> Result <Array <T , Dim <[ Ix ; $N] >>>
335+ where
336+ T : Clone + Add <T , Output = T > + Sub <T , Output = T > + num_traits:: One ,
337+ Dim <[ Ix ; $N] >: Dimension ,
338+ T : nalgebra:: RealField + Copy + core:: iter:: Sum , // From lfilter_zi_dyn
339+ {
340+ let axis = check_and_get_axis_dyn( axis, & x) . map_err( |_| Error :: InvalidArg {
347341 arg: "axis" . into( ) ,
348342 reason: "index out of range." . into( ) ,
349- } ) ;
350- }
351-
352- // We make a best effort to convert into appropriate usize.
353- let axis: isize = axis. unwrap_or ( -1 ) ;
354- if axis >= 0 {
355- axis. unsigned_abs ( )
356- } else {
357- a. ndim ( )
358- . checked_add_signed ( axis)
359- . expect ( "Invalid add to `axis` option" )
360- }
361- } ;
362- let ( edge, ext) = validate_pad ( padding, x. view ( ) , axis, a. len ( ) . max ( b. len ( ) ) ) ?;
363-
364- let zi: Array < T , Dim < [ Ix ; 2 ] > > = {
365- let mut zi = lfilter_zi_dyn ( b. as_slice ( ) . unwrap ( ) , a. as_slice ( ) . unwrap ( ) ) ;
366- let mut sh = [ 1 ; 2 ] ;
367- sh[ axis] = zi. len ( ) ; // .size()?
368-
369- zi. into_shape_with_order ( sh)
370- . map_err ( |_| Error :: InvalidArg {
371- arg : "b/a" . into ( ) ,
372- reason : "Generated lfilter_zi from given b or a resulted in an error." . into ( ) ,
373- } ) ?
374- } ;
375- let ( y, _) = {
376- let x0 = axis_slice_unsafe ( & ext, None , Some ( 1 ) , None , axis, ext. ndim ( ) ) ?;
377- let zi_arg = zi. clone ( ) * x0; // Is it possible to not need to clone?
378- ArrayBase :: < _ , Dim < [ Ix ; 2 ] > > :: lfilter (
379- b. view ( ) ,
380- a. view ( ) ,
381- ext,
382- Some ( axis as _ ) ,
383- Some ( zi_arg. view ( ) ) ,
384- ) ?
385- } ;
386-
387- let ( y, _) = {
388- let y0 = axis_slice_unsafe ( & y, Some ( -1 ) , None , None , axis, y. ndim ( ) ) ?;
389- let zi_arg = zi * y0; // originally zi * y0
390- ArrayView :: < T , Dim < [ Ix ; 2 ] > > :: lfilter (
391- b. view ( ) ,
392- a. view ( ) ,
393- unsafe { axis_reverse_unsafe ( & y, axis, 2 ) } ,
394- Some ( axis as _ ) ,
395- Some ( zi_arg. view ( ) ) ,
396- ) ?
397- } ;
343+ } ) ?;
344+ let ( edge, ext) = validate_pad( padding, x. view( ) , axis, a. len( ) . max( b. len( ) ) ) ?;
345+
346+ let zi: Array <T , Dim <[ Ix ; $N] >> = {
347+ let mut zi = lfilter_zi_dyn( b. as_slice( ) . unwrap( ) , a. as_slice( ) . unwrap( ) ) ;
348+ let mut sh = [ 1 ; $N] ;
349+ sh[ axis] = zi. len( ) ; // .size()?
350+
351+ zi. into_shape_with_order( sh)
352+ . map_err( |_| Error :: InvalidArg {
353+ arg: "b/a" . into( ) ,
354+ reason: "Generated lfilter_zi from given b or a resulted in an error."
355+ . into( ) ,
356+ } ) ?
357+ } ;
358+ let ( y, _) = {
359+ let x0 = axis_slice_unsafe( & ext, None , Some ( 1 ) , None , axis, ext. ndim( ) ) ?;
360+ let zi_arg = zi. clone( ) * x0; // Is it possible to not need to clone?
361+ ArrayBase :: <_, Dim <[ Ix ; $N] >>:: lfilter(
362+ b. view( ) ,
363+ a. view( ) ,
364+ ext,
365+ Some ( axis as _) ,
366+ Some ( zi_arg. view( ) ) ,
367+ ) ?
368+ } ;
398369
399- let y = unsafe { axis_reverse_unsafe ( & y, axis, 2 ) } ;
370+ let ( y, _) = {
371+ let y0 = axis_slice_unsafe( & y, Some ( -1 ) , None , None , axis, y. ndim( ) ) ?;
372+ let zi_arg = zi * y0; // originally zi * y0
373+ ArrayView :: <T , Dim <[ Ix ; $N] >>:: lfilter(
374+ b. view( ) ,
375+ a. view( ) ,
376+ unsafe { axis_reverse_unsafe( & y, axis, $N) } ,
377+ Some ( axis as _) ,
378+ Some ( zi_arg. view( ) ) ,
379+ ) ?
380+ } ;
400381
401- if edge > 0 {
402- let y = unsafe {
403- axis_slice_unsafe ( & y, Some ( edge as _ ) , Some ( -( edge as isize ) ) , None , axis, 2 )
404- } ?;
405- Ok ( y. to_owned ( ) )
406- } else {
407- Ok ( y. to_owned ( ) )
382+ let y = unsafe { axis_reverse_unsafe( & y, axis, $N) } ;
383+
384+ if edge > 0 {
385+ let y = unsafe {
386+ axis_slice_unsafe(
387+ & y,
388+ Some ( edge as _) ,
389+ Some ( -( edge as isize ) ) ,
390+ None ,
391+ axis,
392+ $N,
393+ )
394+ } ?;
395+ Ok ( y. to_owned( ) )
396+ } else {
397+ Ok ( y. to_owned( ) )
398+ }
399+ }
408400 }
409- }
401+ } ;
410402}
411403
404+ filtfilt_for_dim ! ( 1 ) ;
405+ filtfilt_for_dim ! ( 2 ) ;
406+ filtfilt_for_dim ! ( 3 ) ;
407+ filtfilt_for_dim ! ( 4 ) ;
408+ filtfilt_for_dim ! ( 5 ) ;
409+ filtfilt_for_dim ! ( 6 ) ;
410+
412411#[ cfg( test) ]
413412mod test {
414413 use super :: * ;
@@ -596,6 +595,67 @@ mod test {
596595 assert_eq ! ( result, expected) ;
597596 }
598597
598+ /// Tests that filtfilt works with default padding with a FIR filter.
599+ #[ test]
600+ fn filtfilt_1d_fir_default_pad_small ( ) {
601+ let x =
602+ array ! [ 0. , 0.6389613 , 0.890577 , 0.9830277 , 0.9992535 , 0.9756868 , 0.9304659 , 0.8734051 ] ;
603+ let b = array ! [ 0.5 , 0.5 ] ;
604+ let a = array ! [ 1. ] ;
605+ let result = Array :: < _ , Dim < [ _ ; 1 ] > > :: filtfilt (
606+ b. view ( ) ,
607+ a. view ( ) ,
608+ x,
609+ None ,
610+ Some ( FiltFiltPad :: default ( ) ) ,
611+ )
612+ . expect ( "Could not filtfilt none_pad" ) ;
613+ let expected =
614+ array ! [ 0. , 0.5421249 , 0.8507858 , 0.9639715 , 0.9893054 , 0.9702733 , 0.9275059 , 0.8734051 ] ;
615+ Zip :: from ( & result)
616+ . and ( & expected)
617+ . for_each ( |& r, & e| assert_relative_eq ! ( r, e, max_relative = 1e-6 , epsilon = 1e-10 ) ) ;
618+ }
619+
620+ /// Tests that filtfilt works with default padding with a FIR filter.
621+ #[ test]
622+ fn filtfilt_1d_fir_default_pad_big ( ) {
623+ // n_elems = 25
624+ // x = np.sin(np.log(np.linspace(1., n_elems, n_elems)))
625+ // b = firwin(8, 0.2)
626+ // a = np.array([1.])
627+ // expected = filtfilt(b, a, x)
628+
629+ let x = array ! [
630+ 0. , 0.6389613 , 0.890577 , 0.9830277 , 0.9992535 , 0.9756868 , 0.9304659 , 0.8734051 ,
631+ 0.8101266 , 0.7439803 , 0.6770137 , 0.6104955 , 0.5452131 , 0.481649 , 0.4200881 , 0.3606866 ,
632+ 0.3035148 , 0.2485867 , 0.1958789 , 0.1453437 , 0.0969178 , 0.0505287 , 0.0060984 ,
633+ -0.0364531 , -0.0772063
634+ ] ;
635+ let b = array ! [
636+ 0.0087547 , 0.0479489 , 0.1640244 , 0.279272 , 0.279272 , 0.1640244 , 0.0479489 , 0.0087547
637+ ] ;
638+ let a = array ! [ 1. ] ;
639+ let result = Array :: < _ , Dim < [ _ ; 1 ] > > :: filtfilt (
640+ b. view ( ) ,
641+ a. view ( ) ,
642+ x,
643+ None ,
644+ Some ( FiltFiltPad :: default ( ) ) ,
645+ )
646+ . expect ( "Could not filtfilt none_pad" ) ;
647+ let expected = array ! [
648+ 0. , 0.3503788 , 0.6340265 , 0.8172474 , 0.9055143 , 0.9253101 , 0.9036955 , 0.8594274 ,
649+ 0.8033733 , 0.7414859 , 0.6771011 , 0.6121664 , 0.5478511 , 0.4848631 , 0.4236259 , 0.3643826 ,
650+ 0.3072603 , 0.25231 , 0.1995331 , 0.1488972 , 0.1003401 , 0.0537529 , 0.0089268 , -0.0345238 ,
651+ -0.0772063
652+ ] ;
653+
654+ Zip :: from ( & result)
655+ . and ( & expected)
656+ . for_each ( |& r, & e| assert_relative_eq ! ( r, e, max_relative = 1e-5 , epsilon = 1e-10 ) ) ;
657+ }
658+
599659 /// Tests that filtfilt works with no padding with a FIR filter.
600660 #[ test]
601661 fn filtfilt_2d_fir_none_pad ( ) {
0 commit comments