Skip to content

Commit 359ab6c

Browse files
fix latex_ocr inference (#14498)
* add * update * add * add
1 parent ed6fe28 commit 359ab6c

File tree

1 file changed

+17
-20
lines changed

1 file changed

+17
-20
lines changed

ppocr/modeling/backbones/rec_resnetv2.py

Lines changed: 17 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ def __init__(
8888
self.export = is_export
8989
self.eps = eps
9090

91+
self.running_mean = paddle.zeros([self._out_channels], dtype="float32")
92+
self.running_variance = paddle.ones([self._out_channels], dtype="float32")
93+
orin_shape = self.weight.shape
94+
new_weight = F.batch_norm(
95+
self.weight.reshape([1, self._out_channels, -1]),
96+
self.running_mean,
97+
self.running_variance,
98+
momentum=0.0,
99+
epsilon=self.eps,
100+
use_global_stats=False,
101+
).reshape(orin_shape)
102+
self.weight.set_value(new_weight.numpy())
103+
91104
def forward(self, x):
92105
if not self.training:
93106
self.export = True
@@ -96,30 +109,14 @@ def forward(self, x):
96109
x = pad_same_export(x, self._kernel_size, self._stride, self._dilation)
97110
else:
98111
x = pad_same(x, self._kernel_size, self._stride, self._dilation)
99-
running_mean = paddle.to_tensor([0] * self._out_channels, dtype="float32")
100-
running_variance = paddle.to_tensor([1] * self._out_channels, dtype="float32")
101112
if self.export:
102-
weight = paddle.reshape(
103-
F.batch_norm(
104-
self.weight.reshape([1, self._out_channels, -1]).cast(
105-
paddle.float32
106-
),
107-
running_mean,
108-
running_variance,
109-
momentum=0.0,
110-
epsilon=self.eps,
111-
use_global_stats=False,
112-
),
113-
self.weight.shape,
114-
)
113+
weight = self.weight
115114
else:
116115
weight = paddle.reshape(
117116
F.batch_norm(
118-
self.weight.reshape([1, self._out_channels, -1]).cast(
119-
paddle.float32
120-
),
121-
running_mean,
122-
running_variance,
117+
self.weight.reshape([1, self._out_channels, -1]),
118+
self.running_mean,
119+
self.running_variance,
123120
training=True,
124121
momentum=0.0,
125122
epsilon=self.eps,

0 commit comments

Comments
 (0)