Skip to content

Commit cab6687

Browse files
committed
finish model
1 parent 04c989b commit cab6687

File tree

19 files changed

+27928
-1
lines changed

19 files changed

+27928
-1
lines changed

configs/rec/MixTex.yml

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
Global:
2+
use_gpu: True
3+
epoch_num: 5
4+
log_smooth_window: 20 # for logging metrics during training procedure
5+
print_batch_step: 10
6+
save_model_dir: ./output/MixTex
7+
save_epoch_step: 5
8+
max_seq_len: 256
9+
eval_batch_step: [0, 5]
10+
cal_metric_during_train: true
11+
pretrained_model:
12+
checkpoints:
13+
save_inference_dir:
14+
use_visualdl: false
15+
infer_img: doc/datasets/pme_demo/0000013.png
16+
infer_mode: False
17+
use_space_char: False
18+
rec_char_dict_path: ./ppocr/utils/dict/mixtex
19+
save_res_path: ./output/rec/predicts_mixtex.txt
20+
d2s_train_image_shape: [3, 400, 500]
21+
22+
23+
Optimizer:
24+
name: AdamW
25+
beta1: 0.9
26+
beta2: 0.999
27+
lr:
28+
name: Piecewise
29+
decay_epochs : [3]
30+
values : [0.0005, 0.00005]
31+
warmup_epoch: 5
32+
regularizer:
33+
name: L2
34+
factor: 3.0e-05
35+
36+
Architecture:
37+
model_type: rec
38+
algorithm: MixTex
39+
in_channels: 3
40+
Transform:
41+
Backbone:
42+
name: SwinTransformer_tiny_patch4_window7_224
43+
img_size: 224
44+
patch_size: 4
45+
num_classes: 25678 # class num of vob
46+
input_channel: 1
47+
is_predict: False
48+
is_export: False
49+
Head:
50+
name: RobertHead
51+
pad_value: -100
52+
is_export: False
53+
decoder_args:
54+
vocab_size: 25681
55+
cross_attend: True
56+
rel_pos_bias: False
57+
use_scalenorm: False
58+
attention_probs_dropout_prob: 0.1
59+
bos_token_id: 0
60+
chunk_size_feed_forward: 0
61+
diversity_penalty: 0.0
62+
do_sample: False
63+
eos_token_id: 2
64+
hidden_act: gelu
65+
hidden_dropout_prob: 0.1
66+
hidden_size: 768
67+
max_length: 20
68+
max_position_embeddings: 3000
69+
min_length: 0
70+
num_attention_heads: 12
71+
num_hidden_layers: 4
72+
pad_token_id: 1
73+
temperature: 1.0
74+
tie_word_embeddings: True
75+
top_k: 50
76+
top_p: 1.0
77+
intermediate_size: 3072
78+
type_vocab_size: 1
79+
initializer_range: 0.02
80+
81+
82+
Loss:
83+
name: LaTeXOCRLoss
84+
85+
PostProcess:
86+
name: MixTexDecode
87+
rec_char_dict_path: ./ppocr/utils/dict/mixtex
88+
89+
Metric:
90+
name: MixTexMetric
91+
main_indicator: exp_rate
92+
cal_blue_score: True
93+
94+
Train:
95+
dataset:
96+
name: MixTexDataSet
97+
data_dir: ./data/Pseudo-Latext-ZhEn
98+
batch_size_per_pair: 24
99+
transforms:
100+
- RecResizeImg:
101+
image_shape: [3, 400, 500]
102+
- RescaleImage:
103+
scale: 0.00392156862745098
104+
- KeepKeys:
105+
keep_keys: ['image']
106+
loader:
107+
shuffle: True
108+
batch_size_per_card: 1
109+
drop_last: False
110+
num_workers: 0
111+
112+
Eval:
113+
dataset:
114+
name: MixTexDataSet
115+
data_dir: ./data/Pseudo-Latext-ZhEn
116+
data:
117+
batch_size_per_pair: 24
118+
transforms:
119+
- RecResizeImg:
120+
image_shape: [3, 400, 500]
121+
- RescaleImage:
122+
scale: 0.00392156862745098
123+
- KeepKeys:
124+
keep_keys: ['image']
125+
loader:
126+
shuffle: True
127+
batch_size_per_card: 1
128+
drop_last: False
129+
num_workers: 0

ppocr/data/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from ppocr.data.pubtab_dataset import PubTabDataSet
4040
from ppocr.data.multi_scale_sampler import MultiScaleSampler
4141
from ppocr.data.latexocr_dataset import LaTeXOCRDataSet
42+
from ppocr.data.mixtex_dataset import MixTexDataSet
4243

4344
# for PaddleX dataset_type
4445
TextDetDataset = SimpleDataSet
@@ -97,6 +98,7 @@ def build_dataloader(config, mode, device, logger, seed=None):
9798
"PubTabTableRecDataset",
9899
"KieDataset",
99100
"LaTeXOCRDataSet",
101+
"MixTexDataSet",
100102
]
101103
module_name = config[mode]["dataset"]["name"]
102104
assert module_name in support_dict, Exception(

ppocr/data/imaug/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
RFLRecResizeImg,
4343
SVTRRecAug,
4444
ParseQRecAug,
45+
RescaleImage,
4546
)
4647
from .ssl_img_aug import SSLRotateResize
4748
from .randaugment import RandAugment

ppocr/data/imaug/label_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1855,6 +1855,7 @@ def encode(
18551855
)
18561856
for encoding in encodings
18571857
]
1858+
18581859
sanitized_tokens = {}
18591860
for key in tokens_and_encodings[0][0].keys():
18601861
stack = [e for item, _ in tokens_and_encodings for e in item[key]]
@@ -2191,3 +2192,22 @@ def __call__(self, data):
21912192
data["label"] = np.array(topk["input_ids"]).astype(np.int64)[0]
21922193
data["attention_mask"] = np.array(topk["attention_mask"]).astype(np.int64)[0]
21932194
return data
2195+
2196+
2197+
class MixTexLabelEncode:
2198+
def __init__(
2199+
self,
2200+
rec_char_dict_path,
2201+
**kwargs,
2202+
):
2203+
from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer
2204+
2205+
self.tokenizer = RobertaTokenizer.from_pretrained(
2206+
pretrained_model_name_or_path=rec_char_dict_path
2207+
)
2208+
2209+
def __call__(
2210+
self, target_text, padding="max_length", max_length=256, truncation=True
2211+
):
2212+
target = self.tokenizer(target_text, padding, max_length, truncation).input_ids
2213+
return target

ppocr/data/imaug/rec_img_aug.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,19 @@ def __call__(self, data):
578578
return data
579579

580580

581+
class RescaleImage(object):
582+
def __init__(self, scale, dtype=np.float32, **kwargs):
583+
self.scale = scale
584+
self.dtype = np.float32
585+
586+
def __call__(self, data):
587+
img = data["image"]
588+
rescaled_image = img * self.scale
589+
rescaled_image = rescaled_image.astype(self.dtype)
590+
data = {"image": rescaled_image}
591+
return data
592+
593+
581594
def resize_norm_img_sar(img, image_shape, width_downsample_ratio=0.25):
582595
imgC, imgH, imgW_min, imgW_max = image_shape
583596
h = img.shape[0]

ppocr/data/mixtex_dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
from datasets import load_dataset
17+
18+
from paddle.io import Dataset
19+
from .imaug.label_ops import MixTexLabelEncode
20+
from .imaug import transform, create_operators
21+
22+
from paddlenlp.transformers.roberta.tokenizer import RobertaTokenizer
23+
24+
25+
class MixTexDataSet(Dataset):
26+
def __init__(self, config, mode, logger, seed=None):
27+
super(MixTexDataSet, self).__init__()
28+
self.logger = logger
29+
self.mode = mode.lower()
30+
31+
global_config = config["Global"]
32+
dataset_config = config[mode]["dataset"]
33+
loader_config = config[mode]["loader"]
34+
35+
self.data_dir = dataset_config["data_dir"]
36+
self.image_size = global_config["d2s_train_image_shape"]
37+
self.batchsize = dataset_config["batch_size_per_pair"]
38+
self.max_seq_len = global_config["max_seq_len"]
39+
self.rec_char_dict_path = global_config["rec_char_dict_path"]
40+
self.tokenizer = MixTexLabelEncode(self.rec_char_dict_path)
41+
42+
self.dataframe = load_dataset(self.data_dir)
43+
44+
self.ops = create_operators(dataset_config["transforms"], global_config)
45+
self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 2)
46+
self.need_reset = True
47+
48+
def __getitem__(self, idx):
49+
image = self.dataframe["train"][idx]["image"].convert("RGB")
50+
image = np.asarray(image)
51+
data = {"image": image}
52+
pixel_values = transform(data, self.ops)
53+
target_text = self.dataframe["train"][idx]["text"]
54+
target = self.tokenizer.tokenizer(
55+
target_text,
56+
padding="max_length",
57+
max_length=self.max_seq_len,
58+
truncation=True,
59+
).input_ids
60+
labels = [
61+
label if label != self.tokenizer.tokenizer.pad_token_id else -100
62+
for label in target
63+
]
64+
# labels = [label if label != self.tokenizer.pad_token_id else -100 for label in target]
65+
return (pixel_values, labels)
66+
67+
def __len__(self):
68+
return len(self.dataframe["train"])

ppocr/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from .rec_cppd_loss import CPPDLoss
4848
from .rec_latexocr_loss import LaTeXOCRLoss
4949
from .rec_unimernet_loss import UniMERNetLoss
50+
from .rec_mixtex_loss import MixTex
5051

5152
# cls loss
5253
from .cls_loss import ClsLoss

ppocr/losses/rec_mixtex_loss.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
This code is refer from:
17+
https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py
18+
"""
19+
20+
import paddle
21+
import paddle.nn as nn
22+
import paddle.nn.functional as F
23+
import numpy as np
24+
25+
26+
class MixTexLoss(nn.Layer):
27+
"""
28+
MixTex adopt CrossEntropyLoss for network training.
29+
"""
30+
31+
def __init__(self):
32+
super(MixTexLoss, self).__init__()
33+
self.ignore_index = -100
34+
self.cross = nn.CrossEntropyLoss(
35+
reduction="mean", ignore_index=self.ignore_index
36+
)
37+
38+
def forward(self, preds, batch):
39+
word_probs = preds
40+
labels = batch[1][:, 1:]
41+
word_loss = self.cross(
42+
paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
43+
paddle.reshape(labels, [-1]),
44+
)
45+
46+
loss = word_loss
47+
return {"loss": loss}

ppocr/metrics/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
__all__ = ["build_metric"]
2323

2424
from .det_metric import DetMetric, DetFCEMetric
25-
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric
25+
from .rec_metric import RecMetric, CNTMetric, CANMetric, LaTeXOCRMetric, MixTexMetric
2626
from .cls_metric import ClsMetric
2727
from .e2e_metric import E2EMetric
2828
from .distillation_metric import DistillationMetric
@@ -51,6 +51,7 @@ def build_metric(config):
5151
"CNTMetric",
5252
"CANMetric",
5353
"LaTeXOCRMetric",
54+
"MixTexMetric",
5455
]
5556

5657
config = copy.deepcopy(config)

0 commit comments

Comments
 (0)