Skip to content

Commit 60f022a

Browse files
committed
Added code to make a replicated training set.
1 parent acd54b3 commit 60f022a

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

learning.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Learn to estimate functions from examples. (Chapters 18-20)"""
22

33
from utils import *
4-
import heapq, math, random
4+
import copy, heapq, math, random
55
from collections import defaultdict
66

77
#______________________________________________________________________________
@@ -433,6 +433,39 @@ def weighted_mode(values, weights):
433433
totals[v] += w
434434
return max(totals.keys(), key=totals.get)
435435

436+
#_____________________________________________________________________________
437+
# Adapting an unweighted learner for AdaBoost
438+
439+
def WeightedLearner(unweighted_learner):
440+
"""Given a learner that takes just an unweighted dataset, return
441+
one that takes also a weight for each example. [p. 749 footnote 14]"""
442+
def train(dataset, weights):
443+
return unweighted_learner(replicated_dataset(dataset, weights))
444+
return train
445+
446+
def replicated_dataset(dataset, weights, n=None):
447+
"""Copy dataset, replicating each example in proportion to the
448+
corresponding weight."""
449+
n = n or len(dataset.examples)
450+
result = copy.copy(dataset)
451+
result.examples = weighted_replicate(dataset.examples, weights, n)
452+
return result
453+
454+
def weighted_replicate(seq, weights, n):
455+
"""Return n selections from seq, with the count of each element of
456+
seq proportional to the corresponding weight (filling in fractions
457+
randomly).
458+
>>> weighted_replicate('ABC', [1,2,1], 4)
459+
['A', 'B', 'B', 'C']"""
460+
assert len(seq) == len(weights)
461+
weights = normalize(weights)
462+
wholes = [int(w*n) for w in weights]
463+
fractions = [(w*n) % 1 for w in weights]
464+
return (flatten([x] * nx for x, nx in zip(seq, wholes))
465+
+ weighted_sample_with_replacement(seq, fractions, n - sum(wholes)))
466+
467+
def flatten(seqs): return sum(seqs, [])
468+
436469
#_____________________________________________________________________________
437470
# Functions for testing learners on examples
438471

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy