4
4
5
5
from ctc_decoder .language_model import LanguageModel
6
6
7
+ LOG_ZERO = float ("-inf" )
7
8
8
9
class BeamEntry :
9
10
"""Information about one single beam at specific time-step."""
10
11
11
12
def __init__ (self ):
12
- self .pr_total = 0 # blank and non-blank
13
- self .pr_non_blank = 0 # non-blank
14
- self .pr_blank = 0 # blank
15
- self .pr_text = 1 # LM score
13
+ self .pr_total = LOG_ZERO # blank and non-blank
14
+ self .pr_non_blank = LOG_ZERO # non-blank
15
+ self .pr_blank = LOG_ZERO # blank
16
+ self .pr_text = 0 # LM score
16
17
self .lm_applied = False # flag if LM was already applied to this beam
17
18
self .labeling = () # beam-labeling
18
19
20
+ def is_empty (self ):
21
+ return len (self .labeling ) == 0
19
22
20
23
class BeamState :
21
24
"""Information about all beams at specific time-step."""
@@ -27,12 +30,12 @@ def norm(self):
27
30
"""Length-normalise LM score."""
28
31
for k in self .entries .keys ():
29
32
labeling_len = len (self .entries [k ].labeling )
30
- self .entries [k ].pr_text = self . entries [ k ]. pr_text ** (1.0 / (labeling_len if labeling_len else 1.0 ))
33
+ self .entries [k ].pr_text = (1.0 / (labeling_len if labeling_len else 1.0 )) * self . entries [ k ]. pr_text
31
34
32
35
def sort (self ):
33
36
"""Return beam-labelings, sorted by probability."""
34
37
beams = [v for (_ , v ) in self .entries .items ()]
35
- sorted_beams = sorted (beams , reverse = True , key = lambda x : x .pr_total * x .pr_text )
38
+ sorted_beams = sorted (beams , reverse = True , key = lambda x : x .pr_total + x .pr_text )
36
39
return [x .labeling for x in sorted_beams ]
37
40
38
41
@@ -42,8 +45,11 @@ def apply_lm(parent_beam, child_beam, labels, lm):
42
45
c1 = labels [parent_beam .labeling [- 1 ] if parent_beam .labeling else labels .index (' ' )] # first char
43
46
c2 = labels [child_beam .labeling [- 1 ]] # second char
44
47
lm_factor = 0.01 # influence of language model
45
- bigram_prob = lm .get_char_bigram (c1 , c2 ) ** lm_factor
46
- child_beam .pr_text = parent_beam .pr_text * bigram_prob # probability of char sequence
48
+ bigram_prob = lm_factor * np .log (lm .get_char_bigram (c1 , c2 ))
49
+ if parent_beam .is_empty ():
50
+ child_beam .pr_text = bigram_prob # first char in beam
51
+ else :
52
+ child_beam .pr_text = parent_beam .pr_text + bigram_prob # probability of char sequence
47
53
child_beam .lm_applied = True # only apply LM once per beam entry
48
54
49
55
@@ -75,8 +81,8 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
75
81
last = BeamState ()
76
82
labeling = ()
77
83
last .entries [labeling ] = BeamEntry ()
78
- last .entries [labeling ].pr_blank = 1
79
- last .entries [labeling ].pr_total = 1
84
+ last .entries [labeling ].pr_blank = LOG_ZERO
85
+ last .entries [labeling ].pr_total = LOG_ZERO
80
86
81
87
# go over all time-steps
82
88
for t in range (max_T ):
@@ -89,23 +95,29 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
89
95
for labeling in best_labelings :
90
96
91
97
# probability of paths ending with a non-blank
92
- pr_non_blank = 0
98
+ pr_non_blank = LOG_ZERO
93
99
# in case of non-empty beam
94
100
if labeling :
95
101
# probability of paths with repeated last char at the end
96
- pr_non_blank = last .entries [labeling ].pr_non_blank * mat [t , labeling [- 1 ]]
102
+ if last .entries [labeling ].pr_non_blank == LOG_ZERO :
103
+ pr_non_blank = np .log (mat [t , labeling [- 1 ]]) # cannot add to -inf
104
+ else :
105
+ pr_non_blank = last .entries [labeling ].pr_non_blank + np .log (mat [t , labeling [- 1 ]])
97
106
98
107
# probability of paths ending with a blank
99
- pr_blank = last .entries [labeling ].pr_total * mat [t , blank_idx ]
108
+ if last .entries [labeling ].pr_total == LOG_ZERO :
109
+ pr_blank = np .log (mat [t , blank_idx ]) # cannot add to -inf
110
+ else :
111
+ pr_blank = last .entries [labeling ].pr_total + np .log (mat [t , blank_idx ])
100
112
101
113
# add beam at current time-step if needed
102
114
add_beam (curr , labeling )
103
115
104
116
# fill in data
105
117
curr .entries [labeling ].labeling = labeling
106
- curr .entries [labeling ].pr_non_blank += pr_non_blank
107
- curr .entries [labeling ].pr_blank += pr_blank
108
- curr .entries [labeling ].pr_total += pr_blank + pr_non_blank
118
+ curr .entries [labeling ].pr_non_blank = np . logaddexp ( curr . entries [ labeling ]. pr_non_blank , pr_non_blank )
119
+ curr .entries [labeling ].pr_blank = np . logaddexp ( curr . entries [ labeling ]. pr_blank , pr_blank )
120
+ curr .entries [labeling ].pr_total = np . logaddexp ( curr . entries [ labeling ]. pr_total , np . logaddexp ( pr_blank , pr_non_blank ))
109
121
curr .entries [labeling ].pr_text = last .entries [labeling ].pr_text
110
122
curr .entries [labeling ].lm_applied = True # LM already applied at previous time-step for this beam-labeling
111
123
@@ -116,17 +128,22 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
116
128
117
129
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
118
130
if labeling and labeling [- 1 ] == c :
119
- pr_non_blank = mat [t , c ] * last .entries [labeling ].pr_blank
131
+ # if pr_blank is 0 then we cannot extend the beam with a dupe char
132
+ # so pr_non_blank should still be 0 (-inf in log-space)
133
+ pr_non_blank = last .entries [labeling ].pr_blank + np .log (mat [t , c ])
120
134
else :
121
- pr_non_blank = mat [t , c ] * last .entries [labeling ].pr_total
135
+ if last .entries [labeling ].pr_total == LOG_ZERO :
136
+ pr_non_blank = np .log (mat [t , c ]) # cannot add to -inf
137
+ else :
138
+ pr_non_blank = last .entries [labeling ].pr_total + np .log (mat [t , c ])
122
139
123
140
# add beam at current time-step if needed
124
141
add_beam (curr , new_labeling )
125
142
126
143
# fill in data
127
144
curr .entries [new_labeling ].labeling = new_labeling
128
- curr .entries [new_labeling ].pr_non_blank += pr_non_blank
129
- curr .entries [new_labeling ].pr_total += pr_non_blank
145
+ curr .entries [new_labeling ].pr_non_blank = np . logaddexp ( curr . entries [ new_labeling ]. pr_non_blank , pr_non_blank )
146
+ curr .entries [new_labeling ].pr_total = np . logaddexp ( curr . entries [ new_labeling ]. pr_total , pr_non_blank )
130
147
131
148
# apply LM
132
149
apply_lm (curr .entries [labeling ], curr .entries [new_labeling ], labels , lm )
0 commit comments