Skip to content

Commit e314510

Browse files
import encryption for aistudio & fix sync bn
1 parent 4f7476d commit e314510

File tree

2 files changed

+57
-10
lines changed

2 files changed

+57
-10
lines changed

ppocr/utils/export_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,6 +331,12 @@ def export_single_model(
331331
model = dynamic_to_static(model, arch_config, logger, input_shape)
332332

333333
if quanter is None:
334+
try:
335+
import encryption # Attempt to import the encryption module for AIStudio's encryption model
336+
except (
337+
ModuleNotFoundError
338+
): # Encryption is not needed if the module cannot be imported
339+
print("Skipping import of the encryption module")
334340
if config["Global"].get("export_with_pir", False):
335341
paddle_version = version.parse(paddle.__version__)
336342
assert (
@@ -349,6 +355,18 @@ def export_single_model(
349355
return
350356

351357

358+
def convert_bn(model):
359+
for n, m in model.named_children():
360+
if isinstance(m, nn.SyncBatchNorm):
361+
bn = nn.BatchNorm2D(
362+
m._num_features, m._momentum, m._epsilon, m._weight_attr, m._bias_attr
363+
)
364+
bn.set_dict(m.state_dict())
365+
setattr(model, n, bn)
366+
else:
367+
convert_bn(m)
368+
369+
352370
def export(config, base_model=None, save_path=None):
353371
if paddle.distributed.get_rank() != 0:
354372
return
@@ -424,6 +442,7 @@ def export(config, base_model=None, save_path=None):
424442
else:
425443
model = build_model(config["Architecture"])
426444
load_model(config, model, model_type=config["Architecture"]["model_type"])
445+
convert_bn(model)
427446
model.eval()
428447

429448
if not save_path:

ppocr/utils/save_load.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,14 @@
2626
from ppocr.utils.logging import get_logger
2727
from ppocr.utils.network import maybe_download_params
2828

29+
try:
30+
import encryption # Attempt to import the encryption module for AIStudio's encryption model
31+
32+
encrypted = encryption.is_encryption_needed()
33+
except ImportError:
34+
get_logger().warning("Skipping import of the encryption module.")
35+
encrypted = False # Encryption is not needed if the module cannot be imported
36+
2937
__all__ = ["load_model"]
3038

3139

@@ -278,13 +286,11 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
278286
else:
279287
train_results = {}
280288
train_results["model_name"] = config["Global"]["pdx_model_name"]
281-
label_dict_path = os.path.abspath(
282-
config["Global"].get("character_dict_path", "")
283-
)
289+
label_dict_path = config["Global"].get("character_dict_path", "")
284290
if label_dict_path != "":
291+
label_dict_path = os.path.abspath(label_dict_path)
285292
if not os.path.exists(label_dict_path):
286293
label_dict_path = ""
287-
label_dict_path = label_dict_path
288294
train_results["label_dict"] = label_dict_path
289295
train_results["train_log"] = "train.log"
290296
train_results["visualdl_log"] = ""
@@ -305,9 +311,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
305311
raise ValueError("No metric score found.")
306312
train_results["models"]["best"]["score"] = metric_score
307313
for tag in save_model_tag:
308-
train_results["models"]["best"][tag] = os.path.join(
309-
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
310-
)
314+
if tag == "pdparams" and encrypted:
315+
train_results["models"]["best"][tag] = os.path.join(
316+
prefix,
317+
(
318+
f"{prefix}.encrypted.{tag}"
319+
if tag != "pdstates"
320+
else f"{prefix}.states"
321+
),
322+
)
323+
else:
324+
train_results["models"]["best"][tag] = os.path.join(
325+
prefix,
326+
f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
327+
)
311328
for tag in save_inference_tag:
312329
train_results["models"]["best"][tag] = os.path.join(
313330
prefix,
@@ -329,9 +346,20 @@ def update_train_results(config, prefix, metric_info, done_flag=False, last_num=
329346
metric_score = 0
330347
train_results["models"][f"last_{1}"]["score"] = metric_score
331348
for tag in save_model_tag:
332-
train_results["models"][f"last_{1}"][tag] = os.path.join(
333-
prefix, f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states"
334-
)
349+
if tag == "pdparams" and encrypted:
350+
train_results["models"][f"last_{1}"][tag] = os.path.join(
351+
prefix,
352+
(
353+
f"{prefix}.encrypted.{tag}"
354+
if tag != "pdstates"
355+
else f"{prefix}.states"
356+
),
357+
)
358+
else:
359+
train_results["models"][f"last_{1}"][tag] = os.path.join(
360+
prefix,
361+
f"{prefix}.{tag}" if tag != "pdstates" else f"{prefix}.states",
362+
)
335363
for tag in save_inference_tag:
336364
train_results["models"][f"last_{1}"][tag] = os.path.join(
337365
prefix,

0 commit comments

Comments
 (0)