|
}
}
}
return classType;
}
/**
* 根据得到的数据行分类进行类别的决策
*
* @param dataIndex
* 根据分类的数据索引号
* @return
*/
public String judgeClassType(ArrayList dataIndex) {
// 结果类型值
String resultClassType = "";
String classType = "";
int count = 0;
int temp = 0;
Map type2Num = new HashMap();
for (String index : dataIndex) {
temp = Integer.parseInt(index);
// 取最后一列的决策类别数据
classType = datas.get(temp)[featureNames.length - 1];
if (type2Num.containsKey(classType)) {
// 如果类别已经存在,则使其计数加1
count = type2Num.get(classType);
count++;
} else {
count = 1;
}
type2Num.put(classType, count);
}
// 选出其中类别支持计数最多的一个类别值
count = -1;
for (Map.Entry entry : type2Num.entrySet()) {
if ((int) entry.getValue() > count) {
count = (int) entry.getValue();
resultClassType = (String) entry.getKey();
}
}
return resultClassType;
}
}
随机森林算法工具类RandomForestTool.java:
?
?
package DataMining_RandomForest;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.Random;
/**
* 随机森林算法工具类
*
* @author lyq
*
*/
public class RandomForestTool {
// 测试数据文件地址
private String filePath;
// 决策树的样本占总数的占比率
private double sampleNumRatio;
// 样本数据的采集特征数量占总特征的比例
private double featureNumRatio;
// 决策树的采样样本数
private int sampleNum;
// 样本数据的采集采样特征数
private int featureNum;
// 随机森林中的决策树的数目,等于总的数据数/用于构造每棵树的数据的数量
private int treeNum;
// 随机数产生器
private Random random;
// 样本数据列属性名称行
private String[] featureNames;
// 原始的总的数据
private ArrayList totalDatas;
// 决策树森林
private ArrayList decisionForest;
public RandomForestTool(String filePath, double sampleNumRatio,
double featureNumRatio) {
this.filePath = filePath;
this.sampleNumRatio = sampleNumRatio;
this.featureNumRatio = featureNumRatio;
readDataFile();
}
/**
* 从文件中读取数据
*/
private void readDataFile() {
File file = new File(filePath);
ArrayList dataArray = new ArrayList();
try {
BufferedReader in = new BufferedReader(new FileReader(file));
String str;
String[] tempArray;
while ((str = in.readLine()) != null) {
tempArray = str.split(" ");
dataArray.add(tempArray);
}
in.close();
} catch (IOException e) {
e.getStackTrace();
}
totalDatas = dataArray;
featureNames = totalDatas.get(0);
sampleNum = (int) ((totalDatas.size() - 1) * sampleNumRatio);
//算属性数量的时候需要去掉id属性和决策属性,用条件属性计算
featureNum = (int) ((featureNames.length -2) * featureNumRatio);
// 算数量的时候需要去掉首行属性名称行
treeNum = (totalDatas.size() - 1) / sampleNum;
}
/**
* 产生决策树
*/
private DecisionTree produceDecisionTree() {
int temp = 0;
DecisionTree tree;
String[] tempData;
//采样数据的随机行号组
ArrayList sampleRandomNum;
//采样属性特征的随机列号组
ArrayList featureRandomNum;
ArrayList datas;
sampleRandomNum = new ArrayList<>();
featureRandomNum = new ArrayList<>();
datas = new ArrayList<>();
for(int i=0; i 0){
array[0] = temp + "";
}
temp++;
}
t |