5
5
6
6
#______________________________________________________________________________
7
7
8
+ def rms_error (predictions , targets ):
9
+ return math .sqrt (ms_error (predictions , targets ))
10
+
11
+ def ms_error (predictions , targets ):
12
+ return mean ([(p - t )** 2 for p , t in zip (predictions , targets )])
13
+
14
+ def mean_error (predictions , targets ):
15
+ return mean ([abs (p - t ) for p , t in zip (predictions , targets )])
16
+
17
+ def mean_boolean_error (predictions , targets ):
18
+ return mean ([(p != t ) for p , t in zip (predictions , targets )])
19
+
20
+ #______________________________________________________________________________
21
+
8
22
class DataSet :
9
23
"""A data set for a machine learning problem. It has the following fields:
10
24
@@ -19,21 +33,25 @@ class DataSet:
19
33
values for the corresponding attribute. If initially None,
20
34
it is computed from the known examples by self.setproblem.
21
35
If not None, an erroneous value raises ValueError.
36
+ d.distance A function from a pair of examples to a nonnegative number.
37
+ Should be symmetric, etc. Defaults to mean_boolean_error
38
+ since that can handle any field types.
22
39
d.name Name of the data set (for output display only).
23
40
d.source URL or other source where the data came from.
24
41
25
42
Normally, you call the constructor and you're done; then you just
26
43
access fields like d.examples and d.target and d.inputs."""
27
44
28
45
def __init__ (self , examples = None , attrs = None , attrnames = None , target = - 1 ,
29
- inputs = None , values = None , name = '' , source = '' , exclude = ()):
46
+ inputs = None , values = None , distance = mean_boolean_error ,
47
+ name = '' , source = '' , exclude = ()):
30
48
"""Accepts any of DataSet's fields. Examples can also be a
31
49
string or file from which to parse examples using parse_csv.
32
50
Optional parameter: exclude, as documented in .setproblem().
33
51
>>> DataSet(examples='1, 2, 3')
34
52
<DataSet(): 1 examples, 3 attributes>
35
53
"""
36
- update (self , name = name , source = source , values = values )
54
+ update (self , name = name , source = source , values = values , distance = distance )
37
55
# Initialize .examples from string or list or data directory
38
56
if isinstance (examples , str ):
39
57
self .examples = parse_csv (examples )
@@ -121,19 +139,6 @@ def parse_csv(input, delim=','):
121
139
lines = [line for line in input .splitlines () if line .strip () is not '' ]
122
140
return [map (num_or_str , line .split (delim )) for line in lines ]
123
141
124
- def rms_error (predictions , targets ):
125
- return math .sqrt (ms_error (predictions , targets ))
126
-
127
- def ms_error (predictions , targets ):
128
- return mean ([(p - t )** 2 for p , t in zip (predictions , targets )])
129
-
130
- def mean_error (predictions , targets ):
131
- return mean ([abs (p - t ) for p , t in zip (predictions , targets )])
132
-
133
- def mean_boolean_error (predictions , targets ):
134
- return mean ([(p != t ) for p , t in zip (predictions , targets )])
135
-
136
-
137
142
#______________________________________________________________________________
138
143
139
144
class Learner :
@@ -223,24 +228,21 @@ def predict(self, example):
223
228
With k>1, find k closest, and have them vote for the best."""
224
229
if self .k == 1 :
225
230
neighbor = argmin (self .dataset .examples ,
226
- lambda e : self .distance (e , example ))
231
+ lambda e : self .dataset . distance (e , example ))
227
232
return neighbor [self .dataset .target ]
228
233
else :
229
234
## Maintain a sorted list of (distance, example) pairs.
230
235
## For very large k, a PriorityQueue would be better
231
236
best = []
232
237
for e in self .dataset .examples :
233
- d = self .distance (e , example )
238
+ d = self .dataset . distance (e , example )
234
239
if len (best ) < self .k :
235
240
best .append ((d , e ))
236
241
elif d < best [- 1 ][0 ]:
237
242
best [- 1 ] = (d , e )
238
243
best .sort ()
239
244
return mode ([e [self .dataset .target ] for (d , e ) in best ])
240
245
241
- def distance (self , e1 , e2 ):
242
- return mean_boolean_error (e1 , e2 )
243
-
244
246
#______________________________________________________________________________
245
247
246
248
class DecisionTree :
0 commit comments