Skip to content

Commit 0697d24

Browse files
support export with pir and no pir (#14379)
1 parent 04c989b commit 0697d24

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

ppocr/utils/export_model.py

Lines changed: 52 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import yaml
1717
import json
1818
import copy
19+
import shutil
1920
import paddle
2021
import paddle.nn as nn
2122
from paddle.jit import to_static
2223

2324
from collections import OrderedDict
25+
from packaging import version
2426
from argparse import ArgumentParser, RawDescriptionHelpFormatter
2527
from ppocr.modeling.architectures import build_model
2628
from ppocr.postprocess import build_post_process
@@ -39,21 +41,23 @@ def setup_orderdict():
3941
def dump_infer_config(config, path, logger):
4042
setup_orderdict()
4143
infer_cfg = OrderedDict()
44+
if not os.path.exists(os.path.dirname(path)):
45+
os.makedirs(os.path.dirname(path))
4246
if config["Global"].get("pdx_model_name", None):
4347
infer_cfg["Global"] = {"model_name": config["Global"]["pdx_model_name"]}
4448
if config["Global"].get("uniform_output_enabled", None):
4549
arch_config = config["Architecture"]
4650
if arch_config["algorithm"] in ["SVTR_LCNet", "SVTR_HGNet"]:
4751
common_dynamic_shapes = {
48-
"x": [[1, 3, 48, 320], [1, 3, 48, 320], [8, 3, 48, 320]]
52+
"x": [[1, 3, 24, 160], [1, 3, 48, 320], [8, 3, 96, 640]]
4953
}
5054
elif arch_config["model_type"] == "det":
5155
common_dynamic_shapes = {
5256
"x": [[1, 3, 160, 160], [1, 3, 160, 160], [1, 3, 1280, 1280]]
5357
}
5458
elif arch_config["algorithm"] == "SLANet":
5559
common_dynamic_shapes = {
56-
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 192, 672]]
60+
"x": [[1, 3, 32, 32], [1, 3, 64, 448], [8, 3, 488, 488]]
5761
}
5862
elif arch_config["algorithm"] == "LaTeXOCR":
5963
common_dynamic_shapes = {
@@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
101105
logger.info("Export inference config file to {}".format(os.path.join(path)))
102106

103107

104-
def export_single_model(
105-
model, arch_config, save_path, logger, input_shape=None, quanter=None
106-
):
108+
def dynamic_to_static(model, arch_config, logger, input_shape=None):
107109
if arch_config["algorithm"] == "SRN":
108110
max_text_length = arch_config["Head"]["max_text_length"]
109111
other_shape = [
@@ -262,9 +264,46 @@ def export_single_model(
262264
for layer in model.sublayers():
263265
if hasattr(layer, "rep") and not getattr(layer, "is_repped"):
264266
layer.rep()
267+
return model
268+
269+
270+
def export_single_model(
271+
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
272+
):
273+
274+
model = dynamic_to_static(model, arch_config, logger, input_shape)
265275

266276
if quanter is None:
267-
paddle.jit.save(model, save_path)
277+
paddle_version = version.parse(paddle.__version__)
278+
if (
279+
paddle_version >= version.parse("3.0.0b2")
280+
or paddle_version == version.parse("0.0.0")
281+
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
282+
save_path = os.path.dirname(save_path)
283+
for enable_pir in [True, False]:
284+
if not enable_pir:
285+
save_path_no_pir = os.path.join(save_path, "inference")
286+
model.forward.rollback()
287+
with paddle.pir_utils.OldIrGuard():
288+
model = dynamic_to_static(
289+
model, arch_config, logger, input_shape
290+
)
291+
paddle.jit.save(model, save_path_no_pir)
292+
else:
293+
save_path_pir = os.path.join(
294+
os.path.dirname(save_path),
295+
f"{os.path.basename(save_path)}_pir",
296+
"inference",
297+
)
298+
paddle.jit.save(model, save_path_pir)
299+
shutil.copy(
300+
yaml_path,
301+
os.path.join(
302+
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
303+
),
304+
)
305+
else:
306+
paddle.jit.save(model, save_path)
268307
else:
269308
quanter.save_quantized_model(model, save_path)
270309
logger.info("inference model is saved to {}".format(save_path))
@@ -362,19 +401,22 @@ def export(config, base_model=None, save_path=None):
362401
input_shape = rec_rs[0]["ABINetRecResizeImg"]["image_shape"] if rec_rs else None
363402
else:
364403
input_shape = None
365-
404+
dump_infer_config(config, yaml_path, logger)
366405
if arch_config["algorithm"] in [
367406
"Distillation",
368407
]: # distillation model
369408
archs = list(arch_config["Models"].values())
370409
for idx, name in enumerate(model.model_name_list):
371410
sub_model_save_path = os.path.join(save_path, name, "inference")
372411
export_single_model(
373-
model.model_list[idx], archs[idx], sub_model_save_path, logger
412+
model.model_list[idx],
413+
archs[idx],
414+
sub_model_save_path,
415+
logger,
416+
yaml_path,
374417
)
375418
else:
376419
save_path = os.path.join(save_path, "inference")
377420
export_single_model(
378-
model, arch_config, save_path, logger, input_shape=input_shape
421+
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
379422
)
380-
dump_infer_config(config, yaml_path, logger)

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@ requests
1414
albumentations==1.4.10
1515
# to be compatible with albumentations
1616
albucore==0.0.13
17+
packaging

0 commit comments

Comments
 (0)