设为首页 加入收藏

TOP

xgboost on spark
2018-12-02 17:52:29 】 浏览:354
Tags:xgboost spark
背景
项目需要预测出每一类别的概率,spark ml、mlib中自带算法只能预测出所属类别满足不了需求,因此找到此算法。

版本
spark1.6只能用XGBoost0.7之前的版本,此版本训练及预测只能使用rdd不能用df造成一定的不便,预测出的结果只有概率值,需自己与原始数据关联得到完整的记录,最大概率所属类别需自己算出。因此选择了spark2.0与XGBoost0.7。

scala代码
/**
 * train XGBoost model with the DataFrame-represented data
 *  trainingData the trainingset represented as DataFrame
 *  params Map containing the parameters to configure XGBoost
 *  round the number of iterations
 *  nWorkers the number of xgboost workers, 0 by default which means that the number of
 *                 workers equals to the partition number of trainingData RDD
 *  obj the user-defined objective function, null by default
 *  eva l the user-defined eva luation function, null by default
 *  useExternalMemory indicate whether to use external memory cache, by setting this flag as
 *                           true, the user may save the RAM cost for running XGBoost within Spark
 * missing the value represented the missing value in the dataset
 * featureCol the name of input column, "features" as default value
 *  labelCol the name of output column, "label" as default value
 */

val maxDepth = args(0).toInt
val numRound = args(1).toInt
val nworker = args(2).toInt
val paramMap = List(
  "eta" -> 0.01, //学习率
  "gamma" -> 0.1, //用于控制是否后剪枝的参数,越大越保守,一般0.1、0.2这样子。
  "lambda" -> 2, //控制模型复杂度的权重值的L2正则化项参数,参数越大,模型越不容易过拟合。
  "subsample" -> 0.8, //随机采样训练样本
  "colsample_bytree" -> 0.8, //生成树时进行的列采样
  "max_depth" -> maxDepth, //构建树的深度,越大越容易过拟合
  "min_child_weight" -> 5,
  "objective" -> "multi:softprob",  //定义学习任务及相应的学习目标
  "eva l_metric" -> "merror",
  "num_class" -> 21
).toMap

val model:XGBoostModel = XGBoost.trainWithDataFrame(vecDF, paramMap, numRound, nworker,
  useExternalMemory = true,
  featureCol = "features",
  labelCol = "label",
  missing = 0.0f)

//predict the test set
val predict:DataFrame = model.transform(vecDF)

注意partition、work、excutor的对应关系
】【打印繁体】【投稿】【收藏】 【推荐】【举报】【评论】 【关闭】 【返回顶部
上一篇Spark架构及运算逻辑 下一篇Spark调优(数据序列化和内存调优..

最新文章

热门文章

Hot 文章

Python

C 语言

C++基础

大数据基础

linux编程基础

C/C++面试题目