This is an implementation of the Deep Learning Important FeaTures (DeepLift) algorithm introduced by Shrikumar et al. (2017). It's a local method for interpreting a single element \(x\) of the dataset concerning a reference value \(x'\) and returns the contribution of each input feature from the difference of the output (\(y=f(x)\)) and reference output (\(y'=f(x')\)) prediction. The basic idea of this method is to decompose the difference-from-reference prediction with respect to the input features, i.e. $$\Delta y = y - y' = \sum_i C(x_i).$$ Compared to Layer-wise Relevance Propagation (see LRP), the DeepLift method is an exact decomposition and not an approximation, so we get real contributions of the input features to the difference-from-reference prediction. There are two ways to handle activation functions: Rescale-Rule ('rescale') and RevealCancel-Rule ('reveal_cancel').

References

A. Shrikumar et al. (2017) Learning important features through propagating activation differences. ICML 2017, p. 4844-4866

Super class

innsight::InterpretingMethod -> DeepLift

Public fields

x_ref

The reference input of size (1, dim_in) for the interpretation.

rule_name

Name of the applied rule to calculate the contributions for the non-linear part of a neural network layer. Either "rescale" or "reveal_cancel".

Methods

Inherited methods


Method new()

Create a new instance of the DeepLift method.

Usage

DeepLift$new(
  converter,
  data,
  channels_first = TRUE,
  output_idx = NULL,
  ignore_last_act = TRUE,
  rule_name = "rescale",
  x_ref = NULL,
  dtype = "float"
)

Arguments

converter

An instance of the R6 class Converter.

data

The data for which the contribution scores are to be calculated. It has to be an array or array-like format of size (batch_size, dim_in).

channels_first

The format of the given date, i.e. channels on last dimension (FALSE) or after the batch dimension (TRUE). If the data has no channels, use the default value TRUE.

output_idx

This vector determines for which outputs the method will be applied. By default (NULL), all outputs (but limited to the first 10) are considered.

ignore_last_act

Set this boolean value to include the last activation, or not (default: TRUE). In some cases, the last activation leads to a saturation problem.

rule_name

Name of the applied rule to calculate the contributions. Use one of 'rescale' and 'reveal_cancel'.

x_ref

The reference input of size (1, dim_in) for the interpretation. With the default value NULL you use an input of zeros.

dtype

The data type for the calculations. Use either 'float' for torch::torch_float or 'double' for torch::torch_double.


Method plot()

This method visualizes the result of the selected method in a ggplot2::ggplot. You can use the argument data_idx to select the data points in the given data for the plot. In addition, the individual output nodes for the plot can be selected with the argument output_idx. The different results for the selected data points and outputs are visualized using the method ggplot2::facet_grid. You can also use the as_plotly argument to generate an interactive plot based on the plot function plotly::plot_ly.

Usage

DeepLift$plot(
  data_idx = 1,
  output_idx = NULL,
  aggr_channels = "sum",
  as_plotly = FALSE
)

Arguments

data_idx

An integer vector containing the numbers of the data points whose result is to be plotted, e.g. c(1,3) for the first and third data point in the given data. Default: c(1).

output_idx

An integer vector containing the numbers of the output indices whose result is to be plotted, e.g. c(1,4) for the first and fourth model output. But this vector must be included in the vector output_idx from the initialization, otherwise, no results were calculated for this output node and can not be plotted. By default (NULL), the smallest index of all calculated output nodes is used.

aggr_channels

Pass one of 'norm', 'sum', 'mean' or a custom function to aggregate the channels, e.g. the maximum (base::max) or minimum (base::min) over the channels or only individual channels with function(x) x[1]. By default ('sum'), the sum of all channels is used.
Note: This argument is used only for 2D and 3D inputs.

as_plotly

This boolean value (default: FALSE) can be used to create an interactive plot based on the library plotly. This function takes use of plotly::ggplotly, hence make sure that the suggested package plotly is installed in your R session.
Advanced: You can first output the results as a ggplot (as_plotly = FALSE) and then make custom changes to the plot, e.g. other theme or other fill color. Then you can manually call the function ggplotly to get an interactive plotly plot.

Returns

Returns either a ggplot2::ggplot (as_plotly = FALSE) or a plotly::plot_ly (as_plotly = TRUE) with the plotted results.


Method boxplot()

This function visualizes the results of this method in a boxplot, where the type of visualization depends on the input dimension of the data. By default a ggplot2::ggplot is returned, but with the argument as_plotly an interactive plotly::plot_ly plot can be created, which however requires a successful installation of the package plotly.

Usage

DeepLift$boxplot(
  output_idx = NULL,
  data_idx = "all",
  ref_data_idx = NULL,
  aggr_channels = "norm",
  preprocess_FUN = abs,
  as_plotly = FALSE,
  individual_data_idx = NULL,
  individual_max = 20
)

Arguments

output_idx

An integer vector containing the numbers of the output indices whose result is to be plotted, e.g. c(1,4) for the first and fourth model output. But this vector must be included in the vector output_idx from the initialization, otherwise, no results were calculated for this output node and can not be plotted. By default (NULL), the smallest index of all calculated output nodes is used.

data_idx

By default ("all"), all available data is used to calculate the boxplot information. However, this parameter can be used to select a subset of them by passing the indices. E.g. with data_idx = c(1:10, 25, 26) only the first 10 data points and the 25th and 26th are used to calculate the boxplots.

ref_data_idx

This integer number determines the index for the reference data point. In addition to the boxplots, it is displayed in red color and is used to compare an individual result with the summary statistics provided by the boxplot. With the default value (NULL) no individual data point is plotted. This index can be chosen with respect to all available data, even if only a subset is selected with argument data_idx.
Note: Because of the complexity of 3D inputs, this argument is used only for 1D and 2D inputs and disregarded for 3D inputs.

aggr_channels

Pass one of 'norm', 'sum', 'mean' or a custom function to aggregate the channels, e.g. the maximum (base::max) or minimum (base::min) over the channels or only individual channels with function(x) x[1]. By default ('norm'), the Euclidean norm of all channels is used.
Note: This argument is used only for 2D and 3D inputs.

preprocess_FUN

This function is applied to the method's result before calculating the boxplots. Since positive and negative values often cancel each other out, the absolute value (abs) is used by default. But you can also use the raw data (identity) to see the results' orientation, the squared data (function(x) x^2) to weight the outliers higher or any other function.

as_plotly

This boolean value (default: FALSE) can be used to create an interactive plot based on the library plotly instead of ggplot2. Make sure that the suggested package plotly is installed in your R session.

individual_data_idx

Only relevant for a plotly plot with input dimension 1 or 2! This integer vector of data indices determines the available data points in a dropdown menu, which are drawn in individually analogous to ref_data_idx only for more data points. With the default value NULL the first individual_max data points are used.
Note: If ref_data_idx is specified, this data point will be added to those from individual_data_idx in the dropdown menu.

individual_max

Only relevant for a plotly plot with input dimension 1 or 2! This integer determines the maximum number of individual data points in the dropdown menu without counting ref_data_idx. This means that if individual_data_idx has more than individual_max indices, only the first individual_max will be used. A too high number can significantly increase the runtime.

Returns

Returns either a ggplot2::ggplot (as_plotly = FALSE) or a plotly::plot_ly (as_plotly = TRUE) with the boxplots.


Method clone()

The objects of this class are cloneable with this method.

Usage

DeepLift$clone(deep = FALSE)

Arguments

deep

Whether to make a deep clone.

Examples

#----------------------- Example 1: Torch ----------------------------------
library(torch)

# Create nn_sequential model and data
model <- nn_sequential(
  nn_linear(5, 12),
  nn_relu(),
  nn_linear(12, 2),
  nn_softmax(dim = 2)
)
data <- torch_randn(25, 5)
ref <- torch_randn(1, 5)

# Create Converter
converter <- Converter$new(model, input_dim = c(5))

# Apply method DeepLift
deeplift <- DeepLift$new(converter, data, x_ref = ref)
#> Backward pass 'DeepLift':
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |===================================                                   |  50%
  |                                                                            
  |======================================================================| 100%

# Print the result as a torch tensor for first two data points
deeplift$get_result("torch.tensor")[1:2]
#> torch_tensor
#> (1,.,.) = 
#>  -0.0659 -0.0058
#>   0.0335 -0.0145
#>   0.0072 -0.0262
#>   0.2201  0.2676
#>   0.2283 -0.0452
#> 
#> (2,.,.) = 
#>  -0.0551 -0.0215
#>   0.0392 -0.0136
#>   0.0362  0.0828
#>  -0.0059  0.1871
#>   0.0878  0.0303
#> [ CPUFloatType{2,5,2} ]

# Plot the result for both classes
plot(deeplift, output_idx = 1:2)


# Plot the boxplot of all datapoints
boxplot(deeplift, output_idx = 1:2)


# ------------------------- Example 2: Neuralnet ---------------------------
library(neuralnet)
data(iris)

# Train a neural network
nn <- neuralnet((Species == "setosa") ~ Petal.Length + Petal.Width,
  iris,
  linear.output = FALSE,
  hidden = c(3, 2), act.fct = "tanh", rep = 1
)

# Convert the model
converter <- Converter$new(nn)

# Apply DeepLift with rescale-rule and a reference input of the feature
# means
x_ref <- matrix(colMeans(iris[, c(3, 4)]), nrow = 1)
deeplift_rescale <- DeepLift$new(converter, iris[, c(3, 4)], x_ref = x_ref)
#> Backward pass 'DeepLift':
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |=======================                                               |  33%
  |                                                                            
  |===============================================                       |  67%
  |                                                                            
  |======================================================================| 100%

# Get the result as a dataframe and show first 5 rows
deeplift_rescale$get_result(type = "data.frame")[1:5, ]
#>     data      feature               class     value
#> 1 data_1 Petal.Length Species == "setosa" 0.1461408
#> 2 data_2 Petal.Length Species == "setosa" 0.1461408
#> 3 data_3 Petal.Length Species == "setosa" 0.1520031
#> 4 data_4 Petal.Length Species == "setosa" 0.1402526
#> 5 data_5 Petal.Length Species == "setosa" 0.1461408

# Plot the result for the first datapoint in the data
plot(deeplift_rescale, data_idx = 1)


# Plot the result as boxplots
boxplot(deeplift_rescale)


# ------------------------- Example 3: Keras -------------------------------
library(keras)

if (is_keras_available()) {
  data <- array(rnorm(10 * 32 * 32 * 3), dim = c(10, 32, 32, 3))

  model <- keras_model_sequential()
  model %>%
    layer_conv_2d(
      input_shape = c(32, 32, 3), kernel_size = 8, filters = 8,
      activation = "softplus", padding = "valid"
    ) %>%
    layer_conv_2d(
      kernel_size = 8, filters = 4, activation = "tanh",
      padding = "same"
    ) %>%
    layer_conv_2d(
      kernel_size = 4, filters = 2, activation = "relu",
      padding = "valid"
    ) %>%
    layer_flatten() %>%
    layer_dense(units = 64, activation = "relu") %>%
    layer_dense(units = 16, activation = "relu") %>%
    layer_dense(units = 2, activation = "softmax")

  # Convert the model
  converter <- Converter$new(model)

  # Apply the DeepLift method with reveal-cancel rule
  deeplift_revcancel <- DeepLift$new(converter, data,
    channels_first = FALSE,
    rule_name = "reveal_cancel"
  )

  # Plot the result for the first image and both classes
  plot(deeplift_revcancel, output_idx = 1:2)

  # Plot the result as boxplots for first class
  boxplot(deeplift_revcancel, output_idx = 1)

  # You can also create an interactive plot with plotly.
  # This is a suggested package, so make sure that it is installed
  library(plotly)
  boxplot(deeplift_revcancel, as_plotly = TRUE)
}
#> Backward pass 'DeepLift':
#> 
  |                                                                            
  |                                                                      |   0%
  |                                                                            
  |==========                                                            |  14%
  |                                                                            
  |====================                                                  |  29%
  |                                                                            
  |==============================                                        |  43%
  |                                                                            
  |========================================                              |  57%
  |                                                                            
  |==================================================                    |  71%
  |                                                                            
  |============================================================          |  86%
  |                                                                            
  |======================================================================| 100%