Skip to content

Commit 17fac54

Browse files
antmarakisnorvig
authored andcommitted
Learning: Grade Learner (#496)
* Add grade_learner * Update test_learning.py
1 parent fb503e6 commit 17fac54

File tree

2 files changed

+31
-27
lines changed

2 files changed

+31
-27
lines changed

learning.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,16 @@ def score(learner, size):
908908
return [(size, mean([score(learner, size) for t in range(trials)]))
909909
for size in sizes]
910910

911+
912+
def grade_learner(predict, tests):
913+
"""Grades the given learner based on how many tests it passes.
914+
tests is a list with each element in the form: (values, output)."""
915+
correct = 0
916+
for t in tests:
917+
if predict(t[0]) == t[1]:
918+
correct += 1
919+
return correct
920+
911921
# ______________________________________________________________________________
912922
# The rest of this file gives datasets for machine learning problems.
913923

tests/test_learning.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
11
from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \
22
PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \
33
NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \
4-
euclidean_distance
4+
euclidean_distance, grade_learner
55
from utils import DataFile
66

77

88

99
def test_euclidean():
10-
distance = euclidean_distance([1,2], [3,4])
10+
distance = euclidean_distance([1, 2], [3, 4])
1111
assert round(distance, 2) == 2.83
1212

13-
distance = euclidean_distance([1,2,3], [4,5,6])
13+
distance = euclidean_distance([1, 2, 3], [4, 5, 6])
1414
assert round(distance, 2) == 5.2
1515

16-
distance = euclidean_distance([0,0,0], [0,0,0])
16+
distance = euclidean_distance([0, 0, 0], [0, 0, 0])
1717
assert distance == 0
1818

1919

@@ -24,7 +24,7 @@ def test_exclude():
2424

2525
def test_parse_csv():
2626
Iris = DataFile('iris.csv').read()
27-
assert parse_csv(Iris)[0] == [5.1,3.5,1.4,0.2,'setosa']
27+
assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2,'setosa']
2828

2929

3030
def test_weighted_mode():
@@ -47,25 +47,25 @@ def test_naive_bayes():
4747

4848
# Discrete
4949
nBD = NaiveBayesLearner(iris)
50-
assert nBD([5,3,1,0.1]) == "setosa"
50+
assert nBD([5, 3, 1, 0.1]) == "setosa"
5151

5252

5353
def test_k_nearest_neighbors():
5454
iris = DataSet(name="iris")
5555

5656
kNN = NearestNeighborLearner(iris,k=3)
57-
assert kNN([5,3,1,0.1]) == "setosa"
58-
assert kNN([6,5,3,1.5]) == "versicolor"
59-
assert kNN([7.5,4,6,2]) == "virginica"
57+
assert kNN([5, 3, 1, 0.1]) == "setosa"
58+
assert kNN([6, 5, 3, 1.5]) == "versicolor"
59+
assert kNN([7.5, 4, 6, 2]) == "virginica"
6060

6161

6262
def test_decision_tree_learner():
6363
iris = DataSet(name="iris")
6464

6565
dTL = DecisionTreeLearner(iris)
66-
assert dTL([5,3,1,0.1]) == "setosa"
67-
assert dTL([6,5,3,1.5]) == "versicolor"
68-
assert dTL([7.5,4,6,2]) == "virginica"
66+
assert dTL([5, 3, 1, 0.1]) == "setosa"
67+
assert dTL([6, 5, 3, 1.5]) == "versicolor"
68+
assert dTL([7.5, 4, 6, 2]) == "virginica"
6969

7070

7171
def test_neural_network_learner():
@@ -75,14 +75,11 @@ def test_neural_network_learner():
7575
iris.classes_to_numbers(classes)
7676

7777
nNL = NeuralNetLearner(iris, [5], 0.15, 75)
78-
pred1 = nNL([5,3,1,0.1])
79-
pred2 = nNL([6,3,3,1.5])
80-
pred3 = nNL([7.5,4,6,2])
78+
tests = [([5, 3, 1, 0.1], 0),
79+
([6, 3, 3, 1.5], 1),
80+
([7.5, 4, 6, 2], 2)]
8181

82-
# NeuralNetLearner might be wrong. If it is, check if prediction is in range.
83-
assert pred1 == 0 or pred1 in range(len(classes))
84-
assert pred2 == 1 or pred2 in range(len(classes))
85-
assert pred3 == 2 or pred3 in range(len(classes))
82+
assert grade_learner(nNL, tests) >= 2
8683

8784

8885
def test_perceptron():
@@ -92,11 +89,8 @@ def test_perceptron():
9289
classes_number = len(iris.values[iris.target])
9390

9491
perceptron = PerceptronLearner(iris)
95-
pred1 = perceptron([5,3,1,0.1])
96-
pred2 = perceptron([6,3,4,1])
97-
pred3 = perceptron([7.5,4,6,2])
98-
99-
# PerceptronLearner might be wrong. If it is, check if prediction is in range.
100-
assert pred1 == 0 or pred1 in range(classes_number)
101-
assert pred2 == 1 or pred2 in range(classes_number)
102-
assert pred3 == 2 or pred3 in range(classes_number)
92+
tests = [([5, 3, 1, 0.1], 0),
93+
([6, 3, 4, 1.1], 1),
94+
([7.5, 4, 6, 2], 2)]
95+
96+
assert grade_learner(perceptron, tests) >= 2

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy