设为首页 加入收藏

TOP

随机森林和GBDT的学习(四)
2015-07-24 11:55:25 来源: 作者: 【 】 浏览:17
Tags:随机 森林 GBDT 学习
ree = new DecisionTree(datas); return tree; } /** * 构造随机森林 */ public void constructRandomTree() { DecisionTree tree; random = new Random(); decisionForest = new ArrayList<>(); System.out.println("下面是随机森林中的决策树:"); // 构造决策树加入森林中 for (int i = 0; i < treeNum; i++) { System.out.println("\n决策树" + (i+1)); tree = produceDecisionTree(); decisionForest.add(tree); } } /** * 根据给定的属性条件进行类别的决策 * * @param features * 给定的已知的属性描述 * @return */ public String judgeClassType(String features) { // 结果类型值 String resultClassType = ""; String classType = ""; int count = 0; Map type2Num = new HashMap(); for (DecisionTree tree : decisionForest) { classType = tree.decideClassType(features); if (type2Num.containsKey(classType)) { // 如果类别已经存在,则使其计数加1 count = type2Num.get(classType); count++; } else { count = 1; } type2Num.put(classType, count); } // 选出其中类别支持计数最多的一个类别值 count = -1; for (Map.Entry entry : type2Num.entrySet()) { if ((int) entry.getValue() > count) { count = (int) entry.getValue(); resultClassType = (String) entry.getKey(); } } return resultClassType; } } CART算法工具类CARTTool.java:

?

?

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.Queue;

/**
 * CART分类回归树算法工具类
 * 
 * @author lyq
 * 
 */
public class CARTTool {
	// 类标号的值类型
	private final String YES = "Yes";
	private final String NO = "No";

	// 所有属性的类型总数,在这里就是data源数据的列数
	private int attrNum;
	private String filePath;
	// 初始源数据,用一个二维字符数组存放模仿表格数据
	private String[][] data;
	// 数据的属性行的名字
	private String[] attrNames;
	// 每个属性的值所有类型
	private HashMap> attrValue;

	public CARTTool(ArrayList dataArray) {
		attrValue = new HashMap<>();
		readData(dataArray);
	}

	/**
	 * 根据随机选取的样本数据进行初始化
	 * @param dataArray
	 * 已经读入的样本数据
	 */
	public void readData(ArrayList dataArray) {
		data = new String[dataArray.size()][];
		dataArray.toArray(data);
		attrNum = data[0].length;
		attrNames = data[0];
	}

	/**
	 * 首先初始化每种属性的值的所有类型,用于后面的子类熵的计算时用
	 */
	public void initAttrValue() {
		ArrayList tempValues;

		// 按照列的方式,从左往右找
		for (int j = 1; j < attrNum; j++) {
			// 从一列中的上往下开始寻找值
			tempValues = new ArrayList<>();
			for (int i = 1; i < data.length; i++) {
				if (!tempValues.contains(data[i][j])) {
					// 如果这个属性的值没有添加过,则添加
					tempValues.add(data[i][j]);
				}
			}

			// 一列属性的值已经遍历完毕,复制到map属性表中
			attrValue.put(data[0][j], tempValues);
		}
	}

	/**
	 * 计算机基尼指数
	 * 
	 * @param remainData
	 *            剩余数据
	 * @param attrName
	 *            属性名称
	 * @param value
	 *            属性值
	 * @param beLongValue
	 *            分类是否属于此属性值
	 * @return
	 */
	public double computeGini(String[][] remainData, String attrName,
			String value, boolean beLongValue) {
		// 实例总数
		int total = 0;
		// 正实例数
		int posNum = 0;
		// 负实例数
		int negNum = 0;
		// 基尼指数
		double gini = 0;

		// 还是按列从左往右遍历属性
		for (int j = 1; j < attrNames.length; j++) {
			// 找到了指定的属性
			if (attrName.equals(attrNames[j])) {
				for (int i = 1; i < remainData.length; i++) {
					// 统计正负实例按照属于和不属于值类型进行划分
					if ((beLongValue && remainData[i][j].equals(value))
							|| (!beLongValue && !remainData[i][j].equals(value))) {
						if (remainData[i][attrNames.length - 1].equals(YES)) {
							// 判断此行数据是否为正实例
							posNum++;
						} else {
							negNum++;
						}
					}
				}
			}
		}

		total = posNum + negNum;
		double posProbobly = (double) posNum / total;
		double negProbobly = (double) negNum / total;
		gini = 1 - posProbobly * posProbobly - negProbobly * negProbobly;

		// 返回计算基尼指数
		return gini;
	}

	/**
	 * 计算属性划分的最小基尼指数,返回最小的属性值划分和最小的基尼指数,保存在一个数组中
	 * 
	 * @param remainData
	 *            剩余谁
	 * @param attrName
	 *            属性名称
	 * @return
	 */
	public String[] computeAttrGini(String[][] remainData, String attrName) {
		String[] str = new String[2];
		// 最终该属性的划分类型值
		String spiltValue = "";
		// 临时变量
		int tempNum = 0;
		// 保存属性的值划分时的最小的基尼指数
		double minGini = Integer.MAX_VALUE;
		ArrayList valueTypes = attrValue.get(attrName);
		// 属于此属性值的实例数
		HashMap belongNum = new HashMap<>();

		for (String string : valueTypes) {
			// 重新计数的时候,数字归0
			tempNum = 0;
			// 按列从左往右遍历属性
			for (int j = 1; j < attrNames.length; j++) {
				// 找到了指定的属性
				if (attrName.equals(attrNames[j])) {
					for (int i = 1; i < remainData.length; i++) {
						// 统计正负实例按照属于和不属于值类型进行划分
						if (remainData[i][j].equals(string)) {
							tempNum++;
						}
					}
				}
			}

			belongNum.put(string, tempNum);
		}

		double tempGini = 0;
		double posProbably = 1.0;
		double negProbably = 1.0;
		for (String string : valueTypes) {
			tempGini = 0;

			posProbably = 1.0 * belongNum.get(string) / (remainData.length - 1);
			negProbably = 1 - posProbably;

			tempGini += posProbably
					* computeGini(remainData, attrName, string, true);
			tempGini += negProbably
					* computeGini(remainData, attrName, string, false);

			if (tempGini < minGini) {
				minGini = tempGini;
				spiltValue = string;
			}
		}

		str[0] = spiltValue;
		str[1] = minGini + "";

		return str;
	}

	public void buildDecisionTree(TreeNode node, String parentAttrValue,
			String[][] remainData, ArrayList remainAttr,
			boolean beLongParentValue) {
		// 属性划分值
		String valueType = "";
		// 划分属性名称
		String spiltAttrName = "";
		double minGini = Integer.MAX_VALUE;
		double tempGini = 0;
		// 基尼指数数组,保存了基尼指数和此基尼指数的划分属性值
		String[] giniArray;

		if (beLongParentValue) {
			node.setParentAttrValue(parentAttrValue);
		} else {
			node.setParentAttrValue("!" + parentAttrValue);
		}

		if (remainAttr.size() == 0) {
			if (remainData.length > 1) {
				ArrayList indexArray = new ArrayList<>();
				for (int i = 1; i < remainData.length; i++) {
					indexArray.add(remainData[i][0]);
				}
				node.setDataIndex(indexArray);
			}
		//	System.out.println("attr remain null");
			return;
		}

		for (String str : remainAttr) {
			giniArray = computeAttrGini(remainData, str);
			tempGini = Double.parseDouble(giniArray[1]);

			if (tempGini < minGini) {
				spiltAttrName = str;
				minGini = tempGini;
				valueType = giniArray[0];
			}
		}
		// 移除划分属性
		remainAttr.remove(spiltAttrName);
		node.setAttrName(spiltAttrName);

		// 孩子节点,分类回归树中,每次二元划分,分出2个孩子节点
		TreeNode[] childNode = new TreeNode[2];
		String[][] rData;

		boolean[] bArray = new boolean[] { true, false };
		for (int i = 0; i < bArray.length; i++) {
			// 二元划分属于属性值的划分
			rData = removeData(remainData, spiltAttrName, valueType, bArray[i]);

			boolean sameClass = true;
			ArrayList indexArray = new ArrayList<>();
			for (int k = 1; k < rData.length; k++) {
				indexArray.add(rData[k][0]);
				// 判断是否为同一类的
				if (!rData[k][attrNames.length - 1]
						.equals(rData[1][attrNames.length - 1])) {
					// 只要有1个不相等,就不是同类型的
					sameClass = false;
					break;
				}
			}

			childNode[i] = new TreeNode();
			if (!sameClass) {
				// 创建新的对象属性,对象的同个引用会出错
				ArrayList rAttr = new ArrayList<>();
				for (String str : remainAttr) {
					rAttr.add(str);
				}
				buildDecisionTree(childNode[i], valueType, rData, rAttr,
						bArray[i]);
			} else {
				String pAtr = (bArray[i] ? valueType : "!" + valueType);
				childNode[i].setParentAttrValue(pAtr);
				childNode[i].setDataIndex(indexArray);
			}
		}

		node.setChildAttrNode(childNode);
	}

	/**
	 * 属性划分完毕,进行数据的移除
	 * 
	 * @param srcData
	 *            源数据
	 * @param attrName
	 *            划分的属性名称
	 * @param valueType
	 *            属性
首页 上一页 1 2 3 4 5 6 下一页 尾页 4/6/6
】【打印繁体】【投稿】【收藏】 【推荐】【举报】【评论】 【关闭】 【返回顶部
分享到: 
上一篇【翻译自mos文章】在alter/drop表.. 下一篇干货分享:DBA专家门诊一期:索引..

评论

帐  号: 密码: (新用户注册)
验 证 码:
表  情:
内  容:

·Announcing October (2025-12-24 15:18:16)
·MySQL有什么推荐的学 (2025-12-24 15:18:13)
·到底应该用MySQL还是 (2025-12-24 15:18:11)
·进入Linux世界大门的 (2025-12-24 14:51:47)
·Download Linux | Li (2025-12-24 14:51:44)