随机森林和GBDT的学习(二)

2015-07-24 11:55:25 · 作者: · 浏览: 28
tAttrValue; } public void setParentAttrValue(String parentAttrValue) { this.parentAttrValue = parentAttrValue; } public TreeNode[] getChildAttrNode() { return childAttrNode; } public void setChildAttrNode(TreeNode[] childAttrNode) { this.childAttrNode = childAttrNode; } public ArrayList getDataIndex() { return dataIndex; } public void setDataIndex(ArrayList dataIndex) { this.dataIndex = dataIndex; } public int getLeafNum() { return leafNum; } public void setLeafNum(int leafNum) { this.leafNum = leafNum; } } 决策树类DecisionTree.java:

?

?

package DataMining_RandomForest;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

/**
 * 决策树
 * 
 * @author lyq
 * 
 */
public class DecisionTree {
	// 树的根节点
	TreeNode rootNode;
	// 数据的属性列名称
	String[] featureNames;
	// 这棵树所包含的数据
	ArrayList datas;
	// 决策树构造的的工具类
	CARTTool tool;

	public DecisionTree(ArrayList datas) {
		this.datas = datas;
		this.featureNames = datas.get(0);

		tool = new CARTTool(datas);
		// 通过CART工具类进行决策树的构建,并返回树的根节点
		rootNode = tool.startBuildingTree();
	}

	/**
	 * 根据给定的数据特征描述进行类别的判断
	 * 
	 * @param features
	 * @return
	 */
	public String decideClassType(String features) {
		String classType = "";
		// 查询属性组
		String[] queryFeatures;
		// 在本决策树中对应的查询的属性值描述
		ArrayList featureStrs;

		featureStrs = new ArrayList<>();
		queryFeatures = features.split(",");

		String[] array;
		for (String name : featureNames) {
			for (String featureva lue : queryFeatures) {
				array = featureva lue.split("=");
				// 将对应的属性值加入到列表中
				if (array[0].equals(name)) {
					featureStrs.add(array);
				}
			}
		}

		// 开始从根据节点往下递归搜索
		classType = recusiveSearchClassType(rootNode, featureStrs);

		return classType;
	}

	/**
	 * 递归搜索树,查询属性的分类类别
	 * 
	 * @param node
	 *            当前搜索到的节点
	 * @param remainFeatures
	 *            剩余未判断的属性
	 * @return
	 */
	private String recusiveSearchClassType(TreeNode node,
			ArrayList
remainFeatures) { String classType = null; // 如果节点包含了数据的id索引,说明已经分类到底了 if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { classType = judgeClassType(node.getDataIndex()); return classType; } // 取出剩余属性中的一个匹配属性作为当前的判断属性名称 String[] currentFeature = null; for (String[] featureva lue : remainFeatures) { if (node.getAttrName().equals(featureva lue[0])) { currentFeature = featureva lue; break; } } for (TreeNode childNode : node.getChildAttrNode()) { // 寻找子节点中属于此属性值的分支 if (childNode.getParentAttrValue().equals(currentFeature[1])) { remainFeatures.remove(currentFeature); classType = recusiveSearchClassType(childNode, remainFeatures); // 如果找到了分类结果,则直接挑出循环 break; }else{ //进行第二种情况的判断加上!符号的情况 String value = childNode.getParentAttrValue(); if(value.charAt(0) == '!'){ //去掉第一个!字符 value = value.substring(1, value.length()); if(!value.equals(currentFeature[1])){ remainFeatures.remove(currentFeature); classType = recusiveSearchClassType(childNode, remainFeatures); break; }