@@ -257,11 +257,8 @@ def __init__(self, attr, attrname=None, branches=None):
257
257
258
258
def predict (self , example ):
259
259
"Given an example, use the tree to classify the example."
260
- child = self .branches [example [self .attr ]]
261
- if isinstance (child , DecisionTree ):
262
- return child .predict (example )
263
- else :
264
- return child
260
+ attrvalue = example [self .attr ]
261
+ return decision_tree_predict (self .branches [attrvalue ], example )
265
262
266
263
def add (self , val , subtree ):
267
264
"Add a branch. If self.attr = val, go to the given subtree."
@@ -280,19 +277,18 @@ def display(self, indent=0):
280
277
def __repr__ (self ):
281
278
return ('DecisionTree(%r, %r, %r)'
282
279
% (self .attr , self .attrname , self .branches ))
283
-
284
- Yes , No = True , False
280
+
281
+ def decision_tree_predict (tree , example ):
282
+ "Treat a non-DecisionTree as a leaf."
283
+ return tree .predict (example ) if isinstance (tree , DecisionTree ) else tree
285
284
286
285
#______________________________________________________________________________
287
286
288
287
class DecisionTreeLearner (Learner ):
289
288
"[Fig. 18.5]"
290
289
291
290
def predict (self , example ):
292
- if isinstance (self .dt , DecisionTree ):
293
- return self .dt .predict (example )
294
- else :
295
- return self .dt
291
+ return decision_tree_predict (self .dt , example )
296
292
297
293
def train (self , dataset ):
298
294
self .dataset = dataset
0 commit comments