2626from ppocr .utils .logging import get_logger
2727from ppocr .utils .network import maybe_download_params
2828
29+ try :
30+ import encryption # Attempt to import the encryption module for AIStudio's encryption model
31+
32+ encrypted = encryption .is_encryption_needed ()
33+ except ImportError :
34+ get_logger ().warning ("Skipping import of the encryption module." )
35+ encrypted = False # Encryption is not needed if the module cannot be imported
36+
2937__all__ = ["load_model" ]
3038
3139
@@ -278,13 +286,11 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
278286 else :
279287 train_results = {}
280288 train_results ["model_name" ] = config ["Global" ]["pdx_model_name" ]
281- label_dict_path = os .path .abspath (
282- config ["Global" ].get ("character_dict_path" , "" )
283- )
289+ label_dict_path = config ["Global" ].get ("character_dict_path" , "" )
284290 if label_dict_path != "" :
291+ label_dict_path = os .path .abspath (label_dict_path )
285292 if not os .path .exists (label_dict_path ):
286293 label_dict_path = ""
287- label_dict_path = label_dict_path
288294 train_results ["label_dict" ] = label_dict_path
289295 train_results ["train_log" ] = "train.log"
290296 train_results ["visualdl_log" ] = ""
@@ -305,9 +311,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
305311 raise ValueError ("No metric score found." )
306312 train_results ["models" ]["best" ]["score" ] = metric_score
307313 for tag in save_model_tag :
308- train_results ["models" ]["best" ][tag ] = os .path .join (
309- prefix , f"{ prefix } .{ tag } " if tag != "pdstates" else f"{ prefix } .states"
310- )
314+ if tag == "pdparams" and encrypted :
315+ train_results ["models" ]["best" ][tag ] = os .path .join (
316+ prefix ,
317+ (
318+ f"{ prefix } .encrypted.{ tag } "
319+ if tag != "pdstates"
320+ else f"{ prefix } .states"
321+ ),
322+ )
323+ else :
324+ train_results ["models" ]["best" ][tag ] = os .path .join (
325+ prefix ,
326+ f"{ prefix } .{ tag } " if tag != "pdstates" else f"{ prefix } .states" ,
327+ )
311328 for tag in save_inference_tag :
312329 train_results ["models" ]["best" ][tag ] = os .path .join (
313330 prefix ,
@@ -329,9 +346,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
329346 metric_score = 0
330347 train_results ["models" ][f"last_{ 1 } " ]["score" ] = metric_score
331348 for tag in save_model_tag :
332- train_results ["models" ][f"last_{ 1 } " ][tag ] = os .path .join (
333- prefix , f"{ prefix } .{ tag } " if tag != "pdstates" else f"{ prefix } .states"
334- )
349+ if tag == "pdparams" and encrypted :
350+ train_results ["models" ][f"last_{ 1 } " ][tag ] = os .path .join (
351+ prefix ,
352+ (
353+ f"{ prefix } .encrypted.{ tag } "
354+ if tag != "pdstates"
355+ else f"{ prefix } .states"
356+ ),
357+ )
358+ else :
359+ train_results ["models" ][f"last_{ 1 } " ][tag ] = os .path .join (
360+ prefix ,
361+ f"{ prefix } .{ tag } " if tag != "pdstates" else f"{ prefix } .states" ,
362+ )
335363 for tag in save_inference_tag :
336364 train_results ["models" ][f"last_{ 1 } " ][tag ] = os .path .join (
337365 prefix ,
0 commit comments