This class implements the Connection Weights method investigated by Olden et al. (2004) which results in a feature relevance score for each input variable. The basic idea is to multiply up all path weights for each possible connection between an input feature and the output node and then calculate the sum over them. Besides, it is a global interpretation method and independent of the input data. For a neural network with $$3$$ hidden layers with weight matrices $$W_1$$, $$W_2$$ and $$W_3$$ this method results in a simple matrix multiplication $$W_1 * W_2 * W_3.$$

## References

• J. D. Olden et al. (2004) An accurate comparison of methods for quantifying variable importance in artificial neural networks using simulated data. Ecological Modelling 178, p. 389–397

## Public fields

converter

The converter of class Converter with the stored and torch-converted model.

channels_first

The data format of the result, i.e. channels on last dimension (FALSE) or on the first dimension (TRUE). If the data has no channels, use the default value TRUE.

dtype

The type of the data and parameters (either 'float' for torch::torch_float or 'double' for torch::torch_double).

result

The methods result as a torch tensor of size (dim_in, dim_out) and with data type dtype.

output_idx

This vector determines for which outputs the method will be applied. By default (NULL), all outputs (but limited to the first 10) are considered.

## Methods

### Method new()

#### Arguments

type

The data type of the result. Use one of 'array', 'torch.tensor', 'torch_tensor' or 'data.frame' (default: 'array').

#### Returns

The result of this method for the given data in the chosen type.

### Method plot()

This method visualizes the result of the ConnectionWeights method in a ggplot2::ggplot. You can use the argument output_idx to select individual output nodes for the plot. The different results for the selected outputs are visualized using the method ggplot2::facet_grid. You can also use the as_plotly argument to generate an interactive plot based on the plot function plotly::plot_ly.

#### Arguments

deep

Whether to make a deep clone.

## Examples

#----------------------- Example 1: Torch ----------------------------------
library(torch)

# Create nn_sequential model
model <- nn_sequential(
nn_linear(5, 12),
nn_relu(),
nn_linear(12, 1),
nn_sigmoid()
)

# Create Converter with input names
converter <- Converter$new(model, input_dim = c(5), input_names = list(c("Car", "Cat", "Dog", "Plane", "Horse")) ) # Apply method Connection Weights cw <- ConnectionWeights$new(converter)
#> Backward pass 'ConnectionWeights':
#>
|
|                                                                      |   0%
|
|===================================                                   |  50%
|
|======================================================================| 100%

# Print the result as a data.frame
cw$get_result("data.frame") #> feature class value #> 1 Car Y1 -0.006045442 #> 2 Cat Y1 0.055774860 #> 3 Dog Y1 -0.048691414 #> 4 Plane Y1 -0.023640137 #> 5 Horse Y1 -0.064227775 # Plot the result plot(cw) #----------------------- Example 2: Neuralnet ------------------------------ library(neuralnet) data(iris) # Train a Neural Network nn <- neuralnet((Species == "setosa") ~ Petal.Length + Petal.Width, iris, linear.output = FALSE, hidden = c(3, 2), act.fct = "tanh", rep = 1 ) # Convert the trained model converter <- Converter$new(nn)

# Apply the Connection Weights method
cw <- ConnectionWeights$new(converter) #> Backward pass 'ConnectionWeights': #> | | | 0% | |======================= | 33% | |=============================================== | 67% | |======================================================================| 100% # Get the result as a torch tensor cw$get_result(type = "torch.tensor")
#> torch_tensor
#>  1.4929
#> -12.9609
#> [ CPUFloatType{2,1} ]

# Plot the result
plot(cw)

#----------------------- Example 3: Keras ----------------------------------
library(keras)

if (is_keras_available()) {
# Define a model
model <- keras_model_sequential()
model %>%
layer_conv_1d(
input_shape = c(64, 3), kernel_size = 16, filters = 8,
activation = "softplus"
) %>%
layer_conv_1d(kernel_size = 16, filters = 4, activation = "tanh") %>%
layer_conv_1d(kernel_size = 16, filters = 2, activation = "relu") %>%
layer_flatten() %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 2, activation = "softmax")

# Convert the model
converter <- Converter$new(model) # Apply the Connection Weights method cw <- ConnectionWeights$new(converter)

# Get the result as data.frame
cw$get_result(type = "data.frame") # Plot the result for all classes plot(cw, output_idx = 1:2) } # ------------------------- Advanced: Plotly ------------------------------- # If you want to create an interactive plot of your results with custom # changes, you can take use of the method plotly::ggplotly library(ggplot2) library(plotly) #> #> Attaching package: ‘plotly’ #> The following object is masked from ‘package:ggplot2’: #> #> last_plot #> The following object is masked from ‘package:stats’: #> #> filter #> The following object is masked from ‘package:graphics’: #> #> layout library(neuralnet) data(iris) nn <- neuralnet(Species ~ ., iris, linear.output = FALSE, hidden = c(10, 8), act.fct = "tanh", rep = 1, threshold = 0.5 ) # create an converter for this model converter <- Converter$new(nn)

# create new instance of 'LRP'
cw <- ConnectionWeights\$new(converter)
#> Backward pass 'ConnectionWeights':
#>
|
|                                                                      |   0%
|
|=======================                                               |  33%
|
|===============================================                       |  67%
|
|======================================================================| 100%

library(plotly)

p <- plot(cw, output_idx = 1) +
theme_bw() +
scale_fill_gradient2(low = "green", mid = "black", high = "blue")
#> Scale for 'fill' is already present. Adding another scale for 'fill', which
#> will replace the existing scale.

# Now apply the method plotly::ggplotly with argument tooltip = "text"
plotly::ggplotly(p, tooltip = "text")
#> Warning: gather_() was deprecated in tidyr 1.2.0.
#> Please use gather() instead.
#> This warning is displayed once every 8 hours.
#> Call lifecycle::last_lifecycle_warnings() to see where this warning was generated.

{"x":{"data":[{"x":[2.7,2.7,3.3,3.3,2.7],"y":[0,3.84830331802368,3.84830331802368,0,0],"text":"<b><\/br> Relative Importance : 3.848 <\/b><br /> <\/br> Output:      setosa <\/br> Feature:     Petal.Length","type":"scatter","mode":"lines","line":{"width":1.88976377952756,"color":"transparent","dash":"solid"},"fill":"toself","fillcolor":"rgba(21,9,35,1)","hoveron":"fills","showlegend":false,"xaxis":"x","yaxis":"y","hoverinfo":"text","frame":null},{"x":[1.7,1.7,2.3,2.3,1.7],"y":[0,10.735897064209,10.735897064209,0,0],"text":"<b><\/br> Relative Importance : 10.74 <\/b><br /> <\/br> Output:      setosa <\/br> Feature:     Sepal.Width","type":"scatter","mode":"lines","line":{"width":1.88976377952756,"color":"transparent","dash":"solid"},"fill":"toself","fillcolor":"rgba(33,16,87,1)","hoveron":"fills","showlegend":false,"xaxis":"x","yaxis":"y","hoverinfo":"text","frame":null},{"x":[0.7,0.7,1.3,1.3,0.7],"y":[0,12.806604385376,12.806604385376,0,0],"text":"<b><\/br> Relative Importance : 12.81 <\/b><br /> <\/br> Output:      setosa <\/br> Feature:     Sepal.Length","type":"scatter","mode":"lines","line":{"width":1.88976377952756,"color":"transparent","dash":"solid"},"fill":"toself","fillcolor":"rgba(35,17,105,1)","hoveron":"fills","showlegend":false,"xaxis":"x","yaxis":"y","hoverinfo":"text","frame":null},{"x":[3.7,3.7,4.3,4.3,3.7],"y":[0,29.1362800598145,29.1362800598145,0,0],"text":"<b><\/br> Relative Importance : 29.14 <\/b><br /> <\/br> Output:      setosa <\/br> Feature:     Petal.Width","type":"scatter","mode":"lines","line":{"width":1.88976377952756,"color":"transparent","dash":"solid"},"fill":"toself","fillcolor":"rgba(0,0,255,1)","hoveron":"fills","showlegend":false,"xaxis":"x","yaxis":"y","hoverinfo":"text","frame":null},{"x":[0.4,4.6],"y":[0,0],"text":"","type":"scatter","mode":"lines","line":{"width":1.88976377952756,"color":"rgba(0,0,0,1)","dash":"solid"},"hoveron":"points","showlegend":false,"xaxis":"x","yaxis":"y","hoverinfo":"text","frame":null}],"layout":{"margin":{"t":34.9954337899543,"r":18.9954337899543,"b":37.2602739726027,"l":37.2602739726027},"plot_bgcolor":"rgba(255,255,255,1)","paper_bgcolor":"rgba(255,255,255,1)","font":{"color":"rgba(0,0,0,1)","family":"","size":14.6118721461187},"xaxis":{"domain":[0,1],"automargin":true,"type":"linear","autorange":false,"range":[0.4,4.6],"tickmode":"array","ticktext":["Sepal.Length","Sepal.Width","Petal.Length","Petal.Width"],"tickvals":[1,2,3,4],"categoryorder":"array","categoryarray":["Sepal.Length","Sepal.Width","Petal.Length","Petal.Width"],"nticks":null,"ticks":"outside","tickcolor":"rgba(51,51,51,1)","ticklen":3.65296803652968,"tickwidth":0.66417600664176,"showticklabels":true,"tickfont":{"color":"rgba(77,77,77,1)","family":"","size":11.689497716895},"tickangle":-0,"showline":false,"linecolor":null,"linewidth":0,"showgrid":true,"gridcolor":"rgba(235,235,235,1)","gridwidth":0.66417600664176,"zeroline":false,"anchor":"y","title":"","hoverformat":".2f"},"annotations":[{"text":"Feature","x":0.5,"y":0,"showarrow":false,"ax":0,"ay":0,"font":{"color":"rgba(0,0,0,1)","family":"","size":14.6118721461187},"xref":"paper","yref":"paper","textangle":-0,"xanchor":"center","yanchor":"top","annotationType":"axis","yshift":-21.9178082191781},{"text":"Relative Importance","x":0,"y":0.5,"showarrow":false,"ax":0,"ay":0,"font":{"color":"rgba(0,0,0,1)","family":"","size":14.6118721461187},"xref":"paper","yref":"paper","textangle":-90,"xanchor":"right","yanchor":"center","annotationType":"axis","xshift":-21.9178082191781},{"text":"setosa","x":0.5,"y":1,"showarrow":false,"ax":0,"ay":0,"font":{"color":"rgba(26,26,26,1)","family":"","size":11.689497716895},"xref":"paper","yref":"paper","textangle":-0,"xanchor":"center","yanchor":"bottom"}],"yaxis":{"domain":[0,1],"automargin":true,"type":"linear","autorange":false,"range":[-1.45681400299072,30.5930940628052],"tickmode":"array","ticktext":["0","10","20","30"],"tickvals":[0,10,20,30],"categoryorder":"array","categoryarray":["0","10","20","30"],"nticks":null,"ticks":"outside","tickcolor":"rgba(51,51,51,1)","ticklen":3.65296803652968,"tickwidth":0.66417600664176,"showticklabels":true,"tickfont":{"color":"rgba(77,77,77,1)","family":"","size":11.689497716895},"tickangle":-0,"showline":false,"linecolor":null,"linewidth":0,"showgrid":true,"gridcolor":"rgba(235,235,235,1)","gridwidth":0.66417600664176,"zeroline":false,"anchor":"x","title":"","hoverformat":".2f"},"shapes":[{"type":"rect","fillcolor":"transparent","line":{"color":"rgba(51,51,51,1)","width":0.66417600664176,"linetype":"solid"},"yref":"paper","xref":"paper","x0":0,"x1":1,"y0":0,"y1":1},{"type":"rect","fillcolor":"rgba(217,217,217,1)","line":{"color":"rgba(51,51,51,1)","width":0.66417600664176,"linetype":"solid"},"yref":"paper","xref":"paper","x0":0,"x1":1,"y0":0,"y1":23.37899543379,"yanchor":1,"ysizemode":"pixel"}],"showlegend":false,"legend":{"bgcolor":"rgba(255,255,255,1)","bordercolor":"transparent","borderwidth":1.88976377952756,"font":{"color":"rgba(0,0,0,1)","family":"","size":11.689497716895},"title":{"text":"","font":{"color":"rgba(0,0,0,1)","family":"","size":14.6118721461187}}},"hovermode":"closest","barmode":"relative"},"config":{"doubleClick":"reset","modeBarButtonsToAdd":["hoverclosest","hovercompare"],"showSendToCloud":false},"source":"A","attrs":{"2a6e4c1faec7":{"xmin":{},"xmax":{},"ymin":{},"ymax":{},"fill":{},"text":{},"type":"scatter"},"2a6e12d4b962":{"yintercept":{}}},"cur_data":"2a6e4c1faec7","visdat":{"2a6e4c1faec7":["function (y) ","x"],"2a6e12d4b962":["function (y) ","x"]},"highlight":{"on":"plotly_click","persistent":false,"dynamic":false,"selectize":false,"opacityDim":0.2,"selected":{"opacity":1},"debounce":0},"shinyEvents":["plotly_hover","plotly_click","plotly_selected","plotly_relayout","plotly_brushed","plotly_brushing","plotly_clickannotation","plotly_doubleclick","plotly_deselect","plotly_afterplot","plotly_sunburstclick"],"base_url":"https://plot.ly"},"evals":[],"jsHooks":[]}