Trait declarations
Two trait functions allow the implementer to restrict the types of data X, y and Xnew discussed above. The MLJ task interface uses these traits for data type checks but also for model search. If they are omitted (and your model is registered) then a general user may attempt to use your model with inappropriately typed data.
The trait functions input_scitype and target_scitype take scientific data types as values. We assume here familiarity with ScientificTypes.jl (see Getting Started for the basics).
For example, to ensure that the X presented to the DecisionTreeClassifier fit method is a table whose columns all have Continuous element type (and hence AbstractFloat machine type), one declares
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)or, equivalently,
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)If, instead, columns were allowed to have either: (i) a mixture of Continuous and Missing values, or (ii) Count (i.e., integer) values, then the declaration would be
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)Similarly, to ensure the target is an AbstractVector whose elements have Finite scitype (and hence CategoricalValue machine type) we declare
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}Multivariate targets
The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare
target_scitype(SomeSupervisedModel) = Table(Continuous)then this constrains the target to be any table whose columns have Continuous element scitype (i.e., AbstractFloat), while
target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})restricts to tables with continuous or binary (ordered or unordered) columns.
For predicting variable length sequences of, say, binary values (CategoricalValues) with some common size-two pool) we declare
target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}The trait functions controlling the form of data are summarized as follows:
| method | return type | declarable return values | fallback value | 
|---|---|---|---|
input_scitype | Type | some scientific type | Unknown | 
target_scitype | Type | some scientific type | Unknown | 
Additional trait functions tell MLJ's @load macro how to find your model if it is registered, and provide other self-explanatory metadata about the model:
| method | return type | declarable return values | fallback value | 
|---|---|---|---|
load_path | String | unrestricted | "unknown" | 
package_name | String | unrestricted | "unknown" | 
package_uuid | String | unrestricted | "unknown" | 
package_url | String | unrestricted | "unknown" | 
package_license | String | unrestricted | "unknown" | 
is_pure_julia | Bool | true or false | false | 
supports_weights | Bool | true or false | false | 
supports_class_weights | Bool | true or false | false | 
supports_training_losses | Bool | true or false | false | 
reports_feature_importances | Bool | true or false | false | 
Here is the complete list of trait function declarations for DecisionTreeClassifier, whose core algorithms are provided by DecisionTree.jl, but whose interface actually lives at MLJDecisionTreeInterface.jl.
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJDecisionTreeInterface.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = "https://github.com/bensadeghi/DecisionTree.jl"
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = trueAlternatively, these traits can also be declared using MMI.metadata_pkg and MMI.metadata_model helper functions as:
MMI.metadata_pkg(
  DecisionTreeClassifier,
  name="DecisionTree",
  package_uuid="7806a523-6efd-50cb-b5f6-3fa6f1930dbb",
  package_url="https://github.com/bensadeghi/DecisionTree.jl",
  is_pure_julia=true
)
MMI.metadata_model(
  DecisionTreeClassifier,
  input_scitype=MMI.Table(MMI.Continuous),
  target_scitype=AbstractVector{<:MMI.Finite},
  load_path="MLJDecisionTreeInterface.DecisionTreeClassifier"
)Important. Do not omit the load_path specification. If unsure what it should be, post an issue at MLJ.
MLJModelInterface.metadata_pkg — Functionmetadata_pkg(T; args...)Helper function to write the metadata for a package providing model T. Use it with broadcasting to define the metadata of the package providing a series of models.
Keywords
package_name="unknown": package namepackage_uuid="unknown": package uuidpackage_url="unknown": package urlis_pure_julia=missing: whether the package is pure juliapackage_license="unknown": package licenseis_wrapper=false: whether the package is a wrapper
Example
metadata_pkg.((KNNRegressor, KNNClassifier),
    package_name="NearestNeighbors",
    package_uuid="b8a86587-4115-5ab1-83bc-aa920d37bbce",
    package_url="https://github.com/KristofferC/NearestNeighbors.jl",
    is_pure_julia=true,
    package_license="MIT",
    is_wrapper=false)MLJModelInterface.metadata_model — Functionmetadata_model(T; args...)Helper function to write the metadata for a model T.
Keywords
input_scitype=Unknown: allowed scientific type of the input datatarget_scitype=Unknown: allowed scitype of the target (supervised)output_scitype=Unknown: allowed scitype of the transformed data (unsupervised)supports_weights=false: whether the model supports sample weightssupports_class_weights=false: whether the model supports class weightsload_path="unknown": where the model is (usuallyPackageName.ModelName)human_name=nothing: human name of the modelsupports_training_losses=nothing: whether the (necessarily iterative) model can report training lossesreports_feature_importances=nothing: whether the model reports feature importances
Example
metadata_model(KNNRegressor,
    input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
    target_scitype=AbstractVector{MLJModelInterface.Continuous},
    supports_weights=true,
    load_path="NearestNeighbors.KNNRegressor")