@@ -88,6 +88,19 @@ def __init__(
8888 self .export = is_export
8989 self .eps = eps
9090
91+ self .running_mean = paddle .zeros ([self ._out_channels ], dtype = "float32" )
92+ self .running_variance = paddle .ones ([self ._out_channels ], dtype = "float32" )
93+ orin_shape = self .weight .shape
94+ new_weight = F .batch_norm (
95+ self .weight .reshape ([1 , self ._out_channels , - 1 ]),
96+ self .running_mean ,
97+ self .running_variance ,
98+ momentum = 0.0 ,
99+ epsilon = self .eps ,
100+ use_global_stats = False ,
101+ ).reshape (orin_shape )
102+ self .weight .set_value (new_weight .numpy ())
103+
91104 def forward (self , x ):
92105 if not self .training :
93106 self .export = True
@@ -96,30 +109,14 @@ def forward(self, x):
96109 x = pad_same_export (x , self ._kernel_size , self ._stride , self ._dilation )
97110 else :
98111 x = pad_same (x , self ._kernel_size , self ._stride , self ._dilation )
99- running_mean = paddle .to_tensor ([0 ] * self ._out_channels , dtype = "float32" )
100- running_variance = paddle .to_tensor ([1 ] * self ._out_channels , dtype = "float32" )
101112 if self .export :
102- weight = paddle .reshape (
103- F .batch_norm (
104- self .weight .reshape ([1 , self ._out_channels , - 1 ]).cast (
105- paddle .float32
106- ),
107- running_mean ,
108- running_variance ,
109- momentum = 0.0 ,
110- epsilon = self .eps ,
111- use_global_stats = False ,
112- ),
113- self .weight .shape ,
114- )
113+ weight = self .weight
115114 else :
116115 weight = paddle .reshape (
117116 F .batch_norm (
118- self .weight .reshape ([1 , self ._out_channels , - 1 ]).cast (
119- paddle .float32
120- ),
121- running_mean ,
122- running_variance ,
117+ self .weight .reshape ([1 , self ._out_channels , - 1 ]),
118+ self .running_mean ,
119+ self .running_variance ,
123120 training = True ,
124121 momentum = 0.0 ,
125122 epsilon = self .eps ,
0 commit comments