Skip to content

Commit 8766399

Browse files
authored
Merge pull request #23 from a-sneddon/log-probs
Use log probabilities in beam search.
2 parents 62ae2b0 + 2e2ad04 commit 8766399

File tree

1 file changed

+37
-20
lines changed

1 file changed

+37
-20
lines changed

ctc_decoder/beam_search.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44

55
from ctc_decoder.language_model import LanguageModel
66

7+
LOG_ZERO = float("-inf")
78

89
class BeamEntry:
910
"""Information about one single beam at specific time-step."""
1011

1112
def __init__(self):
12-
self.pr_total = 0 # blank and non-blank
13-
self.pr_non_blank = 0 # non-blank
14-
self.pr_blank = 0 # blank
15-
self.pr_text = 1 # LM score
13+
self.pr_total = LOG_ZERO # blank and non-blank
14+
self.pr_non_blank = LOG_ZERO # non-blank
15+
self.pr_blank = LOG_ZERO # blank
16+
self.pr_text = 0 # LM score
1617
self.lm_applied = False # flag if LM was already applied to this beam
1718
self.labeling = () # beam-labeling
1819

20+
def is_empty(self):
21+
return len(self.labeling) == 0
1922

2023
class BeamState:
2124
"""Information about all beams at specific time-step."""
@@ -27,12 +30,12 @@ def norm(self):
2730
"""Length-normalise LM score."""
2831
for k in self.entries.keys():
2932
labeling_len = len(self.entries[k].labeling)
30-
self.entries[k].pr_text = self.entries[k].pr_text ** (1.0 / (labeling_len if labeling_len else 1.0))
33+
self.entries[k].pr_text = (1.0 / (labeling_len if labeling_len else 1.0)) * self.entries[k].pr_text
3134

3235
def sort(self):
3336
"""Return beam-labelings, sorted by probability."""
3437
beams = [v for (_, v) in self.entries.items()]
35-
sorted_beams = sorted(beams, reverse=True, key=lambda x: x.pr_total * x.pr_text)
38+
sorted_beams = sorted(beams, reverse=True, key=lambda x: x.pr_total + x.pr_text)
3639
return [x.labeling for x in sorted_beams]
3740

3841

@@ -42,8 +45,11 @@ def apply_lm(parent_beam, child_beam, labels, lm):
4245
c1 = labels[parent_beam.labeling[-1] if parent_beam.labeling else labels.index(' ')] # first char
4346
c2 = labels[child_beam.labeling[-1]] # second char
4447
lm_factor = 0.01 # influence of language model
45-
bigram_prob = lm.get_char_bigram(c1, c2) ** lm_factor
46-
child_beam.pr_text = parent_beam.pr_text * bigram_prob # probability of char sequence
48+
bigram_prob = lm_factor * np.log(lm.get_char_bigram(c1, c2))
49+
if parent_beam.is_empty():
50+
child_beam.pr_text = bigram_prob # first char in beam
51+
else:
52+
child_beam.pr_text = parent_beam.pr_text + bigram_prob # probability of char sequence
4753
child_beam.lm_applied = True # only apply LM once per beam entry
4854

4955

@@ -75,8 +81,8 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
7581
last = BeamState()
7682
labeling = ()
7783
last.entries[labeling] = BeamEntry()
78-
last.entries[labeling].pr_blank = 1
79-
last.entries[labeling].pr_total = 1
84+
last.entries[labeling].pr_blank = LOG_ZERO
85+
last.entries[labeling].pr_total = LOG_ZERO
8086

8187
# go over all time-steps
8288
for t in range(max_T):
@@ -89,23 +95,29 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
8995
for labeling in best_labelings:
9096

9197
# probability of paths ending with a non-blank
92-
pr_non_blank = 0
98+
pr_non_blank = LOG_ZERO
9399
# in case of non-empty beam
94100
if labeling:
95101
# probability of paths with repeated last char at the end
96-
pr_non_blank = last.entries[labeling].pr_non_blank * mat[t, labeling[-1]]
102+
if last.entries[labeling].pr_non_blank == LOG_ZERO:
103+
pr_non_blank = np.log(mat[t, labeling[-1]]) # cannot add to -inf
104+
else:
105+
pr_non_blank = last.entries[labeling].pr_non_blank + np.log(mat[t, labeling[-1]])
97106

98107
# probability of paths ending with a blank
99-
pr_blank = last.entries[labeling].pr_total * mat[t, blank_idx]
108+
if last.entries[labeling].pr_total == LOG_ZERO:
109+
pr_blank = np.log(mat[t, blank_idx]) # cannot add to -inf
110+
else:
111+
pr_blank = last.entries[labeling].pr_total + np.log(mat[t, blank_idx])
100112

101113
# add beam at current time-step if needed
102114
add_beam(curr, labeling)
103115

104116
# fill in data
105117
curr.entries[labeling].labeling = labeling
106-
curr.entries[labeling].pr_non_blank += pr_non_blank
107-
curr.entries[labeling].pr_blank += pr_blank
108-
curr.entries[labeling].pr_total += pr_blank + pr_non_blank
118+
curr.entries[labeling].pr_non_blank = np.logaddexp(curr.entries[labeling].pr_non_blank, pr_non_blank)
119+
curr.entries[labeling].pr_blank = np.logaddexp(curr.entries[labeling].pr_blank, pr_blank)
120+
curr.entries[labeling].pr_total = np.logaddexp(curr.entries[labeling].pr_total, np.logaddexp(pr_blank, pr_non_blank))
109121
curr.entries[labeling].pr_text = last.entries[labeling].pr_text
110122
curr.entries[labeling].lm_applied = True # LM already applied at previous time-step for this beam-labeling
111123

@@ -116,17 +128,22 @@ def beam_search(mat: np.ndarray, labels: str, beam_width: int = 25, lm: Optional
116128

117129
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
118130
if labeling and labeling[-1] == c:
119-
pr_non_blank = mat[t, c] * last.entries[labeling].pr_blank
131+
# if pr_blank is 0 then we cannot extend the beam with a dupe char
132+
# so pr_non_blank should still be 0 (-inf in log-space)
133+
pr_non_blank = last.entries[labeling].pr_blank + np.log(mat[t, c])
120134
else:
121-
pr_non_blank = mat[t, c] * last.entries[labeling].pr_total
135+
if last.entries[labeling].pr_total == LOG_ZERO:
136+
pr_non_blank = np.log(mat[t, c]) # cannot add to -inf
137+
else:
138+
pr_non_blank = last.entries[labeling].pr_total + np.log(mat[t, c])
122139

123140
# add beam at current time-step if needed
124141
add_beam(curr, new_labeling)
125142

126143
# fill in data
127144
curr.entries[new_labeling].labeling = new_labeling
128-
curr.entries[new_labeling].pr_non_blank += pr_non_blank
129-
curr.entries[new_labeling].pr_total += pr_non_blank
145+
curr.entries[new_labeling].pr_non_blank = np.logaddexp(curr.entries[new_labeling].pr_non_blank, pr_non_blank)
146+
curr.entries[new_labeling].pr_total = np.logaddexp(curr.entries[new_labeling].pr_total, pr_non_blank)
130147

131148
# apply LM
132149
apply_lm(curr.entries[labeling], curr.entries[new_labeling], labels, lm)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy