1616import yaml
1717import json
1818import copy
19+ import shutil
1920import paddle
2021import paddle .nn as nn
2122from paddle .jit import to_static
2223
2324from collections import OrderedDict
25+ from packaging import version
2426from argparse import ArgumentParser , RawDescriptionHelpFormatter
2527from ppocr .modeling .architectures import build_model
2628from ppocr .postprocess import build_post_process
@@ -39,21 +41,23 @@ def setup_orderdict():
3941def dump_infer_config (config , path , logger ):
4042 setup_orderdict ()
4143 infer_cfg = OrderedDict ()
44+ if not os .path .exists (os .path .dirname (path )):
45+ os .makedirs (os .path .dirname (path ))
4246 if config ["Global" ].get ("pdx_model_name" , None ):
4347 infer_cfg ["Global" ] = {"model_name" : config ["Global" ]["pdx_model_name" ]}
4448 if config ["Global" ].get ("uniform_output_enabled" , None ):
4549 arch_config = config ["Architecture" ]
4650 if arch_config ["algorithm" ] in ["SVTR_LCNet" , "SVTR_HGNet" ]:
4751 common_dynamic_shapes = {
48- "x" : [[1 , 3 , 48 , 320 ], [1 , 3 , 48 , 320 ], [8 , 3 , 48 , 320 ]]
52+ "x" : [[1 , 3 , 24 , 160 ], [1 , 3 , 48 , 320 ], [8 , 3 , 96 , 640 ]]
4953 }
5054 elif arch_config ["model_type" ] == "det" :
5155 common_dynamic_shapes = {
5256 "x" : [[1 , 3 , 160 , 160 ], [1 , 3 , 160 , 160 ], [1 , 3 , 1280 , 1280 ]]
5357 }
5458 elif arch_config ["algorithm" ] == "SLANet" :
5559 common_dynamic_shapes = {
56- "x" : [[1 , 3 , 32 , 32 ], [1 , 3 , 64 , 448 ], [8 , 3 , 192 , 672 ]]
60+ "x" : [[1 , 3 , 32 , 32 ], [1 , 3 , 64 , 448 ], [8 , 3 , 488 , 488 ]]
5761 }
5862 elif arch_config ["algorithm" ] == "LaTeXOCR" :
5963 common_dynamic_shapes = {
@@ -101,9 +105,7 @@ def dump_infer_config(config, path, logger):
101105 logger .info ("Export inference config file to {}" .format (os .path .join (path )))
102106
103107
104- def export_single_model (
105- model , arch_config , save_path , logger , input_shape = None , quanter = None
106- ):
108+ def dynamic_to_static (model , arch_config , logger , input_shape = None ):
107109 if arch_config ["algorithm" ] == "SRN" :
108110 max_text_length = arch_config ["Head" ]["max_text_length" ]
109111 other_shape = [
@@ -262,9 +264,46 @@ def export_single_model(
262264 for layer in model .sublayers ():
263265 if hasattr (layer , "rep" ) and not getattr (layer , "is_repped" ):
264266 layer .rep ()
267+ return model
268+
269+
270+ def export_single_model (
271+ model , arch_config , save_path , logger , yaml_path , input_shape = None , quanter = None
272+ ):
273+
274+ model = dynamic_to_static (model , arch_config , logger , input_shape )
265275
266276 if quanter is None :
267- paddle .jit .save (model , save_path )
277+ paddle_version = version .parse (paddle .__version__ )
278+ if (
279+ paddle_version >= version .parse ("3.0.0b2" )
280+ or paddle_version == version .parse ("0.0.0" )
281+ ) and os .environ .get ("FLAGS_enable_pir_api" , None ) not in ["0" , "False" ]:
282+ save_path = os .path .dirname (save_path )
283+ for enable_pir in [True , False ]:
284+ if not enable_pir :
285+ save_path_no_pir = os .path .join (save_path , "inference" )
286+ model .forward .rollback ()
287+ with paddle .pir_utils .OldIrGuard ():
288+ model = dynamic_to_static (
289+ model , arch_config , logger , input_shape
290+ )
291+ paddle .jit .save (model , save_path_no_pir )
292+ else :
293+ save_path_pir = os .path .join (
294+ os .path .dirname (save_path ),
295+ f"{ os .path .basename (save_path )} _pir" ,
296+ "inference" ,
297+ )
298+ paddle .jit .save (model , save_path_pir )
299+ shutil .copy (
300+ yaml_path ,
301+ os .path .join (
302+ os .path .dirname (save_path_pir ), os .path .basename (yaml_path )
303+ ),
304+ )
305+ else :
306+ paddle .jit .save (model , save_path )
268307 else :
269308 quanter .save_quantized_model (model , save_path )
270309 logger .info ("inference model is saved to {}" .format (save_path ))
@@ -362,19 +401,22 @@ def export(config, base_model=None, save_path=None):
362401 input_shape = rec_rs [0 ]["ABINetRecResizeImg" ]["image_shape" ] if rec_rs else None
363402 else :
364403 input_shape = None
365-
404+ dump_infer_config ( config , yaml_path , logger )
366405 if arch_config ["algorithm" ] in [
367406 "Distillation" ,
368407 ]: # distillation model
369408 archs = list (arch_config ["Models" ].values ())
370409 for idx , name in enumerate (model .model_name_list ):
371410 sub_model_save_path = os .path .join (save_path , name , "inference" )
372411 export_single_model (
373- model .model_list [idx ], archs [idx ], sub_model_save_path , logger
412+ model .model_list [idx ],
413+ archs [idx ],
414+ sub_model_save_path ,
415+ logger ,
416+ yaml_path ,
374417 )
375418 else :
376419 save_path = os .path .join (save_path , "inference" )
377420 export_single_model (
378- model , arch_config , save_path , logger , input_shape = input_shape
421+ model , arch_config , save_path , logger , yaml_path , input_shape = input_shape
379422 )
380- dump_infer_config (config , yaml_path , logger )
0 commit comments