Trait declarations
Two trait functions allow the implementer to restrict the types of data X
, y
and Xnew
discussed above. The MLJ task interface uses these traits for data type checks but also for model search. If they are omitted (and your model is registered) then a general user may attempt to use your model with inappropriately typed data.
The trait functions input_scitype
and target_scitype
take scientific data types as values. We assume here familiarity with ScientificTypes.jl (see Getting Started for the basics).
For example, to ensure that the X
presented to the DecisionTreeClassifier
fit
method is a table whose columns all have Continuous
element type (and hence AbstractFloat
machine type), one declares
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
or, equivalently,
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Continuous)
If, instead, columns were allowed to have either: (i) a mixture of Continuous
and Missing
values, or (ii) Count
(i.e., integer) values, then the declaration would be
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = Table(Union{Continuous,Missing},Count)
Similarly, to ensure the target is an AbstractVector whose elements have Finite
scitype (and hence CategoricalValue
machine type) we declare
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:Finite}
Multivariate targets
The above remarks continue to hold unchanged for the case multivariate targets. For example, if we declare
target_scitype(SomeSupervisedModel) = Table(Continuous)
then this constrains the target to be any table whose columns have Continuous
element scitype (i.e., AbstractFloat
), while
target_scitype(SomeSupervisedModel) = Table(Continuous, Finite{2})
restricts to tables with continuous or binary (ordered or unordered) columns.
For predicting variable length sequences of, say, binary values (CategoricalValue
s) with some common size-two pool) we declare
target_scitype(SomeSupervisedModel) = AbstractVector{<:NTuple{<:Finite{2}}}
The trait functions controlling the form of data are summarized as follows:
method | return type | declarable return values | fallback value |
---|---|---|---|
input_scitype | Type | some scientific type | Unknown |
target_scitype | Type | some scientific type | Unknown |
Additional trait functions tell MLJ's @load
macro how to find your model if it is registered, and provide other self-explanatory metadata about the model:
method | return type | declarable return values | fallback value |
---|---|---|---|
load_path | String | unrestricted | "unknown" |
package_name | String | unrestricted | "unknown" |
package_uuid | String | unrestricted | "unknown" |
package_url | String | unrestricted | "unknown" |
package_license | String | unrestricted | "unknown" |
is_pure_julia | Bool | true or false | false |
supports_weights | Bool | true or false | false |
supports_class_weights | Bool | true or false | false |
supports_training_losses | Bool | true or false | false |
reports_feature_importances | Bool | true or false | false |
Here is the complete list of trait function declarations for DecisionTreeClassifier
, whose core algorithms are provided by DecisionTree.jl, but whose interface actually lives at MLJDecisionTreeInterface.jl.
MMI.input_scitype(::Type{<:DecisionTreeClassifier}) = MMI.Table(MMI.Continuous)
MMI.target_scitype(::Type{<:DecisionTreeClassifier}) = AbstractVector{<:MMI.Finite}
MMI.load_path(::Type{<:DecisionTreeClassifier}) = "MLJDecisionTreeInterface.DecisionTreeClassifier"
MMI.package_name(::Type{<:DecisionTreeClassifier}) = "DecisionTree"
MMI.package_uuid(::Type{<:DecisionTreeClassifier}) = "7806a523-6efd-50cb-b5f6-3fa6f1930dbb"
MMI.package_url(::Type{<:DecisionTreeClassifier}) = "https://github.com/bensadeghi/DecisionTree.jl"
MMI.is_pure_julia(::Type{<:DecisionTreeClassifier}) = true
Alternatively, these traits can also be declared using MMI.metadata_pkg
and MMI.metadata_model
helper functions as:
MMI.metadata_pkg(
DecisionTreeClassifier,
name="DecisionTree",
package_uuid="7806a523-6efd-50cb-b5f6-3fa6f1930dbb",
package_url="https://github.com/bensadeghi/DecisionTree.jl",
is_pure_julia=true
)
MMI.metadata_model(
DecisionTreeClassifier,
input_scitype=MMI.Table(MMI.Continuous),
target_scitype=AbstractVector{<:MMI.Finite},
load_path="MLJDecisionTreeInterface.DecisionTreeClassifier"
)
Important. Do not omit the load_path
specification. If unsure what it should be, post an issue at MLJ.
MLJModelInterface.metadata_pkg
— Functionmetadata_pkg(T; args...)
Helper function to write the metadata for a package providing model T
. Use it with broadcasting to define the metadata of the package providing a series of models.
Keywords
package_name="unknown"
: package namepackage_uuid="unknown"
: package uuidpackage_url="unknown"
: package urlis_pure_julia=missing
: whether the package is pure juliapackage_license="unknown"
: package licenseis_wrapper=false
: whether the package is a wrapper
Example
metadata_pkg.((KNNRegressor, KNNClassifier),
package_name="NearestNeighbors",
package_uuid="b8a86587-4115-5ab1-83bc-aa920d37bbce",
package_url="https://github.com/KristofferC/NearestNeighbors.jl",
is_pure_julia=true,
package_license="MIT",
is_wrapper=false)
MLJModelInterface.metadata_model
— Functionmetadata_model(T; args...)
Helper function to write the metadata for a model T
.
Keywords
input_scitype=Unknown
: allowed scientific type of the input datatarget_scitype=Unknown
: allowed scitype of the target (supervised)output_scitype=Unknown
: allowed scitype of the transformed data (unsupervised)supports_weights=false
: whether the model supports sample weightssupports_class_weights=false
: whether the model supports class weightsload_path="unknown"
: where the model is (usuallyPackageName.ModelName
)human_name=nothing
: human name of the modelsupports_training_losses=nothing
: whether the (necessarily iterative) model can report training lossesreports_feature_importances=nothing
: whether the model reports feature importances
Example
metadata_model(KNNRegressor,
input_scitype=MLJModelInterface.Table(MLJModelInterface.Continuous),
target_scitype=AbstractVector{MLJModelInterface.Continuous},
supports_weights=true,
load_path="NearestNeighbors.KNNRegressor")