Supervised Learning I

Train, predict, evaluate with mlr3

Goal

Our goal for this exercise sheet is to learn the basics of mlr3 for supervised learning by training a first simple model on training data and by evaluating its performance on hold-out/test data.

German Credit Dataset

The German credit dataset was donated by Prof. Dr. Hans Hoffman of the University of Hamburg in 1994 and contains 1000 datapoints reflecting bank customers. The goal is to classify people as a good or bad credit risk based on 20 personal, demographic and financial features. The dataset is available at the UCI repository as Statlog (German Credit Data) Data Set.

Motivation of Risk Prediction

Customers who do not repay the distributed loan on time represent an enormous risk for a bank: First, because they create an unintended gap in the bank’s planning, and second, because the collection of the repayment amount additionally causes additional time and cost for the bank.

On the other hand, (interest rates for) loans are an important revenue stream for banks. If a person’s loan is rejected, even though they would have met the repayment deadlines, revenue is lost, as well as potential upselling opportunities.

Banks are therefore highly interested in a risk prediction model that accurately predicts the risk of future customers. This is where supervised learning models come into play.

Data Overview

n = 1,000 observations of bank customers

  • credit_risk: is the customer a good or bad credit risk?
  • age: age in years
  • amount: amount asked by applicant
  • credit_history: past credit history of applicant at this bank
  • duration: duration of the credit in months
  • employment_duration: present employment since
  • foreign_worker: is applicant foreign worker?
  • housing: type of apartment rented, owned, for free / no payment
  • installment_rate: installment rate in percentage of disposable income
  • job: current job information
  • number_credits: number of existing credits at this bank
  • other_debtors: other debtors/guarantors present?
  • other_installment_plans: other installment plans the applicant is paying
  • people_liable: number of people being liable to provide maintenance
  • personal_status_sex: combination of sex and personal status of applicant
  • present_residence: present residence since
  • property: properties that applicant has
  • purpose: reason customer is applying for a loan
  • savings: savings accounts/bonds at this bank
  • status: status/balance of checking account at this bank
  • telephone: is there any telephone registered for this customer?

Preprocessing

We first load the data from the rchallenge package (you may need to install it first) and get a brief overview.

# install.packages("rchallenge")
library("rchallenge")
data("german")
skimr::skim(german)
Data summary
Name german
Number of rows 1000
Number of columns 21
_______________________
Column type frequency:
factor 18
numeric 3
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
status 0 1 FALSE 4 …: 394, no : 274, …: 269, 0<=: 63
credit_history 0 1 FALSE 5 no : 530, all: 293, exi: 88, cri: 49
purpose 0 1 FALSE 10 fur: 280, oth: 234, car: 181, car: 103
savings 0 1 FALSE 5 unk: 603, …: 183, …: 103, 100: 63
employment_duration 0 1 FALSE 5 1 <: 339, >= : 253, 4 <: 174, < 1: 172
installment_rate 0 1 TRUE 4 < 2: 476, 25 : 231, 20 : 157, >= : 136
personal_status_sex 0 1 FALSE 4 mal: 548, fem: 310, fem: 92, mal: 50
other_debtors 0 1 FALSE 3 non: 907, gua: 52, co-: 41
present_residence 0 1 TRUE 4 >= : 413, 1 <: 308, 4 <: 149, < 1: 130
property 0 1 FALSE 4 bui: 332, unk: 282, car: 232, rea: 154
other_installment_plans 0 1 FALSE 3 non: 814, ban: 139, sto: 47
housing 0 1 FALSE 3 ren: 714, for: 179, own: 107
number_credits 0 1 TRUE 4 1: 633, 2-3: 333, 4-5: 28, >= : 6
job 0 1 FALSE 4 ski: 630, uns: 200, man: 148, une: 22
people_liable 0 1 FALSE 2 0 t: 845, 3 o: 155
telephone 0 1 FALSE 2 no: 596, yes: 404
foreign_worker 0 1 FALSE 2 no: 963, yes: 37
credit_risk 0 1 FALSE 2 goo: 700, bad: 300

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
duration 0 1 20.90 12.06 4 12.0 18.0 24.0 72 ▇▇▂▁▁
amount 0 1 3271.25 2822.75 250 1365.5 2319.5 3972.2 18424 ▇▂▁▁▁
age 0 1 35.54 11.35 19 27.0 33.0 42.0 75 ▇▆▃▁▁

Exercises:

Now, we can start building a model. To do so, we need to address the following questions:

  • What is the problem we are trying to solve?
  • What is an appropriate learning algorithm?
  • How do we evaluate “good” performance?

More systematically in mlr3 they can be expressed via five components:

  • The Task definition.
  • The Learner definition.
  • The training via $train().
  • The prediction via $predict().
  • The evaluation via one $score().

Create a Classification Task

Install and load the mlr3verse package which is a collection of multiple add-on packages in the mlr3 universe (if you fail installing mlr3verse, try to install and load only the mlr3 and mlr3learners packages). Then, create a classification task using the training data as an input and credit_risk as the target variable (with the class label good as the positive class). By defining an mlr3 task, we conceptualize the ML problem we want to solve (here we face a classification task). As we have a classification task here, make sure you properly specify the class that should be used as the positive class (i.e., the class label for which we would like to predict probabilities - here good if you are interested in predicting a probability for the creditworthiness of customers).

Hint 1:

Use e.g. as_task_classif() to create a classification task.

Hint 2:
library(mlr3verse)
task = as_task_classif(x = ..., target = ..., ... = "good")

Solution:

Click me:

To initialize a TaskClassif object, two equivalent calls exist:

library("mlr3verse")
task = TaskClassif$new("german_credit", backend = german, target = "credit_risk", positive = "good")
task = as_task_classif(german, target = "credit_risk", positive = "good")
task
## <TaskClassif:german> (1000 x 21)
## * Target: credit_risk
## * Properties: twoclass
## * Features (20):
##   - fct (14): credit_history, employment_duration, foreign_worker, housing, job, other_debtors,
##     other_installment_plans, people_liable, personal_status_sex, property, purpose, savings,
##     status, telephone
##   - int (3): age, amount, duration
##   - ord (3): installment_rate, number_credits, present_residence

Alternatively, we can directly use the task included in mlr3:

library("mlr3verse")
task = tsk("german_credit")

Split Data in Training and Test Data

Your task is to split the dataset into 70 % training data and 30 % test data by randomly sampling rows. Later, we will use the training data to learn an ML model and use the test data to assess its performance.

Recap: Why do we need train and test data?

We use part of the available data (the training data) to train our model. The remaining/hold-out data (test data) is used to evaluate the trained model. This is exactly how we anticipate using the model in practice: We want to fit the model to existing data and then make predictions on new, unseen data points for which we do not know the outcome/target values.

Note: Hold-out splitting requires a dataset that is sufficiently large such that both the training and test dataset are suitable representations of the target population. What “sufficiently large” means depends on the dataset at hand and the complexity of the problem.

The ratio of training to test data is also context dependent. In practice, a 70% to 30% (~ 2:1) ratio is a good starting point.

Hint 1:

Use partition() on the task to create a train index and a test index.

Set a seed (e.g, set.seed(100L)) to make your results reproducible.

Hint 2:
# Sample ids for training and test split
set.seed(100L)
splits = partition(...)

Solution:

Click me:

We sample row ids by using partition().

set.seed(100L)
splits = partition(task)
splits
## $train
##   [1]    1    2    4    6    7    9   10   11   12   13   14   15   16   18   19   21   23   24   26   27   28
##  [22]   31   32   34   36   37   38   42   45   46   47   48   53   54   55   56   60   61   64   65   66   67
##  [43]   68   71   72   76   77   78   79   82   83   84   85   86   87   88   90   91   92   93   94   95   96
##  [64]   98  100  102  104  107  108  111  112  114  115  116  117  118  119  120  121  124  125  128  129  130
##  [85]  131  133  134  135  136  137  138  140  142  144  145  146  147  150  153  154  155  156  157  158  161
## [106]  163  165  166  167  168  169  170  171  173  174  175  177  183  184  185  187  192  193  194  197  198
## [127]  201  202  203  205  206  209  210  213  214  218  220  221  222  223  224  227  228  229  232  233  237
## [148]  239  240  241  242  243  244  246  247  248  249  250  251  254  255  256  258  259  261  262  263  264
## [169]  267  268  269  270  272  273  274  276  277  278  280  281  282  283  284  286  287  288  289  290  291
## [190]  292  293  296  297  298  299  300  301  302  303  304  305  306  307  308  309  310  313  314  316  318
## [211]  321  322  323  324  325  326  327  328  329  330  332  334  335  336  337  338  339  340  341  342  345
## [232]  346  347  349  350  351  353  354  355  358  359  362  363  364  366  367  370  371  372  374  376  377
## [253]  378  379  380  383  385  387  389  390  392  393  394  395  396  397  400  401  402  403  404  405  406
## [274]  410  411  415  417  420  421  422  423  424  425  426  427  428  429  430  431  435  436  437  439  442
## [295]  444  447  448  449  450  456  457  458  459  461  462  463  464  465  466  467  469  470  471  473  474
## [316]  476  477  478  479  480  481  482  483  486  487  488  489  490  491  492  493  494  495  496  497  500
## [337]  502  503  504  505  506  507  509  510  514  515  516  518  519  520  521  523  525  526  528  530  532
## [358]  533  536  537  538  541  543  544  546  548  549  550  551  552  553  554  555  556  557  558  559  560
## [379]  564  567  568  570  572  573  576  577  580  584  585  587  590  592  593  594  596  597  598  599  600
## [400]  605  606  607  609  611  613  614  615  616  618  619  623  624  625  627  628  630  631  632  633  634
## [421]  636  637  639  640  641  642  643  644  646  648  650  651  652  653  655  656  658  659  660  661  662
## [442]  663  666  668  669  670  671  675  679  682  683  685  689  690  692  693  694  695  696  697  699  702
## [463]  703  704  705  706  708  709  710  711  713  714  715  716  717  718  719  720  721  722  723  728  729
## [484]  730  731  732  733  734  735  737  738  739  742  745  746  747  749  753  755  756  757  758  759  760
## [505]  761  762  766  768  769  770  771  772  773  775  778  779  780  784  785  787  789  790  791  792  793
## [526]  794  795  796  797  799  800  801  803  805  806  807  809  811  813  814  816  817  818  819  822  823
## [547]  825  826  829  831  832  834  835  836  838  839  841  842  843  844  846  847  848  849  851  852  853
## [568]  856  857  858  859  861  862  863  864  865  866  867  868  869  871  873  877  878  881  882  883  885
## [589]  888  889  891  892  893  894  895  897  898  899  900  901  902  903  905  908  910  911  914  915  916
## [610]  917  918  919  920  923  924  925  926  927  928  931  932  934  935  936  938  939  940  942  943  944
## [631]  945  946  947  948  949  950  951  952  953  954  955  956  957  958  960  966  967  968  969  970  971
## [652]  972  973  975  976  977  978  981  982  983  985  988  989  990  991  993  994  996  998 1000
## 
## $test
##   [1]   3   5   8  17  20  22  25  29  30  33  35  39  40  41  43  44  49  50  51  52  57  58  59  62  63  69
##  [27]  70  73  74  75  80  81  89  97  99 101 103 105 106 109 110 113 122 123 126 127 132 139 141 143 148 149
##  [53] 151 152 159 160 162 164 172 176 178 179 180 181 182 186 188 189 190 191 195 196 199 200 204 207 208 211
##  [79] 212 215 216 217 219 225 226 230 231 234 235 236 238 245 252 253 257 260 265 266 271 275 279 285 294 295
## [105] 311 312 315 317 319 320 331 333 343 344 348 352 356 357 360 361 365 368 369 373 375 381 382 384 386 388
## [131] 391 398 399 407 408 409 412 413 414 416 418 419 432 433 434 438 440 441 443 445 446 451 452 453 454 455
## [157] 460 468 472 475 484 485 498 499 501 508 511 512 513 517 522 524 527 529 531 534 535 539 540 542 545 547
## [183] 561 562 563 565 566 569 571 574 575 578 579 581 582 583 586 588 589 591 595 601 602 603 604 608 610 612
## [209] 617 620 621 622 626 629 635 638 645 647 649 654 657 664 665 667 672 673 674 676 677 678 680 681 684 686
## [235] 687 688 691 698 700 701 707 712 724 725 726 727 736 740 741 743 744 748 750 751 752 754 763 764 765 767
## [261] 774 776 777 781 782 783 786 788 798 802 804 808 810 812 815 820 821 824 827 828 830 833 837 840 845 850
## [287] 854 855 860 870 872 874 875 876 879 880 884 886 887 890 896 904 906 907 909 912 913 921 922 929 930 933
## [313] 937 941 959 961 962 963 964 965 974 979 980 984 986 987 992 995 997 999
## 
## $validation
## integer(0)

Train a Model on the Training Dataset

The created Task contains the data we want to work with. Now that we conceptualized the ML task (i.e., classification) in a Task object, it is time to train our first supervised learning method. We start with a simple classifier: a logistic regression model. During this course, you will, of course, also gain experience with more complex models.

Fit a logistic regression model to the german_credit task using only the training data.

Hint 1:

Use lrn() to initialize a Learner object. The short cut and therefore input to this method is "classif.log_reg".

To train a model, use the $train() method of your instantiated learner with the task of the previous exercise as an input.

Hint 2:
logreg = lrn("classif.log_reg")
logreg$train(..., row_ids = ...)

Solution:

Click me:

By using the syntactic sugar method lrn(), we first initialize a LearnerClassif model. Using the $train() method, we derive optimal parameters (i.e., coefficients) for our logistic regression model. With row_ids = splits$train, we specify that only the training data should be used.

logreg = lrn("classif.log_reg")
logreg$train(task, row_ids = splits$train)

Inspect the Model

Have a look at the coefficients by using summary(). Name at least two features that have a significant effect on the outcome.

Hint 1:

Use the summary() method of the model field of our trained model. By looking on task$positive, we could see which of the two classes good or bad is used as the positive class (i.e., the class to which the model predictions will refer).

Hint 2:
summary(yourmodel$model)

Solution:

Click me:

Similar to models fitted via glm() or lm(), we could receive a summary of the coefficients (including p-values) using summary().

summary(logreg$model)
## 
## Call:
## stats::glm(formula = form, family = "binomial", data = data, 
##     model = FALSE)
## 
## Coefficients:
##                                                              Estimate  Std. Error z value       Pr(>|z|)    
## (Intercept)                                                -3.1543532   1.3701908   -2.30        0.02133 *  
## age                                                         0.0144578   0.0118059    1.22        0.22072    
## amount                                                     -0.0001367   0.0000577   -2.37        0.01784 *  
## credit_historycritical account/other credits elsewhere      0.3696507   0.7352414    0.50        0.61513    
## credit_historyno credits taken/all credits paid back duly   1.2150620   0.5739264    2.12        0.03425 *  
## credit_historyexisting credits paid back duly till now      0.9700566   0.6264085    1.55        0.12148    
## credit_historyall credits at this bank paid back duly       1.9434478   0.5820557    3.34        0.00084 ***
## duration                                                   -0.0295915   0.0122968   -2.41        0.01611 *  
## employment_duration< 1 yr                                  -0.3372155   0.5427375   -0.62        0.53439    
## employment_duration1 <= ... < 4 yrs                        -0.0218998   0.5238527   -0.04        0.96665    
## employment_duration4 <= ... < 7 yrs                         0.4897266   0.5609702    0.87        0.38266    
## employment_duration>= 7 yrs                                 0.1104749   0.5437570    0.20        0.83900    
## foreign_workeryes                                           1.7140875   0.7698682    2.23        0.02598 *  
## housingrent                                                 0.5610512   0.2923332    1.92        0.05496 .  
## housingown                                                  0.7080619   0.6038881    1.17        0.24099    
## installment_rate.L                                         -0.8074326   0.2728663   -2.96        0.00309 ** 
## installment_rate.Q                                          0.0594248   0.2479357    0.24        0.81058    
## installment_rate.C                                         -0.0351134   0.2540488   -0.14        0.89007    
## jobunskilled - resident                                    -0.3768399   0.9394487   -0.40        0.68833    
## jobskilled employee/official                               -0.3329590   0.9131823   -0.36        0.71540    
## jobmanager/self-empl/highly qualif. employee                0.1366757   0.9209889    0.15        0.88203    
## number_credits.L                                            0.1877171   1.0070714    0.19        0.85213    
## number_credits.Q                                            0.3365599   0.8389922    0.40        0.68831    
## number_credits.C                                           -0.3752939   0.5937729   -0.63        0.52735    
## other_debtorsco-applicant                                  -0.3244795   0.4943771   -0.66        0.51161    
## other_debtorsguarantor                                      1.2699170   0.5751648    2.21        0.02725 *  
## other_installment_plansstores                               0.7719189   0.5507474    1.40        0.16104    
## other_installment_plansnone                                 0.9052032   0.3082859    2.94        0.00332 ** 
## people_liable3 or more                                     -0.3713109   0.3276947   -1.13        0.25717    
## personal_status_sexfemale : non-single or male : single     1.0714349   0.5242157    2.04        0.04097 *  
## personal_status_sexmale : married/widowed                   1.5086800   0.5164705    2.92        0.00349 ** 
## personal_status_sexfemale : single                          0.8203908   0.6011491    1.36        0.17235    
## present_residence.L                                        -0.1005453   0.2661614   -0.38        0.70561    
## present_residence.Q                                         0.5619816   0.2521788    2.23        0.02585 *  
## present_residence.C                                        -0.5056580   0.2615088   -1.93        0.05316 .  
## propertycar or other                                       -0.4732924   0.3183653   -1.49        0.13711    
## propertybuilding soc. savings agr. / life insurance        -0.0186494   0.2964523   -0.06        0.94984    
## propertyreal estate                                        -0.7649722   0.5501163   -1.39        0.16436    
## purposecar (new)                                            1.4977168   0.4521313    3.31        0.00092 ***
## purposecar (used)                                           0.4228846   0.3261269    1.30        0.19474    
## purposefurniture/equipment                                  0.7755526   0.3224826    2.40        0.01617 *  
## purposeradio/television                                    -0.2887552   0.9525555   -0.30        0.76179    
## purposedomestic appliances                                 -0.5714243   0.6578263   -0.87        0.38504    
## purposerepairs                                             -0.1004203   0.5291260   -0.19        0.84948    
## purposevacation                                            15.2660806 543.5461693    0.03        0.97759    
## purposeretraining                                           1.0739961   0.4348230    2.47        0.01351 *  
## purposebusiness                                             3.2625392   1.4748082    2.21        0.02695 *  
## savings... < 100 DM                                         0.3224494   0.3700117    0.87        0.38350    
## savings100 <= ... < 500 DM                                  0.1003430   0.4800239    0.21        0.83442    
## savings500 <= ... < 1000 DM                                 1.6775820   0.6328090    2.65        0.00803 ** 
## savings... >= 1000 DM                                       0.8220917   0.3243537    2.53        0.01126 *  
## status... < 0 DM                                            0.4836803   0.2758264    1.75        0.07950 .  
## status0<= ... < 200 DM                                      1.1400132   0.4682444    2.43        0.01491 *  
## status... >= 200 DM / salary for at least 1 year            1.9430728   0.2993987    6.49 0.000000000086 ***
## telephoneyes (under customer name)                          0.1103368   0.2496382    0.44        0.65850    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 825.22  on 669  degrees of freedom
## Residual deviance: 576.59  on 615  degrees of freedom
## AIC: 686.6
## 
## Number of Fisher Scoring iterations: 14

According to the summary, e.g., credit_history and status significantly influence the creditworthiness and the bank’s risk assessment. By looking on task$positive, we see that the class good (creditworthy client) is the positive class. This means that a positive sign of the estimated coefficient of a feature means that the feature has a positive influence on being a creditworthy client (while a negative sign will have a negative influence).

task$positive
## [1] "good"

For example, the negative sign of the coefficients of credit_history = delay in paying off in the past and credit_history = critical account/other credit elsewhere, indicate a negative influence and therefore lower probability of being a creditworthy client compared to their reference class credit_history = all credits at this bank paid back duly. The positive sign of the coefficient of status >= 200 DM / salary for at least 1 year and status = 0 <= ... < 200 DM, therefore, indicate a positive influence w.r.t to its reference class status < 0 DM.

Predict on the Test Dataset

Use the trained model to predict on the hold-out/test dataset.

Hint 1

Use $predict() with row_ids.

Hint 2
pred = yourmodel$predict(..., row_ids = ...)

Solution:

Click me:
pred_logreg = logreg$predict(task, row_ids = splits$test)

Evaluation

What is the classification error on the test data (200 observations)?

Hint 1:

The classification error gives the rate of observations that were misclassified. Use the $score() method on the corresponding PredictionClassif object of the previous exercise.

Hint 2:
pred_logreg$score()

Solution:

Click me:

By using the $score() method, we obtain an estimate for the classification error of our model.

pred_logreg$score()
## classif.ce 
##    0.24848

The classification error is 0.255 - so 25.5 % of the test instances were misclassified by our logistic regression model.

Predicting Probabilities Instead of Labels

Similarly, we can assess the performance of our model using the Brier score. However, this requires predicted probabilities instead of predicted labels. Evaluate the model using the Brier score. To do so, retrain the model with a learner that returns probabilities.

Hint 1:

You can generate predictions with probabilities by specifying a predict_type argument inside the lrn() function call when constructing a learner.

Hint 2:

You can get an overview of performance measures in mlr3 using as.data.table(msr()).

Solution:

Click me:
# Train a learner
logreg = lrn("classif.log_reg", predict_type = "prob")
logreg$train(task)
# Generate predictions
pred_logreg = logreg$predict(task, row_ids = splits$test)
# Evaluate performance using brier score
measure = msr("classif.bbrier")
pred_logreg$score(measure)
## classif.bbrier 
##        0.14075

Summary

In this exercise sheet we learned how to fit a logistic regression model on a training task and how to assess its performance on unseen test data with the help of mlr3. We showed how to split data manually into training and test data, but in most scenarios it is a call to resample or benchmark. We will learn more on this in the next sections.