|
4 | 4 |
|
5 | 5 |
|
6 | 6 | def ctcBestPath(mat, classes):
|
7 |
| - "implements best path decoding as shown by Graves (Dissertation, p63)" |
| 7 | + "implements best path decoding as shown by Graves (Dissertation, p63)" |
8 | 8 |
|
9 |
| - # dim0=t, dim1=c |
10 |
| - maxT, maxC = mat.shape |
11 |
| - label = '' |
12 |
| - blankIdx = len(classes) |
13 |
| - lastMaxIdx = maxC # init with invalid label |
| 9 | + # get list of char indices along best path |
| 10 | + best_path = np.argmax(mat, axis=1) |
14 | 11 |
|
15 |
| - for t in range(maxT): |
16 |
| - maxIdx = np.argmax(mat[t, :]) |
| 12 | + # collapse best path and map char indices to string |
| 13 | + blank_idx = len(classes) |
| 14 | + last_char_idx = blank_idx |
| 15 | + res = '' |
| 16 | + for char_idx in best_path: |
| 17 | + if char_idx != last_char_idx and char_idx != blank_idx: |
| 18 | + res += classes[char_idx] |
| 19 | + last_char_idx = char_idx |
17 | 20 |
|
18 |
| - if maxIdx != lastMaxIdx and maxIdx != blankIdx: |
19 |
| - label += classes[maxIdx] |
20 |
| - |
21 |
| - lastMaxIdx = maxIdx |
22 |
| - |
23 |
| - return label |
| 21 | + return res |
24 | 22 |
|
25 | 23 |
|
26 | 24 | def testBestPath():
|
27 |
| - "test decoder" |
28 |
| - classes = 'ab' |
29 |
| - mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) |
30 |
| - print('Test best path decoding') |
31 |
| - expected = '' |
32 |
| - actual = ctcBestPath(mat, classes) |
33 |
| - print('Expected: "' + expected + '"') |
34 |
| - print('Actual: "' + actual + '"') |
35 |
| - print('OK' if expected == actual else 'ERROR') |
| 25 | + "test decoder" |
| 26 | + classes = 'ab' |
| 27 | + mat = np.array([[0.4, 0, 0.6], [0.4, 0, 0.6]]) |
| 28 | + print('Test best path decoding') |
| 29 | + expected = '' |
| 30 | + actual = ctcBestPath(mat, classes) |
| 31 | + print('Expected: "' + expected + '"') |
| 32 | + print('Actual: "' + actual + '"') |
| 33 | + print('OK' if expected == actual else 'ERROR') |
36 | 34 |
|
37 | 35 |
|
38 | 36 | if __name__ == '__main__':
|
39 |
| - testBestPath() |
| 37 | + testBestPath() |
0 commit comments