决策树算法Matlab实现(train+test)

news/2024/6/17 17:22:03 标签: 机器学习, matlab

决策树是一种特别简单的机器学习分类算法。决策树想法来源于人类的决策过程。举个最简单的例子,人类发现下雨的时候,往往会有刮东风,然后天色变暗。对应于决策树模型,预测天气模型中的刮东风和天色变暗就是我们收集的特征,是否下雨就是类别标签。构建的决策树如下图所示 
这里写图片描述 
决策树模型构建过程为,在特征集合中无放回的依次递归抽选特征作为决策树的节点——当前节点信息增益或者增益率最大,当前节点的值作为当前节点分支出来的有向边(实际上主要选择的是这些边,这个由信息增益的计算公式就可以得到)。对于这个进行直观解释 
来说一个极端情况,如果有一个特征下,特征取不同值的时候,对应的类别标签都是纯的,决策者肯定会选择这个特征,作为鉴别未知数据的判别准则。由下面的计算信息增益的公式可以发现这时候对应的信息增益是最大的。 
g(D,A)=H(D)-H(D|A) 
g(D,A):表示特征A对训练数据集D的信息增益 
H(D):表示数据集合D的经验熵 
H(D|A):表示特征A给定条件下数据集合D的条件熵。 
反之,当某个特征它的各个取值下对应的类别标签均匀分布的时候H(D|A)最大,又对于所有的特征H(D)是都一样的。因此,这时候的g(D,A)最小。 
总之一句话,我们要挑选的特征是:当前特征下各个取值包含的分类信息最明确。 
下面我们来看一个MATLAB编写的决策树算法,帮助理解 
树终止条件为 
1、特征数为空 
2、树为纯的 
3、信息增益或增益率小于阀值

一、模型训练部分 
训练模型主函数:

function decisionTreeModel=decisionTree(data,label,propertyName,delta)

global Node;

Node=struct('level',-1,'fatherNodeName',[],'EdgeProperty',[],'NodeName',[]);
BuildTree(-1,'root','Stem',data,label,propertyName,delta);
Node(1)=[];
model.Node=Node;
decisionTreeModel=model;
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9

递归构建决策树部分

matlab has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;">function BuildTree(fatherlevel,fatherNodeName,edge,data,label,propertyName,delta)

global Node;
sonNode=struct('level',0,'fatherNodeName',[],'EdgeProperty',[],'NodeName',[]);
sonNode.level=fatherlevel+1;
sonNode.fatherNodeName=fatherNodeName;
sonNode.EdgeProperty=edge;
if length(unique(label))==1
    sonNode.NodeName=label(1);
    Node=[Node sonNode];
    return;
end
if length(propertyName)<1
    labelSet=unique(label);
    k=length(labelSet);
    labelNum=zeros(k,1);
    for i=1:k
        labelNum(i)=length(find(label==labelSet(i)));
    end
    [~,labelIndex]=max(labelNum);
    sonNode.NodeName=labelSet(labelIndex);
    Node=[Node sonNode];
    return;
end
[sonIndex,BuildNode]=CalcuteNode(data,label,delta);
if BuildNode
    dataRowIndex=setdiff(1:length(propertyName),sonIndex);
    sonNode.NodeName=propertyName{sonIndex};
    Node=[Node sonNode];
    propertyName(sonIndex)=[];
    sonData=data(:,sonIndex);
    sonEdge=unique(sonData);

    for i=1:length(sonEdge)
        edgeDataIndex=find(sonData==sonEdge(i));
        BuildTree(sonNode.level,sonNode.NodeName,sonEdge(i),data(edgeDataIndex,dataRowIndex),label(edgeDataIndex,:),propertyName,delta);
    end
else
    labelSet=unique(label);
    k=length(labelSet);
    labelNum=zeros(k,1);
    for i=1:k
        labelNum(i)=length(find(label==labelSet(i)));
    end
    [~,labelIndex]=max(labelNum);
    sonNode.NodeName=labelSet(labelIndex);
    Node=[Node sonNode];
    return;
end
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49

计算决策树下一个节点特征

matlab has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;">function [NodeIndex,BuildNode]=CalcuteNode(data,label,delta)

LargeEntropy=CEntropy(label);
[m,n]=size(data);
EntropyGain=LargeEntropy*ones(1,n);
BuildNode=true;
for i=1:n
    pData=data(:,i);
    itemList=unique(pData);
    for j=1:length(itemList)
        itemIndex=find(pData==itemList(j));
        EntropyGain(i)=EntropyGain(i)-length(itemIndex)/m*CEntropy(label(itemIndex));
    end
    % 此处运行则为增益率,注释掉则为增益
    % EntropyGain(i)=EntropyGain(i)/CEntropy(pData); 
end
[maxGainEntropy,NodeIndex]=max(EntropyGain);
if maxGainEntropy<delta
    BuildNode=false;
end
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20

计算熵

matlab has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;">function result=CEntropy(propertyList)

result=0;
totalLength=length(propertyList);
itemList=unique(propertyList);
pNum=length(itemList);
for i=1:pNum
    itemLength=length(find(propertyList==itemList(i)));
    pItem=itemLength/totalLength;
    result=result-pItem*log2(pItem);
end
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

二、模型预测 
下面这个函数是根据训练好的决策树模型,输入测试样本集合和特征名,对每个测试样本预测输出结果。

function label=decisionTreeTest(decisionTreeModel,sampleSet,propertyName)

lengthSample=size(sampleSet,1);
label=zeros(lengthSample,1);
for sampleIndex=1:lengthSample
    sample=sampleSet(sampleIndex,:);
    Nodes=decisionTreeModel.Node;
    rootNode=Nodes(1);
    head=rootNode.NodeName;
    index=GetFeatureNum(propertyName,head);
    edge=sample(index);
    k=1;
    level=1;
    while k<length(Nodes)
        k=k+1;
        if Nodes(k).level==level
            if strcmp(Nodes(k).fatherNodeName,head)
                if Nodes(k).EdgeProperty==edge
                    if Nodes(k).NodeName<10
                        label(sampleIndex)=Nodes(k).NodeName;
                        break;
                    else
                        head=Nodes(k).NodeName;
                        index=GetFeatureNum(propertyName,head);
                        edge=sample(index);
                        level=level+1;
                    end
                end
            end
        end
    end
end
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

由于训练好的决策树模型里面保存的是节点名,所以在预测的时候需要将节点名对应的特征得到。下面这个函数是为了方便得到特征维数序号。

matlab has-numbering" style="display: block; padding: 0px; background: transparent; color: inherit; box-sizing: border-box; font-family: "Source Code Pro", monospace;font-size:undefined; white-space: pre; border-radius: 0px; word-wrap: normal;">function result=GetFeatureNum(propertyName,str)
result=0;
for i=1:length(propertyName)
    if strcmp(propertyName{i},str)==1
        result=i;
        break;
    end
end
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

三、决策树实验 
这是很多书本上都有的一个例子,可以看出预测结果准确率100%。

clear;clc;

% OutlookType=struct('Sunny',1,'Rainy',2,'Overcast',3);
% TemperatureType=struct('hot',1,'warm',2,'cool',3);
% HumidityType=struct('high',1,'norm',2);
% WindyType={'True',1,'False',0};
% PlayGolf={'Yes',1,'No',0};
% data=struct('Outlook',[],'Temperature',[],'Humidity',[],'Windy',[],'PlayGolf',[]);

Outlook=[1,1,3,2,2,2,3,1,1,2,1,3,3,2]';
Temperature=[1,1,1,2,3,3,3,2,3,3,2,2,1,2]';
Humidity=[1,1,1,1,2,2,2,1,2,2,2,1,2,1]';
Windy=[0,1,0,0,0,1,1,0,0,0,1,1,0,1]';

data=[Outlook Temperature Humidity Windy];
PlayGolf=[0,0,1,1,1,0,1,0,1,1,1,1,1,0]';
propertyName={'Outlook','Temperature','Humidity','Windy'};
delta=0.1;
decisionTreeModel=decisionTree(data,PlayGolf,propertyName,delta);

label=decisionTreeTest(decisionTreeModel,data,propertyName);


http://www.niftyadmin.cn/n/889513.html

相关文章

标签页 插件 html,基于Bootstrap以多标签页加载页面的Tabs插件

bTabs是一款以多标签页加载页面的jquery Tabs插件。bTabs可以方便的和bootstrap2和bootstrap3集成使用&#xff0c;通过简单配置即可让页面变为Ext、EasyUI之类多标签页模式&#xff0c;丰富业务展示模式。bTabs插件的特点有&#xff1a;支持Bootstrap2、3的UI环境。以多标签页…

机器学习之逻辑回归算法

下面是转载的内容&#xff0c;主要是介绍逻辑回归的理论知识&#xff0c;先总结一下自己看完的心得 简单来说线性回归就是直接将特征值和其对应的概率进行相乘得到一个结果&#xff0c;逻辑回归则是这样的结果上加上一个逻辑函数 这里选用的就是Sigmoid函数&#xff0c;在坐标尺…

在VSCode中成功安装Go相关插件问题:tools failed to install.

一、介绍 目的&#xff1a;本文将主要介绍在windows使用VSCode配置Go语言环境 软件&#xff1a;VSCode 二、安装出现的问题 完整信息如下 Installing 8 tools at D:\GoPath\bingo-outlinego-symbolsgurugorenamedlvgodefgoreturnsgolintInstalling golang.org/x/tools/cmd/guru…

html div添加关闭按钮,大神你好,请问怎么在以下代码的div中添加一个关闭按钮?...

该楼层疑似违规已被系统折叠 隐藏此楼查看此楼var hidefunction(){var divsdocument.getElementsByTagName("div");for (var i0;i{divs[i].style.display"none";}}window.οnlοadfunction(){hide();var adocument.getElementsByTagName("td");f…

core text html,用Core Text计算一段文本绘制在屏幕上后的高度

Core Text提供了一系列方便的函数&#xff0c;可以很容易的把文本绘制在屏幕上&#xff0c;对于一个Frame来说&#xff0c;一般并不需要担心文本的排列问题&#xff0c;这些Core Text的函数都可以直接搞定&#xff0c;只要给他一个大小合适的CGRect就可以。但&#xff0c;在某些…

OpenCV中的findContours函数参数详解

OpenCV中通过使用findContours函数&#xff0c;简单几个的步骤就可以检测出物体的轮廓&#xff0c;很方便。这些准备继续探讨一下 findContours方法中各参数的含义及用法&#xff0c;比如要求只检测最外层轮廓该怎么办&#xff1f;contours里边的数据结构是怎样 的&#xff1f;…

错误代码CS0051可访问性不一致_解决方案

一、问题的出现 用C#在写多线程时报错 二、解决方案 1&#xff0c;分析思路 本来对BaseStruct设置为私有访问&#xff0c;但调用时又想公开化&#xff0c;从而造成了编译错误。 2&#xff0c;解决 将红色部分改为公有 3&#xff0c;总结 注意public、pravite、和internal关键字…

OpenCV之meanshift分割详解

1. 原理 用meanshift做图像平滑和分割&#xff0c;其实是一回事。其本质是经过迭代&#xff0c;将收敛点的像素值代替原来的像素值&#xff0c;从而去除了局部相似的纹理&#xff0c;同时保留了边缘等差异较大的特征。 OpenCV中自带有基于meanshift的分割方法pyrMeanShiftFilte…