Skip to content

Commit 7be52df

Browse files
committed
Refactor CountingProbDist:
* Use weighted_sampler() for the sampling. * By not using a defaultdict, keep n_obs always up to date. * Don't inherit from ProbDist, at least for now, since we don't use anything from the superclass, and if we did it would break (e.g. __setitem__). * Remove __len__ since nobody uses it. * Use heapq.nlargest. Also separate random doctests from others, and tweak docs.
1 parent 2cf81dc commit 7be52df

File tree

2 files changed

+51
-57
lines changed

2 files changed

+51
-57
lines changed

text.py

Lines changed: 45 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
from utils import *
1111
from math import log, exp
12-
import re, probability, string, search
12+
import heapq, re, search
1313

14-
class CountingProbDist(probability.ProbDist):
14+
class CountingProbDist:
1515
"""A probability distribution formed by observing and counting examples.
16-
If P is an instance of this class and o
17-
is an observed value, then there are 3 main operations:
16+
If p is an instance of this class and o is an observed value, then
17+
there are 3 main operations:
1818
p.add(o) increments the count for observation o by 1.
1919
p.sample() returns a random element from the distribution.
2020
p[o] returns the probability for o (as in a regular ProbDist)."""
@@ -23,49 +23,40 @@ def __init__(self, observations=[], default=0):
2323
"""Create a distribution, and optionally add in some observations.
2424
By default this is an unsmoothed distribution, but saying default=1,
2525
for example, gives you add-one smoothing."""
26-
update(self, dictionary=DefaultDict(default), needs_recompute=False,
27-
table=[], n_obs=0)
26+
update(self, dictionary={}, n_obs=0.0, default=default, sampler=None)
2827
for o in observations:
2928
self.add(o)
3029

3130
def add(self, o):
32-
"""Add an observation o to the distribution."""
31+
"Add an observation o to the distribution."
32+
self.smooth_for(o)
3333
self.dictionary[o] += 1
3434
self.n_obs += 1
35-
self.needs_recompute = True
35+
self.sampler = None
3636

37-
def sample(self):
38-
"""Return a random sample from the distribution."""
39-
if self.needs_recompute: self._recompute()
40-
if self.n_obs == 0:
41-
return None
42-
i = bisect.bisect_left(self.table, (1 + random.randrange(self.n_obs),))
43-
(count, o) = self.table[i]
44-
return o
37+
def smooth_for(self, o):
38+
"""Include o among the possible observations, whether or not
39+
it's been observed yet."""
40+
if o not in self.dictionary:
41+
self.dictionary[o] = self.default
42+
self.n_obs += self.default
43+
self.sampler = None
4544

4645
def __getitem__(self, item):
47-
"""Return an estimate of the probability of item."""
48-
if self.needs_recompute: self._recompute()
46+
"Return an estimate of the probability of item."
47+
self.smooth_for(item)
4948
return self.dictionary[item] / self.n_obs
5049

51-
def __len__(self):
52-
if self.needs_recompute: self._recompute()
53-
return self.n_obs
54-
5550
def top(self, n):
5651
"Return (count, obs) tuples for the n most frequent observations."
57-
items = [(v, k) for (k, v) in self.dictionary.items()]
58-
items.sort(); items.reverse()
59-
return items[0:n]
60-
61-
def _recompute(self):
62-
"""Recompute the total count n_obs and the table of entries."""
63-
n_obs = 0
64-
table = []
65-
for (o, count) in self.dictionary.items():
66-
n_obs += count
67-
table.append((n_obs, o))
68-
update(self, n_obs=float(n_obs), table=table, needs_recompute=False)
52+
return heapq.nlargest(n, [(v, k) for (k, v) in self.dictionary.items()])
53+
54+
def sample(self):
55+
"Return a random sample from the distribution."
56+
if self.sampler is None:
57+
self.sampler = weighted_sampler(self.dictionary.keys(),
58+
self.dictionary.values())
59+
return self.sampler()
6960

7061
#______________________________________________________________________________
7162

@@ -81,7 +72,7 @@ def samples(self, n):
8172
class NgramTextModel(CountingProbDist):
8273
"""This is a discrete probability distribution over n-tuples of words.
8374
You can add, sample or get P[(word1, ..., wordn)]. The method P.samples(n)
84-
builds up an n-word sequence; P.add_text and P.add_sequence add data."""
75+
builds up an n-word sequence; P.add and P.add_sequence add data."""
8576

8677
def __init__(self, n, observation_sequence=[]):
8778
## In addition to the dictionary of n-tuples, cond_prob is a
@@ -91,7 +82,7 @@ def __init__(self, n, observation_sequence=[]):
9182
self.cond_prob = DefaultDict(CountingProbDist())
9283
self.add_sequence(observation_sequence)
9384

94-
## sample, __len__, __getitem__ inherited from CountingProbDist
85+
## sample, __getitem__ inherited from CountingProbDist
9586
## Note they deal with tuples, not strings, as inputs
9687

9788
def add(self, ngram):
@@ -113,13 +104,12 @@ def samples(self, nwords):
113104
n = self.n
114105
nminus1gram = ('',) * (n-1)
115106
output = []
116-
while len(output) < nwords:
107+
for i in range(nwords):
108+
if nminus1gram not in self.cond_prob:
109+
nminus1gram = ('',) * (n-1) # Cannot continue, so restart.
117110
wn = self.cond_prob[nminus1gram].sample()
118-
if wn:
119-
output.append(wn)
120-
nminus1gram = nminus1gram[1:] + (wn,)
121-
else: ## Cannot continue, so restart.
122-
nminus1gram = ('',) * (n-1)
111+
output.append(wn)
112+
nminus1gram = nminus1gram[1:] + (wn,)
123113
return ' '.join(output)
124114

125115
#______________________________________________________________________________
@@ -404,24 +394,14 @@ def goal_test(self, state):
404394
True
405395
"""
406396

407-
__doc__ += random_tests("""
397+
__doc__ += ("""
408398
## Compare 1-, 2-, and 3-gram word models of the same text.
409399
>>> flatland = DataFile("EN-text/flatland.txt").read()
410400
>>> wordseq = words(flatland)
411401
>>> P1 = UnigramTextModel(wordseq)
412402
>>> P2 = NgramTextModel(2, wordseq)
413403
>>> P3 = NgramTextModel(3, wordseq)
414404
415-
## Generate random text from the N-gram models
416-
>>> P1.samples(20)
417-
'you thought known but were insides of see in depend by us dodecahedrons just but i words are instead degrees'
418-
419-
>>> P2.samples(20)
420-
'flatland well then can anything else more into the total destruction and circles teach others confine women must be added'
421-
422-
>>> P3.samples(20)
423-
'flatland by edwin a abbott 1884 to the wake of a certificate from nature herself proving the equal sided triangle'
424-
425405
## The most frequent entries in each model
426406
>>> P1.top(10)
427407
[(2081, 'the'), (1479, 'of'), (1021, 'and'), (1008, 'to'), (850, 'a'), (722, 'i'), (640, 'in'), (478, 'that'), (399, 'is'), (348, 'you')]
@@ -431,6 +411,18 @@ def goal_test(self, state):
431411
432412
>>> P3.top(10)
433413
[(30, ('a', 'straight', 'line')), (19, ('of', 'three', 'dimensions')), (16, ('the', 'sense', 'of')), (13, ('by', 'the', 'sense')), (13, ('as', 'well', 'as')), (12, ('of', 'the', 'circles')), (12, ('of', 'sight', 'recognition')), (11, ('the', 'number', 'of')), (11, ('that', 'i', 'had')), (11, ('so', 'as', 'to'))]
414+
""")
415+
416+
__doc__ += random_tests("""
417+
## Generate random text from the N-gram models
418+
>>> P1.samples(20)
419+
'you thought known but were insides of see in depend by us dodecahedrons just but i words are instead degrees'
420+
421+
>>> P2.samples(20)
422+
'flatland well then can anything else more into the total destruction and circles teach others confine women must be added'
423+
424+
>>> P3.samples(20)
425+
'flatland by edwin a abbott 1884 to the wake of a certificate from nature herself proving the equal sided triangle'
434426
435427
## Probabilities of some common n-grams
436428
>>> P1['the']

utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -505,13 +505,15 @@ def weighted_sample_with_replacement(seq, weights, n):
505505
"""Pick n samples from seq at random, with replacement, with the
506506
probability of each element in proportion to its corresponding
507507
weight."""
508+
sample = weighted_sampler(seq, weights)
509+
return [sample() for s in range(n)]
510+
511+
def weighted_sampler(seq, weights):
512+
"Return a random-sample function that picks from seq weighted by weights."
508513
totals = []
509514
for w in weights:
510515
totals.append(w + totals[-1] if totals else w)
511-
def sample():
512-
r = random.uniform(0, totals[-1])
513-
return seq[bisect.bisect(totals, r)]
514-
return [sample() for s in range(n)]
516+
return lambda: seq[bisect.bisect(totals, random.uniform(0, totals[-1]))]
515517

516518
def num_or_str(x):
517519
"""The argument is a string; convert to a number if possible, or strip it.

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy