Correcting Class Imbalance
Oversampling and undersampling methods
Models providing oversampling or undersampling methods, to correct for class imbalance, are listed under Class Imbalance. In particular, several popular algorithms are provided by the Imbalance.jl package, which includes detailed documentation and tutorials.
Incorporating class imbalance in supervised learning pipelines
One or more oversampling/undersampling algorithms can be fused with an MLJ classifier using the BalancedModel
wrapper. This creates a new classifier which can be treated like any other; resampling to correct for class imbalance, relevant only for training of the atomic classifier, is then carried out internally. If, for example, one applies cross-validation to the wrapped classifier (using evaluate!
, say) then this means over/undersampling is then repeated for each training fold automatically.
Refer to the MLJBalancing.jl documentation for further details.
MLJBalancing.BalancedModel
— FunctionBalancedModel(; model=nothing, balancer1=balancer_model1, balancer2=balancer_model2, ...)
BalancedModel(model; balancer1=balancer_model1, balancer2=balancer_model2, ...)
Given a classification model, and one or more balancer models that all implement the MLJModelInterface
, BalancedModel
allows constructing a sequential pipeline that wraps an arbitrary number of balancing models and a classifier together in a sequential pipeline.
Operation
- During training, data is first passed to
balancer1
and the result is passed tobalancer2
and so on, the result from the final balancer is then passed to the classifier for training. - During prediction, the balancers have no effect.
Arguments
model::Supervised
: A classification model that implements theMLJModelInterface
.balancer1::Static=...
: The first balancer model to pass the data to. This keyword argument can have any name.balancer2::Static=...
: The second balancer model to pass the data to. This keyword argument can have any name.- and so on for an arbitrary number of balancers.
Returns
- An instance of type ProbabilisticBalancedModel or DeterministicBalancedModel, depending on the prediction type of model.
Example
using MLJ
using Imbalance
# generate data
X, y = Imbalance.generate_imbalanced_data(1000, 5; class_probs=[0.2, 0.3, 0.5])
# prepare classification and balancing models
SMOTENC = @load SMOTENC pkg=Imbalance verbosity=0
TomekUndersampler = @load TomekUndersampler pkg=Imbalance verbosity=0
LogisticClassifier = @load LogisticClassifier pkg=MLJLinearModels verbosity=0
oversampler = SMOTENC(k=5, ratios=1.0, rng=42)
undersampler = TomekUndersampler(min_ratios=0.5, rng=42)
logistic_model = LogisticClassifier()
# wrap them in a BalancedModel
balanced_model = BalancedModel(model=logistic_model, balancer1=oversampler, balancer2=undersampler)
# now this behaves as a unified model that can be trained, validated, fine-tuned, etc.
mach = machine(balanced_model, X, y)
fit!(mach)