Skip to content

Commit 488aebb

Browse files
committed
Move distance function into DataSet where users can change it.
1 parent 426b06f commit 488aebb

File tree

1 file changed

+22
-20
lines changed

1 file changed

+22
-20
lines changed

learning.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,20 @@
55

66
#______________________________________________________________________________
77

8+
def rms_error(predictions, targets):
9+
return math.sqrt(ms_error(predictions, targets))
10+
11+
def ms_error(predictions, targets):
12+
return mean([(p - t)**2 for p, t in zip(predictions, targets)])
13+
14+
def mean_error(predictions, targets):
15+
return mean([abs(p - t) for p, t in zip(predictions, targets)])
16+
17+
def mean_boolean_error(predictions, targets):
18+
return mean([(p != t) for p, t in zip(predictions, targets)])
19+
20+
#______________________________________________________________________________
21+
822
class DataSet:
923
"""A data set for a machine learning problem. It has the following fields:
1024
@@ -19,21 +33,25 @@ class DataSet:
1933
values for the corresponding attribute. If initially None,
2034
it is computed from the known examples by self.setproblem.
2135
If not None, an erroneous value raises ValueError.
36+
d.distance A function from a pair of examples to a nonnegative number.
37+
Should be symmetric, etc. Defaults to mean_boolean_error
38+
since that can handle any field types.
2239
d.name Name of the data set (for output display only).
2340
d.source URL or other source where the data came from.
2441
2542
Normally, you call the constructor and you're done; then you just
2643
access fields like d.examples and d.target and d.inputs."""
2744

2845
def __init__(self, examples=None, attrs=None, attrnames=None, target=-1,
29-
inputs=None, values=None, name='', source='', exclude=()):
46+
inputs=None, values=None, distance=mean_boolean_error,
47+
name='', source='', exclude=()):
3048
"""Accepts any of DataSet's fields. Examples can also be a
3149
string or file from which to parse examples using parse_csv.
3250
Optional parameter: exclude, as documented in .setproblem().
3351
>>> DataSet(examples='1, 2, 3')
3452
<DataSet(): 1 examples, 3 attributes>
3553
"""
36-
update(self, name=name, source=source, values=values)
54+
update(self, name=name, source=source, values=values, distance=distance)
3755
# Initialize .examples from string or list or data directory
3856
if isinstance(examples, str):
3957
self.examples = parse_csv(examples)
@@ -121,19 +139,6 @@ def parse_csv(input, delim=','):
121139
lines = [line for line in input.splitlines() if line.strip() is not '']
122140
return [map(num_or_str, line.split(delim)) for line in lines]
123141

124-
def rms_error(predictions, targets):
125-
return math.sqrt(ms_error(predictions, targets))
126-
127-
def ms_error(predictions, targets):
128-
return mean([(p - t)**2 for p, t in zip(predictions, targets)])
129-
130-
def mean_error(predictions, targets):
131-
return mean([abs(p - t) for p, t in zip(predictions, targets)])
132-
133-
def mean_boolean_error(predictions, targets):
134-
return mean([(p != t) for p, t in zip(predictions, targets)])
135-
136-
137142
#______________________________________________________________________________
138143

139144
class Learner:
@@ -223,24 +228,21 @@ def predict(self, example):
223228
With k>1, find k closest, and have them vote for the best."""
224229
if self.k == 1:
225230
neighbor = argmin(self.dataset.examples,
226-
lambda e: self.distance(e, example))
231+
lambda e: self.dataset.distance(e, example))
227232
return neighbor[self.dataset.target]
228233
else:
229234
## Maintain a sorted list of (distance, example) pairs.
230235
## For very large k, a PriorityQueue would be better
231236
best = []
232237
for e in self.dataset.examples:
233-
d = self.distance(e, example)
238+
d = self.dataset.distance(e, example)
234239
if len(best) < self.k:
235240
best.append((d, e))
236241
elif d < best[-1][0]:
237242
best[-1] = (d, e)
238243
best.sort()
239244
return mode([e[self.dataset.target] for (d, e) in best])
240245

241-
def distance(self, e1, e2):
242-
return mean_boolean_error(e1, e2)
243-
244246
#______________________________________________________________________________
245247

246248
class DecisionTree:

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy