CatBoostRegressor
CatBoostRegressor
A model type for constructing a CatBoost regressor, based on CatBoost.jl, and implementing the MLJ model interface.
From MLJ, the type can be imported using
CatBoostRegressor = @load CatBoostRegressor pkg=CatBoost
Do model = CatBoostRegressor()
to construct an instance with default hyper-parameters. Provide keyword arguments to override hyper-parameter defaults, as in CatBoostRegressor(iterations=...)
.
Training data
In MLJ or MLJBase, bind an instance model
to data with
mach = machine(model, X, y)
where
X
: any table of input features (eg, aDataFrame
) whose columns each have one of the following element scitypes:Continuous
,Count
,Finite
,Textual
; check column scitypes withschema(X)
.Textual
columns will be passed to catboost astext_features
,Multiclass
columns will be passed to catboost ascat_features
, andOrderedFactor
columns will be converted to integers.y
: the target, which can be anyAbstractVector
whose element scitype isContinuous
; check the scitype withscitype(y)
Train the machine with fit!(mach, rows=...)
.
Hyper-parameters
More details on the catboost hyperparameters, here are the Python docs: https://catboost.ai/en/docs/concepts/python-reference_catboostclassifier#parameters
Operations
predict(mach, Xnew)
: probabilistic predictions of the target given new featuresXnew
having the same scitype asX
above.
Accessor functions
feature_importances(mach)
: return vector of feature importances, in the form offeature::Symbol => importance::Real
pairs
Fitted parameters
The fields of fitted_params(mach)
are:
model
: The Python CatBoostRegressor model
Report
The fields of report(mach)
are:
feature_importances
: Vector{Pair{Symbol, Float64}} of feature importances
Examples
using CatBoost.MLJCatBoostInterface
using MLJ
X = (
duration = [1.5, 4.1, 5.0, 6.7],
n_phone_calls = [4, 5, 6, 7],
department = coerce(["acc", "ops", "acc", "ops"], Multiclass),
)
y = [2.0, 4.0, 6.0, 7.0]
model = CatBoostRegressor(iterations=5)
mach = machine(model, X, y)
fit!(mach)
preds = predict(mach, X)
See also catboost and the unwrapped model type CatBoost.CatBoostRegressor
.