Skip to content

Instantly share code, notes, and snippets.

@pazzo83
Created July 6, 2021 00:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pazzo83/61b09a3821145baf889896c6f4160d6e to your computer and use it in GitHub Desktop.
Save pazzo83/61b09a3821145baf889896c6f4160d6e to your computer and use it in GitHub Desktop.
using MLJModelInterface, MLJBase, TSVD
MLJModelInterface.@mlj_model mutable struct TSVDTransformer <: MLJModelInterface.Unsupervised
nvals::Int = 2
maxiter::Int = 1000
end
struct TSVDTransformerResult
singular_values::Vector{Float64}
components::Matrix{Float64}
end
function MLJBase.fit(transformer::TSVDTransformer, verbosity::Int, X)
U, s, V = tsvd(X, transformer.nvals; maxiter=transformer.maxiter)
res = TSVDTransformerResult(s, V)
return res, nothing, NamedTuple()
end
function MLJBase.transform(::TSVDTransformer, result, X)
X_transformed = X * result.components
return X_transformed
end
@ablaom
Copy link

ablaom commented Jul 20, 2021

This looks basically good, thanks!

types

Similar transformers in MLJ - eg, PCA from MultivariateStats - take tabular data as input, and the output is also tabular (with made-up names for the columns). See here for an example.

Now I'm guessing you may want allow sparse data here. While you can wrap sparse arrays as tables (using Tables.matrix) this performs poorly if you have lots of features (columns). Later we may address this with a sparse tabular format, but this does not exist. I suggest we add AbstractMatrix{<:AbstractFloat} as an allowed type (as you have in your code at present) in which case a matrix is should also be returned by transform.

So you will do something like this:

using ScientificTypesBase
MLJModelInterface.input_scitype(::Type{<:TSVDTransformer}) = Union{Table(Continuous),AbstractMatrix{<:AbstractFloat}}
MLJModelInterface.output_scitype(::Type{<:TSVDTransformer}) = Union{Table(Continuous),AbstractMatrix{<:AbstractFloat}}

This will guarantee you fit/transform methods receive either a Tables.jl compatible matrix (which you are probably just going to call Tables.matrix on) or a bone-fide matrix. However, your fitresult should encode which is the case, so you know what to return.

where to add the interface

This could be hosted by TSVD.jl if you happen to have close connections to that package. Otherwise, I suggest we create an interface-only package MLJTSVDInterface.jl (as we have done many times) to live at JuliaAI. I've created this already for you, just in case (currently private).

details

In your code get rid of MLJBase. Everything you need should be MLJModelInterface and, possibly ScientificTypesBase.

If you are happy to proceed with a PR to MLJTSVDInterface.jl, then I will try to make a prompt code review.

@pazzo83
Copy link
Author

pazzo83 commented Jul 20, 2021

Sounds good! I will submit a PR into the new package and we can go from there!

@pazzo83
Copy link
Author

pazzo83 commented Jul 20, 2021

I wasn't sure about the workflow and I think I might have corrupted the repository you created. Would you mind recreating it?

@ablaom
Copy link

ablaom commented Jul 20, 2021

No worries. I'll also set up a bit more of a proper skeleton to make it easier.

@pazzo83
Copy link
Author

pazzo83 commented Jul 21, 2021

Great, thank you!

@ablaom
Copy link

ablaom commented Jul 21, 2021

Done: https://github.com/JuliaAI/MLJTSVDInterface.jl .

It's still private for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment