Trait declarations

Two trait functions allow the implementer to restrict the types of data X, y and Xnew discussed above. The MLJ task interface uses these traits for data type checks but also for model search. If they are omitted (and your model is registered) then a general user may attempt to use your model with inappropriately typed data.

The trait functions input_scitype and target_scitype take scientific data types as values. We assume here familiarity with ScientificTypes.jl (see Getting Started for the basics).

For example, to ensure that the X presented to the DecisionTreeClassifier fit method is a table whose columns all have Continuous element type (and hence AbstractFloat machine type), one declares

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)

or, equivalently,

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)

If, instead, columns were allowed to have either: (i) a mixture of Continuous and Missing values, or (ii) Count (i.e., integer) values, then the declaration would be

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)

Similarly, to ensure the target is an AbstractVector whose elements have Finite scitype (and hence CategoricalValue machine type) we declare

MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}

Multivariate targets

The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare

target_scitype(SomeSupervisedModel) = Table(Continuous)

then this constrains the target to be any table whose columns have Continuous element scitype (i.e., AbstractFloat), while

target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})

restricts to tables with continuous or binary (ordered or unordered) columns.

For predicting variable length sequences of, say, binary values (CategoricalValues) with some common size-two pool) we declare

target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}

The trait functions controlling the form of data are summarized as follows:

methodreturn typedeclarable return valuesfallback value
input_scitypeTypesome scientific typeUnknown
target_scitypeTypesome scientific typeUnknown

Additional trait functions tell MLJ's @load macro how to find your model if it is registered, and provide other self-explanatory metadata about the model:

methodreturn typedeclarable return valuesfallback value
load_pathStringunrestricted"unknown"
package_nameStringunrestricted"unknown"
package_uuidStringunrestricted"unknown"
package_urlStringunrestricted"unknown"
package_licenseStringunrestricted"unknown"
is_pure_juliaBooltrue or falsefalse
supports_weightsBooltrue or falsefalse
supports_class_weightsBooltrue or falsefalse
supports_training_lossesBooltrue or falsefalse
reports_feature_importancesBooltrue or falsefalse

Here is the complete list of trait function declarations for DecisionTreeClassifier, whose core algorithms are provided by DecisionTree.jl, but whose interface actually lives at MLJDecisionTreeInterface.jl.

MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJDecisionTreeInterface.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = "https://github.com/bensadeghi/DecisionTree.jl"
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = true

Alternatively, these traits can also be declared using MMI.metadata_pkg and MMI.metadata_model helper functions as:

MMI.metadata_pkg(
  DecisionTreeClassifier,
  name="DecisionTree",
  package_uuid="7806a523-6efd-50cb-b5f6-3fa6f1930dbb",
  package_url="https://github.com/bensadeghi/DecisionTree.jl",
  is_pure_julia=true
)

MMI.metadata_model(
  DecisionTreeClassifier,
  input_scitype=MMI.Table(MMI.Continuous),
  target_scitype=AbstractVector{<:MMI.Finite},
  load_path="MLJDecisionTreeInterface.DecisionTreeClassifier"
)

Important. Do not omit the load_path specification. If unsure what it should be, post an issue at MLJ.

MLJModelInterface.metadata_pkgFunction
metadata_pkg(T; args...)

Helper function to write the metadata for a package providing model T. Use it with broadcasting to define the metadata of the package providing a series of models.

Keywords

  • package_name="unknown" : package name
  • package_uuid="unknown" : package uuid
  • package_url="unknown" : package url
  • is_pure_julia=missing : whether the package is pure julia
  • package_license="unknown": package license
  • is_wrapper=false : whether the package is a wrapper

Example

metadata_pkg.((KNNRegressor, KNNClassifier),
    package_name="NearestNeighbors",
    package_uuid="b8a86587-4115-5ab1-83bc-aa920d37bbce",
    package_url="https://github.com/KristofferC/NearestNeighbors.jl",
    is_pure_julia=true,
    package_license="MIT",
    is_wrapper=false)
source
MLJModelInterface.metadata_modelFunction
metadata_model(T; args...)

Helper function to write the metadata for a model T.

Keywords

  • input_scitype=Unknown: allowed scientific type of the input data
  • target_scitype=Unknown: allowed scitype of the target (supervised)
  • output_scitype=Unknown: allowed scitype of the transformed data (unsupervised)
  • supports_weights=false: whether the model supports sample weights
  • supports_class_weights=false: whether the model supports class weights
  • load_path="unknown": where the model is (usually PackageName.ModelName)
  • human_name=nothing: human name of the model
  • supports_training_losses=nothing: whether the (necessarily iterative) model can report training losses
  • reports_feature_importances=nothing: whether the model reports feature importances

Example

metadata_model(KNNRegressor,
    input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
    target_scitype=AbstractVector{MLJModelInterface.Continuous},
    supports_weights=true,
    load_path="NearestNeighbors.KNNRegressor")
source