Skip to content

Commit c6f81c5

Browse files
committed
Fill out elimination_ask() (uncommented).
1 parent bb8b235 commit c6f81c5

File tree

1 file changed

+59
-12
lines changed

1 file changed

+59
-12
lines changed

probability.py

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -287,26 +287,73 @@ def enumerate_all(vars, e, bn):
287287

288288
#______________________________________________________________________________
289289

290-
def elimination_ask(X, e, bn, order=reversed):
291-
"[Fig. 14.11]"
290+
def elimination_ask(X, e, bn):
291+
"""[Fig. 14.11]
292+
>>> elimination_ask('Burglary', dict(JohnCalls=T, MaryCalls=T), burglary
293+
... ).show_approx()
294+
'False: 0.716, True: 0.284'"""
292295
factors = []
293-
for var in order(bn.vars):
294-
factors.append(Factor(var, e))
296+
for var in reversed(bn.vars):
297+
factors.append(make_factor(var, e, bn))
295298
if is_hidden(var, X, e):
296-
factors = sum_out(var, factors)
297-
return pointwise_product(factors).normalize()
299+
factors = sum_out(var, factors, bn)
300+
return pointwise_product(factors, bn).normalize()
298301

299302
def is_hidden(var, X, e):
300303
return var != X and var not in e
301304

302-
def Factor(var, e):
303-
unimplemented()
305+
def make_factor(var, e, bn):
306+
node = bn.variable_node(var)
307+
vars = [X for X in [var] + node.parents if X not in e]
308+
cpt = dict((event_values(e1, vars), node.p(e1[var], e1))
309+
for e1 in all_events(vars, bn, e))
310+
return Factor(vars, cpt)
311+
312+
def pointwise_product(factors, bn):
313+
return reduce(lambda f, g: f.pointwise_product(g, bn), factors)
314+
315+
def sum_out(var, factors, bn):
316+
result, var_factors = [], []
317+
for f in factors:
318+
(var_factors if var in f.vars else result).append(f)
319+
result.append(pointwise_product(var_factors, bn).sum_out(var, bn))
320+
return result
321+
322+
class Factor:
323+
324+
def __init__(self, vars, cpt):
325+
update(self, vars=vars, cpt=cpt)
326+
327+
def pointwise_product(self, other, bn):
328+
vars = list(set(self.vars) | set(other.vars))
329+
cpt = dict((event_values(e, vars), self.p(e) * other.p(e))
330+
for e in all_events(vars, bn, {}))
331+
return Factor(vars, cpt)
332+
333+
def sum_out(self, var, bn):
334+
vars = [X for X in self.vars if X != var]
335+
cpt = dict((event_values(e, vars),
336+
sum(self.p(extend(e, var, val))
337+
for val in bn.variable_values(var)))
338+
for e in all_events(vars, bn, {}))
339+
return Factor(vars, cpt)
304340

305-
def pointwise_product(factors):
306-
unimplemented()
341+
def normalize(self):
342+
assert len(self.vars) == 1
343+
return ProbDist(self.vars[0],
344+
dict((k, v) for ((k,), v) in self.cpt.items()))
307345

308-
def sum_out(var, factors):
309-
unimplemented()
346+
def p(self, e):
347+
return self.cpt[event_values(e, self.vars)]
348+
349+
def all_events(vars, bn, e1):
350+
if not vars:
351+
yield e1
352+
else:
353+
X, rest = vars[0], vars[1:]
354+
for e in all_events(rest, bn, e1):
355+
for x in bn.variable_values(X):
356+
yield extend(e, X, x)
310357

311358
#______________________________________________________________________________
312359

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy