mlr3 Tutorial

Bernd Bischl & Marvin N. Wright

Introduction and Overview

mlr3 Book

Best entry point to mlr3

Tutorial follows the book closely

mlr3 Ecosystem and Package List

Package Installation and Loading

mlr3verse package references all important packages of mlr3 ecosystem

library(mlr3verse)

But explicitly loading packages can be transparent – and some of us prefer that

library(mlr3)
library(mlr3learners)
library(mlr3tuning)
library(mlr3pipelines)

mlr3 Design Principles

  • OO system (with R6)

  • Light on deps - only use stuff with clear benefits and no dep recursion hell

  • Defensive programming and type safety

  • Flexibility and efficiency trumps simplicity and convenience (still support latter conservatively)

  • Proper and similar containers; data is tabular, also result data; use data.table, coz it’s fast and great (but note: heavy use of list cols)

  • Emphasis on computation instead of presentation

R6 Classes

R’s more recent OO system

Instances of a class are constructed with $new()

foo = Foo$new(bar = 1)

We often replace this with sugar functions (see later)

R6 - Fields, Active Bindings, Methods

Objects have mutable state, encapsulated in fields

Access and modify via $ operator

foo$bar = 2

Can also be ABs, which look like normal field-access or assign, but can run arbitrary code

(R6’s version of getters and setters)

So the above code could e.g. validate assign data

Methods are other functions associated with object

Can also change state

foo$change_me()

R6 - Reference Semantics

Objects are environments

So we can change their state via assignments

And if we pass them to functions, a reference is passed (different to R’s usual copy-on-write)

# does not create a copy foo
foo2 = foo

# changes state of LHS
foo$bar = 3

# if foo is changed inside, caller-object is changed 
do_something(foo)

# create copy
foo3 = foo$clone(deep = TRUE)

Sugar Functions

Most mlr3 objects are created with sugar functions

Slightly shorter and more convenient

E.g. lrn() creates learner object without $new()

lrn("regr.rpart")
<LearnerRegrRpart:regr.rpart>: Regression Tree
* Model: -
* Parameters: xval=0
* Packages: mlr3, rpart
* Predict Types:  [response]
* Feature Types: logical, integer, numeric, factor, ordered
* Properties: importance, missings, selected_features, weights

We combine these with dictionaries (see next) and there is sugar for more than object creation

Dictionaries

Common classes and constructors organized in dicts

Means: you can list and search

Associates keys with objects

mlr_tasks
<DictionaryTask> with 22 stored values
Keys: ames_housing, bike_sharing, boston_housing, breast_cancer, california_housing, german_credit, ilpd, iris, kc_housing, moneyball, mtcars, optdigits, penguins, penguins_simple, pima, ruspini, sonar, spam, titanic, usarrests, wine, zoo

Sugar for object retrieval

tsk("pima")

Dicts exist for tasks, learners, metrics, pipeline ops, etc.

Could even be changed locally by user

Data and Basic Modeling

ML Process and Common Steps

Predefined Tasks

We ship some predefined tasks

Stored in mlr_tasks dict, access with tsk()

mlr_tasks
<DictionaryTask> with 22 stored values
Keys: ames_housing, bike_sharing, boston_housing, breast_cancer, california_housing, german_credit, ilpd, iris, kc_housing, moneyball, mtcars, optdigits, penguins, penguins_simple, pima, ruspini, sonar, spam, titanic, usarrests, wine, zoo
tsk_mtcars = tsk("mtcars")
tsk_mtcars
<TaskRegr:mtcars> (32 x 11): Motor Trends
* Target: mpg
* Properties: -
* Features (10):
  - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt

Construct Your Own Task

data("mtcars", package = "datasets")
tsk_mtcars = as_task_regr(mtcars, target = "mpg", id = "cars")
tsk_mtcars
<TaskRegr:cars> (32 x 11)
* Target: mpg
* Properties: -
* Features (10):
  - dbl (10): am, carb, cyl, disp, drat, gear, hp, qsec, vs, wt

Task Mutators

Filter features

tsk_mtcars_small = tsk("mtcars") 
tsk_mtcars_small$select("cyl")

Filter rows

tsk_mtcars_small$filter(2:3)
tsk_mtcars_small$data()
     mpg   cyl
   <num> <num>
1:  21.0     6
2:  22.8     4

Lots of operations for tasks, let’s move on for now

Help Pages

For functions use ? as usual

Help page of mtcars task via ?mlr_tasks_mtcars

$help() for some objects like tasks / learners

tsk("mtcars")$help()

Learners

Stored in mlr_learners dict, access with lrn()

Dict is populated by loaded mlr3 packages, e.g. mlr3learners, mlr3extralearners, mlr3proba

Dict sugar with lrn()

lrn("regr.rpart")
<LearnerRegrRpart:regr.rpart>: Regression Tree
* Model: -
* Parameters: xval=0
* Packages: mlr3, rpart
* Predict Types:  [response]
* Feature Types: logical, integer, numeric, factor, ordered
* Properties: importance, missings, selected_features, weights

Dependencies are auto-loaded but not auto-installed

Learner Metadata

  • $feature_types: what learner can handle
  • $packages: dependency list, auto-loaded
  • $properties: e.g. can handle “missings” natively
  • $predict_types: possible pred outputs, e.g. probs
  • $param_set: set of available hyperparameters T Dict can be searched for learners with certain properties

Learner Stages

Training

# train and inspect model
tsk_mtcars = tsk("mtcars")
lrn_rpart = lrn("regr.rpart")
lrn_rpart$train(tsk_mtcars)
lrn_rpart$model
n= 32 

node), split, n, deviance, yval
      * denotes terminal node

1) root 32 1126.00 20.09  
  2) cyl>=5 21  198.50 16.65  
    4) hp>=192.5 7   28.83 13.41 *
    5) hp< 192.5 14   59.87 18.26 *
  3) cyl< 5 11  203.40 26.66 *

Partition Data and Train on Subset

Randomly split task into disjoint sets

splits = partition(tsk_mtcars)
splits
$train
 [1]  1  2  5  7 10 12 13 14 15 16 17 20 21 22 23 24 25 27 29 30 32

$test
 [1]  3  4  6  8  9 11 18 19 26 28 31

$validation
integer(0)

(Later we often don’t do this manually, see resampling)

Train on training set

lrn_rpart$train(tsk_mtcars, row_ids = splits$train)

Predict (on Subset)

prediction = lrn_rpart$predict(tsk_mtcars, row_ids = splits$test)

Returns Prediction object

prediction
<PredictionRegr> for 11 observations:
 row_ids truth response
       3  22.8     23.5
       4  21.4     15.4
       6  18.1     15.4
     ---   ---      ---
      26  27.3     23.5
      28  30.4     23.5
      31  15.0     15.4

Get tabular form:

as.data.table(prediction)

Hyperparameters

Represented as ParamSet object:

lrn_rpart$param_set
<ParamSet(10)>
                id    class lower upper nlevels        default  value
            <char>   <char> <num> <num>   <num>         <list> <list>
 1:             cp ParamDbl     0     1     Inf           0.01 [NULL]
 2:     keep_model ParamLgl    NA    NA       2          FALSE [NULL]
 3:     maxcompete ParamInt     0   Inf     Inf              4 [NULL]
 4:       maxdepth ParamInt     1    30      30             30 [NULL]
 5:   maxsurrogate ParamInt     0   Inf     Inf              5 [NULL]
 6:      minbucket ParamInt     1   Inf     Inf <NoDefault[0]> [NULL]
 7:       minsplit ParamInt     1   Inf     Inf             20 [NULL]
 8: surrogatestyle ParamInt     0     1       2              0 [NULL]
 9:   usesurrogate ParamInt     0     2       3              2 [NULL]
10:           xval ParamInt     0   Inf     Inf             10      0

Used for overview, input validation and tuning setup

HP Classes

HP Class HP Type
ParamDbl Real-valued (numeric)
ParamInt Integer
ParamFct Categorical (factor)
ParamLgl Logical / Boolean
ParamUty Untyped

Getting and Setting HPs

During construction

lrn_rpart = lrn("regr.rpart", maxdepth = 1)
lrn_rpart$param_set$values
$maxdepth
[1] 1

$xval
[1] 0

Update after construction

# can do
lrn_rpart$param_set$values$cp = 0.5
# better
lrn_rpart$param_set$set_values(xval = 2, cp = 0.5)

Evaluation of Predictions

Important step of applied ML

Same code as before:

lrn_rpart = lrn("regr.rpart")
tsk_mtcars = tsk("mtcars")
splits = partition(tsk_mtcars)
lrn_rpart$train(tsk_mtcars, splits$train)
prediction = lrn_rpart$predict(tsk_mtcars, splits$test)

Measure

Compare preds and labels for Prediction objs

Outputs a scalar

Stored in mlr_measures dict, access with msr()

as.data.table(msr())[c(3, 4, 6, 7, 17, 45, 55),
  .(key, label, task_type, predict_type)]
Key: <key>
           key                 label task_type predict_type
        <char>                <char>    <char>       <char>
1:          ci            Default CI      <NA>     response
2:    ci.con_z     Conservative-Z CI      <NA>     response
3:  ci.holdout            Holdout CI      <NA>     response
4:      ci.ncv          Nested CV CI      <NA>     response
5: classif.fdr  False Discovery Rate   classif     response
6:   clust.wss Within Sum of Squares     clust    partition
7:  regr.medse  Median Squared Error      regr     response

Use Measure to Score Pred Object

Mean absolute error

\(f(y, \hat{y}) = | y - \hat{y} |\)

measure = msr("regr.mae")
measure
<MeasureRegrSimple:regr.mae>: Mean Absolute Error
* Packages: mlr3, mlr3measures
* Range: [0, Inf]
* Minimize: TRUE
* Average: macro
* Parameters: list()
* Properties: -
* Predict type: response
prediction$score(measure)
regr.mae 
   4.409 

Classification

Very similar interface

Let’s do penguins and score with accuracy

tsk_penguins = tsk("penguins")
splits = partition(tsk_penguins)
lrn_rpart = lrn("classif.rpart")
measure = msr("classif.acc") 
lrn_rpart$train(tsk_penguins, splits$train)
lrn_rpart$predict(tsk_penguins, splits$test)$score(measure)
classif.acc 
     0.9474 

Binary Classification Tasks

sonar task is binary

tsk_sonar = tsk("sonar")
tsk_sonar
<TaskClassif:sonar> (208 x 61): Sonar: Mines vs. Rocks
* Target: Class
* Properties: twoclass
* Features (60):
  - dbl (60): V1, V10, V11, V12, V13, V14, V15, V16, V17, V18, V19, V2, V20, V21, V22, V23, V24, V25, V26, V27, V28, V29, V3, V30, V31, V32, V33, V34, V35, V36, V37, V38, V39, V4, V40, V41, V42, V43, V44, V45, V46, V47, V48, V49, V5, V50, V51, V52, V53, V54, V55, V56, V57, V58, V59, V6, V60, V7, V8, V9
tsk_sonar$class_names
[1] "M" "R"

Especially for ROC metrics, better define pos class

library(mlbench); data(Sonar)
tsk_sonar = as_task_classif(Sonar, target = "Class", positive = "M")

Classification Predictions

Predict hard labels with predict_type="response"

Or vector of probs with predict_type="prob"

lrn_rpart = lrn("classif.rpart", predict_type = "prob")
lrn_rpart$train(tsk_penguins, splits$train)
prediction = lrn_rpart$predict(tsk_penguins, splits$test)
prediction
<PredictionClassif> for 114 observations:
 row_ids     truth  response prob.Adelie prob.Chinstrap prob.Gentoo
       2    Adelie    Adelie      0.9895        0.01053           0
       3    Adelie    Adelie      0.9895        0.01053           0
       4    Adelie    Adelie      0.9895        0.01053           0
     ---       ---       ---         ---            ---         ---
     325 Chinstrap Chinstrap      0.0000        1.00000           0
     328 Chinstrap Chinstrap      0.0000        1.00000           0
     329 Chinstrap Chinstrap      0.0000        1.00000           0

Classification Measures

To evaluate "response" predictions, you will need measures with predict_type = "response"

To evaluate probability predictions you will need predict_type = "prob"

measures = msrs(c("classif.mbrier", "classif.logloss", "classif.acc"))
prediction$score(measures)
 classif.mbrier classif.logloss     classif.acc 
         0.1059          1.8262          0.9474 

Confusion Matrix

2x2 table of labels vs predicted hard labels

prediction$confusion
           truth
response    Adelie Chinstrap Gentoo
  Adelie        49         0      3
  Chinstrap      2        19      0
  Gentoo         1         0     40

Thresholding

Default response hard label is class with max posterior prob

For binary tasks, this implies >= 50%

task_credit = tsk("german_credit")
split = partition(task_credit)

lrn_rf = lrn("classif.ranger", predict_type = "prob")
lrn_rf$train(task_credit, split$train)
prediction = lrn_rf$predict(task_credit, split$test)
prediction$score(msr("classif.acc"))
classif.acc 
     0.7667 
prediction$confusion
        truth
response good bad
    good  210  65
    bad    12  43

Thresholding

For imbalanced tasks or unequal costs can change threshold

prediction$set_threshold(0.85)
prediction$confusion
        truth
response good bad
    good   72   5
    bad   150 103

ROC Metrics

prediction$set_threshold(0.2)
prediction$score(msrs(c("classif.tpr", "classif.ppv", "classif.fbeta")))
  classif.tpr   classif.ppv classif.fbeta 
       0.9955        0.6717        0.8022 
prediction$set_threshold(0.5)
prediction$score(msrs(c("classif.tpr", "classif.ppv", "classif.fbeta")))
  classif.tpr   classif.ppv classif.fbeta 
       0.9459        0.7636        0.8451 
prediction$set_threshold(0.8)
prediction$score(msrs(c("classif.tpr", "classif.ppv", "classif.fbeta")))
  classif.tpr   classif.ppv classif.fbeta 
       0.4369        0.9238        0.5933 

Threshold vs. Performance

# cost matrix as given on the UCI page of the german credit data set
# https://archive.ics.uci.edu/ml/datasets/statlog+(german+credit+data)
costs = matrix(c(0, 1, 5, 0), nrow = 2)
dimnames(costs) = list(pred = task_credit$class_names, truth = task_credit$class_names)
print(costs)
      truth
pred   good bad
  good    0   5
  bad     1   0
mm = msr("classif.costs", id = "cost", costs = costs, normalize = FALSE)
autoplot(prediction, measure = mm, type = "threshold")

ROC Curve

autoplot(prediction, type = "roc")

Evaluation and Benchmarking

Resampling

Resampling Strategy

Define how to repeatedly split data into train and test sets

Stored in mlr_resamplings dict, access with rmsp()

as.data.table(rsmp())[, .(key, label)]
Key: <key>
                   key                         label
                <char>                        <char>
 1:          bootstrap                     Bootstrap
 2:             custom                 Custom Splits
 3:          custom_cv Custom Split Cross-Validation
 4:                 cv              Cross-Validation
 5:            holdout                       Holdout
 6:           insample           Insample Resampling
 7:                loo                 Leave-One-Out
 8:          nested_cv                     Nested CV
 9: paired_subsampling            Paired Subsampling
10:        repeated_cv     Repeated Cross-Validation
11:        subsampling                   Subsampling

Cross Validation

Resampling Strategy

Holdout method

holdout = rsmp("holdout", ratio = 0.8)

3-fold CV

cv3 = rsmp("cv", folds = 3)

Subsampling with 3 repeats and 9/10 ratio

ss390 = rsmp("subsampling", repeats = 3, ratio = 0.9)

2-repeats 5-fold CV

rcv25 = rsmp("repeated_cv", repeats = 2, folds = 5)

Resampling Experiments

resample() repeatedly fits learner on training set

Then predicts test data

Store preds in ResampleResult

tsk_penguins = tsk("penguins")
lrn_rpart = lrn("classif.rpart")

rr = resample(tsk_penguins, lrn_rpart, cv3, store_models = TRUE)
rr
<ResampleResult> with 3 resampling iterations
  task_id    learner_id resampling_id iteration     prediction_test warnings errors
 penguins classif.rpart            cv         1 <PredictionClassif>        0      0
 penguins classif.rpart            cv         2 <PredictionClassif>        0      0
 penguins classif.rpart            cv         3 <PredictionClassif>        0      0

Scoring comes later, as separate operation

Score and Aggregate

Score and Aggregate Resample Result

Apply measure to each pred object via $score()

rr$score(msr("classif.ce"))
    task_id    learner_id resampling_id iteration classif.ce
     <char>        <char>        <char>     <int>      <num>
1: penguins classif.rpart            cv         1    0.06957
2: penguins classif.rpart            cv         2    0.06087
3: penguins classif.rpart            cv         3    0.06140
Hidden columns: task, learner, resampling, prediction_test

$aggregate() aggregates all of these scores into scalar

rr$aggregate(msr("classif.ce"))
classif.ce 
   0.06395 

ResampleResult Objects

Prediction object for each resampling iteration

rr$predictions()[[1]]
<PredictionClassif> for 115 observations:
 row_ids     truth  response
       2    Adelie    Adelie
       3    Adelie    Adelie
       8    Adelie    Adelie
     ---       ---       ---
     329 Chinstrap Chinstrap
     332 Chinstrap Chinstrap
     339 Chinstrap Chinstrap

ResampleResult Objects

Can also be used for model inspection

rr$learners[[1]]$model
n= 229 

node), split, n, loss, yval, (yprob)
      * denotes terminal node

1) root 229 127 Adelie (0.44541 0.20524 0.34934)  
  2) bill_length< 42.35 98   2 Adelie (0.97959 0.01020 0.01020) *
  3) bill_length>=42.35 131  52 Gentoo (0.04580 0.35115 0.60305)  
    6) island=Dream,Torgersen 51   5 Chinstrap (0.09804 0.90196 0.00000) *
    7) island=Biscoe 80   1 Gentoo (0.01250 0.00000 0.98750) *

Summary

task = tsk("mtcars")
learner = lrn("regr.rpart")
resampling = rsmp("cv", folds = 3)

rr = resample(task, learner, resampling, store_models = TRUE)

rr$aggregate(msr("classif.ce"))
rr$score(msr("classif.ce"))

rr$predictions()[[1]]
rr$learners[[1]]$model

Benchmarking

Compare multiple learners on a single task

Or multiple learners on multiple tasks

tasks = tsks(c("german_credit", "sonar"))
learners = lrns(c("classif.rpart", "classif.ranger",
  "classif.featureless"), predict_type = "prob")
rsmp_cv2 = rsmp("cv", folds = 2)

design = benchmark_grid(tasks, learners, rsmp_cv2)
design
            task             learner resampling
          <char>              <char>     <char>
1: german_credit       classif.rpart         cv
2: german_credit      classif.ranger         cv
3: german_credit classif.featureless         cv
4:         sonar       classif.rpart         cv
5:         sonar      classif.ranger         cv
6:         sonar classif.featureless         cv

Design DT can also be manually constructed and controlled

Benchmarking

Runs resample() on each row of the design

Collects the results in a BenchmarkResult object

bmr = benchmark(design)
bmr
<BenchmarkResult> of 12 rows with 6 resampling runs
 nr       task_id          learner_id resampling_id iters warnings errors
  1 german_credit       classif.rpart            cv     2        0      0
  2 german_credit      classif.ranger            cv     2        0      0
  3 german_credit classif.featureless            cv     2        0      0
  4         sonar       classif.rpart            cv     2        0      0
  5         sonar      classif.ranger            cv     2        0      0
  6         sonar classif.featureless            cv     2        0      0

Score Benchmark Result

$score() for each test set

bmr$score()
       nr       task_id          learner_id resampling_id iteration     prediction_test classif.ce
    <int>        <char>              <char>        <char>     <int>              <list>      <num>
 1:     1 german_credit       classif.rpart            cv         1 <PredictionClassif>     0.2780
 2:     1 german_credit       classif.rpart            cv         2 <PredictionClassif>     0.2800
 3:     2 german_credit      classif.ranger            cv         1 <PredictionClassif>     0.2420
 4:     2 german_credit      classif.ranger            cv         2 <PredictionClassif>     0.2260
 5:     3 german_credit classif.featureless            cv         1 <PredictionClassif>     0.3200
 6:     3 german_credit classif.featureless            cv         2 <PredictionClassif>     0.2800
 7:     4         sonar       classif.rpart            cv         1 <PredictionClassif>     0.2788
 8:     4         sonar       classif.rpart            cv         2 <PredictionClassif>     0.2308
 9:     5         sonar      classif.ranger            cv         1 <PredictionClassif>     0.1827
10:     5         sonar      classif.ranger            cv         2 <PredictionClassif>     0.1731
11:     6         sonar classif.featureless            cv         1 <PredictionClassif>     0.4712
12:     6         sonar classif.featureless            cv         2 <PredictionClassif>     0.4615
Hidden columns: uhash, task, learner, resampling

Aggregate Benchmark Result

$score() for each row of design

bmr$aggregate()
      nr       task_id          learner_id resampling_id iters classif.ce
   <int>        <char>              <char>        <char> <int>      <num>
1:     1 german_credit       classif.rpart            cv     2     0.2790
2:     2 german_credit      classif.ranger            cv     2     0.2340
3:     3 german_credit classif.featureless            cv     2     0.3000
4:     4         sonar       classif.rpart            cv     2     0.2548
5:     5         sonar      classif.ranger            cv     2     0.1779
6:     6         sonar classif.featureless            cv     2     0.4663
Hidden columns: resample_result

BenchmarkResult Objects

Collection of multiple ResampleResult objects

bmr$resample_result(1)
<ResampleResult> with 2 resampling iterations
       task_id    learner_id resampling_id iteration     prediction_test warnings errors
 german_credit classif.rpart            cv         1 <PredictionClassif>        0      0
 german_credit classif.rpart            cv         2 <PredictionClassif>        0      0

BenchmarkResult as Table

Convert to a data.table

as.data.table(bmr)
                                   uhash                        task                                         learner     resampling iteration          prediction
                                  <char>                      <list>                                          <list>         <list>     <int>              <list>
 1: d3e4712f-b956-432f-b36e-3673fe1bf9e4 <TaskClassif:german_credit>             <LearnerClassifRpart:classif.rpart> <ResamplingCV>         1 <PredictionClassif>
 2: d3e4712f-b956-432f-b36e-3673fe1bf9e4 <TaskClassif:german_credit>             <LearnerClassifRpart:classif.rpart> <ResamplingCV>         2 <PredictionClassif>
 3: 29a3313f-4cd4-4fe5-8436-f93c9487a0f7 <TaskClassif:german_credit>           <LearnerClassifRanger:classif.ranger> <ResamplingCV>         1 <PredictionClassif>
 4: 29a3313f-4cd4-4fe5-8436-f93c9487a0f7 <TaskClassif:german_credit>           <LearnerClassifRanger:classif.ranger> <ResamplingCV>         2 <PredictionClassif>
 5: 9ce70e43-be44-418c-ad99-d584b91adac7 <TaskClassif:german_credit> <LearnerClassifFeatureless:classif.featureless> <ResamplingCV>         1 <PredictionClassif>
 6: 9ce70e43-be44-418c-ad99-d584b91adac7 <TaskClassif:german_credit> <LearnerClassifFeatureless:classif.featureless> <ResamplingCV>         2 <PredictionClassif>
 7: 51bacc40-89dc-4694-95c9-66ad1bab4d6a         <TaskClassif:sonar>             <LearnerClassifRpart:classif.rpart> <ResamplingCV>         1 <PredictionClassif>
 8: 51bacc40-89dc-4694-95c9-66ad1bab4d6a         <TaskClassif:sonar>             <LearnerClassifRpart:classif.rpart> <ResamplingCV>         2 <PredictionClassif>
 9: 58a11a8c-4253-448c-b75c-2bc7bc20d108         <TaskClassif:sonar>           <LearnerClassifRanger:classif.ranger> <ResamplingCV>         1 <PredictionClassif>
10: 58a11a8c-4253-448c-b75c-2bc7bc20d108         <TaskClassif:sonar>           <LearnerClassifRanger:classif.ranger> <ResamplingCV>         2 <PredictionClassif>
11: d1292625-e45c-46a0-a2a3-1098b70fe8ae         <TaskClassif:sonar> <LearnerClassifFeatureless:classif.featureless> <ResamplingCV>         1 <PredictionClassif>
12: d1292625-e45c-46a0-a2a3-1098b70fe8ae         <TaskClassif:sonar> <LearnerClassifFeatureless:classif.featureless> <ResamplingCV>         2 <PredictionClassif>

Visualize Benchmark Results

autoplot(bmr, measure = msr("classif.acc"))

Hyperparameter Optimization

Hyperparameter Optimization Loop

Learner and Search Space

Decide which HPs to tune and what range to tune

as.data.table(lrn("classif.svm")$param_set)[1:12,
  .(id, class, lower, upper, nlevels)]
                 id    class lower upper nlevels
             <char>   <char> <num> <num>   <num>
 1:       cachesize ParamDbl  -Inf   Inf     Inf
 2:   class.weights ParamUty    NA    NA     Inf
 3:           coef0 ParamDbl  -Inf   Inf     Inf
 4:            cost ParamDbl     0   Inf     Inf
 5:           cross ParamInt     0   Inf     Inf
 6: decision.values ParamLgl    NA    NA       2
 7:          degree ParamInt     1   Inf     Inf
 8:         epsilon ParamDbl     0   Inf     Inf
 9:          fitted ParamLgl    NA    NA       2
10:           gamma ParamDbl     0   Inf     Inf
11:          kernel ParamFct    NA    NA       4
12:              nu ParamDbl  -Inf   Inf     Inf

TuneToken

to_tune() to flag the HP for later tuning

learner = lrn("classif.svm",
  type  = "C-classification",
  kernel = "radial",
  cost  = to_tune(1e-1, 1e5),
  gamma = to_tune(1e-1, 1)
)
learner
<LearnerClassifSVM:classif.svm>: Support Vector Machine
* Model: -
* Parameters: cost=<RangeTuneToken>, gamma=<RangeTuneToken>, kernel=radial, type=C-classification
* Packages: mlr3, mlr3learners, e1071
* Predict Types:  [response], prob
* Feature Types: logical, integer, numeric
* Properties: multiclass, twoclass

Terminator

Terminator Construction and defaults
Clock Time trm("clock_time")
Number of Evaluations trm("evals", n_evals = 100, k = 0)
Performance Level trm("perf_reached", level = 0.1)
Run Time trm("run_time", secs = 30)
Stagnation trm("stagnation", iters = 10, threshold = 0)

Terminator

trm("combo") allows to combine multiple terminators

trm("none") is used by tuners that terminate on their own

Terminator Construction and defaults
Combo trm("combo", any = TRUE)
None trm("none")

Tuning Instance

Sets up tuning objective and collects evals in archive

tsk_sonar = tsk("sonar")
instance = ti(
  task = tsk_sonar,
  learner = learner,
  resampling = rsmp("cv", folds = 3),
  measures = msr("classif.ce"),
  terminator = trm("none")
)
instance
<TuningInstanceBatchSingleCrit>
* State:  Not optimized
* Objective: <ObjectiveTuningBatch:classif.svm_on_sonar>
* Search Space:
       id    class lower  upper nlevels
   <char>   <char> <num>  <num>   <num>
1:   cost ParamDbl   0.1 100000     Inf
2:  gamma ParamDbl   0.1      1     Inf
* Terminator: <TerminatorNone>

Tuners as Black Box Optimizers

Stored in mlr_tuners dict, access with tnr()

Tuner Construction Package
Random Search tnr("random_search") mlr3tuning
Grid Search tnr("grid_search") mlr3tuning
CMA-ES tnr("cmaes") adagio
Gen. Simulated Annealing tnr("gensa") GenSA
Nonlinear Optimization tnr("nloptr") nloptr
Iterated Racing tnr("irace") irace
Hyperband tnr("hyperband") mlr3hyperband
Bayesian Optimization tnr("mbo") mlr3mbo

Control Parameters

Can be set as for learners in $param_set

tuner$param_set
<ParamSet(3)>
                  id    class lower upper nlevels        default  value
              <char>   <char> <num> <num>   <num>         <list> <list>
1:        batch_size ParamInt     1   Inf     Inf <NoDefault[0]>     10
2:        resolution ParamInt     1   Inf     Inf <NoDefault[0]>      5
3: param_resolutions ParamUty    NA    NA     Inf <NoDefault[0]> [NULL]

Triggering Tuning

Returns the best found HPC and performance

tuner$optimize(instance)
    cost gamma learner_param_vals  x_domain classif.ce
   <num> <num>             <list>    <list>      <num>
1: 75000   0.1          <list[4]> <list[2]>     0.3266

Result also stored in instance$result

instance$result
    cost gamma learner_param_vals  x_domain classif.ce
   <num> <num>             <list>    <list>      <num>
1: 75000   0.1          <list[4]> <list[2]>     0.3266

$learner_param_vals: optimal HPC + manually set stuff

$x_domain –> see trafos

Analyze Results

Archive stores all evaluated HPCs

as.data.table(instance$archive)
        cost gamma classif.ce runtime_learners           timestamp warnings errors  x_domain batch_nr  resample_result
       <num> <num>      <num>            <num>              <POSc>    <int>  <int>    <list>    <int>           <list>
 1:  25000.1 0.325     0.5046            0.059 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 2:  25000.1 0.775     0.5673            0.058 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 3:  25000.1 1.000     0.5673            0.043 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 4:  50000.1 0.325     0.5046            0.043 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 5:  50000.1 0.775     0.5673            0.043 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 6:  75000.0 0.100     0.3266            0.044 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 7:  75000.0 0.550     0.5480            0.043 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 8:  75000.0 0.775     0.5673            0.043 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
 9: 100000.0 0.550     0.5480            0.050 2025-03-24 08:20:06        0      0 <list[2]>        1 <ResampleResult>
10:      0.1 0.100     0.5673            0.042 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
11:      0.1 0.550     0.5673            0.052 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
12:  25000.1 0.100     0.3266            0.044 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
13:  25000.1 0.550     0.5480            0.044 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
14:  50000.1 0.100     0.3266            0.043 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
15:  50000.1 1.000     0.5673            0.044 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
16:  75000.0 0.325     0.5046            0.053 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
17: 100000.0 0.325     0.5046            0.042 2025-03-24 08:20:07        0      0 <list[2]>        2 <ResampleResult>
18:      0.1 0.325     0.5673            0.052 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
19:      0.1 0.775     0.5673            0.042 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
20:      0.1 1.000     0.5673            0.050 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
21:  50000.1 0.550     0.5480            0.044 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
22:  75000.0 1.000     0.5673            0.042 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
23: 100000.0 0.100     0.3266            0.054 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
24: 100000.0 0.775     0.5673            0.043 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
25: 100000.0 1.000     0.5673            0.045 2025-03-24 08:20:08        0      0 <list[2]>        3 <ResampleResult>
        cost gamma classif.ce runtime_learners           timestamp warnings errors  x_domain batch_nr  resample_result

Visualize Results

Visualize results as surface plot with mlr3viz

autoplot(instance, type = "surface")

Train Learner with Optimal HPC

lrn_svm_tuned = lrn("classif.svm")
lrn_svm_tuned$param_set$values = instance$result_learner_param_vals
lrn_svm_tuned$train(tsk_sonar)$model

Call:
svm.default(x = data, y = task$truth(), type = "C-classification", kernel = "radial", gamma = 0.1, cost = 75000.025, probability = (self$predict_type == "prob"))


Parameters:
   SVM-Type:  C-classification 
 SVM-Kernel:  radial 
       cost:  75000 

Number of Support Vectors:  205

Trafos on Log Scale

cost = runif(1000, log(1e-5), log(1e5))

Trafos on Log Scale

exp_cost = exp(cost)

Trafos on Log Scale

learner = lrn("classif.svm",
  cost  = to_tune(1e-5, 1e5, logscale = TRUE),
  gamma = to_tune(1e-5, 1e5, logscale = TRUE),
  kernel = "radial",
  type = "C-classification"
)

Convenient Tuning

mlr3tuning has some sugar functions

Same code as before

tnr_grid_search = tnr("grid_search", resolution = 5, batch_size = 5)
lrn_svm = lrn("classif.svm",
  cost  = to_tune(1e-5, 1e5, logscale = TRUE),
  gamma = to_tune(1e-5, 1e5, logscale = TRUE),
  kernel = "radial",
  type = "C-classification"
)
rsmp_cv3 = rsmp("cv", folds = 3)
msr_ce = msr("classif.ce")

Tuning with tune

Create tuning instance, then $optimize()

instance = tune(
  tuner = tnr_grid_search,
  task = tsk_sonar,
  learner = lrn_svm,
  resampling = rsmp_cv3,
  measures = msr_ce
)
instance$result
    cost  gamma learner_param_vals  x_domain classif.ce
   <num>  <num>             <list>    <list>      <num>
1: 11.51 -5.756          <list[4]> <list[2]>     0.1443

Summary

lrn_rpart = lrn("classif.rpart",
  minsplit  = to_tune(2, 128, logscale = TRUE),
  minbucket = to_tune(1, 64, logscale = TRUE),
  cp        = to_tune(1e-04, 1e-1, logscale = TRUE)
)

instance = ti(
  task = tsk("pima"),
  learner = lrn_rpart,
  resampling = rsmp("cv", folds = 3),
  measures = msr("classif.ce"),
  terminator = trm("evals", n_evals = 100)
)
tuner = tnr("random_search", batch_size = 10)
tuner$optimize(instance)

Tuning with auto_tuner

Tuning with auto_tuner

Inherits from Learner

Runs tune() when $train() is called on train data

Then trains learner with optimal HPC on train data

at = auto_tuner(
  tuner = tnr_grid_search,
  learner = lrn_svm,
  resampling = rsmp_cv3,
  measure = msr_ce
)

Tuning with auto_tuner

at
<AutoTuner:classif.svm.tuned>
* Model: list
* Parameters: list()
* Packages: mlr3, mlr3tuning, mlr3learners, e1071
* Predict Types:  [response], prob
* Feature Types: logical, integer, numeric
* Properties: multiclass, twoclass
* Search Space:
       id    class  lower upper nlevels
   <char>   <char>  <num> <num>   <num>
1:   cost ParamDbl -11.51 11.51     Inf
2:  gamma ParamDbl -11.51 11.51     Inf

Nested Resampling

Nested Resampling AutoTuner

Resampling AutoTuner with resample() or benchmark() does this naturally

at = auto_tuner(
  tuner = tnr_grid_search,
  learner = lrn_svm,
  resampling = rsmp("cv", folds = 4),
  measure = msr_ce,
  id = "svm"
)

rr = resample(tsk_sonar, at, rsmp_cv3, store_models = TRUE)

Estim. perf of tuned learner: aggr. perf on all outer test sets

rr$aggregate()
classif.ce 
    0.1586 

Inner Tuning Results

Optimal configurations across all outer folds

extract_inner_tuning_results(rr)
   iteration  cost  gamma classif.ce learner_param_vals  x_domain task_id learner_id resampling_id
       <int> <num>  <num>      <num>             <list>    <list>  <char>     <char>        <char>
1:         1 5.756 -5.756     0.1880          <list[4]> <list[2]>   sonar        svm            cv
2:         2 5.756 -5.756     0.2372          <list[4]> <list[2]>   sonar        svm            cv
3:         3 5.756 -5.756     0.1376          <list[4]> <list[2]>   sonar        svm            cv

Inner Tuning Archives

Full tuning archives

extract_inner_tuning_archives(rr)
    iteration    cost   gamma classif.ce x_domain_cost x_domain_gamma runtime_learners           timestamp warnings errors batch_nr  resample_result task_id learner_id resampling_id
        <int>   <num>   <num>      <num>         <num>          <num>            <num>              <POSc>    <int>  <int>    <int>           <list>  <char>     <char>        <char>
 1:         1 -11.513 -11.513     0.4317     1.000e-05      1.000e-05            0.043 2025-03-24 08:20:16        0      0        1 <ResampleResult>   sonar        svm            cv
 2:         1   0.000  -5.756     0.2447     1.000e+00      3.162e-03            0.051 2025-03-24 08:20:16        0      0        1 <ResampleResult>   sonar        svm            cv
 3:         1   0.000   5.756     0.4317     1.000e+00      3.162e+02            0.046 2025-03-24 08:20:16        0      0        1 <ResampleResult>   sonar        svm            cv
 4:         1   5.756   0.000     0.4317     3.162e+02      1.000e+00            0.047 2025-03-24 08:20:16        0      0        1 <ResampleResult>   sonar        svm            cv
 5:         1   5.756   5.756     0.4317     3.162e+02      3.162e+02            0.057 2025-03-24 08:20:16        0      0        1 <ResampleResult>   sonar        svm            cv
 6:         1 -11.513  -5.756     0.4317     1.000e-05      3.162e-03            0.043 2025-03-24 08:20:17        0      0        2 <ResampleResult>   sonar        svm            cv
 7:         1 -11.513   5.756     0.4317     1.000e-05      3.162e+02            0.043 2025-03-24 08:20:17        0      0        2 <ResampleResult>   sonar        svm            cv
 8:         1   0.000 -11.513     0.4317     1.000e+00      1.000e-05            0.054 2025-03-24 08:20:17        0      0        2 <ResampleResult>   sonar        svm            cv
 9:         1   0.000  11.513     0.4317     1.000e+00      1.000e+05            0.046 2025-03-24 08:20:17        0      0        2 <ResampleResult>   sonar        svm            cv
10:         1  11.513 -11.513     0.2158     1.000e+05      1.000e-05            0.043 2025-03-24 08:20:17        0      0        2 <ResampleResult>   sonar        svm            cv
11:         1  -5.756  11.513     0.4317     3.162e-03      1.000e+05            0.044 2025-03-24 08:20:17        0      0        3 <ResampleResult>   sonar        svm            cv
12:         1   0.000   0.000     0.4317     1.000e+00      1.000e+00            0.047 2025-03-24 08:20:17        0      0        3 <ResampleResult>   sonar        svm            cv
13:         1   5.756  -5.756     0.1880     3.162e+02      3.162e-03            0.050 2025-03-24 08:20:17        0      0        3 <ResampleResult>   sonar        svm            cv
14:         1  11.513  -5.756     0.1880     1.000e+05      3.162e-03            0.044 2025-03-24 08:20:17        0      0        3 <ResampleResult>   sonar        svm            cv
15:         1  11.513   0.000     0.4317     1.000e+05      1.000e+00            0.047 2025-03-24 08:20:17        0      0        3 <ResampleResult>   sonar        svm            cv
16:         1 -11.513   0.000     0.4317     1.000e-05      1.000e+00            0.065 2025-03-24 08:20:18        0      0        4 <ResampleResult>   sonar        svm            cv
17:         1  -5.756   5.756     0.4317     3.162e-03      3.162e+02            0.047 2025-03-24 08:20:18        0      0        4 <ResampleResult>   sonar        svm            cv
18:         1   5.756 -11.513     0.2450     3.162e+02      1.000e-05            0.039 2025-03-24 08:20:18        0      0        4 <ResampleResult>   sonar        svm            cv
19:         1   5.756  11.513     0.4317     3.162e+02      1.000e+05            0.060 2025-03-24 08:20:18        0      0        4 <ResampleResult>   sonar        svm            cv
20:         1  11.513  11.513     0.4317     1.000e+05      1.000e+05            0.047 2025-03-24 08:20:18        0      0        4 <ResampleResult>   sonar        svm            cv
21:         1 -11.513  11.513     0.4317     1.000e-05      1.000e+05            0.044 2025-03-24 08:20:18        0      0        5 <ResampleResult>   sonar        svm            cv
22:         1  -5.756 -11.513     0.4317     3.162e-03      1.000e-05            0.056 2025-03-24 08:20:18        0      0        5 <ResampleResult>   sonar        svm            cv
23:         1  -5.756  -5.756     0.4317     3.162e-03      3.162e-03            0.042 2025-03-24 08:20:18        0      0        5 <ResampleResult>   sonar        svm            cv
24:         1  -5.756   0.000     0.4317     3.162e-03      1.000e+00            0.047 2025-03-24 08:20:18        0      0        5 <ResampleResult>   sonar        svm            cv
25:         1  11.513   5.756     0.4317     1.000e+05      3.162e+02            0.055 2025-03-24 08:20:18        0      0        5 <ResampleResult>   sonar        svm            cv
26:         2 -11.513   5.756     0.5397     1.000e-05      3.162e+02            0.045 2025-03-24 08:20:19        0      0        1 <ResampleResult>   sonar        svm            cv
27:         2   0.000 -11.513     0.4611     1.000e+00      1.000e-05            0.059 2025-03-24 08:20:19        0      0        1 <ResampleResult>   sonar        svm            cv
28:         2   5.756  11.513     0.5397     3.162e+02      1.000e+05            0.045 2025-03-24 08:20:19        0      0        1 <ResampleResult>   sonar        svm            cv
29:         2  11.513 -11.513     0.2805     1.000e+05      1.000e-05            0.054 2025-03-24 08:20:19        0      0        1 <ResampleResult>   sonar        svm            cv
30:         2  11.513   5.756     0.5397     1.000e+05      3.162e+02            0.047 2025-03-24 08:20:19        0      0        1 <ResampleResult>   sonar        svm            cv
31:         2  -5.756   5.756     0.5397     3.162e-03      3.162e+02            0.044 2025-03-24 08:20:20        0      0        2 <ResampleResult>   sonar        svm            cv
32:         2   0.000   5.756     0.5397     1.000e+00      3.162e+02            0.060 2025-03-24 08:20:20        0      0        2 <ResampleResult>   sonar        svm            cv
33:         2   0.000  11.513     0.5397     1.000e+00      1.000e+05            0.045 2025-03-24 08:20:20        0      0        2 <ResampleResult>   sonar        svm            cv
34:         2   5.756 -11.513     0.3097     3.162e+02      1.000e-05            0.043 2025-03-24 08:20:20        0      0        2 <ResampleResult>   sonar        svm            cv
35:         2  11.513  11.513     0.5397     1.000e+05      1.000e+05            0.059 2025-03-24 08:20:20        0      0        2 <ResampleResult>   sonar        svm            cv
36:         2 -11.513 -11.513     0.4611     1.000e-05      1.000e-05            0.044 2025-03-24 08:20:20        0      0        3 <ResampleResult>   sonar        svm            cv
37:         2  -5.756  -5.756     0.4683     3.162e-03      3.162e-03            0.054 2025-03-24 08:20:20        0      0        3 <ResampleResult>   sonar        svm            cv
38:         2   0.000  -5.756     0.3237     1.000e+00      3.162e-03            0.044 2025-03-24 08:20:20        0      0        3 <ResampleResult>   sonar        svm            cv
39:         2   5.756  -5.756     0.2372     3.162e+02      3.162e-03            0.044 2025-03-24 08:20:20        0      0        3 <ResampleResult>   sonar        svm            cv
40:         2  11.513   0.000     0.5183     1.000e+05      1.000e+00            0.048 2025-03-24 08:20:20        0      0        3 <ResampleResult>   sonar        svm            cv
41:         2 -11.513   0.000     0.5254     1.000e-05      1.000e+00            0.048 2025-03-24 08:20:21        0      0        4 <ResampleResult>   sonar        svm            cv
42:         2 -11.513  11.513     0.5397     1.000e-05      1.000e+05            0.043 2025-03-24 08:20:21        0      0        4 <ResampleResult>   sonar        svm            cv
43:         2  -5.756 -11.513     0.4611     3.162e-03      1.000e-05            0.056 2025-03-24 08:20:21        0      0        4 <ResampleResult>   sonar        svm            cv
44:         2  -5.756  11.513     0.5397     3.162e-03      1.000e+05            0.044 2025-03-24 08:20:21        0      0        4 <ResampleResult>   sonar        svm            cv
45:         2   5.756   0.000     0.5183     3.162e+02      1.000e+00            0.046 2025-03-24 08:20:21        0      0        4 <ResampleResult>   sonar        svm            cv
46:         2 -11.513  -5.756     0.4683     1.000e-05      3.162e-03            0.045 2025-03-24 08:20:21        0      0        5 <ResampleResult>   sonar        svm            cv
47:         2  -5.756   0.000     0.5254     3.162e-03      1.000e+00            0.045 2025-03-24 08:20:21        0      0        5 <ResampleResult>   sonar        svm            cv
48:         2   0.000   0.000     0.5183     1.000e+00      1.000e+00            0.055 2025-03-24 08:20:21        0      0        5 <ResampleResult>   sonar        svm            cv
49:         2   5.756   5.756     0.5397     3.162e+02      3.162e+02            0.045 2025-03-24 08:20:21        0      0        5 <ResampleResult>   sonar        svm            cv
50:         2  11.513  -5.756     0.2372     1.000e+05      3.162e-03            0.042 2025-03-24 08:20:21        0      0        5 <ResampleResult>   sonar        svm            cv
51:         3 -11.513   5.756     0.4634     1.000e-05      3.162e+02            0.056 2025-03-24 08:20:13        0      0        1 <ResampleResult>   sonar        svm            cv
52:         3   5.756 -11.513     0.2242     3.162e+02      1.000e-05            0.042 2025-03-24 08:20:13        0      0        1 <ResampleResult>   sonar        svm            cv
53:         3   5.756  -5.756     0.1376     3.162e+02      3.162e-03            0.070 2025-03-24 08:20:13        0      0        1 <ResampleResult>   sonar        svm            cv
54:         3   5.756   0.000     0.4634     3.162e+02      1.000e+00            0.044 2025-03-24 08:20:13        0      0        1 <ResampleResult>   sonar        svm            cv
55:         3  11.513  11.513     0.4634     1.000e+05      1.000e+05            0.045 2025-03-24 08:20:13        0      0        1 <ResampleResult>   sonar        svm            cv
56:         3 -11.513  -5.756     0.4634     1.000e-05      3.162e-03            0.045 2025-03-24 08:20:14        0      0        2 <ResampleResult>   sonar        svm            cv
57:         3 -11.513   0.000     0.4634     1.000e-05      1.000e+00            0.055 2025-03-24 08:20:14        0      0        2 <ResampleResult>   sonar        svm            cv
58:         3  -5.756 -11.513     0.4634     3.162e-03      1.000e-05            0.042 2025-03-24 08:20:14        0      0        2 <ResampleResult>   sonar        svm            cv
59:         3   0.000  -5.756     0.2097     1.000e+00      3.162e-03            0.044 2025-03-24 08:20:14        0      0        2 <ResampleResult>   sonar        svm            cv
60:         3   0.000   0.000     0.4634     1.000e+00      1.000e+00            0.053 2025-03-24 08:20:14        0      0        2 <ResampleResult>   sonar        svm            cv
61:         3 -11.513  11.513     0.4634     1.000e-05      1.000e+05            0.043 2025-03-24 08:20:14        0      0        3 <ResampleResult>   sonar        svm            cv
62:         3  -5.756  -5.756     0.4634     3.162e-03      3.162e-03            0.044 2025-03-24 08:20:14        0      0        3 <ResampleResult>   sonar        svm            cv
63:         3  -5.756  11.513     0.4634     3.162e-03      1.000e+05            0.045 2025-03-24 08:20:14        0      0        3 <ResampleResult>   sonar        svm            cv
64:         3  11.513  -5.756     0.1376     1.000e+05      3.162e-03            0.049 2025-03-24 08:20:14        0      0        3 <ResampleResult>   sonar        svm            cv
65:         3  11.513   0.000     0.4634     1.000e+05      1.000e+00            0.043 2025-03-24 08:20:14        0      0        3 <ResampleResult>   sonar        svm            cv
66:         3  -5.756   5.756     0.4634     3.162e-03      3.162e+02            0.055 2025-03-24 08:20:15        0      0        4 <ResampleResult>   sonar        svm            cv
67:         3   0.000  11.513     0.4634     1.000e+00      1.000e+05            0.047 2025-03-24 08:20:15        0      0        4 <ResampleResult>   sonar        svm            cv
68:         3   5.756   5.756     0.4634     3.162e+02      3.162e+02            0.043 2025-03-24 08:20:15        0      0        4 <ResampleResult>   sonar        svm            cv
69:         3   5.756  11.513     0.4634     3.162e+02      1.000e+05            0.045 2025-03-24 08:20:15        0      0        4 <ResampleResult>   sonar        svm            cv
70:         3  11.513   5.756     0.4634     1.000e+05      3.162e+02            0.055 2025-03-24 08:20:15        0      0        4 <ResampleResult>   sonar        svm            cv
71:         3 -11.513 -11.513     0.4634     1.000e-05      1.000e-05            0.043 2025-03-24 08:20:16        0      0        5 <ResampleResult>   sonar        svm            cv
72:         3  -5.756   0.000     0.4634     3.162e-03      1.000e+00            0.044 2025-03-24 08:20:16        0      0        5 <ResampleResult>   sonar        svm            cv
73:         3   0.000 -11.513     0.4634     1.000e+00      1.000e-05            0.055 2025-03-24 08:20:16        0      0        5 <ResampleResult>   sonar        svm            cv
74:         3   0.000   5.756     0.4634     1.000e+00      3.162e+02            0.044 2025-03-24 08:20:16        0      0        5 <ResampleResult>   sonar        svm            cv
75:         3  11.513 -11.513     0.2395     1.000e+05      1.000e-05            0.044 2025-03-24 08:20:16        0      0        5 <ResampleResult>   sonar        svm            cv
    iteration    cost   gamma classif.ce x_domain_cost x_domain_gamma runtime_learners           timestamp warnings errors batch_nr  resample_result task_id learner_id resampling_id

Defining Search Spaces with ps

Full control over with ps

search_space = ps(
  cost  = p_dbl(lower = 1e-1, upper = 1e5),
  kernel = p_fct(c("radial", "linear")),
  shrinking = p_lgl()
)

Pass to tuning instance

ti(tsk_sonar, lrn("classif.svm", type = "C-classification"), rsmp_cv3,
  msr_ce, trm("none"), search_space = search_space)
<TuningInstanceBatchSingleCrit>
* State:  Not optimized
* Objective: <ObjectiveTuningBatch:classif.svm_on_sonar>
* Search Space:
          id    class lower  upper nlevels
      <char>   <char> <num>  <num>   <num>
1:      cost ParamDbl   0.1 100000     Inf
2:    kernel ParamFct    NA     NA       2
3: shrinking ParamLgl    NA     NA       2
* Terminator: <TerminatorNone>

Complex Transformations

search_space = ps(
  cost = p_dbl(-1, 1, trafo = function(x) exp(x)),
  kernel = p_fct(c("polynomial", "radial")),
  .extra_trafo = function(x, param_set) {
    if (x$kernel == "polynomial") x$cost = x$cost + 2
    x
  }
)
search_space$trafo(list(cost = 1, kernel = "radial"))
$cost
[1] 2.718

$kernel
[1] "radial"
search_space$trafo(list(cost = 1, kernel = "polynomial"))
$cost
[1] 4.718

$kernel
[1] "polynomial"

Predefined Search Spaces

Can be stored in and retrieved with mlr3tuningspaces

library(mlr3tuningspaces)
as.data.table(mlr_tuning_spaces)[1:3, .(key, label)]
Key: <key>
                      key                             label
                   <char>                            <char>
1: classif.glmnet.default   Classification GLM with Default
2:    classif.glmnet.rbv1 Classification GLM with RandomBot
3:    classif.glmnet.rbv2 Classification GLM with RandomBot

Summary

lrn_rpart = lrn("classif.rpart",
  minsplit  = to_tune(2, 128, logscale = TRUE),
  minbucket = to_tune(1, 64, logscale = TRUE),
  cp        = to_tune(1e-04, 1e-1, logscale = TRUE)
)

at = auto_tuner(
  tuner = tnr("random_search", batch_size = 10),
  learner = lrn_rpart,
  resampling = rsmp("cv", folds = 4),
  measure = msr("classif.ce"),
)

rr = resample(tsk("pima"), at, rsmp("cv", folds = 3), store_models = TRUE)

Sequential Pipelines

Sequential Pipelines

Workflows including data preprocessing, building ensemble-models, or more complicated meta-models

PipeOps are the building blocks

PipeOps are connected to form a Graph or pipeline

PipeOp

Short for Pipeline Operator

Includes a $train() and a $predict() method

Has a $param_set field that defines the hyperparameters

Constructed with the po() function

library(mlr3pipelines)

po_pca = po("pca", center = TRUE)
po_pca
PipeOp: <pca> (not trained)
values: <center=TRUE>
Input channels <name [train type, predict type]>:
  input [Task,Task]
Output channels <name [train type, predict type]>:
  output [Task,Task]

PipeOp Train

PipeOp includes a $train() and a $predict() method

The po("pca") applies a principal component analysis

tsk_small = tsk("penguins_simple")$select(c("bill_depth", "bill_length"))
poin = list(tsk_small$clone()$filter(1:5))
poout = po_pca$train(poin) # poin: Task in a list
poout # list with a single element 'output'
$output
<TaskClassif:penguins> (5 x 3): Simplified Palmer Penguins
* Target: species
* Properties: multiclass
* Features (2):
  - dbl (2): PC1, PC2
poout[[1]]$head(3)
   species    PC1       PC2
    <fctr>  <num>     <num>
1:  Adelie 0.1561  0.005716
2:  Adelie 1.2677  0.789534
3:  Adelie 1.5336 -0.174460

PipeOp State

The training phase typically generates a particular model of the data, which is saved as the internal $state field

The $state field of po("pca") contains the rotation matrix

po_pca$state
Standard deviations (1, .., p=2):
[1] 1.513 1.034

Rotation (n x k) = (2 x 2):
                PC1     PC2
bill_depth  -0.6116 -0.7911
bill_length  0.7911 -0.6116

PipeOp Predict

This state is then used during predictions and applied to new data

tsk_onepenguin = tsk_small$clone()$filter(42)
poin = list(tsk_onepenguin)
poout = po_pca$predict(poin)
poout[[1]]$data()
   species   PC1    PC2
    <fctr> <num>  <num>
1:  Adelie 1.555 -1.455

Graph

PipeOps represent individual computational steps in machine learning pipelines

These pipelines themselves are defined by Graph objects

A Graph is a collection of PipeOps with “edges” that guide the flow of data

Graph

The most convenient way of building a Graph is to connect a sequence of PipeOps using the %>>%-operator

po_mutate = po("mutate",
  mutation = list(bill_ratio = ~bill_length / bill_depth)
)
po_scale = po("scale")
graph = po_mutate %>>% po_scale
graph
Graph with 2 PipeOps:
     ID         State sccssors prdcssors
 <char>        <char>   <char>    <char>
 mutate <<UNTRAINED>>    scale          
  scale <<UNTRAINED>>             mutate

Graph

graph$plot(horizontal = TRUE)

Sequential Learner-Pipelines

Most common application for mlr3pipelines is to preprocess data before feeding it into a Learner

Learners as PipeOps

Learner objects can be converted to PipeOps

lrn_logreg = lrn("classif.log_reg")
graph = po("imputesample") %>>% lrn_logreg
graph$plot(horizontal = TRUE)

Graphs as Learners

To use a Graph as a Learner with an identical interface, it can be wrapped in a GraphLearner object with as_learner()

glrn_sample = as_learner(graph)
glrn_mode = as_learner(po("imputemode") %>>% lrn_logreg)

design = benchmark_grid(tsk("pima"), list(glrn_sample, glrn_mode),
  rsmp("cv", folds = 3))
bmr = benchmark(design)
aggr = bmr$aggregate()[, .(learner_id, classif.ce)]
aggr
                     learner_id classif.ce
                         <char>      <num>
1: imputesample.classif.log_reg     0.2253
2:   imputemode.classif.log_reg     0.2292

Configuring Pipeline Hyperparameters

PipeOp hyperparameters are collected together in the $param_set of a graph and prefixed with the ID of the PipeOp

graph = po("scale", center = FALSE, scale = TRUE, id = "scale") %>>%
  po("scale", center = TRUE, scale = FALSE, id = "center") %>>%
  lrn("classif.rpart", cp = 1)
unlist(graph$param_set$values)
      scale.center        scale.scale       scale.robust      center.center       center.scale      center.robust   classif.rpart.cp classif.rpart.xval 
                 0                  1                  0                  1                  0                  0                  1                  0 

Non-Sequential Pipelines

Non-Sequential Pipelines

Non-sequential pipelines can perform more complex operations

Using the gunion() function, we can instead combine multiple PipeOps, Graphs, or a mixture of both, into a parallel Graph

graph = po("scale", center = TRUE, scale = FALSE) %>>%
  gunion(list(
    po("missind"),
    po("imputemedian")
  )) %>>%
  po("featureunion")

graph$plot(horizontal = TRUE)

Common Patterns and ppl()

Many common problems in ML can be well solved by the same pipelines

ppl("bagging", graph) creates a bagging ensemble

ppl("branch", graphs) creates a branch

ppl("robustify") common preprocessing steps

ppl("stacking", base_learners, super_learner) creates a stacking ensemble

Branching

po("branch") creates multiple paths such that data can only flow through one of these as determined by the selection hyperparameter

use po("unbranch") (with the same arguments as "branch") to ensure that the outputs are merged into one result object

Branching

To demonstrate alternative paths we will make use of the MNIST [@lecun1998gradient] data, which is useful for demonstrating preprocessing

library(mlr3oml)
otsk_mnist = otsk(id = 3573)
tsk_mnist = as_task(otsk_mnist)$
  filter(sample(70000, 1000))$
  select(otsk_mnist$feature_names[sample(700, 100)])

Branching

Do nothing po("nop")

Apply PCA po("pca")

Remove constant features po("removeconstants") then apply the Yeo-Johnson transform po("yeojohnson")

paths = c("nop", "pca", "yeojohnson")

graph = po("branch", paths, id = "brnchPO") %>>%
  gunion(list(
    po("nop"),
    po("pca"),
    po("removeconstants", id = "rm_const") %>>%
      po("yeojohnson", id = "YJ")
  )) %>>% po("unbranch", paths, id = "unbrnchPO")

Branching

graph$plot(horizontal = TRUE)

Branching

The output of this Graph depends on the setting of the branch.selection hyperparameter

# use the "PCA" path
graph$param_set$values$brnchPO.selection = "pca"
# new PCA columns
head(graph$train(tsk_mnist)[[1]]$feature_names)
[1] "PC1" "PC2" "PC3" "PC4" "PC5" "PC6"
# use the "No-Op" path
graph$param_set$values$brnchPO.selection = "nop"
# same features
head(graph$train(tsk_mnist)[[1]]$feature_names)
[1] "pixel3"  "pixel6"  "pixel24" "pixel29" "pixel32" "pixel34"

Tune Branch Pipeline

Branching can even be used to tune which of several learners is most appropriate for a given dataset

graph_learner = graph %>>%
  ppl("branch", lrns(c("classif.rpart", "classif.kknn")))
graph_learner$plot(horizontal = TRUE)

Tune Branch Pipeline

Tuning the selection hyperparameters can help determine which of the possible options work best in combination

graph_learner = as_learner(graph_learner)

graph_learner$param_set$set_values(
  brnchPO.selection = to_tune(paths),
  branch.selection = to_tune(c("classif.rpart", "classif.kknn")),
  classif.kknn.k = to_tune(p_int(1, 32,
    depends = branch.selection == "classif.kknn"))
)

instance = tune(tnr("grid_search"), tsk_mnist, graph_learner,
  rsmp("repeated_cv", folds = 3, repeats = 3), msr("classif.ce"))

instance$archive$data[order(classif.ce)[1:5],
  .(brnchPO.selection, classif.kknn.k, branch.selection, classif.ce)]

autoplot(instance)

Additional Features

  • Error handling with encapsulation and fallbacks
  • Parallelization with future, mlr3batchmark and rush
  • Controlled logging with lgr package