设为首页 加入收藏

TOP

随机森林和GBDT的学习(五)
2015-07-24 11:55:25 来源: 作者: 【 】 浏览:16
Tags:随机 森林 GBDT 学习
的值类型 * @parame beLongValue 分类是否属于此值类型 */ private String[][] removeData(String[][] srcData, String attrName, String valueType, boolean beLongValue) { String[][] desDataArray; ArrayList desData = new ArrayList<>(); // 待删除数据 ArrayList selectData = new ArrayList<>(); selectData.add(attrNames); // 数组数据转化到列表中,方便移除 for (int i = 0; i < srcData.length; i++) { desData.add(srcData[i]); } // 还是从左往右一列列的查找 for (int j = 1; j < attrNames.length; j++) { if (attrNames[j].equals(attrName)) { for (int i = 1; i < desData.size(); i++) { if (desData.get(i)[j].equals(valueType)) { // 如果匹配这个数据,则移除其他的数据 selectData.add(desData.get(i)); } } } } if (beLongValue) { desDataArray = new String[selectData.size()][]; selectData.toArray(desDataArray); } else { // 属性名称行不移除 selectData.remove(attrNames); // 如果是划分不属于此类型的数据时,进行移除 desData.removeAll(selectData); desDataArray = new String[desData.size()][]; desData.toArray(desDataArray); } return desDataArray; } /** * 构造分类回归树,并返回根节点 * @return */ public TreeNode startBuildingTree() { initAttrValue(); ArrayList remainAttr = new ArrayList<>(); // 添加属性,除了最后一个类标号属性 for (int i = 1; i < attrNames.length - 1; i++) { remainAttr.add(attrNames[i]); } TreeNode rootNode = new TreeNode(); buildDecisionTree(rootNode, "", data, remainAttr, false); setIndexAndAlpah(rootNode, 0, false); showDecisionTree(rootNode, 1); return rootNode; } /** * 显示决策树 * * @param node * 待显示的节点 * @param blankNum * 行空格符,用于显示树型结构 */ private void showDecisionTree(TreeNode node, int blankNum) { System.out.println(); for (int i = 0; i < blankNum; i++) { System.out.print(" "); } System.out.print("--"); // 显示分类的属性值 if (node.getParentAttrValue() != null && node.getParentAttrValue().length() > 0) { System.out.print(node.getParentAttrValue()); } else { System.out.print("--"); } System.out.print("--"); if (node.getDataIndex() != null && node.getDataIndex().size() > 0) { String i = node.getDataIndex().get(0); System.out.print("【" + node.getNodeIndex() + "】类别:" + data[Integer.parseInt(i)][attrNames.length - 1]); System.out.print("["); for (String index : node.getDataIndex()) { System.out.print(index + ", "); } System.out.print("]"); } else { // 递归显示子节点 System.out.print("【" + node.getNodeIndex() + ":" + node.getAttrName() + "】"); if (node.getChildAttrNode() != null) { for (TreeNode childNode : node.getChildAttrNode()) { showDecisionTree(childNode, 2 * blankNum); } } else { System.out.print("【 Child Null】"); } } } /** * 为节点设置序列号,并计算每个节点的误差率,用于后面剪枝 * * @param node * 开始的时候传入的是根节点 * @param index * 开始的索引号,从1开始 * @param ifCutNode * 是否需要剪枝 */ private void setIndexAndAlpah(TreeNode node, int index, boolean ifCutNode) { TreeNode tempNode; // 最小误差代价节点,即将被剪枝的节点 TreeNode minAlphaNode = null; double minAlpah = Integer.MAX_VALUE; Queue nodeQueue = new LinkedList(); nodeQueue.add(node); while (nodeQueue.size() > 0) { index++; // 从队列头部获取首个节点 tempNode = nodeQueue.poll(); tempNode.setNodeIndex(index); if (tempNode.getChildAttrNode() != null) { for (TreeNode childNode : tempNode.getChildAttrNode()) { nodeQueue.add(childNode); } computeAlpha(tempNode); if (tempNode.getAlpha() < minAlpah) { minAlphaNode = tempNode; minAlpah = tempNode.getAlpha(); } else if (tempNode.getAlpha() == minAlpah) { // 如果误差代价值一样,比较包含的叶子节点个数,剪枝有多叶子节点数的节点 if (tempNode.getLeafNum() > minAlphaNode.getLeafNum()) { minAlphaNode = tempNode; } } } } if (ifCutNode) { // 进行树的剪枝,让其左右孩子节点为null minAlphaNode.setChildAttrNode(null); } } /** * 为非叶子节点计算误差代价,这里的后剪枝法用的是CCP代价复杂度剪枝 * * @param node * 待计算的非叶子节点 */ private void computeAlpha(TreeNode node) { double rt = 0; double Rt = 0; double alpha = 0; // 当前节点的数据总数 int sumNum = 0; // 最少的偏差数 int minNum = 0; ArrayList dataIndex; ArrayList leafNodes = new ArrayList<>(); addLeafNode(node, leafNodes); node.setLeafNum(leafNodes.size()); for (TreeNode attrNode : leafNodes) { dataIndex = attrNode.getDataIndex(); int num = 0; sumNum += dataIndex.size(); for (String s : dataIndex) { // 统计分类数据中的正负实例数 if (data[Integer.parseInt(s)][attrNames.length - 1].equals(YES)) { num++; } } minNum += num; // 取小数量的值部分 if (1.0 * num / dataIndex.size() > 0.5) { num = dataIndex.size() - num; } rt += (1.0 * num / (data.length - 1)); } //同样取出少偏差的那部分 if (1.0 * minNum / sumNum > 0.5) { minNum = sumNum - minNum; } Rt = 1.0 * minNum / (data.length - 1); alpha = 1.0 * (Rt - rt) / (leafNodes.size() - 1);
首页 上一页 2 3 4 5 6 下一页 尾页 5/6/6
】【打印繁体】【投稿】【收藏】 【推荐】【举报】【评论】 【关闭】 【返回顶部
分享到: 
上一篇【翻译自mos文章】在alter/drop表.. 下一篇干货分享:DBA专家门诊一期:索引..

评论

帐  号: 密码: (新用户注册)
验 证 码:
表  情:
内  容:

·Announcing October (2025-12-24 15:18:16)
·MySQL有什么推荐的学 (2025-12-24 15:18:13)
·到底应该用MySQL还是 (2025-12-24 15:18:11)
·进入Linux世界大门的 (2025-12-24 14:51:47)
·Download Linux | Li (2025-12-24 14:51:44)