1
- from typing import Optional
1
+ from collections import defaultdict
2
+ from dataclasses import dataclass
3
+ from typing import Optional , List , Tuple
2
4
3
5
import numpy as np
4
6
5
7
from ctc_decoder .language_model import LanguageModel
6
8
7
- LOG_ZERO = float ("-inf" )
8
9
10
+ def log (x : float ) -> float :
11
+ with np .errstate (divide = 'ignore' ):
12
+ return np .log (x )
13
+
14
+
15
+ @dataclass
9
16
class BeamEntry :
10
17
"""Information about one single beam at specific time-step."""
18
+ pr_total : float = log (0 ) # blank and non-blank
19
+ pr_non_blank : float = log (0 ) # non-blank
20
+ pr_blank : float = log (0 ) # blank
21
+ pr_text : float = log (1 ) # LM score
22
+ lm_applied : bool = False # flag if LM was already applied to this beam
23
+ labeling : tuple = () # beam-labeling
11
24
12
- def __init__ (self ):
13
- self .pr_total = LOG_ZERO # blank and non-blank
14
- self .pr_non_blank = LOG_ZERO # non-blank
15
- self .pr_blank = LOG_ZERO # blank
16
- self .pr_text = 0 # LM score
17
- self .lm_applied = False # flag if LM was already applied to this beam
18
- self .labeling = () # beam-labeling
19
-
20
- def is_empty (self ):
21
- return len (self .labeling ) == 0
22
25
23
- class BeamState :
26
+ class BeamList :
24
27
"""Information about all beams at specific time-step."""
25
28
26
- def __init__ (self ):
27
- self .entries = {}
29
+ def __init__ (self ) -> None :
30
+ self .entries = defaultdict ( BeamEntry )
28
31
29
- def norm (self ):
32
+ def normalize (self ) -> None :
30
33
"""Length-normalise LM score."""
31
34
for k in self .entries .keys ():
32
35
labeling_len = len (self .entries [k ].labeling )
33
36
self .entries [k ].pr_text = (1.0 / (labeling_len if labeling_len else 1.0 )) * self .entries [k ].pr_text
34
37
35
- def sort (self ):
38
+ def sort_labelings (self ) -> List [ Tuple [ int ]] :
36
39
"""Return beam-labelings, sorted by probability."""
37
- beams = [ v for ( _ , v ) in self .entries .items ()]
40
+ beams = self .entries .values ()
38
41
sorted_beams = sorted (beams , reverse = True , key = lambda x : x .pr_total + x .pr_text )
39
42
return [x .labeling for x in sorted_beams ]
40
43
41
44
42
- def apply_lm (parent_beam , child_beam , chars , lm ) :
45
+ def apply_lm (parent_beam : BeamEntry , child_beam : BeamEntry , chars : str , lm : LanguageModel ) -> None :
43
46
"""Calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars."""
44
- if lm and not child_beam .lm_applied :
45
- c1 = chars [parent_beam .labeling [- 1 ] if parent_beam .labeling else chars .index (' ' )] # first char
46
- c2 = chars [child_beam .labeling [- 1 ]] # second char
47
- lm_factor = 0.01 # influence of language model
48
- bigram_prob = lm_factor * np .log (lm .get_char_bigram (c1 , c2 ))
49
- if parent_beam .is_empty ():
50
- child_beam .pr_text = bigram_prob # first char in beam
51
- else :
52
- child_beam .pr_text = parent_beam .pr_text + bigram_prob # probability of char sequence
53
- child_beam .lm_applied = True # only apply LM once per beam entry
47
+ if not lm or child_beam .lm_applied :
48
+ return
54
49
50
+ # take bigram if beam length at least 2
51
+ if len (child_beam .labeling ) > 1 :
52
+ c = chars [child_beam .labeling [- 2 ]]
53
+ d = chars [child_beam .labeling [- 1 ]]
54
+ ngram_prob = lm .get_char_bigram (c , d )
55
+ # otherwise take unigram
56
+ else :
57
+ c = chars [child_beam .labeling [- 1 ]]
58
+ ngram_prob = lm .get_char_unigram (c )
55
59
56
- def add_beam (beam_state , labeling ):
57
- """Add beam if it does not yet exist."""
58
- if labeling not in beam_state .entries :
59
- beam_state .entries [labeling ] = BeamEntry ()
60
+ lm_factor = 0.01 # influence of language model
61
+ child_beam .pr_text = parent_beam .pr_text + lm_factor * log (ngram_prob ) # probability of char sequence
62
+ child_beam .lm_applied = True # only apply LM once per beam entry
60
63
61
64
62
65
def beam_search (mat : np .ndarray , chars : str , beam_width : int = 25 , lm : Optional [LanguageModel ] = None ) -> str :
@@ -78,46 +81,38 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
78
81
max_T , max_C = mat .shape
79
82
80
83
# initialise beam state
81
- last = BeamState ()
84
+ last = BeamList ()
82
85
labeling = ()
83
86
last .entries [labeling ] = BeamEntry ()
84
- last .entries [labeling ].pr_blank = LOG_ZERO
85
- last .entries [labeling ].pr_total = LOG_ZERO
87
+ last .entries [labeling ].pr_blank = log ( 1 )
88
+ last .entries [labeling ].pr_total = log ( 1 )
86
89
87
90
# go over all time-steps
88
91
for t in range (max_T ):
89
- curr = BeamState ()
92
+ curr = BeamList ()
90
93
91
94
# get beam-labelings of best beams
92
- best_labelings = last .sort ()[0 :beam_width ]
95
+ best_labelings = last .sort_labelings ()[:beam_width ]
93
96
94
97
# go over best beams
95
98
for labeling in best_labelings :
96
99
97
100
# probability of paths ending with a non-blank
98
- pr_non_blank = LOG_ZERO
101
+ pr_non_blank = log ( 0 )
99
102
# in case of non-empty beam
100
103
if labeling :
101
104
# probability of paths with repeated last char at the end
102
- if last .entries [labeling ].pr_non_blank == LOG_ZERO :
103
- pr_non_blank = np .log (mat [t , labeling [- 1 ]]) # cannot add to -inf
104
- else :
105
- pr_non_blank = last .entries [labeling ].pr_non_blank + np .log (mat [t , labeling [- 1 ]])
105
+ pr_non_blank = last .entries [labeling ].pr_non_blank + log (mat [t , labeling [- 1 ]])
106
106
107
107
# probability of paths ending with a blank
108
- if last .entries [labeling ].pr_total == LOG_ZERO :
109
- pr_blank = np .log (mat [t , blank_idx ]) # cannot add to -inf
110
- else :
111
- pr_blank = last .entries [labeling ].pr_total + np .log (mat [t , blank_idx ])
112
-
113
- # add beam at current time-step if needed
114
- add_beam (curr , labeling )
108
+ pr_blank = last .entries [labeling ].pr_total + log (mat [t , blank_idx ])
115
109
116
- # fill in data
110
+ # fill in data for current beam
117
111
curr .entries [labeling ].labeling = labeling
118
112
curr .entries [labeling ].pr_non_blank = np .logaddexp (curr .entries [labeling ].pr_non_blank , pr_non_blank )
119
113
curr .entries [labeling ].pr_blank = np .logaddexp (curr .entries [labeling ].pr_blank , pr_blank )
120
- curr .entries [labeling ].pr_total = np .logaddexp (curr .entries [labeling ].pr_total , np .logaddexp (pr_blank , pr_non_blank ))
114
+ curr .entries [labeling ].pr_total = np .logaddexp (curr .entries [labeling ].pr_total ,
115
+ np .logaddexp (pr_blank , pr_non_blank ))
121
116
curr .entries [labeling ].pr_text = last .entries [labeling ].pr_text
122
117
curr .entries [labeling ].lm_applied = True # LM already applied at previous time-step for this beam-labeling
123
118
@@ -128,21 +123,14 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
128
123
129
124
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
130
125
if labeling and labeling [- 1 ] == c :
131
- # if pr_blank is 0 then we cannot extend the beam with a dupe char
132
- # so pr_non_blank should still be 0 (-inf in log-space)
133
- pr_non_blank = last .entries [labeling ].pr_blank + np .log (mat [t , c ])
126
+ pr_non_blank = last .entries [labeling ].pr_blank + log (mat [t , c ])
134
127
else :
135
- if last .entries [labeling ].pr_total == LOG_ZERO :
136
- pr_non_blank = np .log (mat [t , c ]) # cannot add to -inf
137
- else :
138
- pr_non_blank = last .entries [labeling ].pr_total + np .log (mat [t , c ])
139
-
140
- # add beam at current time-step if needed
141
- add_beam (curr , new_labeling )
128
+ pr_non_blank = last .entries [labeling ].pr_total + log (mat [t , c ])
142
129
143
130
# fill in data
144
131
curr .entries [new_labeling ].labeling = new_labeling
145
- curr .entries [new_labeling ].pr_non_blank = np .logaddexp (curr .entries [new_labeling ].pr_non_blank , pr_non_blank )
132
+ curr .entries [new_labeling ].pr_non_blank = np .logaddexp (curr .entries [new_labeling ].pr_non_blank ,
133
+ pr_non_blank )
146
134
curr .entries [new_labeling ].pr_total = np .logaddexp (curr .entries [new_labeling ].pr_total , pr_non_blank )
147
135
148
136
# apply LM
@@ -152,11 +140,11 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
152
140
last = curr
153
141
154
142
# normalise LM scores according to beam-labeling-length
155
- last .norm ()
143
+ last .normalize ()
156
144
157
145
# sort by probability
158
- best_labeling = last .sort ()[0 ] # get most probable labeling
146
+ best_labeling = last .sort_labelings ()[0 ] # get most probable labeling
159
147
160
148
# map label string to char string
161
- res = '' .join ([chars [l ] for l in best_labeling ])
149
+ res = '' .join ([chars [label ] for label in best_labeling ])
162
150
return res
0 commit comments