Skip to content

Commit 63458a9

Browse files
antmarakisnorvig
authored andcommitted
Learning: Naive Bayes Classifier (aimacode#618)
* add a simple naive bayes classifier * Update test_learning.py * spacing * minor fix * lists to strings
1 parent a58fe90 commit 63458a9

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

learning.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,13 +306,35 @@ def predict(example):
306306
# ______________________________________________________________________________
307307

308308

309-
def NaiveBayesLearner(dataset, continuous=True):
309+
def NaiveBayesLearner(dataset, continuous=True, simple=False):
310+
if simple:
311+
return NaiveBayesSimple(dataset)
310312
if(continuous):
311313
return NaiveBayesContinuous(dataset)
312314
else:
313315
return NaiveBayesDiscrete(dataset)
314316

315317

318+
def NaiveBayesSimple(distribution):
319+
"""A simple naive bayes classifier that takes as input a dictionary of
320+
CountingProbDist objects and classifies items according to these distributions.
321+
The input dictionary is in the following form:
322+
(ClassName, ClassProb): CountingProbDist"""
323+
target_dist = {c_name: prob for c_name, prob in distribution.keys()}
324+
attr_dists = {c_name: count_prob for (c_name, _), count_prob in distribution.items()}
325+
326+
def predict(example):
327+
"""Predict the target value for example. Calculate probabilities for each
328+
class and pick the max."""
329+
def class_probability(targetval):
330+
attr_dist = attr_dists[targetval]
331+
return target_dist[targetval] * product(attr_dist[a] for a in example)
332+
333+
return argmax(target_dist.keys(), key=class_probability)
334+
335+
return predict
336+
337+
316338
def NaiveBayesDiscrete(dataset):
317339
"""Just count how many times each value of each input attribute
318340
occurs, conditional on the target value. Count the different

tests/test_learning.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,20 @@ def test_naive_bayes():
105105
assert nBC([6, 5, 3, 1.5]) == "versicolor"
106106
assert nBC([7, 3, 6.5, 2]) == "virginica"
107107

108+
# Simple
109+
data1 = 'a'*50 + 'b'*30 + 'c'*15
110+
dist1 = CountingProbDist(data1)
111+
data2 = 'a'*30 + 'b'*45 + 'c'*20
112+
dist2 = CountingProbDist(data2)
113+
data3 = 'a'*20 + 'b'*20 + 'c'*35
114+
dist3 = CountingProbDist(data3)
115+
116+
dist = {('First', 0.5): dist1, ('Second', 0.3): dist2, ('Third', 0.2): dist3}
117+
nBS = NaiveBayesLearner(dist, simple=True)
118+
assert nBS('aab') == 'First'
119+
assert nBS(['b', 'b']) == 'Second'
120+
assert nBS('ccbcc') == 'Third'
121+
108122

109123
def test_k_nearest_neighbors():
110124
iris = DataSet(name="iris")

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy