随机森林和GBDT的学习(四)

2015-07-24 11:55:25 · 作者: · 浏览: 29
ree = new DecisionTree(datas); return tree; } /** * 构造随机森林 */ public void constructRandomTree() { DecisionTree tree; random = new Random(); decisionForest = new ArrayList<>(); System.out.println("下面是随机森林中的决策树:"); // 构造决策树加入森林中 for (int i = 0; i < treeNum; i++) { System.out.println("\n决策树" + (i+1)); tree = produceDecisionTree(); decisionForest.add(tree); } } /** * 根据给定的属性条件进行类别的决策 * * @param features * 给定的已知的属性描述 * @return */ public String judgeClassType(String features) { // 结果类型值 String resultClassType = ""; String classType = ""; int count = 0; Map type2Num = new HashMap(); for (DecisionTree tree : decisionForest) { classType = tree.decideClassType(features); if (type2Num.containsKey(classType)) { // 如果类别已经存在,则使其计数加1 count = type2Num.get(classType); count++; } else { count = 1; } type2Num.put(classType, count); } // 选出其中类别支持计数最多的一个类别值 count = -1; for (Map.Entry entry : type2Num.entrySet()) { if ((int) entry.getValue() > count) { count = (int) entry.getValue(); resultClassType = (String) entry.getKey(); } } return resultClassType; } } CART算法工具类CARTTool.java:

?

?

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;

/**
 * CART分类回归树算法工具类
 * 
 * @author lyq
 * 
 */
public class CARTTool {
	// 类标号的值类型
	private final String YES = "Yes";
	private final String NO = "No";

	// 所有属性的类型总数,在这里就是data源数据的列数
	private int attrNum;
	private String filePath;
	// 初始源数据,用一个二维字符数组存放模仿表格数据
	private String[][] data;
	// 数据的属性行的名字
	private String[] attrNames;
	// 每个属性的值所有类型
	private HashMap> attrValue;

	public CARTTool(ArrayList dataArray) {
		attrValue = new HashMap<>();
		readData(dataArray);
	}

	/**
	 * 根据随机选取的样本数据进行初始化
	 * @param dataArray
	 * 已经读入的样本数据
	 */
	public void readData(ArrayList dataArray) {
		data = new String[dataArray.size()][];
		dataArray.toArray(data);
		attrNum = data[0].length;
		attrNames = data[0];
	}

	/**
	 * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
	 */
	public void initAttrValue() {
		ArrayList tempValues;

		// 按照列的方式,从左往右找
		for (int j = 1; j < attrNum; j++) {
			// 从一列中的上往下开始寻找值
			tempValues = new ArrayList<>();
			for (int i = 1; i < data.length; i++) {
				if (!tempValues.contains(data[i][j])) {
					// 如果这个属性的值没有添加过,则添加
					tempValues.add(data[i][j]);
				}
			}

			// 一列属性的值已经遍历完毕,复制到map属性表中
			attrValue.put(data[0][j], tempValues);
		}
	}

	/**
	 * 计算机基尼指数
	 * 
	 * @param remainData
	 *            剩余数据
	 * @param attrName
	 *            属性名称
	 * @param value
	 *            属性值
	 * @param beLongValue
	 *            分类是否属于此属性值
	 * @return
	 */
	public double computeGini(String[][] remainData, String attrName,
			String value, boolean beLongValue) {
		// 实例总数
		int total = 0;
		// 正实例数
		int posNum = 0;
		// 负实例数
		int negNum = 0;
		// 基尼指数
		double gini = 0;

		// 还是按列从左往右遍历属性
		for (int j = 1; j < attrNames.length; j++) {
			// 找到了指定的属性
			if (attrName.equals(attrNames[j])) {
				for (int i = 1; i < remainData.length; i++) {
					// 统计正负实例按照属于和不属于值类型进行划分
					if ((beLongValue && remainData[i][j].equals(value))
							|| (!beLongValue && !remainData[i][j].equals(value))) {
						if (remainData[i][attrNames.length - 1].equals(YES)) {
							// 判断此行数据是否为正实例
							posNum++;
						} else {
							negNum++;
						}
					}
				}
			}
		}

		total = posNum + negNum;
		double posProbobly = (double) posNum / total;
		double negProbobly = (double) negNum / total;
		gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;

		// 返回计算基尼指数
		return gini;
	}

	/**
	 * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
	 * 
	 * @param remainData
	 *            剩余谁
	 * @param attrName
	 *            属性名称
	 * @return
	 */
	public String[] computeAttrGini(String[][] remainData, String attrName) {
		String[] str = new String[2];
		// 最终该属性的划分类型值
		String spiltValue = "";
		// 临时变量
		int tempNum = 0;
		// 保存属性的值划分时的最小的基尼指数
		double minGini = Integer.MAX_VALUE;
		ArrayList
valueTypes = attrValue.get(attrName); // 属于此属性值的实例数 HashMap belongNum = new HashMap<>(); for (String string : valueTypes) { // 重新计数的时候,数字归0 tempNum = 0; // 按列从左往右遍历属性 for (int j = 1; j < attrNames.length; j++) { // 找到了指定的属性 if (attrName.equals(attrNames[j])) { for (int i = 1; i < remainData.length; i++) { // 统计正负实例按照属于和不属于值类型进行划分 if (remainData[i][j].equals(string)) { tempNum++; } } } } belongNum.put(string, tempNum); } double tempGini = 0; double posProbably = 1.0; double negProbably = 1.0; for (String string : valueTypes) { tempGini = 0; posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1); negProbably = 1 - posProbably; tempGini += posProbably * computeGini(remainData, attrName, string, true); tempGini += negProbably * computeGini(remainData, attrName, string, false); if (tempGini < minGini) { minGini = tempGini; spiltValue = string; } } str[0] = spiltValue; str[1] = minGini + ""; return str; } public void buildDecisionTree(TreeNode node, String parentAttrValue, String[][] remainData, ArrayList remainAttr, boolean beLongParentValue) { // 属性划分值 String valueType = ""; // 划分属性名称 String spiltAttrName = ""; double minGini = Integer.MAX_VALUE; double tempGini = 0; // 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值 String[] giniArray; if (beLongParentValue) { node.setParentAttrValue(parentAttrValue); } else { node.setParentAttrValue("!" + parentAttrValue); } if (remainAttr.size() == 0) { if (remainData.length > 1) { ArrayList indexArray = new ArrayList<>(); for (int i = 1; i < remainData.length; i++) { indexArray.add(remainData[i][0]); } node.setDataIndex(indexArray); } // System.out.println("attr remain null"); return; } for (String str : remainAttr) { giniArray = computeAttrGini(remainData, str); tempGini = Double.parseDouble(giniArray[1]); if (tempGini < minGini) { spiltAttrName = str; minGini = tempGini; valueType = giniArray[0]; } } // 移除划分属性 remainAttr.remove(spiltAttrName); node.setAttrName(spiltAttrName); // 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点 TreeNode[] childNode = new TreeNode[2]; String[][] rData; boolean[] bArray = new boolean[] { true, false }; for (int i = 0; i < bArray.length; i++) { // 二元划分属于属性值的划分 rData = removeData(remainData, spiltAttrName, valueType, bArray[i]); boolean sameClass = true; ArrayList indexArray = new ArrayList<>(); for (int k = 1; k < rData.length; k++) { indexArray.add(rData[k][0]); // 判断是否为同一类的 if (!rData[k][attrNames.length - 1] .equals(rData[1][attrNames.length - 1])) { // 只要有1个不相等,就不是同类型的 sameClass = false; break; } } childNode[i] = new TreeNode(); if (!sameClass) { // 创建新的对象属性,对象的同个引用会出错 ArrayList rAttr = new ArrayList<>(); for (String str : remainAttr) { rAttr.add(str); } buildDecisionTree(childNode[i], valueType, rData, rAttr, bArray[i]); } else { String pAtr = (bArray[i] ? valueType : "!" + valueType); childNode[i].setParentAttrValue(pAtr); childNode[i].setDataIndex(indexArray); } } node.setChildAttrNode(childNode); } /** * 属性划分完毕,进行数据的移除 * * @param srcData * 源数据 * @param attrName * 划分的属性名称 * @param valueType * 属性