Skip to content

Commit 60a20a4

Browse files
committed
Fixed normalize() to work on []. Added AdaBoost.
1 parent 004f873 commit 60a20a4

File tree

2 files changed

+34
-8
lines changed

2 files changed

+34
-8
lines changed

learning.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Learn to estimate functions from examples. (Chapters 18-20)"""
22

33
from utils import *
4-
import heapq, random
4+
import heapq, math, random
55

66
#______________________________________________________________________________
77

@@ -318,9 +318,7 @@ def split_by(attr, examples):
318318
def information_content(values):
319319
"Number of bits to represent the probability distribution in values."
320320
# If the values do not sum to 1, normalize them to make them a Prob. Dist.
321-
values = removeall(0, values)
322-
s = float(sum(values))
323-
if s != 1.0: values = [v/s for v in values]
321+
values = normalize(removeall(0, values))
324322
return sum([- v * log2(v) for v in values])
325323

326324
#______________________________________________________________________________
@@ -394,6 +392,34 @@ def predict(example):
394392
return predict
395393
return train
396394

395+
#______________________________________________________________________________
396+
397+
def AdaBoost(L, K):
398+
"""[Fig. 18.34]"""
399+
def train(dataset):
400+
examples, target = dataset.examples, dataset.target
401+
N = len(examples)
402+
w = [1./N] * N
403+
h, z = [], []
404+
for k in range(K):
405+
h_k = L(dataset.examples, w)
406+
h.append(h_k)
407+
error = sum(weight for example, weight in zip(examples, w)
408+
if example[target] != h_k(example))
409+
if error == 0:
410+
break
411+
assert error < 1, "AdaBoost's sub-learner misclassified everything"
412+
for j, example in enumerate(examples):
413+
if example[target] == h[k](example):
414+
w[j] *= error / (1. - error)
415+
w = normalize(w)
416+
z.append(math.log((1. - error) / error))
417+
return WeightedMajority(h, z)
418+
return train
419+
420+
def WeightedMajority(h, z):
421+
raise NotImplementedError
422+
397423
#_____________________________________________________________________________
398424
# Functions for testing learners on examples
399425

utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,13 +518,13 @@ def num_or_str(x):
518518
except ValueError:
519519
return str(x).strip()
520520

521-
def normalize(numbers, total=1.0):
522-
"""Multiply each number by a constant such that the sum is 1.0 (or total).
521+
def normalize(numbers):
522+
"""Multiply each number by a constant such that the sum is 1.0
523523
>>> normalize([1,2,1])
524524
[0.25, 0.5, 0.25]
525525
"""
526-
k = total / sum(numbers)
527-
return [k * n for n in numbers]
526+
total = float(sum(numbers))
527+
return [n / total for n in numbers]
528528

529529
## OK, the following are not as widely useful utilities as some of the other
530530
## functions here, but they do show up wherever we have 2D grids: Wumpus and

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy