信息
信息论奠基人香农(Shannon)认为“信息是用来消除随机不确定性的东西”,信息量用来量化消除的不确定的多少。事件发生的概率越低,那么该事件发生的信息量就越高
一个发生的事件x的信息量为
h(x) = – log_2{(P(x))}
概率越小,信息量就越大。
信息熵是所有可能发生事件的信息量的期望值
$$H(X) = – \sum_{x \in X}{p(x) log_2(p(x))}$$
python代码中,data的最后一列是分类的情况,calcShannonEntropy函数用来计算信息熵。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
#data : list of [f1 f2 f3 label] def calcShannonEntropy(data): m = len(data) dic = {} for v in data: label = v[-1] if label not in dic.keys(): dic[label] = 0 dic[label] += 1 shannonEntropy = 0.0 for label in dic: p = dic[label] / float(m) shannonEntropy -= p * log(p, 2) return shannonEntropy |
splitData用于选取第axis维度等于value的向量,同时在向量中去掉axis这一维度
1 2 3 4 5 6 7 8 |
def splitData(data, axis, value): ret = [] for v in data: if v[axis] == value: t = v[:axis] t.extend(v[axis+1:]) ret.append(t) return ret |
chooseBestFeatureToSplit函数中,分别对每个特征计算选择后的信息熵,选择最好的特征。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
def chooseBestFeatureToSplit(data): n = len(data[0]) - 1 bestFeature = 0 bestShannonEntropy = calcShannonEntropy(data) for i in range(n): featureList = [v[i] for v in data] values = set(featureList) shannonEntropy = 0.0 for value in values: subdata = splitData(data, i, value) p = len(subdata) / float(len(data)) shannonEntropy += p * calcShannonEntropy(subdata) if shannonEntropy < bestShannonEntropy: bestFeature = i bestShannonEntropy = shannonEntropy return bestFeature |
createTree
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
def majorityCount(classList): dic = {} for v in classList: if v not in dic.keys(): dic[v] = 0 dic[v] += 1 sortedDic = sorted(dic.items(), key = operator.itemgetter(1), reverse=True) return sortedDic[0][0] def createTree(data, labels): classList = [a[-1] for a in data] if classList.count(classList[0]) == len(classList): return classList[0] if len(classList[0]) == 1: return majorityCount(classList) bestFeature = chooseBestFeatureToSplit(data) bestFeatureLabel = labels[bestFeature] tree = {bestFeatureLabel:{}} del(labels[bestFeature]) features = set([a[bestFeature] for a in data]) for feature in features: subLabels = labels[:] tree[bestFeatureLabel][feature] = createTree( splitData(data, bestFeature, feature),subLabels) return tree def classify(inputTree, featureLabels, testVec): firstStr = list(inputTree.keys())[0] subTree = inputTree[firstStr] featIndex = featureLabels.index(firstStr) for key in subTree.keys(): if testVec[featIndex] == key: if (type(subTree[key]).__name__ == 'dict'): classLabel = classify(subTree, featureLabels, testVec) else: classLabel = subTree[key] return classLabel |