I wanted to circulate an alternative API for addressing the issue of model specific optimizations in MLlib, https://issues.apache.org/jira/browse/SPARK-22126. This builds on top of the work by Weichen and the API he's proposed. The idea here is break up the description of how multiple model instances can be fit in parallel and the actual execution of this fitting. I propose that we add a new method called fitMultiple
(working title) which returns Array[Callable[Model[_]]]
. fitMultiple(dataset, paramMaps).map(callable => callable.call())
should be the same as paramMaps.map(pm => fit(dataset, pm))
, but could include some model specific performance optimizations. Callables returned by fitMultiple
should be designed to be thread safe so that they can be run in parallel (eg by CrossValidator).
We would provide a default implementation of fitMultiple(...)
so that even developers who write their own Estimators/Transformers would not need to be implement this method unless they wanted to include performance optimizations.
def fitMultiple(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]] = {
paramMaps.map { paramMap =>
new Callable[M] {
override def call(): M = fit(dataset, paramMap)
}
}
}
We would also update fit(dataset, paramMaps)
to use fitMultiple
so that it would be a wrapper around fitMultiple
which would benefit from any performance optimizations but would run the fitting synchronously in the current thread.
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
fitMultiple(dataset, paramMaps).map{ _.call() }
}
I thought about this API:
def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]]
, but, I am afraid that it cannot support some kind of model specific optimization. Let me reference an example from @jkbradley to explain this:In this example, we can see that, Seq[M] from
((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10))
can only be computed in one thread code, Seq[M] from((regParam=0.3, maxIter=5), (regParam=0.3, maxIter=10))
in another thread. If we use this API:def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]]
, it will return 4Callable[M]
objects, the logic ofcomputing Seq[M] from ((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10))
cannot be split into two callable function if we want to do this optimization. In this example, there're 4 paramMaps, but we can at most generate two threads to compute the models for them.So, I think the API should be changed to the following style:
It allow
callable.call()
to return multiple models, and return type isMap[Int, M]
, key is integer, used to mark the paramMap index for corresponding model. Use the example above, there're 4 paramMaps, but only return 2 callable objects, one callable object for((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10))
, another one for((regParam=0.3, maxIter=5), (regParam=0.3, maxIter=10))
.and the default
fitCallables/fit with paramMaps
can be implemented as following:If use the API I proposed above, the code in CrossValidation
can be changed to:
@jkbradley @MrBago @sueann