diff --git a/learning.py b/learning.py index 8347fbbef..fffbccf83 100644 --- a/learning.py +++ b/learning.py @@ -908,6 +908,16 @@ def score(learner, size): return [(size, mean([score(learner, size) for t in range(trials)])) for size in sizes] + +def grade_learner(predict, tests): + """Grades the given learner based on how many tests it passes. + tests is a list with each element in the form: (values, output).""" + correct = 0 + for t in tests: + if predict(t[0]) == t[1]: + correct += 1 + return correct + # ______________________________________________________________________________ # The rest of this file gives datasets for machine learning problems. diff --git a/tests/test_learning.py b/tests/test_learning.py index 1b4b825c1..1bac9a4cc 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -1,19 +1,19 @@ from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \ NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \ - euclidean_distance + euclidean_distance, grade_learner from utils import DataFile def test_euclidean(): - distance = euclidean_distance([1,2], [3,4]) + distance = euclidean_distance([1, 2], [3, 4]) assert round(distance, 2) == 2.83 - distance = euclidean_distance([1,2,3], [4,5,6]) + distance = euclidean_distance([1, 2, 3], [4, 5, 6]) assert round(distance, 2) == 5.2 - distance = euclidean_distance([0,0,0], [0,0,0]) + distance = euclidean_distance([0, 0, 0], [0, 0, 0]) assert distance == 0 @@ -24,7 +24,7 @@ def test_exclude(): def test_parse_csv(): Iris = DataFile('iris.csv').read() - assert parse_csv(Iris)[0] == [5.1,3.5,1.4,0.2,'setosa'] + assert parse_csv(Iris)[0] == [5.1, 3.5, 1.4, 0.2,'setosa'] def test_weighted_mode(): @@ -47,25 +47,25 @@ def test_naive_bayes(): # Discrete nBD = NaiveBayesLearner(iris) - assert nBD([5,3,1,0.1]) == "setosa" + assert nBD([5, 3, 1, 0.1]) == "setosa" def test_k_nearest_neighbors(): iris = DataSet(name="iris") kNN = NearestNeighborLearner(iris,k=3) - assert kNN([5,3,1,0.1]) == "setosa" - assert kNN([6,5,3,1.5]) == "versicolor" - assert kNN([7.5,4,6,2]) == "virginica" + assert kNN([5, 3, 1, 0.1]) == "setosa" + assert kNN([6, 5, 3, 1.5]) == "versicolor" + assert kNN([7.5, 4, 6, 2]) == "virginica" def test_decision_tree_learner(): iris = DataSet(name="iris") dTL = DecisionTreeLearner(iris) - assert dTL([5,3,1,0.1]) == "setosa" - assert dTL([6,5,3,1.5]) == "versicolor" - assert dTL([7.5,4,6,2]) == "virginica" + assert dTL([5, 3, 1, 0.1]) == "setosa" + assert dTL([6, 5, 3, 1.5]) == "versicolor" + assert dTL([7.5, 4, 6, 2]) == "virginica" def test_neural_network_learner(): @@ -75,14 +75,11 @@ def test_neural_network_learner(): iris.classes_to_numbers(classes) nNL = NeuralNetLearner(iris, [5], 0.15, 75) - pred1 = nNL([5,3,1,0.1]) - pred2 = nNL([6,3,3,1.5]) - pred3 = nNL([7.5,4,6,2]) + tests = [([5, 3, 1, 0.1], 0), + ([6, 3, 3, 1.5], 1), + ([7.5, 4, 6, 2], 2)] - # NeuralNetLearner might be wrong. If it is, check if prediction is in range. - assert pred1 == 0 or pred1 in range(len(classes)) - assert pred2 == 1 or pred2 in range(len(classes)) - assert pred3 == 2 or pred3 in range(len(classes)) + assert grade_learner(nNL, tests) >= 2 def test_perceptron(): @@ -92,11 +89,8 @@ def test_perceptron(): classes_number = len(iris.values[iris.target]) perceptron = PerceptronLearner(iris) - pred1 = perceptron([5,3,1,0.1]) - pred2 = perceptron([6,3,4,1]) - pred3 = perceptron([7.5,4,6,2]) - - # PerceptronLearner might be wrong. If it is, check if prediction is in range. - assert pred1 == 0 or pred1 in range(classes_number) - assert pred2 == 1 or pred2 in range(classes_number) - assert pred3 == 2 or pred3 in range(classes_number) + tests = [([5, 3, 1, 0.1], 0), + ([6, 3, 4, 1.1], 1), + ([7.5, 4, 6, 2], 2)] + + assert grade_learner(perceptron, tests) >= 2 pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy