|
tAttrValue;
}
public void setParentAttrValue(String parentAttrValue) {
this.parentAttrValue = parentAttrValue;
}
public TreeNode[] getChildAttrNode() {
return childAttrNode;
}
public void setChildAttrNode(TreeNode[] childAttrNode) {
this.childAttrNode = childAttrNode;
}
public ArrayList getDataIndex() {
return dataIndex;
}
public void setDataIndex(ArrayList dataIndex) {
this.dataIndex = dataIndex;
}
public int getLeafNum() {
return leafNum;
}
public void setLeafNum(int leafNum) {
this.leafNum = leafNum;
}
}
决策树类DecisionTree.java:
?
?
package DataMining_RandomForest;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
/**
* 决策树
*
* @author lyq
*
*/
public class DecisionTree {
// 树的根节点
TreeNode rootNode;
// 数据的属性列名称
String[] featureNames;
// 这棵树所包含的数据
ArrayList datas;
// 决策树构造的的工具类
CARTTool tool;
public DecisionTree(ArrayList datas) {
this.datas = datas;
this.featureNames = datas.get(0);
tool = new CARTTool(datas);
// 通过CART工具类进行决策树的构建,并返回树的根节点
rootNode = tool.startBuildingTree();
}
/**
* 根据给定的数据特征描述进行类别的判断
*
* @param features
* @return
*/
public String decideClassType(String features) {
String classType = "";
// 查询属性组
String[] queryFeatures;
// 在本决策树中对应的查询的属性值描述
ArrayList featureStrs;
featureStrs = new ArrayList<>();
queryFeatures = features.split(",");
String[] array;
for (String name : featureNames) {
for (String featureva lue : queryFeatures) {
array = featureva lue.split("=");
// 将对应的属性值加入到列表中
if (array[0].equals(name)) {
featureStrs.add(array);
}
}
}
// 开始从根据节点往下递归搜索
classType = recusiveSearchClassType(rootNode, featureStrs);
return classType;
}
/**
* 递归搜索树,查询属性的分类类别
*
* @param node
* 当前搜索到的节点
* @param remainFeatures
* 剩余未判断的属性
* @return
*/
private String recusiveSearchClassType(TreeNode node,
ArrayList remainFeatures) {
String classType = null;
// 如果节点包含了数据的id索引,说明已经分类到底了
if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
classType = judgeClassType(node.getDataIndex());
return classType;
}
// 取出剩余属性中的一个匹配属性作为当前的判断属性名称
String[] currentFeature = null;
for (String[] featureva lue : remainFeatures) {
if (node.getAttrName().equals(featureva lue[0])) {
currentFeature = featureva lue;
break;
}
}
for (TreeNode childNode : node.getChildAttrNode()) {
// 寻找子节点中属于此属性值的分支
if (childNode.getParentAttrValue().equals(currentFeature[1])) {
remainFeatures.remove(currentFeature);
classType = recusiveSearchClassType(childNode, remainFeatures);
// 如果找到了分类结果,则直接挑出循环
break;
}else{
//进行第二种情况的判断加上!符号的情况
String value = childNode.getParentAttrValue();
if(value.charAt(0) == '!'){
//去掉第一个!字符
value = value.substring(1, value.length());
if(!value.equals(currentFeature[1])){
remainFeatures.remove(currentFeature);
classType = recusiveSearchClassType(childNode, remainFeatures);
break;
}
|