The predict method
A compulsory predict
method has the form
MMI.predict(model::SomeSupervisedModel, fitresult, Xnew) -> yhat
Here Xnew
will have the same form as the X
passed to fit
.
Note that while Xnew
generally consists of multiple observations (e.g., has multiple rows in the case of a table) it is assumed, in view of the i.i.d assumption recalled above, that calling predict(..., Xnew)
is equivalent to broadcasting some method predict_one(..., x)
over the individual observations x
in Xnew
(a method implementing the probability distribution p(X |y)
above).
Prediction types for deterministic responses.
In the case of Deterministic
models, yhat
should have the same scitype as the y
passed to fit
(see above). If y
is a CategoricalVector
(classification) then elements of the prediction yhat
must have a pool == to the pool of the target y
presented in training, even if not all levels appear in the training data or prediction itself.
Unfortunately, code not written with the preservation of categorical levels in mind poses special problems. To help with this, MLJModelInterface provides some utilities: MLJModelInterface.int
(for converting a CategoricalValue
into an integer, the ordering of these integers being consistent with that of the pool) and MLJModelInterface.decoder
(for constructing a callable object that decodes the integers back into CategoricalValue
objects). Refer to Convenience methods below for important details.
Note that a decoder created during fit
may need to be bundled with fitresult
to make it available to predict
during re-encoding. So, for example, if the core algorithm being wrapped by fit
expects a nominal target yint
of type Vector{<:Integer}
then a fit
method may look something like this:
function MMI.fit(model::SomeSupervisedModel, verbosity, X, y)
yint = MMI.int(y)
a_target_element = y[1] # a CategoricalValue/String
decode = MMI.decoder(a_target_element) # can be called on integers
core_fitresult = SomePackage.fit(X, yint, verbosity=verbosity)
fitresult = (decode, core_fitresult)
cache = nothing
report = nothing
return fitresult, cache, report
end
while a corresponding deterministic predict
operation might look like this:
function MMI.predict(model::SomeSupervisedModel, fitresult, Xnew)
decode, core_fitresult = fitresult
yhat = SomePackage.predict(core_fitresult, Xnew)
return decode.(yhat)
end
For a concrete example, refer to the code for SVMClassifier
.
Of course, if you are coding a learning algorithm from scratch, rather than wrapping an existing one, these extra measures may be unnecessary.
Prediction types for probabilistic responses
In the case of Probabilistic
models with univariate targets, yhat
must be an AbstractVector
or table whose elements are distributions. In the common case of a vector (single target), this means one distribution per row of Xnew
.
A distribution is some object that, at the least, implements Base.rng
(i.e., is something that can be sampled). Currently, all performance measures (metrics) defined in MLJBase.jl additionally assume that a distribution is either:
An instance of some subtype of
Distributions.Distribution
, an abstract type defined in theDistributions.jl
package; orAn instance of
CategoricalDistributions.UnivariateFinite
, from the CategoricalDistributions.jl package, which should be used for all probabilistic classifiers, i.e., for predictors whose target has scientific type<:AbstractVector{<:Finite}
.
All such distributions implement the probability mass or density function Distributions.pdf
. If your model's predictions cannot be predict objects of this form, then you will need to implement appropriate performance measures to buy into MLJ's performance evaluation apparatus.
An implementation can avoid CategoricalDistributions.jl as a dependency by using the "dummy" constructor MLJModelInterface.UnivariateFinite
, which is bound to the true one when MLJBase.jl is loaded.
For efficiency, one should not construct UnivariateFinite
instances one at a time. Rather, once a probability vector, matrix, or dictionary is known, construct an instance of UnivariateFiniteVector <: AbstractArray{<:UnivariateFinite},1}
to return. Both UnivariateFinite
and UnivariateFiniteVector
objects are constructed using the single UnivariateFinite
function.
For example, suppose the target y
arrives as a subsample of some ybig
and is missing some classes:
ybig = categorical([:a, :b, :a, :a, :b, :a, :rare, :a, :b])
y = ybig[1:6]
Your fit method has bundled the first element of y
with the fitresult
to make it available to predict
for purposes of tracking the complete pool of classes. Let's call this an_element = y[1]
. Then, supposing the corresponding probabilities of the observed classes [:a, :b]
are in an n x 2
matrix probs
(where n
the number of rows of Xnew
) then you return
yhat = MLJModelInterface.UnivariateFinite([:a, :b], probs, pool=an_element)
This object automatically assigns zero-probability to the unseen class :rare
(i.e., pdf.(yhat, :rare)
works and returns a zero vector). If you would like to assign :rare
non-zero probabilities, simply add it to the first vector (the support) and supply a larger probs
matrix.
In a binary classification problem, it suffices to specify a single vector of probabilities, provided you specify augment=true
, as in the following example, and note carefully that these probabilities are associated with the last (second) class you specify in the constructor:
y = categorical([:TRUE, :FALSE, :FALSE, :TRUE, :TRUE])
an_element = y[1]
probs = rand(10)
yhat = MLJModelInterface.UnivariateFinite([:FALSE, :TRUE], probs, augment=true, pool=an_element)
The constructor has a lot of options, including passing a dictionary instead of vectors. See CategoricalDistributions.UnivariateFinite
for details.
See LinearBinaryClassifier for an example of a Probabilistic classifier implementation.
Important note on binary classifiers. There is no "Binary" scitype distinct from Multiclass{2}
or OrderedFactor{2}
; Binary
is just an alias for Union{Multiclass{2},OrderedFactor{2}}
. The target_scitype
of a binary classifier will generally be AbstractVector{<:Binary}
and according to the mlj scitype convention, elements of y
have type CategoricalValue
, and not Bool
. See BinaryClassifier for an example.
Report items returned by predict
A predict
method, or other operation such as transform
, can contribute to the report accessible in any machine associated with a model. See Reporting byproducts of a static transformation below for details.