Spark中机器学习源码分析
Spark的机器学习包中包含了几个重要的父类,分别介绍如下:
- Params
- Model
- Pipeline
- PipelineModel
- PipelineStage
首先介绍一下Spark中Params
这个Traits(特征,这是Scala中的基本概念,参考Scala和Spark中乱七八糟的符号)。Params
是Spark机器学习库中机器学习组件与参数相关的契约(Contract)。是所有机器学习模型都会继承并用到的类,主要是用来“管理”模型参数的。在org.apache.spark.ml
中的算法都是一个Transformer或者是一个Estimator,而这两个东西都是继承自PipelineStage的,PipelineStage是所有模型的父类,它本身集成了Params特征,因此,所有的算法也都拥有Params的特征了。Params建立的目标是为了方便获取的模型中的参数以及参数的相关解释。
Params里面定义了一系列关于不同变量类型的类,作为后面定义参数时用的类型。例如,如果你要定义一个整形(int)的参数,可以使用params.scala下面的IntParam类型
/**
* :: DeveloperApi ::
* Specialized version of `Param[Int]` for Java.
*/
@DeveloperApi
class IntParam(parent: String, name: String, doc: String, isValid: Int => Boolean)
extends Param[Int](parent, name, doc, isValid) {
def this(parent: String, name: String, doc: String) =
this(parent, name, doc, ParamValidators.alwaysTrue)
def this(parent: Identifiable, name: String, doc: String, isValid: Int => Boolean) =
this(parent.uid, name, doc, isValid)
def this(parent: Identifiable, name: String, doc: String) = this(parent.uid, name, doc)
/** Creates a param pair with the given value (for Java). */
override def w(value: Int): ParamPair[Int] = super.w(value)
override def jsonEncode(value: Int): String = {
compact(render(JInt(value)))
}
override def jsonDecode(json: String): Int = {
implicit val formats = DefaultFormats
parse(json).extract[Int]
}
}
定义方法如下,即首先初始化一个IntParam并给出名字和含义解释,然后调用Params特征下setDefault
方法即可设置该变量的默认值了:
val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or categories) into which data points are grouped. Must be >= 2.", ParamValidators.gtEq(2))
setDefault(numBuckets -> 2)
Params特征里面只有一个属性常量,是一个懒加载的方法得到的所有属性结果,它只在第一次被需要的时候加载。它是利用Java反射原理得到所有参数,首先列出来所有的公共的没有参数的方法,并只保留返回值类型是Param[_]类型的结果,这样就能得到所有声明为Param参数的属性了。也就是说如果模型的参数定义成Params下面的类型,那么就能用这个常量获取到:
lazy val params: Array[Param[_]] = {
val methods = this.getClass.getMethods
methods.filter { m =>
Modifier.isPublic(m.getModifiers) &&
classOf[Param[_]].isAssignableFrom(m.getReturnType) &&
m.getParameterTypes.isEmpty
}.sortBy(_.getName)
.map(m => m.invoke(this).asInstanceOf[Param[_]])
}
剩下的Params中的方法主要是获取变量、解释变量、设置变量等:
explainParam(Param[_]):String - 解释某个变量
explainParams():String - 解释所有变量
isSet(Param[_]):Boolean - 某个变量是否被设置过值
isDefined(Param[_]):Boolean - 某个变量是否被定义过
hasParam(String):Boolean - 是否有某个变量
getParam(String):Param[_] - 获取某个变量
set[T](Param[T],T):this.type - 设置某个变量
set(String, Any):this.type - 通过参数名设置某个变量
set(ParamPair[_]):this.type - 通过ParamPair设置某个对象
get[T](Param[T]):Option[T] - 获取某个用户设置的变量(Option类型,可以是null)
clear(Param[T]):this.type - 清空某个变量的值
getOrDefault[T](Param[T]):T - 获取某个变量的值,没有则获取默认值
$[T](Param[T]):T - 这个是$()表达式,是getOrDefault的傀儡方法
setDefault[T](Param[T],T):this.type - 设置默认值
setDefault[T](ParamPair[_]*):this.type - 设置默认值
getDefault[T](Param[T]):Option[T] - 获取默认值
hasDefault[T](Param[T]):Boolean - 是否有默认值
copy(extra: ParamMap):Params - 复制一个当前实例的副本,包含相同的UID,以及一些额外的参数,集成了Params的子类需要实现这个方法,并定义合适的返回值。注意,一般这个用在transformer中比较多,把转换后的列加到原来的DataFrame中去
defaultCopy[T <: Params](extra: ParamMap): T - 这个方法是上面copy的默认实现方式
extractParamMap(extra: ParamMap): ParamMap - 抽取原来的参数及其用户设置的值,并和提供的额外的参数组合在一起返回
extractParamMap(): ParamMap - 抽取原来的参数及其用户设置的值
还有个比较重要的方法是copyValue经常能用到,它是把当前石丽霞默认的参数及其值传给一个其他的实例to(注意,两个实例之间必须有共享的参数),同时把to中的参数添加额外的参数extra,在Transformer和Estimator中经常能用到:
/**
* Copies param values from this instance to another instance for params shared by them.
*
* This handles default Params and explicitly set Params separately.
* Default Params are copied from and to `defaultParamMap`, and explicitly set Params are
* copied from and to `paramMap`.
* Warning: This implicitly assumes that this [[Params]] instance and the target instance
* share the same set of default Params.
*
* @param to the target instance, which should work with the same set of default Params as this
* source instance
* @param extra extra params to be copied to the target's `paramMap`
* @return the target instance with param values copied
*/
protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
// 将原来的参数paramMap和新加的参数集合连接起来
val map = paramMap ++ extra
params.foreach { param =>
// copy default Params
if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
}
// copy explicitly set Params
if (map.contains(param) && to.hasParam(param.name)) {
to.set(param.name, map(param))
}
}
to
}
我们看一个而例子就知道了
package org.apache.spark.ml.feature
import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators, Params}
import org.apache.spark.ml.util.Identifiable
/**
* Created by d00454735 on 2018/9/8.
*/
object ParamTest {
def main(args: Array[String]): Unit = {
val test = new ParamTest()
test.setA(4)
println(test.a)
println(test.explainParams())
val test2 = new ParamTest2()
println("------------------------before copy------------------------")
println(test2.explainParams())
println("------------------------after copy------------------------")
test.copyValues(test2)
println(test2.explainParams())
}
}
class ParamTest(override val uid: String) extends Params {
val a = new IntParam(this, "a", "param test int param, must >2", ParamValidators.gtEq(2))
setDefault(a -> 2)
def setA(value: Int): this.type = set(a, value)
def getA: Int = getOrDefault(a)
override def copy(extra: ParamMap): ParamTest = defaultCopy(extra)
def this() = this(Identifiable.randomUID("paramTest"))
}
class ParamTest2(override val uid: String) extends Params {
val b = new IntParam(this, "b", "another param")
setDefault(b -> 0)
val a = new IntParam(this, "a", "anther")
setDefault(a -> 0)
def this() = this(Identifiable.randomUID("paramTest2"))
override def copy(extra: ParamMap): Params = defaultCopy(extra)
}
其输出结果如下,可以看到test2已经成功将test中的值copy到了自己:
paramTest_aad4637f7033__a
a: param test int param, must >2 (default: 2, current: 4)
------------------------before copy------------------------
a: anther (default: 0)
b: another param (default: 0)
------------------------after copy------------------------
a: anther (default: 2, current: 4)
b: another param (default: 0)
Model.scala
欢迎大家关注DataLearner官方微信,接受最新的AI技术推送
