Skip to content

Instantly share code, notes, and snippets.

@MrBago
Last active December 7, 2017 09:37
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MrBago/f501b9e7712dc6a67dc9fea24e309bf0 to your computer and use it in GitHub Desktop.
Save MrBago/f501b9e7712dc6a67dc9fea24e309bf0 to your computer and use it in GitHub Desktop.
api.md

I wanted to circulate an alternative API for addressing the issue of model specific optimizations in MLlib, https://issues.apache.org/jira/browse/SPARK-22126. This builds on top of the work by Weichen and the API he's proposed. The idea here is break up the description of how multiple model instances can be fit in parallel and the actual execution of this fitting. I propose that we add a new method called fitMultiple (working title) which returns Array[Callable[Model[_]]]. fitMultiple(dataset, paramMaps).map(callable => callable.call()) should be the same as paramMaps.map(pm => fit(dataset, pm)), but could include some model specific performance optimizations. Callables returned by fitMultiple should be designed to be thread safe so that they can be run in parallel (eg by CrossValidator).

We would provide a default implementation of fitMultiple(...) so that even developers who write their own Estimators/Transformers would not need to be implement this method unless they wanted to include performance optimizations.

def fitMultiple(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]] = {
  paramMaps.map { paramMap =>
    new Callable[M] {
      override def call(): M = fit(dataset, paramMap)
    }
  }
}

We would also update fit(dataset, paramMaps) to use fitMultiple so that it would be a wrapper around fitMultiple which would benefit from any performance optimizations but would run the fitting synchronously in the current thread.

def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
   fitMultiple(dataset, paramMaps).map{ _.call() }
}
@sueann
Copy link

sueann commented Dec 4, 2017

So this essentially makes this part: https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala#L149-L159 customizable for Estimator right? When I create a new Estimator, I am expected to:

  • Provide a fit() implementation that trains a model for a particular param set value
  • Optionally provide a fitCallables() implementation that trains a set of models for a set of param set values
    Is that right?

@MrBago
Copy link
Author

MrBago commented Dec 5, 2017

@SueAnne, yep that's the idea. fitCallable will ideally be lightweight (non-blocking) and allow the requester to schedule the actual computation in some thread pool, but that's more of a guideline than a requirement.

@WeichenXu123
Copy link

WeichenXu123 commented Dec 5, 2017

I thought about this API: def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]], but, I am afraid that it cannot support some kind of model specific optimization. Let me reference an example from @jkbradley to explain this:

It could be possible to still use the current parallelism and still allow for model-specific optimizations. For example, if we doing cross validation and have a param map with regParam = (0.1, 0.3) and maxIter = (5, 10). Lets say that the cross validator could know that maxIter is optimized for the model being evaluated (e.g. a new method in Estimator that return such params). It would then be straightforward for the cross validator to remove maxIter from the param map that will be parallelized over and use it to create 2 arrays of paramMaps: ((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10)) and ((regParam=0.3, maxIter=5), (regParam=0.3, maxIter=10)). It could then fit these 2 in parallel with calls to def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M].

In this example, we can see that, Seq[M] from ((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10)) can only be computed in one thread code, Seq[M] from ((regParam=0.3, maxIter=5), (regParam=0.3, maxIter=10)) in another thread. If we use this API: def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[M]], it will return 4 Callable[M] objects, the logic of computing Seq[M] from ((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10)) cannot be split into two callable function if we want to do this optimization. In this example, there're 4 paramMaps, but we can at most generate two threads to compute the models for them.

So, I think the API should be changed to the following style:

def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[Map[Int, M]]]`

It allow callable.call() to return multiple models, and return type is Map[Int, M], key is integer, used to mark the paramMap index for corresponding model. Use the example above, there're 4 paramMaps, but only return 2 callable objects, one callable object for ((regParam=0.1, maxIter=5), (regParam=0.1, maxIter=10)), another one for ((regParam=0.3, maxIter=5), (regParam=0.3, maxIter=10)).

and the default fitCallables/fit with paramMaps can be implemented as following:

def fitCallables(dataset: Dataset[_], paramMaps: Array[ParamMap]): Array[Callable[Map[Int, M]]] = {
  paramMaps.zipWithIndex.map { case (paramMap: ParamMap, index: Int) =>
    new Callable[Map[Int, M]] {
      override def call(): Map[Int, M] = Map(index -> fit(dataset, paramMap))
    }
  }
}
def fit(dataset: Dataset[_], paramMaps: Array[ParamMap]): Seq[M] = {
   fitCallables(dataset, paramMaps).map { _.call().toSeq }.flatMap(_).sortBy(_._1).map(_._2)
}

If use the API I proposed above, the code in CrossValidation
can be changed to:

      val trainingDataset = sparkSession.createDataFrame(training, schema).cache()
      val validationDataset = sparkSession.createDataFrame(validation, schema).cache()

      // Fit models in a Future for training in parallel
      val modelMapFutures = fitCallables(trainingDataset, paramMaps).map { callable =>
         Future[Map[Int, Model[_]]] {
            val modelMap = callable.call()
            if (collectSubModelsParam) {
               ...
            }
            modelMap
         } (executionContext)
      }

      // Unpersist training data only when all models have trained
      Future.sequence[Model[_], Iterable](modelMapFutures)(implicitly, executionContext)
        .onComplete { _ => trainingDataset.unpersist() } (executionContext)

      // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up
      val foldMetricMapFutures = modelMapFutures.map { modelMapFuture =>
        modelMapFuture.map { modelMap =>
          modelMap.map { case (index: Int, model: Model[_]) =>
            val metric = eval.evaluate(model.transform(validationDataset, paramMaps(index)))
            (index, metric)
          }
        } (executionContext)
      }

      // Wait for metrics to be calculated before unpersisting validation dataset
      val foldMetrics = foldMetricMapFutures.map(ThreadUtils.awaitResult(_, Duration.Inf))
          .map(_.toSeq).sortBy(_._1).map(_._2)

@jkbradley @MrBago @sueann

@thunterdb
Copy link

I do not understand why we would need to make batches of callables together, this is something that can be dealt with when creating the callables. For example, you can still produce 4 callables, but only 2 do some more complex work, and the last 2 are just piggybacking on the hard work of the first 2. Of course, in terms of scheduling, one may make some mistakes, but for large numbers of configurations or large thread pools this is not going to be relevant.

@tomasatdatabricks
Copy link

What Weichen is proposing makes sense to me. I don't see how would you efficiently synchronize the Callables if they were piggy-backing on each other's results. And I think this would matter even with large thread pools since the number of dependent tasks is not bound and can be O(N) (if I understood Tim's comment correctly).

@MrBago
Copy link
Author

MrBago commented Dec 5, 2017

Let's move the discussion to the JIRA, https://issues.apache.org/jira/browse/SPARK-22126. I'm also updating the gist to name the method fitMultiple. Sorry if that makes this thread harder to read, but I want to present the API as fitMultiple in OSS.

@WeichenXu123
Copy link

WeichenXu123 commented Dec 6, 2017

Making Callables piggy-backing on each other will bring more complexity, and it seems will bring risks of dead-lock (if scheduled in wrong order)? for example, the default implementation, callables are scheduled serially, if a certain callable piggy-backing on others are scheduled first, then dead-lock occurs ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment