From 780609b88310463ade7883aabcc4fdf518952e58 Mon Sep 17 00:00:00 2001 From: Lucas Moura Date: Wed, 8 Mar 2017 15:47:30 -0300 Subject: [PATCH 1/4] Add new test to test_learning.py --- tests/test_learning.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/test_learning.py b/tests/test_learning.py index ec2cf18bd..6010907d2 100644 --- a/tests/test_learning.py +++ b/tests/test_learning.py @@ -1,11 +1,10 @@ from learning import parse_csv, weighted_mode, weighted_replicate, DataSet, \ PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, \ NeuralNetLearner, PerceptronLearner, DecisionTreeLearner, \ - euclidean_distance, grade_learner, err_ratio + euclidean_distance, grade_learner, err_ratio, train_and_test from utils import DataFile - def test_euclidean(): distance = euclidean_distance([1, 2], [3, 4]) assert round(distance, 2) == 2.83 @@ -124,3 +123,14 @@ def test_perceptron(): assert grade_learner(perceptron, tests) > 1/2 assert err_ratio(perceptron, iris) < 0.4 + + +def test_train_and_test(): + dataset = DataSet(name="iris") + start = 50 + end = 100 + + train_set, validation_set = train_and_test(dataset, start, end) + + assert len(train_set) == 100 + assert len(validation_set) == 50 From eeed91fed770d555d730104b1e30e78bbafa7729 Mon Sep 17 00:00:00 2001 From: Lucas Moura Date: Wed, 8 Mar 2017 17:06:28 -0300 Subject: [PATCH 2/4] Update print_table on utils.py Update how table values are formated, since tuples could be received as well --- utils.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/utils.py b/utils.py index 5afa43760..b703a8a55 100644 --- a/utils.py +++ b/utils.py @@ -311,6 +311,18 @@ def issequence(x): return isinstance(x, collections.abc.Sequence) +def format_table_value(value, numfmt): + if isnumber(value): + value = numfmt.format(value) + elif type(value) is tuple: + tmp = [] + for v in value: + tmp.append(format_table_value(v, numfmt)) + value = tuple(tmp) + + return value + + def print_table(table, header=None, sep=' ', numfmt='{}'): """Print a list of lists as a table, so that columns line up nicely. header, if specified, will be printed as the first row. @@ -322,7 +334,7 @@ def print_table(table, header=None, sep=' ', numfmt='{}'): if header: table.insert(0, header) - table = [[numfmt.format(x) if isnumber(x) else x for x in row] + table = [[format_table_value(x, numfmt) for x in row] for row in table] sizes = list( From 594a8230eefd760af1987158031de3625863e057 Mon Sep 17 00:00:00 2001 From: Lucas Moura Date: Wed, 8 Mar 2017 17:09:23 -0300 Subject: [PATCH 3/4] Update cross_validation implementation --- learning.py | 64 +++++++++++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 27 deletions(-) diff --git a/learning.py b/learning.py index 06a719745..c214abda3 100644 --- a/learning.py +++ b/learning.py @@ -293,7 +293,7 @@ def sample(self): # ______________________________________________________________________________ -def PluralityLearner(dataset): +def PluralityLearner(dataset, size=None): """A very dumb algorithm: always pick the result that was most popular in the training data. Makes a baseline for comparison.""" most_popular = mode([e[dataset.target] for e in dataset.examples]) @@ -306,14 +306,14 @@ def predict(example): # ______________________________________________________________________________ -def NaiveBayesLearner(dataset, continuous=True): +def NaiveBayesLearner(dataset, size=None, continuous=True): if(continuous): - return NaiveBayesContinuous(dataset) + return NaiveBayesContinuous(dataset, size) else: - return NaiveBayesDiscrete(dataset) + return NaiveBayesDiscrete(dataset, size) -def NaiveBayesDiscrete(dataset): +def NaiveBayesDiscrete(dataset, size): """Just count how many times each value of each input attribute occurs, conditional on the target value. Count the different target values too.""" @@ -341,7 +341,7 @@ def class_probability(targetval): return predict -def NaiveBayesContinuous(dataset): +def NaiveBayesContinuous(dataset, size): """Count how many times each target value occurs. Also, find the means and deviations of input attribute values for each target value.""" means, deviations = dataset.find_means_and_deviations() @@ -426,7 +426,7 @@ def __repr__(self): # ______________________________________________________________________________ -def DecisionTreeLearner(dataset): +def DecisionTreeLearner(dataset, size=None): """[Figure 18.5]""" target, values = dataset.target, dataset.values @@ -911,66 +911,76 @@ def train_and_test(dataset, start, end): return train, val +def partition(dataset, fold, k): + num_examples = len(dataset.examples) + return train_and_test(dataset, fold * (num_examples / k), (fold + 1) * (num_examples / k)) + + def cross_validation(learner, size, dataset, k=10, trials=1): """Do k-fold cross_validate and return their mean. That is, keep out 1/k of the examples for testing on each of k runs. Shuffle the examples first; if trials>1, average over several shuffles. Returns Training error, Validataion error""" - if k is None: - k = len(dataset.examples) if trials > 1: trial_errT = 0 trial_errV = 0 + for t in range(trials): - errT, errV = cross_validation(learner, size, dataset, - k=10, trials=1) + errT, errV = cross_validation(learner, size, dataset, k) trial_errT += errT trial_errV += errV + return trial_errT / trials, trial_errV / trials else: fold_errT = 0 fold_errV = 0 - n = len(dataset.examples) + examples = dataset.examples for fold in range(k): random.shuffle(dataset.examples) - train_data, val_data = train_and_test(dataset, fold * (n / k), - (fold + 1) * (n / k)) - dataset.examples = train_data + training_set, validation_set = partition(dataset, fold, k) h = learner(dataset, size) + fold_errT += err_ratio(h, dataset, train_data) fold_errV += err_ratio(h, dataset, val_data) + # Reverting back to original once test is completed dataset.examples = examples + return fold_errT / k, fold_errV / k +def leave_one_out(learner, dataset, size=None): + """Leave one out cross-validation over the dataset.""" + return cross_validation(learner, size, dataset, k=len(dataset.examples)) + + +def converges(err_val): + """Check for convergence provided err_val has more than two values""" + return err_val >= 2 and isclose(err_val[-2], err_val[-1], rel_tol=1e-6) + + def cross_validation_wrapper(learner, dataset, k=10, trials=1): """[Fig 18.8] Return the optimal value of size having minimum error - on validataion set. + on validation set. err_train: A training error array, indexed by size err_val: A validataion error array, indexed by size """ - err_val = [] err_train = [] + err_val = [] + size = 1 while True: errT, errV = cross_validation(learner, size, dataset, k) - # Check for convergence provided err_val is not empty - if (err_val and isclose(err_val[-1], errV, rel_tol=1e-6)): - best_size = size - return learner(dataset, best_size) - err_val.append(errV) err_train.append(errT) - print(err_val) - size += 1 + if converges(err_val): + best_size = size + return learner(dataset, best_size) -def leave_one_out(learner, dataset, size=None): - """Leave one out cross-validation over the dataset.""" - return cross_validation(learner, size, dataset, k=len(dataset.examples)) + size += 1 def learningcurve(learner, dataset, trials=10, sizes=None): From 685e2a57d07cf199a437ec320cfcf6cd9187af5b Mon Sep 17 00:00:00 2001 From: Lucas Moura Date: Wed, 8 Mar 2017 17:09:58 -0300 Subject: [PATCH 4/4] Update compare function in learning.py --- learning.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/learning.py b/learning.py index c214abda3..0c8eeaf9a 100644 --- a/learning.py +++ b/learning.py @@ -1105,14 +1105,14 @@ def ContinuousXor(n): # ______________________________________________________________________________ -def compare(algorithms=[PluralityLearner, NaiveBayesLearner, - NearestNeighborLearner, DecisionTreeLearner], - datasets=[iris, orings, zoo, restaurant, SyntheticRestaurant(20), - Majority(7, 100), Parity(7, 100), Xor(100)], - k=10, trials=1): +def compare(algorithms=[PluralityLearner, NaiveBayesLearner, NearestNeighborLearner, + DecisionTreeLearner], + datasets=[iris, orings, zoo, restaurant, SyntheticRestaurant(20), Majority(7, 100), + Parity(7, 100), Xor(100)], + k=10, size=3, trials=1): """Compare various learners on various datasets using cross-validation. Print results as a table.""" print_table([[a.__name__.replace('Learner', '')] + - [cross_validation(a, d, k, trials) for d in datasets] + [cross_validation(a, size, d, k, trials) for d in datasets] for a in algorithms], - header=[''] + [d.name[0:7] for d in datasets], numfmt='%.2f') + header=[''] + [d.name[0:7] for d in datasets], numfmt='{:.2f}') pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy