随机森林和GBDT的学习(三)

2015-07-24 11:55:25 · 作者: · 浏览: 30
} } } return classType; } /** * 根据得到的数据行分类进行类别的决策 * * @param dataIndex * 根据分类的数据索引号 * @return */ public String judgeClassType(ArrayList dataIndex) { // 结果类型值 String resultClassType = ""; String classType = ""; int count = 0; int temp = 0; Map type2Num = new HashMap(); for (String index : dataIndex) { temp = Integer.parseInt(index); // 取最后一列的决策类别数据 classType = datas.get(temp)[featureNames.length - 1]; if (type2Num.containsKey(classType)) { // 如果类别已经存在,则使其计数加1 count = type2Num.get(classType); count++; } else { count = 1; } type2Num.put(classType, count); } // 选出其中类别支持计数最多的一个类别值 count = -1; for (Map.Entry entry : type2Num.entrySet()) { if ((int) entry.getValue() > count) { count = (int) entry.getValue(); resultClassType = (String) entry.getKey(); } } return resultClassType; } } 随机森林算法工具类RandomForestTool.java:

?

?

package DataMining_RandomForest;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;

/**
 * 随机森林算法工具类
 * 
 * @author lyq
 * 
 */
public class RandomForestTool {
	// 测试数据文件地址
	private String filePath;
	// 决策树的样本占总数的占比率
	private double sampleNumRatio;
	// 样本数据的采集特征数量占总特征的比例
	private double featureNumRatio;
	// 决策树的采样样本数
	private int sampleNum;
	// 样本数据的采集采样特征数
	private int featureNum;
	// 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量
	private int treeNum;
	// 随机数产生器
	private Random random;
	// 样本数据列属性名称行
	private String[] featureNames;
	// 原始的总的数据
	private ArrayList totalDatas;
	// 决策树森林
	private ArrayList decisionForest;

	public RandomForestTool(String filePath, double sampleNumRatio,
			double featureNumRatio) {
		this.filePath = filePath;
		this.sampleNumRatio = sampleNumRatio;
		this.featureNumRatio = featureNumRatio;

		readDataFile();
	}

	/**
	 * 从文件中读取数据
	 */
	private void readDataFile() {
		File file = new File(filePath);
		ArrayList dataArray = new ArrayList();

		try {
			BufferedReader in = new BufferedReader(new FileReader(file));
			String str;
			String[] tempArray;
			while ((str = in.readLine()) != null) {
				tempArray = str.split(" ");
				dataArray.add(tempArray);
			}
			in.close();
		} catch (IOException e) {
			e.getStackTrace();
		}

		totalDatas = dataArray;
		featureNames = totalDatas.get(0);
		sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio);
		//算属性数量的时候需要去掉id属性和决策属性,用条件属性计算
		featureNum = (int) ((featureNames.length -2) * featureNumRatio);
		// 算数量的时候需要去掉首行属性名称行
		treeNum = (totalDatas.size() - 1) / sampleNum;
	}

	/**
	 * 产生决策树
	 */
	private DecisionTree produceDecisionTree() {
		int temp = 0;
		DecisionTree tree;
		String[] tempData;
		//采样数据的随机行号组
		ArrayList
sampleRandomNum; //采样属性特征的随机列号组 ArrayList featureRandomNum; ArrayList datas; sampleRandomNum = new ArrayList<>(); featureRandomNum = new ArrayList<>(); datas = new ArrayList<>(); for(int i=0; i 0){ array[0] = temp + ""; } temp++; } t