Skip to content

Commit e8247fd

Browse files
committed
add decision tree
1 parent e67ee1c commit e8247fd

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

heuristic/decision_tree.hpp

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#pragma once
2+
#include <cassert>
3+
#include <sstream>
4+
#include <string>
5+
#include <vector>
6+
7+
struct decision_tree {
8+
decision_tree() = default;
9+
10+
using ull = unsigned long long;
11+
static const int D = 64;
12+
13+
struct Node {
14+
bool is_leaf = false;
15+
bool mixed = false;
16+
bool label = false;
17+
18+
int split_by = -1;
19+
int ch1 = -1;
20+
int ch0 = -1;
21+
std::string encode() const {
22+
int x = is_leaf * 4 + mixed * 2 + label;
23+
return char('0' + x) + std::to_string(split_by) + "/" + std::to_string(ch1) + "/" +
24+
std::to_string(ch0);
25+
}
26+
static Node decode(const std::string &s) {
27+
int x = s.at(0) - '0';
28+
int split_by = -1, ch1 = -1, ch0 = -1;
29+
char tmp;
30+
std::stringstream ss(s.substr(1));
31+
ss >> split_by >> tmp >> ch1 >> tmp >> ch0;
32+
return Node{bool(x / 4 % 2), bool(x / 2 % 2), bool(x % 2), split_by, ch1, ch0};
33+
}
34+
};
35+
36+
std::vector<Node> nodes;
37+
38+
std::string encode() const {
39+
std::string ret;
40+
for (const Node &n : nodes) ret += n.encode() + " ";
41+
if (!ret.empty()) ret.pop_back();
42+
return ret;
43+
}
44+
45+
static decision_tree decode(const std::string &s) {
46+
std::stringstream ss(s);
47+
decision_tree ret;
48+
while (!ss.eof()) {
49+
std::string s;
50+
ss >> s;
51+
ret.nodes.push_back(Node::decode(s));
52+
}
53+
return ret;
54+
}
55+
56+
static double GiniImpurity(double p) { return 2 * p * (1 - p); }
57+
58+
int rec_fit(std::vector<ull> &Xy, std::vector<int> &pos_xsum, std::vector<int> &neg_xsum,
59+
int npos, int nneg, int dim) {
60+
const int node_id = nodes.size();
61+
// dbg(make_tuple("Call", node_id, Xy.size(), nneg, npos));
62+
nodes.push_back(Node());
63+
if (!npos or !nneg) {
64+
nodes.back().is_leaf = true;
65+
nodes.back().label = npos ? true : false;
66+
} else {
67+
double piv_max = 1e30;
68+
int arg_piv_max = -1;
69+
70+
for (int c = 0; c < dim; ++c) {
71+
double ch1_rate = 1.0 * (pos_xsum[c] + neg_xsum[c]) / (npos + nneg);
72+
if (0.0 < ch1_rate and ch1_rate < 1.0) {
73+
double tmp =
74+
ch1_rate * GiniImpurity(1.0 * pos_xsum[c] / (pos_xsum[c] + neg_xsum[c])) +
75+
(1 - ch1_rate) * GiniImpurity(1.0 * (npos - pos_xsum[c]) /
76+
(npos + nneg - pos_xsum[c] - neg_xsum[c]));
77+
if (tmp < piv_max) piv_max = tmp, arg_piv_max = c;
78+
}
79+
}
80+
81+
if (arg_piv_max >= 0) {
82+
const int nb_ch1 = pos_xsum.at(arg_piv_max) + neg_xsum.at(arg_piv_max);
83+
std::vector<ull> Xy1;
84+
std::vector<int> pos_xsum1(dim), neg_xsum1(dim);
85+
int npos1 = 0, nneg1 = 0;
86+
87+
const bool mode = (nb_ch1 >= npos + nneg);
88+
89+
for (int i = 0; i < int(Xy.size());) {
90+
if (mode ^ ((Xy[i] >> arg_piv_max) & 1)) {
91+
Xy1.push_back(Xy[i]);
92+
const bool y = (Xy[i] >> (D - 1)) & 1;
93+
--(y ? npos : nneg);
94+
++(y ? npos1 : nneg1);
95+
for (int j = 0; j < dim; ++j) {
96+
if ((Xy[i] >> j) & 1) {
97+
--(y ? pos_xsum : neg_xsum)[j];
98+
++(y ? pos_xsum1 : neg_xsum1)[j];
99+
}
100+
}
101+
std::swap(Xy[i], Xy.back());
102+
Xy.pop_back();
103+
} else {
104+
++i;
105+
}
106+
}
107+
108+
nodes.at(node_id).split_by = arg_piv_max;
109+
nodes.at(node_id).ch1 = rec_fit(Xy, pos_xsum, neg_xsum, npos, nneg, dim);
110+
nodes.at(node_id).ch0 = rec_fit(Xy1, pos_xsum1, neg_xsum1, npos1, nneg1, dim);
111+
112+
if (!mode) std::swap(nodes[node_id].ch0, nodes[node_id].ch1);
113+
} else {
114+
nodes.at(node_id).is_leaf = true;
115+
nodes.at(node_id).mixed = true;
116+
nodes.at(node_id).label = npos > nneg;
117+
}
118+
}
119+
return node_id;
120+
}
121+
122+
template <class T1, class T2>
123+
void fit(const std::vector<std::vector<T1>> &X, const std::vector<T2> &y, int dim = 0) {
124+
if (dim == 0 and !X.empty()) dim = X.front().size();
125+
assert(dim <= 63);
126+
assert(X.size() == y.size());
127+
128+
std::vector<ull> Xy(X.size());
129+
std::vector<int> pos_xsum(dim), neg_xsum(dim);
130+
int npos = 0, nneg = 0;
131+
for (int i = 0; i < int(X.size()); ++i) {
132+
bool yi = y[i];
133+
++(yi ? npos : nneg);
134+
ull xy = ull(yi) << (D - 1);
135+
for (int j = 0; j < dim; ++j) {
136+
if (X[i][j]) {
137+
xy |= ull(1) << j;
138+
++(yi ? pos_xsum : neg_xsum)[j];
139+
}
140+
}
141+
Xy[i] = xy;
142+
}
143+
rec_fit(Xy, pos_xsum, neg_xsum, npos, nneg, dim);
144+
}
145+
146+
template <class T> bool predict(const std::vector<T> &x) {
147+
int now = 0;
148+
while (!nodes.at(now).is_leaf) {
149+
now = x.at(nodes.at(now).split_by) ? nodes.at(now).ch1 : nodes.at(now).ch0;
150+
}
151+
return nodes.at(now).label;
152+
}
153+
};

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy