Skip to content

Commit fe60246

Browse files
authored
Merge pull request #70 from RapidAI/add_param_for_ocr
Add param for ocr
2 parents eaaf4d3 + 1dbfec3 commit fe60246

File tree

8 files changed

+205
-35
lines changed

8 files changed

+205
-35
lines changed

README.md

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@
1313
</div>
1414

1515
### 最近更新
16-
- **2024.10.13**
17-
- 补充最新paddlex-SLANet-plus 测评结果(已集成模型到[RapidTable](https://github.com/RapidAI/RapidTable)仓库)
1816
- **2024.10.22**
1917
- 补充复杂背景多表格检测提取方案[RapidTableDet](https://github.com/RapidAI/RapidTableDetection)
2018
- **2024.10.29**
2119
- 使用yolo11重新训练表格分类器,修正wired_table_rec v2逻辑坐标还原错误,并更新测评
20+
- **2024.11.12**
21+
- 抽离模型识别和处理过程核心阈值,方便大家进行微调适配自己的场景[微调入参参考](#核心参数)
2222

2323
### 简介
24-
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。
25-
24+
💖该仓库是用来对文档中表格做结构化识别的推理库,包括来自阿里读光有线和无线表格识别模型,llaipython(微信)贡献的有线表格模型,网易Qanything内置表格分类模型等。\
25+
[快速开始](#安装) [模型评测](#指标结果) [使用建议](#使用建议) [表格旋转及透视修正](#表格旋转及透视修正) [微调入参参考](#核心参数) [常见问题](#FAQ) [更新计划](#更新计划)
2626
#### 特点
2727

2828
**** 采用ONNXRuntime作为推理引擎,cpu下单图推理1-7s
@@ -68,6 +68,7 @@
6868
wired_table_rec_v2(有线表格精度最高): 通用场景有线表格(论文,杂志,期刊, 收据,单据,账单)
6969

7070
paddlex-SLANet-plus(综合精度最高): 文档场景表格(论文,杂志,期刊中的表格)
71+
[微调入参参考](#核心参数)
7172

7273
### 安装
7374

@@ -158,7 +159,30 @@ for i, res in enumerate(result):
158159
# cv2.imwrite(f"{out_dir}/{file_name}-visualize.jpg", img)
159160
```
160161

161-
## FAQ (Frequently Asked Questions)
162+
### 核心参数
163+
```python
164+
wired_table_rec = WiredTableRecognition()
165+
html, elasp, polygons, logic_points, ocr_res = wired_table_rec(
166+
img_path,
167+
version="v2", #默认使用v2线框模型,切换阿里读光模型可改为v1
168+
morph_close=True, # 是否进行形态学操作,辅助找到更多线框,默认为True
169+
more_h_lines=True, # 是否基于线框检测结果进行更多水平线检查,辅助找到更小线框, 默认为True
170+
h_lines_threshold = 100, # 必须开启more_h_lines, 连接横线检测像素阈值,小于该值会生成新横线,默认为100
171+
more_v_lines=True, # 是否基于线框检测结果进行更多垂直线检查,辅助找到更小线框, 默认为True
172+
v_lines_threshold = 15, # 必须开启more_v_lines, 连接竖线检测像素阈值,小于该值会生成新竖线,默认为15
173+
extend_line=True, # 是否基于线框检测结果进行线段延长,辅助找到更多线框, 默认为True
174+
need_ocr=True, # 是否进行OCR识别, 默认为True
175+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
176+
)
177+
lineless_table_rec = LinelessTableRecognition()
178+
html, elasp, polygons, logic_points, ocr_res = lineless_table_rec(
179+
need_ocr=True, # 是否进行OCR识别, 默认为True
180+
rec_again=True,# 是否针对未识别到文字的表格框,进行单独截取再识别,默认为True
181+
)
182+
```
183+
184+
185+
## FAQ
162186
1. **问:识别框丢失了内部文字信息**
163187
- 答:默认使用的rapidocr小模型,如果需要更高精度的效果,可以从 [模型列表](https://rapidai.github.io/RapidOCRDocs/model_list/#_1)
164188
下载更高精度的ocr模型,在执行时传入ocr_result即可
@@ -168,7 +192,7 @@ for i, res in enumerate(result):
168192
主要耗时在ocr阶段,可以参考 [rapidocr_paddle](https://rapidai.github.io/RapidOCRDocs/install_usage/rapidocr_paddle/usage/#_3)
169193
加速ocr识别过程
170194

171-
### TODO List
195+
### 更新计划
172196

173197
- [x] 图片小角度偏移修正方法补充
174198
- [x] 增加数据集数量,增加更多评测对比

lineless_table_rec/main.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,23 +51,37 @@ def __call__(
5151
self,
5252
content: InputType,
5353
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
54+
**kwargs
5455
):
5556
ss = time.perf_counter()
57+
rec_again = True
58+
need_ocr = True
59+
if kwargs:
60+
rec_again = kwargs.get("rec_again", True)
61+
need_ocr = kwargs.get("need_ocr", True)
5662
img = self.load_img(content)
57-
if self.ocr is None and ocr_result is None:
58-
raise ValueError(
59-
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
60-
)
61-
if ocr_result is None:
62-
ocr_result, _ = self.ocr(img)
6363
input_info = self.preprocess(img)
6464
try:
6565
polygons, slct_logi = self.infer(input_info)
6666
logi_points = self.filter_logi_points(slct_logi)
67+
if not need_ocr:
68+
sorted_polygons, idx_list = sorted_ocr_boxes(
69+
[box_4_2_poly_to_box_4_1(box) for box in polygons]
70+
)
71+
return (
72+
"",
73+
time.perf_counter() - ss,
74+
sorted_polygons,
75+
logi_points[idx_list],
76+
[],
77+
)
78+
79+
if ocr_result is None and need_ocr:
80+
ocr_result, _ = self.ocr(img)
6781
# ocr 结果匹配
6882
cell_box_det_map, no_match_ocr_det = match_ocr_cell(ocr_result, polygons)
6983
# 如果有识别框没有ocr结果,直接进行rec补充
70-
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
84+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
7185
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
7286
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
7387
# 拆分包含和重叠的识别框
@@ -81,7 +95,6 @@ def __call__(
8195
]
8296
# 生成行列对应的二维表格, 合并同行同列识别框中的的ocr识别框
8397
t_rec_ocr_list, grid = self.handle_overlap_row_col(t_rec_ocr_list)
84-
# todo 根据grid 及 not_match_orc_boxes,尝试将ocr识别填入单行单列中
8598
# 将同一个识别框中的ocr结果排序并同行合并
8699
t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list)
87100
# 渲染为html
@@ -192,11 +205,11 @@ def infer(self, input_content: Dict[str, Any]) -> Tuple[np.ndarray, np.ndarray]:
192205
def sort_and_gather_ocr_res(self, res):
193206
for i, dict_res in enumerate(res):
194207
_, sorted_idx = sorted_ocr_boxes(
195-
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5
208+
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
196209
)
197210
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
198211
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
199-
dict_res["t_ocr_res"], thehold=0.5
212+
dict_res["t_ocr_res"], thehold=0.3
200213
)
201214
return res
202215

@@ -263,12 +276,17 @@ def re_rec(
263276
img: np.ndarray,
264277
sorted_polygons: np.ndarray,
265278
cell_box_map: Dict[int, List[str]],
279+
rec_again=True,
266280
) -> Dict[int, List[any]]:
267281
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
268282
#
269283
for i in range(sorted_polygons.shape[0]):
270284
if cell_box_map.get(i):
271285
continue
286+
if not rec_again:
287+
box = sorted_polygons[i]
288+
cell_box_map[i] = [[box, "", 1]]
289+
continue
272290
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
273291
pad_img = cv2.copyMakeBorder(
274292
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)
175 KB
Loading

tests/test_lineless_table_rec.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,39 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
244244
assert (
245245
html_output == expected_html
246246
), f"Expected HTML does not match. Got: {html_output}"
247+
248+
249+
@pytest.mark.parametrize(
250+
"img_path, table_str_len, td_nums",
251+
[
252+
("table.jpg", 2870, 160),
253+
],
254+
)
255+
def test_no_rec_again(img_path, table_str_len, td_nums):
256+
img_path = test_file_dir / img_path
257+
img = cv2.imread(str(img_path))
258+
259+
table_str, *_ = table_recog(img, rec_again=False)
260+
261+
assert len(table_str) >= table_str_len
262+
assert table_str.count("td") == td_nums
263+
264+
265+
@pytest.mark.parametrize(
266+
"img_path, html_output, points_len",
267+
[
268+
("table.jpg", "", 77),
269+
("lineless_table_recognition.jpg", "", 51),
270+
],
271+
)
272+
def test_no_ocr(img_path, html_output, points_len):
273+
img_path = test_file_dir / img_path
274+
275+
html, elasp, polygons, logic_points, ocr_res = table_recog(
276+
str(img_path), need_ocr=False
277+
)
278+
assert len(ocr_res) == 0
279+
assert len(polygons) > points_len
280+
assert len(logic_points) > points_len
281+
assert len(polygons) == len(logic_points)
282+
assert html == html_output

tests/test_wired_table_rec.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,22 @@ def test_input_normal(img_path, gt_td_nums, gt2):
6565
assert td_nums >= gt_td_nums
6666

6767

68+
@pytest.mark.parametrize(
69+
"img_path, gt_td_nums",
70+
[
71+
("wired_big_box.png", 70),
72+
],
73+
)
74+
def test_input_normal(img_path, gt_td_nums):
75+
img_path = test_file_dir / img_path
76+
77+
ocr_result, _ = ocr_engine(img_path)
78+
table_str, *_ = table_recog(str(img_path), ocr_result)
79+
td_nums = get_td_nums(table_str)
80+
81+
assert td_nums >= gt_td_nums
82+
83+
6884
@pytest.mark.parametrize(
6985
"box1, box2, threshold, expected",
7086
[
@@ -264,3 +280,40 @@ def test_plot_html_table(logi_points, cell_box_map, expected_html):
264280
assert (
265281
html_output == expected_html
266282
), f"Expected HTML does not match. Got: {html_output}"
283+
284+
285+
@pytest.mark.parametrize(
286+
"img_path, gt_td_nums, gt2",
287+
[
288+
("table_recognition.jpg", 35, "d colsp"),
289+
],
290+
)
291+
def test_no_rec_again(img_path, gt_td_nums, gt2):
292+
img_path = test_file_dir / img_path
293+
294+
ocr_result, _ = ocr_engine(img_path)
295+
table_str, *_ = table_recog(str(img_path), ocr_result, rec_again=False)
296+
td_nums = get_td_nums(table_str)
297+
298+
assert td_nums >= gt_td_nums
299+
300+
301+
@pytest.mark.parametrize(
302+
"img_path, html_output, points_len",
303+
[
304+
("table2.jpg", "", 20),
305+
("row_span.png", "", 14),
306+
],
307+
)
308+
def test_no_ocr(img_path, html_output, points_len):
309+
img_path = test_file_dir / img_path
310+
311+
ocr_result, _ = ocr_engine(img_path)
312+
html, elasp, polygons, logic_points, ocr_res = table_recog(
313+
str(img_path), ocr_result, need_ocr=False
314+
)
315+
assert len(ocr_res) == 0
316+
assert len(polygons) > points_len
317+
assert len(logic_points) > points_len
318+
assert len(polygons) == len(logic_points)
319+
assert html == html_output

wired_table_rec/main.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,21 @@ def __call__(
5050
self,
5151
img: InputType,
5252
ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None,
53+
**kwargs,
5354
) -> Tuple[str, float, Any, Any, Any]:
5455
if self.ocr is None and ocr_result is None:
5556
raise ValueError(
5657
"One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed."
5758
)
5859

5960
s = time.perf_counter()
60-
61+
rec_again = True
62+
need_ocr = True
63+
if kwargs:
64+
rec_again = kwargs.get("rec_again", True)
65+
need_ocr = kwargs.get("need_ocr", True)
6166
img = self.load_img(img)
62-
polygons = self.table_line_rec(img)
67+
polygons = self.table_line_rec(img, **kwargs)
6368
if polygons is None:
6469
logging.warning("polygons is None.")
6570
return "", 0.0, None, None, None
@@ -71,12 +76,22 @@ def __call__(
7176
polygons[:, 3, :].copy(),
7277
polygons[:, 1, :].copy(),
7378
)
74-
if ocr_result is None:
79+
if not need_ocr:
80+
sorted_polygons, idx_list = sorted_ocr_boxes(
81+
[box_4_2_poly_to_box_4_1(box) for box in polygons]
82+
)
83+
return (
84+
"",
85+
time.perf_counter() - s,
86+
sorted_polygons,
87+
logi_points[idx_list],
88+
[],
89+
)
90+
if ocr_result is None and need_ocr:
7591
ocr_result, _ = self.ocr(img)
7692
cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons)
7793
# 如果有识别框没有ocr结果,直接进行rec补充
78-
# cell_box_det_map = self.re_rec_high_precise(img, polygons, cell_box_det_map)
79-
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map)
94+
cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again)
8095
# 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理
8196
t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points)
8297
# 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式
@@ -139,11 +154,11 @@ def transform_res(
139154
def sort_and_gather_ocr_res(self, res):
140155
for i, dict_res in enumerate(res):
141156
_, sorted_idx = sorted_ocr_boxes(
142-
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.5
157+
[ocr_det[0] for ocr_det in dict_res["t_ocr_res"]], threhold=0.3
143158
)
144159
dict_res["t_ocr_res"] = [dict_res["t_ocr_res"][i] for i in sorted_idx]
145160
dict_res["t_ocr_res"] = gather_ocr_list_by_row(
146-
dict_res["t_ocr_res"], threhold=0.5
161+
dict_res["t_ocr_res"], threhold=0.3
147162
)
148163
return res
149164

@@ -152,12 +167,16 @@ def re_rec(
152167
img: np.ndarray,
153168
sorted_polygons: np.ndarray,
154169
cell_box_map: Dict[int, List[str]],
170+
rec_again=True,
155171
) -> Dict[int, List[Any]]:
156172
"""找到poly对应为空的框,尝试将直接将poly框直接送到识别中"""
157-
#
158173
for i in range(sorted_polygons.shape[0]):
159174
if cell_box_map.get(i):
160175
continue
176+
if not rec_again:
177+
box = sorted_polygons[i]
178+
cell_box_map[i] = [[box, "", 1]]
179+
continue
161180
crop_img = get_rotate_crop_image(img, sorted_polygons[i])
162181
pad_img = cv2.copyMakeBorder(
163182
crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255)

wired_table_rec/table_line_rec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self, model_path: Optional[str] = None):
3636

3737
self.session = OrtInferSession(model_path)
3838

39-
def __call__(self, img: np.ndarray) -> Optional[np.ndarray]:
39+
def __call__(self, img: np.ndarray, **kwargs) -> Optional[np.ndarray]:
4040
img_info = self.preprocess(img)
4141
pred = self.infer(img_info)
4242
polygons = self.postprocess(pred)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy