Skip to content

Commit 6b5c3dd

Browse files
author
Harald Scheidl
committed
beam search: reworked log-prob implementation, reworked lang model (unigram and bigram)
1 parent 37a8a1f commit 6b5c3dd

File tree

3 files changed

+87
-89
lines changed

3 files changed

+87
-89
lines changed

ctc_decoder/beam_search.py

Lines changed: 54 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,62 +1,65 @@
1-
from typing import Optional
1+
from collections import defaultdict
2+
from dataclasses import dataclass
3+
from typing import Optional, List, Tuple
24

35
import numpy as np
46

57
from ctc_decoder.language_model import LanguageModel
68

7-
LOG_ZERO = float("-inf")
89

10+
def log(x: float) -> float:
11+
with np.errstate(divide='ignore'):
12+
return np.log(x)
13+
14+
15+
@dataclass
916
class BeamEntry:
1017
"""Information about one single beam at specific time-step."""
18+
pr_total: float = log(0) # blank and non-blank
19+
pr_non_blank: float = log(0) # non-blank
20+
pr_blank: float = log(0) # blank
21+
pr_text: float = log(1) # LM score
22+
lm_applied: bool = False # flag if LM was already applied to this beam
23+
labeling: tuple = () # beam-labeling
1124

12-
def __init__(self):
13-
self.pr_total = LOG_ZERO # blank and non-blank
14-
self.pr_non_blank = LOG_ZERO # non-blank
15-
self.pr_blank = LOG_ZERO # blank
16-
self.pr_text = 0 # LM score
17-
self.lm_applied = False # flag if LM was already applied to this beam
18-
self.labeling = () # beam-labeling
19-
20-
def is_empty(self):
21-
return len(self.labeling) == 0
2225

23-
class BeamState:
26+
class BeamList:
2427
"""Information about all beams at specific time-step."""
2528

26-
def __init__(self):
27-
self.entries = {}
29+
def __init__(self) -> None:
30+
self.entries = defaultdict(BeamEntry)
2831

29-
def norm(self):
32+
def normalize(self) -> None:
3033
"""Length-normalise LM score."""
3134
for k in self.entries.keys():
3235
labeling_len = len(self.entries[k].labeling)
3336
self.entries[k].pr_text = (1.0 / (labeling_len if labeling_len else 1.0)) * self.entries[k].pr_text
3437

35-
def sort(self):
38+
def sort_labelings(self) -> List[Tuple[int]]:
3639
"""Return beam-labelings, sorted by probability."""
37-
beams = [v for (_, v) in self.entries.items()]
40+
beams = self.entries.values()
3841
sorted_beams = sorted(beams, reverse=True, key=lambda x: x.pr_total + x.pr_text)
3942
return [x.labeling for x in sorted_beams]
4043

4144

42-
def apply_lm(parent_beam, child_beam, chars, lm):
45+
def apply_lm(parent_beam: BeamEntry, child_beam: BeamEntry, chars: str, lm: LanguageModel) -> None:
4346
"""Calculate LM score of child beam by taking score from parent beam and bigram probability of last two chars."""
44-
if lm and not child_beam.lm_applied:
45-
c1 = chars[parent_beam.labeling[-1] if parent_beam.labeling else chars.index(' ')] # first char
46-
c2 = chars[child_beam.labeling[-1]] # second char
47-
lm_factor = 0.01 # influence of language model
48-
bigram_prob = lm_factor * np.log(lm.get_char_bigram(c1, c2))
49-
if parent_beam.is_empty():
50-
child_beam.pr_text = bigram_prob # first char in beam
51-
else:
52-
child_beam.pr_text = parent_beam.pr_text + bigram_prob # probability of char sequence
53-
child_beam.lm_applied = True # only apply LM once per beam entry
47+
if not lm or child_beam.lm_applied:
48+
return
5449

50+
# take bigram if beam length at least 2
51+
if len(child_beam.labeling) > 1:
52+
c = chars[child_beam.labeling[-2]]
53+
d = chars[child_beam.labeling[-1]]
54+
ngram_prob = lm.get_char_bigram(c, d)
55+
# otherwise take unigram
56+
else:
57+
c = chars[child_beam.labeling[-1]]
58+
ngram_prob = lm.get_char_unigram(c)
5559

56-
def add_beam(beam_state, labeling):
57-
"""Add beam if it does not yet exist."""
58-
if labeling not in beam_state.entries:
59-
beam_state.entries[labeling] = BeamEntry()
60+
lm_factor = 0.01 # influence of language model
61+
child_beam.pr_text = parent_beam.pr_text + lm_factor * log(ngram_prob) # probability of char sequence
62+
child_beam.lm_applied = True # only apply LM once per beam entry
6063

6164

6265
def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[LanguageModel] = None) -> str:
@@ -78,46 +81,38 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
7881
max_T, max_C = mat.shape
7982

8083
# initialise beam state
81-
last = BeamState()
84+
last = BeamList()
8285
labeling = ()
8386
last.entries[labeling] = BeamEntry()
84-
last.entries[labeling].pr_blank = LOG_ZERO
85-
last.entries[labeling].pr_total = LOG_ZERO
87+
last.entries[labeling].pr_blank = log(1)
88+
last.entries[labeling].pr_total = log(1)
8689

8790
# go over all time-steps
8891
for t in range(max_T):
89-
curr = BeamState()
92+
curr = BeamList()
9093

9194
# get beam-labelings of best beams
92-
best_labelings = last.sort()[0:beam_width]
95+
best_labelings = last.sort_labelings()[:beam_width]
9396

9497
# go over best beams
9598
for labeling in best_labelings:
9699

97100
# probability of paths ending with a non-blank
98-
pr_non_blank = LOG_ZERO
101+
pr_non_blank = log(0)
99102
# in case of non-empty beam
100103
if labeling:
101104
# probability of paths with repeated last char at the end
102-
if last.entries[labeling].pr_non_blank == LOG_ZERO:
103-
pr_non_blank = np.log(mat[t, labeling[-1]]) # cannot add to -inf
104-
else:
105-
pr_non_blank = last.entries[labeling].pr_non_blank + np.log(mat[t, labeling[-1]])
105+
pr_non_blank = last.entries[labeling].pr_non_blank + log(mat[t, labeling[-1]])
106106

107107
# probability of paths ending with a blank
108-
if last.entries[labeling].pr_total == LOG_ZERO:
109-
pr_blank = np.log(mat[t, blank_idx]) # cannot add to -inf
110-
else:
111-
pr_blank = last.entries[labeling].pr_total + np.log(mat[t, blank_idx])
112-
113-
# add beam at current time-step if needed
114-
add_beam(curr, labeling)
108+
pr_blank = last.entries[labeling].pr_total + log(mat[t, blank_idx])
115109

116-
# fill in data
110+
# fill in data for current beam
117111
curr.entries[labeling].labeling = labeling
118112
curr.entries[labeling].pr_non_blank = np.logaddexp(curr.entries[labeling].pr_non_blank, pr_non_blank)
119113
curr.entries[labeling].pr_blank = np.logaddexp(curr.entries[labeling].pr_blank, pr_blank)
120-
curr.entries[labeling].pr_total = np.logaddexp(curr.entries[labeling].pr_total, np.logaddexp(pr_blank, pr_non_blank))
114+
curr.entries[labeling].pr_total = np.logaddexp(curr.entries[labeling].pr_total,
115+
np.logaddexp(pr_blank, pr_non_blank))
121116
curr.entries[labeling].pr_text = last.entries[labeling].pr_text
122117
curr.entries[labeling].lm_applied = True # LM already applied at previous time-step for this beam-labeling
123118

@@ -128,21 +123,14 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
128123

129124
# if new labeling contains duplicate char at the end, only consider paths ending with a blank
130125
if labeling and labeling[-1] == c:
131-
# if pr_blank is 0 then we cannot extend the beam with a dupe char
132-
# so pr_non_blank should still be 0 (-inf in log-space)
133-
pr_non_blank = last.entries[labeling].pr_blank + np.log(mat[t, c])
126+
pr_non_blank = last.entries[labeling].pr_blank + log(mat[t, c])
134127
else:
135-
if last.entries[labeling].pr_total == LOG_ZERO:
136-
pr_non_blank = np.log(mat[t, c]) # cannot add to -inf
137-
else:
138-
pr_non_blank = last.entries[labeling].pr_total + np.log(mat[t, c])
139-
140-
# add beam at current time-step if needed
141-
add_beam(curr, new_labeling)
128+
pr_non_blank = last.entries[labeling].pr_total + log(mat[t, c])
142129

143130
# fill in data
144131
curr.entries[new_labeling].labeling = new_labeling
145-
curr.entries[new_labeling].pr_non_blank = np.logaddexp(curr.entries[new_labeling].pr_non_blank, pr_non_blank)
132+
curr.entries[new_labeling].pr_non_blank = np.logaddexp(curr.entries[new_labeling].pr_non_blank,
133+
pr_non_blank)
146134
curr.entries[new_labeling].pr_total = np.logaddexp(curr.entries[new_labeling].pr_total, pr_non_blank)
147135

148136
# apply LM
@@ -152,11 +140,11 @@ def beam_search(mat: np.ndarray, chars: str, beam_width: int = 25, lm: Optional[
152140
last = curr
153141

154142
# normalise LM scores according to beam-labeling-length
155-
last.norm()
143+
last.normalize()
156144

157145
# sort by probability
158-
best_labeling = last.sort()[0] # get most probable labeling
146+
best_labeling = last.sort_labelings()[0] # get most probable labeling
159147

160148
# map label string to char string
161-
res = ''.join([chars[l] for l in best_labeling])
149+
res = ''.join([chars[label] for label in best_labeling])
162150
return res

ctc_decoder/language_model.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,43 @@ class LanguageModel:
33

44
def __init__(self, txt: str, chars: str) -> None:
55
"""Create language model from text corpus."""
6-
txt = ' ' + txt + ' ' # ensure first/last characters appear next to whitespace
7-
self._init_char_bigrams(txt, chars)
86

9-
def _init_char_bigrams(self, txt: str, chars: str) -> None:
10-
"""Initialize table of character bigrams."""
11-
12-
# init bigrams with 0 values
13-
self.bigram = {c: {d: 0 for d in chars} for c in chars}
7+
# compute unigrams
8+
self._unigram = {c: 0 for c in chars}
9+
for c in chars:
10+
# ignore unknown chars
11+
if c not in self._unigram:
12+
continue
13+
self._unigram[c] += 1
1414

15-
# go through text and add each char bigram
15+
# compute bigrams
16+
self._bigram = {c: {d: 0 for d in chars} for c in chars}
1617
for i in range(len(txt) - 1):
17-
first = txt[i]
18-
second = txt[i + 1]
18+
c = txt[i]
19+
d = txt[i + 1]
1920

2021
# ignore unknown chars
21-
if first not in self.bigram or second not in self.bigram[first]:
22+
if c not in self._bigram or d not in self._bigram[c]:
2223
continue
2324

24-
self.bigram[first][second] += 1
25+
self._bigram[c][d] += 1
26+
27+
# normalize
28+
sum_unigram = sum(self._unigram.values())
29+
for c in chars:
30+
self._unigram[c] /= sum_unigram
31+
32+
for c in chars:
33+
sum_bigram = sum(self._bigram[c].values())
34+
if sum_bigram == 0:
35+
continue
36+
for d in chars:
37+
self._bigram[c][d] /= sum_bigram
2538

26-
def get_char_bigram(self, first: str, second: str) -> float:
27-
"""Probability that first character is followed by second one."""
28-
first = first if first else ' ' # map start to word beginning
29-
second = second if second else ' ' # map end to word end
39+
def get_char_unigram(self, c: str) -> float:
40+
"""Probability of character c."""
41+
return self._unigram[c]
3042

31-
# number of bigrams starting with given char
32-
num_bigrams = sum(self.bigram[first].values())
33-
if num_bigrams == 0:
34-
return 0
35-
return self.bigram[first][second] / num_bigrams
43+
def get_char_bigram(self, c: str, d: str) -> float:
44+
"""Probability that character c is followed by character d."""
45+
return self._bigram[c][d]

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
setup(
44
name='ctc-decoder',
5-
version='1.0.0',
5+
version='1.0.1',
66
description='Connectionist Temporal Classification decoders.',
77
author='Harald Scheidl',
88
packages=['ctc_decoder'],
99
url="https://github.com/githubharald/CTCDecoder",
1010
install_requires=['editdistance', 'numpy'],
11-
python_requires=">=3.6"
11+
python_requires='>=3.7'
1212
)

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy