|
ree = new DecisionTree(datas);
return tree;
}
/**
* 构造随机森林
*/
public void constructRandomTree() {
DecisionTree tree;
random = new Random();
decisionForest = new ArrayList<>();
System.out.println("下面是随机森林中的决策树:");
// 构造决策树加入森林中
for (int i = 0; i < treeNum; i++) {
System.out.println("\n决策树" + (i+1));
tree = produceDecisionTree();
decisionForest.add(tree);
}
}
/**
* 根据给定的属性条件进行类别的决策
*
* @param features
* 给定的已知的属性描述
* @return
*/
public String judgeClassType(String features) {
// 结果类型值
String resultClassType = "";
String classType = "";
int count = 0;
Map type2Num = new HashMap();
for (DecisionTree tree : decisionForest) {
classType = tree.decideClassType(features);
if (type2Num.containsKey(classType)) {
// 如果类别已经存在,则使其计数加1
count = type2Num.get(classType);
count++;
} else {
count = 1;
}
type2Num.put(classType, count);
}
// 选出其中类别支持计数最多的一个类别值
count = -1;
for (Map.Entry entry : type2Num.entrySet()) {
if ((int) entry.getValue() > count) {
count = (int) entry.getValue();
resultClassType = (String) entry.getKey();
}
}
return resultClassType;
}
}
CART算法工具类CARTTool.java:
?
?
package DataMining_RandomForest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;
/**
* CART分类回归树算法工具类
*
* @author lyq
*
*/
public class CARTTool {
// 类标号的值类型
private final String YES = "Yes";
private final String NO = "No";
// 所有属性的类型总数,在这里就是data源数据的列数
private int attrNum;
private String filePath;
// 初始源数据,用一个二维字符数组存放模仿表格数据
private String[][] data;
// 数据的属性行的名字
private String[] attrNames;
// 每个属性的值所有类型
private HashMap> attrValue;
public CARTTool(ArrayList dataArray) {
attrValue = new HashMap<>();
readData(dataArray);
}
/**
* 根据随机选取的样本数据进行初始化
* @param dataArray
* 已经读入的样本数据
*/
public void readData(ArrayList dataArray) {
data = new String[dataArray.size()][];
dataArray.toArray(data);
attrNum = data[0].length;
attrNames = data[0];
}
/**
* 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
*/
public void initAttrValue() {
ArrayList tempValues;
// 按照列的方式,从左往右找
for (int j = 1; j < attrNum; j++) {
// 从一列中的上往下开始寻找值
tempValues = new ArrayList<>();
for (int i = 1; i < data.length; i++) {
if (!tempValues.contains(data[i][j])) {
// 如果这个属性的值没有添加过,则添加
tempValues.add(data[i][j]);
}
}
// 一列属性的值已经遍历完毕,复制到map属性表中
attrValue.put(data[0][j], tempValues);
}
}
/**
* 计算机基尼指数
*
* @param remainData
* 剩余数据
* @param attrName
* 属性名称
* @param value
* 属性值
* @param beLongValue
* 分类是否属于此属性值
* @return
*/
public double computeGini(String[][] remainData, String attrName,
String value, boolean beLongValue) {
// 实例总数
int total = 0;
// 正实例数
int posNum = 0;
// 负实例数
int negNum = 0;
// 基尼指数
double gini = 0;
// 还是按列从左往右遍历属性
for (int j = 1; j < attrNames.length; j++) {
// 找到了指定的属性
if (attrName.equals(attrNames[j])) {
for (int i = 1; i < remainData.length; i++) {
// 统计正负实例按照属于和不属于值类型进行划分
if ((beLongValue && remainData[i][j].equals(value))
|| (!beLongValue && !remainData[i][j].equals(value))) {
if (remainData[i][attrNames.length - 1].equals(YES)) {
// 判断此行数据是否为正实例
posNum++;
} else {
negNum++;
}
}
}
}
}
total = posNum + negNum;
double posProbobly = (double) posNum / total;
double negProbobly = (double) negNum / total;
gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;
// 返回计算基尼指数
return gini;
}
/**
* 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
*
* @param remainData
* 剩余谁
* @param attrName
* 属性名称
* @return
*/
public String[] computeAttrGini(String[][] remainData, String attrName) {
String[] str = new String[2];
// 最终该属性的划分类型值
String spiltValue = "";
// 临时变量
int tempNum = 0;
// 保存属性的值划分时的最小的基尼指数
double minGini = Integer.MAX_VALUE;
ArrayList valueTypes = attrValue.get(attrName);
// 属于此属性值的实例数
HashMap belongNum = new HashMap<>();
for (String string : valueTypes) {
// 重新计数的时候,数字归0
tempNum = 0;
// 按列从左往右遍历属性
for (int j = 1; j < attrNames.length; j++) {
// 找到了指定的属性
if (attrName.equals(attrNames[j])) {
for (int i = 1; i < remainData.length; i++) {
// 统计正负实例按照属于和不属于值类型进行划分
if (remainData[i][j].equals(string)) {
tempNum++;
}
}
}
}
belongNum.put(string, tempNum);
}
double tempGini = 0;
double posProbably = 1.0;
double negProbably = 1.0;
for (String string : valueTypes) {
tempGini = 0;
posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
negProbably = 1 - posProbably;
tempGini += posProbably
* computeGini(remainData, attrName, string, true);
tempGini += negProbably
* computeGini(remainData, attrName, string, false);
if (tempGini < minGini) {
minGini = tempGini;
spiltValue = string;
}
}
str[0] = spiltValue;
str[1] = minGini + "";
return str;
}
public void buildDecisionTree(TreeNode node, String parentAttrValue,
String[][] remainData, ArrayList remainAttr,
boolean beLongParentValue) {
// 属性划分值
String valueType = "";
// 划分属性名称
String spiltAttrName = "";
double minGini = Integer.MAX_VALUE;
double tempGini = 0;
// 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
String[] giniArray;
if (beLongParentValue) {
node.setParentAttrValue(parentAttrValue);
} else {
node.setParentAttrValue("!" + parentAttrValue);
}
if (remainAttr.size() == 0) {
if (remainData.length > 1) {
ArrayList indexArray = new ArrayList<>();
for (int i = 1; i < remainData.length; i++) {
indexArray.add(remainData[i][0]);
}
node.setDataIndex(indexArray);
}
// System.out.println("attr remain null");
return;
}
for (String str : remainAttr) {
giniArray = computeAttrGini(remainData, str);
tempGini = Double.parseDouble(giniArray[1]);
if (tempGini < minGini) {
spiltAttrName = str;
minGini = tempGini;
valueType = giniArray[0];
}
}
// 移除划分属性
remainAttr.remove(spiltAttrName);
node.setAttrName(spiltAttrName);
// 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
TreeNode[] childNode = new TreeNode[2];
String[][] rData;
boolean[] bArray = new boolean[] { true, false };
for (int i = 0; i < bArray.length; i++) {
// 二元划分属于属性值的划分
rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);
boolean sameClass = true;
ArrayList indexArray = new ArrayList<>();
for (int k = 1; k < rData.length; k++) {
indexArray.add(rData[k][0]);
// 判断是否为同一类的
if (!rData[k][attrNames.length - 1]
.equals(rData[1][attrNames.length - 1])) {
// 只要有1个不相等,就不是同类型的
sameClass = false;
break;
}
}
childNode[i] = new TreeNode();
if (!sameClass) {
// 创建新的对象属性,对象的同个引用会出错
ArrayList rAttr = new ArrayList<>();
for (String str : remainAttr) {
rAttr.add(str);
}
buildDecisionTree(childNode[i], valueType, rData, rAttr,
bArray[i]);
} else {
String pAtr = (bArray[i] ? valueType : "!" + valueType);
childNode[i].setParentAttrValue(pAtr);
childNode[i].setDataIndex(indexArray);
}
}
node.setChildAttrNode(childNode);
}
/**
* 属性划分完毕,进行数据的移除
*
* @param srcData
* 源数据
* @param attrName
* 划分的属性名称
* @param valueType
* 属性 |