|
的值类型
* @parame beLongValue 分类是否属于此值类型
*/
private String[][] removeData(String[][] srcData, String attrName,
String valueType, boolean beLongValue) {
String[][] desDataArray;
ArrayList desData = new ArrayList<>();
// 待删除数据
ArrayList selectData = new ArrayList<>();
selectData.add(attrNames);
// 数组数据转化到列表中,方便移除
for (int i = 0; i < srcData.length; i++) {
desData.add(srcData[i]);
}
// 还是从左往右一列列的查找
for (int j = 1; j < attrNames.length; j++) {
if (attrNames[j].equals(attrName)) {
for (int i = 1; i < desData.size(); i++) {
if (desData.get(i)[j].equals(valueType)) {
// 如果匹配这个数据,则移除其他的数据
selectData.add(desData.get(i));
}
}
}
}
if (beLongValue) {
desDataArray = new String[selectData.size()][];
selectData.toArray(desDataArray);
} else {
// 属性名称行不移除
selectData.remove(attrNames);
// 如果是划分不属于此类型的数据时,进行移除
desData.removeAll(selectData);
desDataArray = new String[desData.size()][];
desData.toArray(desDataArray);
}
return desDataArray;
}
/**
* 构造分类回归树,并返回根节点
* @return
*/
public TreeNode startBuildingTree() {
initAttrValue();
ArrayList remainAttr = new ArrayList<>();
// 添加属性,除了最后一个类标号属性
for (int i = 1; i < attrNames.length - 1; i++) {
remainAttr.add(attrNames[i]);
}
TreeNode rootNode = new TreeNode();
buildDecisionTree(rootNode, "", data, remainAttr, false);
setIndexAndAlpah(rootNode, 0, false);
showDecisionTree(rootNode, 1);
return rootNode;
}
/**
* 显示决策树
*
* @param node
* 待显示的节点
* @param blankNum
* 行空格符,用于显示树型结构
*/
private void showDecisionTree(TreeNode node, int blankNum) {
System.out.println();
for (int i = 0; i < blankNum; i++) {
System.out.print(" ");
}
System.out.print("--");
// 显示分类的属性值
if (node.getParentAttrValue() != null
&& node.getParentAttrValue().length() > 0) {
System.out.print(node.getParentAttrValue());
} else {
System.out.print("--");
}
System.out.print("--");
if (node.getDataIndex() != null && node.getDataIndex().size() > 0) {
String i = node.getDataIndex().get(0);
System.out.print("【" + node.getNodeIndex() + "】类别:"
+ data[Integer.parseInt(i)][attrNames.length - 1]);
System.out.print("[");
for (String index : node.getDataIndex()) {
System.out.print(index + ", ");
}
System.out.print("]");
} else {
// 递归显示子节点
System.out.print("【" + node.getNodeIndex() + ":"
+ node.getAttrName() + "】");
if (node.getChildAttrNode() != null) {
for (TreeNode childNode : node.getChildAttrNode()) {
showDecisionTree(childNode, 2 * blankNum);
}
} else {
System.out.print("【 Child Null】");
}
}
}
/**
* 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝
*
* @param node
* 开始的时候传入的是根节点
* @param index
* 开始的索引号,从1开始
* @param ifCutNode
* 是否需要剪枝
*/
private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) {
TreeNode tempNode;
// 最小误差代价节点,即将被剪枝的节点
TreeNode minAlphaNode = null;
double minAlpah = Integer.MAX_VALUE;
Queue nodeQueue = new LinkedList();
nodeQueue.add(node);
while (nodeQueue.size() > 0) {
index++;
// 从队列头部获取首个节点
tempNode = nodeQueue.poll();
tempNode.setNodeIndex(index);
if (tempNode.getChildAttrNode() != null) {
for (TreeNode childNode : tempNode.getChildAttrNode()) {
nodeQueue.add(childNode);
}
computeAlpha(tempNode);
if (tempNode.getAlpha() < minAlpah) {
minAlphaNode = tempNode;
minAlpah = tempNode.getAlpha();
} else if (tempNode.getAlpha() == minAlpah) {
// 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点
if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) {
minAlphaNode = tempNode;
}
}
}
}
if (ifCutNode) {
// 进行树的剪枝,让其左右孩子节点为null
minAlphaNode.setChildAttrNode(null);
}
}
/**
* 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝
*
* @param node
* 待计算的非叶子节点
*/
private void computeAlpha(TreeNode node) {
double rt = 0;
double Rt = 0;
double alpha = 0;
// 当前节点的数据总数
int sumNum = 0;
// 最少的偏差数
int minNum = 0;
ArrayList dataIndex;
ArrayList leafNodes = new ArrayList<>();
addLeafNode(node, leafNodes);
node.setLeafNum(leafNodes.size());
for (TreeNode attrNode : leafNodes) {
dataIndex = attrNode.getDataIndex();
int num = 0;
sumNum += dataIndex.size();
for (String s : dataIndex) {
// 统计分类数据中的正负实例数
if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) {
num++;
}
}
minNum += num;
// 取小数量的值部分
if (1.0 * num / dataIndex.size() > 0.5) {
num = dataIndex.size() - num;
}
rt += (1.0 * num / (data.length - 1));
}
//同样取出少偏差的那部分
if (1.0 * minNum / sumNum > 0.5) {
minNum = sumNum - minNum;
}
Rt = 1.0 * minNum / (data.length - 1);
alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
|