CatBoostClassifier
CatBoostClassifierA model type for constructing a CatBoost classifier, based on CatBoost.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
CatBoostClassifier = @load CatBoostClassifier pkg=CatBoostDo model = CatBoostClassifier() to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in CatBoostClassifier(iterations=...).
Training data
In MLJ or MLJBase, bind an instance model to data with
mach = machine(model, X, y)where
X: any table of input features (eg, aDataFrame) whose columns each have one of the following element scitypes:Continuous,Count,Finite,Textual; check column scitypes withschema(X).Textualcolumns will be passed to catboost astext_features,Multiclasscolumns will be passed to catboost ascat_features, andOrderedFactorcolumns will be converted to integers.y: the target, which can be anyAbstractVectorwhose element scitype isFinite; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...).
Hyper-parameters
More details on the catboost hyperparameters, here are the Python docs: https://catboost.ai/en/docs/concepts/python-reference_catboostclassifier#parameters
Operations
predict(mach, Xnew): probabilistic predictions of the target given new featuresXnewhaving the same scitype asXabove.predict_mode(mach, Xnew): returns the mode of each of the prediction above.
Accessor functions
feature_importances(mach): return vector of feature importances, in the form offeature::Symbol => importance::Realpairs
Fitted parameters
The fields of fitted_params(mach) are:
model: The Python CatBoostClassifier model
Report
The fields of report(mach) are:
feature_importances: Vector{Pair{Symbol, Float64}} of feature importances
Examples
using CatBoost.MLJCatBoostInterface
using MLJ
X = (
duration = [1.5, 4.1, 5.0, 6.7],
n_phone_calls = [4, 5, 6, 7],
department = coerce(["acc", "ops", "acc", "ops"], Multiclass),
)
y = coerce([0, 0, 1, 1], Multiclass)
model = CatBoostClassifier(iterations=5)
mach = machine(model, X, y)
fit!(mach)
probs = predict(mach, X)
preds = predict_mode(mach, X)See also catboost and the unwrapped model type CatBoost.CatBoostClassifier.