diff --git a/.Rprofile b/.Rprofile index 0201e1af91fa19a581c3afd2719a0d14bcda5d9e..4f507a9dae95fb5dbb9c33d4f46bea38fbf319a1 100644 --- a/.Rprofile +++ b/.Rprofile @@ -1,3 +1,7 @@ +if (requireNamespace("testthat", quietly = TRUE)) { + testthat::set_max_fails(Inf) +} + #' Helper function for package development #' #' This is a manual extension of [testthat::snapshot_review()] which works for the \code{.rds} files used in @@ -7,17 +11,19 @@ #' @param ... Additional arguments passed to [waldo::compare()] #' Gives the relative path to the test files to review #' -snapshot_review_man <- function(path, tolerance = NULL, ...) { - changed <- testthat:::snapshot_meta(path) - these_rds <- (tools::file_ext(changed$name) == "rds") - if (any(these_rds)) { - for (i in which(these_rds)) { - old <- readRDS(changed[i, "cur"]) - new <- readRDS(changed[i, "new"]) +snapshot_review_man <- function(path, tolerance = 10^(-5), max_diffs = 200, ...) { + if (requireNamespace("testthat", quietly = TRUE) && requireNamespace("waldo", quietly = TRUE)) { + changed <- testthat:::snapshot_meta(path) + these_rds <- (tools::file_ext(changed$name) == "rds") + if (any(these_rds)) { + for (i in which(these_rds)) { + old <- readRDS(changed[i, "cur"]) + new <- readRDS(changed[i, "new"]) - cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n")) - print(waldo::compare(old, new, max_diffs = 50, tolerance = tolerance, ...)) - browser() + cat(paste0("Difference for check ", changed[i, "name"], " in test ", changed[i, "test"], "\n")) + print(waldo::compare(old, new, max_diffs = max_diffs, tolerance = tolerance, ...)) + browser() + } } } } diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml index 2b496dba9419edd8a9aba06e37d69f3f3e577957..bdc738e129220d74724b103faa837719a5799114 100644 --- a/.github/workflows/R-CMD-check.yaml +++ b/.github/workflows/R-CMD-check.yaml @@ -19,9 +19,9 @@ on: push: - branches: [main, master, cranversion, devel] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] pull_request: - branches: [main, master, cranversion, devel] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] name: R-CMD-check diff --git a/.github/workflows/lint-changed-files.yaml b/.github/workflows/lint-changed-files.yaml index 7f71f45f01bc8f7db2c62d4759522788ab787acc..59375477005b76341ad0f849f43c016195b40f89 100644 --- a/.github/workflows/lint-changed-files.yaml +++ b/.github/workflows/lint-changed-files.yaml @@ -8,7 +8,7 @@ # Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help on: pull_request: - branches: [main, master] + branches: [main, master, cranversion, devel, 'shapr-1.0.0'] name: lint-changed-files diff --git a/.lintr b/.lintr index 321d0fbfbdaeea578a4a1408e9d3762ac5b4ef36..f88e5a08bd641a36438a6d4b9e99ee577adbbbd5 100644 --- a/.lintr +++ b/.lintr @@ -8,6 +8,7 @@ linters: linters_with_defaults( ) exclusions: list( "inst/scripts", + "inst/code_paper", "vignettes", "R/RcppExports.R", "R/zzz.R" diff --git a/DESCRIPTION b/DESCRIPTION index a823f1e19098faafdd1ea86e992111890274aa70..8494e672a10d91018937160f4a0ca09c8074164c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,20 +1,19 @@ Package: shapr -Version: 0.2.3.9200 +Version: 1.0.0 Title: Prediction Explanation with Dependence-Aware Shapley Values Description: Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley - values do, however, assume feature independence. This package implements the method - described in Aas, Jullum and Løland (2019) , which accounts for any feature + values do, however, assume feature independence. This package implements methods which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. Authors@R: c( - person("Nikolai", "Sellereite", email = "nikolaisellereite@gmail.com", role = "aut", comment = c(ORCID = "0000-0002-4671-0337")), person("Martin", "Jullum", email = "Martin.Jullum@nr.no", role = c("cre", "aut"), comment = c(ORCID = "0000-0003-3908-5155")), person("Lars Henry Berge", "Olsen", email = "lholsen@math.uio.no", role = "aut", comment = c(ORCID = "0009-0006-9360-6993")), person("Annabelle", "Redelmeier", email = "Annabelle.Redelmeier@nr.no", role = "aut"), - person("Jon", "Lachmann", email = "Jon@lachmann.nu", role = "aut"), + person("Jon", "Lachmann", email = "Jon@lachmann.nu", role = "aut", comment = c(ORCID = "0000-0001-8396-5673")), + person("Nikolai", "Sellereite", email = "nikolaisellereite@gmail.com", role = "aut", comment = c(ORCID = "0000-0002-4671-0337")), person("Anders", "Løland", email = "Anders.Loland@nr.no", role = "ctb"), person("Jens Christian", "Wahl", email = "Jens.Christian.Wahl@nr.no", role = "ctb"), person("Camilla", "Lingjærde", role = "ctb"), @@ -27,7 +26,7 @@ Encoding: UTF-8 LazyData: true ByteCompile: true Language: en-US -RoxygenNote: 7.3.1 +RoxygenNote: 7.3.2 Depends: R (>= 3.5.0) Imports: stats, @@ -66,7 +65,8 @@ Suggests: yardstick, hardhat, rsample, - rlang + rlang, + cli LinkingTo: RcppArmadillo, Rcpp diff --git a/NAMESPACE b/NAMESPACE index 1fa9bc343af1248cfd9c3a24b7920b7f7bc0b5ec..5c0835022448111db01ae026c62bf5b255028969 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -52,19 +52,31 @@ S3method(setup_approach,regression_separate) S3method(setup_approach,regression_surrogate) S3method(setup_approach,timeseries) S3method(setup_approach,vaeac) +export(additional_regression_setup) export(aicc_full_single_cpp) +export(append_vS_list) +export(check_convergence) +export(cli_compute_vS) +export(cli_iter) +export(cli_startup) +export(coalition_matrix_cpp) +export(compute_estimates) export(compute_shapley_new) +export(compute_time) export(compute_vS) export(correction_matrix_cpp) +export(create_coalition_table) export(explain) export(explain_forecast) -export(feature_combinations) -export(feature_matrix_cpp) export(finalize_explanation) +export(finalize_explanation_forecast) export(get_cov_mat) export(get_data_specs) +export(get_extra_est_args_default) +export(get_iterative_args_default) export(get_model_specs) export(get_mu_vec) +export(get_output_args_default) export(get_supported_approaches) export(hat_matrix_cpp) export(mahalanobis_distance_cpp) @@ -73,19 +85,28 @@ export(plot_MSEv_eval_crit) export(plot_SV_several_approaches) export(predict_model) export(prepare_data) +export(prepare_data_causal) export(prepare_data_copula_cpp) +export(prepare_data_copula_cpp_caus) export(prepare_data_gaussian_cpp) +export(prepare_data_gaussian_cpp_caus) +export(prepare_next_iteration) +export(print_iter) export(regression.train_model) export(rss_cpp) +export(save_results) export(setup) export(setup_approach) export(setup_computation) +export(shapley_setup) +export(testing_cleanup) export(vaeac_get_evaluation_criteria) export(vaeac_get_extra_para_default) export(vaeac_plot_eval_crit) export(vaeac_plot_imputed_ggpairs) export(vaeac_train_model) export(vaeac_train_model_continue) +export(weight_matrix) export(weight_matrix_cpp) importFrom(Rcpp,sourceCpp) importFrom(data.table,":=") @@ -110,6 +131,7 @@ importFrom(stats,as.formula) importFrom(stats,contrasts) importFrom(stats,embed) importFrom(stats,formula) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,predict) @@ -118,8 +140,10 @@ importFrom(stats,qt) importFrom(stats,rnorm) importFrom(stats,sd) importFrom(stats,setNames) +importFrom(utils,capture.output) importFrom(utils,head) importFrom(utils,methods) importFrom(utils,modifyList) +importFrom(utils,relist) importFrom(utils,tail) useDynLib(shapr, .registration = TRUE) diff --git a/NEWS.md b/NEWS.md index e5f8cb3d17842c17bf60bdbecee0fb973913e6a9..b892be33003fead1d44a733e9a52b65728cef61b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,8 +1,4 @@ -# shapr (development version) - -* Release a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr`. The wrapper moves back and forth back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). -* Complete restructuring motivated by introducing the Python wrapper. The restructuring splits the explanation tasks into smaller pieces, which was necessary to allow the Python wrapper to move back and forth between R and Python. -* As part of the restructuring, we also did a number of design changes, resulting in a series of breaking changes described below. +# shapr 1.0.0 ### Breaking changes @@ -10,9 +6,19 @@ * Prediction and checking functions for custom models are now passed directly as arguments to `explain()` instead of being defined as functions of a specific class in the global env. * The previously exported function `make_dummies` used to explain `xgboost` models with categorical data, is removed to simplify the code base. This is rather handled with a custom prediction model. * The function `explain.ctree_comb_mincrit`, which allowed combining models with `approch=ctree` with different `mincrit` parameters, has been removed to simplify the code base. It may return in a completely general manner in later version of `shapr`. +* New argument names: prediction_zero -> phi0, n_combinations -> max_n_coalitions, n_samples -> n_MC_samples, ### New features +* Iterative Shapley value estimation with convergence detection +* New approaches: vaeac, regression_separate, regression_surrogate, timeseries, categorical +* verbose argument for explain() to control the amount of output +* Parallelized computation of v(S) with future, including progress updates +* Paired_sampling of coalitions +* prev_shapr_object argument to explain() to continue explanation from a previous object +* asymmetric and causal Shapley values +* Improved KernelSHAP estimation with adjusted weights for reduced variance +* Release a Python wrapper (`shaprpyr`, [#325](https://github.com/NorskRegnesentral/shapr/pull/325)) for explaining predictions from Python models (from Python) utilizing almost all functionality of `shapr`. The wrapper moves back and forth back and forth between Python and R, doing the prediction in Python, and almost everything else in R. This simplifies maintenance of `shaprpy` significantly. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). * Introduce batch computation of conditional expectations ([#244](https://github.com/NorskRegnesentral/shapr/issues/244)). This essentially compute $v(S)$ for a portion of the $S$-subsets at a time, to reduce the amount of data needed to be held in memory. The user can control the number of batches herself, but we set a reasonable value by default ([#327](https://github.com/NorskRegnesentral/shapr/pull/327)). @@ -49,6 +55,7 @@ Previously, this was not possible with the prediction functions defined internal ### Documentation improvements * The [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) has been updated to reflect the new framework for explaining predictions, and all the new package features/functionality. +* New vignettes also for the regression paradigm, vaeac and the asymmetric/causal Shapley values # shapr 0.2.3 (GitHub only) diff --git a/R/RcppExports.R b/R/RcppExports.R index 1f27325fe7c011ff824691dcf0bdef196e1470d2..1ab7b6196d11264c1a92deb7753aec9d3415c985 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -110,7 +110,7 @@ inv_gaussian_transform_cpp <- function(z, x) { #' Generate (Gaussian) Copula MC samples #' -#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the #' univariate standard normal. #' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations #' to explain on the original scale. @@ -118,7 +118,7 @@ inv_gaussian_transform_cpp <- function(z, x) { #' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been #' transformed to a standardized normal distribution. #' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of #' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. #' This is not a problem internally in shapr as the empty and grand coalitions treated differently. #' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -127,8 +127,8 @@ inv_gaussian_transform_cpp <- function(z, x) { #' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been #' transformed to a standardized normal distribution. #' -#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian #' copula MC samples for each explicand and coalition on the original scale. #' #' @export @@ -138,21 +138,51 @@ prepare_data_copula_cpp <- function(MC_samples_mat, x_explain_mat, x_explain_gau .Call(`_shapr_prepare_data_copula_cpp`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) } +#' Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +#' +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations to +#' explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`. +#' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the +#' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +#' transformed to a standardized normal distribution. +#' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of +#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +#' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed +#' using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution. +#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +#' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +#' transformed to a standardized normal distribution. +#' +#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +#' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_copula_cpp_caus <- function(MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) { + .Call(`_shapr_prepare_data_copula_cpp_caus`, MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat) +} + #' Generate Gaussian MC samples #' -#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the #' univariate standard normal. #' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations #' to explain. -#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of #' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. #' This is not a problem internally in shapr as the empty and grand coalitions treated differently. #' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. #' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance #' between all pairs of features. #' -#' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -#' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +#' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +#' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian #' MC samples for each explicand and coalition. #' #' @export @@ -162,6 +192,30 @@ prepare_data_gaussian_cpp <- function(MC_samples_mat, x_explain_mat, S, mu, cov_ .Call(`_shapr_prepare_data_gaussian_cpp`, MC_samples_mat, x_explain_mat, S, mu, cov_mat) } +#' Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +#' +#' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +#' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +#' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations +#' to explain. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat` +#' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +#' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +#' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +#' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. +#' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +#' between all pairs of features. +#' +#' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +#' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +#' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_gaussian_cpp_caus <- function(MC_samples_mat, x_explain_mat, S, mu, cov_mat) { + .Call(`_shapr_prepare_data_gaussian_cpp_caus`, MC_samples_mat, x_explain_mat, S, mu, cov_mat) +} + #' (Generalized) Mahalanobis distance #' #' Used to get the Euclidean distance as well by setting \code{mcov} = \code{diag(m)}. @@ -199,7 +253,7 @@ sample_features_cpp <- function(m, n_features) { #' #' @param xtest Numeric matrix. Represents a single test observation. #' -#' @param S Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +#' @param S Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals #' the total number of sampled/non-sampled feature combinations and \code{m} equals #' the total number of unique features. Note that \code{m = ncol(xtrain)}. See details #' for more information. @@ -228,34 +282,34 @@ observation_impute_cpp <- function(index_xtrain, index_s, xtrain, xtest, S) { #' Calculate weight matrix #' -#' @param subsets List. Each of the elements equals an integer +#' @param coalitions List. Each of the elements equals an integer #' vector representing a valid combination of features/feature groups. #' @param m Integer. Number of features/feature groups #' @param n Integer. Number of combinations #' @param w Numeric vector of length \code{n}, i.e. \code{w[i]} equals #' the Shapley weight of feature/feature group combination \code{i}, represented by -#' \code{subsets[[i]]}. +#' \code{coalitions[[i]]}. #' #' @export #' @keywords internal #' #' @return Matrix of dimension n x m + 1 -#' @author Nikolai Sellereite -weight_matrix_cpp <- function(subsets, m, n, w) { - .Call(`_shapr_weight_matrix_cpp`, subsets, m, n, w) +#' @author Nikolai Sellereite, Martin Jullum +weight_matrix_cpp <- function(coalitions, m, n, w) { + .Call(`_shapr_weight_matrix_cpp`, coalitions, m, n, w) } -#' Get feature matrix +#' Get coalition matrix #' -#' @param features List -#' @param m Positive integer. Total number of features +#' @param coalitions List +#' @param m Positive integer. Total number of coalitions #' #' @export #' @keywords internal #' #' @return Matrix -#' @author Nikolai Sellereite -feature_matrix_cpp <- function(features, m) { - .Call(`_shapr_feature_matrix_cpp`, features, m) +#' @author Nikolai Sellereite, Martin Jullum +coalition_matrix_cpp <- function(coalitions, m) { + .Call(`_shapr_coalition_matrix_cpp`, coalitions, m) } diff --git a/R/approach.R b/R/approach.R index e0325ea3d449c81b388fa1f753a54d7db9b1e381..2b08c454f3489d44c5cfe99d4b53c5132ec3cbd0 100644 --- a/R/approach.R +++ b/R/approach.R @@ -9,17 +9,49 @@ #' #' @export setup_approach <- function(internal, ...) { + verbose <- internal$parameters$verbose + approach <- internal$parameters$approach - this_class <- "" + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X - if (length(approach) > 1) { - class(this_class) <- "combined" + + + needs_X <- c("regression_surrogate", "vaeac") + + run_now <- (isFALSE(any(needs_X %in% approach)) && isTRUE(is.null(X))) || + (isTRUE(any(needs_X %in% approach)) && isFALSE(is.null(X))) + + if (isFALSE(run_now)) { # Do nothing + return(internal) } else { - class(this_class) <- approach - } + if ("progress" %in% verbose) { + cli::cli_progress_step("Setting up approach(es)") + } + if ("vS_details" %in% verbose) { + if ("vaeac" %in% approach) { + pretrained_provided <- internal$parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided + if (isFALSE(pretrained_provided)) { + cli::cli_h2("Extra info about the training/tuning of the vaeac model") + } else { + cli::cli_h2("Extra info about the pretrained vaeac model") + } + } + } + + this_class <- "" + + if (length(approach) > 1) { + class(this_class) <- "combined" + } else { + class(this_class) <- approach + } + + UseMethod("setup_approach", this_class) - UseMethod("setup_approach", this_class) + internal$timing_list$setup_approach <- Sys.time() + } } #' @inheritParams default_doc @@ -49,6 +81,10 @@ setup_approach.combined <- function(internal, ...) { #' @export #' @keywords internal prepare_data <- function(internal, index_features = NULL, ...) { + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + # Extract the used approach(es) approach <- internal$parameters$approach @@ -57,9 +93,9 @@ prepare_data <- function(internal, index_features = NULL, ...) { # Check if the user provided one or several approaches. if (length(approach) > 1) { - # Picks the relevant approach from the internal$objects$X table which list the unique approach of the batch + # Picks the relevant approach from the X table which list the unique approach of the batch # matches by index_features - class(this_class) <- internal$objects$X[id_combination == index_features[1], approach] + class(this_class) <- X[id_coalition == index_features[1], approach] } else { # Only one approach for all coalitions sizes class(this_class) <- approach diff --git a/R/approach_categorical.R b/R/approach_categorical.R index f29ea07f2107acb604625e14ea152068055e1f9e..e023e451fdb2a52fa05fb10c4e4b708d05bd8dba 100644 --- a/R/approach_categorical.R +++ b/R/approach_categorical.R @@ -7,7 +7,7 @@ #' #' @param categorical.epsilon Numeric value. (Optional) #' If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -#' estimated using `x_train`. If certain observations occur in `x_train` and NOT in `x_explain`, +#' estimated using `x_train`. If certain observations occur in `x_explain` and NOT in `x_train`, #' then epsilon is used as the proportion of times that these observations occurs in the training data. #' In theory, this proportion should be zero, but this causes an error later in the Shapley computation. #' @@ -36,28 +36,36 @@ setup_approach.categorical <- function(internal, # estimate joint_prob_dt if it is not passed to the function if (is.null(joint_probability_dt)) { + # Get the frequency of the unique feature value combinations in the training data joint_prob_dt0 <- x_train[, .N, eval(feature_names)] - explain_not_in_train <- data.table::setkeyv(data.table::setDT(x_explain), feature_names)[!x_train] + # Get the feature value combinations in the explicands that are NOT in the training data and their frequency + explain_not_in_train <- data.table::setkeyv(data.table::setDT(data.table::copy(x_explain)), feature_names)[!x_train] N_explain_not_in_train <- nrow(unique(explain_not_in_train)) + # Add these feature value combinations, and their corresponding frequency, to joint_prob_dt0 if (N_explain_not_in_train > 0) { joint_prob_dt0 <- rbind(joint_prob_dt0, cbind(explain_not_in_train, N = categorical.epsilon)) } + # Compute the joint probability for each feature value combination joint_prob_dt0[, joint_prob := N / .N] joint_prob_dt0[, joint_prob := joint_prob / sum(joint_prob)] data.table::setkeyv(joint_prob_dt0, feature_names) + # Remove the frequency column and add an id column joint_probability_dt <- joint_prob_dt0[, N := NULL][, id_all := .I] } else { + # The `joint_probability_dt` is passed to explain by the user, and we do some checks. for (i in colnames(x_explain)) { + # Check that feature name is present is_error <- !(i %in% names(joint_probability_dt)) if (is_error > 0) { stop(paste0(i, " is in x_explain but not in joint_probability_dt.")) } + # Check that the feature has the same levels is_error <- !all(levels(x_explain[[i]]) %in% levels(joint_probability_dt[[i]])) if (is_error > 0) { @@ -65,6 +73,7 @@ setup_approach.categorical <- function(internal, } } + # Check that dt contains a `joint_prob` col all entries are probabilities between 0 and 1 (inclusive) and add to 1. is_error <- !("joint_prob" %in% names(joint_probability_dt)) | !all(joint_probability_dt$joint_prob <= 1) | !all(joint_probability_dt$joint_prob >= 0) | @@ -76,9 +85,11 @@ setup_approach.categorical <- function(internal, sum(joint_prob) must equal to 1.') } + # Add an id column joint_probability_dt <- joint_probability_dt[, id_all := .I] } + # Store the `joint_probability_dt` data table internal$parameters$categorical.joint_prob_dt <- joint_probability_dt return(internal) @@ -90,42 +101,39 @@ setup_approach.categorical <- function(internal, #' @rdname prepare_data #' @export #' @keywords internal +#' @author Annabelle Redelmeier and Lars Henry Berge Olsen prepare_data.categorical <- function(internal, index_features = NULL, ...) { - x_train <- internal$data$x_train - x_explain <- internal$data$x_explain - - joint_probability_dt <- internal$parameters$categorical.joint_prob_dt - - X <- internal$objects$X - S <- internal$objects$S - - if (is.null(index_features)) { # 2,3 - features <- X$features # list of [1], [2], [2, 3] - } else { - features <- X$features[index_features] # list of [1], + # Use a faster function when index_feature is only a single coalition, as in causal Shapley values. + if (length(index_features) == 1) { + return(prepare_data_single_coalition(internal, index_features)) } - feature_names <- internal$parameters$feature_names - # 3 id columns: id, id_combination, and id_all + # 3 id columns: id, id_coalition, and id_all # id: for each x_explain observation - # id_combination: the rows of the S matrix + # id_coalition: the rows of the S matrix # id_all: identifies the unique combinations of feature values from # the training data (not necessarily the ones in the explain data) + # Extract the needed objects/variables + x_explain <- internal$data$x_explain + joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + feature_names <- internal$parameters$feature_names feature_conditioned <- paste0(feature_names, "_conditioned") feature_conditioned_id <- c(feature_conditioned, "id") - S_dt <- data.table::data.table(S) + # Extract from iterative list + iter <- length(internal$iter_list) + S <- internal$iter_list[[iter]]$S + S_dt <- data.table::data.table(S[index_features, , drop = FALSE]) S_dt[S_dt == 0] <- NA - S_dt[, id_combination := seq_len(nrow(S_dt))] - - data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + S_dt[, id_coalition := index_features] + data.table::setnames(S_dt, c(feature_conditioned, "id_coalition")) # (1) Compute marginal probabilities - # multiply table of probabilities nrow(S) times - joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))] + # multiply table of probabilities length(index_features) times + joint_probability_mult <- joint_probability_dt[rep(id_all, length(index_features))] data.table::setkeyv(joint_probability_mult, "id_all") j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix @@ -153,21 +161,17 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) { cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] cond_dt[, cond_prob := joint_prob / marg_prob] - cond_dt[id_combination == 1, marg_prob := 0] - cond_dt[id_combination == 1, cond_prob := 1] # check marginal probabilities cond_dt_unique <- unique(cond_dt, by = feature_conditioned) - check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), - by = "id_combination" - ][["sum_prob"]] + check <- cond_dt_unique[id_coalition != 1][, .(sum_prob = sum(marg_prob)), by = "id_coalition"][["sum_prob"]] if (!all(round(check) == 1)) { print("Warning - not all marginal probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") } # make x_explain - data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + data.table::setkeyv(cond_dt, c("id_coalition", "id_all")) x_explain_with_id <- data.table::copy(x_explain)[, id := .I] dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] @@ -178,22 +182,67 @@ prepare_data.categorical <- function(internal, index_features = NULL, ...) { dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] # check conditional probabilities - check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), - by = c("id_combination", "id") - ][["sum_prob"]] + check <- dt[id_coalition != 1][, .(sum_prob = sum(cond_prob)), by = c("id_coalition", "id")][["sum_prob"]] if (!all(round(check) == 1)) { print("Warning - not all conditional probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") } setnames(dt, "cond_prob", "w") - data.table::setkeyv(dt, c("id_combination", "id")) - - # here we merge so that we only return the combintations found in our actual explain data - # this merge does not change the number of rows in dt - # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") - # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] - dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] - ret_col <- c("id_combination", "id", feature_names, "w") - return(dt[id_combination %in% index_features, mget(ret_col)]) + data.table::setkeyv(dt, c("id_coalition", "id")) + + # Return the relevant columns + return(dt[, mget(c("id_coalition", "id", feature_names, "w"))]) +} + +#' Compute the conditional probabilities for a single coalition for the categorical approach +#' +#' The [shapr::prepare_data.categorical()] function is slow when evaluated for a single coalition. +#' This is a bottleneck for Causal Shapley values which call said function a lot with single coalitions. +#' +#' @inheritParams default_doc +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_single_coalition <- function(internal, index_features) { + # if (length(index_features) != 1) stop("`index_features` must be single integer.") + + # Extract the needed objects + x_explain <- internal$data$x_explain + feature_names <- internal$parameters$feature_names + joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + + # Extract from iterative list + iter <- length(internal$iter_list) + S <- internal$iter_list[[iter]]$S + + # Add an id column to x_explain (copy as this changes `x_explain` outside the function) + x_explain_copy <- data.table::copy(x_explain)[, id := .I] + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_features, ] == 1] + cond_cols_with_id <- c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values <- x_explain_copy[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_coalition <- data.table::merge.data.table(joint_probability_dt, + dt_conditional_feature_values, + by = cond_cols, + allow.cartesian = TRUE + ) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_coalition[, w := joint_prob / sum(joint_prob), by = id] + results_id_coalition[, c("id_all", "joint_prob") := NULL] + + # Set the index_features to their correct value + results_id_coalition[, id_coalition := index_features] + + # Set id_coalition and id to be the keys and the two first columns for consistency with other approaches + data.table::setkeyv(results_id_coalition, c("id_coalition", "id")) + data.table::setcolorder(results_id_coalition, c("id_coalition", "id", feature_names)) + + return(results_id_coalition) } diff --git a/R/approach_copula.R b/R/approach_copula.R index 4e7f5e914bdb8ec4186ddb959af895493e325938..f89112f7e37aae27a3546627ea0dd13ea4c1bebb 100644 --- a/R/approach_copula.R +++ b/R/approach_copula.R @@ -47,25 +47,71 @@ setup_approach.copula <- function(internal, ...) { #' @author Lars Henry Berge Olsen prepare_data.copula <- function(internal, index_features, ...) { # Extract used variables - S <- internal$objects$S[index_features, , drop = FALSE] feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu copula.cov_mat <- internal$parameters$copula.cov_mat copula.x_explain_gaussian_mat <- as.matrix(internal$data$copula.x_explain_gaussian) + causal_sampling <- internal$parameters$causal_sampling + + iter <- length(internal$iter_list) + + S <- internal$iter_list[[iter]]$S[index_features, , drop = FALSE] + + if (causal_sampling) { + # Casual Shapley values (either symmetric or asymmetric) + + # Get if this is the first causal sampling step + causal_first_step <- isTRUE(internal$parameters$causal_first_step) # Only set when called from prepdare_data_causal + + # Set which copula data generating function to use + prepare_copula <- ifelse(causal_first_step, prepare_data_copula_cpp, prepare_data_copula_cpp_caus) + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- ifelse(causal_first_step, TRUE, FALSE) + + # For not the first step, the number of MC samples for causal Shapley values are n_explain, see prepdare_data_causal + n_MC_samples_updated <- ifelse(causal_first_step, n_MC_samples, n_explain) + + # Update data when not in the first causal sampling step, see prepdare_data_causal for explanations + if (!causal_first_step) { + # Update the `copula.x_explain_gaussian_mat` + copula.x_explain_gaussian <- apply( + X = rbind(x_explain_mat, x_train_mat), + MARGIN = 2, + FUN = gaussian_transform_separate, + n_y = nrow(x_explain_mat) + ) + if (is.null(dim(copula.x_explain_gaussian))) copula.x_explain_gaussian <- t(as.matrix(copula.x_explain_gaussian)) + copula.x_explain_gaussian_mat <- as.matrix(copula.x_explain_gaussian) + } + } else { + # Regular Shapley values (either symmetric or asymmetric) + + # Set which copula data generating function to use + prepare_copula <- prepare_data_copula_cpp + + # Set if we have to reshape the output of the prepare_copula function + reshape_prepare_copula_output <- TRUE + + # Set that the number of updated MC samples, only used when sampling from N(0, 1) + n_MC_samples_updated <- n_MC_samples + } # Generate the MC samples from N(0, 1) - MC_samples_mat <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features) + MC_samples_mat <- matrix(rnorm(n_MC_samples_updated * n_features), nrow = n_MC_samples_updated, ncol = n_features) # Use C++ to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}), for all coalitions and explicands, # and then transforming them back to the original scale using the inverse Gaussian transform in C++. - # The object `dt` is a 3D array of dimension (n_samples, n_explain * n_coalitions, n_features). - dt <- prepare_data_copula_cpp( + # The `dt` object is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features) for regular + # Shapley and in the first step for causal Shapley values. For later steps in the causal Shapley value framework, + # the `dt` object is a matrix of dimension (n_explain * n_coalitions, n_features). + dt <- prepare_copula( MC_samples_mat = MC_samples_mat, x_explain_mat = x_explain_mat, x_explain_gaussian_mat = copula.x_explain_gaussian_mat, @@ -75,17 +121,17 @@ prepare_data.copula <- function(internal, index_features, ...) { cov_mat = copula.cov_mat ) - # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + # Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features) when needed + if (reshape_prepare_copula_output) dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] - dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] - dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = nrow(S))] + dt[, w := 1 / n_MC_samples] + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } diff --git a/R/approach_ctree.R b/R/approach_ctree.R index 3c73c0d5adca46b4d67cce87da80470c758119d3..86e8b5e970629eeb7853138f5fbb3869ea625024 100644 --- a/R/approach_ctree.R +++ b/R/approach_ctree.R @@ -12,13 +12,13 @@ #' Determines the minimum sum of weights in a terminal node required for a split #' #' @param ctree.sample Boolean. (default = TRUE) -#' If TRUE, then the method always samples `n_samples` observations from the leaf nodes (with replacement). -#' If FALSE and the number of observations in the leaf node is less than `n_samples`, +#' If TRUE, then the method always samples `n_MC_samples` observations from the leaf nodes (with replacement). +#' If FALSE and the number of observations in the leaf node is less than `n_MC_samples`, #' the method will take all observations in the leaf. -#' If FALSE and the number of observations in the leaf node is more than `n_samples`, -#' the method will sample `n_samples` observations (with replacement). +#' If FALSE and the number of observations in the leaf node is more than `n_MC_samples`, +#' the method will sample `n_MC_samples` observations (with replacement). #' This means that there will always be sampling in the leaf unless -#' `sample` = FALSE AND the number of obs in the node is less than `n_samples`. +#' `sample` = FALSE AND the number of obs in the node is less than `n_MC_samples`. #' #' @inheritParams default_doc_explain #' @@ -46,7 +46,7 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { x_train <- internal$data$x_train x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_features <- internal$parameters$n_features ctree.mincriterion <- internal$parameters$ctree.mincriterion ctree.minsplit <- internal$parameters$ctree.minsplit @@ -54,7 +54,9 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { ctree.sample <- internal$parameters$ctree.sample labels <- internal$objects$feature_specs$labels - X <- internal$objects$X + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X dt_l <- list() @@ -81,24 +83,24 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { l <- lapply( X = all_trees, FUN = sample_ctree, - n_samples = n_samples, + n_MC_samples = n_MC_samples, x_explain = x_explain[i, , drop = FALSE], x_train = x_train, n_features = n_features, sample = ctree.sample ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") - dt_l[[i]][, w := 1 / n_samples] + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") + dt_l[[i]][, w := 1 / n_MC_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) - dt[id_combination %in% c(1, 2^n_features), w := 1.0] + dt[id_coalition %in% c(1, 2^n_features), w := 1.0] # only return unique dt - dt2 <- dt[, sum(w), by = c("id_combination", labels, "id")] + dt2 <- dt[, sum(w), by = c("id_coalition", labels, "id")] setnames(dt2, "V1", "w") return(dt2) @@ -121,7 +123,7 @@ prepare_data.ctree <- function(internal, index_features = NULL, ...) { #' @param minbucket Numeric scalar. (default = 7) #' Determines the minimum sum of weights in a terminal node required for a split #' -#' @param use_partykit String. In some semi-rare cases `partyk::ctree` runs into an error related to the LINPACK +#' @param use_partykit String. In some semi-rare cases `partykit::ctree` runs into an error related to the LINPACK #' used by R. To get around this problem, one may fall back to using the newer (but slower) `partykit::ctree` #' function, which is a reimplementation of the same method. Setting this parameter to `"on_error"` (default) #' falls back to `partykit::ctree`, if `party::ctree` fails. Other options are `"never"`, which always @@ -202,7 +204,7 @@ create_ctree <- function(given_ind, #' @param tree List. Contains tree which is an object of type ctree built from the party package. #' Also contains given_ind, the features to condition upon. #' -#' @param n_samples Numeric. Indicates how many samples to use for MCMC. +#' @param n_MC_samples Numeric. Indicates how many samples to use for MCMC. #' #' @param x_explain Matrix, data.frame or data.table with the features of the observation whose #' predictions ought to be explained (test data). Dimension `1\timesp` or `p\times1`. @@ -213,15 +215,15 @@ create_ctree <- function(given_ind, #' #' @param sample Boolean. True indicates that the method samples from the terminal node #' of the tree whereas False indicates that the method takes all the observations if it is -#' less than n_samples. +#' less than n_MC_samples. #' -#' @return data.table with `n_samples` (conditional) Gaussian samples +#' @return data.table with `n_MC_samples` (conditional) Gaussian samples #' #' @keywords internal #' #' @author Annabelle Redelmeier sample_ctree <- function(tree, - n_samples, + n_MC_samples, x_explain, x_train, n_features, @@ -263,12 +265,12 @@ sample_ctree <- function(tree, rowno <- seq_len(nrow(x_train)) - use_all_obs <- !sample & (length(rowno[fit.nodes == pred.nodes]) <= n_samples) + use_all_obs <- !sample & (length(rowno[fit.nodes == pred.nodes]) <= n_MC_samples) if (use_all_obs) { newrowno <- rowno[fit.nodes == pred.nodes] } else { - newrowno <- sample(rowno[fit.nodes == pred.nodes], n_samples, + newrowno <- sample(rowno[fit.nodes == pred.nodes], n_MC_samples, replace = TRUE ) } diff --git a/R/approach_empirical.R b/R/approach_empirical.R index 00f182807750e23b9edb785ddc9f820b08f76ffd..cbf6a7c75e25bb42ac96fae14310d9925e8d18ce 100644 --- a/R/approach_empirical.R +++ b/R/approach_empirical.R @@ -12,7 +12,7 @@ #' `eta` is the \eqn{\eta} parameter in equation (15) of Aas et al (2021). #' #' @param empirical.fixed_sigma Positive numeric scalar. (default = 0.1) -#' Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +#' Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. #' Only used when `empirical.type = "fixed_sigma"` #' #' @param empirical.n_samples_aicc Positive integer. (default = 1000) @@ -116,14 +116,17 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_explain <- internal$data$x_explain empirical.cov_mat <- internal$parameters$empirical.cov_mat - X <- internal$objects$X - S <- internal$objects$S + + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S n_explain <- internal$parameters$n_explain empirical.type <- internal$parameters$empirical.type empirical.eta <- internal$parameters$empirical.eta empirical.fixed_sigma <- internal$parameters$empirical.fixed_sigma - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples model <- internal$tmp$model predict_model <- internal$tmp$predict_model @@ -165,11 +168,11 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain = as.matrix(x_explain[i, , drop = FALSE]), empirical.eta = empirical.eta, - n_samples = n_samples + n_MC_samples = n_MC_samples ) dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } } else { h_optim_mat <- matrix(NA, ncol = n_col, nrow = no_empirical) @@ -214,11 +217,11 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain = as.matrix(x_explain[i, , drop = FALSE]), empirical.eta = empirical.eta, - n_samples = n_samples + n_MC_samples = n_MC_samples ) dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } } @@ -235,9 +238,9 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { #' Generate permutations of training data using test observations #' #' @param W_kernel Numeric matrix. Contains all nonscaled weights between training and test -#' observations for all feature combinations. The dimension equals `n_train x m`. -#' @param S Integer matrix of dimension `n_combinations x m`, where `n_combinations` -#' and `m` equals the total number of sampled/non-sampled feature combinations and +#' observations for all coalitions. The dimension equals `n_train x m`. +#' @param S Integer matrix of dimension `n_coalitions x m`, where `n_coalitions` +#' and `m` equals the total number of sampled/non-sampled coalitions and #' the total number of unique features, respectively. Note that `m = ncol(x_train)`. #' @param x_train Numeric matrix #' @param x_explain Numeric matrix @@ -249,15 +252,15 @@ prepare_data.empirical <- function(internal, index_features = NULL, ...) { #' @keywords internal #' #' @author Nikolai Sellereite -observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = .7, n_samples = 1e3) { +observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = .7, n_MC_samples = 1e3) { # Check input stopifnot(is.matrix(W_kernel) & is.matrix(S)) stopifnot(nrow(W_kernel) == nrow(x_train)) stopifnot(ncol(W_kernel) == nrow(S)) stopifnot(all(S %in% c(0, 1))) - index_s <- index_x_train <- id_combination <- weight <- w <- wcum <- NULL # due to NSE notes in R CMD check + index_s <- index_x_train <- id_coalition <- weight <- w <- wcum <- NULL # due to NSE notes in R CMD check - # Find weights for all combinations and training data + # Find weights for all coalitions and training data dt <- data.table::as.data.table(W_kernel) nms_vec <- seq_len(ncol(dt)) names(nms_vec) <- colnames(dt) @@ -265,11 +268,11 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = dt_melt <- data.table::melt( dt, id.vars = "index_x_train", - variable.name = "id_combination", + variable.name = "id_coalition", value.name = "weight", variable.factor = FALSE ) - dt_melt[, index_s := nms_vec[id_combination]] + dt_melt[, index_s := nms_vec[id_coalition]] # Remove training data with small weight knms <- c("index_s", "weight") @@ -279,7 +282,7 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = dt_melt[, wcum := cumsum(weight), by = "index_s"] dt_melt <- dt_melt[wcum > 1 - empirical.eta][, wcum := NULL] } - dt_melt <- dt_melt[, tail(.SD, n_samples), by = "index_s"] + dt_melt <- dt_melt[, tail(.SD, n_MC_samples), by = "index_s"] # Generate data used for prediction dt_p <- observation_impute_cpp( @@ -293,7 +296,7 @@ observation_impute <- function(W_kernel, S, x_train, x_explain, empirical.eta = # Add keys dt_p <- data.table::as.data.table(dt_p) data.table::setnames(dt_p, colnames(x_train)) - dt_p[, id_combination := dt_melt[["index_s"]]] + dt_p[, id_coalition := dt_melt[["index_s"]]] dt_p[, w := dt_melt[["weight"]]] return(dt_p) @@ -362,19 +365,22 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain empirical.n_samples_aicc <- internal$parameters$empirical.n_samples_aicc - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features + n_shapley_values <- internal$parameters$n_shapley_values labels <- internal$objects$feature_specs$labels empirical.start_aicc <- internal$parameters$empirical.start_aicc empirical.eval_max_aicc <- internal$parameters$empirical.eval_max_aicc - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + stopifnot( data.table::is.data.table(X), - !is.null(X[["id_combination"]]), - !is.null(X[["n_features"]]) + !is.null(X[["id_coalition"]]), + !is.null(X[["coalition_size"]]) ) optimsamp <- sample_combinations( @@ -386,7 +392,7 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) empirical.n_samples_aicc <- nrow(optimsamp) nloops <- n_explain # No of observations in test data - h_optim_mat <- matrix(NA, ncol = n_features, nrow = n_combinations) + h_optim_mat <- matrix(NA, ncol = n_shapley_values, nrow = n_coalitions) if (is.null(index_features)) { index_features <- X[, .I] @@ -394,10 +400,10 @@ compute_AICc_each_k <- function(internal, model, predict_model, index_features) # Optimization is done only once for all distributions which conditions on # exactly k variables - these_k <- unique(X[, n_features[index_features]]) + these_k <- unique(X[, coalition_size[index_features]]) for (i in these_k) { - these_cond <- X[index_features][n_features == i, id_combination] + these_cond <- X[index_features][coalition_size == i, id_coalition] cutters <- seq_len(empirical.n_samples_aicc) no_cond <- length(these_cond) cond_samp <- cut( @@ -477,14 +483,16 @@ compute_AICc_full <- function(internal, model, predict_model, index_features) { n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain empirical.n_samples_aicc <- internal$parameters$empirical.n_samples_aicc - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features + n_shapley_values <- internal$parameters$n_shapley_values labels <- internal$objects$feature_specs$labels empirical.start_aicc <- internal$parameters$empirical.start_aicc empirical.eval_max_aicc <- internal$parameters$empirical.eval_max_aicc - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S ntest <- n_explain @@ -500,7 +508,7 @@ compute_AICc_full <- function(internal, model, predict_model, index_features) { ) nloops <- n_explain # No of observations in test data - h_optim_mat <- matrix(NA, ncol = n_features, nrow = n_combinations) + h_optim_mat <- matrix(NA, ncol = n_shapley_values, nrow = n_coalitions) if (is.null(index_features)) { index_features <- X[, .I] diff --git a/R/approach_gaussian.R b/R/approach_gaussian.R index 23dd34d984dee2cfd5a7404cd115d444278d03e1..45c8aefbe05646367f0689d0ebf1d5d93a0e97f9 100644 --- a/R/approach_gaussian.R +++ b/R/approach_gaussian.R @@ -51,40 +51,67 @@ setup_approach.gaussian <- function(internal, #' @author Lars Henry Berge Olsen prepare_data.gaussian <- function(internal, index_features, ...) { # Extract used variables - S <- internal$objects$S[index_features, , drop = FALSE] feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain n_features <- internal$parameters$n_features - n_samples <- internal$parameters$n_samples - n_combinations_now <- length(index_features) + n_MC_samples <- internal$parameters$n_MC_samples + n_coalitions_now <- length(index_features) x_explain_mat <- as.matrix(internal$data$x_explain) mu <- internal$parameters$gaussian.mu cov_mat <- internal$parameters$gaussian.cov_mat + causal_sampling <- internal$parameters$causal_sampling + + iter <- length(internal$iter_list) + + S <- internal$iter_list[[iter]]$S[index_features, , drop = FALSE] + + if (causal_sampling) { + # Casual Shapley values (either symmetric or asymmetric) + + # Get if this is the first causal sampling step + causal_first_step <- isTRUE(internal$parameters$causal_first_step) # Only set when called from prepdare_data_causal + + # Set which gaussian data generating function to use + prepare_gauss <- ifelse(causal_first_step, prepare_data_gaussian_cpp, prepare_data_gaussian_cpp_caus) + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- ifelse(causal_first_step, TRUE, FALSE) + + # For not the first step, the number of MC samples for causal Shapley values are n_explain, see prepdare_data_causal + n_MC_samples_updated <- ifelse(causal_first_step, n_MC_samples, n_explain) + } else { + # Regular Shapley values (either symmetric or asymmetric) + + # Set which gaussian data generating function to use + prepare_gauss <- prepare_data_gaussian_cpp + + # Set if we have to reshape the output of the prepare_gauss function + reshape_prepare_gauss_output <- TRUE + + # Set that the number of updated MC samples, only used when sampling from N(0, 1) + n_MC_samples_updated <- n_MC_samples + } # Generate the MC samples from N(0, 1) - MC_samples_mat <- matrix(rnorm(n_samples * n_features), nrow = n_samples, ncol = n_features) + MC_samples_mat <- matrix(rnorm(n_MC_samples_updated * n_features), nrow = n_MC_samples_updated, ncol = n_features) - # Use Cpp to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}) for all coalitions and explicands. - # The object `dt` is a 3D array of dimension (n_samples, n_explain * n_coalitions, n_features). - dt <- prepare_data_gaussian_cpp( - MC_samples_mat = MC_samples_mat, - x_explain_mat = x_explain_mat, - S = S, - mu = mu, - cov_mat = cov_mat - ) + # Use C++ to convert the MC samples to N(mu_{Sbar|S}, Sigma_{Sbar|S}) for all coalitions and explicands. + # The `dt` object is a 3D array of dimension (n_MC_samples, n_explain * n_coalitions, n_features) for regular + # Shapley and in the first step for causal Shapley values. For later steps in the causal Shapley value framework, + # the `dt` object is a matrix of dimension (n_explain * n_coalitions, n_features). + dt <- prepare_gauss(MC_samples_mat = MC_samples_mat, x_explain_mat = x_explain_mat, S = S, mu = mu, cov_mat = cov_mat) - # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + # Reshape `dt` to a 2D array of dimension (n_MC_samples * n_explain * n_coalitions, n_features) when needed + if (reshape_prepare_gauss_output) dim(dt) <- c(n_coalitions_now * n_explain * n_MC_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] - dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] - dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = nrow(S))] + dt[, w := 1 / n_MC_samples] + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -112,3 +139,34 @@ get_cov_mat <- function(x_train, min_eigen_value = 1e-06) { get_mu_vec <- function(x_train) { unname(colMeans(x_train)) } + +#' Generate marginal Gaussian data using Cholesky decomposition +#' +#' Given a multivariate Gaussian distribution, this function creates data from specified marginals of said distribution. +#' +#' @param n_MC_samples Integer. The number of samples to generate. +#' @param Sbar_features Vector of integers indicating which marginals to sample from. +#' @param mu Numeric vector containing the expected values for all features in the multivariate Gaussian distribution. +#' @param cov_mat Numeric matrix containing the covariance between all features +#' in the multivariate Gaussian distribution. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +create_marginal_data_gaussian <- function(n_MC_samples, Sbar_features, mu, cov_mat) { + # Extract the sub covariance matrix for the selected features + cov_submat <- cov_mat[Sbar_features, Sbar_features] + + # Perform the Cholesky decomposition of the covariance matrix + chol_decomp <- chol(cov_submat) + + # Generate independent standard normal samples + Z <- matrix(rnorm(n_MC_samples * length(Sbar_features)), nrow = n_MC_samples) + + # Transform the standard normal samples to have the desired covariance structure + samples <- Z %*% chol_decomp + + # Shift by the mean vector + samples <- sweep(samples, 2, mu[Sbar_features], "+") + + return(data.table(samples)) +} diff --git a/R/approach_independence.R b/R/approach_independence.R index ba45b7e4bd362ff03e65b0f9e89ab1f7dcc17d9c..7effc7f4ecb506a29fb795cc07b81be2515749f6 100644 --- a/R/approach_independence.R +++ b/R/approach_independence.R @@ -20,19 +20,21 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Extract relevant parameters feature_specs <- internal$objects$feature_specs - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples n_train <- internal$parameters$n_train n_explain <- internal$parameters$n_explain - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S if (is.null(index_features)) { - # Use all feature combinations/coalitions (only applies if a single approach is used) + # Use all coalitions (only applies if a single approach is used) index_features <- X[, .I] } - # Extract the relevant feature combinations/coalitions + # Extract the relevant coalitions # Set `drop = FALSE` to ensure that `S0` is a matrix. S0 <- S[index_features, , drop = FALSE] @@ -65,10 +67,10 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { x_explain0_mat <- as.matrix(x_explain0) # Get coalition indices. - # We repeat each coalition index `min(n_samples, n_train)` times. We use `min` - # as we cannot sample `n_samples` unique indices if `n_train` is less than `n_samples`. - index_s <- rep(seq_len(nrow(S0)), each = min(n_samples, n_train)) - w0 <- 1 / min(n_samples, n_train) # The inverse of the number of samples being used in practice + # We repeat each coalition index `min(n_MC_samples, n_train)` times. We use `min` + # as we cannot sample `n_MC_samples` unique indices if `n_train` is less than `n_MC_samples`. + index_s <- rep(seq_len(nrow(S0)), each = min(n_MC_samples, n_train)) + w0 <- 1 / min(n_MC_samples, n_train) # The inverse of the number of samples being used in practice # Creat a list to store the MC samples, where ith entry is associated with ith explicand dt_l <- list() @@ -80,7 +82,7 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Sample the indices of the training observations we are going to splice the explicand with # and replicate these indices by the number of coalitions. - index_xtrain <- c(replicate(nrow(S0), sample(x = seq(n_train), size = min(n_samples, n_train), replace = FALSE))) + index_xtrain <- c(replicate(nrow(S0), sample(x = seq(n_train), size = min(n_MC_samples, n_train), replace = FALSE))) # Generate data used for prediction. This splices the explicand with # the other sampled training observations for all relevant coalitions. @@ -95,7 +97,7 @@ prepare_data.independence <- function(internal, index_features = NULL, ...) { # Add keys dt_l[[i]] <- data.table::as.data.table(dt_p) data.table::setnames(dt_l[[i]], feature_specs$labels) - dt_l[[i]][, id_combination := index_features[index_s]] + dt_l[[i]][, id_coalition := index_features[index_s]] dt_l[[i]][, w := w0] dt_l[[i]][, id := i] } diff --git a/R/approach_regression_separate.R b/R/approach_regression_separate.R index 7104db548247a197263d9e096778ecaf369b3d3d..643acb3ebffc888877a09d43ea80d95a98fd15c7 100644 --- a/R/approach_regression_separate.R +++ b/R/approach_regression_separate.R @@ -11,8 +11,8 @@ #' The data.frame must contain the possible hyperparameter value combinations to try. #' The column names must match the names of the tuneable parameters specified in `regression.model`. #' If `regression.tune_values` is a function, then it should take one argument `x` which is the training data -#' for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -#' Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +#' for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +#' Using a function allows the hyperparameter values to change based on the size of the coalition See the regression #' vignette for several examples. #' Note, to make it easier to call `explain()` from Python, the `regression.tune_values` can also be a string #' containing an R function. For example, @@ -42,8 +42,6 @@ setup_approach.regression_separate <- function(internal, regression.check_namespaces() # Small printout to the user - if (internal$parameters$verbose == 2) message("Starting 'setup_approach.regression_separate'.") - if (internal$parameters$verbose == 2) regression.separate_time_mess() # TODO: maybe remove # Add the default parameter values for the non-user specified parameters for the separate regression approach defaults <- @@ -54,7 +52,6 @@ setup_approach.regression_separate <- function(internal, internal <- regression.check_parameters(internal = internal) # Small printout to the user - if (internal$parameters$verbose == 2) message("Done with 'setup_approach.regression_separate'.") return(internal) # Return the updated internal list } @@ -67,38 +64,42 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. # Load `workflows`, needed when parallelized as we call predict with a workflow object. Checked installed above. requireNamespace("workflows", quietly = TRUE) + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + verbose <- internal$parameters$verbose + # Get the features in the batch - features <- internal$objects$X$features[index_features] + features <- X$features[index_features] - # Small printout to the user about which batch that are currently worked on - if (internal$parameters$verbose == 2) regression.prep_message_batch(internal, index_features) - # Initialize empty data table with specific column names and id_combination (transformed to integer later). The data + # Initialize empty data table with specific column names and id_coalition (transformed to integer later). The data # table will contain the contribution function values for the coalitions given by `index_features` and all explicands. - dt_res_column_names <- c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain))) + dt_res_column_names <- c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain))) dt_res <- data.table(matrix(ncol = length(dt_res_column_names), nrow = 0, dimnames = list(NULL, dt_res_column_names))) # Iterate over the coalitions provided by index_features. # Note that index_features will never be NULL and never contain the empty or grand coalitions. for (comb_idx in seq_along(features)) { - # Get the column indices of the features in current coalition/combination + # Get the column indices of the features in current coalition current_comb <- features[[comb_idx]] # Extract the current training (and add y_hat as response) and explain data current_x_train <- internal$data$x_train[, ..current_comb][, "y_hat" := internal$data$x_train_y_hat] current_x_explain <- internal$data$x_explain[, ..current_comb] + # Fit the current separate regression model to the current training data - if (internal$parameters$verbose == 2) regression.prep_message_comb(internal, index_features, comb_idx) regression.current_fit <- regression.train_model( x = current_x_train, seed = internal$parameters$seed, - verbose = internal$parameters$verbose, + verbose = verbose, regression.model = internal$parameters$regression.model, regression.tune = internal$parameters$regression.tune, regression.tune_values = internal$parameters$regression.tune_values, regression.vfold_cv_para = internal$parameters$regression.vfold_cv_para, - regression.recipe_func = internal$parameters$regression.recipe_func + regression.recipe_func = internal$parameters$regression.recipe_func, + current_comb = current_comb ) # Compute the predicted response for the explicands, i.e., the v(S, x_i) for all explicands x_i. @@ -108,9 +109,9 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. dt_res <- rbind(dt_res, data.table(index_features[comb_idx], matrix(pred_explicand, nrow = 1)), use.names = FALSE) } - # Set id_combination to be the key - dt_res[, id_combination := as.integer(id_combination)] - data.table::setkey(dt_res, id_combination) + # Set id_coalition to be the key + dt_res[, id_coalition := as.integer(id_coalition)] + data.table::setkey(dt_res, id_coalition) # Return the estimated contribution function values return(dt_res) @@ -139,14 +140,15 @@ prepare_data.regression_separate <- function(internal, index_features = NULL, .. #' @keywords internal regression.train_model <- function(x, seed = 1, - verbose = 0, + verbose = NULL, regression.model = parsnip::linear_reg(), regression.tune = FALSE, regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.response_var = "y_hat", - regression.surrogate_n_comb = NULL) { + regression.surrogate_n_comb = NULL, + current_comb = NULL) { # Create a recipe to the augmented training data regression.recipe <- recipes::recipe(as.formula(paste(regression.response_var, "~ .")), data = x) @@ -203,9 +205,14 @@ regression.train_model <- function(x, grid = regression.grid, metrics = yardstick::metric_set(yardstick::rmse) ) - # Small printout to the user - if (verbose == 2) regression.cv_message(regression.results = regression.results, regression.grid = regression.grid) + if ("vS_details" %in% verbose) { + regression.cv_message( + regression.results = regression.results, + regression.grid = regression.grid, + current_comb = current_comb + ) + } # Set seed for reproducibility. Without this we get different results based on if we run in parallel or sequential set.seed(seed) @@ -320,6 +327,11 @@ regression.get_tune <- function(regression.model, regression.tune_values, x_trai #' @author Lars Henry Berge Olsen #' @keywords internal regression.check_parameters <- function(internal) { + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + + # Convert the objects to R-objects if they are strings if (is.character(internal$parameters$regression.model)) { internal$parameters$regression.model <- regression.get_string_to_R(internal$parameters$regression.model) @@ -343,7 +355,7 @@ regression.check_parameters <- function(internal) { # Check that `regression.check_sur_n_comb` is a valid value (only applicable for surrogate regression) regression.check_sur_n_comb( regression.surrogate_n_comb = internal$parameters$regression.surrogate_n_comb, - used_n_combinations = internal$parameters$used_n_combinations + n_coalitions = n_coalitions ) # Check and get if we are to tune the hyperparameters of the regression model @@ -432,43 +444,6 @@ regression.check_namespaces <- function() { } # Message functions ==================================================================================================== -#' Produce time message for separate regression -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.separate_time_mess <- function() { - message(paste( - "When using `approach = 'regression_separate'` the `explanation$timing$timing_secs` object \n", - "can be missleading as `setup_computation` does not contain the training times of the \n", - "regression models as they are trained on the fly in `compute_vS`. This is to reduce memory \n", - "usage and to improve efficency.\n" - )) # TODO: should we add the time somewhere else? -} - -#' Produce message about which batch prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.prep_message_batch <- function(internal, index_features) { - message(paste0( - "Working on batch ", internal$objects$X[id_combination == index_features[1]]$batch, " of ", - internal$parameters$n_batches, " in `prepare_data.", internal$parameters$approach, "()`." - )) -} - -#' Produce message about which combination prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @param comb_idx Integer. The index of the combination in a specific batch. -#' @author Lars Henry Berge Olsen -#' @keywords internal -regression.prep_message_comb <- function(internal, index_features, comb_idx) { - message(paste0( - "Working on combination with id ", internal$objects$X$id_combination[index_features[comb_idx]], - " of ", internal$parameters$used_n_combinations, "." - )) -} - #' Produce message about which batch prepare_data is working on #' #' @param regression.results The results of the CV procedures. @@ -477,7 +452,7 @@ regression.prep_message_comb <- function(internal, index_features, comb_idx) { #' #' @author Lars Henry Berge Olsen #' @keywords internal -regression.cv_message <- function(regression.results, regression.grid, n_cv = 10) { +regression.cv_message <- function(regression.results, regression.grid, n_cv = 10, current_comb) { # Get the feature names and add evaluation metric rmse feature_names <- names(regression.grid) feature_names_rmse <- c(feature_names, "rmse", "rmse_std_err") @@ -494,8 +469,16 @@ regression.cv_message <- function(regression.results, regression.grid, n_cv = 10 regression.grid_best$rmse_std <- round(best_results$std_err, 2) width <- sapply(regression.grid_best, function(x) max(nchar(as.character(unique(x))))) - # Message title of the results - message(paste0("Results of the ", best_results$n[1], "-fold cross validation (top ", n_cv, " best configurations):")) + # Regression_separate adds the v(S), while separate does not add anything, but prints the Extra info thing + if (!is.null(current_comb)) { + this_vS <- paste0("for v(", paste0(current_comb, collapse = " "), ") ") + } else { + cli::cli_h2("Extra info about the tuning of the regression model") + this_vS <- "" + } + + msg0 <- paste0("Top ", n_cv, " best configs ", this_vS, "(using ", best_results$n[1], "-fold CV)") + msg <- NULL # Iterate over the n_cv best results and print out the hyper parameter values and the rmse and rmse_std_err for (row_idx in seq_len(nrow(best_results))) { @@ -509,8 +492,11 @@ regression.cv_message <- function(regression.results, regression.grid, n_cv = 10 seq_along(feature_values_rmse), function(x) format(as.character(feature_values_rmse[x]), width = width[x], justify = "left") ) - message(paste0("#", row_idx, ": ", paste(paste(feature_names_rmse, "=", values_fixed_len), collapse = " "), "")) + msg <- + c(msg, paste0("#", row_idx, ": ", paste(paste(feature_names_rmse, "=", values_fixed_len), collapse = " "), "\n")) } - - message("") # Empty message to get a blank line + cli::cli({ + cli::cli_h3(msg0) + for (i in seq_along(msg)) cli::cli_text(msg[i]) + }) } diff --git a/R/approach_regression_surrogate.R b/R/approach_regression_surrogate.R index a61890694236ba109fe37942d555fa2c61848924..845daaa231214e08caae94f39788222d1f1d619f 100644 --- a/R/approach_regression_surrogate.R +++ b/R/approach_regression_surrogate.R @@ -3,13 +3,17 @@ #' #' @inheritParams default_doc_explain #' @inheritParams setup_approach.regression_separate -#' @param regression.surrogate_n_comb Integer (default is `internal$parameters$used_n_combinations`) specifying the -#' number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -#' "`internal$parameters$used_n_combinations` - 2". By default, we use all coalitions, but this can take a lot of memory -#' in larger dimensions. Note that by "all", we mean all coalitions chosen by `shapr` to be used. This will be all -#' \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if `shapr` is in the exact mode. If the -#' user sets a lower value than `internal$parameters$used_n_combinations`, then we sample this amount of unique -#' coalitions separately for each training observations. That is, on average, all coalitions should be equally trained. +#' @param regression.surrogate_n_comb Integer. +#' (default is `internal$iter_list[[length(internal$iter_list)]]$n_coalitions`) specifying the +#' number of unique coalitions to apply to each training observation. Maximum allowed value is +#' "`internal$iter_list[[length(internal$iter_list)]]$n_coalitions` - 2". +#' By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +#' Note that by "all", we mean all coalitions chosen by `shapr` to be used. +#' This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if `shapr` is in +#' the exact mode. +#' If the user sets a lower value than `internal$iter_list[[length(internal$iter_list)]]$n_coalitions`, +#' then we sample this amount of unique coalitions separately for each training observations. +#' That is, on average, all coalitions should be equally trained. #' #' @export #' @author Lars Henry Berge Olsen @@ -19,13 +23,14 @@ setup_approach.regression_surrogate <- function(internal, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.surrogate_n_comb = - internal$parameters$used_n_combinations - 2, + internal$iter_list[[length(internal$iter_list)]]$n_coalitions - 2, ...) { + verbose <- internal$parameters$verbose + # Check that required libraries are installed regression.check_namespaces() - # Small printout to the user - if (internal$parameters$verbose == 2) message("Starting 'setup_approach.regression_surrogate'.") + # Add the default parameter values for the non-user specified parameters for the separate regression approach defaults <- mget(c( @@ -43,11 +48,10 @@ setup_approach.regression_surrogate <- function(internal, ) # Fit the surrogate regression model and store it in the internal list - if (internal$parameters$verbose == 2) message("Start training the surrogate model.") internal$objects$regression.surrogate_model <- regression.train_model( x = x_train_augmented, seed = internal$parameters$seed, - verbose = internal$parameters$verbose, + verbose = verbose, regression.model = internal$parameters$regression.model, regression.tune = internal$parameters$regression.tune, regression.tune_values = internal$parameters$regression.tune_values, @@ -56,8 +60,6 @@ setup_approach.regression_surrogate <- function(internal, regression.surrogate_n_comb = regression.surrogate_n_comb + 1 # Add 1 as augment_include_grand = TRUE above ) - # Small printout to the user - if (internal$parameters$verbose == 2) message("Done with 'setup_approach.regression_surrogate'.") return(internal) # Return the updated internal list } @@ -70,8 +72,6 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . # Load `workflows`, needed when parallelized as we call predict with a workflow object. Checked installed above. requireNamespace("workflows", quietly = TRUE) - # Small printout to the user about which batch that are currently worked on - if (internal$parameters$verbose == 2) regression.prep_message_batch(internal, index_features) # Augment the explicand data x_explain_aug <- regression.surrogate_aug_data(internal, x = internal$data$x_explain, index_features = index_features) @@ -81,8 +81,8 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . # Insert the predicted contribution functions values into a data table of the correct setup dt_res <- data.table(as.integer(index_features), matrix(pred_explicand, nrow = length(index_features))) - data.table::setnames(dt_res, c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) - data.table::setkey(dt_res, id_combination) # Set id_combination to be the key + data.table::setnames(dt_res, c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) + data.table::setkey(dt_res, id_coalition) # Set id_coalition to be the key return(dt_res) } @@ -95,21 +95,21 @@ prepare_data.regression_surrogate <- function(internal, index_features = NULL, . #' @param y_hat Vector of numerics (optional) containing the predicted responses for the observations in `x`. #' @param index_features Array of integers (optional) containing which coalitions to consider. Must be provided if #' `x` is the explicands. -#' @param augment_add_id_comb Logical (default is `FALSE`). If `TRUE`, an additional column is adding containing +#' @param augment_add_id_coal Logical (default is `FALSE`). If `TRUE`, an additional column is adding containing #' which coalition was applied. #' @param augment_include_grand Logical (default is `FALSE`). If `TRUE`, then the grand coalition is included. #' If `index_features` are provided, then `augment_include_grand` has no effect. Note that if we sample the -#' combinations then the grand coalition is equally likely to be samples as the other coalitions (or weighted if +#' coalitions then the grand coalition is equally likely to be samples as the other coalitions (or weighted if #' `augment_comb_prob` is provided). #' @param augment_masks_as_factor Logical (default is `FALSE`). If `TRUE`, then the binary masks are converted #' to factors. If `FALSE`, then the binary masks are numerics. #' @param augment_comb_prob Array of numerics (default is `NULL`). The length of the array must match the number of -#' combinations being considered, where each entry specifies the probability of sampling the corresponding coalition. +#' coalitions being considered, where each entry specifies the probability of sampling the corresponding coalition. #' This is useful if we want to generate more training data for some specific coalitions. One possible choice would be -#' `augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_combinations] else NULL`. +#' `augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_coalitions] else NULL`. #' @param augment_weights String (optional). Specifying which type of weights to add to the observations. #' If `NULL` (default), then no weights are added. If `"Shapley"`, then the Shapley weights for the different -#' combinations are added to corresponding observations where the coalitions was applied. If `uniform`, then +#' coalitions are added to corresponding observations where the coalitions was applied. If `uniform`, then #' all observations get an equal weight of one. #' #' @return A data.table containing the augmented data. @@ -121,25 +121,28 @@ regression.surrogate_aug_data <- function(internal, index_features = NULL, augment_masks_as_factor = FALSE, augment_include_grand = FALSE, - augment_add_id_comb = FALSE, + augment_add_id_coal = FALSE, augment_comb_prob = NULL, augment_weights = NULL) { + iter <- length(internal$iter_list) + # Get some of the parameters - S <- internal$objects$S - actual_n_combinations <- internal$parameters$used_n_combinations - 2 # Remove empty and grand coalitions + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + actual_n_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Remove empty and grand coalitions regression.surrogate_n_comb <- internal$parameters$regression.surrogate_n_comb if (!is.null(index_features)) regression.surrogate_n_comb <- length(index_features) # Applicable from prep_data() if (augment_include_grand) { - actual_n_combinations <- actual_n_combinations + 1 # Add 1 to include the grand comb + actual_n_coalitions <- actual_n_coalitions + 1 # Add 1 to include the grand comb regression.surrogate_n_comb <- regression.surrogate_n_comb + 1 } - if (regression.surrogate_n_comb > actual_n_combinations) regression.surrogate_n_comb <- actual_n_combinations + if (regression.surrogate_n_comb > actual_n_coalitions) regression.surrogate_n_comb <- actual_n_coalitions # Small checks if (!is.null(augment_weights)) augment_weights <- match.arg(augment_weights, c("Shapley", "uniform")) - if (!is.null(augment_comb_prob) && length(augment_comb_prob) != actual_n_combinations) { - stop(paste("`augment_comb_prob` must be of length", actual_n_combinations, ".")) + if (!is.null(augment_comb_prob) && length(augment_comb_prob) != actual_n_coalitions) { + stop(paste("`augment_comb_prob` must be of length", actual_n_coalitions, ".")) } if (!is.null(augment_weights) && augment_include_grand && augment_weights == "Shapley") { @@ -164,11 +167,11 @@ regression.surrogate_aug_data <- function(internal, # Check if we are to augment the training data or the explicands if (is.null(index_features)) { # Training: get matrix (n_obs x regression.surrogate_n_comb) containing the indices of the active coalitions - if (regression.surrogate_n_comb >= actual_n_combinations) { # Start from two to exclude the empty set - comb_active_idx <- matrix(rep(seq(2, actual_n_combinations + 1), times = n_obs), ncol = n_obs) + if (regression.surrogate_n_comb >= actual_n_coalitions) { # Start from two to exclude the empty set + comb_active_idx <- matrix(rep(seq(2, actual_n_coalitions + 1), times = n_obs), ncol = n_obs) } else { comb_active_idx <- sapply(seq(n_obs), function(x) { # Add 1 as we want to exclude the empty set - sample.int(n = actual_n_combinations, size = regression.surrogate_n_comb, prob = augment_comb_prob) + 1 + sample.int(n = actual_n_coalitions, size = regression.surrogate_n_comb, prob = augment_comb_prob) + 1 }) } } else { @@ -178,8 +181,8 @@ regression.surrogate_aug_data <- function(internal, # Extract the active coalitions for each explicand. The number of rows are n_obs * n_comb_per_explicands, # where the first n_comb_per_explicands rows are connected to the first explicand and so on. Set the column names. - id_comb <- as.vector(comb_active_idx) - comb_active <- S[id_comb, , drop = FALSE] + id_coal <- as.vector(comb_active_idx) + comb_active <- S[id_coal, , drop = FALSE] colnames(comb_active) <- names(feature_classes) # Repeat the feature values as many times as there are active coalitions @@ -209,11 +212,11 @@ regression.surrogate_aug_data <- function(internal, # Add either uniform weights or Shapley kernel weights if (!is.null(augment_weights)) { - x_augmented[, "weight" := if (augment_weights == "Shapley") internal$objects$X$shapley_weight[id_comb] else 1] + x_augmented[, "weight" := if (augment_weights == "Shapley") X$shapley_weight[id_coal] else 1] } - # Add the id_comb as a factor - if (augment_add_id_comb) x_augmented[, "id_comb" := factor(id_comb)] + # Add the id_coal as a factor + if (augment_add_id_coal) x_augmented[, "id_coal" := factor(id_coal)] # Add repeated responses if provided if (!is.null(y_hat)) x_augmented[, "y_hat" := rep(y_hat, each = regression.surrogate_n_comb)] @@ -229,16 +232,16 @@ regression.surrogate_aug_data <- function(internal, #' Check that `regression.surrogate_n_comb` is either NULL or a valid integer. #' #' @inheritParams setup_approach.regression_surrogate -#' @param used_n_combinations Integer. The number of used combinations (including the empty and grand coalitions). +#' @param n_coalitions Integer. The number of used coalitions (including the empty and grand coalition). #' #' @author Lars Henry Berge Olsen #' @keywords internal -regression.check_sur_n_comb <- function(regression.surrogate_n_comb, used_n_combinations) { +regression.check_sur_n_comb <- function(regression.surrogate_n_comb, n_coalitions) { if (!is.null(regression.surrogate_n_comb)) { - if (regression.surrogate_n_comb < 1 || used_n_combinations - 2 < regression.surrogate_n_comb) { + if (regression.surrogate_n_comb < 1 || n_coalitions - 2 < regression.surrogate_n_comb) { stop(paste0( "`regression.surrogate_n_comb` (", regression.surrogate_n_comb, ") must be a positive integer less than or ", - "equal to `used_n_combinations` minus two (", used_n_combinations - 2, ")." + "equal to `n_coalitions` minus two (", n_coalitions - 2, ")." )) } } diff --git a/R/approach_timeseries.R b/R/approach_timeseries.R index 09f9fc113e2efe271d0619795bd31e2b6a2b5c66..c4be714f2bdce11e49447e6ed3408fbd4f998244 100644 --- a/R/approach_timeseries.R +++ b/R/approach_timeseries.R @@ -39,7 +39,7 @@ setup_approach.timeseries <- function(internal, #' @export #' @keywords internal prepare_data.timeseries <- function(internal, index_features = NULL, ...) { - id <- id_combination <- w <- NULL + id <- id_coalition <- w <- NULL x_train <- internal$data$x_train x_explain <- internal$data$x_explain @@ -48,8 +48,10 @@ prepare_data.timeseries <- function(internal, index_features = NULL, ...) { timeseries.upper_bound <- internal$parameters$timeseries.bounds[1] timeseries.lower_bound <- internal$parameters$timeseries.bounds[2] - X <- internal$objects$X - S <- internal$objects$S + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S if (is.null(index_features)) { features <- X$features @@ -134,12 +136,11 @@ prepare_data.timeseries <- function(internal, index_features = NULL, ...) { names(tmp[[j]]) <- names(tmp[[1]]) } - dt_l[[i]] <- rbindlist(tmp, idcol = "id_combination") - # dt_l[[i]][, w := 1 / .N, by = id_combination] # IS THIS NECESSARY? + dt_l[[i]] <- rbindlist(tmp, idcol = "id_coalition") dt_l[[i]][, id := i] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) - ret_col <- c("id_combination", "id", feature_names, "w") - return(dt[id_combination %in% index_features, mget(ret_col)]) + ret_col <- c("id_coalition", "id", feature_names, "w") + return(dt[id_coalition %in% index_features, mget(ret_col)]) } diff --git a/R/approach_vaeac.R b/R/approach_vaeac.R index 4ba03ba20da3d9c879e8f1f9caf55c74c8b40692..2eff261c3a21b8942b83c5a8c2e59fb0baaa7e3e 100644 --- a/R/approach_vaeac.R +++ b/R/approach_vaeac.R @@ -31,6 +31,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. vaeac.epochs = 100, vaeac.extra_parameters = list(), ...) { + verbose <- internal$parameters$verbose + # Check that torch is installed if (!requireNamespace("torch", quietly = TRUE)) { stop("`torch` is not installed. Please run `install.packages('torch')`.") @@ -38,13 +40,13 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. if (!torch::torch_is_installed()) stop("`torch` is not properly installed. Please run `torch::install_torch()`.") # Extract the objects we will use later - S <- internal$objects$S - X <- internal$objects$X + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + S_causal <- internal$iter_list[[iter]]$S_causal_steps_unique_S # NULL if not causal sampling + causal_sampling <- internal$parameters$causal_sampling # NULL if not causal sampling parameters <- internal$parameters - # Small printout to user - if (parameters$verbose == 2) message("Setting up the `vaeac` approach.") - # Check if we are doing a combination of approaches combined_approaches <- length(parameters$approach) > 1 @@ -62,10 +64,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. vaeac_main_para <- mget(vaeac_main_para_names) # Add the default extra parameter values for the non-user specified extra parameters - parameters$vaeac.extra_parameters <- utils::modifyList(vaeac_get_extra_para_default(), - parameters$vaeac.extra_parameters, - keep.null = TRUE - ) + parameters$vaeac.extra_parameters <- + utils::modifyList(vaeac_get_extra_para_default(), parameters$vaeac.extra_parameters, keep.null = TRUE) # Add the default main parameter values for the non-user specified main parameters parameters <- utils::modifyList(vaeac_main_para, parameters, keep.null = TRUE) @@ -74,20 +74,31 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. parameters <- c(parameters[(length(vaeac_main_para) + 1):length(parameters)], parameters[seq_along(vaeac_main_para)]) # Check if vaeac is to be applied on a subset of coalitions. - if (!parameters$exact || parameters$is_groupwise || combined_approaches) { + if (isTRUE(causal_sampling)) { + # We are doing causal Shapley values. Then we do not want to train on the full + # coalitions, but rather the coalitions in the chain of sampling steps used + # to generate the full MC sample. Casual Shapley does not support combined + # approaches, so we do not have to check for that. All coalitions are + # done by vaeac, and we give them equal importance. Skip the empty and grand coalitions. + # Note that some steps occur more often (when features in Sbar are late in the causal ordering), + # and one can potentially consider to give this more weight. + nrow_S_causal <- nrow(S_causal) + parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions <- S_causal[-c(1, nrow_S_causal), , drop = FALSE] + parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- rep(1, nrow_S_causal - 2) / (nrow_S_causal - 2) + } else if (!parameters$exact || parameters$is_groupwise || combined_approaches) { # We have either: - # 1) sampled `n_combinations` different subsets of coalitions (i.e., not exact), + # 1) sampled `n_coalitions` different subsets of coalitions (i.e., not exact), # 2) using the coalitions which respects the groups in group Shapley values, and/or # 3) using a combination of approaches where vaeac is only used on a subset of the coalitions. # Here, objects$S contains the coalitions while objects$X contains the information about the approach. # Extract the the coalitions / masks which are estimated using vaeac as a matrix parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions <- - S[X[approach == "vaeac"]$id_combination, , drop = FALSE] + S[X[approach == "vaeac"]$id_coalition, , drop = FALSE] # Extract the weights for the corresponding coalitions / masks. parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- - X$shapley_weight[X[approach == "vaeac"]$id_combination] + X$shapley_weight[X[approach == "vaeac"]$id_coalition] # Normalize the weights/probabilities such that they sum to one. parameters$vaeac.extra_parameters$vaeac.mask_gen_coalitions_prob <- @@ -101,8 +112,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # Check if user provided a pre-trained vaeac model, otherwise, we train one from scratch. if (is.null(parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model)) { # We train a vaeac model with the parameters in `parameters`, as user did not provide pre-trained vaeac model - if (parameters$verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Training the `vaeac` model with the provided parameters from scratch on ", ifelse(parameters$vaeac.extra_parameter$vaeac.cuda, "GPU", "CPU"), "." )) @@ -137,7 +148,7 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # The pre-trained vaeac model is either: # 1. The explanation$internal$parameters$vaeac list of type "vaeac" from an earlier call to explain(). # 2. A string containing the path to where the "vaeac" model is stored on disk. - if (parameters$verbose == 2) message("Loading the provided `vaeac` model.") + if ("vS_details" %in% verbose) cli::cli_text("Loading the provided `vaeac` model.") # Boolean representing that a pre-trained vaeac model was provided parameters$vaeac.extra_parameters$vaeac.pretrained_vaeac_model_provided <- TRUE @@ -146,8 +157,8 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. parameters <- vaeac_update_pretrained_model(parameters = parameters) # Small printout informing about the location of the model - if (parameters$verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "The `vaeac` model runs/is trained on ", ifelse(parameters$vaeac$parameters$cuda, "GPU", "CPU"), "." )) } @@ -172,8 +183,18 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. # Update/overwrite the parameters list in the internal list. internal$parameters <- parameters - # Small printout to user - if (parameters$verbose == 2) message("Done with setting up the `vaeac` approach.\n") + if ("vS_details" %in% verbose) { + folder_to_save_model <- parameters$vaeac$parameters$folder_to_save_model + vaeac_save_file_names <- parameters$vaeac$parameters$vaeac_save_file_names + + cli::cli_alert_info(c( + "The trained `vaeac` models are saved to folder {.path {folder_to_save_model}} at\n", + "{.path {vaeac_save_file_names[1]}}\n", + "{.path {vaeac_save_file_names[2]}}\n", + "{.path {vaeac_save_file_names[3]}}" + )) + } + # Return the updated internal list. return(internal) @@ -185,24 +206,25 @@ setup_approach.vaeac <- function(internal, # add default values for vaeac here. #' @export #' @author Lars Henry Berge Olsen prepare_data.vaeac <- function(internal, index_features = NULL, ...) { + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + S <- internal$iter_list[[iter]]$S + # If not provided, then set `index_features` to all non trivial coalitions - if (is.null(index_features)) index_features <- seq(2, internal$parameters$n_combinations - 1) + if (is.null(index_features)) index_features <- seq(2, n_coalitions - 1) # Extract objects we are going to need later - S <- internal$objects$S seed <- internal$parameters$seed verbose <- internal$parameters$verbose x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain - n_samples <- internal$parameters$n_samples + n_MC_samples <- internal$parameters$n_MC_samples vaeac.model <- internal$parameters$vaeac.model vaeac.sampler <- internal$parameters$vaeac.sampler vaeac.checkpoint <- internal$parameters$vaeac.checkpoint vaeac.batch_size_sampling <- internal$parameters$vaeac.extra_parameters$vaeac.batch_size_sampling - # Small printout to the user about which batch we are working on - if (verbose == 2) vaeac_prep_message_batch(internal = internal, index_features = index_features) - # Apply all coalitions to all explicands to get a data table where `vaeac` will impute the `NaN` values x_explain_extended <- vaeac_get_x_explain_extended(x_explain = x_explain, S = S, index_features = index_features) @@ -215,7 +237,7 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { x_explain_with_MC_samples_dt <- vaeac_impute_missing_entries( x_explain_with_NaNs = x_explain_extended, n_explain = n_explain, - n_samples = n_samples, + n_MC_samples = n_MC_samples, vaeac_model = vaeac.model, checkpoint = vaeac.checkpoint, sampler = vaeac.sampler, @@ -314,8 +336,8 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' `mask_gen_coalitions` is specified. #' @param mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the #' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally -#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when -#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +#' in `shapr` when we only consider a subset of coalitions, i.e., when +#' `n_coalitions` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., #' when `group` is specified in [shapr::explain()]. #' @param mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height #' of `mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in @@ -334,8 +356,6 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' Abalone data set), it can be advantageous to \eqn{\log} transform the data to unbounded form before using `vaeac`. #' If `TRUE`, then [shapr::vaeac_postprocess_data()] will take the \eqn{\exp} of the results to get back to strictly #' positive values when using the `vaeac` model to impute missing values/generate the Monte Carlo samples. -#' @param verbose Boolean. An integer specifying the level of verbosity. Use `0` (default) for no verbosity, -#' `1` for low verbose, and `2` for high verbose. #' @param seed Positive integer (default is `1`). Seed for reproducibility. Specifies the seed before any randomness #' based code is being run. #' @param which_vaeac_model String (default is `best`). The name of the `vaeac` model (snapshots from different @@ -344,6 +364,7 @@ prepare_data.vaeac <- function(internal, index_features = NULL, ...) { #' Note that additional choices are available if `vaeac.save_every_nth_epoch` is provided. For example, if #' `vaeac.save_every_nth_epoch = 5`, then `vaeac.which_vaeac_model` can also take the values `"epoch_5"`, `"epoch_10"`, #' `"epoch_15"`, and so on. +#' @inheritParams explain #' @param ... List of extra parameters, currently not used. #' #' @return A list containing the training/validation errors and paths to where the vaeac models are saved on the disk. @@ -472,14 +493,14 @@ vaeac_train_model <- function(x_train, # Add the number of trainable parameters in the vaeac model to the state list if (initialization_idx == 1) { state_list$n_trainable_parameters <- vaeac_model$n_train_param - if (verbose == 2) { - message(paste0("The vaeac model contains ", vaeac_model$n_train_param[1, 1], " trainable parameters.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("The vaeac model contains ", vaeac_model$n_train_param[1, 1], " trainable parameters.")) } } # Print which initialization vaeac the function is working on - if (verbose == 2) { - message(paste0("Initializing vaeac number ", initialization_idx, " of ", n_vaeacs_initialize, ".")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("Initializing vaeac model number ", initialization_idx, " of ", n_vaeacs_initialize, ".")) } # Create the ADAM optimizer @@ -515,8 +536,8 @@ vaeac_train_model <- function(x_train, # Check if we are printing detailed debug information # Small printout to the user stating which initiated vaeac model was the best. - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Best vaeac inititalization was number ", vaeac_model_best_list$initialization_idx, " (of ", n_vaeacs_initialize, ") with a training VLB = ", round(as.numeric(vaeac_model_best_list$train_vlb[-1]$cpu()), 3), " after ", epochs_initiation_phase, " epochs. Continue to train this inititalization." @@ -705,20 +726,17 @@ vaeac_train_model_auxiliary <- function(vaeac_model, # Save if current vaeac model has the lowest validation IWAE error if ((max(val_iwae) <= val_iwae_now)$item() || is.null(best_epoch)) { best_epoch <- epoch - if (verbose == 2) message("Saving `best` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[1]) } # Save if current vaeac model has the lowest running validation IWAE error if ((max(val_iwae_running) <= val_iwae_running_now)$item() || is.null(best_epoch_running)) { best_epoch_running <- epoch - if (verbose == 2) message("Saving `best_running` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[2]) } # Save if we are in an n'th epoch and are to save every n'th epoch if (is.numeric(save_every_nth_epoch) && epoch %% save_every_nth_epoch == 0) { - if (verbose == 2) message("Saving `nth_epoch` vaeac model at epoch ", epoch, ".") vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3 + epoch %/% save_every_nth_epoch]) } } @@ -742,8 +760,8 @@ vaeac_train_model_auxiliary <- function(vaeac_model, # Check if we are to apply early stopping, i.e., no improvement in the IWAE for `epochs_early_stopping` epochs. if (is.numeric(epochs_early_stopping)) { if (epoch - best_epoch >= epochs_early_stopping) { - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "No IWAE improvment in ", epochs_early_stopping, " epochs. Apply early stopping at epoch ", epoch, "." )) @@ -771,11 +789,10 @@ vaeac_train_model_auxiliary <- function(vaeac_model, ) } else { # Save the vaeac model at the last epoch - if (verbose == 2) message("Saving `last` vaeac model at epoch ", epoch, ".") last_state <- vaeac_save_state(state_list = state_list, file_name = vaeac_save_file_names[3], return_state = TRUE) # Summary printout - if (verbose == 2) vaeac_print_train_summary(best_epoch, best_epoch_running, last_state) + if ("vS_details" %in% verbose) vaeac_print_train_summary(best_epoch, best_epoch_running, last_state) # Create a return list return_list <- list( @@ -825,14 +842,14 @@ vaeac_train_model_continue <- function(explanation, lr_new = NULL, x_train = NULL, save_data = FALSE, - verbose = 0, + verbose = NULL, seed = 1) { # Check the input if (!"shapr" %in% class(explanation)) stop("`explanation` must be a list of class `shapr`.") if (!"vaeac" %in% explanation$internal$parameters$approach) stop("`vaeac` is not an approach in `explanation`.") if (!is.null(lr_new)) vaeac_check_positive_numerics(list(lr_new = lr_new)) if (!is.null(x_train) && !data.table::is.data.table(x_train)) stop("`x_train` must be a `data.table` object.") - vaeac_check_verbose(verbose) + check_verbose(verbose) vaeac_check_positive_integers(list(epochs_new = epochs_new, seed = seed)) vaeac_check_logicals(list(save_data = save_data)) @@ -998,25 +1015,26 @@ vaeac_train_model_continue <- function(explanation, #' #' @inheritParams vaeac_train_model #' @param x_explain_with_NaNs A 2D matrix, where the missing entries to impute are represented by `NaN`. -#' @param n_samples Integer. The number of imputed versions we create for each row in `x_explain_with_NaNs`. +#' @param n_MC_samples Integer. The number of imputed versions we create for each row in `x_explain_with_NaNs`. #' @param index_features Optional integer vector. Used internally in shapr package to index the coalitions. #' @param n_explain Positive integer. The number of explicands. #' @param vaeac_model An initialized `vaeac` model that we are going to use to generate the MC samples. #' @param checkpoint List containing the parameters of the `vaeac` model. #' @param sampler A sampler object used to sample the MC samples. #' -#' @return A data.table where the missing values (`NaN`) in `x_explain_with_NaNs` have been imputed `n_samples` times. +#' @return A data.table where the missing values (`NaN`) in `x_explain_with_NaNs` have been imputed `n_MC_samples` +#' times. #' The data table will contain extra id columns if `index_features` and `n_explain` are provided. #' #' @keywords internal #' @author Lars Henry Berge Olsen vaeac_impute_missing_entries <- function(x_explain_with_NaNs, - n_samples, + n_MC_samples, vaeac_model, checkpoint, sampler, batch_size, - verbose = 0, + verbose = NULL, seed = NULL, n_explain = NULL, index_features = NULL) { @@ -1031,8 +1049,6 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, torch::torch_manual_seed(seed) } - if (verbose == 2) message("Preprocessing the explicands.") - # Preprocess `x_explain_with_NaNs`. Turn factor names into numerics 1,2,...,K, (vaeac only accepts numerics) and keep # track of the maping of names. Optionally log-transform the continuous features. Then, finally, normalize the data # using the training means and standard deviations. I.e., we assume that the new data follow the same distribution as @@ -1051,11 +1067,9 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Create a data loader that load/iterate over the data set in chronological order. dataloader <- torch::dataloader(dataset = dataset, batch_size = batch_size, shuffle = FALSE) - if (verbose == 2) message("Generating the MC samples.") - # Create an auxiliary list of lists to store the imputed values combined with the original values. The structure is # [[i'th MC sample]][[b'th batch]], where the entries are tensors of dimension batch_size x n_features. - results <- lapply(seq(n_samples), function(k) list()) + results <- lapply(seq(n_MC_samples), function(k) list()) # Generate the conditional Monte Carlo samples for the observation `x_explain_with_NaNs`, one batch at the time. coro::loop(for (batch in dataloader) { @@ -1079,10 +1093,14 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Do not need to keep track of the gradients, as we are not fitting the model. torch::with_no_grad({ # Compute the distribution parameters for the generative models inferred by the masked encoder and decoder. - # This is a tensor of shape [batch_size, n_samples, n_generative_parameters]. Note that, for only continuous + # This is a tensor of shape [batch_size, n_MC_samples, n_generative_parameters]. Note that, for only continuous # features we have that n_generative_parameters = 2*n_features, but for categorical data the number depends # on the number of categories. - samples_params <- vaeac_model$generate_samples_params(batch = batch_extended, mask = mask_extended, K = n_samples) + samples_params <- vaeac_model$generate_samples_params( + batch = batch_extended, + mask = mask_extended, + K = n_MC_samples + ) # Remove the parameters belonging to added instances in batch_extended. samples_params <- samples_params[1:batch$shape[1], , ] @@ -1094,7 +1112,7 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, batch_zeroed_nans[mask] <- 0 # Iterate over the number of imputations and generate the imputed samples - for (i in seq(n_samples)) { + for (i in seq(n_MC_samples)) { # Extract the i'th inferred generative parameters for the whole batch. # sample_params is a tensor of shape [batch_size, n_generative_parameters]. sample_params <- samples_params[, i, ] @@ -1110,24 +1128,26 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # Make a deep copy and add it to correct location in the results list. results[[i]] <- append(results[[i]], sample$clone()$detach()$cpu()) - } # End of iterating over the n_samples + } # End of iterating over the n_MC_samples }) # End of iterating over the batches. Done imputing. - if (verbose == 2) message("Postprocessing the Monte Carlo samples.") - - # Order the MC samples into a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features]. The lapply function + # Order the MC samples into a tensor of shape [nrow(x_explain_with_NaNs), n_MC_samples, n_features]. + # The lapply function # creates a list of tensors of shape [nrow(x_explain_with_NaNs), 1, n_features] by concatenating the batches for the # i'th MC sample to a tensor of shape [nrow(x_explain_with_NaNs), n_features] and then add unsqueeze to add a new # singleton dimension as the second dimension to get the shape [nrow(x_explain_with_NaNs), 1, n_features]. Then - # outside of the lapply function, we concatenate the n_samples torch elements to form a final torch result of shape - # [nrow(x_explain_with_NaNs), n_samples, n_features]. - result <- torch::torch_cat(lapply(seq(n_samples), function(i) torch::torch_cat(results[[i]])$unsqueeze(2)), dim = 2) + # outside of the lapply function, we concatenate the n_MC_samples torch elements to form a final torch result of shape + # [nrow(x_explain_with_NaNs), n_MC_samples, n_features]. + result <- torch::torch_cat(lapply( + seq(n_MC_samples), + function(i) torch::torch_cat(results[[i]])$unsqueeze(2) + ), dim = 2) # Get back to the original distribution by undoing the normalization by multiplying with the std and adding the mean result <- result * checkpoint$norm_std + checkpoint$norm_mean - # Convert from a tensor of shape [nrow(x_explain_with_NaNs), n_samples, n_features] - # to a matrix of shape [(nrow(x_explain_with_NaNs) * n_samples), n_features]. + # Convert from a tensor of shape [nrow(x_explain_with_NaNs), n_MC_samples, n_features] + # to a matrix of shape [(nrow(x_explain_with_NaNs) * n_MC_samples), n_features]. result <- data.table::as.data.table(as.matrix(result$view(c( result$shape[1] * result$shape[2], result$shape[3] @@ -1138,15 +1158,15 @@ vaeac_impute_missing_entries <- function(x_explain_with_NaNs, # If user provide `index_features`, then we add columns needed for shapr computations if (!is.null(index_features)) { - # Add id, id_combination and weights (uniform for the `vaeac` approach) to the result. - result[, c("id", "id_combination", "w") := list( - rep(x = seq(n_explain), each = length(index_features) * n_samples), - rep(x = index_features, each = n_samples, times = n_explain), - 1 / n_samples + # Add id, id_coalition and weights (uniform for the `vaeac` approach) to the result. + result[, c("id", "id_coalition", "w") := list( + rep(x = seq(n_explain), each = length(index_features) * n_MC_samples), + rep(x = index_features, each = n_MC_samples, times = n_explain), + 1 / n_MC_samples )] # Set the key in the data table - data.table::setkeyv(result, c("id", "id_combination")) + data.table::setkeyv(result, c("id", "id_coalition")) } return(result) @@ -1364,19 +1384,6 @@ vaeac_check_mask_gen <- function(mask_gen_coalitions, mask_gen_coalitions_prob, } } -#' Function that checks the verbose parameter -#' -#' @inheritParams vaeac_train_model -#' -#' @return The function does not return anything. -#' -#' @keywords internal -#' @author Lars Henry Berge Olsen -vaeac_check_verbose <- function(verbose) { - if (!is.numeric(verbose) || !(verbose %in% c(0, 1, 2))) { - stop("`vaeac.verbose` must be either `0` (no verbosity), `1` (low verbosity), or `2` (high verbosity).") - } -} #' Function that checks that the save folder exists and for a valid file name #' @@ -1529,7 +1536,7 @@ vaeac_check_parameters <- function(x_train, seed, ...) { # Check verbose parameter - vaeac_check_verbose(verbose = verbose) + check_verbose(verbose = verbose) # Check that the activation function is valid torch::nn_module object vaeac_check_activation_func(activation_function = activation_function) @@ -1655,9 +1662,10 @@ vaeac_check_parameters <- function(x_train, #' during the training of the vaeac model. Used in [torch::dataloader()]. #' @param vaeac.batch_size_sampling Positive integer (default is `NULL`) The number of samples to include in #' each batch when generating the Monte Carlo samples. If `NULL`, then the function generates the Monte Carlo samples -#' for the provided coalitions/combinations and all explicands sent to [shapr::explain()] at the time. -#' The number of coalitions are determined by `n_batches` in [shapr::explain()]. We recommend to tweak `n_batches` -#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory. +#' for the provided coalitions and all explicands sent to [shapr::explain()] at the time. +#' The number of coalitions are determined by the `n_batches` used by [shapr::explain()]. We recommend to tweak +#' `extra_computation_args$max_batch_size` and `extra_computation_args$min_n_batches` +#' rather than `vaeac.batch_size_sampling`. Larger batch sizes are often much faster provided sufficient memory. #' @param vaeac.running_avg_n_values Positive integer (default is `5`). The number of previous IWAE values to include #' when we compute the running means of the IWAE criterion. #' @param vaeac.skip_conn_layer Logical (default is `TRUE`). If `TRUE`, we apply identity skip connections in each @@ -1682,8 +1690,8 @@ vaeac_check_parameters <- function(x_train, #' `vaeac.mask_gen_coalitions` is specified. #' @param vaeac.mask_gen_coalitions Matrix (default is `NULL`). Matrix containing the coalitions that the #' `vaeac` model will be trained on, see [shapr::specified_masks_mask_generator()]. This parameter is used internally -#' in `shapr` when we only consider a subset of coalitions/combinations, i.e., when -#' `n_combinations` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +#' in `shapr` when we only consider a subset of coalitions, i.e., when +#' `n_coalitions` \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., #' when `group` is specified in [shapr::explain()]. #' @param vaeac.mask_gen_coalitions_prob Numeric array (default is `NULL`). Array of length equal to the height #' of `vaeac.mask_gen_coalitions` containing the probabilities of sampling the corresponding coalitions in @@ -1817,8 +1825,8 @@ vaeac_get_mask_generator_name <- function(mask_gen_coalitions, mask_generator_name <- "specified_masks_mask_generator" # Small printout - if (verbose == 2) { - message(paste0("Using 'specified_masks_mask_generator' with '", nrow(mask_gen_coalitions), "' coalitions.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0("Using 'specified_masks_mask_generator' with '", nrow(mask_gen_coalitions), "' coalitions.")) } } else if (length(masking_ratio) == 1) { # We are going to use 'mcar_mask_generator' as masking_ratio is a singleton. @@ -1826,15 +1834,21 @@ vaeac_get_mask_generator_name <- function(mask_gen_coalitions, mask_generator_name <- "mcar_mask_generator" # Small printout - if (verbose == 2) message(paste0("Using 'mcar_mask_generator' with 'masking_ratio = ", masking_ratio, "'.")) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( + "Using 'mcar_mask_generator' with 'masking_ratio = ", + masking_ratio, + "'." + )) + } } else if (length(masking_ratio) > 1) { # We are going to use 'specified_prob_mask_generator' as masking_ratio is a vector (of same length as ncol(x_train). # I.e., masking_ratio[5] specifies the probability of masking 5 features mask_generator_name <- "specified_prob_mask_generator" # We have an array of masking ratios. Then we are using the specified_prob_mask_generator. - if (verbose == 2) { - message(paste0( + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( "Using 'specified_prob_mask_generator' mask generator with 'masking_ratio = [", paste(masking_ratio, collapse = ", "), "]'." )) @@ -2104,10 +2118,12 @@ vaeac_get_data_objects <- function(x_train, # Ensure a valid batch size if (batch_size > length(train_indices)) { - message(paste0( - "Decrease `batch_size` (", batch_size, ") to largest allowed value (", length(train_indices), "), ", - "i.e., the number of training observations." - )) + if ("vS_details" %in% verbose) { + cli::cli_text(paste0( + "Decrease `batch_size` (", batch_size, ") to largest allowed value (", length(train_indices), "), ", + "i.e., the number of training observations." + )) + } batch_size <- length(train_indices) } @@ -2429,19 +2445,34 @@ Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f\n", last_state$val_iwae[-1]$cpu(), last_state$val_iwae_running[-1]$cpu() )) -} -#' Produce message about which batch prepare_data is working on -#' @inheritParams default_doc -#' @inheritParams default_doc_explain -#' @author Lars Henry Berge Olsen -#' @keywords internal -vaeac_prep_message_batch <- function(internal, index_features) { - id_batch <- internal$objects$X[id_combination == index_features[1]]$batch - n_batches <- internal$parameters$n_batches - message(paste0("Generating Monte Carlo samples using `vaeac` for batch ", id_batch, " of ", n_batches, ".")) + # Trying to replace the above, but have not succeeded really. + # msg <- c("\nResults of the `vaeac` training process:", + # sprintf("Best epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # best_epoch, + # last_state$train_vlb[best_epoch]$cpu(), + # last_state$val_iwae[best_epoch]$cpu(), + # last_state$val_iwae_running[best_epoch]$cpu() + # ), + # sprintf("Best running avg epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # best_epoch_running, + # last_state$train_vlb[best_epoch_running]$cpu(), + # last_state$val_iwae[best_epoch_running]$cpu(), + # last_state$val_iwae_running[best_epoch_running]$cpu() + # ), + # sprintf("Last epoch: %d. \tVLB = %.3f \tIWAE = %.3f \tIWAE_running = %.3f", + # last_state$epoch, + # last_state$train_vlb[-1]$cpu(), + # last_state$val_iwae[-1]$cpu(), + # last_state$val_iwae_running[-1]$cpu() + # ) + # ) + # + # + # cli::cli_text(msg) } + # Plot functions ======================================================================================================= #' Plot the training VLB and validation IWAE for `vaeac` models #' @@ -2500,8 +2531,8 @@ vaeac_prep_message_batch <- function(internal, index_features) { #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p0, -#' n_samples = 1, # As we are only interested in the training of the vaeac +#' phi0 = p0, +#' n_MC_samples = 1, # As we are only interested in the training of the vaeac #' vaeac.epochs = 10, # Should be higher in applications. #' vaeac.n_vaeacs_initialize = 1, #' vaeac.width = 16, @@ -2514,8 +2545,8 @@ vaeac_prep_message_batch <- function(internal, index_features) { #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p0, -#' n_samples = 1, # As we are only interested in the training of the vaeac +#' phi0 = p0, +#' n_MC_samples = 1, # As we are only interested in the training of the vaeac #' vaeac.epochs = 10, # Should be higher in applications. #' vaeac.width = 16, #' vaeac.depth = 2, @@ -2735,8 +2766,8 @@ vaeac_plot_eval_crit <- function(explanation_list, #' x_explain = x_explain, #' x_train = x_train, #' approach = "vaeac", -#' prediction_zero = mean(y_train), -#' n_samples = 1, +#' phi0 = mean(y_train), +#' n_MC_samples = 1, #' vaeac.epochs = 10, #' vaeac.n_vaeacs_initialize = 1 #' ) @@ -2815,7 +2846,7 @@ vaeac_plot_imputed_ggpairs <- function( checkpoint <- torch::torch_load(vaeac_model_path) # Get the number of observations in the x_true and features - n_samples <- if (is.null(x_true)) 500 else nrow(x_true) + n_MC_samples <- if (is.null(x_true)) 500 else nrow(x_true) n_features <- checkpoint$n_features # Checking for valid dimension @@ -2830,12 +2861,12 @@ vaeac_plot_imputed_ggpairs <- function( # Impute the missing entries using the vaeac approach. Here we generate x from p(x), so no conditioning. imputed_values <- vaeac_impute_missing_entries( - x_explain_with_NaNs = matrix(NaN, n_samples, checkpoint$n_features), - n_samples = 1, + x_explain_with_NaNs = matrix(NaN, n_MC_samples, checkpoint$n_features), + n_MC_samples = 1, vaeac_model = vaeac_model, checkpoint = checkpoint, sampler = explanation$internal$parameters$vaeac.sampler, - batch_size = n_samples, + batch_size = n_MC_samples, verbose = explanation$internal$parameters$verbose, seed = explanation$internal$parameters$seed ) @@ -2847,7 +2878,7 @@ vaeac_plot_imputed_ggpairs <- function( # Add type variable representing if they are imputed samples or from `x_true` combined_data$type <- - factor(rep(c("True", "Imputed"), times = c(ifelse(is.null(nrow(x_true)), 0, nrow(x_true)), n_samples))) + factor(rep(c("True", "Imputed"), times = c(ifelse(is.null(nrow(x_true)), 0, nrow(x_true)), n_MC_samples))) # Create the ggpairs figure and potentially add title based on the description of the used vaeac model figure <- GGally::ggpairs( diff --git a/R/approach_vaeac_torch_modules.R b/R/approach_vaeac_torch_modules.R index e353327aba41357be6aebcb09391e8c3b6765799..da0118f94b56750c54a74bc66dacedbaca6c31c1 100644 --- a/R/approach_vaeac_torch_modules.R +++ b/R/approach_vaeac_torch_modules.R @@ -1525,7 +1525,7 @@ gauss_cat_sampler_most_likely <- function(one_hot_max_sizes, min_sigma = 1e-4, m distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params col_sample <- torch::torch_argmax(distr$probs, -1)[, NULL]$to(dtype = torch::torch_float()) # Most lik class } - sample <- append(sample, col_sample) # Add the vector of sampled values for the i´th feature to the sample list + sample <- append(sample, col_sample) # Add the vector of sampled values for the i-th feature to the sample list } return(torch::torch_cat(sample, -1)) # Create a 2D torch by column binding the vectors in the list } @@ -1587,7 +1587,7 @@ gauss_cat_sampler_random <- function(one_hot_max_sizes, min_sigma = 1e-4, min_pr distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params col_sample <- distr$sample()$unsqueeze(-1)$to(dtype = torch::torch_float()) # Sample class using class prob } - sample <- append(sample, col_sample) # Add the vector of sampled values for the i´th feature to the sample list + sample <- append(sample, col_sample) # Add the vector of sampled values for the i-th feature to the sample list } return(torch::torch_cat(sample, -1)) # Create a 2D torch by column binding the vectors in the list } @@ -1656,7 +1656,7 @@ gauss_cat_parameters <- function(one_hot_max_sizes, min_sigma = 1e-4, min_prob = distr <- vaeac_categorical_parse_params(params, self$min_prob) # Create a categorical distr based on params current_parameters <- distr$probs # Extract the current probabilities for each classs } - parameters <- append(parameters, current_parameters) # Add the i´th feature's parameters to the parameters list + parameters <- append(parameters, current_parameters) # Add the i-th feature's parameters to the parameters list } return(torch::torch_cat(parameters, -1)) # Create a 2D torch_tensor by column binding the tensors in the list } @@ -1821,7 +1821,7 @@ categorical_to_one_hot_layer <- function(one_hot_max_sizes, add_nans_map_for_col # ONLY FOR CONTINUOUS FEATURES: out_cols now is a list of n_features tensors of shape n x size = n x 1 for # continuous variables. We concatenate them to a matrix of dim n x 2*n_features (in cont case) for prior net, but # for proposal net, it is n x 3*n_features, and they take the form - # [batch1, is.nan1, batch2, is.nan2, …, batch12, is.nan12, mask1, mask2, …, mask12] + # [batch1, is.nan1, batch2, is.nan2, ..., batch12, is.nan12, mask1, mask2, ..., mask12] return(out_cols) } ) diff --git a/R/asymmetric_and_casual_Shapley.R b/R/asymmetric_and_casual_Shapley.R new file mode 100644 index 0000000000000000000000000000000000000000..079883da96ebc9db9cf49ebfa68fb5eacda8d601 --- /dev/null +++ b/R/asymmetric_and_casual_Shapley.R @@ -0,0 +1,583 @@ +# Check functions ------------------------------------------------------------------------------------------------- +#' Check that all explicands has at least one valid MC sample in causal Shapley values +#' +#' @param dt Data.table containing the generated MC samples (and conditional values) after each sampling step +#' @inheritParams explain +#' @inheritParams create_marginal_data_categoric +#' @inheritParams create_marginal_data_training +#' +#' @keywords internal +#' +#' @author Lars Henry Berge Olsen +check_categorical_valid_MCsamp <- function(dt, n_explain, n_MC_samples, joint_probability_dt) { + dt_factor <- dt[, .SD, .SDcols = is.factor] # Get the columns that have been inserted into + dt_factor_names <- copy(names(dt_factor)) # Get their names. Copy as we are to change dt_factor + dt_factor[, id := rep(seq(n_explain), each = n_MC_samples)] # Add an id column + dt_valid_coals <- joint_probability_dt[, dt_factor_names, with = FALSE] # Get the valid feature coalitions + dt_invalid <- dt_factor[!dt_valid_coals, on = dt_factor_names] # Get non valid coalitions + explicand_all_invalid <- dt_invalid[, .N, by = id][N == n_MC_samples] # If all samples for an explicand are invalid + if (nrow(explicand_all_invalid) > 0) { + stop(paste0( + "An explicand has no valid MC feature coalitions. Increase `n_MC_samples` or provide ", + "`joint_prob_dt` containing the probaibilities for unlikely coalitions, too." + )) + } +} + +# Convert function ------------------------------------------------------------------------------------------------ +#' Convert feature names into feature indices +#' +#' Functions that takes a `causal_ordering` specified using strings and convert these strings to feature indices. +#' +#' @param labels Vector of strings containing (the order of) the feature names. +#' @param feat_group_txt String that is either "feature" or "group" based on +#' if `shapr` is computing feature- or group-wise Shapley values +#' @inheritParams explain +#' +#' @return The `causal_ordering` list, but with feature indices (w.r.t. `labels`) instead of feature names. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +convert_feature_name_to_idx <- function(causal_ordering, labels, feat_group_txt) { + # Convert the feature names into feature indices + causal_ordering_match <- match(unlist(causal_ordering), labels) + + # Check that user only provided valid feature names + if (any(is.na(causal_ordering_match))) { + stop(paste0( + "`causal_ordering` contains ", feat_group_txt, " names (`", + paste0(unlist(causal_ordering)[is.na(causal_ordering_match)], collapse = "`, `"), "`) ", + "that are not in the data (`", paste0(labels, collapse = "`, `"), "`).\n" + )) + } + + # Recreate the causal_ordering list with the feature indices + causal_ordering <- relist(causal_ordering_match, causal_ordering) + return(causal_ordering) +} + + +# Create functions ------------------------------------------------------------------------------------------------ +#' Function that samples data from the empirical marginal training distribution +#' +#' @description Sample observations from the empirical distribution P(X) using the training dataset. +#' +#' @param n_explain Integer. The number of explicands/observations to explain. +#' @param Sbar_features Vector of integers containing the features indices to generate marginal observations for. +#' That is, if `Sbar_features` is `c(1,4)`, then we sample `n_MC_samples` observations from \eqn{P(X_1, X_4)} using the +#' empirical training observations (with replacements). That is, we sample the first and fourth feature values from +#' the same training observation, so we do not break the dependence between them. +#' @param n_explain Integer. The number of explicands/observations to explain. +#' @param stable_version Logical. If `TRUE` and `n_MC_samples` > `n_train`, then we include each training observation +#' `n_MC_samples %/% n_train` times and then sample the remaining `n_MC_samples %% n_train samples`. Only the latter is +#' done when `n_MC_samples < n_train`. This is done separately for each explicand. If `FALSE`, we randomly sample the +#' from the observations. +#' +#' @inheritParams explain +#' +#' @return Data table of dimension \eqn{`n_MC_samples` \times `length(Sbar_features)`} with the sampled observations. +#' +#' +#' @examples +#' \dontrun{ +#' data("airquality") +#' data <- data.table::as.data.table(airquality) +#' data <- data[complete.cases(data), ] +#' +#' x_var <- c("Solar.R", "Wind", "Temp", "Month") +#' y_var <- "Ozone" +#' +#' ind_x_explain <- 1:6 +#' x_train <- data[-ind_x_explain, ..x_var] +#' x_train +#' create_marginal_data__training(x_train = x_train, Sbar_features = c(1, 4), n_MC_samples = 10) +#' } +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen +create_marginal_data_training <- function(x_train, + n_explain, + Sbar_features, + n_MC_samples = 1000, + stable_version = TRUE) { + # Get the number of training observations + n_train <- nrow(x_train) + + if (stable_version) { + # If n_MC_samples > n_train, then we include each training observations n_MC_samples %/% n_train times and + # then sample the remaining n_MC_samples %% n_train samples. Only the latter is done when n_MC_samples < n_train. + # This is done separately for each explicand + sampled_indices <- as.vector(sapply( + seq(n_explain), + function(x) { + c( + rep(seq(n_train), each = n_MC_samples %/% n_train), + sample(n_train, n_MC_samples %% n_train) + ) + } + )) + } else { + # sample everything and not guarantee that we use all training observations + sampled_indices <- sample(n_train, n_MC_samples * n_explain, replace = TRUE) + } + + # Sample the marginal data and return them + return(x_train[sampled_indices, Sbar_features, with = FALSE]) +} + +#' Create marginal categorical data for causal Shapley values +#' +#' @description +#' This function is used when we generate marginal data for the categorical approach when we have several sampling +#' steps. We need to treat this separately, as we here in the marginal step CANNOT make feature values such +#' that the combination of those and the feature values we condition in S are NOT in +#' `categorical.joint_prob_dt`. If we do this, then we cannot progress further in the chain of sampling +#' steps. E.g., X1 in (1,2,3), X2 in (1,2,3), and X3 in (1,2,3). +#' We know X2 = 2, and let causal structure be X1 -> X2 -> X3. Assume that +#' P(X1 = 1, X2 = 2, X = 3) = P(X1 = 2, X2 = 2, X = 3) = 1/2. Then there is no point +#' generating X1 = 3, as we then cannot generate X3. +#' The solution is only to generate the values which can proceed through the whole +#' chain of sampling steps. To do that, we have to ensure the the marginal sampling +#' respects the valid feature coalitions for all sets of conditional features, i.e., +#' the features in `features_steps_cond_on`. +#' We sample from the valid coalitions using the MARGINAL probabilities. +#' +#' @param Sbar_features Vector of integers containing the features indices to generate marginal observations for. +#' That is, if `Sbar_features` is `c(1,4)`, then we sample `n_MC_samples` observations from \eqn{P(X_1, X_4)}. +#' That is, we sample the first and fourth feature values from the same valid feature coalition using +#' the marginal probability, so we do not break the dependence between them. +#' @param S_original Vector of integers containing the features indices of the original coalition `S`. I.e., not the +#' features in the current sampling step, but the features are known to us before starting the chain of sampling steps. +#' @param joint_prob_dt Data.table containing the joint probability distribution for each coalition of feature values. +#' @inheritParams explain +#' +#' @return Data table of dimension \eqn{(`n_MC_samples` * `nrow(x_explain)`) \times `length(Sbar_features)`} with the +#' sampled observations. +#' +#' @keywords internal +#' +#' @author Lars Henry Berge Olsen +create_marginal_data_categoric <- function(n_MC_samples, + x_explain, + Sbar_features, + S_original, + joint_prob_dt) { + # Get the number of features and their names + n_features <- ncol(x_explain) + feature_names <- colnames(x_explain) + + # Get the feature names of the features we are to generate + Sbar_now_names <- feature_names[Sbar_features] + + # Make a copy of the explicands and add an id + x_explain_copy <- data.table::copy(x_explain)[, id := .I] + + # Get the features that are in S originally and the features we are creating marginal values for + S_original_names <- feature_names[S_original] + S_original_names_with_id <- c("id", S_original_names) + relevant_features <- sort(c(Sbar_features, S_original)) + relevant_features_names <- feature_names[relevant_features] + + # Get the marginal probabilities for the relevant feature coalitions + marginal_prob_dt <- joint_prob_dt[, list(prob = sum(joint_prob)), by = relevant_features_names] + + # Get all valid feature coalitions for the relevant features + dt_valid_coalitions <- unique(joint_prob_dt[, relevant_features, with = FALSE]) + + # Get relevant feature coalitions that are valid for the explicands + dt_valid_coalitions_relevant <- data.table::merge.data.table(x_explain_copy[, S_original_names_with_id, with = FALSE], + dt_valid_coalitions, + by = S_original_names, + allow.cartesian = TRUE + ) + + # Merge the relevant feature coalitions with their marginal probabilities + dt_valid_coal_marg_prob <- data.table::merge.data.table(dt_valid_coalitions_relevant, + marginal_prob_dt, + by = relevant_features_names + ) + dt_valid_coal_marg_prob[, prob := prob / sum(prob), by = id] # Make prob sum to 1 for each explicand + data.table::setkey(dt_valid_coal_marg_prob, "id") # Set id to key so id is in increasing order + + # Sample n_MC_samples from the valid coalitions using the marginal probabilities and extract the Sbar columns + dt_return <- + dt_valid_coal_marg_prob[, .SD[sample(.N, n_MC_samples, replace = TRUE, prob = prob)], + by = id + ][, Sbar_now_names, with = FALSE] + return(dt_return) +} + + + + + + + +# Get functions --------------------------------------------------------------------------------------------------- +#' Get all coalitions satisfying the causal ordering +#' +#' @description +#' This function is only relevant when we are computing asymmetric Shapley values. +#' For symmetric Shapley values (both regular and causal), all coalitions are allowed. +#' +#' @inheritParams explain +#' +#' @param sort_features_in_coalitions Boolean. If `TRUE`, then the feature indices in the +#' coalitions are sorted in increasing order. If `FALSE`, then the function maintains the +#' order of features within each group given in `causal_ordering`. +#' +#' @return List of vectors containing all coalitions that respects the causal ordering. +#' @keywords internal +#' @author Lars Henry Berge Olsen +get_valid_causal_coalitions <- function(causal_ordering, sort_features_in_coalitions = TRUE) { + # Create a list to store the possible coalitions and start with the empty coalition + coalitions <- list(numeric(0)) + + # Iterate over the remaining partial causal orderings + for (i in seq(1, length(causal_ordering))) { + # Get the number of features in the ith component of the (partial) causal ordering + ith_order_length <- length(causal_ordering[[i]]) + + # Create a list of vectors containing all possible feature coalitions except the empty one (with temp indices) + ith_order_coalitions <- + unlist(lapply(seq(ith_order_length), utils::combn, x = ith_order_length, simplify = FALSE), recursive = FALSE) + + # Get the ancestors of the ith component of the (partial) causal ordering + ancestors <- coalitions[[length(coalitions)]] + + # Update the indices by adding the number of ancestors and concatenate the ancestors + coalitions <- + c(coalitions, sapply(ith_order_coalitions, function(x) c(ancestors, x + length(ancestors)), simplify = FALSE)) + } + + # Sort the causal components such that the singletons are in the right order + if (sort_features_in_coalitions) causal_ordering <- sapply(causal_ordering, sort) + + # Convert the temporary indices to the correct feature indices + coalitions <- sapply(coalitions, function(x) unlist(causal_ordering)[x]) + + # Sort the coalitions + if (sort_features_in_coalitions) coalitions <- sapply(coalitions, sort) + + return(coalitions) +} + +#' Get the number of coalitions that respects the causal ordering +#' +#' @inheritParams explain +#' +#' @details The function computes the number of coalitions that respects the causal ordering by computing the number +#' of coalitions in each partial causal component and then summing these. We compute +#' the number of coalitions in the \eqn{i}th a partial causal component by \eqn{2^n - 1}, +#' where \eqn{n} is the number of features in the the \eqn{i}th partial causal component +#' and we subtract one as we do not want to include the situation where no features in +#' the \eqn{i}th partial causal component are present. In the end, we add 1 for the +#' empty coalition. +#' +#' @examples +#' \dontrun{ +#' get_max_n_coalitions_causal(list(1:10)) # 2^10 = 1024 (no causal order) +#' get_max_n_coalitions_causal(list(1:3, 4:7, 8:10)) # 30 +#' get_max_n_coalitions_causal(list(1:3, 4:5, 6:7, 8, 9:10)) # 18 +#' get_max_n_coalitions_causal(list(1:3, c(4, 8), c(5, 7), 6, 9:10)) # 18 +#' get_max_n_coalitions_causal(list(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) # 11 +#' } +#' +#' @return Integer. The (maximum) number of coalitions that respects the causal ordering. +#' @keywords internal +#' @author Lars Henry Berge Olsen +get_max_n_coalitions_causal <- function(causal_ordering) { + return(sum(2^sapply(causal_ordering, length)) - length(causal_ordering) + 1) +} + +#' Get the steps for generating MC samples for coalitions following a causal ordering +#' +#' @inheritParams explain +#' +#' @param S Integer matrix of dimension \code{n_coalitions_valid x m}, where \code{n_coalitions_valid} equals +#' the total number of valid coalitions that respect the causal ordering given in `causal_ordering` and \code{m} equals +#' the total number of features. +#' @param as_string Boolean. +#' If the returned object is to be a list of lists of integers or a list of vectors of strings. +#' +#' @return Depends on the value of the parameter `as_string`. If a string, then `results[j]` is a vector specifying +#' the process of generating the samples for coalition `j`. The length of `results[j]` is the number of steps, and +#' `results[j][i]` is a string of the form `features_to_sample|features_to_condition_on`. If the +#' `features_to_condition_on` part is blank, then we are to sample from the marginal distribution. +#' For `as_string == FALSE`, then we rather return a vector where `results[[j]][[i]]` contains the elements +#' `Sbar` and `S` representing the features to sample and condition on, respectively. +#' +#' @examples +#' \dontrun{ +#' m <- 5 +#' causal_ordering <- list(1:2, 3:4, 5) +#' S <- shapr::feature_matrix_cpp(get_valid_causal_coalitions(causal_ordering = causal_ordering), +#' m = m +#' ) +#' confounding <- c(TRUE, TRUE, FALSE) +#' get_S_causal_steps(S, causal_ordering, confounding, as_string = TRUE) +#' +#' # Look at the effect of changing the confounding assumptions +#' SS1 <- get_S_causal_steps(S, causal_ordering, +#' confounding = c(FALSE, FALSE, FALSE), +#' as_string = TRUE +#' ) +#' SS2 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, FALSE, FALSE), as_string = TRUE) +#' SS3 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, FALSE), as_string = TRUE) +#' SS4 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, TRUE), as_string = TRUE) +#' +#' all.equal(SS1, SS2) +#' SS1[[2]] # Condition on 1 as there is no confounding in the first component +#' SS2[[2]] # Do NOT condition on 1 as there is confounding in the first component +#' SS1[[3]] +#' SS2[[3]] +#' +#' all.equal(SS1, SS3) +#' SS1[[2]] # Condition on 1 as there is no confounding in the first component +#' SS3[[2]] # Do NOT condition on 1 as there is confounding in the first component +#' SS1[[5]] # Condition on 3 as there is no confounding in the second component +#' SS3[[5]] # Do NOT condition on 3 as there is confounding in the second component +#' SS1[[6]] +#' SS3[[6]] +#' +#' all.equal(SS2, SS3) +#' SS2[[5]] +#' SS3[[5]] +#' SS2[[6]] +#' SS3[[6]] +#' +#' all.equal(SS3, SS4) # No difference as the last component is a singleton +#' } +#' @author Lars Henry Berge Olsen +#' @keywords internal +get_S_causal_steps <- function(S, causal_ordering, confounding, as_string = FALSE) { + # List to store the sampling process + results <- vector("list", nrow(S)) + names(results) <- paste0("id_coalition_", seq_len(nrow(S))) + + # Iterate over the coalitions + for (j in seq(2, nrow(S) - 1)) { + # Get the given and dependent features for this coalition + index_given <- seq_len(ncol(S))[as.logical(S[j, ])] + index_dependent <- seq_len(ncol(S))[as.logical(1 - S[j, ])] + + # Iterate over the causal orderings + for (i in seq_along(causal_ordering)) { + # check overlap between index_dependent and ith causal component + to_sample <- intersect(causal_ordering[[i]], index_dependent) + + if (length(to_sample) > 0) { + to_condition <- unlist(causal_ordering[0:(i - 1)]) # Condition on all features in ancestor components + + # If confounding is FALSE, add intervened features in the same component to the `to_condition` set. + # If confounding is TRUE, then no extra conditioning. + if (!confounding[i]) to_condition <- union(intersect(causal_ordering[[i]], index_given), to_condition) + + # Save Sbar and S (sorting is for the visual) + to_sample <- sort(to_sample) + to_condition <- sort(to_condition) + tmp_name <- paste0("id_coalition_", j) + if (as_string) { + results[[j]] <- + c(results[[tmp_name]], paste0(paste0(to_sample, collapse = ","), "|", paste0(to_condition, collapse = ","))) + } else { + results[[tmp_name]][[paste0("step_", length(results[[j]]) + 1)]] <- list(Sbar = to_sample, S = to_condition) + } + } + } + } + + return(results) +} + +# Prepare data function ------------------------------------------------------------------------------------------- +#' Generate data used for predictions and Monte Carlo integration for causal Shapley values +#' +#' This function loops over the given coalitions, and for each coalition it extracts the +#' chain of relevant sampling steps provided in `internal$object$S_causal`. This chain +#' can contain sampling from marginal and conditional distributions. We use the approach given by +#' `internal$parameters$approach` to generate the samples from the conditional distributions, and +#' we iteratively call `prepare_data()` with a modified `internal_copy` list to reuse code. +#' However, this also means that chains with the same conditional distributions will retrain a +#' model of said conditional distributions several times. +#' For the marginal distribution, we sample from the Gaussian marginals when the approach is +#' `gaussian` and from the marginals of the training data for all other approaches. Note that +#' we could extend the code to sample from the marginal (gaussian) copula, too, when `approach` is +#' `copula`. +#' +#' @inheritParams default_doc_explain +#' @param ... Currently not used. +#' +#' @return A data.table containing simulated data that respects the (partial) causal ordering and the +#' the confounding assumptions. The data is used to estimate the contribution function by Monte Carlo integration. +#' +#' @export +#' @keywords internal +#' @author Lars Henry Berge Olsen +prepare_data_causal <- function(internal, index_features = NULL, ...) { + # Recall that here, index_features is a vector of id_coalitions, i.e., indicating which rows in S to use. + # Also note that we are guaranteed that index_features does not include the empty or grand coalition + + # Extract iteration specific variables + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + S <- internal$iter_list[[iter]]$S + S_causal_steps <- internal$iter_list[[iter]]$S_causal_steps + + # Extract the needed variables + x_train <- internal$data$x_train + approach <- internal$parameters$approach # Can only be single approach + x_explain <- internal$data$x_explain + n_explain <- internal$parameters$n_explain + n_features <- internal$parameters$n_features + n_MC_samples <- internal$parameters$n_MC_samples + feature_names <- internal$parameters$feature_names + + # Create a list to store the populated data tables with the MC samples + dt_list <- list() + + # Create a copy of the internal list. We will change its x_explain, n_explain, and n_MC_samples such + # that we can the prepare_data() function which was not originally designed for the step-wise/iterative + # sampling process which is needed for Causal Shapley values where we sample from P(Sbar_i | S_i) and + # the S and Sbar changes in the iterative process. So those also the number of MC samples we need to generate. + internal_copy <- copy(internal) + + # Loop over the coalitions in the batch + index_feature_idx <- 1 + for (index_feature_idx in seq_along(index_features)) { + # Extract the index of the current coalition + index_feature <- index_features[index_feature_idx] + + # Reset the internal_copy list for each new coalition + if (index_feature_idx > 1) { + internal_copy$data$x_explain <- x_explain + internal_copy$parameters$n_explain <- n_explain + internal_copy$parameters$n_MC_samples <- n_MC_samples + } + + # Create the empty data table which we are to populate with the Monte Carlo samples for each coalition + dt <- data.table(matrix(nrow = n_explain * n_MC_samples, ncol = n_features)) + # if (approach == "categorical") dt[, names(dt) := lapply(.SD, as.factor)] # Needed for the categorical approach + colnames(dt) <- feature_names + + # Populate the data table with the features we condition on + S_names <- feature_names[as.logical(S[index_feature, ])] + dt[, (S_names) := x_explain[rep(seq(n_explain), each = n_MC_samples), .SD, .SDcols = S_names]] + + # Get the iterative sampling process for the current coalition + S_causal_steps_now <- internal$iter_list[[iter]]$S_causal_steps[[index_feature]] + + # Loop over the steps in the iterative sampling process to generate MC samples for the unconditional features + sampling_step_idx <- 2 + for (sampling_step_idx in seq_along(S_causal_steps_now)) { + # Set flag indicating whether or not we are in the first sampling step, as the the gaussian and copula + # approaches need to know this to change their sampling procedure to ensure correctly generated MC samples + internal_copy$parameters$causal_first_step <- sampling_step_idx == 1 + + # Get the S (the conditional features) and Sbar (the unconditional features) in the current sampling step + S_now <- S_causal_steps_now[[sampling_step_idx]]$S # The features to condition on in this sampling step + Sbar_now <- S_causal_steps_now[[sampling_step_idx]]$Sbar # The features to sample in this sampling step + Sbar_now_names <- feature_names[Sbar_now] + + # Check if we are to sample from the marginal or conditional distribution + if (is.null(S_now)) { + # Marginal distribution as there are no variables to condition on + + # Generate the marginal data either form the Gaussian or categorical distribution or the training data + # TODO: Can extend to also sample from the marginals of the gaussian copula and vaeac + if (approach == "gaussian") { + # Sample marginal data from the marginal gaussian distribution + dt_Sbar_now_marginal_values <- create_marginal_data_gaussian( + n_MC_samples = n_MC_samples * n_explain, + Sbar_features = Sbar_now, + mu = internal$parameters$gaussian.mu, + cov_mat = internal$parameters$gaussian.cov_mat + ) + } else if (approach == "categorical" && length(S_causal_steps_now) > 1) { + # For categorical approach with several sampling steps, we make sure to only sample feature coalitions + # that are present in `categorical.joint_prob_dt` when combined with the features in `S_names`. + dt_Sbar_now_marginal_values <- create_marginal_data_categoric( + n_MC_samples = n_MC_samples, + x_explain = x_explain, + Sbar_features = Sbar_now, + S_original = seq(n_features)[as.logical(S[index_feature, ])], + joint_prob_dt = internal$parameters$categorical.joint_prob_dt + ) + } else { + # Sample from the training data for all approaches except the gaussian approach + # and except the categorical approach for settings with several sampling steps + dt_Sbar_now_marginal_values <- create_marginal_data_training( + x_train = x_train, + n_explain = n_explain, + Sbar_features = Sbar_now, + n_MC_samples = n_MC_samples, + stable_version = TRUE + ) + } + + # Insert the marginal values into the data table + dt[, (Sbar_now_names) := dt_Sbar_now_marginal_values] + } else { + # Conditional distribution as there are variables to condition on + + # Create dummy versions of S and X only containing the current conditional features, and index_features is 1. + internal_copy$iter_list[[iter]]$S <- matrix(0, ncol = n_features, nrow = 1) + internal_copy$iter_list[[iter]]$S[1, S_now] <- 1 + internal_copy$iter_list[[iter]]$X <- + data.table(id_coalition = 1, features = list(S_now), n_features = length(S_now)) + + # Generate the MC samples conditioning on S_now + dt_new <- prepare_data(internal_copy, index_features = 1, ...) + + if (approach %in% c("independence", "empirical", "ctree", "categorical")) { + # These approaches produce weighted MC samples, i.e., the do not necessarily generate n_MC_samples MC samples. + # We ensure n_MC_samples by weighted sampling (with replacements) those ids with not n_MC_samples MC samples. + n_samp_now <- internal_copy$parameters$n_MC_samples + dt_new <- + dt_new[, .SD[if (.N == n_samp_now) seq(.N) else sample(.N, n_samp_now, replace = TRUE, prob = w)], by = id] + + # Check that dt_new has the right number of rows. + if (nrow(dt_new) != n_explain * n_MC_samples) stop("`dt_new` does not have the right number of rows.\n") + } + + # Insert/keep only the features in Sbar_now into dt + dt[, (Sbar_now_names) := dt_new[, .SD, .SDcols = Sbar_now_names]] + } + + # Here we check if all the generated samples are outside the joint_prob_dt + if (approach == "categorical" && length(S_causal_steps_now) > 1) { + check_categorical_valid_MCsamp( + dt = dt, + n_explain = n_explain, + n_MC_samples = n_MC_samples, + joint_probability_dt = internal$parameters$categorical.joint_prob_dt + ) + } + + # Update the x_explain in internal_copy such that in the next sampling step use the values in dt + # as the conditional feature values. Furthermore, we set n_MC_samples to 1 such that we in the next + # step generate one new value for each of the n_MC_samples MC samples we have begun to generate. + internal_copy$data$x_explain <- dt + internal_copy$parameters$n_explain <- nrow(dt) + internal_copy$parameters$n_MC_samples <- 1 + } + + # Save the now populated data table + dt_list[[index_feature_idx]] <- dt + } + + # Combine the list of data tables and add the id columns + dt <- data.table::rbindlist(dt_list, fill = TRUE) + dt[, id_coalition := rep(index_features, each = n_MC_samples * n_explain)] + dt[, id := rep(seq(n_explain), each = n_MC_samples, times = length(index_features))] + dt[, w := 1 / n_MC_samples] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) + + # Aggregate the weights for the non-unique rows such that we only return a data table with unique rows. + # Only done for these approaches as they are the only approaches that are likely to return duplicates. + if (approach %in% c("independence", "empirical", "ctree", "categorical")) { + dt <- dt[, list(w = sum(w)), by = c("id_coalition", "id", feature_names)] + } + + return(dt) +} diff --git a/R/check_convergence.R b/R/check_convergence.R new file mode 100644 index 0000000000000000000000000000000000000000..b260d9a77f0b1ad974306563a13b3b4c680bea3d --- /dev/null +++ b/R/check_convergence.R @@ -0,0 +1,82 @@ +#' Checks the convergence according to the convergence threshold +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +check_convergence <- function(internal) { + iter <- length(internal$iter_list) + + convergence_tol <- internal$parameters$iterative_args$convergence_tol + max_iter <- internal$parameters$iterative_args$max_iter + max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions + paired_shap_sampling <- internal$parameters$paired_shap_sampling + n_shapley_values <- internal$parameters$n_shapley_values + + exact <- internal$iter_list[[iter]]$exact + + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd + + n_sampled_coalitions <- internal$iter_list[[iter]]$n_coalitions - 2 # Subtract the zero and full predictions + + max_sd <- dt_shapley_sd[, max(.SD, na.rm = TRUE), .SDcols = -1, by = .I]$V1 # Max per prediction + max_sd0 <- max_sd * sqrt(n_sampled_coalitions) # Scales UP the sd as it scales at this rate + + dt_shapley_est0 <- copy(dt_shapley_est) + + est_required_coals_per_ex_id <- est_required_coalitions <- est_remaining_coalitions <- overall_conv_measure <- NA + + if (isTRUE(exact)) { + converged_exact <- TRUE + converged_sd <- FALSE + } else { + converged_exact <- FALSE + if (!is.null(convergence_tol)) { + dt_shapley_est0[, maxval := max(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I] + dt_shapley_est0[, minval := min(.SD, na.rm = TRUE), .SDcols = -c(1, 2), by = .I] + dt_shapley_est0[, max_sd0 := max_sd0] + dt_shapley_est0[, req_samples := (max_sd0 / ((maxval - minval) * convergence_tol))^2] + dt_shapley_est0[, conv_measure := max_sd0 / ((maxval - minval) * sqrt(n_sampled_coalitions))] + dt_shapley_est0[, req_samples := min(req_samples, 2^n_shapley_values - 2)] + + est_required_coalitions <- ceiling(dt_shapley_est0[, median(req_samples)]) # TODO:Consider other ways to do this + if (isTRUE(paired_shap_sampling)) { + est_required_coalitions <- ceiling(est_required_coalitions * 0.5) * 2 + } + est_remaining_coalitions <- max(0, est_required_coalitions - (n_sampled_coalitions + 2)) + + overall_conv_measure <- dt_shapley_est0[, median(conv_measure)] # TODO:Consider other ways to do this + + converged_sd <- (est_remaining_coalitions == 0) + + est_required_coals_per_ex_id <- dt_shapley_est0[, req_samples] + names(est_required_coals_per_ex_id) <- paste0( + "req_samples_explain_id_", + seq_along(est_required_coals_per_ex_id) + ) + } else { + converged_sd <- FALSE + } + } + + converged_max_n_coalitions <- (n_sampled_coalitions + 2 >= max_n_coalitions) + + converged_max_iter <- (iter >= max_iter) + + converged <- converged_exact || converged_sd || converged_max_iter || converged_max_n_coalitions + + internal$iter_list[[iter]]$converged <- converged + internal$iter_list[[iter]]$converged_exact <- converged_exact + internal$iter_list[[iter]]$converged_sd <- converged_sd + internal$iter_list[[iter]]$converged_max_iter <- converged_max_iter + internal$iter_list[[iter]]$converged_max_n_coalitions <- converged_max_n_coalitions + internal$iter_list[[iter]]$est_required_coalitions <- est_required_coalitions + internal$iter_list[[iter]]$est_remaining_coalitions <- est_remaining_coalitions + internal$iter_list[[iter]]$est_required_coals_per_ex_id <- as.list(est_required_coals_per_ex_id) + internal$iter_list[[iter]]$overall_conv_measure <- overall_conv_measure + + internal$timing_list$check_convergence <- Sys.time() + + return(internal) +} diff --git a/R/cli.R b/R/cli.R new file mode 100644 index 0000000000000000000000000000000000000000..694c7f7cddb97451276ab912f74c84c8e24249c6 --- /dev/null +++ b/R/cli.R @@ -0,0 +1,125 @@ +#' Printing startup messages with cli +#' +#' @param model_class String. +#' Class of the model as a string +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_startup <- function(internal, model_class, verbose) { + init_time <- internal$timing_list$init_time + + is_groupwise <- internal$parameters$is_groupwise + approach <- internal$parameters$approach + iterative <- internal$parameters$iterative + n_shapley_values <- internal$parameters$n_shapley_values + n_explain <- internal$parameters$n_explain + saving_path <- internal$parameters$output_args$saving_path + causal_ordering_names_string <- internal$parameters$causal_ordering_names_string + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal + confounding_string <- internal$parameters$confounding_string + + + feat_group_txt <- ifelse(is_groupwise, "group-wise", "feature-wise") + iterative_txt <- ifelse(iterative, "iterative", "non-iterative") + + testing <- internal$parameters$testing + asymmetric <- internal$parameters$asymmetric + confounding <- internal$parameters$confounding + + + line_vec <- "Model class: {.cls {model_class}}" + line_vec <- c(line_vec, "Approach: {.emph {approach}}") + line_vec <- c(line_vec, "Iterative estimation: {.emph {iterative}}") + line_vec <- c(line_vec, "Number of {.emph {feat_group_txt}} Shapley values: {n_shapley_values}") + line_vec <- c(line_vec, "Number of observations to explain: {n_explain}") + if (isTRUE(asymmetric)) { + line_vec <- c(line_vec, "Number of asymmetric coalitions: {max_n_coalitions_causal}") + } + if (isTRUE(asymmetric) || !is.null(confounding)) { + line_vec <- c(line_vec, "Causal ordering: {causal_ordering_names_string}") + } + if (!is.null(confounding)) { + line_vec <- c(line_vec, "Components with confounding: {confounding_string}") + } + if (isFALSE(testing)) { + line_vec <- c(line_vec, "Computations (temporary) saved at: {.path {saving_path}}") + } + + if ("basic" %in% verbose) { + if (isFALSE(testing)) { + cli::cli_h1("Starting {.fn shapr::explain} at {round(init_time)}") + } + cli::cli_ul(line_vec) + } + + if ("vS_details" %in% verbose) { + if (any(c("regression_surrogate", "regression_separate") %in% approach)) { + reg_desc <- paste0(capture.output(internal$parameters$regression.model), collapse = "\n") + cli::cli_h3("Additional details about the regression model") + cli::cli_text(reg_desc) + } + } + + if ("basic" %in% verbose) { + if (isTRUE(iterative)) { + msg <- "iterative computation started" + } else { + msg <- "Main computation started" + } + cli::cli_h2(cli::col_blue(msg)) + } +} + +#' Printing messages in compute_vS with cli +#' +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_compute_vS <- function(internal) { + verbose <- internal$parameters$verbose + approach <- internal$parameters$approach + + if ("progress" %in% verbose) { + cli::cli_progress_step("Computing vS") + } + if ("vS_details" %in% verbose) { + if ("regression_separate" %in% approach) { + tuning <- internal$parameters$regression.tune + if (isTRUE(tuning)) { + cli::cli_h2("Extra info about the tuning of the regression model") + } + } + } +} + +#' Printing messages in iterative procedure with cli +#' +#' @inheritParams default_doc_explain +#' @inheritParams explain +#' +#' @export +#' @keywords internal +cli_iter <- function(verbose, internal, iter) { + iterative <- internal$parameters$iterative + asymmetric <- internal$parameters$asymmetric + + if (!is.null(verbose) && isTRUE(iterative)) { + cli::cli_h1("Iteration {iter}") + } + + if ("basic" %in% verbose) { + new_coal <- internal$iter_list[[iter]]$new_n_coalitions + tot_coal <- internal$iter_list[[iter]]$n_coalitions + all_coal <- ifelse(asymmetric, internal$parameters$max_n_coalitions, 2^internal$parameters$n_shapley_values) + + extra_msg <- ifelse(iterative, ", {new_coal} new", "") + + msg <- paste0("Using {tot_coal} of {all_coal} coalitions", extra_msg, ". ") + + cli::cli_alert_info(msg) + } +} diff --git a/R/compute_estimates.R b/R/compute_estimates.R new file mode 100644 index 0000000000000000000000000000000000000000..05b7fd1d82dab83ad4e20e51827745610a04916d --- /dev/null +++ b/R/compute_estimates.R @@ -0,0 +1,401 @@ +#' Computes the the Shapley values and their standard deviation given the `v(S)` +#' +#' @inheritParams default_doc_explain +#' @param vS_list List +#' Output from [compute_vS()] +#' +#' @export +#' @keywords internal +compute_estimates <- function(internal, vS_list) { + verbose <- internal$parameters$verbose + type <- internal$parameters$type + + internal$timing_list$compute_vS <- Sys.time() + + + iter <- length(internal$iter_list) + compute_sd <- internal$iter_list[[iter]]$compute_sd + + n_boot_samps <- internal$parameters$extra_computation_args$n_boot_samps + + processed_vS_list <- postprocess_vS_list( + vS_list = vS_list, + internal = internal + ) + + internal$timing_list$postprocess_vS <- Sys.time() + + + if ("progress" %in% verbose) { + cli::cli_progress_step("Computing Shapley value estimates") + } + + # Compute the Shapley values + dt_shapley_est <- compute_shapley_new(internal, processed_vS_list$dt_vS) + + internal$timing_list$compute_shapley <- Sys.time() + + if (compute_sd) { + if ("progress" %in% verbose) { + cli::cli_progress_step("Boostrapping Shapley value sds") + } + + dt_shapley_sd <- bootstrap_shapley(internal, n_boot_samps = n_boot_samps, processed_vS_list$dt_vS) + + internal$timing_list$compute_bootstrap <- Sys.time() + } else { + dt_shapley_sd <- dt_shapley_est * 0 + } + + + + # Adding explain_id to the output dt + if (type != "forecast") { + dt_shapley_est[, explain_id := .I] + setcolorder(dt_shapley_est, "explain_id") + dt_shapley_sd[, explain_id := .I] + setcolorder(dt_shapley_sd, "explain_id") + } + + + internal$iter_list[[iter]]$dt_shapley_est <- dt_shapley_est + internal$iter_list[[iter]]$dt_shapley_sd <- dt_shapley_sd + internal$iter_list[[iter]]$vS_list <- vS_list + internal$iter_list[[iter]]$dt_vS <- processed_vS_list$dt_vS + + # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) + internal$output <- processed_vS_list + + if ("basic" %in% verbose) { + cli::cli_progress_done() + } + + return(internal) +} + +#' @keywords internal +postprocess_vS_list <- function(vS_list, internal) { + keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS + phi0 <- internal$parameters$phi0 + n_explain <- internal$parameters$n_explain + + # Appending the zero-prediction to the list + dt_vS0 <- as.data.table(rbind(c(1, rep(phi0, n_explain)))) + + # Extracting/merging the data tables from the batch running + # TODO: Need a memory and speed optimized way to transform the output form dt_vS_list to two different lists, + # I.e. without copying the data more than once. For now I have modified run_batch such that it + # if keep_samp_for_vS=FALSE + # then there is only one copy, but there are two if keep_samp_for_vS=TRUE. This might be OK since the + # latter is used rarely + if (keep_samp_for_vS) { + names(dt_vS0) <- names(vS_list[[1]][[1]]) + + vS_list[[length(vS_list) + 1]] <- list(dt_vS0, NULL) + + dt_vS <- rbindlist(lapply(vS_list, `[[`, 1)) + + dt_samp_for_vS <- rbindlist(lapply(vS_list, `[[`, 2), use.names = TRUE) + + data.table::setorder(dt_samp_for_vS, id_coalition) + } else { + names(dt_vS0) <- names(vS_list[[1]]) + + vS_list[[length(vS_list) + 1]] <- dt_vS0 + + dt_vS <- rbindlist(vS_list) + dt_samp_for_vS <- NULL + } + + data.table::setorder(dt_vS, id_coalition) + + dt_vS <- unique(dt_vS, by = "id_coalition") # To remove duplicated full pred row in the iterative procedure + + output <- list( + dt_vS = dt_vS, + dt_samp_for_vS = dt_samp_for_vS + ) + return(output) +} + + +#' Compute shapley values +#' @param dt_vS The contribution matrix. +#' +#' @inheritParams default_doc +#' +#' @return A `data.table` with Shapley values for each test observation. +#' @export +#' @keywords internal +compute_shapley_new <- function(internal, dt_vS) { + is_groupwise <- internal$parameters$is_groupwise + type <- internal$parameters$type + + iter <- length(internal$iter_list) + + W <- internal$iter_list[[iter]]$W + + shap_names <- internal$parameters$shap_names + + # If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + horizon <- internal$parameters$horizon + cols_per_horizon <- internal$objects$cols_per_horizon + shap_names <- internal$parameters$shap_names + W_list <- internal$objects$W_list + + kshap_list <- list() + for (i in seq_len(horizon)) { + W0 <- W_list[[i]] + + dt_vS0 <- merge(dt_vS, id_coalition_mapper_dt[horizon == i], by = "id_coalition", all.y = TRUE) + data.table::setorder(dt_vS0, horizon_id_coalition) + these_vS0_cols <- grep(paste0("p_hat", i, "_"), names(dt_vS0)) + + kshap0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE])) + kshap_list[[i]] <- data.table::as.data.table(kshap0) + + if (!is_groupwise) { + names(kshap_list[[i]]) <- c("none", cols_per_horizon[[i]]) + } else { + names(kshap_list[[i]]) <- c("none", shap_names) + } + } + + dt_kshap <- cbind(internal$parameters$output_labels, rbindlist(kshap_list, fill = TRUE)) + } else { + kshap <- t(W %*% as.matrix(dt_vS[, -"id_coalition"])) + dt_kshap <- data.table::as.data.table(kshap) + colnames(dt_kshap) <- c("none", shap_names) + } + + return(dt_kshap) +} + +bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + + set.seed(seed) + + X_org <- copy(X) + n_explain <- internal$parameters$n_explain + n_features <- internal$parameters$n_features + shap_names <- internal$parameters$shap_names + paired_shap_sampling <- internal$parameters$paired_shap_sampling + shapley_reweight <- internal$parameters$kernelSHAP_reweighting + + boot_sd_array <- array(NA, dim = c(n_explain, n_features + 1, n_boot_samps)) + + X_keep <- X_org[c(1, .N), .(id_coalition, features, n_features, N, shapley_weight)] + X_samp <- X_org[-c(1, .N), .(id_coalition, features, n_features, N, shapley_weight, sample_freq)] + X_samp[, features_tmp := sapply(features, paste, collapse = " ")] + + n_coalitions_boot <- X_samp[, sum(sample_freq)] + + for (i in seq_len(n_boot_samps)) { + if (paired_shap_sampling) { + # Sample with replacement + X_boot00 <- X_samp[ + sample.int( + n = .N, + size = ceiling(n_coalitions_boot / 2), + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, features, n_features, N) + ] + + X_boot00[, features_tmp := sapply(features, paste, collapse = " ")] + # Not sure why I have to two the next two lines in two steps, but I don't get it to work otherwise + boot_features_dup <- lapply(X_boot00$features, function(x) seq(n_features)[-x]) + X_boot00[, features_dup := boot_features_dup] + X_boot00[, features_dup_tmp := sapply(features_dup, paste, collapse = " ")] + + # Extract the paired coalitions from X_samp + X_boot00_paired <- merge(X_boot00[, .(features_dup_tmp)], + X_samp[, .(id_coalition, features, n_features, N, features_tmp)], + by.x = "features_dup_tmp", by.y = "features_tmp" + ) + X_boot0 <- rbind( + X_boot00[, .(id_coalition, features, n_features, N)], + X_boot00_paired[, .(id_coalition, features, n_features, N)] + ) + } else { + X_boot0 <- X_samp[ + sample.int( + n = .N, + size = n_coalitions_boot, + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, features, n_features, N) + ] + } + + + X_boot0[, shapley_weight := .N / n_coalitions_boot, by = "id_coalition"] + X_boot0 <- unique(X_boot0, by = "id_coalition") + + X_boot <- rbind(X_keep, X_boot0) + data.table::setorder(X_boot, id_coalition) + + kernelSHAP_reweighting(X_boot, reweight = shapley_reweight) # reweights the shapley weights by reference + + W_boot <- shapr::weight_matrix( + X = X_boot, + normalize_W_weights = TRUE, + is_groupwise = FALSE + ) + + kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[, id_coalition], -"id_coalition"])) + + boot_sd_array[, , i] <- copy(kshap_boot) + } + + std_dev_mat <- apply(boot_sd_array, c(1, 2), sd) + + dt_kshap_boot_sd <- data.table::as.data.table(std_dev_mat) + colnames(dt_kshap_boot_sd) <- c("none", shap_names) + + return(dt_kshap_boot_sd) +} + +bootstrap_shapley <- function(internal, dt_vS, n_boot_samps = 100, seed = 123) { + iter <- length(internal$iter_list) + type <- internal$parameters$type + is_groupwise <- internal$parameters$is_groupwise + X_list <- internal$iter_list[[iter]]$X_list + + result <- list() + if (type == "forecast") { + n_explain <- internal$parameters$n_explain + for (i in seq_along(X_list)) { + X <- X_list[[i]] + if (is_groupwise) { + n_shapley_values <- length(internal$data$shap_names) + shap_names <- internal$data$shap_names + } else { + n_shapley_values <- length(internal$parameters$horizon_features[[i]]) + shap_names <- internal$parameters$horizon_features[[i]] + } + dt_cols <- c(1, seq_len(n_explain) + (i - 1) * n_explain + 1) + dt_vS_this <- dt_vS[, dt_cols, with = FALSE] + result[[i]] <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS_this, n_boot_samps, seed) + } + result <- rbindlist(result, fill = TRUE) + } else { + X <- internal$iter_list[[iter]]$X + n_shapley_values <- internal$parameters$n_shapley_values + shap_names <- internal$parameters$shap_names + result <- bootstrap_shapley_inner(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps, seed) + } + return(result) +} + +bootstrap_shapley_inner <- function(X, n_shapley_values, shap_names, internal, dt_vS, n_boot_samps = 100, seed = 123) { + type <- internal$parameters$type + iter <- length(internal$iter_list) + + set.seed(seed) + + n_explain <- internal$parameters$n_explain + paired_shap_sampling <- internal$parameters$paired_shap_sampling + shapley_reweight <- internal$parameters$kernelSHAP_reweighting + + X_org <- copy(X) + + boot_sd_array <- array(NA, dim = c(n_explain, n_shapley_values + 1, n_boot_samps)) + + X_keep <- X_org[c(1, .N), .(id_coalition, coalitions, coalition_size, N)] + X_samp <- X_org[-c(1, .N), .(id_coalition, coalitions, coalition_size, N, shapley_weight, sample_freq)] + X_samp[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + n_coalitions_boot <- X_samp[, sum(sample_freq)] + + if (paired_shap_sampling) { + # Sample with replacement + X_boot00 <- X_samp[ + sample.int( + n = .N, + size = ceiling(n_coalitions_boot * n_boot_samps / 2), + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, coalitions, coalition_size, N, sample_freq) + ] + + X_boot00[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot / 2)] + + X_boot00_paired <- copy(X_boot00[, .(coalitions, boot_id)]) + X_boot00_paired[, coalitions := lapply(coalitions, function(x) seq(n_shapley_values)[-x])] + X_boot00_paired[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + # Extract the paired coalitions from X_samp + X_boot00_paired <- merge(X_boot00_paired, + X_samp[, .(id_coalition, coalition_size, N, shapley_weight, coalitions_tmp)], + by = "coalitions_tmp" + ) + X_boot0 <- rbind( + X_boot00[, .(boot_id, id_coalition, coalitions, coalition_size, N)], + X_boot00_paired[, .(boot_id, id_coalition, coalitions, coalition_size, N)] + ) + + X_boot <- rbind(X_keep[rep(1:2, each = n_boot_samps), ][, boot_id := rep(seq(n_boot_samps), times = 2)], X_boot0) + setkey(X_boot, boot_id, id_coalition) + X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)] + X_boot <- unique(X_boot, by = c("id_coalition", "boot_id")) + X_boot[, shapley_weight := sample_freq] + X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] + } else { + X_boot0 <- X_samp[ + sample.int( + n = .N, + size = n_coalitions_boot * n_boot_samps, + replace = TRUE, + prob = sample_freq + ), + .(id_coalition, coalitions, coalition_size, N) + ] + X_boot <- rbind(X_keep[rep(1:2, each = n_boot_samps), ], X_boot0) + X_boot[, boot_id := rep(seq(n_boot_samps), times = n_coalitions_boot + 2)] + + setkey(X_boot, boot_id, id_coalition) + X_boot[, sample_freq := .N / n_coalitions_boot, by = .(id_coalition, boot_id)] + X_boot <- unique(X_boot, by = c("id_coalition", "boot_id")) + X_boot[, shapley_weight := sample_freq] + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full] + X_boot[coalition_size == 0 | id_coalition %in% full_ids, shapley_weight := X_org[1, shapley_weight]] + } else { + X_boot[coalition_size %in% c(0, n_shapley_values), shapley_weight := X_org[1, shapley_weight]] + } + } + + for (i in seq_len(n_boot_samps)) { + this_X <- X_boot[boot_id == i] # This is highly inefficient, but the best way to deal with the reweighting for now + kernelSHAP_reweighting(this_X, reweight = shapley_reweight) + + W_boot <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + kshap_boot <- t(W_boot %*% as.matrix(dt_vS[id_coalition %in% X_boot[ + boot_id == i, + id_coalition + ], -"id_coalition"])) + + boot_sd_array[, , i] <- copy(kshap_boot) + } + + std_dev_mat <- apply(boot_sd_array, c(1, 2), sd) + + dt_kshap_boot_sd <- data.table::as.data.table(std_dev_mat) + colnames(dt_kshap_boot_sd) <- c("none", shap_names) + + return(dt_kshap_boot_sd) +} diff --git a/R/compute_vS.R b/R/compute_vS.R index 1c6deb190890708c387b54cb6428c8a92d0cc31c..321a391d3f3226a0fce66bcc26780439da1e6f41 100644 --- a/R/compute_vS.R +++ b/R/compute_vS.R @@ -1,23 +1,35 @@ #' Computes `v(S)` for all features subsets `S`. #' +#' @inheritParams default_doc_explain #' @inheritParams default_doc -#' @inheritParams explain #' #' @param method Character #' Indicates whether the lappy method (default) or loop method should be used. +#' This is only used for testing purposes. #' #' @export +#' @keywords internal compute_vS <- function(internal, model, predict_model, method = "future") { - S_batch <- internal$objects$S_batch + iter <- length(internal$iter_list) + + S_batch <- internal$iter_list[[iter]]$S_batch + + # verbose + cli_compute_vS(internal) if (method == "future") { - ret <- future_compute_vS_batch(S_batch = S_batch, internal = internal, model = model, predict_model = predict_model) + vS_list <- future_compute_vS_batch( + S_batch = S_batch, + internal = internal, + model = model, + predict_model = predict_model + ) } else { # Doing the same as above without future without progressbar or paralellization - ret <- list() + vS_list <- list() for (i in seq_along(S_batch)) { S <- S_batch[[i]] - ret[[i]] <- batch_compute_vS( + vS_list[[i]] <- batch_compute_vS( S = S, internal = internal, model = model, @@ -26,7 +38,11 @@ compute_vS <- function(internal, model, predict_model, method = "future") { } } - return(ret) + #### Adds v_S output above to any vS_list already computed #### + vS_list <- append_vS_list(vS_list, internal) + + + return(vS_list) } future_compute_vS_batch <- function(S_batch, internal, model, predict_model) { @@ -56,7 +72,8 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) { if (regression) { dt_vS <- batch_prepare_vS_regression(S = S, internal = internal) } else { - # Here dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE + # Here dt_vS is either only dt_vS or a list containing dt_vS and dt if + # internal$parameters$output_args$keep_samp_for_vS = TRUE dt_vS <- batch_prepare_vS_MC(S = S, internal = internal, model = model, predict_model = predict_model) } @@ -70,25 +87,29 @@ batch_compute_vS <- function(S, internal, model, predict_model, p = NULL) { #' @keywords internal #' @author Lars Henry Berge Olsen batch_prepare_vS_regression <- function(S, internal) { - max_id_comb <- internal$parameters$n_combinations + iter <- length(internal$iter_list) + + X <- internal$iter_list[[iter]]$X + + max_id_coal <- X[, .N] x_explain_y_hat <- internal$data$x_explain_y_hat # Compute the contribution functions different based on if the grand coalition is in S or not - if (!(max_id_comb %in% S)) { + if (!(max_id_coal %in% S)) { dt <- prepare_data(internal, index_features = S) } else { # Remove the grand coalition. NULL is for the special case for when the batch only includes the grand coalition. - dt <- if (length(S) > 1) prepare_data(internal, index_features = S[S != max_id_comb]) else NULL + dt <- if (length(S) > 1) prepare_data(internal, index_features = S[S != max_id_coal]) else NULL # Add the results for the grand coalition (Need to add names in case the batch only contains the grand coalition) - dt <- rbind(dt, data.table(as.integer(max_id_comb), matrix(x_explain_y_hat, nrow = 1)), use.names = FALSE) + dt <- rbind(dt, data.table(as.integer(max_id_coal), matrix(x_explain_y_hat, nrow = 1)), use.names = FALSE) # Need to add column names if batch S only contains the grand coalition - if (length(S) == 1) setnames(dt, c("id_combination", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) + if (length(S) == 1) setnames(dt, c("id_coalition", paste0("p_hat1_", seq_len(internal$parameters$n_explain)))) } - # Set id_combination to be the key - setkey(dt, id_combination) + # Set id_coalition to be the key + setkey(dt, id_coalition) return(dt) } @@ -105,9 +126,11 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) { explain_lags <- internal$parameters$explain_lags y <- internal$data$y xreg <- internal$data$xreg - keep_samp_for_vS <- internal$parameters$keep_samp_for_vS + keep_samp_for_vS <- internal$parameters$output_args$keep_samp_for_vS + causal_sampling <- internal$parameters$causal_sampling - dt <- batch_prepare_vS_MC_auxiliary(S = S, internal = internal) # Make it optional to store and return the dt_list + # Make it optional to store and return the dt_list + dt <- batch_prepare_vS_MC_auxiliary(S = S, internal = internal, causal_sampling = causal_sampling) pred_cols <- paste0("p_hat", seq_len(output_size)) @@ -132,27 +155,22 @@ batch_prepare_vS_MC <- function(S, internal, model, predict_model) { } #' @keywords internal -batch_prepare_vS_MC_auxiliary <- function(S, internal) { - max_id_combination <- internal$parameters$n_combinations +#' @author Lars Henry Berge Olsen and Martin Jullum +batch_prepare_vS_MC_auxiliary <- function(S, internal, causal_sampling) { x_explain <- internal$data$x_explain n_explain <- internal$parameters$n_explain + prepare_data_function <- if (causal_sampling) prepare_data_causal else prepare_data - # TODO: Check what is the fastest approach to deal with the last observation. - # Not doing this for the largest id combination (should check if this is faster or slower, actually) - # An alternative would be to delete rows from the dt which is provided by prepare_data. - if (!(max_id_combination %in% S)) { - # TODO: Need to handle the need for model for the AIC-versions here (skip for Python) - dt <- prepare_data(internal, index_features = S) + iter <- length(internal$iter_list) + X <- internal$iter_list[[iter]]$X + max_id_coalition <- X[, .N] + + if (max_id_coalition %in% S) { + dt <- if (length(S) == 1) NULL else prepare_data_function(internal, index_features = S[S != max_id_coalition]) + dt <- rbind(dt, data.table(id_coalition = max_id_coalition, x_explain, w = 1, id = seq_len(n_explain))) + setkey(dt, id, id_coalition) } else { - if (length(S) > 1) { - S <- S[S != max_id_combination] - dt <- prepare_data(internal, index_features = S) - } else { - dt <- NULL # Special case for when the batch only include the largest id - } - dt_max <- data.table(id_combination = max_id_combination, x_explain, w = 1, id = seq_len(n_explain)) - dt <- rbind(dt, dt_max) - setkey(dt, id, id_combination) + dt <- prepare_data_function(internal, index_features = S) } return(dt) } @@ -176,8 +194,8 @@ compute_preds <- function( if (type == "forecast") { dt[, (pred_cols) := predict_model( x = model, - newdata = .SD[, 1:n_endo], - newreg = .SD[, -(1:n_endo)], + newdata = .SD[, .SD, .SDcols = seq_len(n_endo)], + newreg = .SD[, .SD, .SDcols = seq_len(length(feature_names) - n_endo) + n_endo], horizon = horizon, explain_idx = explain_idx[id], explain_lags = explain_lags, @@ -193,13 +211,55 @@ compute_preds <- function( compute_MCint <- function(dt, pred_cols = "p_hat") { # Calculate contributions - dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_combination), .SDcols = pred_cols] - data.table::setkeyv(dt_res, c("id", "id_combination")) - dt_mat <- data.table::dcast(dt_res, id_combination ~ id, value.var = pred_cols) + dt_res <- dt[, lapply(.SD, function(x) sum(((x) * w) / sum(w))), .(id, id_coalition), .SDcols = pred_cols] + data.table::setkeyv(dt_res, c("id", "id_coalition")) + dt_mat <- data.table::dcast(dt_res, id_coalition ~ id, value.var = pred_cols) if (length(pred_cols) == 1) { names(dt_mat)[-1] <- paste0(pred_cols, "_", names(dt_mat)[-1]) } - # dt_mat[, id_combination := NULL] + # dt_mat[, id_coalition := NULL] return(dt_mat) } + +#' Appends the new vS_list to the prev vS_list +#' +#' +#' @inheritParams compute_estimates +#' +#' @export +#' @keywords internal +append_vS_list <- function(vS_list, internal) { + iter <- length(internal$iter_list) + + # Adds v_S output above to any vS_list already computed + if (iter > 1) { + prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map + prev_vS_list <- internal$iter_list[[iter - 1]]$vS_list + + # Need to map the old id_coalitions to the new numbers for this merging to work out + current_coalition_map <- internal$iter_list[[iter]]$coalition_map + + # Creates a mapper from the last id_coalition to the new id_coalition numbering + id_coalitions_mapper <- merge(prev_coalition_map, + current_coalition_map, + by = "coalitions_str", + suffixes = c("", "_new") + ) + prev_vS_list_new <- list() + + # Applies the mapper to update the prev_vS_list ot the new id_coalition numbering + for (k in seq_along(prev_vS_list)) { + prev_vS_list_new[[k]] <- merge(prev_vS_list[[k]], + id_coalitions_mapper[, .(id_coalition, id_coalition_new)], + by = "id_coalition" + ) + prev_vS_list_new[[k]][, id_coalition := id_coalition_new] + prev_vS_list_new[[k]][, id_coalition_new := NULL] + } + + # Merge the new vS_list with the old vS_list + vS_list <- c(prev_vS_list_new, vS_list) + } + return(vS_list) +} diff --git a/R/documentation.R b/R/documentation.R index 5fb5d0af6f97719024444402040acb52e531147d..eb7bab6133540cef894ace92919ffdb0644f9cb3 100644 --- a/R/documentation.R +++ b/R/documentation.R @@ -2,7 +2,8 @@ #' #' @param internal List. #' Holds all parameters, data, functions and computed objects used within [explain()] -#' The list contains one or more of the elements `parameters`, `data`, `objects`, `output`. +#' The list contains one or more of the elements `parameters`, `data`, `objects`, `iter_list`, `timing_list`, +#' `main_timing_list`, `output`, and `iter_timing_list`. #' #' @param model Objects. #' The model object that ought to be explained. @@ -30,13 +31,17 @@ default_doc <- function(internal, model, predict_model, output_size, extra, ...) #' Exported documentation helper function. #' -#' @param internal Not used. +#' @param iter Integer. +#' The iteration number. Only used internally. +#' +#' @param internal List. +#' Not used directly, but passed through from [explain()]. #' -#' @param index_features Positive integer vector. Specifies the indices of combinations to -#' apply to the present method. `NULL` means all combinations. Only used internally. +#' @param index_features Positive integer vector. Specifies the id_coalition to +#' apply to the present method. `NULL` means all coalitions. Only used internally. #' #' @keywords internal -default_doc_explain <- function(internal, index_features) { +default_doc_explain <- function(internal, iter, index_features) { NULL } @@ -46,7 +51,7 @@ default_doc_explain <- function(internal, index_features) { #' @description #' This helper function displays the specific arguments applicable to the different #' approaches. Note that when calling [shapr::explain()] from Python, the parameters -#' are renamed from the form `approach.parameter_name` to `approach_parameter_name`. +#' are renamed from the `approach.parameter_name` to `approach_parameter_name`. #' That is, an underscore has replaced the dot as the dot is reserved in Python. #' #' @inheritDotParams setup_approach.independence -internal diff --git a/R/explain.R b/R/explain.R index 3e1e10c9733436b588365fd0c7928c836c395b0d..caaaf0743305bf1237ddcb177906077315d9dd39 100644 --- a/R/explain.R +++ b/R/explain.R @@ -21,17 +21,19 @@ #' `"categorical"`, `"timeseries"`, `"independence"`, `"regression_separate"`, or `"regression_surrogate"`. #' The two regression approaches can not be combined with any other approach. See details for more information. #' -#' @param prediction_zero Numeric. +#' @param phi0 Numeric. #' The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any #' features. #' Typically we set this value equal to the mean of the response variable in our training data, but other choices #' such as the mean of the predictions in the training data are also reasonable. #' -#' @param n_combinations Integer. -#' If `group = NULL`, `n_combinations` represents the number of unique feature combinations to sample. -#' If `group != NULL`, `n_combinations` represents the number of unique group combinations to sample. -#' If `n_combinations = NULL`, the exact method is used and all combinations are considered. -#' The maximum number of combinations equals `2^m`, where `m` is the number of features. +#' @param max_n_coalitions Integer. +#' The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +#' (if `iterative = TRUE`). +#' If `iterative = FALSE` it represents the number of feature/group coalitions to use directly. +#' The quantity refers to the number of unique feature coalitions if `group = NULL`, +#' and group coalitions if `group != NULL`. +#' `max_n_coalitions = NULL` corresponds to `max_n_coalitions=2^n_features`. #' #' @param group List. #' If `NULL` regular feature wise Shapley values are computed. @@ -39,39 +41,30 @@ #' the number of groups. The list element contains character vectors with the features included #' in each of the different groups. #' -#' @param n_samples Positive integer. -#' Indicating the maximum number of samples to use in the -#' Monte Carlo integration for every conditional expectation. See also details. -#' -#' @param n_batches Positive integer (or NULL). -#' Specifies how many batches the total number of feature combinations should be split into when calculating the -#' contribution function for each test observation. -#' The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -#' which depends on `approach` and `n_combinations`. -#' For models with many features, increasing the number of batches reduces the RAM allocation significantly. -#' This typically comes with a small increase in computation time. +#' @param n_MC_samples Positive integer. +#' Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +#' For `approach="ctree"`, `n_MC_samples` corresponds to the number of samples +#' from the leaf node (see an exception related to the `ctree.sample` argument [shapr::setup_approach.ctree()]). +#' For `approach="empirical"`, `n_MC_samples` is the \eqn{K} parameter in equations (14-15) of +#' Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +#' `empirical.eta` argument [shapr::setup_approach.empirical()]. #' #' @param seed Positive integer. #' Specifies the seed before any randomness based code is being run. -#' If `NULL` the seed will be inherited from the calling environment. -#' -#' @param keep_samp_for_vS Logical. -#' Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -#' (in `internal$output`) +#' If `NULL` no seed is set in the calling environment. #' #' @param predict_model Function. #' The prediction function used when `model` is not natively supported. -#' (Run [get_supported_models()] for a list of natively supported -#' models.) +#' (Run [get_supported_models()] for a list of natively supported models.) #' The function must have two arguments, `model` and `newdata` which specify, respectively, the model -#' and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +#' and a data.frame/data.table to compute predictions for. +#' The function must give the prediction as a numeric vector. #' `NULL` (the default) uses functions specified internally. #' Can also be used to override the default function for natively supported model classes. #' #' @param get_model_specs Function. #' An optional function for checking model/data consistency when `model` is not natively supported. -#' (Run [get_supported_models()] for a list of natively supported -#' models.) +#' (Run [get_supported_models()] for a list of natively supported models.) #' The function takes `model` as argument and provides a list with 3 elements: #' \describe{ #' \item{labels}{Character vector with the names of each feature.} @@ -82,18 +75,102 @@ #' disabled for unsupported model classes. #' Can also be used to override the default function for natively supported model classes. #' -#' @param MSEv_uniform_comb_weights Logical. If `TRUE` (default), then the function weights the combinations -#' uniformly when computing the MSEv criterion. If `FALSE`, then the function use the Shapley kernel weights to -#' weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -#' sampling frequency when not all combinations are considered. -#' -#' @param timing Logical. -#' Whether the timing of the different parts of the `explain()` should saved in the model object. #' -#' @param verbose An integer specifying the level of verbosity. If `0`, `shapr` will stay silent. -#' If `1`, it will print information about performance. If `2`, some additional information will be printed out. -#' Use `0` (default) for no verbosity, `1` for low verbose, and `2` for high verbose. -#' TODO: Make this clearer when we end up fixing this and if they should force a progressr bar. +#' @param verbose String vector or NULL. +#' Specifies the verbosity (printout detail level) through one or more of strings `"basic"`, `"progress"`, +#' `"convergence"`, `"shapley"` and `"vS_details"`. +#' `"basic"` (default) displays basic information about the computation which is being performed. +#' `"progress` displays information about where in the calculation process the function currently is. +#' #' `"convergence"` displays information on how close to convergence the Shapley value estimates are +#' (only when `iterative = TRUE`) . +#' `"shapley"` displays intermediate Shapley value estimates and standard deviations (only when `iterative = TRUE`) +#' + the final estimates. +#' `"vS_details"` displays information about the v_S estimates. +#' This is most relevant for `approach %in% c("regression_separate", "regression_surrogate", "vaeac"`). +#' `NULL` means no printout. +#' Note that any combination of four strings can be used. +#' E.g. `verbose = c("basic", "vS_details")` will display basic information + details about the vS estimation process. +#' +#' @param paired_shap_sampling Logical. +#' If `TRUE` (default), paired versions of all sampled coalitions are also included in the computation. +#' That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +#' computing the Shapley values. This is done to reduce the variance of the Shapley value estimates. +#' +#' @param iterative Logical or NULL +#' If `NULL` (default), the argument is set to `TRUE` if there are more than 5 features/groups, and `FALSE` otherwise. +#' If eventually `TRUE`, the Shapley values are estimated iteratively in an iterative manner. +#' This provides sufficiently accurate Shapley value estimates faster. +#' First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +#' values. +#' A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +#' If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +#' coalitions. +#' The process is repeated until the variances are below the threshold. +#' Specifics related to the iterative process and convergence criterion are set through `iterative_args`. +#' +#' @param iterative_args Named list. +#' Specifices the arguments for the iterative procedure. +#' See [shapr::get_iterative_args_default()] for description of the arguments and their default values. +#' @param output_args Named list. +#' Specifices certain arguments related to the output of the function. +#' See [shapr::get_output_args_default()] for description of the arguments and their default values. +#' @param extra_computation_args Named list. +#' Specifices extra arguments related to the computation of the Shapley values. +#' See [shapr::get_extra_est_args_default()] for description of the arguments and their default values. +#' @param kernelSHAP_reweighting String. +#' How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +#' the randomness and thereby the variance of the Shapley value estimates. +#' One of `'none'`, `'on_N'`, `'on_all'`, `'on_all_cond'` (default). +#' `'none'` means no reweighting, i.e. the sampling frequency weights are used as is. +#' `'on_coal_size'` means the sampling frequencies are averaged over all coalitions of the same size. +#' `'on_N'` means the sampling frequencies are averaged over all coalitions with the same original sampling +#' probabilities. +#' `'on_all'` means the original sampling probabilities are used for all coalitions. +#' `'on_all_cond'` means the original sampling probabilities are used for all coalitions, while adjusting for the +#' probability that they are sampled at least once. +#' This method is preferred as it has performed the best in simulation studies. +#' +#' @param prev_shapr_object `shapr` object or string. +#' If an object of class `shapr` is provided or string with a path to where intermediate results are strored, +#' then the function will use the previous object to continue the computation. +#' This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +#' want to continue the iterative estimation. See the vignette for examples. +#' +#' @param asymmetric Logical. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' If `FALSE` (default), `explain` computes regular symmetric Shapley values, +#' If `TRUE`, then `explain` compute asymmetric Shapley values based on the (partial) causal ordering +#' given by `causal_ordering`. That is, `explain` only uses the feature combinations/coalitions that +#' respect the causal ordering when computing the asymmetric Shapley values. If `asymmetric` is `TRUE` and +#' `confounding` is `NULL` (default), then `explain` computes asymmetric conditional Shapley values as specified in +#' Frye et al. (2020). If `confounding` is provided, i.e., not `NULL`, then `explain` computes asymmetric causal +#' Shapley values as specified in Heskes et al. (2020). +#' +#' @param causal_ordering List. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' `causal_ordering` is an unnamed list of vectors specifying the components of the +#' partial causal ordering that the coalitions must respect. Each vector represents +#' a component and contains one or more features/groups identified by their names +#' (strings) or indices (integers). If `causal_ordering` is `NULL` (default), no causal +#' ordering is assumed and all possible coalitions are allowed. No causal ordering is +#' equivalent to a causal ordering with a single component that includes all features +#' (`list(1:n_features)`) or groups (`list(1:n_groups)`) for feature-wise and group-wise +#' Shapley values, respectively. For feature-wise Shapley values and +#' `causal_ordering = list(c(1, 2), c(3, 4))`, the interpretation is that features 1 and 2 +#' are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +#' Note: All features/groups must be included in the `causal_ordering` without any duplicates. +#' +#' @param confounding Logical vector. +#' Not applicable for (regular) non-causal or asymmetric explanations. +#' `confounding` is a vector of logicals specifying whether confounding is assumed or not for each component in the +#' `causal_ordering`. If `NULL` (default), then no assumption about the confounding structure is made and `explain` +#' computes asymmetric/symmetric conditional Shapley values, depending on the value of `asymmetric`. +#' If `confounding` is a single logical, i.e., `FALSE` or `TRUE`, then this assumption is set globally +#' for all components in the causal ordering. Otherwise, `confounding` must be a vector of logicals of the same +#' length as `causal_ordering`, indicating the confounding assumption for each component. When `confounding` is +#' specified, then `explain` computes asymmetric/symmetric causal Shapley values, depending on the value of +#' `asymmetric`. The `approach` cannot be `regression_separate` and `regression_surrogate` as the +#' regression-based approaches are not applicable to the causal Shapley value methodology. #' #' @param ... Further arguments passed to specific approaches #' @@ -108,57 +185,50 @@ #' @inheritDotParams setup_approach.regression_surrogate #' @inheritDotParams setup_approach.timeseries #' -#' @details The most important thing to notice is that `shapr` has implemented eight different -#' Monte Carlo-based approaches for estimating the conditional distributions of the data, namely `"empirical"`, -#' `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. -#' `shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`, -#' and see the separate vignette on the regression-based approaches for more information. -#' In addition, the user also has the option of combining the different Monte Carlo-based approaches. -#' E.g., if you're in a situation where you have trained a model that consists of 10 features, -#' and you'd like to use the `"gaussian"` approach when you condition on a single feature, -#' the `"empirical"` approach if you condition on 2-5 features, and `"copula"` version -#' if you condition on more than 5 features this can be done by simply passing -#' `approach = c("gaussian", rep("empirical", 4), rep("copula", 4))`. If -#' `"approach[i]" = "gaussian"` means that you'd like to use the `"gaussian"` approach -#' when conditioning on `i` features. Conditioning on all features needs no approach as that is given -#' by the complete prediction itself, and should thus not be part of the vector. -#' -#' For `approach="ctree"`, `n_samples` corresponds to the number of samples -#' from the leaf node (see an exception related to the `sample` argument). -#' For `approach="empirical"`, `n_samples` is the \eqn{K} parameter in equations (14-15) of -#' Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -#' `empirical.eta` argument. -#' +#' @details The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +#' eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +#' `"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +#' `shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +#' It is also possible to combine the different approaches, see the vignettes for more information. +#' +#' The package also supports the computation of causal and asymmetric Shapley values as introduced by +#' Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +#' as a way to incorporate causal knowledge in the real world by restricting the possible feature +#' combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +#' Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +#' on the prediction, taking into account their causal relationships, by adapting the sampling procedure in `shapr`. +#' +#' The package allows for parallelized computation with progress updates through the tightly connected +#' [future::future] and [progressr::progressr] packages. See the examples below. +#' For iterative estimation (`iterative=TRUE`), intermediate results may also be printed to the console +#' (according to the `verbose` argument). +#' Moreover, the intermediate results are written to disk. +#' This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +#' written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +#' in a memory friendly manner. #' #' @return Object of class `c("shapr", "list")`. Contains the following items: #' \describe{ -#' \item{shapley_values}{data.table with the estimated Shapley values} -#' \item{internal}{List with the different parameters, data and functions used internally} +#' \item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +#' features along the columns. +#' The column `none` is the prediction not devoted to any of the features (given by the argument `phi0`)} +#' \item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +#' Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +#' definition 0 when all coalitions is used. +#' Only present when `extra_computation_args$compute_sd=TRUE`.} +#' \item{internal}{List with the different parameters, data, functions and other output used internally.} #' \item{pred_explain}{Numeric vector with the predictions for the explained observations} -#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +#' \item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +#' \href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +#' }{MSEv evaluation section in the vignette for details}.} +#' \item{timing}{List containing timing information for the different parts of the computation. +#' `init_time` and `end_time` gives the time stamps for the start and end of the computation. +#' `total_time_secs` gives the total time in seconds for the complete execution of `explain()`. +#' `main_timing_secs` gives the time in seconds for the main computations. +#' `iter_timing_secs` gives for each iteration of the iterative estimation, the time spent on the different parts +#' iterative estimation routine.} #' } #' -#' `shapley_values` is a data.table where the number of rows equals -#' the number of observations you'd like to explain, and the number of columns equals `m +1`, -#' where `m` equals the total number of features in your model. -#' -#' If `shapley_values[i, j + 1] > 0` it indicates that the j-th feature increased the prediction for -#' the i-th observation. Likewise, if `shapley_values[i, j + 1] < 0` it indicates that the j-th feature -#' decreased the prediction for the i-th observation. -#' The magnitude of the value is also important to notice. E.g. if `shapley_values[i, k + 1]` and -#' `shapley_values[i, j + 1]` are greater than `0`, where `j != k`, and -#' `shapley_values[i, k + 1]` > `shapley_values[i, j + 1]` this indicates that feature -#' `j` and `k` both increased the value of the prediction, but that the effect of the k-th -#' feature was larger than the j-th feature. -#' -#' The first column in `dt`, called `none`, is the prediction value not assigned to any of the features -#' (\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -#' It's equal for all observations and set by the user through the argument `prediction_zero`. -#' The difference between the prediction and `none` is distributed among the other features. -#' In theory this value should be the expected prediction without conditioning on any features. -#' Typically we set this value equal to the mean of the response variable in our training data, but other choices -#' such as the mean of the predictions in the training data are also reasonable. -#' #' @examples #' #' # Load example data @@ -181,14 +251,26 @@ #' # Explain predictions #' p <- mean(data_train[, y_var]) #' +#' \dontrun{ +#' # (Optionally) enable parallelization via the future package +#' if (requireNamespace("future", quietly = TRUE)) { +#' future::plan("multisession", workers = 2) +#' } +#' } +#' +#' # (Optionally) enable progress updates within every iteration via the progressr package +#' if (requireNamespace("progressr", quietly = TRUE)) { +#' progressr::handlers(global = TRUE) +#' } +#' #' # Empirical approach #' explain1 <- explain( #' model = model, #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian approach @@ -197,8 +279,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian copula approach @@ -207,8 +289,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "copula", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # ctree approach @@ -217,8 +299,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -228,12 +310,12 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = approach, -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' # Print the Shapley values -#' print(explain1$shapley_values) +#' print(explain1$shapley_values_est) #' #' # Plot the results #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -250,10 +332,10 @@ #' x_train = x_train, #' group = group_list, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) -#' print(explain_groups$shapley_values) +#' print(explain_groups$shapley_values_est) #' #' # Separate and surrogate regression approaches with linear regression models. #' # More complex regression models can be used, and we can use CV to @@ -265,7 +347,7 @@ #' model = model, #' x_explain = x_explain, #' x_train = x_train, -#' prediction_zero = p, +#' phi0 = p, #' approach = "regression_separate", #' regression.model = parsnip::linear_reg() #' ) @@ -274,40 +356,72 @@ #' model = model, #' x_explain = x_explain, #' x_train = x_train, -#' prediction_zero = p, +#' phi0 = p, #' approach = "regression_surrogate", #' regression.model = parsnip::linear_reg() #' ) #' +#' ## iterative estimation +#' # For illustration purposes only. By default not used for such small dimensions as here +#' +#' # Gaussian approach +#' explain_iterative <- explain( +#' model = model, +#' x_explain = x_explain, +#' x_train = x_train, +#' approach = "gaussian", +#' phi0 = p, +#' n_MC_samples = 1e2, +#' iterative = TRUE, +#' iterative_args = list(initial_n_coalitions = 10) +#' ) +#' #' @export #' #' @author Martin Jullum, Lars Henry Berge Olsen #' #' @references -#' Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: -#' More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +#' - Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +#' More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +#' - Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +#' incorporating causal knowledge into model-agnostic explainability. +#' Advances in neural information processing systems, 33, 1229-1239. +#' - Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +#' Exploiting causal knowledge to explain individual predictions of complex models. +#' Advances in neural information processing systems, 33, 4778-4789. +#' - Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +#' model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. explain <- function(model, x_explain, x_train, approach, - prediction_zero, - n_combinations = NULL, + phi0, + iterative = NULL, + max_n_coalitions = NULL, group = NULL, - n_samples = 1e3, - n_batches = NULL, + paired_shap_sampling = TRUE, + n_MC_samples = 1e3, + kernelSHAP_reweighting = "on_all_cond", seed = 1, - keep_samp_for_vS = FALSE, + verbose = "basic", predict_model = NULL, get_model_specs = NULL, - MSEv_uniform_comb_weights = TRUE, - timing = TRUE, - verbose = 0, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + extra_computation_args = list(), + iterative_args = list(), + output_args = list(), ...) { # ... is further arguments passed to specific approaches - timing_list <- list(init_time = Sys.time()) - set.seed(seed) + init_time <- Sys.time() + + if (!is.null(seed)) { + set.seed(seed) + } # Gets and check feature specs from the model feature_specs <- get_feature_specs(get_model_specs, model) @@ -318,21 +432,27 @@ explain <- function(model, x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = max_n_coalitions, group = group, - n_samples = n_samples, - n_batches = n_batches, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, feature_specs = feature_specs, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, verbose = verbose, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, + init_time = init_time, + prev_shapr_object = prev_shapr_object, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + output_args = output_args, + extra_computation_args = extra_computation_args, ... ) - timing_list$setup <- Sys.time() # Gets predict_model (if not passed to explain) predict_model <- get_predict_model(predict_model = predict_model, model = model) @@ -345,55 +465,104 @@ explain <- function(model, internal = internal ) - timing_list$test_prediction <- Sys.time() + internal$timing_list$test_prediction <- Sys.time() + + + internal <- additional_regression_setup(internal, model = model, predict_model = predict_model) + + # Not called for approach %in% c("regression_surrogate","vaeac") + internal <- setup_approach(internal, model = model, predict_model = predict_model) + internal$main_timing_list <- internal$timing_list - # Add the predicted response of the training and explain data to the internal list for regression-based methods. - # Use isTRUE as `regression` is not present (NULL) for non-regression methods (i.e., Monte Carlo-based methods). - if (isTRUE(internal$parameters$regression)) { - internal <- regression.get_y_hat(internal = internal, model = model, predict_model = predict_model) + converged <- FALSE + iter <- length(internal$iter_list) + + if (!is.null(seed)) { + set.seed(seed) + } + + cli_startup(internal, class(model), verbose) + + + while (converged == FALSE) { + cli_iter(verbose, internal, iter) + + internal$timing_list <- list(init = Sys.time()) + + # Setup the Shapley framework + internal <- shapley_setup(internal) + + # Only actually called for approach %in% c("regression_surrogate","vaeac") + internal <- setup_approach(internal, model = model, predict_model = predict_model) + + # Compute the vS + vS_list <- compute_vS(internal, model, predict_model) + + # Compute shapley value estimated and bootstrapped standard deviations + internal <- compute_estimates(internal, vS_list) + + # Check convergence based on estimates and standard deviations (and thresholds) + internal <- check_convergence(internal) + + # Save intermediate results + save_results(internal) + + # Preparing parameters for next iteration (does not do anything if already converged) + internal <- prepare_next_iteration(internal) + + # Printing iteration information + print_iter(internal) + + # Setting globals for to simplify the loop + converged <- internal$iter_list[[iter]]$converged + + internal$timing_list$postprocess_res <- Sys.time() + + internal$iter_timing_list[[iter]] <- internal$timing_list + + iter <- iter + 1 } - # Sets up the Shapley (sampling) framework and prepares the - # conditional expectation computation for the chosen approach - # Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters - internal <- setup_computation(internal, model, predict_model) + internal$main_timing_list$main_computation <- Sys.time() + - timing_list$setup_computation <- Sys.time() + # Rerun after convergence to get the same output format as for the non-iterative approach + output <- finalize_explanation(internal = internal) - # Compute the v(S): - # MC: - # 1. Get the samples for the conditional distributions with the specified approach - # 2. Predict with these samples - # 3. Perform MC integration on these to estimate the conditional expectation (v(S)) - # Regression: - # 1. Directly estimate the conditional expectation (v(S)) using the fitted regression model(s) - vS_list <- compute_vS(internal, model, predict_model) + internal$main_timing_list$finalize_explanation <- Sys.time() - timing_list$compute_vS <- Sys.time() + output$timing <- compute_time(internal) - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - output <- finalize_explanation(vS_list = vS_list, internal = internal) - timing_list$shapley_computation <- Sys.time() + # Some cleanup when doing testing + testing <- internal$parameters$testing + if (isTRUE(testing)) { + output <- testing_cleanup(output) + } - # Compute the elapsed time for the different steps - if (timing == TRUE) output$timing <- compute_time(timing_list) - # Temporary to avoid failing tests - output <- remove_outputs_to_pass_tests(output) return(output) } +#' Cleans out certain output arguments to allow perfect reproducability of the output +#' +#' @inheritParams default_doc_explain +#' +#' @export #' @keywords internal -#' @author Lars Henry Berge Olsen -remove_outputs_to_pass_tests <- function(output) { - output$internal$objects$id_combination_mapper_dt <- NULL - output$internal$objects$cols_per_horizon <- NULL - output$internal$objects$W_list <- NULL +#' @author Lars Henry Berge Olsen, Martin Jullum +testing_cleanup <- function(output) { + # Removing the timing of different function calls + output$timing <- NULL + # Clearing out the timing lists as well + output$internal$main_timing_list <- NULL + output$internal$iter_timing_list <- NULL + output$internal$timing_list <- NULL + + # Removing paths to non-reproducable vaeac model objects if (isFALSE(output$internal$parameters$vaeac.extra_parameters$vaeac.save_model)) { output$internal$parameters[c( "vaeac", "vaeac.sampler", "vaeac.model", "vaeac.activation_function", "vaeac.checkpoint" @@ -402,8 +571,16 @@ remove_outputs_to_pass_tests <- function(output) { NULL } - # Remove the `regression` parameter from the output list when we are not doing regression - if (isFALSE(output$internal$parameters$regression)) output$internal$parameters$regression <- NULL + # Removing the fit times for regression surrogate models + if ("regression_surrogate" %in% output$internal$parameters$approach) { + # Deletes the fit_times for approach = regression_surrogate to make tests pass. + # In the future we could delete this only when a new argument in explain called testing is TRUE + output$internal$objects$regression.surrogate_model$pre$mold$blueprint$recipe$fit_times <- NULL + } + + # Delete the saving_path + output$internal$parameters$output_args$saving_path <- NULL + output$saving_path <- NULL return(output) } diff --git a/R/explain_forecast.R b/R/explain_forecast.R index f182e0c6308978db54ecfa56860e7d490b03fa6b..eeaff7ca342d0b82c6b42edd6c1afaf00f5ee6e8 100644 --- a/R/explain_forecast.R +++ b/R/explain_forecast.R @@ -79,7 +79,7 @@ #' explain_y_lags = 2, #' horizon = 3, #' approach = "empirical", -#' prediction_zero = p0_ar, +#' phi0 = p0_ar, #' group_lags = FALSE #' ) #' @@ -93,24 +93,24 @@ explain_forecast <- function(model, explain_xreg_lags = explain_y_lags, horizon, approach, - prediction_zero, - n_combinations = NULL, + phi0, + max_n_coalitions = NULL, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "on_all_cond", group_lags = TRUE, group = NULL, - n_samples = 1e3, - n_batches = NULL, + n_MC_samples = 1e3, seed = 1, - keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, - timing = TRUE, - verbose = 0, + verbose = "basic", ...) { # ... is further arguments passed to specific approaches - timing_list <- list( - init_time = Sys.time() - ) + init_time <- Sys.time() - set.seed(seed) + if (!is.null(seed)) { + set.seed(seed) + } # Gets and check feature specs from the model feature_specs <- get_feature_specs(get_model_specs, model) @@ -120,22 +120,23 @@ explain_forecast <- function(model, train_idx <- seq.int(from = max(c(explain_y_lags, explain_xreg_lags)), to = nrow(y))[-explain_idx] } - # Sets up and organizes input parameters # Checks the input parameters and their compatability # Checks data/model compatability internal <- setup( approach = approach, - prediction_zero = prediction_zero, + phi0 = phi0, output_size = horizon, - n_combinations = n_combinations, - n_samples = n_samples, - n_batches = n_batches, + max_n_coalitions = max_n_coalitions, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, feature_specs = feature_specs, type = "forecast", horizon = horizon, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, + init_time = init_time, y = y, xreg = xreg, train_idx = train_idx, @@ -144,12 +145,10 @@ explain_forecast <- function(model, explain_xreg_lags = explain_xreg_lags, group_lags = group_lags, group = group, - timing = timing, verbose = verbose, ... ) - timing_list$setup <- Sys.time() # Gets predict_model (if not passed to explain) predict_model <- get_predict_model( @@ -157,7 +156,6 @@ explain_forecast <- function(model, model = model ) - # Checks that predict_model gives correct format test_predict_model( x_test = head(internal$data$x_train, 2), @@ -166,60 +164,82 @@ explain_forecast <- function(model, internal = internal ) - timing_list$test_prediction <- Sys.time() + internal$timing_list$test_prediction <- Sys.time() + # Setup for approach + internal <- setup_approach(internal, model = model, predict_model = predict_model) - # Sets up the Shapley (sampling) framework and prepares the - # conditional expectation computation for the chosen approach - # Note: model and predict_model are ONLY used by the AICc-methods of approach empirical to find optimal parameters - internal <- setup_computation(internal, model, predict_model) + internal$main_timing_list <- internal$timing_list - timing_list$setup_computation <- Sys.time() + converged <- FALSE + iter <- length(internal$iter_list) + if (!is.null(seed)) { + set.seed(seed) + } - # Compute the v(S): - # Get the samples for the conditional distributions with the specified approach - # Predict with these samples - # Perform MC integration on these to estimate the conditional expectation (v(S)) - vS_list <- compute_vS(internal, model, predict_model, method = "regular") + cli_startup(internal, class(model), verbose) - timing_list$compute_vS <- Sys.time() + while (converged == FALSE) { + cli_iter(verbose, internal, iter) - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - output <- finalize_explanation( - vS_list = vS_list, - internal = internal - ) + internal$timing_list <- list(init = Sys.time()) - if (timing == TRUE) { - output$timing <- compute_time(timing_list) - } + # setup the Shapley framework + internal <- shapley_setup_forecast(internal) - # Temporary to avoid failing tests - output <- remove_outputs_pass_tests_fore(output) + # May not need to be called here? + internal <- setup_approach(internal, model = model, predict_model = predict_model) - return(output) -} + # Compute the vS + vS_list <- compute_vS(internal, model, predict_model, method = "regular") + + # Compute Shapley values based on conditional expectations (v(S)) + internal <- compute_estimates( + vS_list = vS_list, + internal = internal + ) + + # Check convergence based on estimates and standard deviations (and thresholds) + internal <- check_convergence(internal) -#' @keywords internal -#' @author Lars Henry Berge Olsen -remove_outputs_pass_tests_fore <- function(output) { - # Temporary to avoid failing tests related to vaeac approach - if (isFALSE(output$internal$parameters$vaeac.extra_parameters$vaeac.save_model)) { - output$internal$parameters[c( - "vaeac", "vaeac.sampler", "vaeac.model", "vaeac.activation_function", "vaeac.checkpoint" - )] <- NULL - output$internal$parameters$vaeac.extra_parameters[c("vaeac.folder_to_save_model", "vaeac.model_description")] <- - NULL + # Save intermediate results + save_results(internal) + + # Preparing parameters for next iteration (does not do anything if already converged) + internal <- prepare_next_iteration(internal) + + # Printing iteration information + print_iter(internal) + + ### Setting globals for to simplify the loop + converged <- internal$iter_list[[iter]]$converged + + internal$timing_list$postprocess_res <- Sys.time() + + internal$iter_timing_list[[iter]] <- internal$timing_list + + iter <- iter + 1 } - # Remove the `regression` parameter from the output list when we are not doing regression - if (isFALSE(output$internal$parameters$regression)) output$internal$parameters$regression <- NULL + internal$main_timing_list$main_computation <- Sys.time() + + output <- finalize_explanation(internal = internal) + + internal$main_timing_list$finalize_explanation <- Sys.time() + + output$timing <- compute_time(internal) + + # Some cleanup when doing testing + testing <- internal$parameters$testing + if (isTRUE(testing)) { + output <- testing_cleanup(output) + } return(output) } + #' Set up data for explain_forecast #' #' @param y A matrix or numeric vector containing the endogenous variables for the model. @@ -326,6 +346,8 @@ get_data_forecast <- function(y, xreg, train_idx, explain_idx, explain_y_lags, e y = y, xreg = xreg, group = reg_fcast$group, + horizon_group = reg_fcast$horizon_group, + shap_names = names(data_lag$group), n_endo = ncol(data_lag$lagged), x_train = cbind( data.table::as.data.table(data_lag$lagged[train_idx, , drop = FALSE]), @@ -378,6 +400,7 @@ lag_data <- function(x, lags) { reg_forecast_setup <- function(x, horizon, group) { fcast <- matrix(NA, nrow(x) - horizon + 1, 0) names <- character() + horizon_group <- lapply(seq_len(horizon), function(i) names(group)[!(names(group) %in% colnames(x))]) for (i in seq_len(ncol(x))) { names_i <- paste0(colnames(x)[i], ".F", seq_len(horizon)) names <- c(names, names_i) @@ -386,8 +409,12 @@ reg_forecast_setup <- function(x, horizon, group) { fcast <- cbind(fcast, fcast_i) # Append group names if the exogenous regressor also has lagged values. - group[[colnames(x)[i]]] <- c(group[[colnames(x)[i]]], names_i) + for (h in seq_len(horizon)) { + group[[paste0(colnames(x)[i], ".", h)]] <- c(group[[colnames(x)[i]]], names_i[seq_len(h)]) + horizon_group[[h]] <- c(horizon_group[[h]], paste0(colnames(x)[i], ".", h)) + } + group[[colnames(x)[i]]] <- NULL } colnames(fcast) <- names - return(list(fcast = fcast, group = group)) + return(list(fcast = fcast, group = group, horizon_group = horizon_group)) } diff --git a/R/finalize_explanation.R b/R/finalize_explanation.R index 00a074751f51dca8e6157a1896566db644464711..b820c4297c51ff857afa826b54c07a36dc9a1dff 100644 --- a/R/finalize_explanation.R +++ b/R/finalize_explanation.R @@ -1,106 +1,96 @@ -#' Computes the Shapley values given `v(S)` +#' Gathers the final output to create the explanation object #' -#' @inherit explain -#' @inheritParams default_doc -#' @param vS_list List -#' Output from [compute_vS()] +#' @inheritParams default_doc_explain #' #' @export -finalize_explanation <- function(vS_list, internal) { - MSEv_uniform_comb_weights <- internal$parameters$MSEv_uniform_comb_weights +finalize_explanation <- function(internal) { + MSEv_uniform_comb_weights <- internal$parameters$output_args$MSEv_uniform_comb_weights + output_size <- internal$parameters$output_size + dt_vS <- internal$output$dt_vS - processed_vS_list <- postprocess_vS_list( - vS_list = vS_list, - internal = internal - ) + # Extracting iter (and deleting the last temporary empty list of iter_list) + iter <- length(internal$iter_list) - 1 + internal$iter_list[[iter + 1]] <- NULL - # Extract the predictions we are explaining - p <- get_p(processed_vS_list$dt_vS, internal) + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd + + # Setting parameters and objects used in the end from the last iteration + internal$objects$X <- internal$iter_list[[iter]]$X + internal$objects$S <- internal$iter_list[[iter]]$S + internal$objects$W <- internal$iter_list[[iter]]$W - # internal$timing$postprocessing <- Sys.time() - # Compute the Shapley values - dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS) - # internal$timing$shapley_computation <- Sys.time() # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) internal$tmp <- NULL - internal$output <- processed_vS_list - output <- list( - shapley_values = dt_shapley, - internal = internal, - pred_explain = p - ) - attr(output, "class") <- c("shapr", "list") + + # Extract the predictions we are explaining + p <- get_p(dt_vS, internal) + # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar. # TODO: check if it makes sense for output_size > 1. - if (internal$parameters$output_size == 1) { - output$MSEv <- compute_MSEv_eval_crit( + if (output_size == 1) { + MSEv <- compute_MSEv_eval_crit( internal = internal, - dt_vS = processed_vS_list$dt_vS, + dt_vS = dt_vS, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights ) + } else { + MSEv <- NULL } - return(output) -} - - -#' @keywords internal -postprocess_vS_list <- function(vS_list, internal) { - id_combination <- NULL # due to NSE - - keep_samp_for_vS <- internal$parameters$keep_samp_for_vS - prediction_zero <- internal$parameters$prediction_zero - n_explain <- internal$parameters$n_explain - - # Appending the zero-prediction to the list - dt_vS0 <- as.data.table(rbind(c(1, rep(prediction_zero, n_explain)))) + # Extract iterative results in a simplified format + iterative_results <- get_iter_results(internal$iter_list) - # Extracting/merging the data tables from the batch running - # TODO: Need a memory and speed optimized way to transform the output form dt_vS_list to two different lists, - # I.e. without copying the data more than once. For now I have modified run_batch such that it - # if keep_samp_for_vS=FALSE - # then there is only one copy, but there are two if keep_samp_for_vS=TRUE. This might be OK since the - # latter is used rarely - if (keep_samp_for_vS) { - names(dt_vS0) <- names(vS_list[[1]][[1]]) - - vS_list[[length(vS_list) + 1]] <- list(dt_vS0, NULL) - - dt_vS <- rbindlist(lapply(vS_list, `[[`, 1)) + output <- list( + shapley_values_est = dt_shapley_est, + shapley_values_sd = dt_shapley_sd, + pred_explain = p, + MSEv = MSEv, + iterative_results = iterative_results, + saving_path = internal$parameters$output_args$saving_path, + internal = internal + ) + attr(output, "class") <- c("shapr", "list") - dt_samp_for_vS <- rbindlist(lapply(vS_list, `[[`, 2), use.names = TRUE) + return(output) +} - data.table::setorder(dt_samp_for_vS, id_combination) - } else { - names(dt_vS0) <- names(vS_list[[1]]) +get_iter_results <- function(iter_list) { + ret <- list() + ret$dt_iter_shapley_est <- rbindlist(lapply(iter_list, `[[`, "dt_shapley_est"), idcol = "iter") + ret$dt_iter_shapley_sd <- rbindlist(lapply(iter_list, `[[`, "dt_shapley_sd"), idcol = "iter") + ret$iter_info_dt <- iter_list_to_dt(iter_list) + return(ret) +} - vS_list[[length(vS_list) + 1]] <- dt_vS0 +iter_list_to_dt <- function(iter_list, what = c( + "exact", "compute_sd", "n_coal_next_iter_factor", "n_coalitions", "n_batches", + "converged", "converged_exact", "converged_sd", "converged_max_iter", + "est_required_coalitions", "est_remaining_coalitions", "overall_conv_measure" + )) { + extracted <- lapply(iter_list, function(x) x[what]) + ret <- do.call(rbindlist, list(l = lapply(extracted, as.data.table), fill = TRUE)) + return(ret) +} - dt_vS <- rbindlist(vS_list) - dt_samp_for_vS <- NULL - } - data.table::setorder(dt_vS, id_combination) - output <- list( - dt_vS = dt_vS, - dt_samp_for_vS = dt_samp_for_vS - ) - return(output) -} #' @keywords internal get_p <- function(dt_vS, internal) { - id_combination <- NULL # due to NSE + id_coalition <- NULL # due to NSE + + iter <- length(internal$iter_list) + max_id_coalition <- internal$iter_list[[iter]]$n_coalitions - max_id_combination <- internal$parameters$n_combinations - p <- unlist(dt_vS[id_combination == max_id_combination, ][, id_combination := NULL]) + + p <- unlist(dt_vS[id_coalition == max_id_coalition, ][, id_coalition := NULL]) if (internal$parameters$type == "forecast") { names(p) <- apply(internal$parameters$output_labels, 1, function(x) paste0("explain_idx_", x[1], "_horizon_", x[2])) @@ -109,89 +99,42 @@ get_p <- function(dt_vS, internal) { return(p) } -#' Compute shapley values -#' @param dt_vS The contribution matrix. -#' -#' @inheritParams default_doc -#' -#' @return A `data.table` with Shapley values for each test observation. -#' @export -#' @keywords internal -compute_shapley_new <- function(internal, dt_vS) { - is_groupwise <- internal$parameters$is_groupwise - feature_names <- internal$parameters$feature_names - W <- internal$objects$W - type <- internal$parameters$type - - if (!is_groupwise) { - shap_names <- feature_names - } else { - shap_names <- names(internal$parameters$group) # TODO: Add group_names (and feature_names) to internal earlier - } - # If multiple horizons with explain_forecast are used, we only distribute value to those used at each horizon - if (type == "forecast") { - id_combination_mapper_dt <- internal$objects$id_combination_mapper_dt - horizon <- internal$parameters$horizon - cols_per_horizon <- internal$objects$cols_per_horizon - W_list <- internal$objects$W_list - kshap_list <- list() - for (i in seq_len(horizon)) { - W0 <- W_list[[i]] - dt_vS0 <- merge(dt_vS, id_combination_mapper_dt[horizon == i], by = "id_combination", all.y = TRUE) - data.table::setorder(dt_vS0, horizon_id_combination) - these_vS0_cols <- grep(paste0("p_hat", i, "_"), names(dt_vS0)) - kshap0 <- t(W0 %*% as.matrix(dt_vS0[, these_vS0_cols, with = FALSE])) - kshap_list[[i]] <- data.table::as.data.table(kshap0) - if (!is_groupwise) { - names(kshap_list[[i]]) <- c("none", cols_per_horizon[[i]]) - } else { - names(kshap_list[[i]]) <- c("none", shap_names) - } - } - dt_kshap <- cbind(internal$parameters$output_labels, rbindlist(kshap_list, fill = TRUE)) - } else { - kshap <- t(W %*% as.matrix(dt_vS[, -"id_combination"])) - dt_kshap <- data.table::as.data.table(kshap) - colnames(dt_kshap) <- c("none", shap_names) - } - return(dt_kshap) -} #' Mean Squared Error of the Contribution Function `v(S)` #' #' @inheritParams explain #' @inheritParams default_doc -#' @param dt_vS Data.table of dimension `n_combinations` times `n_explain + 1` containing the contribution function -#' estimates. The first column is assumed to be named `id_combination` and containing the ids of the combinations. -#' The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +#' @param dt_vS Data.table of dimension `n_coalitions` times `n_explain + 1` containing the contribution function +#' estimates. The first column is assumed to be named `id_coalition` and containing the ids of the coalitions. +#' The last row is assumed to be the full coalition, i.e., it contains the predicted responses for the observations #' which are to be explained. #' @param MSEv_skip_empty_full_comb Logical. If `TRUE` (default), we exclude the empty and grand -#' combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +#' coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical #' for all methods, i.e., their contribution function is independent of the used method as they are special cases not -#' effected by the used method. If `FALSE`, we include the empty and grand combinations/coalitions. In this situation, +#' effected by the used method. If `FALSE`, we include the empty and grand coalitions. In this situation, #' we also recommend setting `MSEv_uniform_comb_weights = TRUE`, as otherwise the large weights for the empty and -#' grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative. +#' grand coalitions will outweigh all other coalitions and make the MSEv criterion uninformative. #' #' @return #' List containing: #' \describe{ #' \item{`MSEv`}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged -#' over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} -#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +#' over both the coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +#' also contains the standard deviation of the MSEv values for each explicand (only averaged over the coalitions) #' divided by the square root of the number of explicands.} #' \item{`MSEv_explicand`}{A \code{\link[data.table]{data.table}} with the mean squared error for each -#' explicand, i.e., only averaged over the combinations/coalitions.} -#' \item{`MSEv_combination`}{A \code{\link[data.table]{data.table}} with the mean squared error for each -#' combination/coalition, i.e., only averaged over the explicands/observations. +#' explicand, i.e., only averaged over the coalitions.} +#' \item{`MSEv_coalition`}{A \code{\link[data.table]{data.table}} with the mean squared error for each +#' coalition, i.e., only averaged over the explicands/observations. #' The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for -#' each combination divided by the square root of the number of explicands.} +#' each coalition divided by the square root of the number of explicands.} #' } #' #' @description Function that computes the Mean Squared Error (MSEv) of the contribution function @@ -213,24 +156,28 @@ compute_MSEv_eval_crit <- function(internal, dt_vS, MSEv_uniform_comb_weights, MSEv_skip_empty_full_comb = TRUE) { + iter <- length(internal$iter_list) + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + n_explain <- internal$parameters$n_explain - n_combinations <- internal$parameters$n_combinations - id_combination_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_combinations - 1) else seq(1, n_combinations) - n_combinations_used <- length(id_combination_indices) - features <- internal$objects$X$features[id_combination_indices] + id_coalition_indices <- if (MSEv_skip_empty_full_comb) seq(2, n_coalitions - 1) else seq(1, n_coalitions) + n_coalitions_used <- length(id_coalition_indices) + + X <- internal$objects$X + coalitions <- X$coalitions[id_coalition_indices] # Extract the predicted responses f(x) - p <- unlist(dt_vS[id_combination == n_combinations, -"id_combination"]) + p <- unlist(dt_vS[id_coalition == n_coalitions, -"id_coalition"]) # Create contribution matrix - vS <- as.matrix(dt_vS[id_combination_indices, -"id_combination"]) + vS <- as.matrix(dt_vS[id_coalition_indices, -"id_coalition"]) # Square the difference between the v(S) and f(x) dt_squared_diff_original <- sweep(vS, 2, p)^2 # Get the weights - averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_combinations) else internal$objects$X$shapley_weight - averaging_weights <- averaging_weights[id_combination_indices] + averaging_weights <- if (MSEv_uniform_comb_weights) rep(1, n_coalitions) else X$shapley_weight + averaging_weights <- averaging_weights[id_coalition_indices] averaging_weights_scaled <- averaging_weights / sum(averaging_weights) # Apply the `averaging_weights_scaled` to each column (i.e., each explicand) @@ -241,8 +188,8 @@ compute_MSEv_eval_crit <- function(internal, MSEv_explicand <- colSums(dt_squared_diff) # The MSEv criterion for each coalition, i.e., only averaged over the explicands. - MSEv_combination <- rowMeans(dt_squared_diff * n_combinations_used) - MSEv_combination_sd <- apply(dt_squared_diff * n_combinations_used, 1, sd) / sqrt(n_explain) + MSEv_coalition <- rowMeans(dt_squared_diff * n_coalitions_used) + MSEv_coalition_sd <- apply(dt_squared_diff * n_coalitions_used, 1, sd) / sqrt(n_explain) # The MSEv criterion averaged over both the coalitions and explicands. MSEv <- mean(MSEv_explicand) @@ -250,8 +197,8 @@ compute_MSEv_eval_crit <- function(internal, # Set the name entries in the arrays names(MSEv_explicand) <- paste0("id_", seq(n_explain)) - names(MSEv_combination) <- paste0("id_combination_", id_combination_indices) - names(MSEv_combination_sd) <- paste0("id_combination_", id_combination_indices) + names(MSEv_coalition) <- paste0("id_coalition_", id_coalition_indices) + names(MSEv_coalition_sd) <- paste0("id_coalition_", id_coalition_indices) # Convert the results to data.table MSEv <- data.table( @@ -262,16 +209,67 @@ compute_MSEv_eval_crit <- function(internal, "id" = seq(n_explain), "MSEv" = MSEv_explicand ) - MSEv_combination <- data.table( - "id_combination" = id_combination_indices, - "features" = features, - "MSEv" = MSEv_combination, - "MSEv_sd" = MSEv_combination_sd + MSEv_coalition <- data.table( + "id_coalition" = id_coalition_indices, + "coalitions" = coalitions, + "MSEv" = MSEv_coalition, + "MSEv_sd" = MSEv_coalition_sd ) return(list( MSEv = MSEv, MSEv_explicand = MSEv_explicand, - MSEv_combination = MSEv_combination + MSEv_coalition = MSEv_coalition )) } + + +#' Computes the Shapley values given `v(S)` +#' +#' @inherit explain +#' @inheritParams default_doc +#' @param vS_list List +#' Output from [compute_vS()] +#' +#' @export +finalize_explanation_forecast <- function(vS_list, internal) { # Temporary used for forecast only (the old function) + MSEv_uniform_comb_weights <- internal$parameters$output_args$MSEv_uniform_comb_weights + + processed_vS_list <- postprocess_vS_list( + vS_list = vS_list, + internal = internal + ) + + # Extract the predictions we are explaining + p <- get_p(processed_vS_list$dt_vS, internal) + + # Compute the Shapley values + dt_shapley <- compute_shapley_new(internal, processed_vS_list$dt_vS) + + # Clearing out the timing lists as they are added to the output separately + internal$main_timing_list <- internal$iter_timing_list <- internal$timing_list <- NULL + + # Clearing out the tmp list with model and predict_model (only added for AICc-types of empirical approach) + internal$tmp <- NULL + + internal$output <- processed_vS_list + + output <- list( + shapley_values_est = dt_shapley, + internal = internal, + pred_explain = p + ) + attr(output, "class") <- c("shapr", "list") + + # Compute the MSEv evaluation criterion if the output of the predictive model is a scalar. + # TODO: check if it makes sense for output_size > 1. + if (internal$parameters$output_size == 1) { + output$MSEv <- compute_MSEv_eval_crit( + internal = internal, + dt_vS = processed_vS_list$dt_vS, + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights + ) + } + + return(output) +} diff --git a/R/get_predict_model.R b/R/get_predict_model.R index 93577e8b95009c4c200b893be12838a5fc7b8aa8..4ffec6a45667b8094faf0c593e93331d810ae8c8 100644 --- a/R/get_predict_model.R +++ b/R/get_predict_model.R @@ -43,8 +43,11 @@ test_predict_model <- function(x_test, predict_model, model, internal) { if (!is.null(internal$parameters$type) && internal$parameters$type == "forecast") { tmp <- tryCatch(predict_model( x = model, - newdata = x_test[, 1:internal$data$n_endo, drop = FALSE], - newreg = x_test[, -(1:internal$data$n_endo), drop = FALSE], + newdata = x_test[, .SD, .SDcols = seq_len(internal$data$n_endo), drop = FALSE], + newreg = x_test[, .SD, + .SDcols = seq_len(ncol(x_test) - internal$data$n_endo) + internal$data$n_endo, + drop = FALSE + ], horizon = internal$parameters$horizon, explain_idx = rep(internal$parameters$explain_idx[1], 2), y = internal$data$y, diff --git a/R/model_arima.R b/R/model_arima.R index 7f53cd6ccd36fbceeb08d380f56ac0b2de79b78e..2b2a70d469557007985fcd2162a6d4ec43afe6c2 100644 --- a/R/model_arima.R +++ b/R/model_arima.R @@ -5,29 +5,32 @@ predict_model.Arima <- function(x, newdata, newreg, horizon, explain_idx, explai stop("The stats package is required for predicting stats models") } - prediction <- matrix(NA, nrow(newdata), horizon) - newdata <- as.matrix(newdata) + prediction <- matrix(NA, length(explain_idx), horizon) + newdata <- as.matrix(newdata, nrow = length(explain_idx)) newreg <- as.matrix(newreg) newdata_y_cols <- seq_len(explain_lags$y) newdata_xreg_cols_list <- lapply(paste0("xreg", seq_along(explain_lags$xreg)), function(x) grep(x, colnames(newdata))) exp_idx <- -1 - for (i in seq_len(nrow(newdata))) { + for (i in seq_len(length(explain_idx))) { if (explain_idx[i] != exp_idx) { exp_idx <- explain_idx[i] y_hist <- y[seq_len(exp_idx)] xreg_hist <- xreg[seq_len(exp_idx), , drop = FALSE] } - y_new <- as.numeric(newdata[i, newdata_y_cols]) - y_hist[seq.int(length.out = length(y_new), to = length(y_hist))] <- rev(y_new) + if (ncol(newdata) > 0) { + y_new <- as.numeric(newdata[i, newdata_y_cols]) + y_hist[seq.int(length.out = length(y_new), to = length(y_hist))] <- rev(y_new) + } if (ncol(xreg) == 0) { x <- forecast::Arima(y = y_hist, model = x) prediction[i, ] <- predict(x, h = horizon)$pred } else { for (j in seq_along(explain_lags$xreg)) { + if (length(newdata_xreg_cols_list[[j]]) == 0) next xreg_new <- as.numeric(newdata[i, newdata_xreg_cols_list[[j]]]) xreg_hist[seq.int(length.out = length(xreg_new), to = nrow(xreg_hist)), j] <- rev(xreg_new) } diff --git a/R/plot.R b/R/plot.R index ef80c9a32f05e816595c4a347989fe02b0d0dfa0..a206d1951fb2254a22cae2629d6aaed53ba6a9e6 100644 --- a/R/plot.R +++ b/R/plot.R @@ -60,8 +60,12 @@ #' character vector, indicating the name(s) of the feature(s) to plot. #' @param scatter_hist Logical. #' Only used for `plot_type = "scatter"`. -#' Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note that the -#' bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot. +#' Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note +#' that the bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot. +#' @param include_group_feature_means Logical. +#' Whether to include the average feature value in a group on the y-axis or not. +#' If `FALSE` (default), then no value is shown for the groups. If `TRUE`, then `shapr` includes the mean of the +#' features in each group. #' @param ... Currently not used. #' #' @details See the examples below, or `vignette("understanding_shapr", package = "shapr")` for an examples of @@ -97,8 +101,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -147,8 +151,8 @@ #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = p, -#' n_samples = 1e2 +#' phi0 = p, +#' n_MC_samples = 1e2 #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -156,7 +160,7 @@ #' plot(x, plot_type = "beeswarm") #' } #' -#' @author Martin Jullum, Vilde Ung +#' @author Martin Jullum, Vilde Ung, Lars Henry Berge Olsen plot.shapr <- function(x, plot_type = "bar", digits = 3, @@ -167,6 +171,7 @@ plot.shapr <- function(x, bar_plot_order = "largest_first", scatter_features = NULL, scatter_hist = TRUE, + include_group_feature_means = FALSE, ...) { if (!requireNamespace("ggplot2", quietly = TRUE)) { stop("ggplot2 is not installed. Please run install.packages('ggplot2')") @@ -180,26 +185,63 @@ plot.shapr <- function(x, bar_plot_order='smallest_first' or bar_plot_order='original'.")) } + # Remove the explain_id column + x$shapley_values_est <- x$shapley_values_est[, -"explain_id"] + if (is.null(index_x_explain)) index_x_explain <- seq(x$internal$parameters$n_explain) if (is.null(top_k_features)) top_k_features <- x$internal$parameters$n_features + 1 is_groupwise <- x$internal$parameters$is_groupwise + # For group-wise Shapley values, we check if we are to take the mean over grouped features + if (is_groupwise) { + if (is.na(include_group_feature_means) || + !is.logical(include_group_feature_means) || + length(include_group_feature_means) > 1) { + stop("`include_group_feature_means` must be single logical.") + } + if (!include_group_feature_means && plot_type %in% c("scatter", "beeswarm")) { + stop(paste0( + "`shapr` cannot make a `", plot_type, "` plot for group-wise Shapley values, as the plot needs a ", + "single feature value for the whole group.\n", + "For numerical data, the user can set `include_group_feature_means = TRUE` to use the mean of all ", + "grouped features. The user should use this option cautiously to not misinterpret the explanations." + )) + } + + if (any(x$internal$objects$feature_specs$classes != "numeric")) { + stop("`include_group_feature_means` cannot be `TRUE` for datasets with non-numerical features.") + } + + # Take the mean over the grouped features and update the feature name to the group name + x$internal$data$x_explain <- + x$internal$data$x_explain[, lapply( + x$internal$parameters$group, + function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE) + )] + + x$internal$data$x_train <- + x$internal$data$x_train[, lapply( + x$internal$parameters$group, + function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE) + )] + } + # melting Kshap - shap_names <- colnames(x$shapley_values)[-1] - dt_shap <- round(data.table::copy(x$shapley_values), digits = digits) + shap_names <- x$internal$parameters$shap_names + dt_shap <- round(data.table::copy(x$shapley_values_est), digits = digits) dt_shap[, id := .I] dt_shap_long <- data.table::melt(dt_shap, id.vars = "id", value.name = "phi") dt_shap_long[, sign := factor(sign(phi), levels = c(1, -1), labels = c("Increases", "Decreases"))] # Converting and melting Xtest - if (!is_groupwise) { + if (!is_groupwise || include_group_feature_means) { desc_mat <- trimws(format(x$internal$data$x_explain, digits = digits)) for (i in seq_len(ncol(desc_mat))) { desc_mat[, i] <- paste0(shap_names[i], " = ", desc_mat[, i]) } } else { - desc_mat <- trimws(format(x$shapley_values[, -1], digits = digits)) + desc_mat <- trimws(format(x$shapley_values_est[, -c("none")], digits = digits)) for (i in seq_len(ncol(desc_mat))) { desc_mat[, i] <- paste0(shap_names[i]) } @@ -257,7 +299,7 @@ plot.shapr <- function(x, # compute start and end values for waterfall rectangles data.table::setorder(dt_plot, rank_waterfall) dt_plot[, end := cumsum(phi), by = id] - expected <- x$internal$parameters$prediction_zero + expected <- x$internal$parameters$phi0 dt_plot[, start := c(expected, head(end, -1)), by = id] dt_plot[, phi_significant := format(phi, digits = digits), by = id] @@ -562,8 +604,7 @@ make_beeswarm_plot <- function(dt_plot, col, index_x_explain, x, factor_cols) { gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + - ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.4) + - # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggbeeswarm::geom_beeswarm(priority = "random", cex = 1 / length(index_x_explain)^(1 / 4)) + ggplot2::coord_flip() + ggplot2::theme_classic() + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + @@ -788,8 +829,8 @@ make_waterfall_plot <- function(dt_plot, #' Make plots to visualize and compare the MSEv evaluation criterion for a list of #' [shapr::explain()] objects applied to the same data and model. The function creates #' bar plots and line plots with points to illustrate the overall MSEv evaluation -#' criterion, but also for each observation/explicand and combination by only averaging over -#' the combinations and observations/explicands, respectively. +#' criterion, but also for each observation/explicand and coalition by only averaging over +#' the coalitions and observations/explicands, respectively. #' #' @inheritParams plot.shapr #' @inheritParams default_doc @@ -797,26 +838,26 @@ make_waterfall_plot <- function(dt_plot, #' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model. #' If the entries in the list are named, then the function use these names. Otherwise, they default to #' the approach names (with integer suffix for duplicates) for the explanation objects in `explanation_list`. -#' @param id_combination Integer vector. Which of the combinations (coalitions) to plot. -#' E.g. if you used `n_combinations = 16` in [explain()], you can generate a plot for the -#' first 5 combinations and the 10th by setting `id_combination = c(1:5, 10)`. +#' @param id_coalition Integer vector. Which of the coalitions to plot. +#' E.g. if you used `n_coalitions = 16` in [explain()], you can generate a plot for the +#' first 5 coalitions and the 10th by setting `id_coalition = c(1:5, 10)`. #' @param CI_level Positive numeric between zero and one. Default is `0.95` if the number of observations to explain is #' larger than 20, otherwise `CI_level = NULL`, which removes the confidence intervals. The level of the approximate -#' confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +#' confidence intervals for the overall MSEv and the MSEv_coalition. The confidence intervals are based on that #' the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the #' standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of #' freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. -#' MSEv ± t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by -#' sqrt(N_explicands), thus, the CI are MSEv ± t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +#' MSEv +/- t*SD(MSEv)/sqrt(N_explicands). Note that the `explain()` function already scales the standard deviation by +#' sqrt(N_explicands), thus, the CI are MSEv \/- t*MSEv_sd, where the values MSEv and MSEv_sd are extracted from the #' MSEv data.tables in the objects in the `explanation_list`. #' @param geom_col_width Numeric. Bar width. By default, set to 90% of the [ggplot2::resolution()] of the data. #' @param plot_type Character vector. The possible options are "overall" (default), "comb", and "explicand". #' If `plot_type = "overall"`, then the plot (one bar plot) associated with the overall MSEv evaluation criterion -#' for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +#' for each method is created, i.e., when averaging over both the coalitions and observations/explicands. #' If `plot_type = "comb"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -#' criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +#' criterion for each coalition are created, i.e., when we only average over the observations/explicands. #' If `plot_type = "explicand"`, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -#' criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +#' criterion for each observations/explicands are created, i.e., when we only average over the coalitions. #' If `plot_type` is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are #' created. #' @@ -854,7 +895,7 @@ make_waterfall_plot <- function(dt_plot, #' ) #' #' # Specifying the phi_0, i.e. the expected prediction without any features -#' prediction_zero <- mean(y_train) +#' phi0 <- mean(y_train) #' #' # Independence approach #' explanation_independence <- explain( @@ -862,8 +903,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "independence", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian 1e1 approach @@ -872,8 +913,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e1 +#' phi0 = phi0, +#' n_MC_samples = 1e1 #' ) #' #' # Gaussian 1e2 approach @@ -882,8 +923,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # ctree approach @@ -892,8 +933,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = "ctree", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -902,8 +943,8 @@ make_waterfall_plot <- function(dt_plot, #' x_explain = x_explain, #' x_train = x_train, #' approach = c("gaussian", "independence", "ctree"), -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Create a list of explanations with names @@ -916,24 +957,24 @@ make_waterfall_plot <- function(dt_plot, #' ) #' #' if (requireNamespace("ggplot2", quietly = TRUE)) { -#' # Create the default MSEv plot where we average over both the combinations and observations +#' # Create the default MSEv plot where we average over both the coalitions and observations #' # with approximate 95% confidence intervals #' plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") #' -#' # Can also create plots of the MSEv criterion averaged only over the combinations or observations. +#' # Can also create plots of the MSEv criterion averaged only over the coalitions or observations. #' MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, #' CI_level = 0.95, #' plot_type = c("overall", "comb", "explicand") #' ) #' MSEv_figures$MSEv_bar -#' MSEv_figures$MSEv_combination_bar +#' MSEv_figures$MSEv_coalition_bar #' MSEv_figures$MSEv_explicand_bar #' -#' # When there are many combinations or observations, then it can be easier to look at line plots -#' MSEv_figures$MSEv_combination_line_point +#' # When there are many coalitions or observations, then it can be easier to look at line plots +#' MSEv_figures$MSEv_coalition_line_point #' MSEv_figures$MSEv_explicand_line_point #' -#' # We can specify which observations or combinations to plot +#' # We can specify which observations or coalitions to plot #' plot_MSEv_eval_crit(explanation_list_named, #' plot_type = "explicand", #' index_x_explain = c(1, 3:4, 6), @@ -941,9 +982,9 @@ make_waterfall_plot <- function(dt_plot, #' )$MSEv_explicand_bar #' plot_MSEv_eval_crit(explanation_list_named, #' plot_type = "comb", -#' id_combination = c(3, 4, 9, 13:15), +#' id_coalition = c(3, 4, 9, 13:15), #' CI_level = 0.95 -#' )$MSEv_combination_bar +#' )$MSEv_coalition_bar #' #' # We can alter the figures if other palette schemes or design is wanted #' bar_text_n_decimals <- 1 @@ -973,7 +1014,7 @@ make_waterfall_plot <- function(dt_plot, #' @author Lars Henry Berge Olsen plot_MSEv_eval_crit <- function(explanation_list, index_x_explain = NULL, - id_combination = NULL, + id_coalition = NULL, CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, geom_col_width = 0.9, plot_type = "overall") { @@ -1005,20 +1046,22 @@ plot_MSEv_eval_crit <- function(explanation_list, # Check that the explanation objects explain the same observations MSEv_check_explanation_list(explanation_list) - # Get the number of observations and combinations and the quantile of the T distribution + # Get the number of observations and coalitions and the quantile of the T distribution + iter <- length(explanation_list[[1]]$internal$iter_list) + n_coalitions <- explanation_list[[1]]$internal$iter_list[[iter]]$n_coalitions + n_explain <- explanation_list[[1]]$internal$parameters$n_explain - n_combinations <- explanation_list[[1]]$internal$parameters$n_combinations tfrac <- if (is.null(CI_level)) NULL else qt((1 + CI_level) / 2, n_explain - 1) # Create data.tables of the MSEv values MSEv_dt_list <- MSEv_extract_MSEv_values( explanation_list = explanation_list, index_x_explain = index_x_explain, - id_combination = id_combination + id_coalition = id_coalition ) MSEv_dt <- MSEv_dt_list$MSEv MSEv_explicand_dt <- MSEv_dt_list$MSEv_explicand - MSEv_combination_dt <- MSEv_dt_list$MSEv_combination + MSEv_coalition_dt <- MSEv_dt_list$MSEv_coalition # Warnings related to the approximate confidence intervals if (!is.null(CI_level)) { @@ -1046,23 +1089,23 @@ plot_MSEv_eval_crit <- function(explanation_list, return_object <- list() if ("explicand" %in% plot_type) { - # MSEv averaged over only the combinations for each observation + # MSEv averaged over only the coalitions for each observation return_object <- c( return_object, make_MSEv_explicand_plots( MSEv_explicand_dt = MSEv_explicand_dt, - n_combinations = n_combinations, + n_coalitions = n_coalitions, geom_col_width = geom_col_width ) ) } if ("comb" %in% plot_type) { - # MSEv averaged over only the observations for each combinations + # MSEv averaged over only the observations for each coalitions return_object <- c( return_object, - make_MSEv_combination_plots( - MSEv_combination_dt = MSEv_combination_dt, + make_MSEv_coalition_plots( + MSEv_coalition_dt = MSEv_coalition_dt, n_explain = n_explain, geom_col_width = geom_col_width, tfrac = tfrac @@ -1071,10 +1114,10 @@ plot_MSEv_eval_crit <- function(explanation_list, } if ("overall" %in% plot_type) { - # MSEv averaged over both the combinations and observations + # MSEv averaged over both the coalitions and observations return_object$MSEv_bar <- make_MSEv_bar_plot( MSEv_dt = MSEv_dt, - n_combinations = n_combinations, + n_coalitions = n_coalitions, n_explain = n_explain, geom_col_width = geom_col_width, tfrac = tfrac @@ -1122,7 +1165,7 @@ MSEv_check_explanation_list <- function(explanation_list) { if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.") # Check that all explanation objects use the same column names for the Shapley values - if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) { + if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values_est)))) != 1) { stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.") } @@ -1149,7 +1192,7 @@ MSEv_check_explanation_list <- function(explanation_list) { )) } - # Check that all explanation objects use the same combinations + # Check that all explanation objects use the same coalitions entries_using_diff_combs <- sapply(explanation_list, function(explanation) { !identical(explanation_list[[1]]$internal$objects$X$features, explanation$internal$objects$X$features) }) @@ -1157,7 +1200,7 @@ MSEv_check_explanation_list <- function(explanation_list) { methods_with_diff_comb_str <- paste(names(entries_using_diff_combs)[entries_using_diff_combs], collapse = "', '") stop(paste0( "The object/objects '", methods_with_diff_comb_str, "' in `explanation_list` uses/use different ", - "coaltions than '", names(explanation_list)[1], "'. Cannot compare them." + "coalitions than '", names(explanation_list)[1], "'. Cannot compare them." )) } } @@ -1166,9 +1209,9 @@ MSEv_check_explanation_list <- function(explanation_list) { #' @author Lars Henry Berge Olsen MSEv_extract_MSEv_values <- function(explanation_list, index_x_explain = NULL, - id_combination = NULL) { - # Function that extract the MSEv values from the different explanations objects in ´explanation_list´, - # put the values in data.tables, and keep only the desired observations and combinations. + id_coalition = NULL) { + # Function that extract the MSEv values from the different explanations objects in explanation_list, + # put the values in data.tables, and keep only the desired observations and coalitions. # The overall MSEv criterion MSEv <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv), @@ -1183,27 +1226,27 @@ MSEv_extract_MSEv_values <- function(explanation_list, MSEv_explicand$id <- factor(MSEv_explicand$id) MSEv_explicand$Method <- factor(MSEv_explicand$Method, levels = names(explanation_list)) - # The MSEv evaluation criterion for each combination. - MSEv_combination <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_combination), + # The MSEv evaluation criterion for each coalition. + MSEv_coalition <- rbindlist(lapply(explanation_list, function(explanation) explanation$MSEv$MSEv_coalition), use.names = TRUE, idcol = "Method" ) - MSEv_combination$id_combination <- factor(MSEv_combination$id_combination) - MSEv_combination$Method <- factor(MSEv_combination$Method, levels = names(explanation_list)) + MSEv_coalition$id_coalition <- factor(MSEv_coalition$id_coalition) + MSEv_coalition$Method <- factor(MSEv_coalition$Method, levels = names(explanation_list)) - # Only keep the desired observations and combinations + # Only keep the desired observations and coalitions if (!is.null(index_x_explain)) MSEv_explicand <- MSEv_explicand[id %in% index_x_explain] - if (!is.null(id_combination)) { - id_combination_aux <- id_combination - MSEv_combination <- MSEv_combination[id_combination %in% id_combination_aux] + if (!is.null(id_coalition)) { + id_coalition_aux <- id_coalition + MSEv_coalition <- MSEv_coalition[id_coalition %in% id_coalition_aux] } - return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_combination = MSEv_combination)) + return(list(MSEv = MSEv, MSEv_explicand = MSEv_explicand, MSEv_coalition = MSEv_coalition)) } #' @keywords internal #' @author Lars Henry Berge Olsen make_MSEv_bar_plot <- function(MSEv_dt, - n_combinations, + n_coalitions, n_explain, tfrac = NULL, geom_col_width = 0.9) { @@ -1216,16 +1259,16 @@ make_MSEv_bar_plot <- function(MSEv_dt, ggplot2::labs( x = "Method", y = bquote(MSE[v]), - title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations and" ~ .(n_explain) ~ "explicands") + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions and" ~ .(n_explain) ~ "explicands") ) if (!is.null(tfrac)) { CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) MSEv_bar <- MSEv_bar + - ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations and" ~ .(n_explain) ~ "explicands with" ~ + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions and" ~ .(n_explain) ~ "explicands with" ~ .(CI_level * 100) * "% CI")) + ggplot2::geom_errorbar( position = ggplot2::position_dodge(geom_col_width), @@ -1244,15 +1287,15 @@ make_MSEv_bar_plot <- function(MSEv_dt, #' @keywords internal #' @author Lars Henry Berge Olsen make_MSEv_explicand_plots <- function(MSEv_explicand_dt, - n_combinations, + n_coalitions, geom_col_width = 0.9) { MSEv_explicand_source <- ggplot2::ggplot(MSEv_explicand_dt, ggplot2::aes(x = id, y = MSEv)) + ggplot2::labs( x = "index_x_explain", y = bquote(MSE[v] ~ "(explicand)"), - title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_combinations) ~ - "combinations for each explicand") + title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_coalitions) ~ + "coalitions for each explicand") ) MSEv_explicand_bar <- @@ -1278,21 +1321,21 @@ make_MSEv_explicand_plots <- function(MSEv_explicand_dt, #' @keywords internal #' @author Lars Henry Berge Olsen -make_MSEv_combination_plots <- function(MSEv_combination_dt, - n_explain, - tfrac = NULL, - geom_col_width = 0.9) { - MSEv_combination_source <- - ggplot2::ggplot(MSEv_combination_dt, ggplot2::aes(x = id_combination, y = MSEv)) + +make_MSEv_coalition_plots <- function(MSEv_coalition_dt, + n_explain, + tfrac = NULL, + geom_col_width = 0.9) { + MSEv_coalition_source <- + ggplot2::ggplot(MSEv_coalition_dt, ggplot2::aes(x = id_coalition, y = MSEv)) + ggplot2::labs( - x = "id_combination", - y = bquote(MSE[v] ~ "(combination)"), + x = "id_coalition", + y = bquote(MSE[v] ~ "(coalition)"), title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ - "explicands for each combination") + "explicands for each coalition") ) - MSEv_combination_bar <- - MSEv_combination_source + + MSEv_coalition_bar <- + MSEv_coalition_source + ggplot2::geom_col( width = geom_col_width, position = ggplot2::position_dodge(geom_col_width), @@ -1302,10 +1345,10 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, if (!is.null(tfrac)) { CI_level <- 1 - 2 * (1 - pt(tfrac, n_explain - 1)) - MSEv_combination_bar <- - MSEv_combination_bar + + MSEv_coalition_bar <- + MSEv_coalition_bar + ggplot2::labs(title = bquote(MSE[v] ~ "criterion averaged over the" ~ .(n_explain) ~ - "explicands for each combination with" ~ .(CI_level * 100) * "% CI")) + + "explicands for each coalition with" ~ .(CI_level * 100) * "% CI")) + ggplot2::geom_errorbar( position = ggplot2::position_dodge(geom_col_width), width = 0.25, @@ -1317,16 +1360,16 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, ) } - MSEv_combination_line_point <- - MSEv_combination_source + - ggplot2::aes(x = as.numeric(id_combination)) + - ggplot2::labs(x = "id_combination") + + MSEv_coalition_line_point <- + MSEv_coalition_source + + ggplot2::aes(x = as.numeric(id_coalition)) + + ggplot2::labs(x = "id_coalition") + ggplot2::geom_point(ggplot2::aes(col = Method)) + ggplot2::geom_line(ggplot2::aes(group = Method, col = Method)) return(list( - MSEv_combination_bar = MSEv_combination_bar, - MSEv_combination_line_point = MSEv_combination_line_point + MSEv_coalition_bar = MSEv_coalition_bar, + MSEv_coalition_line_point = MSEv_coalition_line_point )) } @@ -1334,7 +1377,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' #' @description #' Make plots to visualize and compare the estimated Shapley values for a list of -#' [shapr::explain()] objects applied to the same data and model. +#' [shapr::explain()] objects applied to the same data and model. For group-wise Shapley values, +#' the features values plotted are the mean feature values for all features in each group. #' #' @param explanation_list A list of [shapr::explain()] objects applied to the same data and model. #' If the entries in the list is named, then the function use these names. Otherwise, it defaults to @@ -1342,6 +1386,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' @param index_explicands Integer vector. Which of the explicands (test observations) to plot. #' E.g. if you have explained 10 observations using [shapr::explain()], you can generate a plot for the #' first 5 observations/explicands and the 10th by setting `index_x_explain = c(1:5, 10)`. +#' The argument `index_explicands_sort` must be `FALSE` to plot the explicand +#' in the order specified in `index_x_explain`. #' @param only_these_features String vector. Containing the names of the features which #' are to be included in the bar plots. #' @param plot_phi0 Boolean. If we are to include the \eqn{\phi_0} in the bar plots or not. @@ -1368,6 +1414,11 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' ("`free_x`", "`free_y`")? The user has to change the latter manually depending on the value of `horizontal_bars`. #' @param facet_ncol Integer. The number of columns in the facet grid. Default is `facet_ncol = 2`. #' @param geom_col_width Numeric. Bar width. By default, set to 85% of the [ggplot2::resolution()] of the data. +#' @param include_group_feature_means Logical. Whether to include the average feature value in a group on the +#' y-axis or not. If `FALSE` (default), then no value is shown for the groups. If `TRUE`, then `shapr` includes +#' the mean of the features in each group. +#' @param index_explicands_sort Boolean. If `FALSE` (default), then `shapr` plots the explicands in the order +#' specified in `index_explicands`. If `TRUE`, then `shapr` sort the indices in incressing oreder based on their id. #' #' @return A [ggplot2::ggplot()] object. #' @export @@ -1401,7 +1452,7 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' ) #' #' # Specifying the phi_0, i.e. the expected prediction without any features -#' prediction_zero <- mean(y_train) +#' phi0 <- mean(y_train) #' #' # Independence approach #' explanation_independence <- explain( @@ -1409,8 +1460,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "independence", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Empirical approach @@ -1419,8 +1470,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "empirical", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Gaussian 1e1 approach @@ -1429,8 +1480,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e1 +#' phi0 = phi0, +#' n_MC_samples = 1e1 #' ) #' #' # Gaussian 1e2 approach @@ -1439,8 +1490,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = "gaussian", -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Combined approach @@ -1449,8 +1500,8 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' x_explain = x_explain, #' x_train = x_train, #' approach = c("gaussian", "ctree", "empirical"), -#' prediction_zero = prediction_zero, -#' n_samples = 1e2 +#' phi0 = phi0, +#' n_MC_samples = 1e2 #' ) #' #' # Create a list of explanations with names @@ -1506,6 +1557,7 @@ make_MSEv_combination_plots <- function(MSEv_combination_dt, #' @author Lars Henry Berge Olsen plot_SV_several_approaches <- function(explanation_list, index_explicands = NULL, + index_explicands_sort = FALSE, only_these_features = NULL, plot_phi0 = FALSE, digits = 4, @@ -1516,7 +1568,8 @@ plot_SV_several_approaches <- function(explanation_list, facet_scales = "free", facet_ncol = 2, geom_col_width = 0.85, - brewer_palette = NULL) { + brewer_palette = NULL, + include_group_feature_means = FALSE) { # Setup and checks ---------------------------------------------------------------------------- # Check that ggplot2 is installed if (!requireNamespace("ggplot2", quietly = TRUE)) { @@ -1533,7 +1586,7 @@ plot_SV_several_approaches <- function(explanation_list, if (any(names(explanation_list) == "")) stop("All the entries in `explanation_list` must be named.") # Check that the column names for the Shapley values are the same for all explanations in the `explanation_list` - if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values)))) != 1) { + if (length(unique(lapply(explanation_list, function(explanation) colnames(explanation$shapley_values_est)))) != 1) { stop("The Shapley value feature names are not identical in all objects in the `explanation_list`.") } @@ -1578,10 +1631,17 @@ plot_SV_several_approaches <- function(explanation_list, only_these_features_wo_none = only_these_features_wo_none, index_explicands = index_explicands, horizontal_bars = horizontal_bars, - digits = digits + digits = digits, + include_group_feature_means = include_group_feature_means ) - # Melt `dt_Shapley_values` and merge with `dt_desc_long` to creat data.table ready to be plotted with ggplot2 + # Set the explicands to the same order as they were given + if (!index_explicands_sort) { + dt_Shapley_values[, .id := factor(.id, levels = index_explicands, ordered = TRUE)] + dt_desc_long[, .id := factor(.id, levels = index_explicands, ordered = TRUE)] + } + + # Melt `dt_Shapley_values` and merge with `dt_desc_long` to create data.table ready to be plotted with ggplot2 dt_Shapley_values_long <- create_Shapley_value_figure_dt( dt_Shapley_values = dt_Shapley_values, dt_desc_long = dt_desc_long, @@ -1648,7 +1708,7 @@ update_only_these_features <- function(explanation_list, # Update the `only_these_features` parameter vector based on `plot_phi0` or in case it is NULL # Get the common feature names for all explanation objects (including `none`) and one without `none` - feature_names_with_none <- colnames(explanation_list[[1]]$shapley_values) + feature_names_with_none <- colnames(explanation_list[[1]]$shapley_values_est)[-1] feature_names_without_none <- feature_names_with_none[feature_names_with_none != "none"] # Only keep the desired features/columns @@ -1699,7 +1759,7 @@ extract_Shapley_values_dt <- function(explanation_list, lapply( explanation_list, function(explanation) { - data.table::copy(explanation$shapley_values)[, c(".id", ".pred") := list(.I, explanation$pred_explain)] + data.table::copy(explanation$shapley_values_est)[, c(".id", ".pred") := list(.I, explanation$pred_explain)] } ), use.names = TRUE, @@ -1707,10 +1767,7 @@ extract_Shapley_values_dt <- function(explanation_list, ) # Convert to factors - dt_Shapley_values$.method <- factor(dt_Shapley_values$.method, - levels = names(explanation_list), - ordered = TRUE - ) + dt_Shapley_values$.method <- factor(dt_Shapley_values$.method, levels = names(explanation_list), ordered = TRUE) # Set the keys and change the order of the columns data.table::setkeyv(dt_Shapley_values, c(".id", ".method")) @@ -1782,14 +1839,49 @@ create_feature_descriptions_dt <- function(explanation_list, only_these_features_wo_none, index_explicands, horizontal_bars, - digits) { - # Get the explicands - x_explain <- - explanation_list[[1]]$internal$data$x_explain[index_explicands, only_these_features_wo_none, with = FALSE] + digits, + include_group_feature_means) { + # Check if are dealing with group-wise or feature-wise Shapley values + if (explanation_list[[1]]$internal$parameters$is_groupwise) { + # Group-wise Shapley values + + if (include_group_feature_means && any(explanation_list[[1]]$internal$objects$feature_specs$classes != "numeric")) { + stop("`include_group_feature_means` cannot be `TRUE` for datasets with non-numerical features.") + } + + # Get the relevant explicands + x_explain <- explanation_list[[1]]$internal$data$x_explain[index_explicands] + + # Check if we are to compute the mean feature value within each group for each explicand + if (include_group_feature_means) { + feature_groups <- explanation_list[[1]]$internal$parameters$group + x_explain <- + x_explain[, lapply(feature_groups, function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE))] + + # Extract only the relevant columns + x_explain <- x_explain[, only_these_features_wo_none, with = FALSE] + + # Create the description matrix + desc_mat <- trimws(format(x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) + } else { + # Create the description matrix + desc_mat <- matrix(rep(only_these_features_wo_none, each = nrow(x_explain)), nrow = nrow(x_explain)) + colnames(desc_mat) <- only_these_features_wo_none + } + } else { + # Feature-wise Shapley values + + # Get the relevant explicands + x_explain <- + explanation_list[[1]]$internal$data$x_explain[index_explicands, only_these_features_wo_none, with = FALSE] + + # Create the description matrix + desc_mat <- trimws(format(x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) + } # Converting and melting the explicands - desc_mat <- trimws(format(x_explain, digits = digits)) - for (i in seq_len(ncol(desc_mat))) desc_mat[, i] <- paste0(colnames(desc_mat)[i], " = ", desc_mat[, i]) dt_desc <- data.table::as.data.table(cbind(none = "None", desc_mat)) dt_desc_long <- data.table::melt(dt_desc[, .id := index_explicands], id.vars = ".id", @@ -1800,10 +1892,7 @@ create_feature_descriptions_dt <- function(explanation_list, # Make the description into an ordered factor such that the features in the # bar plots follow the same order of features as in the training data. levels <- if (horizontal_bars) rev(unique(dt_desc_long$.description)) else unique(dt_desc_long$.description) - dt_desc_long$.description <- factor(dt_desc_long$.description, - levels = levels, - ordered = TRUE - ) + dt_desc_long$.description <- factor(dt_desc_long$.description, levels = levels, ordered = TRUE) return(dt_desc_long) } diff --git a/R/prepare_next_iteration.R b/R/prepare_next_iteration.R new file mode 100644 index 0000000000000000000000000000000000000000..13bd231bcd1b3728a13af745028921dd2acc2be7 --- /dev/null +++ b/R/prepare_next_iteration.R @@ -0,0 +1,80 @@ +#' Prepares the next iteration of the iterative sampling algorithm +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +prepare_next_iteration <- function(internal) { + iter <- length(internal$iter_list) + converged <- internal$iter_list[[iter]]$converged + paired_shap_sampling <- internal$parameters$paired_shap_sampling + + + if (converged == FALSE) { + next_iter_list <- list() + + n_shapley_values <- internal$parameters$n_shapley_values + n_coal_next_iter_factor_vec <- internal$parameters$iterative_args$n_coal_next_iter_factor_vec + fixed_n_coalitions_per_iter <- internal$parameters$iterative_args$fixed_n_coalitions_per_iter + max_n_coalitions <- internal$parameters$iterative_args$max_n_coalitions + + + est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions + n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor + current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions + current_coal_samples <- internal$iter_list[[iter]]$coal_samples + + if (is.null(fixed_n_coalitions_per_iter)) { + proposal_next_n_coalitions <- current_n_coalitions + ceiling(est_remaining_coalitions * n_coal_next_iter_factor) + } else { + proposal_next_n_coalitions <- current_n_coalitions + fixed_n_coalitions_per_iter + } + + # Thresholding if max_n_coalitions is reached + proposal_next_n_coalitions <- min( + max_n_coalitions, + proposal_next_n_coalitions + ) + + if (paired_shap_sampling) { + proposal_next_n_coalitions <- ceiling(proposal_next_n_coalitions * 0.5) * 2 + } + + + if ((proposal_next_n_coalitions) >= 2^n_shapley_values) { + # Use all coalitions in the last iteration as the estimated number of samples is more than what remains + next_iter_list$exact <- TRUE + next_iter_list$n_coalitions <- 2^n_shapley_values + next_iter_list$compute_sd <- FALSE + } else { + # Sample more keeping the current samples + next_iter_list$exact <- FALSE + next_iter_list$n_coalitions <- proposal_next_n_coalitions + next_iter_list$compute_sd <- TRUE + } + + if (!is.null(n_coal_next_iter_factor_vec[1])) { + next_iter_list$n_coal_next_iter_factor <- ifelse( + length(n_coal_next_iter_factor_vec) >= iter, + n_coal_next_iter_factor_vec[iter], + n_coal_next_iter_factor_vec[length(n_coal_next_iter_factor_vec)] + ) + } else { + next_iter_list$n_coal_next_iter_factor <- NULL + } + + next_iter_list$new_n_coalitions <- next_iter_list$n_coalitions - current_n_coalitions + + next_iter_list$n_batches <- set_n_batches(next_iter_list$new_n_coalitions, internal) + + + next_iter_list$prev_coal_samples <- current_coal_samples + } else { + next_iter_list <- list() + } + + internal$iter_list[[iter + 1]] <- next_iter_list + + + return(internal) +} diff --git a/R/print.R b/R/print.R index 4977a9974995b5c3d394f0e9b67895d6070d5cb9..573cc36e6f739c0c3648f2034f04eed7eab72c8a 100644 --- a/R/print.R +++ b/R/print.R @@ -1,4 +1,8 @@ #' @export print.shapr <- function(x, digits = 4, ...) { - print(x$shapley_values, digits = digits) + shap <- copy(x$shapley_values_est) + shap_names <- x$internal$parameters$shap_names + cols <- c("none", shap_names) + shap[, (cols) := lapply(.SD, round, digits = digits + 2), .SDcols = cols] + print(shap, digits = digits) } diff --git a/R/print_iter.R b/R/print_iter.R new file mode 100644 index 0000000000000000000000000000000000000000..174eea7abf7915effb533103ff2b3963865d9ab3 --- /dev/null +++ b/R/print_iter.R @@ -0,0 +1,109 @@ +#' Prints iterative information +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +print_iter <- function(internal) { + verbose <- internal$parameters$verbose + iter <- length(internal$iter_list) - 1 # This function is called after the preparation of the next iteration + + converged <- internal$iter_list[[iter]]$converged + converged_exact <- internal$iter_list[[iter]]$converged_exact + converged_sd <- internal$iter_list[[iter]]$converged_sd + converged_max_iter <- internal$iter_list[[iter]]$converged_max_iter + converged_max_n_coalitions <- internal$iter_list[[iter]]$converged_max_n_coalitions + overall_conv_measure <- internal$iter_list[[iter]]$overall_conv_measure + n_coal_next_iter_factor <- internal$iter_list[[iter]]$n_coal_next_iter_factor + + saving_path <- internal$parameters$output_args$saving_path + convergence_tol <- internal$parameters$iterative_args$convergence_tol + testing <- internal$parameters$testing + + if ("convergence" %in% verbose) { + convergence_tol <- internal$parameters$iterative_args$convergence_tol + + current_n_coalitions <- internal$iter_list[[iter]]$n_coalitions + est_remaining_coalitions <- internal$iter_list[[iter]]$est_remaining_coalitions + est_required_coalitions <- internal$iter_list[[iter]]$est_required_coalitions + + next_n_coalitions <- internal$iter_list[[iter + 1]]$n_coalitions + next_new_n_coalitions <- internal$iter_list[[iter + 1]]$new_n_coalitions + + cli::cli_h3("Convergence info") + + if (isFALSE(converged)) { + msg <- "Not converged after {current_n_coalitions} coalitions:\n" + + if (!is.null(convergence_tol)) { + conv_nice <- signif(overall_conv_measure, 2) + tol_nice <- format(signif(convergence_tol, 2), scientific = FALSE) + n_coal_next_iter_factor_nice <- format(signif(n_coal_next_iter_factor * 100, 2), scientific = FALSE) + msg <- paste0( + msg, + "Current convergence measure: {conv_nice} [needs {tol_nice}]\n", + "Estimated remaining coalitions: {est_remaining_coalitions}\n", + "(Concervatively) adding {n_coal_next_iter_factor_nice}% of that ({next_new_n_coalitions} coalitions) ", + "in the next iteration." + ) + } + cli::cli_alert_info(msg) + } else { + msg <- "Converged after {current_n_coalitions} coalitions:\n" + if (isTRUE(converged_exact)) { + msg <- paste0( + msg, + "All ({current_n_coalitions}) coalitions used.\n" + ) + } + if (isTRUE(converged_sd)) { + msg <- paste0( + msg, + "Convergence tolerance reached!\n" + ) + } + if (isTRUE(converged_max_iter)) { + msg <- paste0( + msg, + "Maximum number of iterations reached!\n" + ) + } + if (isTRUE(converged_max_n_coalitions)) { + msg <- paste0( + msg, + "Maximum number of coalitions reached!\n" + ) + } + cli::cli_alert_success(msg) + } + } + + if ("shapley" %in% verbose) { + n_explain <- internal$parameters$n_explain + + dt_shapley_est <- internal$iter_list[[iter]]$dt_shapley_est[, -1] + dt_shapley_sd <- internal$iter_list[[iter]]$dt_shapley_sd[, -1] + + # Printing the current Shapley values + matrix1 <- format(round(dt_shapley_est, 3), nsmall = 2, justify = "right") + matrix2 <- format(round(dt_shapley_sd, 2), nsmall = 2, justify = "right") + + if (isTRUE(converged)) { + msg <- "Final " + } else { + msg <- "Current " + } + + if (converged_exact) { + msg <- paste0(msg, "estimated Shapley values") + print_dt <- as.data.table(matrix1) + } else { + msg <- paste0(msg, "estimated Shapley values (sd)") + print_dt <- as.data.table(matrix(paste(matrix1, " (", matrix2, ") ", sep = ""), nrow = n_explain)) + } + + cli::cli_h3(msg) + names(print_dt) <- names(dt_shapley_est) + print(print_dt) + } +} diff --git a/R/save_results.R b/R/save_results.R new file mode 100644 index 0000000000000000000000000000000000000000..cef0e97b960a8e029fc22c5a28db2c3db1ba268a --- /dev/null +++ b/R/save_results.R @@ -0,0 +1,22 @@ +#' Saves the itermediate results to disk +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +save_results <- function(internal) { + saving_path <- internal$parameters$output_args$saving_path + + # Modify name for the new file + filename <- basename(saving_path) + dirname <- dirname(saving_path) + filename_copy <- paste0("new_", filename) + saving_path_copy <- file.path(dirname, filename_copy) + + # Save the results to a new location, then delete old and rename for safe code interruption + + # Saving parameters and iter_list + saveRDS(internal[c("parameters", "iter_list")], saving_path_copy) + if (file.exists(saving_path)) file.remove(saving_path) + file.rename(saving_path_copy, saving_path) +} diff --git a/R/setup.R b/R/setup.R index 5f2f2b5489792f9a3d9f40bc65c9d8303ce48bb8..904c7cdec19138675b7ab025f0393f4d3db7056c 100644 --- a/R/setup.R +++ b/R/setup.R @@ -16,20 +16,24 @@ #' @param is_python Logical. Indicates whether the function is called from the Python wrapper. Default is FALSE which is #' never changed when calling the function via `explain()` in R. The parameter is later used to disallow #' running the AICc-versions of the empirical as that requires data based optimization. +#' @param testing Logical. +#' Only use to remove random components like timing from the object output when comparing output with testthat. +#' Defaults to `FALSE`. +#' @param init_time POSIXct object. +#' The time when the `explain()` function was called, as outputted by `Sys.time()`. +#' Used to calculate the time it took to run the full `explain` call. #' @export setup <- function(x_train, x_explain, approach, - prediction_zero, + paired_shap_sampling = TRUE, + phi0, output_size = 1, - n_combinations, + max_n_coalitions, group, - n_samples, - n_batches, + n_MC_samples, seed, - keep_samp_for_vS, feature_specs, - MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -39,22 +43,45 @@ setup <- function(x_train, explain_y_lags = NULL, explain_xreg_lags = NULL, group_lags = NULL, - timing, verbose, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "none", is_python = FALSE, + testing = FALSE, + init_time = NULL, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + output_args = list(), + extra_computation_args = list(), ...) { internal <- list() + # Using parameters and iter_list from a previouys to continue estimation from on previous shapr objects + if (is.null(prev_shapr_object)) { + prev_iter_list <- NULL + } else { + prev_internal <- get_prev_internal(prev_shapr_object) + + prev_iter_list <- prev_internal$iter_list + + # Overwrite the input arguments set in explain() with those from in prev_shapr_object + # except model, x_explain, x_train, max_n_coalitions, iterative_args, seed + list2env(prev_internal$parameters) + } + + internal$parameters <- get_parameters( approach = approach, - prediction_zero = prediction_zero, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, output_size = output_size, - n_combinations = n_combinations, + max_n_coalitions = max_n_coalitions, group = group, - n_samples = n_samples, - n_batches = n_batches, + n_MC_samples = n_MC_samples, seed = seed, - keep_samp_for_vS = keep_samp_for_vS, type = type, horizon = horizon, train_idx = train_idx, @@ -62,10 +89,17 @@ setup <- function(x_train, explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, group_lags = group_lags, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, verbose = verbose, + iterative = iterative, + iterative_args = iterative_args, + kernelSHAP_reweighting = kernelSHAP_reweighting, is_python = is_python, + testing = testing, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + output_args = output_args, + extra_computation_args = extra_computation_args, ... ) @@ -77,9 +111,9 @@ setup <- function(x_train, colnames(internal$parameters$output_labels) <- c("explain_idx", "horizon") internal$parameters$explain_idx <- explain_idx internal$parameters$explain_lags <- list(y = explain_y_lags, xreg = explain_xreg_lags) + internal$parameters$group_lags <- group_lags # TODO: Consider handling this parameter update somewhere else (like in get_extra_parameters?) - if (group_lags) internal$parameters$group <- internal$data$group } else { internal$data <- get_data(x_train, x_explain) } @@ -88,152 +122,282 @@ setup <- function(x_train, check_data(internal) - internal <- get_extra_parameters(internal) # This includes both extra parameters and other objects + internal <- get_extra_parameters(internal, type) # This includes both extra parameters and other objects + + internal <- check_and_set_parameters(internal, type) + + internal <- set_iterative_parameters(internal, prev_iter_list) - internal <- check_and_set_parameters(internal) + internal$timing_list <- list( + init_time = init_time, + setup = Sys.time() + ) return(internal) } -#' @keywords internal -check_and_set_parameters <- function(internal) { - # Check groups - feature_names <- internal$parameters$feature_names - group <- internal$parameters$group - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features - n_groups <- internal$parameters$n_groups - is_groupwise <- internal$parameters$is_groupwise - exact <- internal$parameters$exact - - if (!is.null(group)) check_groups(feature_names, group) +get_prev_internal <- function(prev_shapr_object, + exclude_parameters = c("max_n_coalitions", "iterative_args", "seed")) { + cl <- class(prev_shapr_object)[1] - if (exact) { - internal$parameters$used_n_combinations <- if (is_groupwise) 2^n_groups else 2^n_features + if (cl == "character") { + internal <- readRDS(file = prev_shapr_object) # Already contains only "parameters" and "iter_list" + } else if (cl == "shapr") { + internal <- prev_shapr_object$internal[c("parameters", "iter_list")] } else { - internal$parameters$used_n_combinations <- - if (is_groupwise) min(2^n_groups, n_combinations) else min(2^n_features, n_combinations) - check_n_combinations(internal) + stop("Invalid `shapr_object` passed to explain(). See ?explain for details.") } - # Check approach - check_approach(internal) - - # Setting default value for n_batches (when NULL) - internal <- set_defaults(internal) + if (length(exclude_parameters) > 0) { + internal$parameters[exclude_parameters] <- NULL + } - # Checking n_batches vs n_combinations etc - check_n_batches(internal) + iter <- length(internal$iter_list) + internal$iter_list[[iter]]$converged <- FALSE # hard setting the convergence parameter - # Check regression if we are doing regression - if (internal$parameters$regression) internal <- regression.check(internal) return(internal) } + #' @keywords internal -#' @author Lars Henry Berge Olsen -regression.check <- function(internal) { - # Check that the model outputs one-dimensional predictions - if (internal$parameters$output_size != 1) { - stop("`regression_separate` and `regression_surrogate` only support models with one-dimensional output") +get_parameters <- function(approach, + paired_shap_sampling, + phi0, + output_size = 1, + max_n_coalitions, + group, + n_MC_samples, + seed, + type, + horizon, + train_idx, + explain_idx, + explain_y_lags, + explain_xreg_lags, + group_lags = NULL, + verbose = "basic", + iterative = FALSE, + iterative_args = list(), + kernelSHAP_reweighting = "none", + asymmetric, + causal_ordering, + confounding, + is_python, + output_args = list(), + extra_computation_args = list(), + testing = FALSE, + ...) { + # Check input type for approach + + # approach is checked more comprehensively later + if (!is.logical(paired_shap_sampling) && length(paired_shap_sampling) == 1) { + stop("`paired_shap_sampling` must be a single logical.") } - # Check that we are NOT explaining a forecast model - if (internal$parameters$type == "forecast") { - stop("`regression_separate` and `regression_surrogate` does not support `forecast`.") + if (!is.logical(iterative) && length(iterative) == 1) { + stop("`iterative` must be a single logical.") + } + if (!is.list(iterative_args)) { + stop("`iterative_args` must be a list.") + } + if (!is.list(output_args)) { + stop("`output_args` must be a list.") + } + if (!is.list(extra_computation_args)) { + stop("`extra_computation_args` must be a list.") } - # Check that we are not to keep the Monte Carlo samples - if (internal$parameters$keep_samp_for_vS) { - stop(paste( - "`keep_samp_for_vS` must be `FALSE` for the `regression_separate` and `regression_surrogate`", - "approaches as there are no Monte Carlo samples to keep for these approaches." - )) + + + # max_n_coalitions + if (!is.null(max_n_coalitions) && + !(is.wholenumber(max_n_coalitions) && + length(max_n_coalitions) == 1 && + !is.na(max_n_coalitions) && + max_n_coalitions > 0)) { + stop("`max_n_coalitions` must be NULL or a single positive integer.") + } + + # group (checked more thoroughly later) + if (!is.null(group) && + !is.list(group)) { + stop("`group` must be NULL or a list") } - # Remove n_samples if we are doing regression, as we are not doing MC sampling - internal$parameters$n_samples <- NULL + # n_MC_samples + if (!(is.wholenumber(n_MC_samples) && + length(n_MC_samples) == 1 && + !is.na(n_MC_samples) && + n_MC_samples > 0)) { + stop("`n_MC_samples` must be a single positive integer.") + } - return(internal) -} -#' @keywords internal -check_n_combinations <- function(internal) { - is_groupwise <- internal$parameters$is_groupwise - n_combinations <- internal$parameters$n_combinations - n_features <- internal$parameters$n_features - n_groups <- internal$parameters$n_groups + # type + if (!(type %in% c("normal", "forecast"))) { + stop("`type` must be either `normal` or `forecast`.\n") + } - type <- internal$parameters$type + # verbose + check_verbose(verbose) + if (!is.null(verbose) && + (!is.character(verbose) || !(all(verbose %in% c("basic", "progress", "convergence", "shapley", "vS_details")))) + ) { + stop( + paste0( + "`verbose` must be NULL or a string (vector) containing one or more of the strings ", + "`basic`, `progress`, `convergence`, `shapley`, `vS_details`.\n" + ) + ) + } + # parameters only used for type "forecast" if (type == "forecast") { - horizon <- internal$parameters$horizon - explain_y_lags <- internal$parameters$explain_lags$y - explain_xreg_lags <- internal$parameters$explain_lags$xreg - xreg <- internal$data$xreg + if (!(is.wholenumber(horizon) && all(horizon > 0))) { + stop("`horizon` must be a vector (or scalar) of positive integers.\n") + } - if (!is_groupwise) { - if (n_combinations <= n_features) { - stop(paste0( - "`n_combinations` (", n_combinations, ") has to be greater than the number of components to decompose ", - " the forecast onto:\n", - "`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ", - "+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n" - )) - } - } else { - if (n_combinations <= n_groups) { - stop(paste0( - "`n_combinations` (", n_combinations, ") has to be greater than the number of components to decompose ", - "the forecast onto:\n", - "ncol(`xreg`) (", ncol(`xreg`), ") + 1" - )) - } + if (any(horizon != output_size)) { + stop(paste0("`horizon` must match the output size of the model (", paste0(output_size, collapse = ", "), ").\n")) } - } else { - if (!is_groupwise) { - if (n_combinations <= n_features) stop("`n_combinations` has to be greater than the number of features.") - } else { - if (n_combinations <= n_groups) stop("`n_combinations` has to be greater than the number of groups.") + + if (!(length(train_idx) > 1 && is.wholenumber(train_idx) && all(train_idx > 0) && all(is.finite(train_idx)))) { + stop("`train_idx` must be a vector of positive finite integers and length > 1.\n") } - } -} + if (!(is.wholenumber(explain_idx) && all(explain_idx > 0) && all(is.finite(explain_idx)))) { + stop("`explain_idx` must be a vector of positive finite integers.\n") + } + if (!(is.wholenumber(explain_y_lags) && all(explain_y_lags >= 0) && all(is.finite(explain_y_lags)))) { + stop("`explain_y_lags` must be a vector of positive finite integers.\n") + } -#' @keywords internal -check_n_batches <- function(internal) { - n_batches <- internal$parameters$n_batches - n_features <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - n_groups <- internal$parameters$n_groups - n_unique_approaches <- internal$parameters$n_unique_approaches + if (!(is.wholenumber(explain_xreg_lags) && all(explain_xreg_lags >= 0) && all(is.finite(explain_xreg_lags)))) { + stop("`explain_xreg_lags` must be a vector of positive finite integers.\n") + } - if (!is_groupwise) { - actual_n_combinations <- ifelse(is.null(n_combinations), 2^n_features, n_combinations) - } else { - actual_n_combinations <- ifelse(is.null(n_combinations), 2^n_groups, n_combinations) + if (!(is.logical(group_lags) && length(group_lags) == 1)) { + stop("`group_lags` must be a single logical.\n") + } } - if (n_batches >= actual_n_combinations) { + + # Parameter used in asymmetric and causal Shapley values (more in-depth checks later) + if (!is.logical(asymmetric) || length(asymmetric) != 1) stop("`asymmetric` must be a single logical.\n") + if (!is.null(confounding) && !is.logical(confounding)) stop("`confounding` must be a logical (vector).\n") + if (!is.null(causal_ordering) && !is.list(causal_ordering)) stop("`causal_ordering` must be a list.\n") + + #### Tests combining more than one parameter #### + # phi0 vs output_size + if (!all((is.numeric(phi0)) && + all(length(phi0) == output_size) && + all(!is.na(phi0)))) { stop(paste0( - "`n_batches` (", n_batches, ") must be smaller than the number of feature combinations/`n_combinations` (", - actual_n_combinations, ")" + "`phi0` (", paste0(phi0, collapse = ", "), + ") must be numeric and match the output size of the model (", + paste0(output_size, collapse = ", "), ")." )) } - if (n_batches < n_unique_approaches) { - stop(paste0( - "`n_batches` (", n_batches, ") must be larger than the number of unique approaches in `approach` (", - n_unique_approaches, ")." - )) + # type + if (!(length(kernelSHAP_reweighting) == 1 && kernelSHAP_reweighting %in% + c("none", "on_N", "on_coal_size", "on_all", "on_N_sum", "on_all_cond", "on_all_cond_paired", "comb"))) { + stop( + "`kernelSHAP_reweighting` must be one of `none`, `on_N`, `on_coal_size`, `on_N_sum`, ", + "`on_all`, `on_all_cond`, `on_all_cond_paired` or `comb`.\n" + ) + } + + + # Getting basic input parameters + parameters <- list( + approach = approach, + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = max_n_coalitions, + group = group, + n_MC_samples = n_MC_samples, + seed = seed, + is_python = is_python, + output_size = output_size, + type = type, + horizon = horizon, + group_lags = group_lags, + verbose = verbose, + kernelSHAP_reweighting = kernelSHAP_reweighting, + iterative = iterative, + iterative_args = iterative_args, + output_args = output_args, + extra_computation_args = extra_computation_args, + asymmetric = asymmetric, + causal_ordering = causal_ordering, + confounding = confounding, + testing = testing + ) + + # Getting additional parameters from ... + parameters <- append(parameters, list(...)) + + # Set boolean to represent if a regression approach is used (any in case of several approaches) + parameters$regression <- any(grepl("regression", parameters$approach)) + + return(parameters) +} + +#' Function that checks the verbose parameter +#' +#' @inheritParams explain +#' +#' @return The function does not return anything. +#' +#' @keywords internal +#' @author Lars Henry Berge Olsen, Martin Jullum +check_verbose <- function(verbose) { + if (!is.null(verbose) && + (!is.character(verbose) || !(all(verbose %in% c("basic", "progress", "convergence", "shapley", "vS_details")))) + ) { + stop( + paste0( + "`verbose` must be NULL or a string (vector) containing one or more of the strings ", + "`basic`, `progress`, `convergence`, `shapley`, `vS_details`.\n" + ) + ) } } +#' @keywords internal +get_data <- function(x_train, x_explain) { + # Check data object type + stop_message <- "" + if (!is.matrix(x_train) && !is.data.frame(x_train)) { + stop_message <- paste0(stop_message, "x_train should be a matrix or a data.frame/data.table.\n") + } + if (!is.matrix(x_explain) && !is.data.frame(x_explain)) { + stop_message <- paste0(stop_message, "x_explain should be a matrix or a data.frame/data.table.\n") + } + if (stop_message != "") { + stop(stop_message) + } + + # Check column names + if (all(is.null(colnames(x_train)))) { + stop_message <- paste0(stop_message, "x_train misses column names.\n") + } + if (all(is.null(colnames(x_explain)))) { + stop_message <- paste0(stop_message, "x_explain misses column names.\n") + } + if (stop_message != "") { + stop(stop_message) + } + data <- list( + x_train = data.table::as.data.table(x_train), + x_explain = data.table::as.data.table(x_explain) + ) +} #' @keywords internal @@ -292,27 +456,6 @@ check_data <- function(internal) { compare_feature_specs(x_train_feature_specs, x_explain_feature_specs, "x_train", "x_explain") } -compare_vecs <- function(vec1, vec2, vec_type, name1, name2) { - if (!identical(vec1, vec2)) { - if (is.null(names(vec1))) { - text_vec1 <- paste(vec1, collapse = ", ") - } else { - text_vec1 <- paste(names(vec1), vec1, sep = ": ", collapse = ", ") - } - if (is.null(names(vec2))) { - text_vec2 <- paste(vec2, collapse = ", ") - } else { - text_vec2 <- paste(names(vec2), vec1, sep = ": ", collapse = ", ") - } - - stop(paste0( - "Feature ", vec_type, " are not identical for ", name1, " and ", name2, ".\n", - name1, " provided: ", text_vec1, ",\n", - name2, " provided: ", text_vec2, ".\n" - )) - } -} - compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_train", sort_labels = FALSE) { if (sort_labels) { compare_vecs(sort(spec1$labels), sort(spec2$labels), "names", name1, name2) @@ -334,10 +477,19 @@ compare_feature_specs <- function(spec1, spec2, name1 = "model", name2 = "x_trai } } - #' This includes both extra parameters and other objects #' @keywords internal -get_extra_parameters <- function(internal) { +get_extra_parameters <- function(internal, type) { + if (type == "forecast") { + if (internal$parameters$group_lags) { + internal$parameters$group <- internal$data$group + } + internal$parameters$horizon_features <- lapply( + internal$data$horizon_group, + function(x) as.character(unlist(internal$data$group[x])) + ) + } + # get number of features and observations to explain internal$parameters$n_features <- ncol(internal$data$x_explain) internal$parameters$n_explain <- nrow(internal$data$x_explain) @@ -361,18 +513,37 @@ get_extra_parameters <- function(internal) { "\nSuccess with message:\n Group names not provided. Assigning them the default names 'group1', 'group2', 'group3' etc." ) - names(internal$parameters$group) <- paste0("group", seq_along(group)) + names(group) <- paste0("group", seq_along(group)) } # Make group list with numeric feature indicators - internal$objects$group_num <- lapply(group, FUN = function(x) { + internal$objects$coal_feature_list <- lapply(group, FUN = function(x) { match(x, internal$parameters$feature_names) }) internal$parameters$n_groups <- length(group) + internal$parameters$group_names <- names(group) + internal$parameters$group <- group + internal$parameters$n_shapley_values <- internal$parameters$n_groups + + if (type == "forecast") { + if (internal$parameters$group_lags) { + internal$parameters$horizon_group <- internal$data$horizon_group + internal$parameters$shap_names <- internal$data$shap_names + } else { + internal$parameters$shap_names <- internal$parameters$group_names + } + } else { + # For normal explain + internal$parameters$shap_names <- internal$parameters$group_names + } } else { - internal$objects$group_num <- NULL + internal$objects$coal_feature_list <- as.list(seq_len(internal$parameters$n_features)) + internal$parameters$n_groups <- NULL + internal$parameters$group_names <- NULL + internal$parameters$shap_names <- internal$parameters$feature_names + internal$parameters$n_shapley_values <- internal$parameters$n_features } # Get the number of unique approaches @@ -382,226 +553,823 @@ get_extra_parameters <- function(internal) { return(internal) } +#' Fetches feature information from a given data set +#' +#' @param x matrix, data.frame or data.table The data to extract feature information from. +#' +#' @details This function is used to extract the feature information to be checked against the corresponding +#' information extracted from the model and other data sets. The function is called from internally +#' +#' @return A list with the following elements: +#' \describe{ +#' \item{labels}{character vector with the feature names to compute Shapley values for} +#' \item{classes}{a named character vector with the labels as names and the class types as elements} +#' \item{factor_levels}{a named list with the labels as names and character vectors with the factor levels as elements +#' (NULL if the feature is not a factor)} +#' } +#' @author Martin Jullum +#' #' @keywords internal -get_parameters <- function(approach, prediction_zero, output_size = 1, n_combinations, group, n_samples, - n_batches, seed, keep_samp_for_vS, type, horizon, train_idx, explain_idx, explain_y_lags, - explain_xreg_lags, group_lags = NULL, MSEv_uniform_comb_weights, timing, verbose, - is_python, ...) { - # Check input type for approach +#' @export +#' +#' @examples +#' # Load example data +#' data("airquality") +#' airquality <- airquality[complete.cases(airquality), ] +#' # Split data into test- and training data +#' x_train <- head(airquality, -3) +#' x_explain <- tail(airquality, 3) +#' # Split data into test- and training data +#' x_train <- data.table::as.data.table(head(airquality)) +#' x_train[, Temp := as.factor(Temp)] +#' get_data_specs(x_train) +get_data_specs <- function(x) { + feature_specs <- list() + feature_specs$labels <- names(x) + feature_specs$classes <- unlist(lapply(x, class)) + feature_specs$factor_levels <- lapply(x, levels) + + # Defining all integer values as numeric + feature_specs$classes[feature_specs$classes == "integer"] <- "numeric" + + return(feature_specs) +} + - # approach is checked more comprehensively later - # n_combinations - if (!is.null(n_combinations) && - !(is.wholenumber(n_combinations) && - length(n_combinations) == 1 && - !is.na(n_combinations) && - n_combinations > 0)) { - stop("`n_combinations` must be NULL or a single positive integer.") + +#' @keywords internal +check_and_set_parameters <- function(internal, type) { + # Check groups + feature_names <- internal$parameters$feature_names + if (type == "forecast") { + group <- internal$parameters$group[internal$parameters$horizon_group[internal$parameters$horizon][[1]]] + } else { + group <- internal$parameters$group } - # group (checked more thoroughly later) - if (!is.null(group) && - !is.list(group)) { - stop("`group` must be NULL or a list") + # Check group + if (!is.null(group)) check_groups(feature_names, group) + + # Check approach + check_approach(internal) + + # Check the arguments related to asymmetric and causal Shapley + # Check the causal_ordering, which must happen before checking the causal sampling + internal <- check_and_set_causal_ordering(internal) + if (!is.null(internal$parameters$confounding)) internal <- check_and_set_confounding(internal) + + # Check the causal sampling + internal <- check_and_set_causal_sampling(internal) + if (internal$parameters$asymmetric) internal <- check_and_set_asymmetric(internal) + + # Adjust max_n_coalitions + internal$parameters$max_n_coalitions <- adjust_max_n_coalitions(internal) + + check_max_n_coalitions_fc(internal) + + internal <- set_output_parameters(internal) + + internal <- check_and_set_iterative(internal) # sets the iterative parameter if it is NULL (default) + + # Set if we are to do exact Shapley value computations or not + internal <- set_exact(internal) + + internal <- set_extra_estimation_params(internal) + + # Give warnings to the user about long computation times + check_computability(internal) + + # Check regression if we are doing regression + if (internal$parameters$regression) internal <- check_regression(internal) + + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_causal_ordering <- function(internal) { + # Extract the needed variables/objects from the internal list + n_shapley_values <- internal$parameters$n_shapley_values + causal_ordering <- internal$parameters$causal_ordering + is_groupwise <- internal$parameters$is_groupwise + feat_group_txt <- ifelse(is_groupwise, "group", "feature") + group <- internal$parameters$group + feature_names <- internal$parameters$feature_names + group_names <- internal$parameters$group_names + + # Get the labels of the features or groups, and the number of them + labels_now <- if (is_groupwise) group_names else feature_names + + # If `causal_ordering` is NULL, then convert it to a list with a single component containing all features/groups + if (is.null(causal_ordering)) causal_ordering <- list(seq(n_shapley_values)) + + # Ensure that causal_ordering represents the causal ordering using the feature/group index representation + if (is.character(unlist(causal_ordering))) { + causal_ordering <- convert_feature_name_to_idx(causal_ordering, labels_now, feat_group_txt) + } + if (!is.numeric(unlist(causal_ordering))) { + stop(paste0( + "`causal_ordering` must be a list containg either only integers representing the ", feat_group_txt, + " indices or the ", feat_group_txt, " names as strings. See the documentation for more details.\n" + )) } - # n_samples - if (!(is.wholenumber(n_samples) && - length(n_samples) == 1 && - !is.na(n_samples) && - n_samples > 0)) { - stop("`n_samples` must be a single positive integer.") + # Ensure that causal_ordering_names represents the causal ordering using the feature name representation + causal_ordering_names <- relist(labels_now[unlist(causal_ordering)], causal_ordering) + + # Check that the we have n_features elements and that they are 1 through n_features (i.e., no duplicates). + causal_ordering_vec_sort <- sort(unlist(causal_ordering)) + if (length(causal_ordering_vec_sort) != n_shapley_values || any(causal_ordering_vec_sort != seq(n_shapley_values))) { + stop(paste0( + "`causal_ordering` is incomplete/incorrect. It must contain all ", + feat_group_txt, " names or indices exactly once.\n" + )) } - # n_batches - if (!is.null(n_batches) && - !(is.wholenumber(n_batches) && - length(n_batches) == 1 && - !is.na(n_batches) && - n_batches > 0)) { - stop("`n_batches` must be NULL or a single positive integer.") + + # For groups we need to convert from group level to feature level + if (is_groupwise) { + group_num <- unname(lapply(group, function(x) match(x, feature_names))) + causal_ordering_features <- lapply(causal_ordering, function(component_i) unlist(group_num[component_i])) + causal_ordering_features_names <- relist(feature_names[unlist(causal_ordering_features)], causal_ordering_features) + internal$parameters$causal_ordering_features <- causal_ordering_features + internal$parameters$causal_ordering_features_names <- causal_ordering_features_names } - # seed is already set, so we know it works - # keep_samp_for_vS - if (!(is.logical(timing) && - length(timing) == 1)) { - stop("`timing` must be single logical.") + # Update the parameters in the internal list + internal$parameters$causal_ordering <- causal_ordering + internal$parameters$causal_ordering_names <- causal_ordering_names + internal$parameters$causal_ordering_names_string <- + paste0("{", paste(sapply(causal_ordering_names, paste, collapse = ", "), collapse = "}, {"), "}") + + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_confounding <- function(internal) { + causal_ordering <- internal$parameters$causal_ordering + causal_ordering_names <- internal$parameters$causal_ordering_names + confounding <- internal$parameters$confounding + + # Check that confounding is either specified globally or locally + if (length(confounding) > 1 && length(confounding) != length(causal_ordering)) { + stop(paste0( + "`confounding` must either be a single logical or a vector of logicals of the same length as ", + "the number of components in `causal_ordering` (", length(causal_ordering), ").\n" + )) } - # keep_samp_for_vS - if (!(is.logical(keep_samp_for_vS) && - length(keep_samp_for_vS) == 1)) { - stop("`keep_samp_for_vS` must be single logical.") + # Replicate the global confounding value across all components + if (length(confounding) == 1) confounding <- rep(confounding, length(causal_ordering)) + + # Update the parameters in the internal list + internal$parameters$confounding <- confounding + + # String with information about which components that are subject to confounding (used by cli) + if (all(!confounding)) { + internal$parameters$confounding_string <- "No component with confounding" + } else { + internal$parameters$confounding_string <- + paste0("{", paste(sapply(causal_ordering_names[confounding], paste, collapse = ", "), collapse = "}, {"), "}") } - # type - if (!(type %in% c("normal", "forecast"))) { - stop("`type` must be either `normal` or `forecast`.\n") + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_causal_sampling <- function(internal) { + confounding <- internal$parameters$confounding + causal_ordering <- internal$parameters$causal_ordering + + # The variable `causal_sampling` represents if we are to use the causal step-wise sampling procedure. We only want to + # do that when confounding is specified, and we have a causal ordering that contains more than one component or + # if we have a single component where the features are subject to confounding. We must use `all` to support + # `confounding` being a vector, but then `length(causal_ordering) > 1`, so `causal` will be TRUE no matter what + # `confounding` vector we have. + internal$parameters$causal_sampling <- !is.null(confounding) && (length(causal_ordering) > 1 || all(confounding)) + + # For the causal/step-wise sampling procedure, we do not support multiple approaches and regression is inapplicable + if (internal$parameters$causal_sampling) { + if (internal$parameters$regression) stop("Causal Shapley values is not applicable for regression approaches.\n") + if (internal$parameters$n_approaches > 1) stop("Causal Shapley values is not applicable for combined approaches.\n") } - # verbose - if (!is.numeric(verbose) || !(verbose %in% c(0, 1, 2))) { - stop("`verbose` must be either `0` (no verbosity), `1` (low verbosity), or `2` (high verbosity).") + return(internal) +} + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_and_set_asymmetric <- function(internal) { + asymmetric <- internal$parameters$asymmetric + # exact <- internal$parameters$exact + causal_ordering <- internal$parameters$causal_ordering + max_n_coalitions <- internal$parameters$max_n_coalitions + paired_shap_sampling <- internal$parameters$paired_shap_sampling + + # Check that we are not doing paired sampling + if (paired_shap_sampling) { + stop(paste0( + "Set `paired_shap_sampling = FALSE` to compute asymmetric Shapley values.\n", + "Asymmetric Shapley values do not support paired sampling as the paired ", + "coalitions will not necessarily respect the causal ordering." + )) } - # parameters only used for type "forecast" - if (type == "forecast") { - if (!(is.wholenumber(horizon) && all(horizon > 0))) { - stop("`horizon` must be a vector (or scalar) of positive integers.\n") - } - if (any(horizon != output_size)) { - stop(paste0("`horizon` must match the output size of the model (", paste0(output_size, collapse = ", "), ").\n")) - } + # Get the number of coalitions that respects the (partial) causal ordering + max_n_coalitions_causal <- get_max_n_coalitions_causal(causal_ordering = causal_ordering) + internal$parameters$max_n_coalitions_causal <- max_n_coalitions_causal - if (!(length(train_idx) > 1 && is.wholenumber(train_idx) && all(train_idx > 0) && all(is.finite(train_idx)))) { - stop("`train_idx` must be a vector of positive finite integers and length > 1.\n") - } + # Get the coalitions that respects the (partial) causal ordering + internal$objects$dt_valid_causal_coalitions <- exact_coalition_table( + m = internal$parameters$n_shapley_values, + dt_valid_causal_coalitions = data.table(coalitions = get_valid_causal_coalitions(causal_ordering = causal_ordering)) + ) # [, c("coalitions", "shapley_weight")] TODO: TA MED ELLER IKKE? - if (!(is.wholenumber(explain_idx) && all(explain_idx > 0) && all(is.finite(explain_idx)))) { - stop("`explain_idx` must be a vector of positive finite integers.\n") - } + # Normalize the weights. Note that weight of a coalition size is even spread out among the valid coalitions + # of each size. I.e., if there is only one valid coalition of size |S|, then it gets the weight of the + # choose(M, |S|) coalitions of said size. + internal$objects$dt_valid_causal_coalitions[-c(1, .N), shapley_weight_norm := shapley_weight / sum(shapley_weight)] - if (!(is.wholenumber(explain_y_lags) && all(explain_y_lags >= 0) && all(is.finite(explain_y_lags)))) { - stop("`explain_y_lags` must be a vector of positive finite integers.\n") + # Convert the coalitions to strings. Needed when sampling the coalitions in `sample_coalition_table()`. + internal$objects$dt_valid_causal_coalitions[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + return(internal) +} + + +#' @keywords internal +adjust_max_n_coalitions <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + n_shapley_values <- internal$parameters$n_shapley_values + asymmetric <- internal$parameters$asymmetric # NULL if regular/symmetric Shapley values + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if regular/symmetric Shapley values + + + # Adjust max_n_coalitions + if (isTRUE(asymmetric)) { + # Asymmetric Shapley values + + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > max_n_coalitions_causal) { + max_n_coalitions <- max_n_coalitions_causal + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or number of coalitions respecting the causal\n", + "ordering ", max_n_coalitions_causal, ", and is therefore set to ", max_n_coalitions_causal, ".\n" + ) + ) } - if (!(is.wholenumber(explain_xreg_lags) && all(explain_xreg_lags >= 0) && all(is.finite(explain_xreg_lags)))) { - stop("`explain_xreg_lags` must be a vector of positive finite integers.\n") + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && + max_n_coalitions < min(10, n_shapley_values + 1, max_n_coalitions_causal)) { + if (max_n_coalitions_causal <= 10) { + max_n_coalitions <- max_n_coalitions_causal + message( + paste0( + "Success with message:\n", + "max_n_coalitions_causal is smaller than or equal to 10, meaning there are\n", + "so few unique causal coalitions that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to ", max_n_coalitions_causal, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_shapley_values + 1, max_n_coalitions_causal) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_shapley_values + 1 = ", n_shapley_values + 1, + " max_n_coalitions_causal = ", max_n_coalitions_causal, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", min(10, n_shapley_values + 1, max_n_coalitions_causal), ".\n" + ) + ) + } } + } else { + # Symmetric/regular Shapley values + + if (isFALSE(is_groupwise)) { # feature wise + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_features) { + max_n_coalitions <- 2^n_features + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or 2^n_features = ", 2^n_features, ", \n", + "and is therefore set to 2^n_features = ", 2^n_features, ".\n" + ) + ) + } + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_features + 1)) { + if (n_features <= 3) { + max_n_coalitions <- 2^n_features + message( + paste0( + "Success with message:\n", + "n_features is smaller than or equal to 3, meaning there are so few unique coalitions (", + 2^n_features, ") that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to 2^n_features = ", 2^n_features, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_features + 1) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_features + 1 = ", n_features + 1, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", max(10, n_features + 1), ".\n" + ) + ) + } + } + } else { # group wise + # Set max_n_coalitions to upper bound + if (is.null(max_n_coalitions) || max_n_coalitions > 2^n_groups) { + max_n_coalitions <- 2^n_groups + message( + paste0( + "Success with message:\n", + "max_n_coalitions is NULL or larger than or 2^n_groups = ", 2^n_groups, ", \n", + "and is therefore set to 2^n_groups = ", 2^n_groups, ".\n" + ) + ) + } + # Set max_n_coalitions to lower bound + if (isFALSE(is.null(max_n_coalitions)) && max_n_coalitions < min(10, n_groups + 1)) { + if (n_groups <= 3) { + max_n_coalitions <- 2^n_groups + message( + paste0( + "Success with message:\n", + "n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (", 2^n_groups, ") ", + "that we should use all to get reliable results.\n", + "max_n_coalitions is therefore set to 2^n_groups = ", 2^n_groups, ".\n" + ) + ) + } else { + max_n_coalitions <- min(10, n_groups + 1) + message( + paste0( + "Success with message:\n", + "max_n_coalitions is smaller than max(10, n_groups + 1 = ", n_groups + 1, "),", + "which will result in unreliable results.\n", + "It is therefore set to ", max(10, n_groups + 1), ".\n" + ) + ) + } + } + } + } - if (!(is.logical(group_lags) && length(group_lags) == 1)) { - stop("`group_lags` must be a single logical.\n") + return(max_n_coalitions) +} + +check_max_n_coalitions_fc <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + + type <- internal$parameters$type + + if (type == "forecast") { + horizon <- internal$parameters$horizon + explain_y_lags <- internal$parameters$explain_lags$y + explain_xreg_lags <- internal$parameters$explain_lags$xreg + xreg <- internal$data$xreg + + if (!is_groupwise) { + if (max_n_coalitions <= n_features) { + stop(paste0( + "`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", + "`horizon` (", horizon, ") + `explain_y_lags` (", explain_y_lags, ") ", + "+ sum(`explain_xreg_lags`) (", sum(explain_xreg_lags), ").\n" + )) + } + } else { + if (max_n_coalitions <= n_groups) { + stop(paste0( + "`max_n_coalitions` (", max_n_coalitions, ") has to be greater than the number of ", + "components to decompose the forecast onto:\n", + "ncol(`xreg`) (", ncol(`xreg`), ") + 1" + )) + } } } +} + +#' @author Martin Jullum +#' @keywords internal +set_output_parameters <- function(internal) { + output_args <- internal$parameters$output_args + + # Get defaults + output_args <- utils::modifyList(get_output_args_default(), + output_args, + keep.null = TRUE + ) + + check_output_args(output_args) + + internal$parameters$output_args <- output_args + + return(internal) +} + +#' Gets the default values for the output arguments +#' +#' @param keep_samp_for_vS Logical. +#' Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in `internal$output`). +#' Not used for `approach="regression_separate"` or `approach="regression_surrogate"`. +#' @param MSEv_uniform_comb_weights Logical. +#' If `TRUE` (default), then the function weights the coalitions uniformly when computing the MSEv criterion. +#' If `FALSE`, then the function use the Shapley kernel weights to weight the coalitions when computing the MSEv +#' criterion. +#' Note that the Shapley kernel weights are replaced by the sampling frequency when not all coalitions are considered. +#' @param saving_path String. +#' The path to the directory where the results of the iterative estimation procedure should be saved. +#' Defaults to a temporary directory. +#' @export +#' @author Martin Jullum +get_output_args_default <- function(keep_samp_for_vS = FALSE, + MSEv_uniform_comb_weights = TRUE, + saving_path = tempfile("shapr_obj_", fileext = ".rds")) { + return(mget(methods::formalArgs(get_output_args_default))) +} + +check_output_args <- function(output_args) { + list2env(output_args, envir = environment()) # Make accessible in the environment + + # Check the output_args elements + + # keep_samp_for_vS + if (!(is.logical(keep_samp_for_vS) && + length(keep_samp_for_vS) == 1)) { + stop("`output_args$keep_samp_for_vS` must be single logical.") + } # Parameter used in the MSEv evaluation criterion if (!(is.logical(MSEv_uniform_comb_weights) && length(MSEv_uniform_comb_weights) == 1)) { - stop("`MSEv_uniform_comb_weights` must be single logical.") + stop("`output_args$MSEv_uniform_comb_weights` must be single logical.") } - #### Tests combining more than one parameter #### - # prediction_zero vs output_size - if (!all((is.numeric(prediction_zero)) && - all(length(prediction_zero) == output_size) && - all(!is.na(prediction_zero)))) { - stop(paste0( - "`prediction_zero` (", paste0(prediction_zero, collapse = ", "), - ") must be numeric and match the output size of the model (", - paste0(output_size, collapse = ", "), ")." - )) + # saving_path + if (!(is.character(saving_path) && + length(saving_path) == 1)) { + stop("`output_args$saving_path` must be a single character.") } - # Getting basic input parameters - parameters <- list( - approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, - group = group, - n_samples = n_samples, - n_batches = n_batches, - seed = seed, - keep_samp_for_vS = keep_samp_for_vS, - is_python = is_python, - output_size = output_size, - type = type, - horizon = horizon, - group_lags = group_lags, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, - verbose = verbose + # Also check that saving_path exists, and abort if not... + if (!dir.exists(dirname(saving_path))) { + stop( + paste0( + "Directory ", dirname(saving_path), " in the output_args$saving_path does not exists.\n", + "Please create the directory with `dir.create('", dirname(saving_path), "')` or use another directory." + ) + ) + } +} + + +#' @author Martin Jullum +#' @keywords internal +set_extra_estimation_params <- function(internal) { + extra_computation_args <- internal$parameters$extra_computation_args + + # Get defaults + extra_computation_args <- utils::modifyList(get_extra_est_args_default(internal), + extra_computation_args, + keep.null = TRUE ) - # Getting additional parameters from ... - parameters <- append(parameters, list(...)) + # Check the output_args elements + check_extra_computation_args(extra_computation_args) - # Setting exact based on n_combinations (TRUE if NULL) - parameters$exact <- ifelse(is.null(parameters$n_combinations), TRUE, FALSE) + extra_computation_args <- trans_null_extra_est_args(extra_computation_args) - # Setting that we are using regression based the approach name (any in case several approaches) - parameters$regression <- any(grepl("regression", parameters$approach)) + internal$parameters$extra_computation_args <- extra_computation_args - return(parameters) + return(internal) } -#' @keywords internal -get_data <- function(x_train, x_explain) { - # Check data object type - stop_message <- "" - if (!is.matrix(x_train) && !is.data.frame(x_train)) { - stop_message <- paste0(stop_message, "x_train should be a matrix or a data.frame/data.table.\n") +#' Gets the default values for the extra estimation arguments +#' +#' @param compute_sd Logical. Whether to estimate the standard deviations of the Shapley value estimates. This is TRUE +#' whenever sampling based kernelSHAP is applied (either iteratively or with a fixed number of coalitions). +#' @param n_boot_samps Integer. The number of bootstrapped samples (i.e. samples with replacement) from the set of all +#' coalitions used to estimate the standard deviations of the Shapley value estimates. +#' @param max_batch_size Integer. The maximum number of coalitions to estimate simultaneously within each iteration. +#' A larger numbers requires more memory, but may have a slight computational advantage. +#' @param min_n_batches Integer. The minimum number of batches to split the computation into within each iteration. +#' Larger numbers gives more frequent progress updates. If parallelization is applied, this should be set no smaller +#' than the number of parallel workers. +#' @inheritParams default_doc_explain +#' @export +#' @author Martin Jullum +get_extra_est_args_default <- function(internal, # Only used to get the default value of compute_sd + compute_sd = isFALSE(internal$parameters$exact), + n_boot_samps = 100, + max_batch_size = 10, + min_n_batches = 10) { + return(mget(methods::formalArgs(get_extra_est_args_default)[-1])) # [-1] to exclude internal +} + +check_extra_computation_args <- function(extra_computation_args) { + list2env(extra_computation_args, envir = environment()) # Make accessible in the environment + + # compute_sd + if (!(is.logical(compute_sd) && + length(compute_sd) == 1)) { + stop("`extra_computation_args$compute_sd` must be single logical.") } - if (!is.matrix(x_explain) && !is.data.frame(x_explain)) { - stop_message <- paste0(stop_message, "x_explain should be a matrix or a data.frame/data.table.\n") + + # n_boot_samps + if (!(is.wholenumber(n_boot_samps) && + length(n_boot_samps) == 1 && + !is.na(n_boot_samps) && + n_boot_samps > 0)) { + stop("`extra_computation_args$n_boot_samps` must be a single positive integer.") } - if (stop_message != "") { - stop(stop_message) + + # max_batch_size + if (!is.null(max_batch_size) && + !((is.wholenumber(max_batch_size) || is.infinite(max_batch_size)) && + length(max_batch_size) == 1 && + !is.na(max_batch_size) && + max_batch_size > 0)) { + stop("`extra_computation_args$max_batch_size` must be NULL, Inf or a single positive integer.") } - # Check column names - if (all(is.null(colnames(x_train)))) { - stop_message <- paste0(stop_message, "x_train misses column names.\n") + # min_n_batches + if (!is.null(min_n_batches) && + !(is.wholenumber(min_n_batches) && + length(min_n_batches) == 1 && + !is.na(min_n_batches) && + min_n_batches > 0)) { + stop("`extra_computation_args$min_n_batches` must be NULL or a single positive integer.") } - if (all(is.null(colnames(x_explain)))) { - stop_message <- paste0(stop_message, "x_explain misses column names.\n") +} + +trans_null_extra_est_args <- function(extra_computation_args) { + list2env(extra_computation_args, envir = environment()) + + # Translating NULL to always return n_batches = 1 (if just one approach) + extra_computation_args$min_n_batches <- ifelse(is.null(min_n_batches), 1, min_n_batches) + extra_computation_args$max_batch_size <- ifelse(is.null(max_batch_size), Inf, max_batch_size) + + return(extra_computation_args) +} + + +check_and_set_iterative <- function(internal) { + iterative <- internal$parameters$iterative + approach <- internal$parameters$approach + + # Always iterative = FALSE for vaeac and regression_surrogate + if (any(approach %in% c("vaeac", "regression_surrogate"))) { + unsupported <- approach[approach %in% c("vaeac", "regression_surrogate")] + + if (isTRUE(iterative)) { + warning( + paste0( + "Iterative estimation of Shapley values are not supported for approach = ", + paste0(unsupported, collapse = ", "), ". Setting iterative = FALSE." + ) + ) + } + + internal$parameters$iterative <- FALSE + } else { + # Sets the default value of iterative to TRUE if computing more than 5 Shapley values for all other approaches + if (is.null(iterative)) { + n_shapley_values <- internal$parameters$n_shapley_values # n_features if feature-wise and n_groups if group-wise + internal$parameters$iterative <- isTRUE(n_shapley_values > 5) + } } - if (stop_message != "") { - stop(stop_message) + + return(internal) +} + + +set_exact <- function(internal) { + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + is_groupwise <- internal$parameters$is_groupwise + iterative <- internal$parameters$iterative + asymmetric <- internal$parameters$asymmetric + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal + + if (isFALSE(iterative) && + ( + (isTRUE(asymmetric) && max_n_coalitions == max_n_coalitions_causal) || + (isFALSE(is_groupwise) && max_n_coalitions == 2^n_features) || + (isTRUE(is_groupwise) && max_n_coalitions == 2^n_groups) + ) + ) { + exact <- TRUE + } else { + exact <- FALSE } + internal$parameters$exact <- exact - data <- list( - x_train = data.table::as.data.table(x_train), - x_explain = data.table::as.data.table(x_explain) - ) + return(internal) } +#' @keywords internal +check_computability <- function(internal) { + is_groupwise <- internal$parameters$is_groupwise + max_n_coalitions <- internal$parameters$max_n_coalitions + n_features <- internal$parameters$n_features + n_groups <- internal$parameters$n_groups + exact <- internal$parameters$exact + causal_sampling <- internal$parameters$causal_sampling # NULL if regular/symmetric Shapley values + asymmetric <- internal$parameters$asymmetric # NULL if regular/symmetric Shapley values + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if regular/symmetric Shapley values + + if (asymmetric) { + if (isTRUE(exact)) { + if (max_n_coalitions_causal > 5000 && max_n_coalitions > 5000) { # TODO check + warning( + paste0( + "Due to computation time, we recommend not computing asymmetric Shapley values exactly \n", + "with all valid causal coalitions (", max_n_coalitions_causal, ") when larger than 5000.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + } + } + # Force user to use a natural number for n_coalitions if m > 13 + if (isTRUE(exact)) { + if (isFALSE(is_groupwise) && n_features > 13) { + warning( + paste0( + "Due to computation time, we recommend not computing Shapley values exactly \n", + "with all 2^n_features (", 2^n_features, ") coalitions for n_features > 13.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + if (isTRUE(is_groupwise) && n_groups > 13) { + warning( + paste0( + "Due to computation time, we recommend not computing Shapley values exactly \n", + "with all 2^n_groups (", 2^n_groups, ") coalitions for n_groups > 13.\n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + ) + } + if (isTRUE(causal_sampling) && !is.null(max_n_coalitions_causal) && max_n_coalitions_causal > 1000) { + paste0( + "Due to computation time, we recommend not computing causal Shapley values exactly \n", + "with all valid causal coalitions when there are more than 1000 due to the long causal sampling time. \n", + "Consider reducing max_n_coalitions and enabling iterative estimation with iterative = TRUE.\n" + ) + } + } else { + if (isFALSE(is_groupwise) && n_features > 30) { + warning( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE", + " when n_features > 30.\n", + ) + } + if (isTRUE(is_groupwise) && n_groups > 30) { + warning( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE", + " when n_groups > 30.\n", + ) + } + if (isTRUE(causal_sampling) && !is.null(max_n_coalitions_causal) && max_n_coalitions_causal > 1000) { + warning( + paste0( + "Due to computation time, we strongly recommend enabling iterative estimation with iterative = TRUE ", + "when the number of valid causal coalitions are more than 1000 due to the long causal sampling time. \n" + ) + ) + } + } +} -#' Fetches feature information from a given data set -#' -#' @param x matrix, data.frame or data.table The data to extract feature information from. -#' -#' @details This function is used to extract the feature information to be checked against the corresponding -#' information extracted from the model and other data sets. The function is called from internally + + + +#' @keywords internal +check_approach <- function(internal) { + # Check length of approach + + approach <- internal$parameters$approach + n_features <- internal$parameters$n_features + supported_approaches <- get_supported_approaches() + + if (!(is.character(approach) && + (length(approach) == 1 || length(approach) == n_features - 1) && + all(is.element(approach, supported_approaches))) + ) { + stop( + paste0( + "`approach` must be one of the following: '", paste0(supported_approaches, collapse = "', '"), "'.\n", + "These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector ", + "of length one less than the number of features (", n_features - 1, ")." + ) + ) + } + + if (length(approach) > 1 && any(grepl("regression", approach))) { + stop("The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches.") + } +} + +#' Gets the implemented approaches #' -#' @return A list with the following elements: -#' \describe{ -#' \item{labels}{character vector with the feature names to compute Shapley values for} -#' \item{classes}{a named character vector with the labels as names and the class types as elements} -#' \item{factor_levels}{a named list with the labels as names and character vectors with the factor levels as elements -#' (NULL if the feature is not a factor)} -#' } -#' @author Martin Jullum +#' @return Character vector. +#' The names of the implemented approaches that can be passed to argument `approach` in [explain()]. #' -#' @keywords internal #' @export -#' -#' @examples -#' # Load example data -#' data("airquality") -#' airquality <- airquality[complete.cases(airquality), ] -#' # Split data into test- and training data -#' x_train <- head(airquality, -3) -#' x_explain <- tail(airquality, 3) -#' # Split data into test- and training data -#' x_train <- data.table::as.data.table(head(airquality)) -#' x_train[, Temp := as.factor(Temp)] -#' get_data_specs(x_train) -get_data_specs <- function(x) { - feature_specs <- list() - feature_specs$labels <- names(x) - feature_specs$classes <- unlist(lapply(x, class)) - feature_specs$factor_levels <- lapply(x, levels) +get_supported_approaches <- function() { + substring(rownames(attr(methods(prepare_data), "info")), first = 14) +} - # Defining all integer values as numeric - feature_specs$classes[feature_specs$classes == "integer"] <- "numeric" - return(feature_specs) + + +#' @keywords internal +#' @author Lars Henry Berge Olsen +check_regression <- function(internal) { + # Check that the model outputs one-dimensional predictions + if (internal$parameters$output_size != 1) { + stop("`regression_separate` and `regression_surrogate` only support models with one-dimensional output") + } + + # Check that we are NOT explaining a forecast model + if (internal$parameters$type == "forecast") { + stop("`regression_separate` and `regression_surrogate` does not support `forecast`.") + } + + # Check that we are not to keep the Monte Carlo samples + if (internal$parameters$output_args$keep_samp_for_vS) { + stop(paste( + "`keep_samp_for_vS` must be `FALSE` for the `regression_separate` and `regression_surrogate`", + "approaches as there are no Monte Carlo samples to keep for these approaches." + )) + } + + # Remove n_MC_samples if we are doing regression, as we are not doing MC sampling + internal$parameters$n_MC_samples <- NULL + + return(internal) } + + + + + + + + + +compare_vecs <- function(vec1, vec2, vec_type, name1, name2) { + if (!identical(vec1, vec2)) { + if (is.null(names(vec1))) { + text_vec1 <- paste(vec1, collapse = ", ") + } else { + text_vec1 <- paste(names(vec1), vec1, sep = ": ", collapse = ", ") + } + if (is.null(names(vec2))) { + text_vec2 <- paste(vec2, collapse = ", ") + } else { + text_vec2 <- paste(names(vec2), vec1, sep = ": ", collapse = ", ") + } + + stop(paste0( + "Feature ", vec_type, " are not identical for ", name1, " and ", name2, ".\n", + name1, " provided: ", text_vec1, ",\n", + name2, " provided: ", text_vec2, ".\n" + )) + } +} + + + #' Check that the group parameter has the right form and content #' #' @@ -668,81 +1436,262 @@ check_groups <- function(feature_names, group) { } } + + #' @keywords internal -check_approach <- function(internal) { - # Check length of approach +set_iterative_parameters <- function(internal, prev_iter_list = NULL) { + iterative <- internal$parameters$iterative - approach <- internal$parameters$approach - n_features <- internal$parameters$n_features - supported_approaches <- get_supported_approaches() + iterative_args <- internal$parameters$iterative_args - if (!(is.character(approach) && - (length(approach) == 1 || length(approach) == n_features - 1) && - all(is.element(approach, supported_approaches))) - ) { + iterative_args <- utils::modifyList(get_iterative_args_default(internal), + iterative_args, + keep.null = TRUE + ) + + # Force setting the number of coalitions and iterations for non-iterative method + if (isFALSE(iterative)) { + iterative_args$max_iter <- 1 + iterative_args$initial_n_coalitions <- iterative_args$max_n_coalitions + } + + check_iterative_args(iterative_args) + + # Translate any null input + iterative_args <- trans_null_iterative_args(iterative_args) + + internal$parameters$iterative_args <- iterative_args + + if (!is.null(prev_iter_list)) { + # Update internal with the iter_list from prev_shapr_object + internal$iter_list <- prev_iter_list + + # Conveniently allow running non-iterative estimation one step further + if (isFALSE(internal$parameters$iterative)) { + internal$parameters$iterative_args$max_iter <- length(internal$iter_list) + 1 + internal$parameters$iterative_args$n_coal_next_iter_factor_vec <- NULL + } + + # Update convergence data with NEW iterative arguments + internal <- check_convergence(internal) + + # Check for convergence based on last iter_list with new iterative arguments + check_vs_prev_shapr_object(internal) + + # Prepare next iteration + internal <- prepare_next_iteration(internal) + } else { + internal$iter_list <- list() + internal$iter_list[[1]] <- list( + n_coalitions = iterative_args$initial_n_coalitions, + new_n_coalitions = iterative_args$initial_n_coalitions, + exact = internal$parameters$exact, + compute_sd = internal$parameters$extra_computation_args$compute_sd, + n_coal_next_iter_factor = iterative_args$n_coal_next_iter_factor_vec[1], + n_batches = set_n_batches(iterative_args$initial_n_coalitions, internal) + ) + } + + return(internal) +} + +check_iterative_args <- function(iterative_args) { + list2env(iterative_args, envir = environment()) + + + # initial_n_coalitions + if (!(is.wholenumber(initial_n_coalitions) && + length(initial_n_coalitions) == 1 && + !is.na(initial_n_coalitions) && + initial_n_coalitions <= max_n_coalitions && + initial_n_coalitions > 2)) { + stop("`iterative_args$initial_n_coalitions` must be a single integer between 2 and `max_n_coalitions`.") + } + + # fixed_n_coalitions + if (!is.null(fixed_n_coalitions_per_iter) && + !(is.wholenumber(fixed_n_coalitions_per_iter) && + length(fixed_n_coalitions_per_iter) == 1 && + !is.na(fixed_n_coalitions_per_iter) && + fixed_n_coalitions_per_iter <= max_n_coalitions && + fixed_n_coalitions_per_iter > 0)) { stop( - paste0( - "`approach` must be one of the following: '", paste0(supported_approaches, collapse = "', '"), "'.\n", - "These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector ", - "of length one less than the number of features (", n_features - 1, ")." - ) + "`iterative_args$fixed_n_coalitions_per_iter` must be NULL or a single positive integer no larger than", + "`max_n_coalitions`." ) } - if (length(approach) > 1 && any(grepl("regression", approach))) { - stop("The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches.") + # max_iter + if (!is.null(max_iter) && + !((is.wholenumber(max_iter) || is.infinite(max_iter)) && + length(max_iter) == 1 && + !is.na(max_iter) && + max_iter > 0)) { + stop("`iterative_args$max_iter` must be NULL, Inf or a single positive integer.") + } + + # convergence_tol + if (!is.null(convergence_tol) && + !(length(convergence_tol) == 1 && + !is.na(convergence_tol) && + convergence_tol >= 0)) { + stop("`iterative_args$convergence_tol` must be NULL, 0, or a positive numeric.") + } + + # n_coal_next_iter_factor_vec + if (!is.null(n_coal_next_iter_factor_vec) && + !(all(!is.na(n_coal_next_iter_factor_vec)) && + all(n_coal_next_iter_factor_vec <= 1) && + all(n_coal_next_iter_factor_vec >= 0))) { + stop("`iterative_args$n_coal_next_iter_factor_vec` must be NULL or a vector or numerics between 0 and 1.") } } -#' @keywords internal -set_defaults <- function(internal) { - # Set defaults for certain arguments (based on other input) +trans_null_iterative_args <- function(iterative_args) { + list2env(iterative_args, envir = environment()) - approach <- internal$parameters$approach + # Translating NULL to always return n_batches = 1 (if just one approach) + iterative_args$max_iter <- ifelse(is.null(max_iter), Inf, max_iter) + + return(iterative_args) +} + + +set_n_batches <- function(n_coalitions, internal) { + min_n_batches <- internal$parameters$extra_computation_args$min_n_batches + max_batch_size <- internal$parameters$extra_computation_args$max_batch_size n_unique_approaches <- internal$parameters$n_unique_approaches - used_n_combinations <- internal$parameters$used_n_combinations - n_batches <- internal$parameters$n_batches - # n_batches - if (is.null(n_batches)) { - internal$parameters$n_batches <- get_default_n_batches(approach, n_unique_approaches, used_n_combinations) - } - return(internal) + # Restrict the sizes of the batches to max_batch_size, but require at least min_n_batches and n_unique_approaches + suggested_n_batches <- max(min_n_batches, n_unique_approaches, ceiling(n_coalitions / max_batch_size)) + + # Set n_batches to no less than n_coalitions + n_batches <- min(n_coalitions, suggested_n_batches) + + return(n_batches) } -#' @keywords internal -get_default_n_batches <- function(approach, n_unique_approaches, n_combinations) { - used_approach <- names(sort(table(approach), decreasing = TRUE))[1] # Most frequent used approach (when more present) +check_vs_prev_shapr_object <- function(internal) { + iter <- length(internal$iter_list) + + converged <- internal$iter_list[[iter]]$converged + converged_exact <- internal$iter_list[[iter]]$converged_exact + converged_sd <- internal$iter_list[[iter]]$converged_sd + converged_max_iter <- internal$iter_list[[iter]]$converged_max_iter + converged_max_n_coalitions <- internal$iter_list[[iter]]$converged_max_n_coalitions + + if (isTRUE(converged)) { + message0 <- "Convergence reached before estimation start.\n" + if (isTRUE(converged_exact)) { + message0 <- c( + message0, + "All coalitions estimated. No need for further estimation.\n" + ) + } + if (isTRUE(converged_sd)) { + message0 <- c( + message0, + "Convergence tolerance reached. Consider decreasing `iterative_args$tolerance`.\n" + ) + } + if (isTRUE(converged_max_iter)) { + message0 <- c( + message0, + "Maximum number of iterations reached. Consider increasing `iterative_args$max_iter`.\n" + ) + } + if (isTRUE(converged_max_n_coalitions)) { + message0 <- c( + message0, + "Maximum number of coalitions reached. Consider increasing `max_n_coalitions`.\n" + ) + } + stop(message0) + } +} - if (used_approach %in% c("ctree", "gaussian", "copula")) { - suggestion <- ceiling(n_combinations / 10) - this_min <- 10 - this_max <- 1000 +# Get functions ======================================================================================================== +#' Function to specify arguments of the iterative estimation procedure +#' +#' @details The functions sets default values for the iterative estimation procedure, according to the function +#' defaults. +#' If the argument `iterative` of [shapr::explain()] is FALSE, it sets parameters corresponding to the use of a +#' non-iterative estimation procedure +#' +#' @param max_iter Integer. Maximum number of estimation iterations +#' @param initial_n_coalitions Integer. Number of coalitions to use in the first estimation iteration. +#' @param fixed_n_coalitions_per_iter Integer. Number of `n_coalitions` to use in each iteration. +#' `NULL` (default) means setting it based on estimates based on a set convergence threshold. +#' @param convergence_tol Numeric. The t variable in the convergence threshold formula on page 6 in the paper +#' Covert and Lee (2021), 'Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression' +#' https://arxiv.org/pdf/2012.01536. Smaller values requires more coalitions before convergence is reached. +#' @param n_coal_next_iter_factor_vec Numeric vector. The number of `n_coalitions` that must be used to reach +#' convergence in the next iteration is estimated. +#' The number of `n_coalitions` actually used in the next iteration is set to this estimate multiplied by +#' `n_coal_next_iter_factor_vec[i]` for iteration `i`. +#' It is wise to start with smaller numbers to avoid using too many `n_coalitions` due to uncertain estimates in +#' the first iterations. +#' @inheritParams default_doc_explain +#' +#' @export +#' @author Martin Jullum +get_iterative_args_default <- function(internal, + initial_n_coalitions = ceiling( + min( + 200, + max( + 5, + internal$parameters$n_features, + (2^internal$parameters$n_features) / 10 + ) + ) + ), + fixed_n_coalitions_per_iter = NULL, + max_iter = 20, + convergence_tol = 0.02, + n_coal_next_iter_factor_vec = c(seq(0.1, 1, by = 0.1), rep(1, max_iter - 10))) { + iterative <- internal$parameters$iterative + max_n_coalitions <- internal$parameters$max_n_coalitions + + if (isTRUE(iterative)) { + ret_list <- mget( + c( + "initial_n_coalitions", + "fixed_n_coalitions_per_iter", + "max_n_coalitions", + "max_iter", + "convergence_tol", + "n_coal_next_iter_factor_vec" + ) + ) } else { - suggestion <- ceiling(n_combinations / 100) - this_min <- 2 - this_max <- 100 - } - min_checked <- max(c(this_min, suggestion, n_unique_approaches)) - ret <- min(c(this_max, min_checked, n_combinations - 1)) - message( - paste0( - "Setting parameter 'n_batches' to ", ret, " as a fair trade-off between memory consumption and ", - "computation time.\n", - "Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption.\n" + ret_list <- list( + initial_n_coalitions = max_n_coalitions, + fixed_n_coalitions_per_iter = NULL, + max_n_coalitions = max_n_coalitions, + max_iter = 1, + convergence_tol = NULL, + n_coal_next_iter_factor_vec = NULL ) - ) - return(ret) + } + return(ret_list) } - -#' Gets the implemented approaches +#' Additional setup for regression-based methods #' -#' @return Character vector. -#' The names of the implemented approaches that can be passed to argument `approach` in [explain()]. +#' @inheritParams default_doc_explain #' #' @export -get_supported_approaches <- function() { - substring(rownames(attr(methods(prepare_data), "info")), first = 14) +#' @keywords internal +additional_regression_setup <- function(internal, model, predict_model) { + # This step needs to be called after predict_model is set, and therefore arrives at a later stage in explain() + + # Add the predicted response of the training and explain data to the internal list for regression-based methods. + # Use isTRUE as `regression` is not present (NULL) for non-regression methods (i.e., Monte Carlo-based methods). + if (isTRUE(internal$parameters$regression)) { + internal <- regression.get_y_hat(internal = internal, model = model, predict_model = predict_model) + } + + return(internal) } diff --git a/R/setup_computation.R b/R/setup_computation.R deleted file mode 100644 index dad9b6240444f22d01a9d26ba552cf9323a7c174..0000000000000000000000000000000000000000 --- a/R/setup_computation.R +++ /dev/null @@ -1,689 +0,0 @@ -#' Sets up everything for the Shapley values computation in [shapr::explain()] -#' -#' @inheritParams default_doc -#' @inheritParams explain -#' @inherit default_doc -#' @export -setup_computation <- function(internal, model, predict_model) { - # model and predict_model are only needed for type AICc of approach empirical, otherwise ignored - type <- internal$parameters$type - - # setup the Shapley framework - internal <- if (type == "forecast") shapley_setup_forecast(internal) else shapley_setup(internal) - - # Setup for approach - internal <- setup_approach(internal, model = model, predict_model = predict_model) - - return(internal) -} - -#' @keywords internal -shapley_setup_forecast <- function(internal) { - exact <- internal$parameters$exact - n_features0 <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - group_num <- internal$objects$group_num - horizon <- internal$parameters$horizon - feature_names <- internal$parameters$feature_names - - X_list <- W_list <- list() - - # Find columns/features to be included in each of the different horizons - col_del_list <- list() - col_del_list[[1]] <- numeric() - if (horizon > 1) { - k <- 2 - for (i in rev(seq_len(horizon)[-1])) { - col_del_list[[k]] <- c(unlist(col_del_list[[k - 1]]), grep(paste0(".F", i), feature_names)) - k <- k + 1 - } - } - - cols_per_horizon <- lapply(rev(col_del_list), function(x) if (length(x) > 0) feature_names[-x] else feature_names) - - horizon_features <- lapply(cols_per_horizon, function(x) which(internal$parameters$feature_names %in% x)) - - # Apply feature_combination, weigth_matrix and feature_matrix_cpp to each of the different horizons - for (i in seq_along(horizon_features)) { - this_featcomb <- horizon_features[[i]] - n_this_featcomb <- length(this_featcomb) - - this_group_num <- lapply(group_num, function(x) x[x %in% this_featcomb]) - - X_list[[i]] <- feature_combinations( - m = n_this_featcomb, - exact = exact, - n_combinations = n_combinations, - weight_zero_m = 10^6, - group_num = this_group_num - ) - - W_list[[i]] <- weight_matrix( - X = X_list[[i]], - normalize_W_weights = TRUE, - is_groupwise = is_groupwise - ) - } - - # Merge the feature combination data.table to single one to use for computing conditional expectations later on - X <- rbindlist(X_list, idcol = "horizon") - X[, N := NA] - X[, shapley_weight := NA] - data.table::setorderv(X, c("n_features", "horizon"), order = c(1, -1)) - X[, horizon_id_combination := id_combination] - X[, id_combination := 0] - X[!duplicated(features), id_combination := .I] - X[, tmp_features := as.character(features)] - X[, id_combination := max(id_combination), by = tmp_features] - X[, tmp_features := NULL] - - # Extracts a data.table allowing mapping from X to X_list/W_list to be used in the compute_shapley function - id_combination_mapper_dt <- X[, .(horizon, horizon_id_combination, id_combination)] - - X[, horizon := NULL] - X[, horizon_id_combination := NULL] - data.table::setorder(X, n_features) - X <- X[!duplicated(id_combination)] - - W <- NULL # Included for consistency. Necessary weights are in W_list instead - - ## Get feature matrix --------- - S <- feature_matrix_cpp( - features = X[["features"]], - m = n_features0 - ) - - - #### Updating parameters #### - - # Updating parameters$exact as done in feature_combinations - if (!exact && n_combinations >= 2^n_features0) { - internal$parameters$exact <- TRUE # Note that this is exact only if all horizons use the exact method. - } - - internal$parameters$n_combinations <- nrow(S) # Updating this parameter in the end based on what is actually used. - - # This will be obsolete later - internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed - # instead of storing it - - internal$objects$X <- X - internal$objects$W <- W - internal$objects$S <- S - internal$objects$S_batch <- create_S_batch_new(internal) - - internal$objects$id_combination_mapper_dt <- id_combination_mapper_dt - internal$objects$cols_per_horizon <- cols_per_horizon - internal$objects$W_list <- W_list - internal$objects$X_list <- X_list - - - return(internal) -} - - -#' @keywords internal -shapley_setup <- function(internal) { - exact <- internal$parameters$exact - n_features0 <- internal$parameters$n_features - n_combinations <- internal$parameters$n_combinations - is_groupwise <- internal$parameters$is_groupwise - - group_num <- internal$objects$group_num - - X <- feature_combinations( - m = n_features0, - exact = exact, - n_combinations = n_combinations, - weight_zero_m = 10^6, - group_num = group_num - ) - - # Get weighted matrix ---------------- - W <- weight_matrix( - X = X, - normalize_W_weights = TRUE, - is_groupwise = is_groupwise - ) - - ## Get feature matrix --------- - S <- feature_matrix_cpp( - features = X[["features"]], - m = n_features0 - ) - - #### Updating parameters #### - - # Updating parameters$exact as done in feature_combinations - if (!exact && n_combinations >= 2^n_features0) { - internal$parameters$exact <- TRUE - } - - internal$parameters$n_combinations <- nrow(S) # Updating this parameter in the end based on what is actually used. - - # This will be obsolete later - internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed - # instead of storing it - - internal$objects$X <- X - internal$objects$W <- W - internal$objects$S <- S - internal$objects$S_batch <- create_S_batch_new(internal) - - - return(internal) -} - -#' Define feature combinations, and fetch additional information about each unique combination -#' -#' @param m Positive integer. Total number of features. -#' @param exact Logical. If `TRUE` all `2^m` combinations are generated, otherwise a -#' subsample of the combinations is used. -#' @param n_combinations Positive integer. Note that if `exact = TRUE`, -#' `n_combinations` is ignored. However, if `m > 12` you'll need to add a positive integer -#' value for `n_combinations`. -#' @param weight_zero_m Numeric. The value to use as a replacement for infinite combination -#' weights when doing numerical operations. -#' @param group_num List. Contains vector of integers indicating the feature numbers for the -#' different groups. -#' -#' @return A data.table that contains the following columns: -#' \describe{ -#' \item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table -#' is sorted by `id_combination`, so that is always equal to `x[["id_combination"]] = 1:nrow(x)`.} -#' \item{features}{List. Each item of the list is an integer vector where `features[[i]]` -#' represents the indices of the features included in combination `i`. Note that all the items -#' are sorted such that `features[[i]] == sort(features[[i]])` is always true.} -#' \item{n_features}{Vector of positive integers. `n_features[i]` equals the number of features in combination -#' `i`, i.e. `n_features[i] = length(features[[i]])`.}. -#' \item{N}{Positive integer. The number of unique ways to sample `n_features[i]` features -#' from `m` different features, without replacement.} -#' } -#' -#' @export -#' -#' @author Nikolai Sellereite, Martin Jullum -#' -#' @examples -#' # All combinations -#' x <- feature_combinations(m = 3) -#' nrow(x) # Equals 2^3 = 8 -#' -#' # Subsample of combinations -#' x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) -feature_combinations <- function(m, exact = TRUE, n_combinations = 200, weight_zero_m = 10^6, group_num = NULL) { - m_group <- length(group_num) # The number of groups - - # Force user to use a natural number for n_combinations if m > 13 - if (m > 13 && is.null(n_combinations) && m_group == 0) { - stop( - paste0( - "Due to computational complexity, we recommend setting n_combinations = 10 000\n", - "if the number of features is larger than 13 for feature-wise Shapley values.\n", - "Note that you can force the use of the exact method (i.e. n_combinations = NULL)\n", - "by setting n_combinations equal to 2^m where m is the number of features.\n" - ) - ) - } - - # Not supported for m > 30 - if (m > 30 && m_group == 0) { - stop( - paste0( - "Currently we are not supporting cases where the number of features is greater than 30\n", - "for feature-wise Shapley values.\n" - ) - ) - } - if (m_group > 30) { - stop( - paste0( - "For computational reasons, we are currently not supporting group-wise Shapley values \n", - "for more than 30 groups. Please reduce the number of groups.\n" - ) - ) - } - - if (!exact) { - if (m_group == 0) { - # Switch to exact for feature-wise method - if (n_combinations >= 2^m) { - n_combinations <- 2^m - exact <- TRUE - message( - paste0( - "Success with message:\n", - "n_combinations is larger than or equal to 2^m = ", 2^m, ". \n", - "Using exact instead.\n" - ) - ) - } - } else { - # Switch to exact for feature-wise method - if (n_combinations >= (2^m_group)) { - n_combinations <- 2^m_group - exact <- TRUE - message( - paste0( - "Success with message:\n", - "n_combinations is larger than or equal to 2^group_num = ", 2^m_group, ". \n", - "Using exact instead.\n" - ) - ) - } - } - } - - if (m_group == 0) { - # Here if feature-wise Shapley values - if (exact) { - dt <- feature_exact(m, weight_zero_m) - } else { - dt <- feature_not_exact(m, n_combinations, weight_zero_m) - stopifnot( - data.table::is.data.table(dt), - !is.null(dt[["p"]]) - ) - p <- NULL # due to NSE notes in R CMD check - dt[, p := NULL] - } - } else { - # Here if group-wise Shapley values - if (exact) { - dt <- feature_group(group_num, weight_zero_m) - } else { - dt <- feature_group_not_exact(group_num, n_combinations, weight_zero_m) - stopifnot( - data.table::is.data.table(dt), - !is.null(dt[["p"]]) - ) - p <- NULL # due to NSE notes in R CMD check - dt[, p := NULL] - } - } - return(dt) -} - -#' @keywords internal -feature_exact <- function(m, weight_zero_m = 10^6) { - dt <- data.table::data.table(id_combination = seq(2^m)) - combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE) - dt[, features := unlist(combinations, recursive = FALSE)] - dt[, n_features := length(features[[1]]), id_combination] - dt[, N := .N, n_features] - dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_features, weight_zero_m)] - - return(dt) -} - -#' @keywords internal -feature_not_exact <- function(m, n_combinations = 200, weight_zero_m = 10^6, unique_sampling = TRUE) { - # Find weights for given number of features ---------- - n_features <- seq(m - 1) - n <- sapply(n_features, choose, n = m) - w <- shapley_weights(m = m, N = n, n_features) * n - p <- w / sum(w) - - feature_sample_all <- list() - unique_samples <- 0 - - - if (unique_sampling) { - while (unique_samples < n_combinations - 2) { - # Sample number of chosen features ---------- - n_features_sample <- sample( - x = n_features, - size = n_combinations - unique_samples - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - - # Sample specific set of features ------- - feature_sample <- sample_features_cpp(m, n_features_sample) - feature_sample_all <- c(feature_sample_all, feature_sample) - unique_samples <- length(unique(feature_sample_all)) - } - } else { - n_features_sample <- sample( - x = n_features, - size = n_combinations - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - feature_sample_all <- sample_features_cpp(m, n_features_sample) - } - - # Add zero and m features - feature_sample_all <- c(list(integer(0)), feature_sample_all, list(c(1:m))) - X <- data.table(n_features = sapply(feature_sample_all, length)) - X[, n_features := as.integer(n_features)] - - # Get number of occurences and duplicated rows------- - is_duplicate <- NULL # due to NSE notes in R CMD check - r <- helper_feature(m, feature_sample_all) - X[, is_duplicate := r[["is_duplicate"]]] - - # When we sample combinations the Shapley weight is equal - # to the frequency of the given combination - X[, shapley_weight := r[["sample_frequence"]]] - - # Populate table and remove duplicated rows ------- - X[, features := feature_sample_all] - if (any(X[["is_duplicate"]])) { - X <- X[is_duplicate == FALSE] - } - X[, is_duplicate := NULL] - data.table::setkeyv(X, "n_features") - - # Make feature list into character - X[, features_tmp := sapply(features, paste, collapse = " ")] - - # Aggregate weights by how many samples of a combination we observe - X <- X[, .( - n_features = data.table::first(n_features), - shapley_weight = sum(shapley_weight), - features = features[1] - ), features_tmp] - - X[, features_tmp := NULL] - data.table::setorder(X, n_features) - - # Add shapley weight and number of combinations - X[c(1, .N), shapley_weight := weight_zero_m] - X[, N := 1] - ind <- X[, .I[data.table::between(n_features, 1, m - 1)]] - X[ind, p := p[n_features]] - X[ind, N := n[n_features]] - - # Set column order and key table - data.table::setkeyv(X, "n_features") - X[, id_combination := .I] - X[, N := as.integer(N)] - nms <- c("id_combination", "features", "n_features", "N", "shapley_weight", "p") - data.table::setcolorder(X, nms) - - return(X) -} - -#' Calculate Shapley weight -#' -#' @param m Positive integer. Total number of features/feature groups. -#' @param n_components Positive integer. Represents the number of features/feature groups you want to sample from -#' a feature space consisting of `m` unique features/feature groups. Note that ` 0 < = n_components <= m`. -#' @param N Positive integer. The number of unique combinations when sampling `n_components` features/feature -#' groups, without replacement, from a sample space consisting of `m` different features/feature groups. -#' @param weight_zero_m Positive integer. Represents the Shapley weight for two special -#' cases, i.e. the case where you have either `0` or `m` features/feature groups. -#' -#' @return Numeric -#' @keywords internal -#' -#' @author Nikolai Sellereite -shapley_weights <- function(m, N, n_components, weight_zero_m = 10^6) { - x <- (m - 1) / (N * n_components * (m - n_components)) - x[!is.finite(x)] <- weight_zero_m - x -} - - -#' @keywords internal -helper_feature <- function(m, feature_sample) { - x <- feature_matrix_cpp(feature_sample, m) - dt <- data.table::data.table(x) - cnms <- paste0("V", seq(m)) - data.table::setnames(dt, cnms) - dt[, sample_frequence := as.integer(.N), by = cnms] - dt[, is_duplicate := duplicated(dt)] - dt[, (cnms) := NULL] - - return(dt) -} - - -#' Analogue to feature_exact, but for groups instead. -#' -#' @inheritParams shapley_weights -#' @param group_num List. Contains vector of integers indicating the feature numbers for the -#' different groups. -#' -#' @return data.table with all feature group combinations, shapley weights etc. -#' -#' @keywords internal -feature_group <- function(group_num, weight_zero_m = 10^6) { - m <- length(group_num) - dt <- data.table::data.table(id_combination = seq(2^m)) - combinations <- lapply(0:m, utils::combn, x = m, simplify = FALSE) - - dt[, groups := unlist(combinations, recursive = FALSE)] - dt[, features := lapply(groups, FUN = group_fun, group_num = group_num)] - dt[, n_groups := length(groups[[1]]), id_combination] - dt[, n_features := length(features[[1]]), id_combination] - dt[, N := .N, n_groups] - dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = n_groups, weight_zero_m)] - - return(dt) -} - -#' @keywords internal -group_fun <- function(x, group_num) { - if (length(x) != 0) { - unlist(group_num[x]) - } else { - integer(0) - } -} - - -#' Analogue to feature_not_exact, but for groups instead. -#' -#' Analogue to feature_not_exact, but for groups instead. -#' -#' @inheritParams shapley_weights -#' @inheritParams feature_group -#' -#' @return data.table with all feature group combinations, shapley weights etc. -#' -#' @keywords internal -feature_group_not_exact <- function(group_num, n_combinations = 200, weight_zero_m = 10^6) { - # Find weights for given number of features ---------- - m <- length(group_num) - n_groups <- seq(m - 1) - n <- sapply(n_groups, choose, n = m) - w <- shapley_weights(m = m, N = n, n_groups) * n - p <- w / sum(w) - - # Sample number of chosen features ---------- - feature_sample_all <- list() - unique_samples <- 0 - - while (unique_samples < n_combinations - 2) { - # Sample number of chosen features ---------- - n_features_sample <- sample( - x = n_groups, - size = n_combinations - unique_samples - 2, # Sample -2 as we add zero and m samples below - replace = TRUE, - prob = p - ) - - # Sample specific set of features ------- - feature_sample <- sample_features_cpp(m, n_features_sample) - feature_sample_all <- c(feature_sample_all, feature_sample) - unique_samples <- length(unique(feature_sample_all)) - } - - # Add zero and m features - feature_sample_all <- c(list(integer(0)), feature_sample_all, list(c(1:m))) - X <- data.table(n_groups = sapply(feature_sample_all, length)) - X[, n_groups := as.integer(n_groups)] - - - # Get number of occurences and duplicated rows------- - is_duplicate <- NULL # due to NSE notes in R CMD check - r <- helper_feature(m, feature_sample_all) - X[, is_duplicate := r[["is_duplicate"]]] - - # When we sample combinations the Shapley weight is equal - # to the frequency of the given combination - X[, shapley_weight := r[["sample_frequence"]]] - - # Populate table and remove duplicated rows ------- - X[, groups := feature_sample_all] - if (any(X[["is_duplicate"]])) { - X <- X[is_duplicate == FALSE] - } - X[, is_duplicate := NULL] - - # Make group list into character - X[, groups_tmp := sapply(groups, paste, collapse = " ")] - - # Aggregate weights by how many samples of a combination we have - X <- X[, .( - n_groups = data.table::first(n_groups), - shapley_weight = sum(shapley_weight), - groups = groups[1] - ), groups_tmp] - - X[, groups_tmp := NULL] - data.table::setorder(X, n_groups) - - - # Add shapley weight and number of combinations - X[c(1, .N), shapley_weight := weight_zero_m] - X[, N := 1] - ind <- X[, .I[data.table::between(n_groups, 1, m - 1)]] - X[ind, p := p[n_groups]] - X[ind, N := n[n_groups]] - - # Adding feature info - X[, features := lapply(groups, FUN = group_fun, group_num = group_num)] - X[, n_features := sapply(X$features, length)] - - # Set column order and key table - data.table::setkeyv(X, "n_groups") - X[, id_combination := .I] - X[, N := as.integer(N)] - nms <- c("id_combination", "groups", "features", "n_groups", "n_features", "N", "shapley_weight", "p") - data.table::setcolorder(X, nms) - - return(X) -} - -#' Calculate weighted matrix -#' -#' @param X data.table -#' @param normalize_W_weights Logical. Whether to normalize the weights for the combinations to sum to 1 for -#' increased numerical stability before solving the WLS (weighted least squares). Applies to all combinations -#' except combination `1` and `2^m`. -#' @param is_groupwise Logical. Indicating whether group wise Shapley values are to be computed. -#' -#' @return Numeric matrix. See [weight_matrix_cpp()] for more information. -#' @keywords internal -#' -#' @author Nikolai Sellereite, Martin Jullum -weight_matrix <- function(X, normalize_W_weights = TRUE, is_groupwise = FALSE) { - # Fetch weights - w <- X[["shapley_weight"]] - - if (normalize_W_weights) { - w[-c(1, length(w))] <- w[-c(1, length(w))] / sum(w[-c(1, length(w))]) - } - - if (!is_groupwise) { - W <- weight_matrix_cpp( - subsets = X[["features"]], - m = X[.N][["n_features"]], - n = X[, .N], - w = w - ) - } else { - W <- weight_matrix_cpp( - subsets = X[["groups"]], - m = X[.N][["n_groups"]], - n = X[, .N], - w = w - ) - } - - return(W) -} - -#' @keywords internal -create_S_batch_new <- function(internal, seed = NULL) { - n_features0 <- internal$parameters$n_features - approach0 <- internal$parameters$approach - n_combinations <- internal$parameters$n_combinations - n_batches <- internal$parameters$n_batches - - X <- internal$objects$X - - if (!is.null(seed)) set.seed(seed) - - if (length(approach0) > 1) { - X[!(n_features %in% c(0, n_features0)), approach := approach0[n_features]] - - # Finding the number of batches per approach - batch_count_dt <- X[!is.na(approach), list( - n_batches_per_approach = - pmax(1, round(.N / (n_combinations - 2) * n_batches)), - n_S_per_approach = .N - ), by = approach] - - # Ensures that the number of batches corresponds to `n_batches` - if (sum(batch_count_dt$n_batches_per_approach) != n_batches) { - # Ensure that the number of batches is not larger than `n_batches`. - # Remove one batch from the approach with the most batches. - while (sum(batch_count_dt$n_batches_per_approach) > n_batches) { - batch_count_dt[ - which.max(n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach - 1 - ] - } - - # Ensure that the number of batches is not lower than `n_batches`. - # Add one batch to the approach with most coalitions per batch - while (sum(batch_count_dt$n_batches_per_approach) < n_batches) { - batch_count_dt[ - which.max(n_S_per_approach / n_batches_per_approach), - n_batches_per_approach := n_batches_per_approach + 1 - ] - } - } - - batch_count_dt[, n_leftover_first_batch := n_S_per_approach %% n_batches_per_approach] - data.table::setorder(batch_count_dt, -n_leftover_first_batch) - - approach_vec <- batch_count_dt[, approach] - n_batch_vec <- batch_count_dt[, n_batches_per_approach] - - # Randomize order before ordering spreading the batches on the different approaches as evenly as possible - # with respect to shapley_weight - X[, randomorder := sample(.N)] - data.table::setorder(X, randomorder) # To avoid smaller id_combinations always proceeding large ones - data.table::setorder(X, shapley_weight) - - batch_counter <- 0 - for (i in seq_along(approach_vec)) { - X[approach == approach_vec[i], batch := ceiling(.I / .N * n_batch_vec[i]) + batch_counter] - batch_counter <- X[approach == approach_vec[i], max(batch)] - } - } else { - X[!(n_features %in% c(0, n_features0)), approach := approach0] - - # Spreading the batches - X[, randomorder := sample(.N)] - data.table::setorder(X, randomorder) - data.table::setorder(X, shapley_weight) - X[!(n_features %in% c(0, n_features0)), batch := ceiling(.I / .N * n_batches)] - } - - # Assigning batch 1 (which always is the smallest) to the full prediction. - X[, randomorder := NULL] - X[id_combination == max(id_combination), batch := 1] - setkey(X, id_combination) - - # Create a list of the batch splits - S_groups <- split(X[id_combination != 1, id_combination], X[id_combination != 1, batch]) - - return(S_groups) -} diff --git a/R/shapley_setup.R b/R/shapley_setup.R new file mode 100644 index 0000000000000000000000000000000000000000..bfd6e7c8fd03846a26a47ee17ef6ca4faf65a451 --- /dev/null +++ b/R/shapley_setup.R @@ -0,0 +1,777 @@ +#' Set up the kernelSHAP framework +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +shapley_setup <- function(internal) { + verbose <- internal$parameters$verbose + n_shapley_values <- internal$parameters$n_shapley_values + n_features <- internal$parameters$n_features + approach <- internal$parameters$approach + is_groupwise <- internal$parameters$is_groupwise + paired_shap_sampling <- internal$parameters$paired_shap_sampling + kernelSHAP_reweighting <- internal$parameters$kernelSHAP_reweighting + coal_feature_list <- internal$objects$coal_feature_list + causal_sampling <- internal$parameters$causal_sampling + causal_ordering <- internal$parameters$causal_ordering + causal_ordering_features <- internal$parameters$causal_ordering_features + confounding <- internal$parameters$confounding + dt_valid_causal_coalitions <- internal$objects$dt_valid_causal_coalitions # NULL if asymmetric is FALSE + max_n_coalitions_causal <- internal$parameters$max_n_coalitions_causal # NULL if asymmetric is FALSE + + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + exact <- internal$iter_list[[iter]]$exact + prev_coal_samples <- internal$iter_list[[iter]]$prev_coal_samples + + if ("progress" %in% verbose) { + cli::cli_progress_step("Sampling coalitions") + } + + + # dt_valid_causal_coalitions is only relevant for asymmetric Shapley values + X <- create_coalition_table( + m = n_shapley_values, + exact = exact, + n_coalitions = n_coalitions, + weight_zero_m = 10^6, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + coal_feature_list = coal_feature_list, + approach0 = approach, + kernelSHAP_reweighting = kernelSHAP_reweighting, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + + + + coalition_map <- X[, .(id_coalition, + coalitions_str = sapply(coalitions, paste, collapse = " ") + )] + + + # Get weighted matrix ---------------- + W <- weight_matrix( + X = X, + normalize_W_weights = TRUE + ) + + + ## Get feature matrix --------- + S <- coalition_matrix_cpp( + coalitions = X[["features"]], + m = n_features + ) + + #### Updating parameters #### + + # Updating parameters$exact as done in create_coalition_table. I don't think this is necessary now. TODO: Check. + # Moreover, it does not apply to grouping, so must be adjusted anyway. + if (!exact && n_coalitions >= min(2^n_shapley_values, max_n_coalitions_causal)) { + internal$iter_list[[iter]]$exact <- TRUE + internal$parameters$exact <- TRUE # Since this means that all coalitions have been sampled + } + + # Updating n_coalitions in the end based on what is actually used. I don't think this is necessary now. TODO: Check. + internal$iter_list[[iter]]$n_coalitions <- nrow(S) + + # This will be obsolete later + internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed + # instead of storing it + + + if (isFALSE(exact)) { + # Storing the feature samples + repetitions <- X[-c(1, .N), sample_freq] + + unique_coal_samples <- X[-c(1, .N), coalitions] + + coal_samples <- unlist( + lapply( + seq_along(unique_coal_samples), + function(i) { + rep( + list(unique_coal_samples[[i]]), + repetitions[i] + ) + } + ), + recursive = FALSE + ) + } else { + coal_samples <- NA + } + + internal$iter_list[[iter]]$X <- X + internal$iter_list[[iter]]$W <- W + internal$iter_list[[iter]]$S <- S + internal$iter_list[[iter]]$coalition_map <- coalition_map + internal$iter_list[[iter]]$S_batch <- create_S_batch(internal) + internal$iter_list[[iter]]$coal_samples <- coal_samples + + # If we are doing causal Shapley values, then get the step-wise data generating process for each coalition + if (causal_sampling) { + # Convert causal_ordering to be on the feature level also for group-wise Shapley values, + # as shapr must know the features to include in each causal sampling step and not the group. + causal_ordering <- if (is_groupwise) causal_ordering_features else causal_ordering + S_causal_steps <- get_S_causal_steps(S = S, causal_ordering = causal_ordering, confounding = confounding) + S_causal_steps_strings <- + get_S_causal_steps(S = S, causal_ordering = causal_ordering, confounding = confounding, as_string = TRUE) + + # Find all unique set of features to condition on + S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) + S_causal_steps_unique <- unique(S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)]) # Get S + S_causal_steps_unique <- S_causal_steps_unique[!sapply(S_causal_steps_unique, is.null)] # Remove NULLs + S_causal_steps_unique <- S_causal_steps_unique[lengths(S_causal_steps_unique) > 0] # Remove extra integer(0) + S_causal_steps_unique <- c(list(integer(0)), S_causal_steps_unique, list(seq(n_shapley_values))) + S_causal_steps_unique_S <- coalition_matrix_cpp(coalitions = S_causal_steps_unique, m = n_shapley_values) + + # Insert into the internal list + internal$iter_list[[iter]]$S_causal_steps <- S_causal_steps + internal$iter_list[[iter]]$S_causal_steps_strings <- S_causal_steps_strings + internal$iter_list[[iter]]$S_causal_steps_unique <- S_causal_steps_unique + internal$iter_list[[iter]]$S_causal_steps_unique_S <- S_causal_steps_unique_S + } + + return(internal) +} + +#' Define coalitions, and fetch additional information about each unique coalition +#' +#' @param m Positive integer. +#' Total number of features/groups. +#' @param exact Logical. +#' If `TRUE` all `2^m` coalitions are generated, otherwise a subsample of the coalitions is used. +#' @param n_coalitions Positive integer. +#' Note that if `exact = TRUE`, `n_coalitions` is ignored. +#' @param weight_zero_m Numeric. +#' The value to use as a replacement for infinite coalition weights when doing numerical operations. +#' @param paired_shap_sampling Logical. +#' Whether to do paired sampling of coalitions. +#' @param prev_coal_samples List. +#' A list of previously sampled coalitions. +#' @param approach0 Character vector. +#' Contains the approach to be used for eastimation of each coalition size. Same as `approach` in `explain()`. +#' @param coal_feature_list List. +#' A list mapping each coalition to the features it contains. +#' @param dt_valid_causal_coalitions data.table. Only applicable for asymmetric Shapley +#' values explanations, and is `NULL` for symmetric Shapley values. +#' The data.table contains information about the coalitions that respects the causal ordering. +#' @inheritParams explain +#' @return A data.table with columns about the that contains the following columns: +#' +#' @export +#' +#' @author Nikolai Sellereite, Martin Jullum +#' +#' @examples +#' # All coalitions +#' x <- create_coalition_table(m = 3) +#' nrow(x) # Equals 2^3 = 8 +#' +#' # Subsample of coalitions +#' x <- create_coalition_table(exact = FALSE, m = 10, n_coalitions = 1e2) +create_coalition_table <- function(m, + exact = TRUE, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + coal_feature_list = as.list(seq_len(m)), + approach0 = "gaussian", + kernelSHAP_reweighting = "none", + dt_valid_causal_coalitions = NULL) { + if (exact) { + dt <- exact_coalition_table( + m = m, + weight_zero_m = weight_zero_m, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + } else { + dt <- sample_coalition_table( + m = m, + n_coalitions = n_coalitions, + weight_zero_m = weight_zero_m, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + kernelSHAP_reweighting = kernelSHAP_reweighting, + dt_valid_causal_coalitions = dt_valid_causal_coalitions + ) + stopifnot( + data.table::is.data.table(dt), + !is.null(dt[["p"]]) + ) + p <- NULL # due to NSE notes in R CMD check + dt[, p := NULL] + } + + dt[, features := lapply(coalitions, FUN = coal_feature_mapper, coal_feature_list = coal_feature_list)] + + # Adding approach to X (needed for the combined approaches) + if (length(approach0) > 1) { + dt[!(coalition_size %in% c(0, m)), approach := approach0[coalition_size]] + } else { + dt[, approach := approach0] + } + + return(dt) +} + +#' @keywords internal +kernelSHAP_reweighting <- function(X, reweight = "on_N") { + # Updates the shapley weights in X based on the reweighting strategy BY REFERENCE + + + if (reweight == "on_N") { + X[-c(1, .N), shapley_weight := mean(shapley_weight), by = N] + } else if (reweight == "on_coal_size") { + X[-c(1, .N), shapley_weight := mean(shapley_weight), by = coalition_size] + } else if (reweight == "on_all") { + m <- X[.N, coalition_size] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + } else if (reweight == "on_N_sum") { + X[-c(1, .N), shapley_weight := sum(shapley_weight), by = N] + } else if (reweight == "on_all_cond") { + m <- X[.N, coalition_size] + K <- X[, sum(sample_freq)] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + X[-c(1, .N), cond := 1 - (1 - shapley_weight)^K] + X[-c(1, .N), shapley_weight := shapley_weight / cond] + } else if (reweight == "on_all_cond_paired") { + m <- X[.N, coalition_size] + K <- X[, sum(sample_freq)] + X[-c(1, .N), shapley_weight := shapley_weights( + m = m, + N = N, + n_components = coalition_size, + weight_zero_m = 10^6 + ) / sum_shapley_weights(m)] + X[-c(1, .N), cond := 1 - (1 - 2 * shapley_weight)^(K / 2)] + X[-c(1, .N), shapley_weight := 2 * shapley_weight / cond] + } + # strategy= "none" or something else do nothing + return(NULL) +} + + +#' @keywords internal +exact_coalition_table <- function(m, dt_valid_causal_coalitions = NULL, weight_zero_m = 10^6) { + # Create all valid coalitions for regular/symmetric or asymmetric Shapley values + if (is.null(dt_valid_causal_coalitions)) { + # Regular/symmetric Shapley values: use all 2^m coalitions + coalitions0 <- unlist(lapply(0:m, utils::combn, x = m, simplify = FALSE), recursive = FALSE) + } else { + # Asymmetric Shapley values: use only the coalitions that respect the causal ordering + coalitions0 <- dt_valid_causal_coalitions[, coalitions] + } + + dt <- data.table::data.table(id_coalition = seq_along(coalitions0)) + dt[, coalitions := coalitions0] + dt[, coalition_size := length(coalitions[[1]]), id_coalition] + dt[, N := .N, coalition_size] + dt[, shapley_weight := shapley_weights(m = m, N = N, n_components = coalition_size, weight_zero_m)] + dt[, sample_freq := NA] + return(dt) +} + +#' @keywords internal +sample_coalition_table <- function(m, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + kernelSHAP_reweighting, + valid_causal_coalitions = NULL, + dt_valid_causal_coalitions = NULL) { + # Setup + coal_samp_vec <- seq(m - 1) + n <- choose(m, coal_samp_vec) + w <- shapley_weights(m = m, N = n, coal_samp_vec) * n + p <- w / sum(w) + + if (!is.null(prev_coal_samples)) { + coal_sample_all <- prev_coal_samples + unique_samples <- length(unique(prev_coal_samples)) + n_coalitions <- min(2^m, n_coalitions) + # Adjusts for the the unique samples, zero and m samples + } else { + coal_sample_all <- list() + unique_samples <- 0 + } + + # Split in whether we do asymmetric or symmetric/regular Shapley values + if (!is.null(dt_valid_causal_coalitions)) { + # Asymmetric Shapley values + while (unique_samples < n_coalitions - 2) { # Sample until we have the right number of unique coalitions + + # Get the number of causal coalitions to sample + n_samps <- n_coalitions - unique_samples - 2 # Sample -2 as we add zero and m samples below + + # Sample the causal coalitions from the valid causal coalitions with the Shapley weight as the probability + # The weights of each coalition size is split evenly among the members of each coalition size, such that + # all.equal(p, dt_valid_causal_coalitions[-c(1,.N), sum(shapley_weight), by = coalition_size][, V1]) + coal_sample <- + dt_valid_causal_coalitions[-c(1, .N)][sample(.N, n_samps, replace = TRUE, prob = shapley_weight), coalitions] + + # Add the samples + coal_sample_all <- c(coal_sample_all, coal_sample) + + # Find the number of unique samples + unique_samples <- length(unique(coal_sample_all)) + } + } else { + # Symmetric/regular Shapley values + while (unique_samples < n_coalitions - 2) { # Sample until we have the right number of unique coalitions + if (paired_shap_sampling == TRUE) { + n_samps <- ceiling((n_coalitions - unique_samples - 2) / 2) # Sample -2 as we add zero and m samples below + } else { + n_samps <- n_coalitions - unique_samples - 2 # Sample -2 as we add zero and m samples below + } + + # Sample the coalition size ---------- + coal_size_sample <- sample( + x = coal_samp_vec, + size = n_samps, + replace = TRUE, + prob = p + ) + + # Sample specific coalitions ------- + coal_sample <- sample_features_cpp(m, coal_size_sample) + if (paired_shap_sampling == TRUE) { + coal_sample_paired <- lapply(coal_sample, function(x) seq(m)[-x]) + coal_sample_all <- c(coal_sample_all, coal_sample, coal_sample_paired) + } else { + coal_sample_all <- c(coal_sample_all, coal_sample) + } + unique_samples <- length(unique(coal_sample_all)) + } + } + + # Add zero and full prediction + coal_sample_all <- c(list(integer(0)), coal_sample_all, list(c(1:m))) + X <- data.table(coalition_size = sapply(coal_sample_all, length)) + X[, coalition_size := as.integer(coalition_size)] + + # Get number of occurences and duplicated rows------- + is_duplicate <- NULL # due to NSE notes in R CMD check + r <- helper_feature(m, coal_sample_all) + X[, is_duplicate := r[["is_duplicate"]]] + + # When we sample coalitions the Shapley weight is equal + # to the frequency of the given coalition + X[, sample_freq := r[["sample_frequence"]]] # We keep an unscaled version of the sampling frequency for bootstrapping + X[, shapley_weight := as.numeric(sample_freq)] # Convert to double for later calculations + + # Populate table and remove duplicated rows ------- + X[, coalitions := coal_sample_all] + if (any(X[["is_duplicate"]])) { + X <- X[is_duplicate == FALSE] + } + X[, is_duplicate := NULL] + data.table::setkeyv(X, "coalition_size") + + + #### TODO: Check if this could be removed: #### + ### Start of possible removal ### + # Make feature list into character + X[, coalitions_tmp := sapply(coalitions, paste, collapse = " ")] + + # Aggregate weights by how many samples of a coalition we observe + X <- X[, .( + coalition_size = data.table::first(coalition_size), + shapley_weight = sum(shapley_weight), + sample_freq = sum(sample_freq), + coalitions = coalitions[1] + ), coalitions_tmp] + + #### End of possible removal #### + + data.table::setorder(X, coalition_size) + + # Add shapley weight and number of coalitions + X[c(1, .N), shapley_weight := weight_zero_m] + X[, N := 1] + ind <- X[, .I[data.table::between(coalition_size, 1, m - 1)]] + X[ind, p := p[coalition_size]] + + if (!is.null(dt_valid_causal_coalitions)) { + # Asymmetric Shapley values + # Get the number of coalitions of each coalition size from the `dt_valid_causal_coalitions` data table + X[dt_valid_causal_coalitions, on = "coalitions_tmp", N := i.N] + } else { + # Symmetric/regular Shapley values + X[ind, N := n[coalition_size]] + } + + X[, coalitions_tmp := NULL] + + # Set column order and key table + data.table::setkeyv(X, "coalition_size") + X[, id_coalition := .I] + X[, N := as.integer(N)] + nms <- c("id_coalition", "coalitions", "coalition_size", "N", "shapley_weight", "p", "sample_freq") + data.table::setcolorder(X, nms) + + kernelSHAP_reweighting(X, reweight = kernelSHAP_reweighting) # Reweights the shapley weights in X by reference + + return(X) +} + + +#' Calculate Shapley weight +#' +#' @param m Positive integer. Total number of features/feature groups. +#' @param n_components Positive integer. Represents the number of features/feature groups you want to sample from +#' a feature space consisting of `m` unique features/feature groups. Note that ` 0 < = n_components <= m`. +#' @param N Positive integer. The number of unique coalitions when sampling `n_components` features/feature +#' groups, without replacement, from a sample space consisting of `m` different features/feature groups. +#' @param weight_zero_m Positive integer. Represents the Shapley weight for two special +#' cases, i.e. the case where you have either `0` or `m` features/feature groups. +#' +#' @return Numeric +#' @keywords internal +#' +#' @author Nikolai Sellereite +shapley_weights <- function(m, N, n_components, weight_zero_m = 10^6) { + x <- (m - 1) / (N * n_components * (m - n_components)) + x[!is.finite(x)] <- weight_zero_m + x +} + +#' @keywords internal +sum_shapley_weights <- function(m) { + coal_samp_vec <- seq(m - 1) + n <- sapply(coal_samp_vec, choose, n = m) + w <- shapley_weights(m = m, N = n, coal_samp_vec) * n + return(sum(w)) +} + + +#' @keywords internal +helper_feature <- function(m, coal_sample) { + x <- coalition_matrix_cpp(coal_sample, m) + dt <- data.table::data.table(x) + cnms <- paste0("V", seq(m)) + data.table::setnames(dt, cnms) + dt[, sample_frequence := as.integer(.N), by = cnms] + dt[, is_duplicate := duplicated(dt)] + dt[, (cnms) := NULL] + + return(dt) +} + + + + +#' @keywords internal +coal_feature_mapper <- function(x, coal_feature_list) { + if (length(x) != 0) { + unlist(coal_feature_list[x]) + } else { + integer(0) + } +} + +#' Calculate weighted matrix +#' +#' @param X data.table +#' @param normalize_W_weights Logical. Whether to normalize the weights for the coalitions to sum to 1 for +#' increased numerical stability before solving the WLS (weighted least squares). Applies to all coalitions +#' except coalition `1` and `2^m`. +#' +#' @return Numeric matrix. See [weight_matrix_cpp()] for more information. +#' @keywords internal +#' +#' @export +#' @author Nikolai Sellereite, Martin Jullum +weight_matrix <- function(X, normalize_W_weights = TRUE) { + # Fetch weights + w <- X[["shapley_weight"]] + + if (normalize_W_weights) { + w[-c(1, length(w))] <- w[-c(1, length(w))] / sum(w[-c(1, length(w))]) + } + + W <- weight_matrix_cpp( + coalitions = X[["coalitions"]], + m = X[.N][["coalition_size"]], + n = X[, .N], + w = w + ) + return(W) +} + +#' @keywords internal +create_S_batch <- function(internal, seed = NULL) { + n_shapley_values <- internal$parameters$n_shapley_values + approach0 <- internal$parameters$approach + type <- internal$parameters$type + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + n_batches <- internal$iter_list[[iter]]$n_batches + + exact <- internal$iter_list[[iter]]$exact + + + coalition_map <- internal$iter_list[[iter]]$coalition_map + + if (type == "forecast") { + id_coalition_mapper_dt <- internal$iter_list[[iter]]$id_coalition_mapper_dt + full_ids <- id_coalition_mapper_dt$id_coalition[id_coalition_mapper_dt$full] + } + + X0 <- copy(internal$iter_list[[iter]]$X) + + if (iter > 1) { + prev_coalition_map <- internal$iter_list[[iter - 1]]$coalition_map + new_id_coalitions <- coalition_map[ + !(coalitions_str %in% prev_coalition_map[-c(1, .N), coalitions_str, ]), + id_coalition + ] + X0 <- X0[id_coalition %in% new_id_coalitions] + } + + # Reduces n_batches if it is larger than the number of new_id_coalitions + n_batches <- min(n_batches, X0[, .N] - 2) + + + if (!is.null(seed)) set.seed(seed) + + if (length(approach0) > 1) { + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0[coalition_size]] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0[coalition_size]] + } + + # Finding the number of batches per approach + batch_count_dt <- X0[!is.na(approach), list( + n_batches_per_approach = + pmax(1, round(.N / (n_coalitions - 2) * n_batches)), + n_S_per_approach = .N + ), by = approach] + + # Ensures that the number of batches corresponds to `n_batches` + if (sum(batch_count_dt$n_batches_per_approach) != n_batches) { + # Ensure that the number of batches is not larger than `n_batches`. + # Remove one batch from the approach with the most batches. + while (sum(batch_count_dt$n_batches_per_approach) > n_batches) { + batch_count_dt[ + which.max(n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach - 1 + ] + } + + # Ensure that the number of batches is not lower than `n_batches`. + # Add one batch to the approach with most coalitions per batch + while (sum(batch_count_dt$n_batches_per_approach) < n_batches) { + batch_count_dt[ + which.max(n_S_per_approach / n_batches_per_approach), + n_batches_per_approach := n_batches_per_approach + 1 + ] + } + } + + batch_count_dt[, n_leftover_first_batch := n_S_per_approach %% n_batches_per_approach] + data.table::setorder(batch_count_dt, -n_leftover_first_batch) + + approach_vec <- batch_count_dt[, approach] + n_batch_vec <- batch_count_dt[, n_batches_per_approach] + + # Randomize order before ordering spreading the batches on the different approaches as evenly as possible + # with respect to shapley_weight + X0[, randomorder := sample(.N)] + data.table::setorder(X0, randomorder) # To avoid smaller id_coalitions always proceeding large ones + data.table::setorder(X0, shapley_weight) + + batch_counter <- 0 + for (i in seq_along(approach_vec)) { + X0[approach == approach_vec[i], batch := ceiling(.I / .N * n_batch_vec[i]) + batch_counter] + batch_counter <- X0[approach == approach_vec[i], max(batch)] + } + } else { + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), approach := approach0] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), approach := approach0] + } + + # Spreading the batches + X0[, randomorder := sample(.N)] + data.table::setorder(X0, randomorder) + data.table::setorder(X0, shapley_weight) + if (type == "forecast") { + X0[!(coalition_size == 0 | id_coalition %in% full_ids), batch := ceiling(.I / .N * n_batches)] + } else { + X0[!(coalition_size %in% c(0, n_shapley_values)), batch := ceiling(.I / .N * n_batches)] + } + } + + # Assigning batch 1 (which always is the smallest) to the full prediction. + X0[, randomorder := NULL] + if (type == "forecast") { + X0[id_coalition %in% full_ids, batch := 1] + } else { + X0[id_coalition == max(id_coalition), batch := 1] + } + setkey(X0, id_coalition) + + # Create a list of the batch splits + S_groups <- split(X0[id_coalition != 1, id_coalition], X0[id_coalition != 1, batch]) + + return(S_groups) +} + + +#' Sets up everything for the Shapley values computation in [shapr::explain()] +#' +#' @inheritParams default_doc +#' @inheritParams explain +#' @inherit default_doc +#' @export +setup_computation <- function(internal, model, predict_model) { # Can this function be removed? /Jon + # model and predict_model are only needed for type AICc of approach empirical, otherwise ignored + type <- internal$parameters$type + + # setup the Shapley framework + internal <- if (type == "forecast") shapley_setup_forecast(internal) else shapley_setup(internal) + + # Setup for approach + internal <- setup_approach(internal, model = model, predict_model = predict_model) + + return(internal) +} + +#' @keywords internal +shapley_setup_forecast <- function(internal) { + n_shapley_values <- internal$parameters$n_shapley_values + n_features <- internal$parameters$n_features + approach <- internal$parameters$approach + is_groupwise <- internal$parameters$is_groupwise + paired_shap_sampling <- internal$parameters$paired_shap_sampling + kernelSHAP_reweighting <- internal$parameters$kernelSHAP_reweighting + + coal_feature_list <- internal$objects$coal_feature_list + horizon <- internal$parameters$horizon + horizon_group <- internal$parameters$horizon_group + feature_names <- internal$parameters$feature_names + + iter <- length(internal$iter_list) + + n_coalitions <- internal$iter_list[[iter]]$n_coalitions + exact <- internal$iter_list[[iter]]$exact + prev_coal_samples <- internal$iter_list[[iter]]$prev_coal_samples + + X_list <- W_list <- list() + + cols_per_horizon <- internal$parameters$horizon_features + horizon_features <- lapply(cols_per_horizon, function(x) which(internal$parameters$feature_names %in% x)) + + # Apply create_coalition_table, weigth_matrix and coalition_matrix_cpp to each of the different horizons + for (i in seq_along(horizon_features)) { + if (is_groupwise && !is.null(horizon_group)) { + this_coal_feature_list <- coal_feature_list[sapply( + names(coal_feature_list), + function(x) x %in% horizon_group[[i]] + )] + } else { + this_coal_feature_list <- lapply(coal_feature_list, function(x) x[x %in% horizon_features[[i]]]) + this_coal_feature_list <- this_coal_feature_list[sapply(this_coal_feature_list, function(x) length(x) != 0)] + } + + n_this_featcomb <- length(this_coal_feature_list) + n_coalitions_here <- min(2^n_this_featcomb, n_coalitions) + + X_list[[i]] <- create_coalition_table( + m = n_this_featcomb, + exact = exact, + n_coalitions = n_coalitions_here, + weight_zero_m = 10^6, + paired_shap_sampling = paired_shap_sampling, + prev_coal_samples = prev_coal_samples, + coal_feature_list = this_coal_feature_list, + approach0 = approach, + kernelSHAP_reweighting = kernelSHAP_reweighting + ) + + W_list[[i]] <- weight_matrix( + X = X_list[[i]], + normalize_W_weights = TRUE + ) + } + + # Merge the coalition data.table to single one to use for computing conditional expectations later on + X <- rbindlist(X_list, idcol = "horizon") + X[, N := NA] + data.table::setorderv(X, c("coalition_size", "horizon"), order = c(1, -1)) + X[, horizon_id_coalition := id_coalition] + X[, id_coalition := 0] + X[!duplicated(features), id_coalition := .I] + X[, tmp_coalitions := as.character(features)] + X[, id_coalition := max(id_coalition), by = tmp_coalitions] + X[, tmp_coalitions := NULL] + + # Extracts a data.table allowing mapping from X to X_list/W_list to be used in the compute_shapley function + id_coalition_mapper_dt <- X[, .(horizon, horizon_id_coalition, id_coalition, full = features %in% horizon_features)] + + X[, horizon := NULL] + X[, horizon_id_coalition := NULL] + data.table::setorder(X, coalition_size) + X <- X[!duplicated(id_coalition)] + + W <- NULL # Included for consistency. Necessary weights are in W_list instead + + coalition_map <- X[, .(id_coalition, + coalitions_str = sapply(features, paste, collapse = " ") + )] + + ## Get feature matrix --------- + S <- coalition_matrix_cpp( + coalitions = X[["features"]], + m = n_features + ) + + + #### Updating parameters #### + + # Updating parameters$exact as done in create_coalition_table + if (!exact && n_coalitions >= 2^n_shapley_values) { + internal$iter_list[[iter]]$exact <- TRUE + internal$parameters$exact <- TRUE # Note that this is exact only if all horizons use the exact method. + } + + internal$iter_list[[iter]]$n_coalitions <- nrow(S) # Updating this parameter in the end based on what is used. + + # This will be obsolete later + internal$parameters$group_num <- NULL # TODO: Checking whether I could just do this processing where needed + # instead of storing it + + internal$iter_list[[iter]]$X <- X + internal$iter_list[[iter]]$W <- W + internal$iter_list[[iter]]$S <- S + internal$iter_list[[iter]]$id_coalition_mapper_dt <- id_coalition_mapper_dt + internal$iter_list[[iter]]$X_list <- X_list + internal$iter_list[[iter]]$coalition_map <- coalition_map + internal$iter_list[[iter]]$S_batch <- create_S_batch(internal) + + internal$objects$cols_per_horizon <- cols_per_horizon + internal$objects$W_list <- W_list + + return(internal) +} diff --git a/R/shapr-package.R b/R/shapr-package.R index fd368e8b4ab46e519853faca83f39d40d2b8a8c7..ee3875f80e57c5d53af25d398515ee33b1195494 100644 --- a/R/shapr-package.R +++ b/R/shapr-package.R @@ -25,8 +25,14 @@ #' #' @importFrom stats rnorm #' +#' @importFrom stats median +#' #' @importFrom Rcpp sourceCpp #' +#' @importFrom utils capture.output +#' +#' @importFrom utils relist +#' #' @keywords internal #' #' @useDynLib shapr, .registration = TRUE diff --git a/R/timing.R b/R/timing.R index b5ac27c9554dc33746de3f04f7f5a329ccc91bbf..2b3188a284a27d70962cdd7463d5ccb4a0185090 100644 --- a/R/timing.R +++ b/R/timing.R @@ -1,16 +1,51 @@ -compute_time <- function(timing_list) { - timing_secs <- mapply( +#' Gathers and computes the timing of the different parts of the explain function. +#' +#' @inheritParams default_doc_explain +#' +#' @export +#' @keywords internal +compute_time <- function(internal) { + verbose <- internal$parameters$verbose + + main_timing_list <- internal$main_timing_list + iter_timing_list <- internal$iter_timing_list + + + main_timing_secs <- mapply( FUN = difftime, - timing_list[-1], - timing_list[-length(timing_list)], + main_timing_list[-1], + main_timing_list[-length(main_timing_list)], units = "secs" ) + iter_timing_secs_list <- list() + for (i in seq_along(iter_timing_list)) { + iter_timing_secs_list[[i]] <- as.list(mapply( + FUN = difftime, + iter_timing_list[[i]][-1], + iter_timing_list[[i]][-length(iter_timing_list[[i]])], + units = "secs" + )) + } + iter_timing_secs_dt <- data.table::rbindlist(iter_timing_secs_list) + iter_timing_secs_dt[, total := rowSums(.SD)] + iter_timing_secs_dt[, iter := .I] + data.table::setcolorder(iter_timing_secs_dt, "iter") + + total_time_secs <- main_timing_list[[length(main_timing_list)]] - main_timing_list[[1]] + total_time_secs <- as.double(total_time_secs, units = "secs") + + timing_output <- list( - init_time = timing_list$init, - total_time_secs = sum(timing_secs), - timing_secs = timing_secs + init_time = main_timing_list[[1]], + end_time = main_timing_list[[length(main_timing_list)]], + total_time_secs = total_time_secs, + overall_timing_secs = main_timing_secs, + main_computation_timing_secs = iter_timing_secs_dt[] ) + internal$main_timing_list <- internal$iter_timing_list <- NULL + + return(timing_output) } diff --git a/R/zzz.R b/R/zzz.R index f540d5c86c8a55293b11011cc2d2d67f093aaa1a..a1458be87f803e06d8084b5c8d7e1fc59af5c2a4 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -11,7 +11,7 @@ "N", "id_all", "id", - "id_combination", + "id_coalition", "w", "id_all", "joint_prob", @@ -77,7 +77,7 @@ "batch", "type", "feature_value_factor", - "horizon_id_combination", + "horizon_id_coalition", "tmp_features", "Method", "MSEv", @@ -107,9 +107,55 @@ "x_train_torch", "self", "..current_comb", - "..regression.response_var" + "..regression.response_var", + "sample_freq", + "features_dup", + "features_dup_tmp", + "maxval", + "minval", + "req_samples", + "explain_id", + "id_coalition_new", + "features_str", + "boot_id", + "iter", + "total", + "coalitions", + "coalition_size", + "coalitions_tmp", + "initial_n_coalitions", + "max_n_coalitions", + "fixed_n_coalitions_per_iter", + "n_coal_next_iter_factor_vec", + "n_boot_samps", + "compute_sd", + "min_n_batches", + "max_batch_size", + "saving_path", + "coalitions_str", + "cond", + "tmp_coalitions", + "max_iter", + "convergence_tol", + "conv_measure", + "verbose", + "MSEv_uniform_comb_weights", + "keep_samp_for_vS", + "S_original_names_with_id", + "Sbar_features", + "Sbar_now_names", + "cond_cols_with_id", + "dt_factor_names", + "feature_conditioned", + "feature_conditioned_id", + "feature_names", + "relevant_features", + "i.N", + "prob", + "shapley_weight_norm" ) ) + invisible() } diff --git a/README.Rmd b/README.Rmd index aff32ad859105a70875d7c527369e644d08d9078..c61b972bad1e0deb0cf6305a6dfe19dbe51848b0 100644 --- a/README.Rmd +++ b/README.Rmd @@ -28,11 +28,13 @@ knitr::opts_chunk$set( ## Brief NEWS +This is `shapr` version 1.0.0, which provides a full suit of new functionality. +See the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) for details + ### Breaking change (June 2023) As of version 0.2.3.9000, the development version of shapr (master branch on GitHub from June 2023) has been severely restructured, introducing a new syntax for explaining models, and thereby introducing a range of breaking changes. This essentially amounts to using a single function (`explain()`) instead of two functions (`shapr()` and `explain()`). The CRAN version of `shapr` (v0.2.2) still uses the old syntax. -See the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) for details. The examples below uses the new syntax. [Here](https://github.com/NorskRegnesentral/shapr/blob/cranversion_0.2.2/README.md) is a version of this README with the syntax of the CRAN version (v0.2.2). @@ -41,63 +43,22 @@ The examples below uses the new syntax. As of version 0.2.3.9100 (master branch on GitHub from June 2023), we provide a Python wrapper (`shaprpy`) which allows explaining python models with the methodology implemented in `shapr`, directly from Python. The wrapper is available [here](https://github.com/NorskRegnesentral/shapr/tree/master/python). See also details in the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md). -## Introduction - -The most common machine learning task is to train a model which is able to predict an unknown outcome (response variable) based on a set of known input variables/features. -When using such models for real life applications, it is often crucial to understand why a certain set of features lead to exactly that prediction. -However, explaining predictions from complex, or seemingly simple, machine learning models is a practical and ethical question, as well as a legal issue. Can I trust the model? Is it biased? Can I explain it to others? We want to explain individual predictions from a complex machine learning model by learning simple, interpretable explanations. - -Shapley values is the only prediction explanation framework with a solid theoretical foundation (@lundberg2017unified). Unless the true distribution of the features are known, and there are less than say 10-15 features, these Shapley values needs to be estimated/approximated. -Popular methods like Shapley Sampling Values (@vstrumbelj2014explaining), SHAP/Kernel SHAP (@lundberg2017unified), and to some extent TreeSHAP (@lundberg2018consistent), assume that the features are independent when approximating the Shapley values for prediction explanation. This may lead to very inaccurate Shapley values, and consequently wrong interpretations of the predictions. @aas2019explaining extends and improves the Kernel SHAP method of @lundberg2017unified to account for the dependence between the features, resulting in significantly more accurate approximations to the Shapley values. -[See the paper for details](https://arxiv.org/abs/1903.10464). - -This package implements the methodology of @aas2019explaining. - -The following methodology/features are currently implemented: - -- Native support of explanation of predictions from models fitted with the following functions -`stats::glm`, `stats::lm`,`ranger::ranger`, `xgboost::xgboost`/`xgboost::xgb.train` and `mgcv::gam`. -- Accounting for feature dependence - * assuming the features are Gaussian (`approach = 'gaussian'`, @aas2019explaining) - * with a Gaussian copula (`approach = 'copula'`, @aas2019explaining) - * using the Mahalanobis distance based empirical (conditional) distribution approach (`approach = 'empirical'`, @aas2019explaining) - * using conditional inference trees (`approach = 'ctree'`, @redelmeier2020explaining). - * using the endpoint match method for time series (`approach = 'timeseries'`, @jullum2021efficient) - * using the joint distribution approach for models with purely cateogrical data (`approach = 'categorical'`, @redelmeier2020explaining) - * assuming all features are independent (`approach = 'independence'`, mainly for benchmarking) -- Combining any of the above methods. -- Explain *forecasts* from time series models at different horizons with `explain_forecast()` (R only) -- Batch computation to reduce memory consumption significantly -- Parallelized computation using the [future](https://future.futureverse.org/) framework. (R only) -- Progress bar showing computation progress, using the [`progressr`](https://progressr.futureverse.org/) package. Must be activated by the user. -- Optional use of the AICc criterion of @hurvich1998smoothing when optimizing the bandwidth parameter in the empirical (conditional) approach of @aas2019explaining. -- Functionality for visualizing the explanations. (R only) -- Support for models not supported natively. - - +## The package +The shapr R package implements an enhanced version of the KernelSHAP method, for approximating Shapley values, +with a strong focus on conditional Shapley values. +The core idea is to remain completely model-agnostic while offering a variety of methods for estimating contribution +functions, enabling accurate computation of conditional Shapley values across different feature types, dependencies, +and distributions. +The package also includes evaluation metrics to compare various approaches. +With features like parallelized computations, convergence detection, progress updates, and extensive plotting options, +shapr is as a highly efficient and user-friendly tool, delivering precise estimates of conditional Shapley values, +which are critical for understanding how features truly contribute to predictions. -Note the prediction outcome must be numeric. -All approaches except `approach = 'categorical'` works for numeric features, but unless the models are very gaussian-like, we recommend `approach = 'ctree'` or `approach = 'empirical'`, especially if there are discretely distributed features. -When the models contains both numeric and categorical features, we recommend `approach = 'ctree'`. -For models with a smaller number of categorical features (without many levels) and a decent training set, we recommend `approach = 'categorical'`. -For (binary) classification based on time series models, we suggest using `approach = 'timeseries'`. -To explain forecasts of time series models (at different horizons), we recommend using `explain_forecast()` instead of `explain()`. -The former has a more suitable input syntax for explaining those kinds of forecasts. -See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for details and further examples. +A basic example is provided below. +Otherwise we refer to the [pkgdown website](https://norskregnesentral.github.io/shapr/) and the vignettes there +for details and further examples. -Unlike SHAP and TreeSHAP, we decompose probability predictions directly to ease the interpretability, i.e. not via log odds transformations. ## Installation @@ -171,18 +132,19 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Finally we plot the resulting explanations plot(explanation) ``` -See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for further examples. +See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) for further basic usage +examples. ## Contribution diff --git a/README.md b/README.md index eb18d9d998536a09b4c5282b79de97b51a60eed3..6644223af5c09928186397aa0570773248a0efa5 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,11 @@ MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.or ## Brief NEWS +This is `shapr` version 1.0.0, which provides a full suit of new +functionality. See the +[NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) +for details + ### Breaking change (June 2023) As of version 0.2.3.9000, the development version of shapr (master @@ -26,9 +31,7 @@ introducing a new syntax for explaining models, and thereby introducing a range of breaking changes. This essentially amounts to using a single function (`explain()`) instead of two functions (`shapr()` and `explain()`). The CRAN version of `shapr` (v0.2.2) still uses the old -syntax. See the -[NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md) -for details. The examples below uses the new syntax. +syntax. The examples below uses the new syntax. [Here](https://github.com/NorskRegnesentral/shapr/blob/cranversion_0.2.2/README.md) is a version of this README with the syntax of the CRAN version (v0.2.2). @@ -43,106 +46,26 @@ Python. The wrapper is available See also details in the [NEWS](https://github.com/NorskRegnesentral/shapr/blob/master/NEWS.md). -## Introduction - -The most common machine learning task is to train a model which is able -to predict an unknown outcome (response variable) based on a set of -known input variables/features. When using such models for real life -applications, it is often crucial to understand why a certain set of -features lead to exactly that prediction. However, explaining -predictions from complex, or seemingly simple, machine learning models -is a practical and ethical question, as well as a legal issue. Can I -trust the model? Is it biased? Can I explain it to others? We want to -explain individual predictions from a complex machine learning model by -learning simple, interpretable explanations. - -Shapley values is the only prediction explanation framework with a solid -theoretical foundation (Lundberg and Lee (2017)). Unless the true -distribution of the features are known, and there are less than say -10-15 features, these Shapley values needs to be estimated/approximated. -Popular methods like Shapley Sampling Values (Štrumbelj and Kononenko -(2014)), SHAP/Kernel SHAP (Lundberg and Lee (2017)), and to some extent -TreeSHAP (Lundberg, Erion, and Lee (2018)), assume that the features are -independent when approximating the Shapley values for prediction -explanation. This may lead to very inaccurate Shapley values, and -consequently wrong interpretations of the predictions. Aas, Jullum, and -Løland (2021) extends and improves the Kernel SHAP method of Lundberg -and Lee (2017) to account for the dependence between the features, -resulting in significantly more accurate approximations to the Shapley -values. [See the paper for details](https://arxiv.org/abs/1903.10464). - -This package implements the methodology of Aas, Jullum, and Løland -(2021). - -The following methodology/features are currently implemented: - -- Native support of explanation of predictions from models fitted with - the following functions `stats::glm`, `stats::lm`,`ranger::ranger`, - `xgboost::xgboost`/`xgboost::xgb.train` and `mgcv::gam`. -- Accounting for feature dependence - - assuming the features are Gaussian (`approach = 'gaussian'`, - Aas, Jullum, and Løland (2021)) - - with a Gaussian copula (`approach = 'copula'`, Aas, Jullum, and - Løland (2021)) - - using the Mahalanobis distance based empirical (conditional) - distribution approach (`approach = 'empirical'`, Aas, Jullum, - and Løland (2021)) - - using conditional inference trees (`approach = 'ctree'`, - Redelmeier, Jullum, and Aas (2020)). - - using the endpoint match method for time series - (`approach = 'timeseries'`, Jullum, Redelmeier, and Aas (2021)) - - using the joint distribution approach for models with purely - cateogrical data (`approach = 'categorical'`, Redelmeier, - Jullum, and Aas (2020)) - - assuming all features are independent - (`approach = 'independence'`, mainly for benchmarking) -- Combining any of the above methods. -- Explain *forecasts* from time series models at different horizons - with `explain_forecast()` (R only) -- Batch computation to reduce memory consumption significantly -- Parallelized computation using the - [future](https://future.futureverse.org/) framework. (R only) -- Progress bar showing computation progress, using the - [`progressr`](https://progressr.futureverse.org/) package. Must be - activated by the user. -- Optional use of the AICc criterion of Hurvich, Simonoff, and - Tsai (1998) when optimizing the bandwidth parameter in the empirical - (conditional) approach of Aas, Jullum, and Løland (2021). -- Functionality for visualizing the explanations. (R only) -- Support for models not supported natively. - - - -Note the prediction outcome must be numeric. All approaches except -`approach = 'categorical'` works for numeric features, but unless the -models are very gaussian-like, we recommend `approach = 'ctree'` or -`approach = 'empirical'`, especially if there are discretely distributed -features. When the models contains both numeric and categorical -features, we recommend `approach = 'ctree'`. For models with a smaller -number of categorical features (without many levels) and a decent -training set, we recommend `approach = 'categorical'`. For (binary) -classification based on time series models, we suggest using -`approach = 'timeseries'`. To explain forecasts of time series models -(at different horizons), we recommend using `explain_forecast()` instead -of `explain()`. The former has a more suitable input syntax for -explaining those kinds of forecasts. See the -[vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) +## The package + +The shapr R package implements an enhanced version of the KernelSHAP +method, for approximating Shapley values, with a strong focus on +conditional Shapley values. The core idea is to remain completely +model-agnostic while offering a variety of methods for estimating +contribution functions, enabling accurate computation of conditional +Shapley values across different feature types, dependencies, and +distributions. The package also includes evaluation metrics to compare +various approaches. With features like parallelized computations, +convergence detection, progress updates, and extensive plotting options, +shapr is as a highly efficient and user-friendly tool, delivering +precise estimates of conditional Shapley values, which are critical for +understanding how features truly contribute to predictions. + +A basic example is provided below. Otherwise we refer to the [pkgdown +website](https://norskregnesentral.github.io/shapr/) and the vignettes +there for details and further examples. -Unlike SHAP and TreeSHAP, we decompose probability predictions directly -to ease the interpretability, i.e. not via log odds transformations. - ## Installation To install the current stable release from CRAN (note, using the old @@ -227,23 +150,38 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-23 19:31:59 ────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: +#> '/tmp/Rtmp6d4Iza/shapr_obj_3be21200fd9e8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> 1: 43.08571 13.2117337 4.785645 -25.57222 -5.599230 -#> 2: 43.08571 -9.9727747 5.830694 -11.03873 -7.829954 -#> 3: 43.08571 -2.2916185 -7.053393 -10.15035 -4.452481 -#> 4: 43.08571 3.3254595 -3.240879 -10.22492 -6.663488 -#> 5: 43.08571 4.3039571 -2.627764 -14.15166 -12.266855 -#> 6: 43.08571 0.4786417 -5.248686 -12.55344 -6.645738 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.08571 13.2117337 4.785645 -25.57222 -5.599230 +#> 2: 2 43.08571 -9.9727747 5.830694 -11.03873 -7.829954 +#> 3: 3 43.08571 -2.2916185 -7.053393 -10.15035 -4.452481 +#> 4: 4 43.08571 3.3254595 -3.240879 -10.22492 -6.663488 +#> 5: 5 43.08571 4.3039571 -2.627764 -14.15166 -12.266855 +#> 6: 6 43.08571 0.4786417 -5.248686 -12.55344 -6.645738 # Finally we plot the resulting explanations plot(explanation) @@ -253,7 +191,7 @@ plot(explanation) See the [vignette](https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html) -for further examples. +for further basic usage examples. ## Contribution @@ -269,66 +207,3 @@ Conduct](https://norskregnesentral.github.io/shapr/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms. ## References - -
- -
- -Aas, Kjersti, Martin Jullum, and Anders Løland. 2021. “Explaining -Individual Predictions When Features Are Dependent: More Accurate -Approximations to Shapley Values.” *Artificial Intelligence* 298. - -
- -
- -Hurvich, Clifford M, Jeffrey S Simonoff, and Chih-Ling Tsai. 1998. -“Smoothing Parameter Selection in Nonparametric Regression Using an -Improved Akaike Information Criterion.” *Journal of the Royal -Statistical Society: Series B (Statistical Methodology)* 60 (2): 271–93. - -
- -
- -Jullum, Martin, Annabelle Redelmeier, and Kjersti Aas. 2021. “Efficient -and Simple Prediction Explanations with groupShapley: A Practical -Perspective.” In *Proceedings of the 2nd Italian Workshop on Explainable -Artificial Intelligence*, 28–43. CEUR Workshop Proceedings. - -
- -
- -Lundberg, Scott M, Gabriel G Erion, and Su-In Lee. 2018. “Consistent -Individualized Feature Attribution for Tree Ensembles.” *arXiv Preprint -arXiv:1802.03888*. - -
- -
- -Lundberg, Scott M, and Su-In Lee. 2017. “A Unified Approach to -Interpreting Model Predictions.” In *Advances in Neural Information -Processing Systems*, 4765–74. - -
- -
- -Redelmeier, Annabelle, Martin Jullum, and Kjersti Aas. 2020. “Explaining -Predictive Models with Mixed Features Using Shapley Values and -Conditional Inference Trees.” In *International Cross-Domain Conference -for Machine Learning and Knowledge Extraction*, 117–37. Springer. - -
- -
- -Štrumbelj, Erik, and Igor Kononenko. 2014. “Explaining Prediction Models -and Individual Predictions with Feature Contributions.” *Knowledge and -Information Systems* 41 (3): 647–65. - -
- -
diff --git a/inst/REFERENCES.bib b/inst/REFERENCES.bib index ccce694e24e93f6d254fbd3842f15003cef9a13d..f0b00ec2d3e173f0f3e3ab0cb600e259257a8bff 100644 --- a/inst/REFERENCES.bib +++ b/inst/REFERENCES.bib @@ -176,17 +176,44 @@ year = 1956 } @Manual{torch, - title = {torch: Tensors and Neural Networks with 'GPU' Acceleration}, - author = {Daniel Falbel and Javier Luraschi}, - year = {2023}, - note = {R package version 0.11.0}, - url = {https://CRAN.R-project.org/package=torch}, - } + title = {torch: Tensors and Neural Networks with 'GPU' Acceleration}, + author = {Daniel Falbel and Javier Luraschi}, + year = {2023}, + note = {R package version 0.11.0}, + url = {https://CRAN.R-project.org/package=torch} +} @Manual{tidymodels, - title = {Tidymodels: a collection of packages for modeling and machine learning using tidyverse principles.}, - author = {Max Kuhn and Hadley Wickham}, - url = {https://www.tidymodels.org}, - year = {2020}, - } + title = {Tidymodels: a collection of packages for modeling and machine learning using tidyverse principles.}, + author = {Max Kuhn and Hadley Wickham}, + url = {https://www.tidymodels.org}, + year = {2020} +} + +@article{heskes2020causal, + title={Causal shapley values: Exploiting causal knowledge to explain individual predictions of complex models}, + author={Heskes, Tom and Sijben, Evi and Bucur, Ioan Gabriel and Claassen, Tom}, + journal={Advances in neural information processing systems}, + volume={33}, + pages={4778--4789}, + year={2020} +} + +@article{frye2020asymmetric, + title={Asymmetric shapley values: incorporating causal knowledge into model-agnostic explainability}, + author={Frye, Christopher and Rowat, Colin and Feige, Ilya}, + journal={Advances in Neural Information Processing Systems}, + volume={33}, + pages={1229--1239}, + year={2020} +} + +@inproceedings{covert2021improving, + title={Improving kernelshap: Practical shapley value estimation using linear regression}, + author={Covert, Ian and Lee, Su-In}, + booktitle={International Conference on Artificial Intelligence and Statistics}, + pages={3457--3465}, + year={2021}, + organization={PMLR} +} diff --git a/inst/code_paper/code_sec_3.R b/inst/code_paper/code_sec_3.R new file mode 100644 index 0000000000000000000000000000000000000000..55c3668269239b84b54920c45313805f3bc0a3ef --- /dev/null +++ b/inst/code_paper/code_sec_3.R @@ -0,0 +1,137 @@ +library(xgboost) +library(data.table) +library(shapr) + +path <- "inst/code_paper/" +x_explain <- fread(paste0(path, "x_explain.csv")) +x_train <- fread(paste0(path, "x_train.csv")) +y_train <- unlist(fread(paste0(path, "y_train.csv"))) +model <- readRDS(paste0(path, "model.rds")) + + +# We compute the SHAP values for the test data. +library(future) +library(progressr) +future::plan(multisession, workers = 4) +progressr::handlers(global = TRUE) + + +# 20 indep +exp_20_indep <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + max_n_coalitions = 20, + approach = "independence", + phi0 = mean(y_train), + verbose = NULL) + + +# 20 ctree +exp_20_ctree <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + max_n_coalitions = 20, + approach = "ctree", + phi0 = mean(y_train), + verbose = NULL, + ctree.sample = FALSE) + + + +exp_20_indep$MSEv$MSEv +exp_20_ctree$MSEv$MSEv + +##### OUTPUT #### +#> exp_20_indep$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1805368 123213.6 +#> exp_20_ctree$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1224818 101680.4 + +exp_20_ctree + +### Continued estimation + +exp_iter_ctree <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + phi0 = mean(y_train), + prev_shapr_object = exp_20_ctree, + ctree.sample = FALSE, + verbose = c("basic","convergence")) + + +### PLotting #### + +library(ggplot2) + +plot(exp_iter_ctree, plot_type = "scatter",scatter_features = c("atemp","windspeed")) + +ggplot2::ggsave("inst/code_paper/scatter_ctree.pdf",width = 7, height = 4) + +### Grouping + + +group <- list(temp = c("temp", "atemp"), + time = c("trend", "cosyear", "sinyear"), + weather = c("hum","windspeed")) + +exp_g_reg <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = mean(y_train), + group = group, + approach = "regression_separate", + regression.model = parsnip::boost_tree( + engine = "xgboost", + mode = "regression" + ), + verbose = NULL) + +tree_vals <- c(10, 15, 25, 50, 100, 500) +exp_g_reg_tuned <- explain(model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = mean(y_train), + group = group, + approach = "regression_separate", + regression.model = + parsnip::boost_tree( + trees = hardhat::tune(), + engine = "xgboost", mode = "regression" + ), + regression.tune_values = expand.grid( + trees = tree_vals + ), + regression.vfold_cv_para = list(v = 5), + verbose = NULL) + + +exp_g_reg$MSEv$MSEv +exp_g_reg_tuned$MSEv$MSEv + +#> exp_group_reg_sep_xgb$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1547240 142123.2 +#> exp_group_reg_sep_xgb_tuned$MSEv$MSEv +#MSEv MSEv_sd +# +# 1: 1534033 142277.4 + +# Plot the best one + +plot(exp_group_reg_sep_xgb_tuned,index_x_explain = 6,plot_type="waterfall") + +ggplot2::ggsave("inst/code_paper/waterfall_group.pdf",width = 7, height = 4) + +# Print Shapley value for the best ones + +head(exp_group_reg_sep_xgb_tuned$shapley_values_est) + + + diff --git a/inst/code_paper/code_sec_4.R b/inst/code_paper/code_sec_4.R new file mode 100644 index 0000000000000000000000000000000000000000..121453c8958dfbdb2afb17daf2a3fd7a80f37e5f --- /dev/null +++ b/inst/code_paper/code_sec_4.R @@ -0,0 +1,202 @@ +# Libraries +# library(ggplot2) +# require(GGally) +# library(ggpubr) +# library(gridExtra) + +# Libraries +library(xgboost) +library(shapr) + +# Download and set up the data as done in Heskes et al. (2020) +temp <- tempfile() +download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +bike <- read.csv(unz(temp, "day.csv")) +unlink(temp) +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# Training-test split. 80% training and 20% test +set.seed(123) +train_index <- sample(x = nrow(bike), size = round(0.8*nrow(bike))) + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost(data = x_train, label = y_train, nround = 100, verbose = FALSE) + +# Compute the phi0 +prediction_zero <- mean(y_train) + +# Specify the causal ordering and confounding +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) + +# Symmetric causal Shapley values: change asymmetric, causal_ordering, and confounding for other versions +explanation_sym_cau <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 100, # Just for speed + approach = "gaussian", + asymmetric = FALSE, + paired_shap_sampling = TRUE, # Paired sampling is default, but must be FALSE for asymmetric SV + causal_ordering = causal_ordering, + confounding = confounding +) + + +# Symmetric Shapley values +explanation_sym_con <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + prediction_zero = prediction_zero, + n_MC_samples = 1000, + verbose = c("basic", "progress", "convergence", "shapley", "vS_details") + # asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + # causal_ordering = NULL, # Default value + # confounding = NULL # Default value +) + +# Asymmetric Shapley values +explanation_asym_con <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value, + verbose = c("basic", "progress", "convergence", "shapley", "vS_details") +) + +# Asymmetric causal Shapley values +explanation_asym_cau <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = confounding +) + +# Symmetric marginal Shapley values +explanation_sym_marg <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Asymmetric marginal Shapley values +explanation_asym_marg <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + prediction_zero = prediction_zero, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1:7), + confounding = TRUE +) + + +# Combine the explanations +explanation_list = list("Symmetric conditional" = explanation_sym_con, + "Asymmetric conditional" = explanation_asym_con, + "Symmetric causal" = explanation_sym_cau, + "Asymmetric causal" = explanation_asym_cau, + "Symmetric marginal" = explanation_sym_marg, + "Asymmetric marginal" = explanation_asym_marg) + +# Make the beeswarm plots +grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm") + + ggplot2::ggtitle(gsub("_", " ", names(explanation_list)[[explanation_idx]])) + # ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) +}) + +# Get the limits +ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) +ylim <- c(min(ylim), max(ylim)) + +# Update the limits +grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + +# THE PLOT IN THE PAPER +fig_few2 = ggpubr::ggarrange(grobs[[2]], grobs[[3]], grobs[[5]], + ncol=3, nrow=1, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer_other.png", + plot = fig_few2, + scale = 0.85, + width = 14, + height = 4) + + +# OTHER PLOTS +# All 6 versions +fig = ggpubr::ggarrange(grobs[[1]], grobs[[3]], grobs[[5]], grobs[[2]], grobs[[4]], grobs[[6]], + ncol=3, nrow=2, common.legend = TRUE, legend="right") + +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig.png", plot = fig, + scale = 0.85, + width = 14, + height = 6) + +# Only 3 of them +fig_few = ggpubr::ggarrange(grobs[[1]], grobs[[2]], grobs[[3]], + ncol=3, nrow=1, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer.png", + plot = fig_few, + scale = 0.85, + width = 14, + height = 4) + +# Other four +fig_few3 = ggpubr::ggarrange(grobs[[1]], grobs[[3]], grobs[[2]], grobs[[5]], + ncol=2, nrow=2, common.legend = TRUE, legend="right") +ggsave(filename = "/Users/larsolsen/Downloads/Paper5_example_fig_fewer_other_other_2.png", + plot = fig_few3, + scale = 0.85, + width = 15, + height = 6) diff --git a/inst/code_paper/code_sec_5.py b/inst/code_paper/code_sec_5.py new file mode 100644 index 0000000000000000000000000000000000000000..6244f592927fc86e07057c2e4f8eed0f2fdc31bd --- /dev/null +++ b/inst/code_paper/code_sec_5.py @@ -0,0 +1,33 @@ +import xgboost as xgb +import pandas as pd +from shaprpy import explain + +path = "inst/code_paper/" + +# Read data +x_train = pd.read_csv(path + "x_train.csv") +x_explain = pd.read_csv(path + "x_explain.csv") +y_train = pd.read_csv(path + "y_train.csv") + +# Load the XGBoost model from the raw format and add feature names +model = xgb.Booster() +model.load_model(path +"xgb.model") +model.feature_names = x_train.columns.tolist() + +exp_20_ctree = explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = 'ctree', + phi0 = y_train.mean().item(), + max_n_coalitions=20, + ctree_sample = False) + + +# Print the Shapley values +print(exp_20_ctree['shapley_values_est'].iloc[:, 1:].round(1)) + + + + + diff --git a/inst/code_paper/code_sec_6.R b/inst/code_paper/code_sec_6.R new file mode 100644 index 0000000000000000000000000000000000000000..f71ba58d1f8b1b8bcdc29a6e6ae6f31090edc615 --- /dev/null +++ b/inst/code_paper/code_sec_6.R @@ -0,0 +1,43 @@ + +library(xgboost) +library(data.table) +library(shapr) + +path <- "inst/code_paper/" +x_full <- fread(paste0(path, "x_full.csv")) + + +model_ar <- ar(x_full$temp, order = 2) + +phi0_ar <- rep(mean(x_full$temp), 3) + +explain_forecast( + model = model_ar, + y = x_full[, "temp"], + train_idx = 2:729, + explain_idx = 730:731, + explain_y_lags = 2, + horizon = 3, + approach = "empirical", + phi0 = phi0_ar, + group_lags = FALSE +) + + +data_fit <- x_full[seq_len(729), ] +model_arimax <- arima(data_fit$temp, order = c(2, 0, 0), xreg = data_fit$windspeed) +phi0_arimax <- rep(mean(data_fit$temp), 2) + +explain_forecast( + model = model_arimax, + y = data_fit[, "temp"], + xreg = bike[, "windspeed"], + train_idx = 2:728, + explain_idx = 729, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = phi0_arimax, + group_lags = TRUE +) diff --git a/inst/code_paper/day.csv b/inst/code_paper/day.csv new file mode 100644 index 0000000000000000000000000000000000000000..7498062a459fff5db9254f5e1a100defa6a961b5 --- /dev/null +++ b/inst/code_paper/day.csv @@ -0,0 +1,732 @@ +instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt +1,2011-01-01,1,0,1,0,6,0,2,0.344167,0.363625,0.805833,0.160446,331,654,985 +2,2011-01-02,1,0,1,0,0,0,2,0.363478,0.353739,0.696087,0.248539,131,670,801 +3,2011-01-03,1,0,1,0,1,1,1,0.196364,0.189405,0.437273,0.248309,120,1229,1349 +4,2011-01-04,1,0,1,0,2,1,1,0.2,0.212122,0.590435,0.160296,108,1454,1562 +5,2011-01-05,1,0,1,0,3,1,1,0.226957,0.22927,0.436957,0.1869,82,1518,1600 +6,2011-01-06,1,0,1,0,4,1,1,0.204348,0.233209,0.518261,0.0895652,88,1518,1606 +7,2011-01-07,1,0,1,0,5,1,2,0.196522,0.208839,0.498696,0.168726,148,1362,1510 +8,2011-01-08,1,0,1,0,6,0,2,0.165,0.162254,0.535833,0.266804,68,891,959 +9,2011-01-09,1,0,1,0,0,0,1,0.138333,0.116175,0.434167,0.36195,54,768,822 +10,2011-01-10,1,0,1,0,1,1,1,0.150833,0.150888,0.482917,0.223267,41,1280,1321 +11,2011-01-11,1,0,1,0,2,1,2,0.169091,0.191464,0.686364,0.122132,43,1220,1263 +12,2011-01-12,1,0,1,0,3,1,1,0.172727,0.160473,0.599545,0.304627,25,1137,1162 +13,2011-01-13,1,0,1,0,4,1,1,0.165,0.150883,0.470417,0.301,38,1368,1406 +14,2011-01-14,1,0,1,0,5,1,1,0.16087,0.188413,0.537826,0.126548,54,1367,1421 +15,2011-01-15,1,0,1,0,6,0,2,0.233333,0.248112,0.49875,0.157963,222,1026,1248 +16,2011-01-16,1,0,1,0,0,0,1,0.231667,0.234217,0.48375,0.188433,251,953,1204 +17,2011-01-17,1,0,1,1,1,0,2,0.175833,0.176771,0.5375,0.194017,117,883,1000 +18,2011-01-18,1,0,1,0,2,1,2,0.216667,0.232333,0.861667,0.146775,9,674,683 +19,2011-01-19,1,0,1,0,3,1,2,0.292174,0.298422,0.741739,0.208317,78,1572,1650 +20,2011-01-20,1,0,1,0,4,1,2,0.261667,0.25505,0.538333,0.195904,83,1844,1927 +21,2011-01-21,1,0,1,0,5,1,1,0.1775,0.157833,0.457083,0.353242,75,1468,1543 +22,2011-01-22,1,0,1,0,6,0,1,0.0591304,0.0790696,0.4,0.17197,93,888,981 +23,2011-01-23,1,0,1,0,0,0,1,0.0965217,0.0988391,0.436522,0.2466,150,836,986 +24,2011-01-24,1,0,1,0,1,1,1,0.0973913,0.11793,0.491739,0.15833,86,1330,1416 +25,2011-01-25,1,0,1,0,2,1,2,0.223478,0.234526,0.616957,0.129796,186,1799,1985 +26,2011-01-26,1,0,1,0,3,1,3,0.2175,0.2036,0.8625,0.29385,34,472,506 +27,2011-01-27,1,0,1,0,4,1,1,0.195,0.2197,0.6875,0.113837,15,416,431 +28,2011-01-28,1,0,1,0,5,1,2,0.203478,0.223317,0.793043,0.1233,38,1129,1167 +29,2011-01-29,1,0,1,0,6,0,1,0.196522,0.212126,0.651739,0.145365,123,975,1098 +30,2011-01-30,1,0,1,0,0,0,1,0.216522,0.250322,0.722174,0.0739826,140,956,1096 +31,2011-01-31,1,0,1,0,1,1,2,0.180833,0.18625,0.60375,0.187192,42,1459,1501 +32,2011-02-01,1,0,2,0,2,1,2,0.192174,0.23453,0.829565,0.053213,47,1313,1360 +33,2011-02-02,1,0,2,0,3,1,2,0.26,0.254417,0.775417,0.264308,72,1454,1526 +34,2011-02-03,1,0,2,0,4,1,1,0.186957,0.177878,0.437826,0.277752,61,1489,1550 +35,2011-02-04,1,0,2,0,5,1,2,0.211304,0.228587,0.585217,0.127839,88,1620,1708 +36,2011-02-05,1,0,2,0,6,0,2,0.233333,0.243058,0.929167,0.161079,100,905,1005 +37,2011-02-06,1,0,2,0,0,0,1,0.285833,0.291671,0.568333,0.1418,354,1269,1623 +38,2011-02-07,1,0,2,0,1,1,1,0.271667,0.303658,0.738333,0.0454083,120,1592,1712 +39,2011-02-08,1,0,2,0,2,1,1,0.220833,0.198246,0.537917,0.36195,64,1466,1530 +40,2011-02-09,1,0,2,0,3,1,2,0.134783,0.144283,0.494783,0.188839,53,1552,1605 +41,2011-02-10,1,0,2,0,4,1,1,0.144348,0.149548,0.437391,0.221935,47,1491,1538 +42,2011-02-11,1,0,2,0,5,1,1,0.189091,0.213509,0.506364,0.10855,149,1597,1746 +43,2011-02-12,1,0,2,0,6,0,1,0.2225,0.232954,0.544167,0.203367,288,1184,1472 +44,2011-02-13,1,0,2,0,0,0,1,0.316522,0.324113,0.457391,0.260883,397,1192,1589 +45,2011-02-14,1,0,2,0,1,1,1,0.415,0.39835,0.375833,0.417908,208,1705,1913 +46,2011-02-15,1,0,2,0,2,1,1,0.266087,0.254274,0.314348,0.291374,140,1675,1815 +47,2011-02-16,1,0,2,0,3,1,1,0.318261,0.3162,0.423478,0.251791,218,1897,2115 +48,2011-02-17,1,0,2,0,4,1,1,0.435833,0.428658,0.505,0.230104,259,2216,2475 +49,2011-02-18,1,0,2,0,5,1,1,0.521667,0.511983,0.516667,0.264925,579,2348,2927 +50,2011-02-19,1,0,2,0,6,0,1,0.399167,0.391404,0.187917,0.507463,532,1103,1635 +51,2011-02-20,1,0,2,0,0,0,1,0.285217,0.27733,0.407826,0.223235,639,1173,1812 +52,2011-02-21,1,0,2,1,1,0,2,0.303333,0.284075,0.605,0.307846,195,912,1107 +53,2011-02-22,1,0,2,0,2,1,1,0.182222,0.186033,0.577778,0.195683,74,1376,1450 +54,2011-02-23,1,0,2,0,3,1,1,0.221739,0.245717,0.423043,0.094113,139,1778,1917 +55,2011-02-24,1,0,2,0,4,1,2,0.295652,0.289191,0.697391,0.250496,100,1707,1807 +56,2011-02-25,1,0,2,0,5,1,2,0.364348,0.350461,0.712174,0.346539,120,1341,1461 +57,2011-02-26,1,0,2,0,6,0,1,0.2825,0.282192,0.537917,0.186571,424,1545,1969 +58,2011-02-27,1,0,2,0,0,0,1,0.343478,0.351109,0.68,0.125248,694,1708,2402 +59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,81,1365,1446 +60,2011-03-01,1,0,3,0,2,1,1,0.266667,0.263879,0.535,0.216425,137,1714,1851 +61,2011-03-02,1,0,3,0,3,1,1,0.335,0.320071,0.449583,0.307833,231,1903,2134 +62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,123,1562,1685 +63,2011-03-04,1,0,3,0,5,1,2,0.261667,0.255679,0.610417,0.203346,214,1730,1944 +64,2011-03-05,1,0,3,0,6,0,2,0.384167,0.378779,0.789167,0.251871,640,1437,2077 +65,2011-03-06,1,0,3,0,0,0,2,0.376522,0.366252,0.948261,0.343287,114,491,605 +66,2011-03-07,1,0,3,0,1,1,1,0.261739,0.238461,0.551304,0.341352,244,1628,1872 +67,2011-03-08,1,0,3,0,2,1,1,0.2925,0.3024,0.420833,0.12065,316,1817,2133 +68,2011-03-09,1,0,3,0,3,1,2,0.295833,0.286608,0.775417,0.22015,191,1700,1891 +69,2011-03-10,1,0,3,0,4,1,3,0.389091,0.385668,0,0.261877,46,577,623 +70,2011-03-11,1,0,3,0,5,1,2,0.316522,0.305,0.649565,0.23297,247,1730,1977 +71,2011-03-12,1,0,3,0,6,0,1,0.329167,0.32575,0.594583,0.220775,724,1408,2132 +72,2011-03-13,1,0,3,0,0,0,1,0.384348,0.380091,0.527391,0.270604,982,1435,2417 +73,2011-03-14,1,0,3,0,1,1,1,0.325217,0.332,0.496957,0.136926,359,1687,2046 +74,2011-03-15,1,0,3,0,2,1,2,0.317391,0.318178,0.655652,0.184309,289,1767,2056 +75,2011-03-16,1,0,3,0,3,1,2,0.365217,0.36693,0.776522,0.203117,321,1871,2192 +76,2011-03-17,1,0,3,0,4,1,1,0.415,0.410333,0.602917,0.209579,424,2320,2744 +77,2011-03-18,1,0,3,0,5,1,1,0.54,0.527009,0.525217,0.231017,884,2355,3239 +78,2011-03-19,1,0,3,0,6,0,1,0.4725,0.466525,0.379167,0.368167,1424,1693,3117 +79,2011-03-20,1,0,3,0,0,0,1,0.3325,0.32575,0.47375,0.207721,1047,1424,2471 +80,2011-03-21,2,0,3,0,1,1,2,0.430435,0.409735,0.737391,0.288783,401,1676,2077 +81,2011-03-22,2,0,3,0,2,1,1,0.441667,0.440642,0.624583,0.22575,460,2243,2703 +82,2011-03-23,2,0,3,0,3,1,2,0.346957,0.337939,0.839565,0.234261,203,1918,2121 +83,2011-03-24,2,0,3,0,4,1,2,0.285,0.270833,0.805833,0.243787,166,1699,1865 +84,2011-03-25,2,0,3,0,5,1,1,0.264167,0.256312,0.495,0.230725,300,1910,2210 +85,2011-03-26,2,0,3,0,6,0,1,0.265833,0.257571,0.394167,0.209571,981,1515,2496 +86,2011-03-27,2,0,3,0,0,0,2,0.253043,0.250339,0.493913,0.1843,472,1221,1693 +87,2011-03-28,2,0,3,0,1,1,1,0.264348,0.257574,0.302174,0.212204,222,1806,2028 +88,2011-03-29,2,0,3,0,2,1,1,0.3025,0.292908,0.314167,0.226996,317,2108,2425 +89,2011-03-30,2,0,3,0,3,1,2,0.3,0.29735,0.646667,0.172888,168,1368,1536 +90,2011-03-31,2,0,3,0,4,1,3,0.268333,0.257575,0.918333,0.217646,179,1506,1685 +91,2011-04-01,2,0,4,0,5,1,2,0.3,0.283454,0.68625,0.258708,307,1920,2227 +92,2011-04-02,2,0,4,0,6,0,2,0.315,0.315637,0.65375,0.197146,898,1354,2252 +93,2011-04-03,2,0,4,0,0,0,1,0.378333,0.378767,0.48,0.182213,1651,1598,3249 +94,2011-04-04,2,0,4,0,1,1,1,0.573333,0.542929,0.42625,0.385571,734,2381,3115 +95,2011-04-05,2,0,4,0,2,1,2,0.414167,0.39835,0.642083,0.388067,167,1628,1795 +96,2011-04-06,2,0,4,0,3,1,1,0.390833,0.387608,0.470833,0.263063,413,2395,2808 +97,2011-04-07,2,0,4,0,4,1,1,0.4375,0.433696,0.602917,0.162312,571,2570,3141 +98,2011-04-08,2,0,4,0,5,1,2,0.335833,0.324479,0.83625,0.226992,172,1299,1471 +99,2011-04-09,2,0,4,0,6,0,2,0.3425,0.341529,0.8775,0.133083,879,1576,2455 +100,2011-04-10,2,0,4,0,0,0,2,0.426667,0.426737,0.8575,0.146767,1188,1707,2895 +101,2011-04-11,2,0,4,0,1,1,2,0.595652,0.565217,0.716956,0.324474,855,2493,3348 +102,2011-04-12,2,0,4,0,2,1,2,0.5025,0.493054,0.739167,0.274879,257,1777,2034 +103,2011-04-13,2,0,4,0,3,1,2,0.4125,0.417283,0.819167,0.250617,209,1953,2162 +104,2011-04-14,2,0,4,0,4,1,1,0.4675,0.462742,0.540417,0.1107,529,2738,3267 +105,2011-04-15,2,0,4,1,5,0,1,0.446667,0.441913,0.67125,0.226375,642,2484,3126 +106,2011-04-16,2,0,4,0,6,0,3,0.430833,0.425492,0.888333,0.340808,121,674,795 +107,2011-04-17,2,0,4,0,0,0,1,0.456667,0.445696,0.479583,0.303496,1558,2186,3744 +108,2011-04-18,2,0,4,0,1,1,1,0.5125,0.503146,0.5425,0.163567,669,2760,3429 +109,2011-04-19,2,0,4,0,2,1,2,0.505833,0.489258,0.665833,0.157971,409,2795,3204 +110,2011-04-20,2,0,4,0,3,1,1,0.595,0.564392,0.614167,0.241925,613,3331,3944 +111,2011-04-21,2,0,4,0,4,1,1,0.459167,0.453892,0.407083,0.325258,745,3444,4189 +112,2011-04-22,2,0,4,0,5,1,2,0.336667,0.321954,0.729583,0.219521,177,1506,1683 +113,2011-04-23,2,0,4,0,6,0,2,0.46,0.450121,0.887917,0.230725,1462,2574,4036 +114,2011-04-24,2,0,4,0,0,0,2,0.581667,0.551763,0.810833,0.192175,1710,2481,4191 +115,2011-04-25,2,0,4,0,1,1,1,0.606667,0.5745,0.776667,0.185333,773,3300,4073 +116,2011-04-26,2,0,4,0,2,1,1,0.631667,0.594083,0.729167,0.3265,678,3722,4400 +117,2011-04-27,2,0,4,0,3,1,2,0.62,0.575142,0.835417,0.3122,547,3325,3872 +118,2011-04-28,2,0,4,0,4,1,2,0.6175,0.578929,0.700833,0.320908,569,3489,4058 +119,2011-04-29,2,0,4,0,5,1,1,0.51,0.497463,0.457083,0.240063,878,3717,4595 +120,2011-04-30,2,0,4,0,6,0,1,0.4725,0.464021,0.503333,0.235075,1965,3347,5312 +121,2011-05-01,2,0,5,0,0,0,2,0.451667,0.448204,0.762083,0.106354,1138,2213,3351 +122,2011-05-02,2,0,5,0,1,1,2,0.549167,0.532833,0.73,0.183454,847,3554,4401 +123,2011-05-03,2,0,5,0,2,1,2,0.616667,0.582079,0.697083,0.342667,603,3848,4451 +124,2011-05-04,2,0,5,0,3,1,2,0.414167,0.40465,0.737083,0.328996,255,2378,2633 +125,2011-05-05,2,0,5,0,4,1,1,0.459167,0.441917,0.444167,0.295392,614,3819,4433 +126,2011-05-06,2,0,5,0,5,1,1,0.479167,0.474117,0.59,0.228246,894,3714,4608 +127,2011-05-07,2,0,5,0,6,0,1,0.52,0.512621,0.54125,0.16045,1612,3102,4714 +128,2011-05-08,2,0,5,0,0,0,1,0.528333,0.518933,0.631667,0.0746375,1401,2932,4333 +129,2011-05-09,2,0,5,0,1,1,1,0.5325,0.525246,0.58875,0.176,664,3698,4362 +130,2011-05-10,2,0,5,0,2,1,1,0.5325,0.522721,0.489167,0.115671,694,4109,4803 +131,2011-05-11,2,0,5,0,3,1,1,0.5425,0.5284,0.632917,0.120642,550,3632,4182 +132,2011-05-12,2,0,5,0,4,1,1,0.535,0.523363,0.7475,0.189667,695,4169,4864 +133,2011-05-13,2,0,5,0,5,1,2,0.5125,0.4943,0.863333,0.179725,692,3413,4105 +134,2011-05-14,2,0,5,0,6,0,2,0.520833,0.500629,0.9225,0.13495,902,2507,3409 +135,2011-05-15,2,0,5,0,0,0,2,0.5625,0.536,0.867083,0.152979,1582,2971,4553 +136,2011-05-16,2,0,5,0,1,1,1,0.5775,0.550512,0.787917,0.126871,773,3185,3958 +137,2011-05-17,2,0,5,0,2,1,2,0.561667,0.538529,0.837917,0.277354,678,3445,4123 +138,2011-05-18,2,0,5,0,3,1,2,0.55,0.527158,0.87,0.201492,536,3319,3855 +139,2011-05-19,2,0,5,0,4,1,2,0.530833,0.510742,0.829583,0.108213,735,3840,4575 +140,2011-05-20,2,0,5,0,5,1,1,0.536667,0.529042,0.719583,0.125013,909,4008,4917 +141,2011-05-21,2,0,5,0,6,0,1,0.6025,0.571975,0.626667,0.12065,2258,3547,5805 +142,2011-05-22,2,0,5,0,0,0,1,0.604167,0.5745,0.749583,0.148008,1576,3084,4660 +143,2011-05-23,2,0,5,0,1,1,2,0.631667,0.590296,0.81,0.233842,836,3438,4274 +144,2011-05-24,2,0,5,0,2,1,2,0.66,0.604813,0.740833,0.207092,659,3833,4492 +145,2011-05-25,2,0,5,0,3,1,1,0.660833,0.615542,0.69625,0.154233,740,4238,4978 +146,2011-05-26,2,0,5,0,4,1,1,0.708333,0.654688,0.6775,0.199642,758,3919,4677 +147,2011-05-27,2,0,5,0,5,1,1,0.681667,0.637008,0.65375,0.240679,871,3808,4679 +148,2011-05-28,2,0,5,0,6,0,1,0.655833,0.612379,0.729583,0.230092,2001,2757,4758 +149,2011-05-29,2,0,5,0,0,0,1,0.6675,0.61555,0.81875,0.213938,2355,2433,4788 +150,2011-05-30,2,0,5,1,1,0,1,0.733333,0.671092,0.685,0.131225,1549,2549,4098 +151,2011-05-31,2,0,5,0,2,1,1,0.775,0.725383,0.636667,0.111329,673,3309,3982 +152,2011-06-01,2,0,6,0,3,1,2,0.764167,0.720967,0.677083,0.207092,513,3461,3974 +153,2011-06-02,2,0,6,0,4,1,1,0.715,0.643942,0.305,0.292287,736,4232,4968 +154,2011-06-03,2,0,6,0,5,1,1,0.62,0.587133,0.354167,0.253121,898,4414,5312 +155,2011-06-04,2,0,6,0,6,0,1,0.635,0.594696,0.45625,0.123142,1869,3473,5342 +156,2011-06-05,2,0,6,0,0,0,2,0.648333,0.616804,0.6525,0.138692,1685,3221,4906 +157,2011-06-06,2,0,6,0,1,1,1,0.678333,0.621858,0.6,0.121896,673,3875,4548 +158,2011-06-07,2,0,6,0,2,1,1,0.7075,0.65595,0.597917,0.187808,763,4070,4833 +159,2011-06-08,2,0,6,0,3,1,1,0.775833,0.727279,0.622083,0.136817,676,3725,4401 +160,2011-06-09,2,0,6,0,4,1,2,0.808333,0.757579,0.568333,0.149883,563,3352,3915 +161,2011-06-10,2,0,6,0,5,1,1,0.755,0.703292,0.605,0.140554,815,3771,4586 +162,2011-06-11,2,0,6,0,6,0,1,0.725,0.678038,0.654583,0.15485,1729,3237,4966 +163,2011-06-12,2,0,6,0,0,0,1,0.6925,0.643325,0.747917,0.163567,1467,2993,4460 +164,2011-06-13,2,0,6,0,1,1,1,0.635,0.601654,0.494583,0.30535,863,4157,5020 +165,2011-06-14,2,0,6,0,2,1,1,0.604167,0.591546,0.507083,0.269283,727,4164,4891 +166,2011-06-15,2,0,6,0,3,1,1,0.626667,0.587754,0.471667,0.167912,769,4411,5180 +167,2011-06-16,2,0,6,0,4,1,2,0.628333,0.595346,0.688333,0.206471,545,3222,3767 +168,2011-06-17,2,0,6,0,5,1,1,0.649167,0.600383,0.735833,0.143029,863,3981,4844 +169,2011-06-18,2,0,6,0,6,0,1,0.696667,0.643954,0.670417,0.119408,1807,3312,5119 +170,2011-06-19,2,0,6,0,0,0,2,0.699167,0.645846,0.666667,0.102,1639,3105,4744 +171,2011-06-20,2,0,6,0,1,1,2,0.635,0.595346,0.74625,0.155475,699,3311,4010 +172,2011-06-21,3,0,6,0,2,1,2,0.680833,0.637646,0.770417,0.171025,774,4061,4835 +173,2011-06-22,3,0,6,0,3,1,1,0.733333,0.693829,0.7075,0.172262,661,3846,4507 +174,2011-06-23,3,0,6,0,4,1,2,0.728333,0.693833,0.703333,0.238804,746,4044,4790 +175,2011-06-24,3,0,6,0,5,1,1,0.724167,0.656583,0.573333,0.222025,969,4022,4991 +176,2011-06-25,3,0,6,0,6,0,1,0.695,0.643313,0.483333,0.209571,1782,3420,5202 +177,2011-06-26,3,0,6,0,0,0,1,0.68,0.637629,0.513333,0.0945333,1920,3385,5305 +178,2011-06-27,3,0,6,0,1,1,2,0.6825,0.637004,0.658333,0.107588,854,3854,4708 +179,2011-06-28,3,0,6,0,2,1,1,0.744167,0.692558,0.634167,0.144283,732,3916,4648 +180,2011-06-29,3,0,6,0,3,1,1,0.728333,0.654688,0.497917,0.261821,848,4377,5225 +181,2011-06-30,3,0,6,0,4,1,1,0.696667,0.637008,0.434167,0.185312,1027,4488,5515 +182,2011-07-01,3,0,7,0,5,1,1,0.7225,0.652162,0.39625,0.102608,1246,4116,5362 +183,2011-07-02,3,0,7,0,6,0,1,0.738333,0.667308,0.444583,0.115062,2204,2915,5119 +184,2011-07-03,3,0,7,0,0,0,2,0.716667,0.668575,0.6825,0.228858,2282,2367,4649 +185,2011-07-04,3,0,7,1,1,0,2,0.726667,0.665417,0.637917,0.0814792,3065,2978,6043 +186,2011-07-05,3,0,7,0,2,1,1,0.746667,0.696338,0.590417,0.126258,1031,3634,4665 +187,2011-07-06,3,0,7,0,3,1,1,0.72,0.685633,0.743333,0.149883,784,3845,4629 +188,2011-07-07,3,0,7,0,4,1,1,0.75,0.686871,0.65125,0.1592,754,3838,4592 +189,2011-07-08,3,0,7,0,5,1,2,0.709167,0.670483,0.757917,0.225129,692,3348,4040 +190,2011-07-09,3,0,7,0,6,0,1,0.733333,0.664158,0.609167,0.167912,1988,3348,5336 +191,2011-07-10,3,0,7,0,0,0,1,0.7475,0.690025,0.578333,0.183471,1743,3138,4881 +192,2011-07-11,3,0,7,0,1,1,1,0.7625,0.729804,0.635833,0.282337,723,3363,4086 +193,2011-07-12,3,0,7,0,2,1,1,0.794167,0.739275,0.559167,0.200254,662,3596,4258 +194,2011-07-13,3,0,7,0,3,1,1,0.746667,0.689404,0.631667,0.146133,748,3594,4342 +195,2011-07-14,3,0,7,0,4,1,1,0.680833,0.635104,0.47625,0.240667,888,4196,5084 +196,2011-07-15,3,0,7,0,5,1,1,0.663333,0.624371,0.59125,0.182833,1318,4220,5538 +197,2011-07-16,3,0,7,0,6,0,1,0.686667,0.638263,0.585,0.208342,2418,3505,5923 +198,2011-07-17,3,0,7,0,0,0,1,0.719167,0.669833,0.604167,0.245033,2006,3296,5302 +199,2011-07-18,3,0,7,0,1,1,1,0.746667,0.703925,0.65125,0.215804,841,3617,4458 +200,2011-07-19,3,0,7,0,2,1,1,0.776667,0.747479,0.650417,0.1306,752,3789,4541 +201,2011-07-20,3,0,7,0,3,1,1,0.768333,0.74685,0.707083,0.113817,644,3688,4332 +202,2011-07-21,3,0,7,0,4,1,2,0.815,0.826371,0.69125,0.222021,632,3152,3784 +203,2011-07-22,3,0,7,0,5,1,1,0.848333,0.840896,0.580417,0.1331,562,2825,3387 +204,2011-07-23,3,0,7,0,6,0,1,0.849167,0.804287,0.5,0.131221,987,2298,3285 +205,2011-07-24,3,0,7,0,0,0,1,0.83,0.794829,0.550833,0.169171,1050,2556,3606 +206,2011-07-25,3,0,7,0,1,1,1,0.743333,0.720958,0.757083,0.0908083,568,3272,3840 +207,2011-07-26,3,0,7,0,2,1,1,0.771667,0.696979,0.540833,0.200258,750,3840,4590 +208,2011-07-27,3,0,7,0,3,1,1,0.775,0.690667,0.402917,0.183463,755,3901,4656 +209,2011-07-28,3,0,7,0,4,1,1,0.779167,0.7399,0.583333,0.178479,606,3784,4390 +210,2011-07-29,3,0,7,0,5,1,1,0.838333,0.785967,0.5425,0.174138,670,3176,3846 +211,2011-07-30,3,0,7,0,6,0,1,0.804167,0.728537,0.465833,0.168537,1559,2916,4475 +212,2011-07-31,3,0,7,0,0,0,1,0.805833,0.729796,0.480833,0.164813,1524,2778,4302 +213,2011-08-01,3,0,8,0,1,1,1,0.771667,0.703292,0.550833,0.156717,729,3537,4266 +214,2011-08-02,3,0,8,0,2,1,1,0.783333,0.707071,0.49125,0.20585,801,4044,4845 +215,2011-08-03,3,0,8,0,3,1,2,0.731667,0.679937,0.6575,0.135583,467,3107,3574 +216,2011-08-04,3,0,8,0,4,1,2,0.71,0.664788,0.7575,0.19715,799,3777,4576 +217,2011-08-05,3,0,8,0,5,1,1,0.710833,0.656567,0.630833,0.184696,1023,3843,4866 +218,2011-08-06,3,0,8,0,6,0,2,0.716667,0.676154,0.755,0.22825,1521,2773,4294 +219,2011-08-07,3,0,8,0,0,0,1,0.7425,0.715292,0.752917,0.201487,1298,2487,3785 +220,2011-08-08,3,0,8,0,1,1,1,0.765,0.703283,0.592083,0.192175,846,3480,4326 +221,2011-08-09,3,0,8,0,2,1,1,0.775,0.724121,0.570417,0.151121,907,3695,4602 +222,2011-08-10,3,0,8,0,3,1,1,0.766667,0.684983,0.424167,0.200258,884,3896,4780 +223,2011-08-11,3,0,8,0,4,1,1,0.7175,0.651521,0.42375,0.164796,812,3980,4792 +224,2011-08-12,3,0,8,0,5,1,1,0.708333,0.654042,0.415,0.125621,1051,3854,4905 +225,2011-08-13,3,0,8,0,6,0,2,0.685833,0.645858,0.729583,0.211454,1504,2646,4150 +226,2011-08-14,3,0,8,0,0,0,2,0.676667,0.624388,0.8175,0.222633,1338,2482,3820 +227,2011-08-15,3,0,8,0,1,1,1,0.665833,0.616167,0.712083,0.208954,775,3563,4338 +228,2011-08-16,3,0,8,0,2,1,1,0.700833,0.645837,0.578333,0.236329,721,4004,4725 +229,2011-08-17,3,0,8,0,3,1,1,0.723333,0.666671,0.575417,0.143667,668,4026,4694 +230,2011-08-18,3,0,8,0,4,1,1,0.711667,0.662258,0.654583,0.233208,639,3166,3805 +231,2011-08-19,3,0,8,0,5,1,2,0.685,0.633221,0.722917,0.139308,797,3356,4153 +232,2011-08-20,3,0,8,0,6,0,1,0.6975,0.648996,0.674167,0.104467,1914,3277,5191 +233,2011-08-21,3,0,8,0,0,0,1,0.710833,0.675525,0.77,0.248754,1249,2624,3873 +234,2011-08-22,3,0,8,0,1,1,1,0.691667,0.638254,0.47,0.27675,833,3925,4758 +235,2011-08-23,3,0,8,0,2,1,1,0.640833,0.606067,0.455417,0.146763,1281,4614,5895 +236,2011-08-24,3,0,8,0,3,1,1,0.673333,0.630692,0.605,0.253108,949,4181,5130 +237,2011-08-25,3,0,8,0,4,1,2,0.684167,0.645854,0.771667,0.210833,435,3107,3542 +238,2011-08-26,3,0,8,0,5,1,1,0.7,0.659733,0.76125,0.0839625,768,3893,4661 +239,2011-08-27,3,0,8,0,6,0,2,0.68,0.635556,0.85,0.375617,226,889,1115 +240,2011-08-28,3,0,8,0,0,0,1,0.707059,0.647959,0.561765,0.304659,1415,2919,4334 +241,2011-08-29,3,0,8,0,1,1,1,0.636667,0.607958,0.554583,0.159825,729,3905,4634 +242,2011-08-30,3,0,8,0,2,1,1,0.639167,0.594704,0.548333,0.125008,775,4429,5204 +243,2011-08-31,3,0,8,0,3,1,1,0.656667,0.611121,0.597917,0.0833333,688,4370,5058 +244,2011-09-01,3,0,9,0,4,1,1,0.655,0.614921,0.639167,0.141796,783,4332,5115 +245,2011-09-02,3,0,9,0,5,1,2,0.643333,0.604808,0.727083,0.139929,875,3852,4727 +246,2011-09-03,3,0,9,0,6,0,1,0.669167,0.633213,0.716667,0.185325,1935,2549,4484 +247,2011-09-04,3,0,9,0,0,0,1,0.709167,0.665429,0.742083,0.206467,2521,2419,4940 +248,2011-09-05,3,0,9,1,1,0,2,0.673333,0.625646,0.790417,0.212696,1236,2115,3351 +249,2011-09-06,3,0,9,0,2,1,3,0.54,0.5152,0.886957,0.343943,204,2506,2710 +250,2011-09-07,3,0,9,0,3,1,3,0.599167,0.544229,0.917083,0.0970208,118,1878,1996 +251,2011-09-08,3,0,9,0,4,1,3,0.633913,0.555361,0.939565,0.192748,153,1689,1842 +252,2011-09-09,3,0,9,0,5,1,2,0.65,0.578946,0.897917,0.124379,417,3127,3544 +253,2011-09-10,3,0,9,0,6,0,1,0.66,0.607962,0.75375,0.153608,1750,3595,5345 +254,2011-09-11,3,0,9,0,0,0,1,0.653333,0.609229,0.71375,0.115054,1633,3413,5046 +255,2011-09-12,3,0,9,0,1,1,1,0.644348,0.60213,0.692174,0.088913,690,4023,4713 +256,2011-09-13,3,0,9,0,2,1,1,0.650833,0.603554,0.7125,0.141804,701,4062,4763 +257,2011-09-14,3,0,9,0,3,1,1,0.673333,0.6269,0.697083,0.1673,647,4138,4785 +258,2011-09-15,3,0,9,0,4,1,2,0.5775,0.553671,0.709167,0.271146,428,3231,3659 +259,2011-09-16,3,0,9,0,5,1,2,0.469167,0.461475,0.590417,0.164183,742,4018,4760 +260,2011-09-17,3,0,9,0,6,0,2,0.491667,0.478512,0.718333,0.189675,1434,3077,4511 +261,2011-09-18,3,0,9,0,0,0,1,0.5075,0.490537,0.695,0.178483,1353,2921,4274 +262,2011-09-19,3,0,9,0,1,1,2,0.549167,0.529675,0.69,0.151742,691,3848,4539 +263,2011-09-20,3,0,9,0,2,1,2,0.561667,0.532217,0.88125,0.134954,438,3203,3641 +264,2011-09-21,3,0,9,0,3,1,2,0.595,0.550533,0.9,0.0964042,539,3813,4352 +265,2011-09-22,3,0,9,0,4,1,2,0.628333,0.554963,0.902083,0.128125,555,4240,4795 +266,2011-09-23,4,0,9,0,5,1,2,0.609167,0.522125,0.9725,0.0783667,258,2137,2395 +267,2011-09-24,4,0,9,0,6,0,2,0.606667,0.564412,0.8625,0.0783833,1776,3647,5423 +268,2011-09-25,4,0,9,0,0,0,2,0.634167,0.572637,0.845,0.0503792,1544,3466,5010 +269,2011-09-26,4,0,9,0,1,1,2,0.649167,0.589042,0.848333,0.1107,684,3946,4630 +270,2011-09-27,4,0,9,0,2,1,2,0.636667,0.574525,0.885417,0.118171,477,3643,4120 +271,2011-09-28,4,0,9,0,3,1,2,0.635,0.575158,0.84875,0.148629,480,3427,3907 +272,2011-09-29,4,0,9,0,4,1,1,0.616667,0.574512,0.699167,0.172883,653,4186,4839 +273,2011-09-30,4,0,9,0,5,1,1,0.564167,0.544829,0.6475,0.206475,830,4372,5202 +274,2011-10-01,4,0,10,0,6,0,2,0.41,0.412863,0.75375,0.292296,480,1949,2429 +275,2011-10-02,4,0,10,0,0,0,2,0.356667,0.345317,0.791667,0.222013,616,2302,2918 +276,2011-10-03,4,0,10,0,1,1,2,0.384167,0.392046,0.760833,0.0833458,330,3240,3570 +277,2011-10-04,4,0,10,0,2,1,1,0.484167,0.472858,0.71,0.205854,486,3970,4456 +278,2011-10-05,4,0,10,0,3,1,1,0.538333,0.527138,0.647917,0.17725,559,4267,4826 +279,2011-10-06,4,0,10,0,4,1,1,0.494167,0.480425,0.620833,0.134954,639,4126,4765 +280,2011-10-07,4,0,10,0,5,1,1,0.510833,0.504404,0.684167,0.0223917,949,4036,4985 +281,2011-10-08,4,0,10,0,6,0,1,0.521667,0.513242,0.70125,0.0454042,2235,3174,5409 +282,2011-10-09,4,0,10,0,0,0,1,0.540833,0.523983,0.7275,0.06345,2397,3114,5511 +283,2011-10-10,4,0,10,1,1,0,1,0.570833,0.542925,0.73375,0.0423042,1514,3603,5117 +284,2011-10-11,4,0,10,0,2,1,2,0.566667,0.546096,0.80875,0.143042,667,3896,4563 +285,2011-10-12,4,0,10,0,3,1,3,0.543333,0.517717,0.90625,0.24815,217,2199,2416 +286,2011-10-13,4,0,10,0,4,1,2,0.589167,0.551804,0.896667,0.141787,290,2623,2913 +287,2011-10-14,4,0,10,0,5,1,2,0.550833,0.529675,0.71625,0.223883,529,3115,3644 +288,2011-10-15,4,0,10,0,6,0,1,0.506667,0.498725,0.483333,0.258083,1899,3318,5217 +289,2011-10-16,4,0,10,0,0,0,1,0.511667,0.503154,0.486667,0.281717,1748,3293,5041 +290,2011-10-17,4,0,10,0,1,1,1,0.534167,0.510725,0.579583,0.175379,713,3857,4570 +291,2011-10-18,4,0,10,0,2,1,2,0.5325,0.522721,0.701667,0.110087,637,4111,4748 +292,2011-10-19,4,0,10,0,3,1,3,0.541739,0.513848,0.895217,0.243339,254,2170,2424 +293,2011-10-20,4,0,10,0,4,1,1,0.475833,0.466525,0.63625,0.422275,471,3724,4195 +294,2011-10-21,4,0,10,0,5,1,1,0.4275,0.423596,0.574167,0.221396,676,3628,4304 +295,2011-10-22,4,0,10,0,6,0,1,0.4225,0.425492,0.629167,0.0926667,1499,2809,4308 +296,2011-10-23,4,0,10,0,0,0,1,0.421667,0.422333,0.74125,0.0995125,1619,2762,4381 +297,2011-10-24,4,0,10,0,1,1,1,0.463333,0.457067,0.772083,0.118792,699,3488,4187 +298,2011-10-25,4,0,10,0,2,1,1,0.471667,0.463375,0.622917,0.166658,695,3992,4687 +299,2011-10-26,4,0,10,0,3,1,2,0.484167,0.472846,0.720417,0.148642,404,3490,3894 +300,2011-10-27,4,0,10,0,4,1,2,0.47,0.457046,0.812917,0.197763,240,2419,2659 +301,2011-10-28,4,0,10,0,5,1,2,0.330833,0.318812,0.585833,0.229479,456,3291,3747 +302,2011-10-29,4,0,10,0,6,0,3,0.254167,0.227913,0.8825,0.351371,57,570,627 +303,2011-10-30,4,0,10,0,0,0,1,0.319167,0.321329,0.62375,0.176617,885,2446,3331 +304,2011-10-31,4,0,10,0,1,1,1,0.34,0.356063,0.703333,0.10635,362,3307,3669 +305,2011-11-01,4,0,11,0,2,1,1,0.400833,0.397088,0.68375,0.135571,410,3658,4068 +306,2011-11-02,4,0,11,0,3,1,1,0.3775,0.390133,0.71875,0.0820917,370,3816,4186 +307,2011-11-03,4,0,11,0,4,1,1,0.408333,0.405921,0.702083,0.136817,318,3656,3974 +308,2011-11-04,4,0,11,0,5,1,2,0.403333,0.403392,0.6225,0.271779,470,3576,4046 +309,2011-11-05,4,0,11,0,6,0,1,0.326667,0.323854,0.519167,0.189062,1156,2770,3926 +310,2011-11-06,4,0,11,0,0,0,1,0.348333,0.362358,0.734583,0.0920542,952,2697,3649 +311,2011-11-07,4,0,11,0,1,1,1,0.395,0.400871,0.75875,0.057225,373,3662,4035 +312,2011-11-08,4,0,11,0,2,1,1,0.408333,0.412246,0.721667,0.0690375,376,3829,4205 +313,2011-11-09,4,0,11,0,3,1,1,0.4,0.409079,0.758333,0.0621958,305,3804,4109 +314,2011-11-10,4,0,11,0,4,1,2,0.38,0.373721,0.813333,0.189067,190,2743,2933 +315,2011-11-11,4,0,11,1,5,0,1,0.324167,0.306817,0.44625,0.314675,440,2928,3368 +316,2011-11-12,4,0,11,0,6,0,1,0.356667,0.357942,0.552917,0.212062,1275,2792,4067 +317,2011-11-13,4,0,11,0,0,0,1,0.440833,0.43055,0.458333,0.281721,1004,2713,3717 +318,2011-11-14,4,0,11,0,1,1,1,0.53,0.524612,0.587083,0.306596,595,3891,4486 +319,2011-11-15,4,0,11,0,2,1,2,0.53,0.507579,0.68875,0.199633,449,3746,4195 +320,2011-11-16,4,0,11,0,3,1,3,0.456667,0.451988,0.93,0.136829,145,1672,1817 +321,2011-11-17,4,0,11,0,4,1,2,0.341667,0.323221,0.575833,0.305362,139,2914,3053 +322,2011-11-18,4,0,11,0,5,1,1,0.274167,0.272721,0.41,0.168533,245,3147,3392 +323,2011-11-19,4,0,11,0,6,0,1,0.329167,0.324483,0.502083,0.224496,943,2720,3663 +324,2011-11-20,4,0,11,0,0,0,2,0.463333,0.457058,0.684583,0.18595,787,2733,3520 +325,2011-11-21,4,0,11,0,1,1,3,0.4475,0.445062,0.91,0.138054,220,2545,2765 +326,2011-11-22,4,0,11,0,2,1,3,0.416667,0.421696,0.9625,0.118792,69,1538,1607 +327,2011-11-23,4,0,11,0,3,1,2,0.440833,0.430537,0.757917,0.335825,112,2454,2566 +328,2011-11-24,4,0,11,1,4,0,1,0.373333,0.372471,0.549167,0.167304,560,935,1495 +329,2011-11-25,4,0,11,0,5,1,1,0.375,0.380671,0.64375,0.0988958,1095,1697,2792 +330,2011-11-26,4,0,11,0,6,0,1,0.375833,0.385087,0.681667,0.0684208,1249,1819,3068 +331,2011-11-27,4,0,11,0,0,0,1,0.459167,0.4558,0.698333,0.208954,810,2261,3071 +332,2011-11-28,4,0,11,0,1,1,1,0.503478,0.490122,0.743043,0.142122,253,3614,3867 +333,2011-11-29,4,0,11,0,2,1,2,0.458333,0.451375,0.830833,0.258092,96,2818,2914 +334,2011-11-30,4,0,11,0,3,1,1,0.325,0.311221,0.613333,0.271158,188,3425,3613 +335,2011-12-01,4,0,12,0,4,1,1,0.3125,0.305554,0.524583,0.220158,182,3545,3727 +336,2011-12-02,4,0,12,0,5,1,1,0.314167,0.331433,0.625833,0.100754,268,3672,3940 +337,2011-12-03,4,0,12,0,6,0,1,0.299167,0.310604,0.612917,0.0957833,706,2908,3614 +338,2011-12-04,4,0,12,0,0,0,1,0.330833,0.3491,0.775833,0.0839583,634,2851,3485 +339,2011-12-05,4,0,12,0,1,1,2,0.385833,0.393925,0.827083,0.0622083,233,3578,3811 +340,2011-12-06,4,0,12,0,2,1,3,0.4625,0.4564,0.949583,0.232583,126,2468,2594 +341,2011-12-07,4,0,12,0,3,1,3,0.41,0.400246,0.970417,0.266175,50,655,705 +342,2011-12-08,4,0,12,0,4,1,1,0.265833,0.256938,0.58,0.240058,150,3172,3322 +343,2011-12-09,4,0,12,0,5,1,1,0.290833,0.317542,0.695833,0.0827167,261,3359,3620 +344,2011-12-10,4,0,12,0,6,0,1,0.275,0.266412,0.5075,0.233221,502,2688,3190 +345,2011-12-11,4,0,12,0,0,0,1,0.220833,0.253154,0.49,0.0665417,377,2366,2743 +346,2011-12-12,4,0,12,0,1,1,1,0.238333,0.270196,0.670833,0.06345,143,3167,3310 +347,2011-12-13,4,0,12,0,2,1,1,0.2825,0.301138,0.59,0.14055,155,3368,3523 +348,2011-12-14,4,0,12,0,3,1,2,0.3175,0.338362,0.66375,0.0609583,178,3562,3740 +349,2011-12-15,4,0,12,0,4,1,2,0.4225,0.412237,0.634167,0.268042,181,3528,3709 +350,2011-12-16,4,0,12,0,5,1,2,0.375,0.359825,0.500417,0.260575,178,3399,3577 +351,2011-12-17,4,0,12,0,6,0,2,0.258333,0.249371,0.560833,0.243167,275,2464,2739 +352,2011-12-18,4,0,12,0,0,0,1,0.238333,0.245579,0.58625,0.169779,220,2211,2431 +353,2011-12-19,4,0,12,0,1,1,1,0.276667,0.280933,0.6375,0.172896,260,3143,3403 +354,2011-12-20,4,0,12,0,2,1,2,0.385833,0.396454,0.595417,0.0615708,216,3534,3750 +355,2011-12-21,1,0,12,0,3,1,2,0.428333,0.428017,0.858333,0.2214,107,2553,2660 +356,2011-12-22,1,0,12,0,4,1,2,0.423333,0.426121,0.7575,0.047275,227,2841,3068 +357,2011-12-23,1,0,12,0,5,1,1,0.373333,0.377513,0.68625,0.274246,163,2046,2209 +358,2011-12-24,1,0,12,0,6,0,1,0.3025,0.299242,0.5425,0.190304,155,856,1011 +359,2011-12-25,1,0,12,0,0,0,1,0.274783,0.279961,0.681304,0.155091,303,451,754 +360,2011-12-26,1,0,12,1,1,0,1,0.321739,0.315535,0.506957,0.239465,430,887,1317 +361,2011-12-27,1,0,12,0,2,1,2,0.325,0.327633,0.7625,0.18845,103,1059,1162 +362,2011-12-28,1,0,12,0,3,1,1,0.29913,0.279974,0.503913,0.293961,255,2047,2302 +363,2011-12-29,1,0,12,0,4,1,1,0.248333,0.263892,0.574167,0.119412,254,2169,2423 +364,2011-12-30,1,0,12,0,5,1,1,0.311667,0.318812,0.636667,0.134337,491,2508,2999 +365,2011-12-31,1,0,12,0,6,0,1,0.41,0.414121,0.615833,0.220154,665,1820,2485 +366,2012-01-01,1,1,1,0,0,0,1,0.37,0.375621,0.6925,0.192167,686,1608,2294 +367,2012-01-02,1,1,1,1,1,0,1,0.273043,0.252304,0.381304,0.329665,244,1707,1951 +368,2012-01-03,1,1,1,0,2,1,1,0.15,0.126275,0.44125,0.365671,89,2147,2236 +369,2012-01-04,1,1,1,0,3,1,2,0.1075,0.119337,0.414583,0.1847,95,2273,2368 +370,2012-01-05,1,1,1,0,4,1,1,0.265833,0.278412,0.524167,0.129987,140,3132,3272 +371,2012-01-06,1,1,1,0,5,1,1,0.334167,0.340267,0.542083,0.167908,307,3791,4098 +372,2012-01-07,1,1,1,0,6,0,1,0.393333,0.390779,0.531667,0.174758,1070,3451,4521 +373,2012-01-08,1,1,1,0,0,0,1,0.3375,0.340258,0.465,0.191542,599,2826,3425 +374,2012-01-09,1,1,1,0,1,1,2,0.224167,0.247479,0.701667,0.0989,106,2270,2376 +375,2012-01-10,1,1,1,0,2,1,1,0.308696,0.318826,0.646522,0.187552,173,3425,3598 +376,2012-01-11,1,1,1,0,3,1,2,0.274167,0.282821,0.8475,0.131221,92,2085,2177 +377,2012-01-12,1,1,1,0,4,1,2,0.3825,0.381938,0.802917,0.180967,269,3828,4097 +378,2012-01-13,1,1,1,0,5,1,1,0.274167,0.249362,0.5075,0.378108,174,3040,3214 +379,2012-01-14,1,1,1,0,6,0,1,0.18,0.183087,0.4575,0.187183,333,2160,2493 +380,2012-01-15,1,1,1,0,0,0,1,0.166667,0.161625,0.419167,0.251258,284,2027,2311 +381,2012-01-16,1,1,1,1,1,0,1,0.19,0.190663,0.5225,0.231358,217,2081,2298 +382,2012-01-17,1,1,1,0,2,1,2,0.373043,0.364278,0.716087,0.34913,127,2808,2935 +383,2012-01-18,1,1,1,0,3,1,1,0.303333,0.275254,0.443333,0.415429,109,3267,3376 +384,2012-01-19,1,1,1,0,4,1,1,0.19,0.190038,0.4975,0.220158,130,3162,3292 +385,2012-01-20,1,1,1,0,5,1,2,0.2175,0.220958,0.45,0.20275,115,3048,3163 +386,2012-01-21,1,1,1,0,6,0,2,0.173333,0.174875,0.83125,0.222642,67,1234,1301 +387,2012-01-22,1,1,1,0,0,0,2,0.1625,0.16225,0.79625,0.199638,196,1781,1977 +388,2012-01-23,1,1,1,0,1,1,2,0.218333,0.243058,0.91125,0.110708,145,2287,2432 +389,2012-01-24,1,1,1,0,2,1,1,0.3425,0.349108,0.835833,0.123767,439,3900,4339 +390,2012-01-25,1,1,1,0,3,1,1,0.294167,0.294821,0.64375,0.161071,467,3803,4270 +391,2012-01-26,1,1,1,0,4,1,2,0.341667,0.35605,0.769583,0.0733958,244,3831,4075 +392,2012-01-27,1,1,1,0,5,1,2,0.425,0.415383,0.74125,0.342667,269,3187,3456 +393,2012-01-28,1,1,1,0,6,0,1,0.315833,0.326379,0.543333,0.210829,775,3248,4023 +394,2012-01-29,1,1,1,0,0,0,1,0.2825,0.272721,0.31125,0.24005,558,2685,3243 +395,2012-01-30,1,1,1,0,1,1,1,0.269167,0.262625,0.400833,0.215792,126,3498,3624 +396,2012-01-31,1,1,1,0,2,1,1,0.39,0.381317,0.416667,0.261817,324,4185,4509 +397,2012-02-01,1,1,2,0,3,1,1,0.469167,0.466538,0.507917,0.189067,304,4275,4579 +398,2012-02-02,1,1,2,0,4,1,2,0.399167,0.398971,0.672917,0.187187,190,3571,3761 +399,2012-02-03,1,1,2,0,5,1,1,0.313333,0.309346,0.526667,0.178496,310,3841,4151 +400,2012-02-04,1,1,2,0,6,0,2,0.264167,0.272725,0.779583,0.121896,384,2448,2832 +401,2012-02-05,1,1,2,0,0,0,2,0.265833,0.264521,0.687917,0.175996,318,2629,2947 +402,2012-02-06,1,1,2,0,1,1,1,0.282609,0.296426,0.622174,0.1538,206,3578,3784 +403,2012-02-07,1,1,2,0,2,1,1,0.354167,0.361104,0.49625,0.147379,199,4176,4375 +404,2012-02-08,1,1,2,0,3,1,2,0.256667,0.266421,0.722917,0.133721,109,2693,2802 +405,2012-02-09,1,1,2,0,4,1,1,0.265,0.261988,0.562083,0.194037,163,3667,3830 +406,2012-02-10,1,1,2,0,5,1,2,0.280833,0.293558,0.54,0.116929,227,3604,3831 +407,2012-02-11,1,1,2,0,6,0,3,0.224167,0.210867,0.73125,0.289796,192,1977,2169 +408,2012-02-12,1,1,2,0,0,0,1,0.1275,0.101658,0.464583,0.409212,73,1456,1529 +409,2012-02-13,1,1,2,0,1,1,1,0.2225,0.227913,0.41125,0.167283,94,3328,3422 +410,2012-02-14,1,1,2,0,2,1,2,0.319167,0.333946,0.50875,0.141179,135,3787,3922 +411,2012-02-15,1,1,2,0,3,1,1,0.348333,0.351629,0.53125,0.1816,141,4028,4169 +412,2012-02-16,1,1,2,0,4,1,2,0.316667,0.330162,0.752917,0.091425,74,2931,3005 +413,2012-02-17,1,1,2,0,5,1,1,0.343333,0.351629,0.634583,0.205846,349,3805,4154 +414,2012-02-18,1,1,2,0,6,0,1,0.346667,0.355425,0.534583,0.190929,1435,2883,4318 +415,2012-02-19,1,1,2,0,0,0,2,0.28,0.265788,0.515833,0.253112,618,2071,2689 +416,2012-02-20,1,1,2,1,1,0,1,0.28,0.273391,0.507826,0.229083,502,2627,3129 +417,2012-02-21,1,1,2,0,2,1,1,0.287826,0.295113,0.594348,0.205717,163,3614,3777 +418,2012-02-22,1,1,2,0,3,1,1,0.395833,0.392667,0.567917,0.234471,394,4379,4773 +419,2012-02-23,1,1,2,0,4,1,1,0.454167,0.444446,0.554583,0.190913,516,4546,5062 +420,2012-02-24,1,1,2,0,5,1,2,0.4075,0.410971,0.7375,0.237567,246,3241,3487 +421,2012-02-25,1,1,2,0,6,0,1,0.290833,0.255675,0.395833,0.421642,317,2415,2732 +422,2012-02-26,1,1,2,0,0,0,1,0.279167,0.268308,0.41,0.205229,515,2874,3389 +423,2012-02-27,1,1,2,0,1,1,1,0.366667,0.357954,0.490833,0.268033,253,4069,4322 +424,2012-02-28,1,1,2,0,2,1,1,0.359167,0.353525,0.395833,0.193417,229,4134,4363 +425,2012-02-29,1,1,2,0,3,1,2,0.344348,0.34847,0.804783,0.179117,65,1769,1834 +426,2012-03-01,1,1,3,0,4,1,1,0.485833,0.475371,0.615417,0.226987,325,4665,4990 +427,2012-03-02,1,1,3,0,5,1,2,0.353333,0.359842,0.657083,0.144904,246,2948,3194 +428,2012-03-03,1,1,3,0,6,0,2,0.414167,0.413492,0.62125,0.161079,956,3110,4066 +429,2012-03-04,1,1,3,0,0,0,1,0.325833,0.303021,0.403333,0.334571,710,2713,3423 +430,2012-03-05,1,1,3,0,1,1,1,0.243333,0.241171,0.50625,0.228858,203,3130,3333 +431,2012-03-06,1,1,3,0,2,1,1,0.258333,0.255042,0.456667,0.200875,221,3735,3956 +432,2012-03-07,1,1,3,0,3,1,1,0.404167,0.3851,0.513333,0.345779,432,4484,4916 +433,2012-03-08,1,1,3,0,4,1,1,0.5275,0.524604,0.5675,0.441563,486,4896,5382 +434,2012-03-09,1,1,3,0,5,1,2,0.410833,0.397083,0.407083,0.4148,447,4122,4569 +435,2012-03-10,1,1,3,0,6,0,1,0.2875,0.277767,0.350417,0.22575,968,3150,4118 +436,2012-03-11,1,1,3,0,0,0,1,0.361739,0.35967,0.476957,0.222587,1658,3253,4911 +437,2012-03-12,1,1,3,0,1,1,1,0.466667,0.459592,0.489167,0.207713,838,4460,5298 +438,2012-03-13,1,1,3,0,2,1,1,0.565,0.542929,0.6175,0.23695,762,5085,5847 +439,2012-03-14,1,1,3,0,3,1,1,0.5725,0.548617,0.507083,0.115062,997,5315,6312 +440,2012-03-15,1,1,3,0,4,1,1,0.5575,0.532825,0.579583,0.149883,1005,5187,6192 +441,2012-03-16,1,1,3,0,5,1,2,0.435833,0.436229,0.842083,0.113192,548,3830,4378 +442,2012-03-17,1,1,3,0,6,0,2,0.514167,0.505046,0.755833,0.110704,3155,4681,7836 +443,2012-03-18,1,1,3,0,0,0,2,0.4725,0.464,0.81,0.126883,2207,3685,5892 +444,2012-03-19,1,1,3,0,1,1,1,0.545,0.532821,0.72875,0.162317,982,5171,6153 +445,2012-03-20,1,1,3,0,2,1,1,0.560833,0.538533,0.807917,0.121271,1051,5042,6093 +446,2012-03-21,2,1,3,0,3,1,2,0.531667,0.513258,0.82125,0.0895583,1122,5108,6230 +447,2012-03-22,2,1,3,0,4,1,1,0.554167,0.531567,0.83125,0.117562,1334,5537,6871 +448,2012-03-23,2,1,3,0,5,1,2,0.601667,0.570067,0.694167,0.1163,2469,5893,8362 +449,2012-03-24,2,1,3,0,6,0,2,0.5025,0.486733,0.885417,0.192783,1033,2339,3372 +450,2012-03-25,2,1,3,0,0,0,2,0.4375,0.437488,0.880833,0.220775,1532,3464,4996 +451,2012-03-26,2,1,3,0,1,1,1,0.445833,0.43875,0.477917,0.386821,795,4763,5558 +452,2012-03-27,2,1,3,0,2,1,1,0.323333,0.315654,0.29,0.187192,531,4571,5102 +453,2012-03-28,2,1,3,0,3,1,1,0.484167,0.47095,0.48125,0.291671,674,5024,5698 +454,2012-03-29,2,1,3,0,4,1,1,0.494167,0.482304,0.439167,0.31965,834,5299,6133 +455,2012-03-30,2,1,3,0,5,1,2,0.37,0.375621,0.580833,0.138067,796,4663,5459 +456,2012-03-31,2,1,3,0,6,0,2,0.424167,0.421708,0.738333,0.250617,2301,3934,6235 +457,2012-04-01,2,1,4,0,0,0,2,0.425833,0.417287,0.67625,0.172267,2347,3694,6041 +458,2012-04-02,2,1,4,0,1,1,1,0.433913,0.427513,0.504348,0.312139,1208,4728,5936 +459,2012-04-03,2,1,4,0,2,1,1,0.466667,0.461483,0.396667,0.100133,1348,5424,6772 +460,2012-04-04,2,1,4,0,3,1,1,0.541667,0.53345,0.469583,0.180975,1058,5378,6436 +461,2012-04-05,2,1,4,0,4,1,1,0.435,0.431163,0.374167,0.219529,1192,5265,6457 +462,2012-04-06,2,1,4,0,5,1,1,0.403333,0.390767,0.377083,0.300388,1807,4653,6460 +463,2012-04-07,2,1,4,0,6,0,1,0.4375,0.426129,0.254167,0.274871,3252,3605,6857 +464,2012-04-08,2,1,4,0,0,0,1,0.5,0.492425,0.275833,0.232596,2230,2939,5169 +465,2012-04-09,2,1,4,0,1,1,1,0.489167,0.476638,0.3175,0.358196,905,4680,5585 +466,2012-04-10,2,1,4,0,2,1,1,0.446667,0.436233,0.435,0.249375,819,5099,5918 +467,2012-04-11,2,1,4,0,3,1,1,0.348696,0.337274,0.469565,0.295274,482,4380,4862 +468,2012-04-12,2,1,4,0,4,1,1,0.3975,0.387604,0.46625,0.290429,663,4746,5409 +469,2012-04-13,2,1,4,0,5,1,1,0.4425,0.431808,0.408333,0.155471,1252,5146,6398 +470,2012-04-14,2,1,4,0,6,0,1,0.495,0.487996,0.502917,0.190917,2795,4665,7460 +471,2012-04-15,2,1,4,0,0,0,1,0.606667,0.573875,0.507917,0.225129,2846,4286,7132 +472,2012-04-16,2,1,4,1,1,0,1,0.664167,0.614925,0.561667,0.284829,1198,5172,6370 +473,2012-04-17,2,1,4,0,2,1,1,0.608333,0.598487,0.390417,0.273629,989,5702,6691 +474,2012-04-18,2,1,4,0,3,1,2,0.463333,0.457038,0.569167,0.167912,347,4020,4367 +475,2012-04-19,2,1,4,0,4,1,1,0.498333,0.493046,0.6125,0.0659292,846,5719,6565 +476,2012-04-20,2,1,4,0,5,1,1,0.526667,0.515775,0.694583,0.149871,1340,5950,7290 +477,2012-04-21,2,1,4,0,6,0,1,0.57,0.542921,0.682917,0.283587,2541,4083,6624 +478,2012-04-22,2,1,4,0,0,0,3,0.396667,0.389504,0.835417,0.344546,120,907,1027 +479,2012-04-23,2,1,4,0,1,1,2,0.321667,0.301125,0.766667,0.303496,195,3019,3214 +480,2012-04-24,2,1,4,0,2,1,1,0.413333,0.405283,0.454167,0.249383,518,5115,5633 +481,2012-04-25,2,1,4,0,3,1,1,0.476667,0.470317,0.427917,0.118792,655,5541,6196 +482,2012-04-26,2,1,4,0,4,1,2,0.498333,0.483583,0.756667,0.176625,475,4551,5026 +483,2012-04-27,2,1,4,0,5,1,1,0.4575,0.452637,0.400833,0.347633,1014,5219,6233 +484,2012-04-28,2,1,4,0,6,0,2,0.376667,0.377504,0.489583,0.129975,1120,3100,4220 +485,2012-04-29,2,1,4,0,0,0,1,0.458333,0.450121,0.587083,0.116908,2229,4075,6304 +486,2012-04-30,2,1,4,0,1,1,2,0.464167,0.457696,0.57,0.171638,665,4907,5572 +487,2012-05-01,2,1,5,0,2,1,2,0.613333,0.577021,0.659583,0.156096,653,5087,5740 +488,2012-05-02,2,1,5,0,3,1,1,0.564167,0.537896,0.797083,0.138058,667,5502,6169 +489,2012-05-03,2,1,5,0,4,1,2,0.56,0.537242,0.768333,0.133696,764,5657,6421 +490,2012-05-04,2,1,5,0,5,1,1,0.6275,0.590917,0.735417,0.162938,1069,5227,6296 +491,2012-05-05,2,1,5,0,6,0,2,0.621667,0.584608,0.756667,0.152992,2496,4387,6883 +492,2012-05-06,2,1,5,0,0,0,2,0.5625,0.546737,0.74,0.149879,2135,4224,6359 +493,2012-05-07,2,1,5,0,1,1,2,0.5375,0.527142,0.664167,0.230721,1008,5265,6273 +494,2012-05-08,2,1,5,0,2,1,2,0.581667,0.557471,0.685833,0.296029,738,4990,5728 +495,2012-05-09,2,1,5,0,3,1,2,0.575,0.553025,0.744167,0.216412,620,4097,4717 +496,2012-05-10,2,1,5,0,4,1,1,0.505833,0.491783,0.552083,0.314063,1026,5546,6572 +497,2012-05-11,2,1,5,0,5,1,1,0.533333,0.520833,0.360417,0.236937,1319,5711,7030 +498,2012-05-12,2,1,5,0,6,0,1,0.564167,0.544817,0.480417,0.123133,2622,4807,7429 +499,2012-05-13,2,1,5,0,0,0,1,0.6125,0.585238,0.57625,0.225117,2172,3946,6118 +500,2012-05-14,2,1,5,0,1,1,2,0.573333,0.5499,0.789583,0.212692,342,2501,2843 +501,2012-05-15,2,1,5,0,2,1,2,0.611667,0.576404,0.794583,0.147392,625,4490,5115 +502,2012-05-16,2,1,5,0,3,1,1,0.636667,0.595975,0.697917,0.122512,991,6433,7424 +503,2012-05-17,2,1,5,0,4,1,1,0.593333,0.572613,0.52,0.229475,1242,6142,7384 +504,2012-05-18,2,1,5,0,5,1,1,0.564167,0.551121,0.523333,0.136817,1521,6118,7639 +505,2012-05-19,2,1,5,0,6,0,1,0.6,0.566908,0.45625,0.083975,3410,4884,8294 +506,2012-05-20,2,1,5,0,0,0,1,0.620833,0.583967,0.530417,0.254367,2704,4425,7129 +507,2012-05-21,2,1,5,0,1,1,2,0.598333,0.565667,0.81125,0.233204,630,3729,4359 +508,2012-05-22,2,1,5,0,2,1,2,0.615,0.580825,0.765833,0.118167,819,5254,6073 +509,2012-05-23,2,1,5,0,3,1,2,0.621667,0.584612,0.774583,0.102,766,4494,5260 +510,2012-05-24,2,1,5,0,4,1,1,0.655,0.6067,0.716667,0.172896,1059,5711,6770 +511,2012-05-25,2,1,5,0,5,1,1,0.68,0.627529,0.747083,0.14055,1417,5317,6734 +512,2012-05-26,2,1,5,0,6,0,1,0.6925,0.642696,0.7325,0.198992,2855,3681,6536 +513,2012-05-27,2,1,5,0,0,0,1,0.69,0.641425,0.697083,0.215171,3283,3308,6591 +514,2012-05-28,2,1,5,1,1,0,1,0.7125,0.6793,0.67625,0.196521,2557,3486,6043 +515,2012-05-29,2,1,5,0,2,1,1,0.7225,0.672992,0.684583,0.2954,880,4863,5743 +516,2012-05-30,2,1,5,0,3,1,2,0.656667,0.611129,0.67,0.134329,745,6110,6855 +517,2012-05-31,2,1,5,0,4,1,1,0.68,0.631329,0.492917,0.195279,1100,6238,7338 +518,2012-06-01,2,1,6,0,5,1,2,0.654167,0.607962,0.755417,0.237563,533,3594,4127 +519,2012-06-02,2,1,6,0,6,0,1,0.583333,0.566288,0.549167,0.186562,2795,5325,8120 +520,2012-06-03,2,1,6,0,0,0,1,0.6025,0.575133,0.493333,0.184087,2494,5147,7641 +521,2012-06-04,2,1,6,0,1,1,1,0.5975,0.578283,0.487083,0.284833,1071,5927,6998 +522,2012-06-05,2,1,6,0,2,1,2,0.540833,0.525892,0.613333,0.209575,968,6033,7001 +523,2012-06-06,2,1,6,0,3,1,1,0.554167,0.542292,0.61125,0.077125,1027,6028,7055 +524,2012-06-07,2,1,6,0,4,1,1,0.6025,0.569442,0.567083,0.15735,1038,6456,7494 +525,2012-06-08,2,1,6,0,5,1,1,0.649167,0.597862,0.467917,0.175383,1488,6248,7736 +526,2012-06-09,2,1,6,0,6,0,1,0.710833,0.648367,0.437083,0.144287,2708,4790,7498 +527,2012-06-10,2,1,6,0,0,0,1,0.726667,0.663517,0.538333,0.133721,2224,4374,6598 +528,2012-06-11,2,1,6,0,1,1,2,0.720833,0.659721,0.587917,0.207713,1017,5647,6664 +529,2012-06-12,2,1,6,0,2,1,2,0.653333,0.597875,0.833333,0.214546,477,4495,4972 +530,2012-06-13,2,1,6,0,3,1,1,0.655833,0.611117,0.582083,0.343279,1173,6248,7421 +531,2012-06-14,2,1,6,0,4,1,1,0.648333,0.624383,0.569583,0.253733,1180,6183,7363 +532,2012-06-15,2,1,6,0,5,1,1,0.639167,0.599754,0.589583,0.176617,1563,6102,7665 +533,2012-06-16,2,1,6,0,6,0,1,0.631667,0.594708,0.504167,0.166667,2963,4739,7702 +534,2012-06-17,2,1,6,0,0,0,1,0.5925,0.571975,0.59875,0.144904,2634,4344,6978 +535,2012-06-18,2,1,6,0,1,1,2,0.568333,0.544842,0.777917,0.174746,653,4446,5099 +536,2012-06-19,2,1,6,0,2,1,1,0.688333,0.654692,0.69,0.148017,968,5857,6825 +537,2012-06-20,2,1,6,0,3,1,1,0.7825,0.720975,0.592083,0.113812,872,5339,6211 +538,2012-06-21,3,1,6,0,4,1,1,0.805833,0.752542,0.567917,0.118787,778,5127,5905 +539,2012-06-22,3,1,6,0,5,1,1,0.7775,0.724121,0.57375,0.182842,964,4859,5823 +540,2012-06-23,3,1,6,0,6,0,1,0.731667,0.652792,0.534583,0.179721,2657,4801,7458 +541,2012-06-24,3,1,6,0,0,0,1,0.743333,0.674254,0.479167,0.145525,2551,4340,6891 +542,2012-06-25,3,1,6,0,1,1,1,0.715833,0.654042,0.504167,0.300383,1139,5640,6779 +543,2012-06-26,3,1,6,0,2,1,1,0.630833,0.594704,0.373333,0.347642,1077,6365,7442 +544,2012-06-27,3,1,6,0,3,1,1,0.6975,0.640792,0.36,0.271775,1077,6258,7335 +545,2012-06-28,3,1,6,0,4,1,1,0.749167,0.675512,0.4225,0.17165,921,5958,6879 +546,2012-06-29,3,1,6,0,5,1,1,0.834167,0.786613,0.48875,0.165417,829,4634,5463 +547,2012-06-30,3,1,6,0,6,0,1,0.765,0.687508,0.60125,0.161071,1455,4232,5687 +548,2012-07-01,3,1,7,0,0,0,1,0.815833,0.750629,0.51875,0.168529,1421,4110,5531 +549,2012-07-02,3,1,7,0,1,1,1,0.781667,0.702038,0.447083,0.195267,904,5323,6227 +550,2012-07-03,3,1,7,0,2,1,1,0.780833,0.70265,0.492083,0.126237,1052,5608,6660 +551,2012-07-04,3,1,7,1,3,0,1,0.789167,0.732337,0.53875,0.13495,2562,4841,7403 +552,2012-07-05,3,1,7,0,4,1,1,0.8275,0.761367,0.457917,0.194029,1405,4836,6241 +553,2012-07-06,3,1,7,0,5,1,1,0.828333,0.752533,0.450833,0.146142,1366,4841,6207 +554,2012-07-07,3,1,7,0,6,0,1,0.861667,0.804913,0.492083,0.163554,1448,3392,4840 +555,2012-07-08,3,1,7,0,0,0,1,0.8225,0.790396,0.57375,0.125629,1203,3469,4672 +556,2012-07-09,3,1,7,0,1,1,2,0.710833,0.654054,0.683333,0.180975,998,5571,6569 +557,2012-07-10,3,1,7,0,2,1,2,0.720833,0.664796,0.6675,0.151737,954,5336,6290 +558,2012-07-11,3,1,7,0,3,1,1,0.716667,0.650271,0.633333,0.151733,975,6289,7264 +559,2012-07-12,3,1,7,0,4,1,1,0.715833,0.654683,0.529583,0.146775,1032,6414,7446 +560,2012-07-13,3,1,7,0,5,1,2,0.731667,0.667933,0.485833,0.08085,1511,5988,7499 +561,2012-07-14,3,1,7,0,6,0,2,0.703333,0.666042,0.699167,0.143679,2355,4614,6969 +562,2012-07-15,3,1,7,0,0,0,1,0.745833,0.705196,0.717917,0.166667,1920,4111,6031 +563,2012-07-16,3,1,7,0,1,1,1,0.763333,0.724125,0.645,0.164187,1088,5742,6830 +564,2012-07-17,3,1,7,0,2,1,1,0.818333,0.755683,0.505833,0.114429,921,5865,6786 +565,2012-07-18,3,1,7,0,3,1,1,0.793333,0.745583,0.577083,0.137442,799,4914,5713 +566,2012-07-19,3,1,7,0,4,1,1,0.77,0.714642,0.600417,0.165429,888,5703,6591 +567,2012-07-20,3,1,7,0,5,1,2,0.665833,0.613025,0.844167,0.208967,747,5123,5870 +568,2012-07-21,3,1,7,0,6,0,3,0.595833,0.549912,0.865417,0.2133,1264,3195,4459 +569,2012-07-22,3,1,7,0,0,0,2,0.6675,0.623125,0.7625,0.0939208,2544,4866,7410 +570,2012-07-23,3,1,7,0,1,1,1,0.741667,0.690017,0.694167,0.138683,1135,5831,6966 +571,2012-07-24,3,1,7,0,2,1,1,0.750833,0.70645,0.655,0.211454,1140,6452,7592 +572,2012-07-25,3,1,7,0,3,1,1,0.724167,0.654054,0.45,0.1648,1383,6790,8173 +573,2012-07-26,3,1,7,0,4,1,1,0.776667,0.739263,0.596667,0.284813,1036,5825,6861 +574,2012-07-27,3,1,7,0,5,1,1,0.781667,0.734217,0.594583,0.152992,1259,5645,6904 +575,2012-07-28,3,1,7,0,6,0,1,0.755833,0.697604,0.613333,0.15735,2234,4451,6685 +576,2012-07-29,3,1,7,0,0,0,1,0.721667,0.667933,0.62375,0.170396,2153,4444,6597 +577,2012-07-30,3,1,7,0,1,1,1,0.730833,0.684987,0.66875,0.153617,1040,6065,7105 +578,2012-07-31,3,1,7,0,2,1,1,0.713333,0.662896,0.704167,0.165425,968,6248,7216 +579,2012-08-01,3,1,8,0,3,1,1,0.7175,0.667308,0.6775,0.141179,1074,6506,7580 +580,2012-08-02,3,1,8,0,4,1,1,0.7525,0.707088,0.659583,0.129354,983,6278,7261 +581,2012-08-03,3,1,8,0,5,1,2,0.765833,0.722867,0.6425,0.215792,1328,5847,7175 +582,2012-08-04,3,1,8,0,6,0,1,0.793333,0.751267,0.613333,0.257458,2345,4479,6824 +583,2012-08-05,3,1,8,0,0,0,1,0.769167,0.731079,0.6525,0.290421,1707,3757,5464 +584,2012-08-06,3,1,8,0,1,1,2,0.7525,0.710246,0.654167,0.129354,1233,5780,7013 +585,2012-08-07,3,1,8,0,2,1,2,0.735833,0.697621,0.70375,0.116908,1278,5995,7273 +586,2012-08-08,3,1,8,0,3,1,2,0.75,0.707717,0.672917,0.1107,1263,6271,7534 +587,2012-08-09,3,1,8,0,4,1,1,0.755833,0.699508,0.620417,0.1561,1196,6090,7286 +588,2012-08-10,3,1,8,0,5,1,2,0.715833,0.667942,0.715833,0.238813,1065,4721,5786 +589,2012-08-11,3,1,8,0,6,0,2,0.6925,0.638267,0.732917,0.206479,2247,4052,6299 +590,2012-08-12,3,1,8,0,0,0,1,0.700833,0.644579,0.530417,0.122512,2182,4362,6544 +591,2012-08-13,3,1,8,0,1,1,1,0.720833,0.662254,0.545417,0.136212,1207,5676,6883 +592,2012-08-14,3,1,8,0,2,1,1,0.726667,0.676779,0.686667,0.169158,1128,5656,6784 +593,2012-08-15,3,1,8,0,3,1,1,0.706667,0.654037,0.619583,0.169771,1198,6149,7347 +594,2012-08-16,3,1,8,0,4,1,1,0.719167,0.654688,0.519167,0.141796,1338,6267,7605 +595,2012-08-17,3,1,8,0,5,1,1,0.723333,0.2424,0.570833,0.231354,1483,5665,7148 +596,2012-08-18,3,1,8,0,6,0,1,0.678333,0.618071,0.603333,0.177867,2827,5038,7865 +597,2012-08-19,3,1,8,0,0,0,2,0.635833,0.603554,0.711667,0.08645,1208,3341,4549 +598,2012-08-20,3,1,8,0,1,1,2,0.635833,0.595967,0.734167,0.129979,1026,5504,6530 +599,2012-08-21,3,1,8,0,2,1,1,0.649167,0.601025,0.67375,0.0727708,1081,5925,7006 +600,2012-08-22,3,1,8,0,3,1,1,0.6675,0.621854,0.677083,0.0702833,1094,6281,7375 +601,2012-08-23,3,1,8,0,4,1,1,0.695833,0.637008,0.635833,0.0845958,1363,6402,7765 +602,2012-08-24,3,1,8,0,5,1,2,0.7025,0.6471,0.615,0.0721458,1325,6257,7582 +603,2012-08-25,3,1,8,0,6,0,2,0.661667,0.618696,0.712917,0.244408,1829,4224,6053 +604,2012-08-26,3,1,8,0,0,0,2,0.653333,0.595996,0.845833,0.228858,1483,3772,5255 +605,2012-08-27,3,1,8,0,1,1,1,0.703333,0.654688,0.730417,0.128733,989,5928,6917 +606,2012-08-28,3,1,8,0,2,1,1,0.728333,0.66605,0.62,0.190925,935,6105,7040 +607,2012-08-29,3,1,8,0,3,1,1,0.685,0.635733,0.552083,0.112562,1177,6520,7697 +608,2012-08-30,3,1,8,0,4,1,1,0.706667,0.652779,0.590417,0.0771167,1172,6541,7713 +609,2012-08-31,3,1,8,0,5,1,1,0.764167,0.6894,0.5875,0.168533,1433,5917,7350 +610,2012-09-01,3,1,9,0,6,0,2,0.753333,0.702654,0.638333,0.113187,2352,3788,6140 +611,2012-09-02,3,1,9,0,0,0,2,0.696667,0.649,0.815,0.0640708,2613,3197,5810 +612,2012-09-03,3,1,9,1,1,0,1,0.7075,0.661629,0.790833,0.151121,1965,4069,6034 +613,2012-09-04,3,1,9,0,2,1,1,0.725833,0.686888,0.755,0.236321,867,5997,6864 +614,2012-09-05,3,1,9,0,3,1,1,0.736667,0.708983,0.74125,0.187808,832,6280,7112 +615,2012-09-06,3,1,9,0,4,1,2,0.696667,0.655329,0.810417,0.142421,611,5592,6203 +616,2012-09-07,3,1,9,0,5,1,1,0.703333,0.657204,0.73625,0.171646,1045,6459,7504 +617,2012-09-08,3,1,9,0,6,0,2,0.659167,0.611121,0.799167,0.281104,1557,4419,5976 +618,2012-09-09,3,1,9,0,0,0,1,0.61,0.578925,0.5475,0.224496,2570,5657,8227 +619,2012-09-10,3,1,9,0,1,1,1,0.583333,0.565654,0.50375,0.258713,1118,6407,7525 +620,2012-09-11,3,1,9,0,2,1,1,0.5775,0.554292,0.52,0.0920542,1070,6697,7767 +621,2012-09-12,3,1,9,0,3,1,1,0.599167,0.570075,0.577083,0.131846,1050,6820,7870 +622,2012-09-13,3,1,9,0,4,1,1,0.6125,0.579558,0.637083,0.0827208,1054,6750,7804 +623,2012-09-14,3,1,9,0,5,1,1,0.633333,0.594083,0.6725,0.103863,1379,6630,8009 +624,2012-09-15,3,1,9,0,6,0,1,0.608333,0.585867,0.501667,0.247521,3160,5554,8714 +625,2012-09-16,3,1,9,0,0,0,1,0.58,0.563125,0.57,0.0901833,2166,5167,7333 +626,2012-09-17,3,1,9,0,1,1,2,0.580833,0.55305,0.734583,0.151742,1022,5847,6869 +627,2012-09-18,3,1,9,0,2,1,2,0.623333,0.565067,0.8725,0.357587,371,3702,4073 +628,2012-09-19,3,1,9,0,3,1,1,0.5525,0.540404,0.536667,0.215175,788,6803,7591 +629,2012-09-20,3,1,9,0,4,1,1,0.546667,0.532192,0.618333,0.118167,939,6781,7720 +630,2012-09-21,3,1,9,0,5,1,1,0.599167,0.571971,0.66875,0.154229,1250,6917,8167 +631,2012-09-22,3,1,9,0,6,0,1,0.65,0.610488,0.646667,0.283583,2512,5883,8395 +632,2012-09-23,4,1,9,0,0,0,1,0.529167,0.518933,0.467083,0.223258,2454,5453,7907 +633,2012-09-24,4,1,9,0,1,1,1,0.514167,0.502513,0.492917,0.142404,1001,6435,7436 +634,2012-09-25,4,1,9,0,2,1,1,0.55,0.544179,0.57,0.236321,845,6693,7538 +635,2012-09-26,4,1,9,0,3,1,1,0.635,0.596613,0.630833,0.2444,787,6946,7733 +636,2012-09-27,4,1,9,0,4,1,2,0.65,0.607975,0.690833,0.134342,751,6642,7393 +637,2012-09-28,4,1,9,0,5,1,2,0.619167,0.585863,0.69,0.164179,1045,6370,7415 +638,2012-09-29,4,1,9,0,6,0,1,0.5425,0.530296,0.542917,0.227604,2589,5966,8555 +639,2012-09-30,4,1,9,0,0,0,1,0.526667,0.517663,0.583333,0.134958,2015,4874,6889 +640,2012-10-01,4,1,10,0,1,1,2,0.520833,0.512,0.649167,0.0908042,763,6015,6778 +641,2012-10-02,4,1,10,0,2,1,3,0.590833,0.542333,0.871667,0.104475,315,4324,4639 +642,2012-10-03,4,1,10,0,3,1,2,0.6575,0.599133,0.79375,0.0665458,728,6844,7572 +643,2012-10-04,4,1,10,0,4,1,2,0.6575,0.607975,0.722917,0.117546,891,6437,7328 +644,2012-10-05,4,1,10,0,5,1,1,0.615,0.580187,0.6275,0.10635,1516,6640,8156 +645,2012-10-06,4,1,10,0,6,0,1,0.554167,0.538521,0.664167,0.268025,3031,4934,7965 +646,2012-10-07,4,1,10,0,0,0,2,0.415833,0.419813,0.708333,0.141162,781,2729,3510 +647,2012-10-08,4,1,10,1,1,0,2,0.383333,0.387608,0.709583,0.189679,874,4604,5478 +648,2012-10-09,4,1,10,0,2,1,2,0.446667,0.438112,0.761667,0.1903,601,5791,6392 +649,2012-10-10,4,1,10,0,3,1,1,0.514167,0.503142,0.630833,0.187821,780,6911,7691 +650,2012-10-11,4,1,10,0,4,1,1,0.435,0.431167,0.463333,0.181596,834,6736,7570 +651,2012-10-12,4,1,10,0,5,1,1,0.4375,0.433071,0.539167,0.235092,1060,6222,7282 +652,2012-10-13,4,1,10,0,6,0,1,0.393333,0.391396,0.494583,0.146142,2252,4857,7109 +653,2012-10-14,4,1,10,0,0,0,1,0.521667,0.508204,0.640417,0.278612,2080,4559,6639 +654,2012-10-15,4,1,10,0,1,1,2,0.561667,0.53915,0.7075,0.296037,760,5115,5875 +655,2012-10-16,4,1,10,0,2,1,1,0.468333,0.460846,0.558333,0.182221,922,6612,7534 +656,2012-10-17,4,1,10,0,3,1,1,0.455833,0.450108,0.692917,0.101371,979,6482,7461 +657,2012-10-18,4,1,10,0,4,1,2,0.5225,0.512625,0.728333,0.236937,1008,6501,7509 +658,2012-10-19,4,1,10,0,5,1,2,0.563333,0.537896,0.815,0.134954,753,4671,5424 +659,2012-10-20,4,1,10,0,6,0,1,0.484167,0.472842,0.572917,0.117537,2806,5284,8090 +660,2012-10-21,4,1,10,0,0,0,1,0.464167,0.456429,0.51,0.166054,2132,4692,6824 +661,2012-10-22,4,1,10,0,1,1,1,0.4875,0.482942,0.568333,0.0814833,830,6228,7058 +662,2012-10-23,4,1,10,0,2,1,1,0.544167,0.530304,0.641667,0.0945458,841,6625,7466 +663,2012-10-24,4,1,10,0,3,1,1,0.5875,0.558721,0.63625,0.0727792,795,6898,7693 +664,2012-10-25,4,1,10,0,4,1,2,0.55,0.529688,0.800417,0.124375,875,6484,7359 +665,2012-10-26,4,1,10,0,5,1,2,0.545833,0.52275,0.807083,0.132467,1182,6262,7444 +666,2012-10-27,4,1,10,0,6,0,2,0.53,0.515133,0.72,0.235692,2643,5209,7852 +667,2012-10-28,4,1,10,0,0,0,2,0.4775,0.467771,0.694583,0.398008,998,3461,4459 +668,2012-10-29,4,1,10,0,1,1,3,0.44,0.4394,0.88,0.3582,2,20,22 +669,2012-10-30,4,1,10,0,2,1,2,0.318182,0.309909,0.825455,0.213009,87,1009,1096 +670,2012-10-31,4,1,10,0,3,1,2,0.3575,0.3611,0.666667,0.166667,419,5147,5566 +671,2012-11-01,4,1,11,0,4,1,2,0.365833,0.369942,0.581667,0.157346,466,5520,5986 +672,2012-11-02,4,1,11,0,5,1,1,0.355,0.356042,0.522083,0.266175,618,5229,5847 +673,2012-11-03,4,1,11,0,6,0,2,0.343333,0.323846,0.49125,0.270529,1029,4109,5138 +674,2012-11-04,4,1,11,0,0,0,1,0.325833,0.329538,0.532917,0.179108,1201,3906,5107 +675,2012-11-05,4,1,11,0,1,1,1,0.319167,0.308075,0.494167,0.236325,378,4881,5259 +676,2012-11-06,4,1,11,0,2,1,1,0.280833,0.281567,0.567083,0.173513,466,5220,5686 +677,2012-11-07,4,1,11,0,3,1,2,0.295833,0.274621,0.5475,0.304108,326,4709,5035 +678,2012-11-08,4,1,11,0,4,1,1,0.352174,0.341891,0.333478,0.347835,340,4975,5315 +679,2012-11-09,4,1,11,0,5,1,1,0.361667,0.355413,0.540833,0.214558,709,5283,5992 +680,2012-11-10,4,1,11,0,6,0,1,0.389167,0.393937,0.645417,0.0578458,2090,4446,6536 +681,2012-11-11,4,1,11,0,0,0,1,0.420833,0.421713,0.659167,0.1275,2290,4562,6852 +682,2012-11-12,4,1,11,1,1,0,1,0.485,0.475383,0.741667,0.173517,1097,5172,6269 +683,2012-11-13,4,1,11,0,2,1,2,0.343333,0.323225,0.662917,0.342046,327,3767,4094 +684,2012-11-14,4,1,11,0,3,1,1,0.289167,0.281563,0.552083,0.199625,373,5122,5495 +685,2012-11-15,4,1,11,0,4,1,2,0.321667,0.324492,0.620417,0.152987,320,5125,5445 +686,2012-11-16,4,1,11,0,5,1,1,0.345,0.347204,0.524583,0.171025,484,5214,5698 +687,2012-11-17,4,1,11,0,6,0,1,0.325,0.326383,0.545417,0.179729,1313,4316,5629 +688,2012-11-18,4,1,11,0,0,0,1,0.3425,0.337746,0.692917,0.227612,922,3747,4669 +689,2012-11-19,4,1,11,0,1,1,2,0.380833,0.375621,0.623333,0.235067,449,5050,5499 +690,2012-11-20,4,1,11,0,2,1,2,0.374167,0.380667,0.685,0.082725,534,5100,5634 +691,2012-11-21,4,1,11,0,3,1,1,0.353333,0.364892,0.61375,0.103246,615,4531,5146 +692,2012-11-22,4,1,11,1,4,0,1,0.34,0.350371,0.580417,0.0528708,955,1470,2425 +693,2012-11-23,4,1,11,0,5,1,1,0.368333,0.378779,0.56875,0.148021,1603,2307,3910 +694,2012-11-24,4,1,11,0,6,0,1,0.278333,0.248742,0.404583,0.376871,532,1745,2277 +695,2012-11-25,4,1,11,0,0,0,1,0.245833,0.257583,0.468333,0.1505,309,2115,2424 +696,2012-11-26,4,1,11,0,1,1,1,0.313333,0.339004,0.535417,0.04665,337,4750,5087 +697,2012-11-27,4,1,11,0,2,1,2,0.291667,0.281558,0.786667,0.237562,123,3836,3959 +698,2012-11-28,4,1,11,0,3,1,1,0.296667,0.289762,0.50625,0.210821,198,5062,5260 +699,2012-11-29,4,1,11,0,4,1,1,0.28087,0.298422,0.555652,0.115522,243,5080,5323 +700,2012-11-30,4,1,11,0,5,1,1,0.298333,0.323867,0.649583,0.0584708,362,5306,5668 +701,2012-12-01,4,1,12,0,6,0,2,0.298333,0.316904,0.806667,0.0597042,951,4240,5191 +702,2012-12-02,4,1,12,0,0,0,2,0.3475,0.359208,0.823333,0.124379,892,3757,4649 +703,2012-12-03,4,1,12,0,1,1,1,0.4525,0.455796,0.7675,0.0827208,555,5679,6234 +704,2012-12-04,4,1,12,0,2,1,1,0.475833,0.469054,0.73375,0.174129,551,6055,6606 +705,2012-12-05,4,1,12,0,3,1,1,0.438333,0.428012,0.485,0.324021,331,5398,5729 +706,2012-12-06,4,1,12,0,4,1,1,0.255833,0.258204,0.50875,0.174754,340,5035,5375 +707,2012-12-07,4,1,12,0,5,1,2,0.320833,0.321958,0.764167,0.1306,349,4659,5008 +708,2012-12-08,4,1,12,0,6,0,2,0.381667,0.389508,0.91125,0.101379,1153,4429,5582 +709,2012-12-09,4,1,12,0,0,0,2,0.384167,0.390146,0.905417,0.157975,441,2787,3228 +710,2012-12-10,4,1,12,0,1,1,2,0.435833,0.435575,0.925,0.190308,329,4841,5170 +711,2012-12-11,4,1,12,0,2,1,2,0.353333,0.338363,0.596667,0.296037,282,5219,5501 +712,2012-12-12,4,1,12,0,3,1,2,0.2975,0.297338,0.538333,0.162937,310,5009,5319 +713,2012-12-13,4,1,12,0,4,1,1,0.295833,0.294188,0.485833,0.174129,425,5107,5532 +714,2012-12-14,4,1,12,0,5,1,1,0.281667,0.294192,0.642917,0.131229,429,5182,5611 +715,2012-12-15,4,1,12,0,6,0,1,0.324167,0.338383,0.650417,0.10635,767,4280,5047 +716,2012-12-16,4,1,12,0,0,0,2,0.3625,0.369938,0.83875,0.100742,538,3248,3786 +717,2012-12-17,4,1,12,0,1,1,2,0.393333,0.4015,0.907083,0.0982583,212,4373,4585 +718,2012-12-18,4,1,12,0,2,1,1,0.410833,0.409708,0.66625,0.221404,433,5124,5557 +719,2012-12-19,4,1,12,0,3,1,1,0.3325,0.342162,0.625417,0.184092,333,4934,5267 +720,2012-12-20,4,1,12,0,4,1,2,0.33,0.335217,0.667917,0.132463,314,3814,4128 +721,2012-12-21,1,1,12,0,5,1,2,0.326667,0.301767,0.556667,0.374383,221,3402,3623 +722,2012-12-22,1,1,12,0,6,0,1,0.265833,0.236113,0.44125,0.407346,205,1544,1749 +723,2012-12-23,1,1,12,0,0,0,1,0.245833,0.259471,0.515417,0.133083,408,1379,1787 +724,2012-12-24,1,1,12,0,1,1,2,0.231304,0.2589,0.791304,0.0772304,174,746,920 +725,2012-12-25,1,1,12,1,2,0,2,0.291304,0.294465,0.734783,0.168726,440,573,1013 +726,2012-12-26,1,1,12,0,3,1,3,0.243333,0.220333,0.823333,0.316546,9,432,441 +727,2012-12-27,1,1,12,0,4,1,2,0.254167,0.226642,0.652917,0.350133,247,1867,2114 +728,2012-12-28,1,1,12,0,5,1,2,0.253333,0.255046,0.59,0.155471,644,2451,3095 +729,2012-12-29,1,1,12,0,6,0,2,0.253333,0.2424,0.752917,0.124383,159,1182,1341 +730,2012-12-30,1,1,12,0,0,0,1,0.255833,0.2317,0.483333,0.350754,364,1432,1796 +731,2012-12-31,1,1,12,0,1,1,2,0.215833,0.223487,0.5775,0.154846,439,2290,2729 diff --git a/inst/code_paper/model.rds b/inst/code_paper/model.rds new file mode 100644 index 0000000000000000000000000000000000000000..7f33ba4e625d006014a822b7975e7184ad2822e6 Binary files /dev/null and b/inst/code_paper/model.rds differ diff --git a/inst/code_paper/prep_data_and_model.R b/inst/code_paper/prep_data_and_model.R new file mode 100644 index 0000000000000000000000000000000000000000..291339518deb4de75f0a98fd084ad1bc563f2d65 --- /dev/null +++ b/inst/code_paper/prep_data_and_model.R @@ -0,0 +1,67 @@ +library(xgboost) +library(data.table) + +# Bike sharing data from http://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# with license https://creativecommons.org/licenses/by/4.0/ + +temp <- tempfile() +url <- "https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip" +download.file(url, temp) +bike <- fread(unzip(temp, "day.csv")) +unlink(temp) + +# Following the data preparation done by +# Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). +# Causal shapley values: Exploiting causal knowledge to explain individual predictions of complex models. +# Advances in neural information processing systems, 33, 4778-4789. +# (See supplement: https://proceedings.neurips.cc/paper_files/paper/2020/file/32e54441e6382a7fbacbbbaf3c450059-Supplemental.zip) + +bike[,trend := as.numeric(difftime(dteday, + dteday[1], + units = "days"))] + +bike[,cosyear :=cospi(trend / 365 * 2)] +bike[,sinyear :=sinpi(trend / 365 * 2)] +bike[,temp := temp * (39 - (-8)) + (-8)] +bike[,atemp := atemp * (50 - (-16)) + (-16)] +bike[,windspeed := 67 * windspeed] +bike[,hum := 100 * hum] + + +# We specify the features and the response variable. +x_var <- c("trend", "cosyear", "sinyear", + "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# We split the data into a training ($80\%$) and test ($20\%$) data set, and we compute $\phi_0$. +set.seed(123) +train_index <- sample(x = nrow(bike), size = round(0.8*nrow(bike))) + +x_full <- bike[, mget(x_var)] + +x_train <- bike[train_index, mget(x_var)] +y_train <- bike[train_index, get(y_var)] + +x_explain <- bike[-train_index, mget(x_var)] +y_explain <- bike[-train_index, get(y_var)] + +# We fit the a basic xgboost model to the training data. +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 100, + verbose = FALSE +) + +#### Writing training and explanation data to csv files +fwrite(x_full, file="inst/code_paper/x_full.csv") +fwrite(x_train, file="inst/code_paper/x_train.csv") +fwrite(as.data.table(y_train), file="inst/code_paper/y_train.csv") +fwrite(x_explain, file="inst/code_paper/x_explain.csv") +fwrite(as.data.table(y_explain), file="inst/code_paper/y_explain.csv") + +# We save the xgboost model object +xgb.save(model, "inst/code_paper/xgb.model") +saveRDS(model, "inst/code_paper/model.rds") + + diff --git a/inst/code_paper/scatter_ctree.pdf b/inst/code_paper/scatter_ctree.pdf new file mode 100644 index 0000000000000000000000000000000000000000..bd1f6de3be0ec1e9d84b44735bd36d20a3f90b6e Binary files /dev/null and b/inst/code_paper/scatter_ctree.pdf differ diff --git a/inst/code_paper/waterfall_group.pdf b/inst/code_paper/waterfall_group.pdf new file mode 100644 index 0000000000000000000000000000000000000000..46add1adf3455a3bcece72422116f032d59c7df8 Binary files /dev/null and b/inst/code_paper/waterfall_group.pdf differ diff --git a/inst/code_paper/x_explain.csv b/inst/code_paper/x_explain.csv new file mode 100644 index 0000000000000000000000000000000000000000..be88f29dcbf2e77537ca9c36461efd165a63ad30 --- /dev/null +++ b/inst/code_paper/x_explain.csv @@ -0,0 +1,147 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +0,1,0,8.175849,7.99925,10.749882,80.5833 +2,0.999407400739705,0.0344216116227457,1.229108,-3.49927,16.636703,43.7273 +8,0.990532452132223,0.137278772113265,-1.498349,-8.33245,24.25065,43.4167 +17,0.957485188355039,0.288482432880609,2.183349,-0.666022,9.833925,86.1667 +21,0.935367949313148,0.353676122176372,-5.2208712,-10.7814064,11.52199,40 +26,0.901501684131884,0.432775592550431,1.165,-1.4998,7.627079,68.75 +31,0.860961015888994,0.508670943852104,1.032178,-0.52102,3.565271,82.9565 +41,0.761104258660775,0.648629561034982,0.887277000000001,-1.908406,7.27285,50.6364 +42,0.749826401204569,0.661634618242278,2.4575,-0.625036,13.625589,54.4167 +43,0.738326354003107,0.674443618832946,6.876534,5.391458,17.479161,45.7391 +52,0.625410572985246,0.780295851070776,0.564434,-3.721822,13.110761,57.7778 +55,0.584298173628369,0.811539059007361,9.124356,7.130426,23.218113,71.2174 +57,0.556017436657045,0.831170626365808,8.143466,7.173194,8.391616,68 +59,0.527077708642372,0.849817091527528,4.533349,1.416014,14.500475,53.5 +61,0.497513288907181,0.867456354729597,1.321651,-2.791222,15.125518,31.8333 +62,0.482507741761219,0.875891705144243,4.298349,0.874814000000001,13.624182,61.0417 +69,0.373719714790469,0.927541683579197,6.876534,4.13,15.60899,64.9565 +76,0.2595117970698,0.965739937654855,17.38,18.782594,15.478139,52.5217 +81,0.175531490421428,0.984473816752092,8.306979,6.303974,15.695487,83.9565 +85,0.107381346664163,0.994217906893952,3.893021,0.522373999999999,12.3481,49.3913 +91,0.00430353829624429,0.99999073973619,6.805,4.832042,13.208782,65.375 +96,-0.0816763953304224,0.99665890175417,12.5625,12.623936,10.874904,60.2917 +98,-0.1159345995955,0.993256849267414,8.0975,6.540914,8.916561,87.75 +100,-0.150055398344653,0.988677590232341,19.995644,21.304322,21.739758,71.6956 +122,-0.50496105472152,0.863142128049912,20.983349,22.417214,22.958689,69.7083 +125,-0.548842958284719,0.835925479418637,14.520849,15.291722,15.292482,59 +141,-0.75549331407268,0.655156357209085,20.395849,21.917,9.916536,74.9583 +143,-0.777597146973627,0.628762814595835,23.02,23.917658,13.875164,74.0833 +144,-0.788305055830525,0.615284599963328,23.059151,24.625772,10.333611,69.625 +145,-0.798779372886365,0.601624063224923,25.291651,27.209408,13.376014,67.75 +146,-0.809016994374947,0.587785252292473,24.038349,26.042528,16.125493,65.375 +148,-0.828770087174503,0.559589262410177,23.3725,24.6263,14.333846,81.875 +149,-0.838279705217774,0.545240438540651,26.466651,28.292072,8.792075,68.5 +153,-0.873807103611081,0.48627270710869,21.14,22.750778,16.959107,35.4167 +156,-0.897743393534234,0.440518784350495,23.881651,25.042628,8.167032,60 +172,-0.98370929377361,0.179766585725562,26.466651,29.792714,11.541554,70.75 +182,-0.999962959116266,0.0086069968886887,26.701651,28.042328,7.709154,44.4583 +191,-0.989314203970366,-0.145799196919875,27.8375,32.167064,18.916579,63.5833 +205,-0.92592477719385,-0.377707965203965,26.936651,31.583228,6.0841561,75.7083 +207,-0.912374757970727,-0.409355958815622,28.425,29.584022,12.292021,40.2917 +212,-0.873807103611081,-0.48627270710869,28.268349,30.417272,10.500039,55.0833 +214,-0.856550995901004,-0.516062391015853,26.388349,28.875842,9.084061,65.75 +219,-0.809016994374948,-0.587785252292473,27.955,30.416678,12.875725,59.2083 +232,-0.658401584698049,-0.752666827532008,25.409151,28.58465,16.666518,77 +246,-0.459732739452105,-0.888057322629493,25.330849,27.918314,13.833289,74.2083 +254,-0.333468778918187,-0.942761143390421,22.284356,23.74058,5.957171,69.2174 +262,-0.200890555130635,-0.97961369164549,18.398349,19.126322,9.041918,88.125 +263,-0.18399835165768,-0.982926551979982,19.965,20.335178,6.4590814,90 +265,-0.150055398344653,-0.988677590232341,20.630849,18.46025,5.2505689,97.25 +270,-0.0645084494493171,-0.997917160865392,21.845,21.960428,9.958143,84.875 +271,-0.0473213883224323,-0.998879715585034,20.983349,21.917792,11.583161,69.9167 +272,-0.0301203048469084,-0.999546280687357,18.515849,19.958714,13.833825,64.75 +273,-0.0129102960750088,-0.999916658654738,11.27,11.248958,19.583832,75.375 +275,0.0215160974362222,-0.999768501979891,10.055849,9.875036,5.5841686,76.0833 +282,0.141540295217043,-0.989932495087353,18.829151,19.83305,2.8343814,73.375 +283,0.158559385103135,-0.987349442393986,18.633349,20.042336,9.583814,80.875 +292,0.309016994374947,-0.951056516295153,14.364151,14.79065,28.292425,63.625 +295,0.357698238833125,-0.933837228822925,11.818349,11.873978,6.6673375,74.125 +297,0.389630449530788,-0.920971287716635,14.168349,14.58275,11.166086,62.2917 +299,0.421100870796089,-0.907013812802636,14.09,14.165036,13.250121,81.2917 +304,0.49751328890718,-0.867456354729597,10.839151,10.207808,9.083257,68.375 +317,0.677614789046689,-0.735417022963986,16.91,18.624392,20.541932,58.7083 +319,0.702527474169157,-0.711656622281775,13.463349,13.831208,9.167543,93 +320,0.714673386042961,-0.699458327051647,8.058349,5.332586,20.459254,57.5833 +321,0.726607524768566,-0.687052767223667,4.885849,1.999586,11.291711,41 +324,0.761104258660774,-0.648629561034981,13.0325,13.374092,9.249618,91 +326,0.782980103677063,-0.622046748440868,12.719151,12.415442,22.500275,75.7917 +328,0.803927961832821,-0.594726686960764,9.625,9.124286,6.6260186,64.375 +332,0.842941537354783,-0.538005171538299,13.541651,13.79075,17.292164,83.0833 +333,0.852077521101309,-0.52341560736555,7.275,4.540586,18.167586,61.3333 +334,0.860961015888994,-0.508670943852104,6.6875,4.166564,14.750586,52.4583 +339,0.901501684131884,-0.432775592550431,13.7375,14.1224,15.583061,94.9583 +350,0.966847813605277,-0.255353295116187,4.141651,0.458486000000001,16.292189,56.0833 +355,0.985220106756061,-0.171293144181478,11.896651,12.123986,3.167425,75.75 +359,0.994670819911521,-0.103101697447435,7.121733,4.82531,16.044155,50.6957 +365,1,0,9.39,8.790986,12.875189,69.25 +366,0.999851839209116,0.0172133561558353,4.833021,0.652063999999999,22.087555,38.1304 +368,0.998666816288476,0.0516196672232542,-2.9475,-8.123758,12.3749,41.4583 +369,0.997630305306586,0.0688024268023196,4.494151,2.375192,8.709129,52.4167 +370,0.996298174934608,0.0859647987374468,7.705849,6.457622,11.249836,54.2083 +375,0.985220106756061,0.171293144181478,4.885849,2.666186,8.791807,84.75 +378,0.975064532257195,0.221921513004165,0.459999999999999,-3.916258,12.541261,45.75 +381,0.962309077454148,0.271958157534106,9.533021,8.042348,23.39171,71.6087 +386,0.935367949313148,0.353676122176372,-0.3625,-5.2915,13.375746,79.625 +410,0.714673386042961,0.699458327051647,8.371651,7.207514,12.1672,53.125 +419,0.598180914405916,0.801361088174677,11.1525,11.124086,15.916989,73.75 +431,0.42110087079609,0.907013812802636,10.995849,9.4166,23.167193,51.3333 +432,0.405425728359997,0.914127988185334,16.7925,18.623864,29.584721,56.75 +436,0.341570769167856,0.939856057941895,13.933349,14.333072,13.916771,48.9167 +447,0.158559385103135,0.987349442393986,20.278349,21.624422,7.7921,69.4167 +448,0.141540295217043,0.989932495087353,15.6175,16.124378,12.916461,88.5417 +452,0.0730951298980776,0.997324973108156,14.755849,15.0827,19.541957,48.125 +460,-0.0645084494493158,0.997917160865392,12.445,12.456758,14.708443,37.4167 +461,-0.0816763953304226,0.99665890175417,10.956651,9.790622,20.125996,37.7083 +477,-0.349647455251229,0.936881346295431,10.643349,9.707264,23.084582,83.5417 +480,-0.397542814282555,0.917583626059394,14.403349,15.040922,7.959064,42.7917 +481,-0.413278607782904,0.910604630094216,15.421651,15.916478,11.833875,75.6667 +485,-0.47495107206705,0.880012203973536,13.815849,14.207936,11.499746,57 +487,-0.50496105472152,0.863142128049912,18.515849,19.501136,9.249886,79.7083 +490,-0.548842958284719,0.835925479418637,21.218349,22.584128,10.250464,75.6667 +491,-0.563150724274918,0.82635419872391,18.4375,20.084642,10.041893,74 +494,-0.605056069648849,0.796182863782616,19.025,20.49965,14.499604,74.4167 +495,-0.618671403262504,0.785649855078714,15.774151,16.457678,21.042221,55.2083 +496,-0.632103411187348,0.774884041367041,17.066651,18.374978,15.874779,36.0417 +498,-0.658401584698049,0.752666827532008,20.7875,22.625708,15.082839,57.625 +501,-0.696376225596872,0.717676913675962,21.923349,23.33435,8.208304,69.7917 +523,-0.912374757970727,0.409355958815622,20.3175,21.583172,10.54245,56.7083 +528,-0.944187508834199,0.32940848222453,22.706651,23.45975,14.374582,83.3333 +534,-0.973118337233262,0.230305670230612,18.711651,19.959572,11.707982,77.7917 +539,-0.989314203970366,0.145799196919875,26.388349,27.084272,12.041307,53.4583 +560,-0.976938492777182,-0.213520915439796,25.056651,27.958772,9.626493,69.9167 +562,-0.969009825724406,-0.247022180480935,27.876651,31.79225,11.000529,64.5 +564,-0.959932689659744,-0.280230675199217,29.286651,33.208478,9.208614,57.7083 +565,-0.954966754855255,-0.29671281927349,28.19,31.166372,11.083743,60.0417 +570,-0.92592477719385,-0.377707965203965,27.289151,30.6257,14.167418,65.5 +573,-0.905193189891397,-0.425000339969555,28.738349,32.458322,10.250464,59.4583 +576,-0.882048024955854,-0.471159507673864,26.349151,29.209142,10.292339,66.875 +577,-0.873807103611081,-0.48627270710869,25.526651,27.751136,11.083475,70.4167 +578,-0.865307254363206,-0.501241813445775,25.7225,28.042328,9.458993,67.75 +581,-0.838279705217774,-0.545240438540651,29.286651,33.583622,17.249686,61.3333 +584,-0.809016994374947,-0.587785252292474,26.584151,30.042986,7.832836,70.375 +586,-0.788305055830525,-0.615284599963328,27.524151,30.167528,10.4587,62.0417 +594,-0.696376225596872,-0.717676913675962,25.996651,-0.00159999999999982,15.500718,57.0833 +596,-0.671259957567532,-0.741222010848596,21.884151,23.834564,5.79215,71.1667 +607,-0.519743812155516,-0.854322169749827,25.213349,27.083414,5.1668189,59.0417 +608,-0.50496105472152,-0.863142128049912,27.915849,29.5004,11.291711,58.75 +611,-0.459732739452105,-0.888057322629493,25.2525,27.667514,10.125107,79.0833 +623,-0.267814305162175,-0.963470548564149,20.591651,22.667222,16.583907,50.1667 +626,-0.217723230396531,-0.976010550632368,21.296651,21.294422,23.958329,87.25 +631,-0.133014706534196,-0.991114063993455,16.870849,18.249578,14.958286,46.7083 +643,0.0730951298980769,-0.997324973108156,20.905,22.292342,7.12545,62.75 +646,0.12447926388679,-0.992222209417932,10.016651,9.582128,12.708493,70.9583 +656,0.292600335633348,-0.956234826591906,16.5575,17.83325,15.874779,72.8333 +661,0.373719714790468,-0.927541683579197,17.575849,19.000064,6.3345686,64.1667 +663,0.405425728359997,-0.914127988185334,17.85,18.959408,8.333125,80.0417 +666,0.452072203932305,-0.891981346459548,14.4425,14.872886,26.666536,69.4583 +668,0.482507741761218,-0.875891705144243,6.954554,4.453994,14.271603,82.5455 +674,0.570242292691787,-0.821476553302414,7.000849,4.33295,15.833775,49.4167 +679,0.638749422051527,-0.769414826883938,10.290849,9.999842,3.8756686,64.5417 +681,0.664855397964286,-0.746972087696555,14.795,15.375278,11.625639,74.1667 +683,0.690173388242971,-0.723644038295913,5.590849,2.583158,13.374875,55.2083 +685,0.714673386042961,-0.699458327051647,8.215,6.915464,11.458675,52.4583 +692,0.793571608952147,-0.608476870115126,9.311651,8.999414,9.917407,56.875 +702,0.886070621534138,-0.46355027090285,13.2675,14.082536,5.5422936,76.75 +705,0.908817637339503,-0.417193602612317,4.024151,1.041464,11.708518,50.875 +728,0.999407400739705,-0.0344216116227456,3.906651,-0.00159999999999982,8.333661,75.2917 diff --git a/inst/code_paper/x_full.csv b/inst/code_paper/x_full.csv new file mode 100644 index 0000000000000000000000000000000000000000..baf54bc2f9a888fdda3c2414b60e59188ae4f254 --- /dev/null +++ b/inst/code_paper/x_full.csv @@ -0,0 +1,732 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +0,1,0,8.175849,7.99925,10.749882,80.5833 +1,0.999851839209116,0.0172133561558347,9.083466,7.346774,16.652113,69.6087 +2,0.999407400739705,0.0344216116227457,1.229108,-3.49927,16.636703,43.7273 +3,0.998666816288476,0.0516196672232538,1.4,-1.999948,10.739832,59.0435 +4,0.997630305306586,0.0688024268023199,2.666979,-0.868180000000001,12.5223,43.6957 +5,0.996298174934608,0.0859647987374465,1.604356,-0.608205999999999,6.0008684,51.8261 +6,0.994670819911521,0.103101697447435,1.236534,-2.216626,11.304642,49.8696 +7,0.99274872245774,0.120208044899353,-0.244999999999999,-5.291236,17.875868,53.5833 +8,0.990532452132223,0.137278772113265,-1.498349,-8.33245,24.25065,43.4167 +9,0.988022665663698,0.154308820664281,-0.910849000000001,-6.041392,14.958889,48.2917 +10,0.985220106756061,0.171293144181478,-0.0527230000000003,-3.363376,8.182844,68.6364 +11,0.982125605868001,0.188226709843244,0.118169,-5.408782,20.410009,59.9545 +12,0.978740079966915,0.205104499868619,-0.244999999999999,-6.041722,20.167,47.0417 +13,0.975064532257195,0.221921513004166,-0.439109999999999,-3.564742,8.478716,53.7826 +14,0.97110005188295,0.23867276600595,2.966651,0.375392000000002,10.583521,49.875 +15,0.966847813605277,0.255353295116187,2.888349,-0.541677999999999,12.625011,48.375 +16,0.962309077454149,0.271958157534106,0.264151,-4.333114,12.999139,53.75 +17,0.957485188355039,0.288482432880609,2.183349,-0.666022,9.833925,86.1667 +18,0.952377575730397,0.304921224656289,5.732178,3.695852,13.957239,74.1739 +19,0.946987753076075,0.321269661692364,4.298349,0.833300000000001,13.125568,53.8333 +20,0.941317317512847,0.337522899594113,0.342499999999999,-5.583022,23.667214,45.7083 +21,0.935367949313148,0.353676122176372,-5.2208712,-10.7814064,11.52199,40 +22,0.929141411403174,0.369724542890673,-3.4634801,-9.4766194,16.5222,43.6522 +23,0.922639548840488,0.385663406243607,-3.4226089,-8.21662,10.60811,49.1739 +24,0.915864288267287,0.401487989205973,2.503466,-0.521284,8.696332,61.6957 +25,0.908817637339503,0.417193602612317,2.2225,-2.5624,19.68795,86.25 +26,0.901501684131884,0.432775592550431,1.165,-1.4998,7.627079,68.75 +27,0.893918596519257,0.448229341740411,1.563466,-1.261078,8.2611,79.3043 +28,0.886070621534138,0.463550270902851,1.236534,-1.999684,9.739455,65.1739 +29,0.877960084700888,0.478733840115789,2.176534,0.521252000000001,4.9568342,72.2174 +30,0.869589389346611,0.493775550159977,0.499151,-3.7075,12.541864,60.375 +31,0.860961015888994,0.508670943852104,1.032178,-0.52102,3.565271,82.9565 +32,0.852077521101309,0.52341560736555,4.22,0.791522000000001,17.708636,77.5417 +33,0.842941537354783,0.5380051715383,0.786979000000001,-4.260052,18.609384,43.7826 +34,0.83355577183857,0.55243531316762,1.931288,-0.913257999999999,8.565213,58.5217 +35,0.823923005757554,0.566701756291118,2.966651,0.0418279999999989,10.792293,92.9167 +36,0.814046093508218,0.580800273453801,5.434151,3.250286,9.5006,56.8333 +37,0.803927961832821,0.594726686960763,4.768349,4.041428,3.0423561,73.8333 +38,0.793571608952147,0.608476870115126,2.379151,-2.915764,24.25065,53.7917 +39,0.782980103677063,0.622046748440868,-1.665199,-6.477322,12.652213,49.4783 +40,0.772156584499164,0.635432300890177,-1.215644,-6.129832,14.869645,43.7391 +41,0.761104258660775,0.648629561034982,0.887277000000001,-1.908406,7.27285,50.6364 +42,0.749826401204569,0.661634618242278,2.4575,-0.625036,13.625589,54.4167 +43,0.738326354003107,0.674443618832946,6.876534,5.391458,17.479161,45.7391 +44,0.726607524768566,0.687052767223667,11.505,10.2911,27.999836,37.5833 +45,0.714673386042961,0.699458327051647,4.506089,0.782084000000001,19.522058,31.4348 +46,0.702527474169157,0.711656622281775,6.958267,4.8692,16.869997,42.3478 +47,0.690173388242972,0.723644038295912,12.484151,12.291428,15.416968,50.5 +48,0.677614789046689,0.735417022963986,16.518349,17.790878,17.749975,51.6667 +49,0.664855397964287,0.746972087696555,10.760849,9.832664,34.000021,18.7917 +50,0.651898995878713,0.758305808478562,5.405199,2.30378,14.956745,40.7826 +51,0.638749422051527,0.769414826883938,6.256651,2.74895,20.625682,60.5 +52,0.625410572985246,0.780295851070776,0.564434,-3.721822,13.110761,57.7778 +53,0.611886401268724,0.790945656756777,2.421733,0.217321999999999,6.305571,42.3043 +54,0.598180914405916,0.801361088174677,5.895644,3.086606,16.783232,69.7391 +55,0.584298173628369,0.811539059007361,9.124356,7.130426,23.218113,71.2174 +56,0.570242292691787,0.821476553302414,5.2775,2.624672,12.500257,53.7917 +57,0.556017436657045,0.831170626365808,8.143466,7.173194,8.391616,68 +58,0.541627820655981,0.840618405634478,11.141831,10.407788,19.408962,87.6364 +59,0.527077708642372,0.849817091527528,4.533349,1.416014,14.500475,53.5 +60,0.512371412128424,0.858763958275803,7.745,5.124686,20.624811,44.9583 +61,0.497513288907181,0.867456354729597,1.321651,-2.791222,15.125518,31.8333 +62,0.482507741761219,0.875891705144243,4.298349,0.874814000000001,13.624182,61.0417 +63,0.467359217158002,0.884067509943364,10.055849,8.999414,16.875357,78.9167 +64,0.452072203932304,0.891981346459549,9.696534,8.172632,23.000229,94.8261 +65,0.436651231956064,0.899630869652243,4.301733,-0.261574,22.870584,55.1304 +66,0.42110087079609,0.907013812802636,5.7475,3.9584,8.08355,42.0833 +67,0.405425728359997,0.914127988185334,5.904151,2.916128,14.75005,77.5417 +68,0.389630449530788,0.920971287716635,10.287277,9.454088,17.545759,0 +69,0.373719714790469,0.927541683579197,6.876534,4.13,15.60899,64.9565 +70,0.357698238833126,0.933837228822925,7.470849,5.4995,14.791925,59.4583 +71,0.341570769167856,0.939856057941895,10.064356,9.086006,18.130468,52.7391 +72,0.32534208471198,0.945596387427143,7.285199,5.912,9.174042,49.6957 +73,0.309016994374947,0.951056516295153,6.917377,4.999748,12.348703,65.5652 +74,0.292600335633349,0.956234826591906,9.165199,8.21738,13.608839,77.6522 +75,0.276096973097469,0.961129783872301,11.505,11.081978,14.041793,60.2917 +76,0.2595117970698,0.965739937654855,17.38,18.782594,15.478139,52.5217 +77,0.242849722095936,0.970063921851507,14.2075,14.79065,24.667189,37.9167 +78,0.226115685508288,0.97410045517242,7.6275,5.4995,13.917307,47.375 +79,0.209314645963049,0.977848341505657,12.230445,11.04251,19.348461,73.7391 +80,0.19245158197083,0.981306470271609,12.758349,13.082372,15.12525,62.4583 +81,0.175531490421428,0.984473816752092,8.306979,6.303974,15.695487,83.9565 +82,0.158559385103135,0.987349442393986,5.395,1.874978,16.333729,80.5833 +83,0.141540295217043,0.989932495087353,4.415849,0.916591999999998,15.458575,49.5 +84,0.124479263886789,0.992222209417932,4.494151,0.999686000000001,14.041257,39.4167 +85,0.107381346664163,0.994217906893952,3.893021,0.522373999999999,12.3481,49.3913 +86,0.0902516100310412,0.995918996147179,4.424356,0.999884000000002,14.217668,30.2174 +87,0.0730951298980776,0.997324973108156,6.2175,3.331928,15.208732,31.4167 +88,0.0559169901006033,0.998435421155564,6.1,3.6251,11.583496,64.6667 +89,0.0387222808921745,0.999250011239683,4.611651,0.999949999999998,14.582282,91.8333 +90,0.0215160974362223,0.999768501979891,6.1,2.707964,17.333436,68.625 +91,0.00430353829624429,0.99999073973619,6.805,4.832042,13.208782,65.375 +92,-0.012910296075009,0.999916658654738,9.781651,8.998622,12.208271,48 +93,-0.0301203048469081,0.999546280687357,18.946651,19.833314,25.833257,42.625 +94,-0.0473213883224319,0.998879715585034,11.465849,10.2911,26.000489,64.2083 +95,-0.0645084494493162,0.997917160865392,10.369151,9.582128,17.625221,47.0833 +96,-0.0816763953304224,0.99665890175417,12.5625,12.623936,10.874904,60.2917 +97,-0.0988201387328714,0.995105311100698,7.784151,5.415614,15.208464,83.625 +98,-0.1159345995955,0.993256849267414,8.0975,6.540914,8.916561,87.75 +99,-0.133014706534196,0.991114063993455,12.053349,12.164642,9.833389,85.75 +100,-0.150055398344653,0.988677590232341,19.995644,21.304322,21.739758,71.6956 +101,-0.167051625502119,0.98594814996383,15.6175,16.541564,18.416893,73.9167 +102,-0.18399835165768,0.982926551979982,11.3875,11.540678,16.791339,81.9167 +103,-0.200890555130635,0.97961369164549,13.9725,14.540972,7.4169,54.0417 +104,-0.217723230396532,0.976010550632368,12.993349,13.166258,15.167125,67.125 +105,-0.23449138957041,0.972118196629061,12.249151,12.082472,22.834136,88.8333 +106,-0.251190063884819,0.967937783024064,13.463349,13.415936,20.334232,47.9583 +107,-0.267814305162174,0.963470548564149,16.0875,17.207636,10.958989,54.25 +108,-0.284359187281004,0.958717816987296,15.774151,16.291028,10.584057,66.5833 +109,-0.300819807635668,0.953680996630446,19.965,21.249872,16.208975,61.4167 +110,-0.317191288589106,0.948361580012172,13.580849,13.956872,21.792286,40.7083 +111,-0.333468778918187,0.942761143390421,7.823349,5.248964,14.707907,72.9583 +112,-0.349647455251228,0.936881346295431,13.62,13.707986,15.458575,88.7917 +113,-0.365722523497269,0.930723931037979,19.338349,20.416358,12.875725,81.0833 +114,-0.381689220266659,0.924290722193093,20.513349,21.917,12.417311,77.6667 +115,-0.397542814282556,0.917583626059394,21.688349,23.209478,21.8755,72.9167 +116,-0.413278607782904,0.910604630094216,21.14,21.959372,20.9174,83.5417 +117,-0.428891937912483,0.903355802324685,21.0225,22.209314,21.500836,70.0833 +118,-0.444378178104613,0.895839290734909,15.97,16.832558,16.084221,45.7083 +119,-0.459732739452104,0.888057322629493,14.2075,14.625386,15.750025,50.3333 +120,-0.47495107206705,0.880012203973536,13.228349,13.581464,7.125718,76.2083 +121,-0.490028666429059,0.871706318709322,17.810849,19.166978,12.291418,73 +122,-0.50496105472152,0.863142128049912,20.983349,22.417214,22.958689,69.7083 +123,-0.519743812155515,0.854322169749827,11.465849,10.7069,22.042732,73.7083 +124,-0.534372558280979,0.845249057353063,13.580849,13.166522,19.791264,44.4167 +125,-0.548842958284719,0.835925479418637,14.520849,15.291722,15.292482,59 +126,-0.563150724274919,0.82635419872391,16.44,17.832986,10.75015,54.125 +127,-0.577291616551727,0.816538051445916,16.831651,18.249578,5.0007125,63.1667 +128,-0.591261444863578,0.806479946320945,17.0275,18.666236,11.792,58.875 +129,-0.605056069648849,0.796182863782616,17.0275,18.499586,7.749957,48.9167 +130,-0.618671403262503,0.785649855078715,17.4975,18.8744,8.083014,63.2917 +131,-0.632103411187349,0.774884041367041,17.145,18.541958,12.707689,74.75 +132,-0.64534811322955,0.763888612790543,16.0875,16.6238,12.041575,86.3333 +133,-0.658401584698049,0.752666827532008,16.479151,17.041514,9.04165,92.25 +134,-0.671259957567531,0.741222010848596,18.4375,19.376,10.249593,86.7083 +135,-0.68391942162461,0.729557554086488,19.1425,20.333792,8.500357,78.7917 +136,-0.696376225596872,0.717676913675962,18.398349,19.542914,18.582718,83.7917 +137,-0.70862667826446,0.705583610107178,17.85,18.792428,13.499964,87 +138,-0.720667149553861,0.693281226886978,16.949151,17.708972,7.250271,82.9583 +139,-0.732494071613579,0.680773409477017,17.223349,18.916772,8.375871,71.9583 +140,-0.74410393987136,0.668063864213534,20.3175,21.75035,8.08355,62.6667 +141,-0.75549331407268,0.655156357209085,20.395849,21.917,9.916536,74.9583 +142,-0.766658819300159,0.642054713236564,21.688349,22.959536,15.667414,81 +143,-0.777597146973627,0.628762814595835,23.02,23.917658,13.875164,74.0833 +144,-0.788305055830525,0.615284599963328,23.059151,24.625772,10.333611,69.625 +145,-0.798779372886365,0.601624063224923,25.291651,27.209408,13.376014,67.75 +146,-0.809016994374947,0.587785252292473,24.038349,26.042528,16.125493,65.375 +147,-0.81901488666808,0.573772267904325,22.824151,24.417014,15.416164,72.9583 +148,-0.828770087174503,0.559589262410177,23.3725,24.6263,14.333846,81.875 +149,-0.838279705217774,0.545240438540651,26.466651,28.292072,8.792075,68.5 +150,-0.847540922892831,0.530730048161934,28.425,31.875278,7.459043,63.6667 +151,-0.856550995901004,0.516062391015853,27.915849,31.583822,13.875164,67.7083 +152,-0.865307254363206,0.501241813445776,25.605,26.500172,19.583229,30.5 +153,-0.873807103611081,0.48627270710869,21.14,22.750778,16.959107,35.4167 +154,-0.882048024955853,0.471159507673864,21.845,23.249936,8.250514,45.625 +155,-0.890027576434677,0.455906693508459,22.471651,24.709064,9.292364,65.25 +156,-0.897743393534234,0.440518784350495,23.881651,25.042628,8.167032,60 +157,-0.905193189891397,0.425000339969554,25.2525,27.2927,12.583136,59.7917 +158,-0.912374757970727,0.409355958815622,28.464151,32.000414,9.166739,62.2083 +159,-0.91928596971861,0.393590276656467,29.991651,34.000214,10.042161,56.8333 +160,-0.92592477719385,0.377707965203965,27.485,30.417272,9.417118,60.5 +161,-0.932289213174513,0.361713730729768,26.075,28.750508,10.37495,65.4583 +162,-0.938377391740864,0.345612312670734,24.5475,26.45945,10.958989,74.7917 +163,-0.944187508834199,0.32940848222453,21.845,23.709164,20.45845,49.4583 +164,-0.949717842791432,0.313107040935827,20.395849,23.042036,18.041961,50.7083 +165,-0.954966754855255,0.29671281927349,21.453349,22.791764,11.250104,47.1667 +166,-0.959932689659744,0.280230675199216,21.531651,23.292836,13.833557,68.8333 +167,-0.964614175691244,0.263665492728008,22.510849,23.625278,9.582943,73.5833 +168,-0.969009825724406,0.247022180480936,24.743349,26.500964,8.000336,67.0417 +169,-0.973118337233262,0.230305670230612,24.860849,26.625836,6.834,66.6667 +170,-0.976938492777182,0.213520915439796,21.845,23.292836,10.416825,74.625 +171,-0.980469160361632,0.196672889793576,23.999151,26.084636,11.458675,77.0417 +172,-0.98370929377361,0.179766585725562,26.466651,29.792714,11.541554,70.75 +173,-0.986657932891657,0.162807012938517,26.231651,29.792978,15.999868,70.3333 +174,-0.989314203970366,0.145799196919875,26.035849,27.334478,14.875675,57.3333 +175,-0.99167731989929,0.128748177452581,24.665,26.458658,14.041257,48.3333 +176,-0.993746580436178,0.111659007121695,23.96,26.083514,6.3337311,51.3333 +177,-0.995521372414475,0.0945367498171996,24.0775,26.042264,7.208396,65.8333 +178,-0.997001169925015,0.077386479233463,26.975849,29.708828,9.666961,63.4167 +179,-0.998185534471859,0.060213277365793,26.231651,27.209408,17.542007,49.7917 +180,-0.99907411510223,0.0430222330045306,24.743349,26.042528,12.415904,43.4167 +181,-0.999666648510511,0.0258184402271331,25.9575,27.042692,6.874736,39.625 +182,-0.999962959116266,0.0086069968886887,26.701651,28.042328,7.709154,44.4583 +183,-0.999962959116266,-0.0086069968886887,25.683349,28.12595,15.333486,68.25 +184,-0.999666648510511,-0.0258184402271331,26.153349,27.917522,5.4591064,63.7917 +185,-0.99907411510223,-0.0430222330045306,27.093349,29.958308,8.459286,59.0417 +186,-0.998185534471859,-0.060213277365793,25.84,29.251778,10.042161,74.3333 +187,-0.997001169925015,-0.077386479233463,27.25,29.333486,10.6664,65.125 +188,-0.995521372414475,-0.0945367498171996,25.330849,28.251878,15.083643,75.7917 +189,-0.993746580436178,-0.111659007121695,26.466651,27.834428,11.250104,60.9167 +190,-0.99167731989929,-0.128748177452581,27.1325,29.54165,12.292557,57.8333 +191,-0.989314203970366,-0.145799196919875,27.8375,32.167064,18.916579,63.5833 +192,-0.986657932891657,-0.162807012938517,29.325849,32.79215,13.417018,55.9167 +193,-0.98370929377361,-0.179766585725562,27.093349,29.500664,9.790911,63.1667 +194,-0.980469160361632,-0.196672889793576,23.999151,25.916864,16.124689,47.625 +195,-0.976938492777182,-0.213520915439796,23.176651,25.208486,12.249811,59.125 +196,-0.973118337233262,-0.230305670230612,24.273349,26.125358,13.958914,58.5 +197,-0.969009825724406,-0.247022180480936,25.800849,28.208978,16.417211,60.4167 +198,-0.964614175691244,-0.263665492728008,27.093349,30.45905,14.458868,65.125 +199,-0.959932689659744,-0.280230675199216,28.503349,33.333614,8.7502,65.0417 +200,-0.954966754855255,-0.29671281927349,28.111651,33.2921,7.625739,70.7083 +201,-0.949717842791432,-0.313107040935827,30.305,38.540486,14.875407,69.125 +202,-0.944187508834199,-0.32940848222453,31.871651,39.499136,8.9177,58.0417 +203,-0.938377391740864,-0.345612312670734,31.910849,37.082942,8.791807,50 +204,-0.932289213174513,-0.361713730729768,31.01,36.458714,11.334457,55.0833 +205,-0.92592477719385,-0.377707965203965,26.936651,31.583228,6.0841561,75.7083 +206,-0.91928596971861,-0.393590276656467,28.268349,30.000614,13.417286,54.0833 +207,-0.912374757970727,-0.409355958815622,28.425,29.584022,12.292021,40.2917 +208,-0.905193189891398,-0.425000339969554,28.620849,32.8334,11.958093,58.3333 +209,-0.897743393534234,-0.440518784350495,31.401651,35.873822,11.667246,54.25 +210,-0.890027576434677,-0.455906693508459,29.795849,32.083442,11.291979,46.5833 +211,-0.882048024955854,-0.471159507673864,29.874151,32.166536,11.042471,48.0833 +212,-0.873807103611081,-0.48627270710869,28.268349,30.417272,10.500039,55.0833 +213,-0.865307254363206,-0.501241813445776,28.816651,30.666686,13.79195,49.125 +214,-0.856550995901004,-0.516062391015853,26.388349,28.875842,9.084061,65.75 +215,-0.847540922892831,-0.530730048161934,25.37,27.876008,13.20905,75.75 +216,-0.838279705217774,-0.545240438540651,25.409151,27.333422,12.374632,63.0833 +217,-0.828770087174504,-0.559589262410177,25.683349,28.626164,15.29275,75.5 +218,-0.81901488666808,-0.573772267904325,26.8975,31.209272,13.499629,75.2917 +219,-0.809016994374948,-0.587785252292473,27.955,30.416678,12.875725,59.2083 +220,-0.798779372886365,-0.601624063224923,28.425,31.791986,10.125107,57.0417 +221,-0.788305055830526,-0.615284599963328,28.033349,29.208878,13.417286,42.4167 +222,-0.777597146973627,-0.628762814595835,25.7225,27.000386,11.041332,42.375 +223,-0.766658819300159,-0.642054713236564,25.291651,27.166772,8.416607,41.5 +224,-0.75549331407268,-0.655156357209085,24.234151,26.626628,14.167418,72.9583 +225,-0.744103939871361,-0.668063864213534,23.803349,25.209608,14.916411,81.75 +226,-0.732494071613579,-0.680773409477017,23.294151,24.667022,13.999918,71.2083 +227,-0.720667149553861,-0.693281226886978,24.939151,26.625242,15.834043,57.8333 +228,-0.70862667826446,-0.705583610107178,25.996651,28.000286,9.625689,57.5417 +229,-0.696376225596872,-0.717676913675962,25.448349,27.709028,15.624936,65.4583 +230,-0.683919421624611,-0.729557554086488,24.195,25.792586,9.333636,72.2917 +231,-0.671259957567532,-0.741222010848596,24.7825,26.833736,6.999289,67.4167 +232,-0.658401584698049,-0.752666827532008,25.409151,28.58465,16.666518,77 +233,-0.64534811322955,-0.763888612790543,24.508349,26.124764,18.54225,47 +234,-0.632103411187349,-0.774884041367041,22.119151,24.000422,9.833121,45.5417 +235,-0.618671403262503,-0.785649855078715,23.646651,25.625672,16.958236,60.5 +236,-0.605056069648849,-0.796182863782616,24.155849,26.626364,14.125811,77.1667 +237,-0.591261444863578,-0.806479946320945,24.9,27.542378,5.6254875,76.125 +238,-0.577291616551728,-0.816538051445916,23.96,25.946696,25.166339,85 +239,-0.563150724274919,-0.82635419872391,25.231773,26.765294,20.412153,56.1765 +240,-0.54884295828472,-0.835925479418637,21.923349,24.125228,10.708275,55.4583 +241,-0.534372558280979,-0.845249057353063,22.040849,23.250464,8.375536,54.8333 +242,-0.519743812155516,-0.854322169749827,22.863349,24.333986,5.5833311,59.7917 +243,-0.50496105472152,-0.863142128049912,22.785,24.584786,9.500332,63.9167 +244,-0.490028666429059,-0.871706318709322,22.236651,23.917328,9.375243,72.7083 +245,-0.47495107206705,-0.880012203973536,23.450849,25.792058,12.416775,71.6667 +246,-0.459732739452105,-0.888057322629493,25.330849,27.918314,13.833289,74.2083 +247,-0.444378178104613,-0.895839290734909,23.646651,25.292636,14.250632,79.0417 +248,-0.428891937912484,-0.903355802324685,17.38,18.0032,23.044181,88.6957 +249,-0.413278607782904,-0.910604630094216,20.160849,19.919114,6.5003936,91.7083 +250,-0.397542814282557,-0.917583626059394,21.793911,20.653826,12.914116,93.9565 +251,-0.381689220266659,-0.924290722193093,22.55,22.210436,8.333393,89.7917 +252,-0.365722523497269,-0.93072393103798,23.02,24.125492,10.291736,75.375 +253,-0.349647455251228,-0.936881346295431,22.706651,24.209114,7.708618,71.375 +254,-0.333468778918187,-0.942761143390421,22.284356,23.74058,5.957171,69.2174 +255,-0.317191288589106,-0.948361580012172,22.589151,23.834564,9.500868,71.25 +256,-0.300819807635668,-0.953680996630446,23.646651,25.3754,11.2091,69.7083 +257,-0.284359187281004,-0.958717816987296,19.1425,20.542286,18.166782,70.9167 +258,-0.267814305162175,-0.963470548564149,14.050849,14.45735,11.000261,59.0417 +259,-0.25119006388482,-0.967937783024064,15.108349,15.581792,12.708225,71.8333 +260,-0.234491389570411,-0.972118196629061,15.8525,16.375442,11.958361,69.5 +261,-0.217723230396532,-0.976010550632368,17.810849,18.95855,10.166714,69 +262,-0.200890555130635,-0.97961369164549,18.398349,19.126322,9.041918,88.125 +263,-0.18399835165768,-0.982926551979982,19.965,20.335178,6.4590814,90 +264,-0.16705162550212,-0.98594814996383,21.531651,20.627558,8.584375,90.2083 +265,-0.150055398344653,-0.988677590232341,20.630849,18.46025,5.2505689,97.25 +266,-0.133014706534196,-0.991114063993455,20.513349,21.251192,5.2516811,86.25 +267,-0.115934599595501,-0.993256849267414,21.805849,21.794042,3.3754064,84.5 +268,-0.0988201387328721,-0.995105311100698,22.510849,22.876772,7.4169,84.8333 +269,-0.0816763953304229,-0.99665890175417,21.923349,21.91865,7.917457,88.5417 +270,-0.0645084494493171,-0.997917160865392,21.845,21.960428,9.958143,84.875 +271,-0.0473213883224323,-0.998879715585034,20.983349,21.917792,11.583161,69.9167 +272,-0.0301203048469084,-0.999546280687357,18.515849,19.958714,13.833825,64.75 +273,-0.0129102960750088,-0.999916658654738,11.27,11.248958,19.583832,75.375 +274,0.00430353829624382,-0.99999073973619,8.763349,6.790922,14.874871,79.1667 +275,0.0215160974362222,-0.999768501979891,10.055849,9.875036,5.5841686,76.0833 +276,0.038722280892174,-0.999250011239683,14.755849,15.208628,13.792218,71 +277,0.055916990100603,-0.998435421155564,17.301651,18.791108,11.87575,64.7917 +278,0.0730951298980769,-0.997324973108156,15.225849,15.70805,9.041918,62.0833 +279,0.0902516100310407,-0.995918996147179,16.009151,17.290664,1.5002439,68.4167 +280,0.107381346664162,-0.994217906893952,16.518349,17.873972,3.0420814,70.125 +281,0.124479263886789,-0.992222209417932,17.419151,18.582878,4.25115,72.75 +282,0.141540295217043,-0.989932495087353,18.829151,19.83305,2.8343814,73.375 +283,0.158559385103135,-0.987349442393986,18.633349,20.042336,9.583814,80.875 +284,0.175531490421428,-0.984473816752092,17.536651,18.169322,16.62605,90.625 +285,0.19245158197083,-0.981306470271609,19.690849,20.419064,9.499729,89.6667 +286,0.209314645963048,-0.977848341505657,17.889151,18.95855,15.000161,71.625 +287,0.226115685508288,-0.97410045517242,15.813349,16.91585,17.291561,48.3333 +288,0.242849722095935,-0.970063921851507,16.048349,17.208164,18.875039,48.6667 +289,0.259511797069799,-0.965739937654855,17.105849,17.70785,11.750393,57.9583 +290,0.276096973097468,-0.961129783872301,17.0275,18.499586,7.375829,70.1667 +291,0.292600335633348,-0.956234826591906,17.461733,17.913968,16.303713,89.5217 +292,0.309016994374947,-0.951056516295153,14.364151,14.79065,28.292425,63.625 +293,0.32534208471198,-0.945596387427143,12.0925,11.957336,14.833532,57.4167 +294,0.341570769167855,-0.939856057941895,11.8575,12.082472,6.2086689,62.9167 +295,0.357698238833125,-0.933837228822925,11.818349,11.873978,6.6673375,74.125 +296,0.373719714790468,-0.927541683579197,13.776651,14.166422,7.959064,77.2083 +297,0.389630449530788,-0.920971287716635,14.168349,14.58275,11.166086,62.2917 +298,0.405425728359997,-0.914127988185334,14.755849,15.207836,9.959014,72.0417 +299,0.421100870796089,-0.907013812802636,14.09,14.165036,13.250121,81.2917 +300,0.436651231956063,-0.899630869652244,7.549151,5.041592,15.375093,58.5833 +301,0.452072203932305,-0.891981346459548,3.945849,-0.957742,23.541857,88.25 +302,0.467359217158002,-0.884067509943364,7.000849,5.207714,11.833339,62.375 +303,0.482507741761218,-0.875891705144243,7.98,7.500158,7.12545,70.3333 +304,0.49751328890718,-0.867456354729597,10.839151,10.207808,9.083257,68.375 +305,0.512371412128424,-0.858763958275803,9.7425,9.748778,5.5001439,71.875 +306,0.527077708642372,-0.849817091527528,11.191651,10.790786,9.166739,70.2083 +307,0.541627820655981,-0.840618405634478,10.956651,10.623872,18.209193,62.25 +308,0.556017436657044,-0.831170626365808,7.353349,5.374364,12.667154,51.9167 +309,0.570242292691787,-0.821476553302414,8.371651,7.915628,6.1676314,73.4583 +310,0.584298173628368,-0.811539059007361,10.565,10.457486,3.834075,75.875 +311,0.598180914405917,-0.801361088174676,11.191651,11.208236,4.6255125,72.1667 +312,0.611886401268724,-0.790945656756777,10.8,10.999214,4.1671186,75.8333 +313,0.625410572985246,-0.780295851070775,9.86,8.665586,12.667489,81.3333 +314,0.638749422051527,-0.769414826883938,7.235849,4.249922,21.083225,44.625 +315,0.651898995878713,-0.758305808478562,8.763349,7.624172,14.208154,55.2917 +316,0.664855397964286,-0.746972087696555,12.719151,12.4163,18.875307,45.8333 +317,0.677614789046689,-0.735417022963986,16.91,18.624392,20.541932,58.7083 +318,0.690173388242971,-0.723644038295913,16.91,17.500214,13.375411,68.875 +319,0.702527474169157,-0.711656622281775,13.463349,13.831208,9.167543,93 +320,0.714673386042961,-0.699458327051647,8.058349,5.332586,20.459254,57.5833 +321,0.726607524768566,-0.687052767223667,4.885849,1.999586,11.291711,41 +322,0.738326354003106,-0.674443618832945,7.470849,5.415878,15.041232,50.2083 +323,0.749826401204569,-0.661634618242278,13.776651,14.165828,12.45865,68.4583 +324,0.761104258660774,-0.648629561034981,13.0325,13.374092,9.249618,91 +325,0.772156584499164,-0.635432300890177,11.583349,11.831936,7.959064,96.25 +326,0.782980103677063,-0.622046748440868,12.719151,12.415442,22.500275,75.7917 +327,0.793571608952147,-0.608476870115126,9.546651,8.583086,11.209368,54.9167 +328,0.803927961832821,-0.594726686960764,9.625,9.124286,6.6260186,64.375 +329,0.814046093508218,-0.580800273453801,9.664151,9.415742,4.5841936,68.1667 +330,0.823923005757554,-0.566701756291118,13.580849,14.0828,13.999918,69.8333 +331,0.83355577183857,-0.552435313167619,15.663466,16.348052,9.522174,74.3043 +332,0.842941537354783,-0.538005171538299,13.541651,13.79075,17.292164,83.0833 +333,0.852077521101309,-0.52341560736555,7.275,4.540586,18.167586,61.3333 +334,0.860961015888994,-0.508670943852104,6.6875,4.166564,14.750586,52.4583 +335,0.869589389346611,-0.493775550159977,6.765849,5.874578,6.750518,62.5833 +336,0.877960084700888,-0.478733840115789,6.060849,4.499864,6.4174811,61.2917 +337,0.886070621534138,-0.463550270902851,7.549151,7.0406,5.6252061,77.5833 +338,0.893918596519257,-0.448229341740411,10.134151,9.99905,4.1679561,82.7083 +339,0.901501684131884,-0.432775592550431,13.7375,14.1224,15.583061,94.9583 +340,0.908817637339503,-0.417193602612317,11.27,10.416236,17.833725,97.0417 +341,0.915864288267287,-0.401487989205973,4.494151,0.957908,16.083886,58 +342,0.922639548840488,-0.385663406243607,5.669151,4.957772,5.5420189,69.5833 +343,0.929141411403174,-0.369724542890673,4.925,1.583192,15.625807,50.75 +344,0.935367949313148,-0.353676122176372,2.379151,0.708164,4.4582939,49 +345,0.941317317512847,-0.337522899594113,3.201651,1.832936,4.25115,67.0833 +346,0.946987753076075,-0.321269661692365,5.2775,3.875108,9.41685,59 +347,0.952377575730397,-0.304921224656289,6.9225,6.331892,4.0842061,66.375 +348,0.957485188355039,-0.288482432880609,11.8575,11.207642,17.958814,63.4167 +349,0.962309077454148,-0.271958157534106,9.625,7.74845,17.458525,50.0417 +350,0.966847813605277,-0.255353295116187,4.141651,0.458486000000001,16.292189,56.0833 +351,0.97110005188295,-0.23867276600595,3.201651,0.208213999999998,11.375193,58.625 +352,0.975064532257195,-0.221921513004165,5.003349,2.541578,11.584032,63.75 +353,0.978740079966915,-0.205104499868619,10.134151,10.165964,4.1252436,59.5417 +354,0.982125605868,-0.188226709843244,12.131651,12.249122,14.8338,85.8333 +355,0.985220106756061,-0.171293144181478,11.896651,12.123986,3.167425,75.75 +356,0.988022665663698,-0.154308820664281,9.546651,8.915858,18.374482,68.625 +357,0.990532452132223,-0.137278772113265,6.2175,3.749972,12.750368,54.25 +358,0.99274872245774,-0.120208044899353,4.914801,2.477426,10.391097,68.1304 +359,0.994670819911521,-0.103101697447435,7.121733,4.82531,16.044155,50.6957 +360,0.996298174934608,-0.0859647987374468,7.275,5.623778,12.62615,76.25 +361,0.997630305306586,-0.0688024268023196,6.05911,2.478284,19.695387,50.3913 +362,0.998666816288476,-0.0516196672232536,3.671651,1.416872,8.000604,57.4167 +363,0.999407400739705,-0.0344216116227456,6.648349,5.041592,9.000579,63.6667 +364,0.999851839209116,-0.0172133561558346,11.27,11.331986,14.750318,61.5833 +365,1,0,9.39,8.790986,12.875189,69.25 +366,0.999851839209116,0.0172133561558353,4.833021,0.652063999999999,22.087555,38.1304 +367,0.999407400739705,0.0344216116227456,-0.95,-7.66585,24.499957,44.125 +368,0.998666816288476,0.0516196672232542,-2.9475,-8.123758,12.3749,41.4583 +369,0.997630305306586,0.0688024268023196,4.494151,2.375192,8.709129,52.4167 +370,0.996298174934608,0.0859647987374468,7.705849,6.457622,11.249836,54.2083 +371,0.994670819911521,0.103101697447434,10.486651,9.791414,11.708786,53.1667 +372,0.99274872245774,0.120208044899353,7.8625,6.457028,12.833314,46.5 +373,0.990532452132223,0.137278772113264,2.535849,0.333614000000001,6.6263,70.1667 +374,0.988022665663698,0.154308820664281,6.508712,5.042516,12.565984,64.6522 +375,0.985220106756061,0.171293144181478,4.885849,2.666186,8.791807,84.75 +376,0.982125605868001,0.188226709843244,9.9775,9.207908,12.124789,80.2917 +377,0.978740079966915,0.20510449986862,4.885849,0.457892000000001,25.333236,50.75 +378,0.975064532257195,0.221921513004165,0.459999999999999,-3.916258,12.541261,45.75 +379,0.97110005188295,0.238672766005951,-0.166651,-5.33275,16.834286,41.9167 +380,0.966847813605278,0.255353295116187,0.93,-3.416242,15.500986,52.25 +381,0.962309077454148,0.271958157534106,9.533021,8.042348,23.39171,71.6087 +382,0.957485188355039,0.288482432880608,6.256651,2.166764,27.833743,44.3333 +383,0.952377575730397,0.304921224656289,0.93,-3.457492,14.750586,49.75 +384,0.946987753076075,0.321269661692364,2.2225,-1.416772,13.58425,45 +385,0.941317317512847,0.337522899594113,0.146650999999999,-4.45825,14.917014,83.125 +386,0.935367949313148,0.353676122176372,-0.3625,-5.2915,13.375746,79.625 +387,0.929141411403174,0.369724542890673,2.261651,0.0418279999999989,7.417436,91.125 +388,0.922639548840487,0.385663406243608,8.0975,7.041128,8.292389,83.5833 +389,0.915864288267287,0.401487989205973,5.825849,3.458186,10.791757,64.375 +390,0.908817637339503,0.417193602612317,8.058349,7.4993,4.9175186,76.9583 +391,0.901501684131884,0.432775592550431,11.975,11.415278,22.958689,74.125 +392,0.893918596519257,0.448229341740411,6.844151,5.541014,14.125543,54.3333 +393,0.886070621534138,0.46355027090285,5.2775,1.999586,16.08335,31.125 +394,0.877960084700888,0.478733840115789,4.650849,1.33325,14.458064,40.0833 +395,0.869589389346611,0.493775550159978,10.33,9.166922,17.541739,41.6667 +396,0.860961015888994,0.508670943852104,14.050849,14.791508,12.667489,50.7917 +397,0.852077521101309,0.523415607365551,10.760849,10.332086,12.541529,67.2917 +398,0.842941537354783,0.538005171538299,6.726651,4.416836,11.959232,52.6667 +399,0.83355577183857,0.55243531316762,4.415849,1.99985,8.167032,77.9583 +400,0.823923005757555,0.566701756291117,4.494151,1.458386,11.791732,68.7917 +401,0.814046093508218,0.580800273453801,5.282623,3.564116,10.3046,62.2174 +402,0.803927961832822,0.594726686960763,8.645849,7.832864,9.874393,49.625 +403,0.793571608952147,0.608476870115126,4.063349,1.583786,8.959307,72.2917 +404,0.782980103677063,0.622046748440867,4.455,1.291208,13.000479,56.2083 +405,0.772156584499164,0.635432300890177,5.199151,3.374828,7.834243,54 +406,0.761104258660774,0.648629561034982,2.535849,-2.082778,19.416332,73.125 +407,0.749826401204569,0.661634618242278,-2.0075,-9.290572,27.417204,46.4583 +408,0.738326354003106,0.674443618832946,2.4575,-0.957742,11.207961,41.125 +409,0.726607524768566,0.687052767223667,7.000849,6.040436,9.458993,50.875 +410,0.714673386042961,0.699458327051647,8.371651,7.207514,12.1672,53.125 +411,0.702527474169157,0.711656622281774,6.883349,5.790692,6.125475,75.2917 +412,0.690173388242972,0.723644038295913,8.136651,7.207514,13.791682,63.4583 +413,0.677614789046689,0.735417022963985,8.293349,7.45805,12.792243,53.4583 +414,0.664855397964287,0.746972087696555,5.16,1.542008,16.958504,51.5833 +415,0.651898995878712,0.758305808478563,5.16,2.043806,15.348561,50.7826 +416,0.638749422051527,0.769414826883938,5.527822,3.477458,13.783039,59.4348 +417,0.625410572985246,0.780295851070776,10.604151,9.916022,15.709557,56.7917 +418,0.611886401268725,0.790945656756777,13.345849,13.333436,12.791171,55.4583 +419,0.598180914405916,0.801361088174677,11.1525,11.124086,15.916989,73.75 +420,0.584298173628369,0.811539059007361,5.669151,0.874549999999999,28.250014,39.5833 +421,0.570242292691787,0.821476553302414,5.120849,1.708328,13.750343,41 +422,0.556017436657045,0.831170626365808,9.233349,7.624964,17.958211,49.0833 +423,0.541627820655981,0.840618405634478,8.880849,7.33265,12.958939,39.5833 +424,0.527077708642373,0.849817091527527,8.184356,6.99902,12.000839,80.4783 +425,0.512371412128424,0.858763958275803,14.834151,15.374486,15.208129,61.5417 +426,0.49751328890718,0.867456354729597,8.606651,7.749572,9.708568,65.7083 +427,0.482507741761219,0.875891705144243,11.465849,11.290472,10.792293,62.125 +428,0.467359217158002,0.884067509943364,7.314151,3.999386,22.416257,40.3333 +429,0.452072203932305,0.891981346459548,3.436651,-0.0827140000000011,15.333486,50.625 +430,0.436651231956064,0.899630869652244,4.141651,0.832771999999999,13.458625,45.6667 +431,0.42110087079609,0.907013812802636,10.995849,9.4166,23.167193,51.3333 +432,0.405425728359997,0.914127988185334,16.7925,18.623864,29.584721,56.75 +433,0.389630449530789,0.920971287716634,11.309151,10.207478,27.7916,40.7083 +434,0.373719714790469,0.927541683579197,5.5125,2.332622,15.12525,35.0417 +435,0.357698238833125,0.933837228822925,9.001733,7.73822,14.913329,47.6957 +436,0.341570769167856,0.939856057941895,13.933349,14.333072,13.916771,48.9167 +437,0.32534208471198,0.945596387427143,18.555,19.833314,15.87565,61.75 +438,0.309016994374948,0.951056516295153,18.9075,20.208722,7.709154,50.7083 +439,0.292600335633348,0.956234826591906,18.2025,19.16645,10.042161,57.9583 +440,0.276096973097469,0.961129783872301,12.484151,12.791114,7.583864,84.2083 +441,0.2595117970698,0.965739937654855,16.165849,17.333036,7.417168,75.5833 +442,0.242849722095936,0.970063921851507,14.2075,14.624,8.501161,81 +443,0.226115685508288,0.97410045517242,17.615,19.166186,10.875239,72.875 +444,0.209314645963048,0.977848341505657,18.359151,19.543178,8.125157,80.7917 +445,0.19245158197083,0.981306470271609,16.988349,17.875028,6.0004061,82.125 +446,0.175531490421428,0.984473816752092,18.045849,19.083422,7.876654,83.125 +447,0.158559385103135,0.987349442393986,20.278349,21.624422,7.7921,69.4167 +448,0.141540295217043,0.989932495087353,15.6175,16.124378,12.916461,88.5417 +449,0.12447926388679,0.992222209417932,12.5625,12.874208,14.791925,88.0833 +450,0.107381346664163,0.994217906893952,12.954151,12.9575,25.917007,47.7917 +451,0.0902516100310416,0.995918996147179,7.196651,4.833164,12.541864,29 +452,0.0730951298980776,0.997324973108156,14.755849,15.0827,19.541957,48.125 +453,0.0559169901006039,0.998435421155564,15.225849,15.832064,21.41655,43.9167 +454,0.0387222808921745,0.999250011239683,9.39,8.790986,9.250489,58.0833 +455,0.0215160974362216,0.999768501979891,11.935849,11.832728,16.791339,73.8333 +456,0.00430353829624429,0.99999073973619,12.014151,11.540942,11.541889,67.625 +457,-0.0129102960750095,0.999916658654738,12.393911,12.215858,20.913313,50.4348 +458,-0.0301203048469079,0.999546280687357,13.933349,14.457878,6.708911,39.6667 +459,-0.0473213883224321,0.998879715585034,17.458349,19.2077,12.125325,46.9583 +460,-0.0645084494493158,0.997917160865392,12.445,12.456758,14.708443,37.4167 +461,-0.0816763953304226,0.99665890175417,10.956651,9.790622,20.125996,37.7083 +462,-0.0988201387328708,0.995105311100698,12.5625,12.124514,18.416357,25.4167 +463,-0.1159345995955,0.993256849267414,15.5,16.50005,15.583932,27.5833 +464,-0.133014706534197,0.991114063993455,14.990849,15.458108,23.999132,31.75 +465,-0.150055398344653,0.988677590232341,12.993349,12.791378,16.708125,43.5 +466,-0.16705162550212,0.98594814996383,8.388712,6.260084,19.783358,46.9565 +467,-0.18399835165768,0.982926551979982,10.6825,9.581864,19.458743,46.625 +468,-0.200890555130635,0.97961369164549,12.7975,12.499328,10.416557,40.8333 +469,-0.217723230396531,0.976010550632368,15.265,16.207736,12.791439,50.2917 +470,-0.23449138957041,0.972118196629061,20.513349,21.87575,15.083643,50.7917 +471,-0.251190063884819,0.967937783024064,23.215849,24.58505,19.083543,56.1667 +472,-0.267814305162174,0.963470548564149,20.591651,23.500142,18.333143,39.0417 +473,-0.284359187281003,0.958717816987297,13.776651,14.164508,11.250104,56.9167 +474,-0.300819807635668,0.953680996630446,15.421651,16.541036,4.4172564,61.25 +475,-0.317191288589107,0.948361580012172,16.753349,18.04115,10.041357,69.4583 +476,-0.333468778918187,0.942761143390421,18.79,19.832786,19.000329,68.2917 +477,-0.349647455251229,0.936881346295431,10.643349,9.707264,23.084582,83.5417 +478,-0.365722523497269,0.93072393103798,7.118349,3.87425,20.334232,76.6667 +479,-0.381689220266659,0.924290722193093,11.426651,10.748678,16.708661,45.4167 +480,-0.397542814282555,0.917583626059394,14.403349,15.040922,7.959064,42.7917 +481,-0.413278607782904,0.910604630094216,15.421651,15.916478,11.833875,75.6667 +482,-0.428891937912483,0.903355802324685,13.5025,13.874042,23.291411,40.0833 +483,-0.444378178104613,0.895839290734909,9.703349,8.915264,8.708325,48.9583 +484,-0.459732739452105,0.888057322629493,13.541651,13.707986,7.832836,58.7083 +485,-0.47495107206705,0.880012203973536,13.815849,14.207936,11.499746,57 +486,-0.49002866642906,0.871706318709322,20.826651,22.083386,10.458432,65.9583 +487,-0.50496105472152,0.863142128049912,18.515849,19.501136,9.249886,79.7083 +488,-0.519743812155516,0.854322169749827,18.32,19.457972,8.957632,76.8333 +489,-0.534372558280979,0.845249057353063,21.4925,23.000522,10.916846,73.5417 +490,-0.548842958284719,0.835925479418637,21.218349,22.584128,10.250464,75.6667 +491,-0.563150724274918,0.82635419872391,18.4375,20.084642,10.041893,74 +492,-0.577291616551727,0.816538051445916,17.2625,18.791372,15.458307,66.4167 +493,-0.591261444863578,0.806479946320945,19.338349,20.793086,19.833943,68.5833 +494,-0.605056069648849,0.796182863782616,19.025,20.49965,14.499604,74.4167 +495,-0.618671403262504,0.785649855078714,15.774151,16.457678,21.042221,55.2083 +496,-0.632103411187348,0.774884041367041,17.066651,18.374978,15.874779,36.0417 +497,-0.64534811322955,0.763888612790542,18.515849,19.957922,8.249911,48.0417 +498,-0.658401584698049,0.752666827532008,20.7875,22.625708,15.082839,57.625 +499,-0.671259957567532,0.741222010848596,18.946651,20.2934,14.250364,78.9583 +500,-0.68391942162461,0.729557554086488,20.748349,22.042664,9.875264,79.4583 +501,-0.696376225596872,0.717676913675962,21.923349,23.33435,8.208304,69.7917 +502,-0.708626678264459,0.705583610107178,19.886651,21.792458,15.374825,52 +503,-0.720667149553861,0.693281226886978,18.515849,20.373986,9.166739,52.3333 +504,-0.732494071613579,0.680773409477016,20.2,21.415928,5.626325,45.625 +505,-0.74410393987136,0.668063864213534,21.179151,22.541822,17.042589,53.0417 +506,-0.755493314072681,0.655156357209085,20.121651,21.334022,15.624668,81.125 +507,-0.766658819300159,0.642054713236564,20.905,22.33445,7.917189,76.5833 +508,-0.777597146973627,0.628762814595834,21.218349,22.584392,6.834,77.4583 +509,-0.788305055830525,0.615284599963328,22.785,24.0422,11.584032,71.6667 +510,-0.798779372886365,0.601624063224923,23.96,25.416914,9.41685,74.7083 +511,-0.809016994374947,0.587785252292474,24.5475,26.417936,13.332464,73.25 +512,-0.81901488666808,0.573772267904325,24.43,26.33405,14.416457,69.7083 +513,-0.828770087174504,0.559589262410176,25.4875,28.8338,13.166907,67.625 +514,-0.838279705217774,0.545240438540651,25.9575,28.417472,19.7918,68.4583 +515,-0.847540922892831,0.530730048161933,22.863349,24.334514,9.000043,67 +516,-0.856550995901004,0.516062391015853,23.96,25.667714,13.083693,49.2917 +517,-0.865307254363206,0.501241813445775,22.745849,24.125492,15.916721,75.5417 +518,-0.873807103611081,0.48627270710869,19.416651,21.375008,12.499654,54.9167 +519,-0.882048024955854,0.471159507673864,20.3175,21.958778,12.333829,49.3333 +520,-0.890027576434676,0.455906693508459,20.0825,22.166678,19.083811,48.7083 +521,-0.897743393534234,0.440518784350495,17.419151,18.708872,14.041525,61.3333 +522,-0.905193189891397,0.425000339969555,18.045849,19.791272,5.167375,61.125 +523,-0.912374757970727,0.409355958815622,20.3175,21.583172,10.54245,56.7083 +524,-0.919285969718611,0.393590276656466,22.510849,23.458892,11.750661,46.7917 +525,-0.92592477719385,0.377707965203965,25.409151,26.792222,9.667229,43.7083 +526,-0.932289213174514,0.361713730729767,26.153349,27.792122,8.959307,53.8333 +527,-0.938377391740864,0.345612312670734,25.879151,27.541586,13.916771,58.7917 +528,-0.944187508834199,0.32940848222453,22.706651,23.45975,14.374582,83.3333 +529,-0.949717842791432,0.313107040935827,22.824151,24.333722,22.999693,58.2083 +530,-0.954966754855255,0.29671281927349,22.471651,25.209278,17.000111,56.9583 +531,-0.959932689659744,0.280230675199217,22.040849,23.583764,11.833339,58.9583 +532,-0.964614175691244,0.263665492728008,21.688349,23.250728,11.166689,50.4167 +533,-0.969009825724406,0.247022180480935,19.8475,21.75035,9.708568,59.875 +534,-0.973118337233262,0.230305670230612,18.711651,19.959572,11.707982,77.7917 +535,-0.976938492777182,0.213520915439796,24.351651,27.209672,9.917139,69 +536,-0.980469160361632,0.196672889793576,28.7775,31.58435,7.625404,59.2083 +537,-0.98370929377361,0.179766585725562,29.874151,33.667772,7.958729,56.7917 +538,-0.986657932891657,0.162807012938517,28.5425,31.791986,12.250414,57.375 +539,-0.989314203970366,0.145799196919875,26.388349,27.084272,12.041307,53.4583 +540,-0.99167731989929,0.128748177452581,26.936651,28.500764,9.750175,47.9167 +541,-0.993746580436178,0.111659007121695,25.644151,27.166772,20.125661,50.4167 +542,-0.995521372414475,0.0945367498172,21.649151,23.250464,23.292014,37.3333 +543,-0.997001169925015,0.077386479233463,24.7825,26.292272,18.208925,36 +544,-0.998185534471859,0.0602132773657926,27.210849,28.583792,11.50055,42.25 +545,-0.99907411510223,0.0430222330045306,31.205849,35.916458,11.082939,48.875 +546,-0.999666648510511,0.0258184402271326,27.955,29.375528,10.791757,60.125 +547,-0.999962959116266,0.0086069968886887,30.344151,33.541514,11.291443,51.875 +548,-0.999962959116266,-0.0086069968886887,28.738349,30.334508,13.082889,44.7083 +549,-0.999666648510511,-0.0258184402271326,28.699151,30.3749,8.457879,49.2083 +550,-0.99907411510223,-0.0430222330045306,29.090849,32.334242,9.04165,53.875 +551,-0.998185534471859,-0.0602132773657926,30.8925,34.250222,12.999943,45.7917 +552,-0.997001169925015,-0.077386479233463,30.931651,33.667178,9.791514,45.0833 +553,-0.995521372414475,-0.0945367498172,32.498349,37.124258,10.958118,49.2083 +554,-0.993746580436178,-0.111659007121695,30.6575,36.166136,8.417143,57.375 +555,-0.99167731989929,-0.128748177452581,25.409151,27.167564,12.125325,68.3333 +556,-0.989314203970366,-0.145799196919875,25.879151,27.876536,10.166379,66.75 +557,-0.986657932891657,-0.162807012938517,25.683349,26.917886,10.166111,63.3333 +558,-0.98370929377361,-0.179766585725562,25.644151,27.209078,9.833925,52.9583 +559,-0.980469160361632,-0.196672889793576,26.388349,28.083578,5.41695,48.5833 +560,-0.976938492777182,-0.213520915439796,25.056651,27.958772,9.626493,69.9167 +561,-0.973118337233262,-0.230305670230612,27.054151,30.542936,11.166689,71.7917 +562,-0.969009825724406,-0.247022180480935,27.876651,31.79225,11.000529,64.5 +563,-0.964614175691244,-0.263665492728008,30.461651,33.875078,7.666743,50.5833 +564,-0.959932689659744,-0.280230675199217,29.286651,33.208478,9.208614,57.7083 +565,-0.954966754855255,-0.29671281927349,28.19,31.166372,11.083743,60.0417 +566,-0.949717842791432,-0.313107040935827,23.294151,24.45965,14.000789,84.4167 +567,-0.944187508834199,-0.32940848222453,20.004151,20.294192,14.2911,86.5417 +568,-0.938377391740864,-0.345612312670734,23.3725,25.12625,6.2926936,76.25 +569,-0.932289213174514,-0.361713730729767,26.858349,29.541122,9.291761,69.4167 +570,-0.92592477719385,-0.377707965203965,27.289151,30.6257,14.167418,65.5 +571,-0.919285969718611,-0.393590276656466,26.035849,27.167564,11.0416,45 +572,-0.912374757970727,-0.409355958815622,28.503349,32.791358,19.082471,59.6667 +573,-0.905193189891397,-0.425000339969555,28.738349,32.458322,10.250464,59.4583 +574,-0.897743393534234,-0.440518784350495,27.524151,30.041864,10.54245,61.3333 +575,-0.890027576434677,-0.455906693508459,25.918349,28.083578,11.416532,62.375 +576,-0.882048024955854,-0.471159507673864,26.349151,29.209142,10.292339,66.875 +577,-0.873807103611081,-0.48627270710869,25.526651,27.751136,11.083475,70.4167 +578,-0.865307254363206,-0.501241813445775,25.7225,28.042328,9.458993,67.75 +579,-0.856550995901004,-0.516062391015853,27.3675,30.667808,8.666718,65.9583 +580,-0.847540922892831,-0.530730048161933,27.994151,31.709222,14.458064,64.25 +581,-0.838279705217774,-0.545240438540651,29.286651,33.583622,17.249686,61.3333 +582,-0.828770087174504,-0.559589262410176,28.150849,32.251214,19.458207,65.25 +583,-0.81901488666808,-0.573772267904325,27.3675,30.876236,8.666718,65.4167 +584,-0.809016994374947,-0.587785252292474,26.584151,30.042986,7.832836,70.375 +585,-0.798779372886365,-0.601624063224923,27.25,30.709322,7.4169,67.2917 +586,-0.788305055830525,-0.615284599963328,27.524151,30.167528,10.4587,62.0417 +587,-0.777597146973627,-0.628762814595834,25.644151,28.084172,16.000471,71.5833 +588,-0.766658819300159,-0.642054713236564,24.5475,26.125622,13.834093,73.2917 +589,-0.755493314072681,-0.655156357209085,24.939151,26.542214,8.208304,53.0417 +590,-0.744103939871361,-0.668063864213534,25.879151,27.708764,9.126204,54.5417 +591,-0.732494071613579,-0.680773409477016,26.153349,28.667414,11.333586,68.6667 +592,-0.720667149553861,-0.693281226886978,25.213349,27.166442,11.374657,61.9583 +593,-0.70862667826446,-0.705583610107178,25.800849,27.209408,9.500332,51.9167 +594,-0.696376225596872,-0.717676913675962,25.996651,-0.00159999999999982,15.500718,57.0833 +595,-0.683919421624611,-0.729557554086488,23.881651,24.792686,11.917089,60.3333 +596,-0.671259957567532,-0.741222010848596,21.884151,23.834564,5.79215,71.1667 +597,-0.658401584698049,-0.752666827532008,21.884151,23.333822,8.708593,73.4167 +598,-0.645348113229551,-0.763888612790542,22.510849,23.66765,4.8756436,67.375 +599,-0.632103411187349,-0.774884041367041,23.3725,25.042364,4.7089811,67.7083 +600,-0.618671403262504,-0.785649855078714,24.704151,26.042528,5.6679186,63.5833 +601,-0.605056069648849,-0.796182863782616,25.0175,26.7086,4.8337686,61.5 +602,-0.591261444863578,-0.806479946320945,23.098349,24.833936,16.375336,71.2917 +603,-0.577291616551728,-0.816538051445916,22.706651,23.335736,15.333486,84.5833 +604,-0.563150724274918,-0.82635419872391,25.056651,27.209408,8.625111,73.0417 +605,-0.54884295828472,-0.835925479418637,26.231651,27.9593,12.791975,62 +606,-0.534372558280979,-0.845249057353063,24.195,25.958378,7.541654,55.2083 +607,-0.519743812155516,-0.854322169749827,25.213349,27.083414,5.1668189,59.0417 +608,-0.50496105472152,-0.863142128049912,27.915849,29.5004,11.291711,58.75 +609,-0.49002866642906,-0.871706318709322,27.406651,30.375164,7.583529,63.8333 +610,-0.47495107206705,-0.880012203973536,24.743349,26.834,4.2927436,81.5 +611,-0.459732739452105,-0.888057322629493,25.2525,27.667514,10.125107,79.0833 +612,-0.444378178104613,-0.895839290734909,26.114151,29.334608,15.833507,75.5 +613,-0.428891937912483,-0.903355802324685,26.623349,30.792878,12.583136,74.125 +614,-0.413278607782904,-0.910604630094216,24.743349,27.251714,9.542207,81.0417 +615,-0.397542814282556,-0.917583626059394,25.056651,27.375464,11.500282,73.625 +616,-0.381689220266659,-0.924290722193093,22.980849,24.333986,18.833968,79.9167 +617,-0.365722523497269,-0.93072393103798,20.67,22.20905,15.041232,54.75 +618,-0.349647455251229,-0.936881346295431,19.416651,21.333164,17.333771,50.375 +619,-0.333468778918187,-0.942761143390421,19.1425,20.583272,6.1676314,52 +620,-0.317191288589107,-0.948361580012172,20.160849,21.62495,8.833682,57.7083 +621,-0.300819807635668,-0.953680996630446,20.7875,22.250828,5.5422936,63.7083 +622,-0.284359187281003,-0.958717816987297,21.766651,23.209478,6.958821,67.25 +623,-0.267814305162175,-0.963470548564149,20.591651,22.667222,16.583907,50.1667 +624,-0.251190063884819,-0.967937783024064,19.26,21.16625,6.0422811,57 +625,-0.234491389570411,-0.972118196629061,19.299151,20.5013,10.166714,73.4583 +626,-0.217723230396531,-0.976010550632368,21.296651,21.294422,23.958329,87.25 +627,-0.200890555130636,-0.97961369164549,17.9675,19.666664,14.416725,53.6667 +628,-0.18399835165768,-0.982926551979982,17.693349,19.124672,7.917189,61.8333 +629,-0.16705162550212,-0.98594814996383,20.160849,21.750086,10.333343,66.875 +630,-0.150055398344653,-0.988677590232341,22.55,24.292208,19.000061,64.6667 +631,-0.133014706534196,-0.991114063993455,16.870849,18.249578,14.958286,46.7083 +632,-0.115934599595501,-0.993256849267414,16.165849,17.165858,9.541068,49.2917 +633,-0.0988201387328712,-0.995105311100698,17.85,19.915814,15.833507,57 +634,-0.0816763953304229,-0.99665890175417,21.845,23.376458,16.3748,63.0833 +635,-0.0645084494493163,-0.997917160865392,22.55,24.12635,9.000914,69.0833 +636,-0.0473213883224323,-0.998879715585034,21.100849,22.666958,10.999993,69 +637,-0.0301203048469084,-0.999546280687357,17.4975,18.999536,15.249468,54.2917 +638,-0.0129102960750097,-0.999916658654738,16.753349,18.165758,9.042186,58.3333 +639,0.00430353829624382,-0.99999073973619,16.479151,17.792,6.0838814,64.9167 +640,0.0215160974362213,-0.999768501979891,19.769151,19.793978,6.999825,87.1667 +641,0.038722280892174,-0.999250011239683,22.9025,23.542778,4.4585686,79.375 +642,0.0559169901006039,-0.998435421155564,22.9025,24.12635,7.875582,72.2917 +643,0.0730951298980769,-0.997324973108156,20.905,22.292342,7.12545,62.75 +644,0.0902516100310416,-0.995918996147179,18.045849,19.542386,17.957675,66.4167 +645,0.107381346664162,-0.994217906893952,11.544151,11.707658,9.457854,70.8333 +646,0.12447926388679,-0.992222209417932,10.016651,9.582128,12.708493,70.9583 +647,0.141540295217043,-0.989932495087353,12.993349,12.915392,12.7501,76.1667 +648,0.158559385103135,-0.987349442393986,16.165849,17.207372,12.584007,63.0833 +649,0.175531490421428,-0.984473816752092,12.445,12.457022,12.166932,46.3333 +650,0.19245158197083,-0.981306470271609,12.5625,12.582686,15.751164,53.9167 +651,0.209314645963048,-0.977848341505657,10.486651,9.832136,9.791514,49.4583 +652,0.226115685508288,-0.97410045517242,16.518349,17.541464,18.667004,64.0417 +653,0.242849722095936,-0.970063921851507,18.398349,19.5839,19.834479,70.75 +654,0.259511797069799,-0.965739937654855,14.011651,14.415836,12.208807,55.8333 +655,0.276096973097469,-0.961129783872301,13.424151,13.707128,6.791857,69.2917 +656,0.292600335633348,-0.956234826591906,16.5575,17.83325,15.874779,72.8333 +657,0.309016994374947,-0.951056516295153,18.476651,19.501136,9.041918,81.5 +658,0.32534208471198,-0.945596387427143,14.755849,15.207572,7.874979,57.2917 +659,0.341570769167855,-0.939856057941895,13.815849,14.124314,11.125618,51 +660,0.357698238833125,-0.933837228822925,14.9125,15.874172,5.4593811,56.8333 +661,0.373719714790468,-0.927541683579197,17.575849,19.000064,6.3345686,64.1667 +662,0.389630449530789,-0.920971287716634,19.6125,20.875586,4.8762064,63.625 +663,0.405425728359997,-0.914127988185334,17.85,18.959408,8.333125,80.0417 +664,0.42110087079609,-0.907013812802636,17.654151,18.5015,8.875289,80.7083 +665,0.436651231956063,-0.899630869652244,16.91,17.998778,15.791364,72 +666,0.452072203932305,-0.891981346459548,14.4425,14.872886,26.666536,69.4583 +667,0.467359217158002,-0.884067509943364,12.68,13.0004,23.9994,88 +668,0.482507741761218,-0.875891705144243,6.954554,4.453994,14.271603,82.5455 +669,0.49751328890718,-0.867456354729597,8.8025,7.8326,11.166689,66.6667 +670,0.512371412128424,-0.858763958275803,9.194151,8.416172,10.542182,58.1667 +671,0.527077708642372,-0.849817091527527,8.685,7.498772,17.833725,52.2083 +672,0.541627820655981,-0.840618405634478,8.136651,5.373836,18.125443,49.125 +673,0.556017436657045,-0.831170626365808,7.314151,5.749508,12.000236,53.2917 +674,0.570242292691787,-0.821476553302414,7.000849,4.33295,15.833775,49.4167 +675,0.584298173628368,-0.811539059007361,5.199151,2.583422,11.625371,56.7083 +676,0.598180914405916,-0.801361088174677,5.904151,2.124986,20.375236,54.75 +677,0.611886401268724,-0.790945656756777,8.552178,6.564806,23.304945,33.3478 +678,0.625410572985246,-0.780295851070776,8.998349,7.457258,14.375386,54.0833 +679,0.638749422051527,-0.769414826883938,10.290849,9.999842,3.8756686,64.5417 +680,0.651898995878712,-0.758305808478563,11.779151,11.833058,8.5425,65.9167 +681,0.664855397964286,-0.746972087696555,14.795,15.375278,11.625639,74.1667 +682,0.677614789046689,-0.735417022963985,8.136651,5.33285,22.917082,66.2917 +683,0.690173388242971,-0.723644038295913,5.590849,2.583158,13.374875,55.2083 +684,0.702527474169157,-0.711656622281774,7.118349,5.416472,10.250129,62.0417 +685,0.714673386042961,-0.699458327051647,8.215,6.915464,11.458675,52.4583 +686,0.726607524768566,-0.687052767223667,7.275,5.541278,12.041843,54.5417 +687,0.738326354003106,-0.674443618832946,8.0975,6.291236,15.250004,69.2917 +688,0.749826401204569,-0.661634618242278,9.899151,8.790986,15.749489,62.3333 +689,0.761104258660774,-0.648629561034982,9.585849,9.124022,5.542575,68.5 +690,0.772156584499164,-0.635432300890177,8.606651,8.082872,6.917482,61.375 +691,0.782980103677063,-0.622046748440867,7.98,7.124486,3.5423436,58.0417 +692,0.793571608952147,-0.608476870115126,9.311651,8.999414,9.917407,56.875 +693,0.803927961832821,-0.594726686960763,5.081651,0.416971999999998,25.250357,40.4583 +694,0.814046093508218,-0.580800273453801,3.554151,1.000478,10.0835,46.8333 +695,0.823923005757554,-0.566701756291117,6.726651,6.374264,3.12555,53.5417 +696,0.83355577183857,-0.55243531316762,5.708349,2.582828,15.916654,78.6667 +697,0.842941537354783,-0.538005171538299,5.943349,3.124292,14.125007,50.625 +698,0.852077521101309,-0.523415607365551,5.20089,3.695852,7.739974,55.5652 +699,0.860961015888994,-0.508670943852104,6.021651,5.375222,3.9175436,64.9583 +700,0.869589389346611,-0.493775550159978,6.021651,4.915664,4.0001814,80.6667 +701,0.877960084700888,-0.478733840115789,8.3325,7.707728,8.333393,82.3333 +702,0.886070621534138,-0.46355027090285,13.2675,14.082536,5.5422936,76.75 +703,0.893918596519257,-0.448229341740411,14.364151,14.957564,11.666643,73.375 +704,0.901501684131884,-0.432775592550431,12.601651,12.248792,21.709407,48.5 +705,0.908817637339503,-0.417193602612317,4.024151,1.041464,11.708518,50.875 +706,0.915864288267287,-0.401487989205973,7.079151,5.249228,8.7502,76.4167 +707,0.922639548840488,-0.385663406243608,9.938349,9.707528,6.792393,91.125 +708,0.929141411403174,-0.369724542890673,10.055849,9.749636,10.584325,90.5417 +709,0.935367949313148,-0.353676122176372,12.484151,12.74795,12.750636,92.5 +710,0.941317317512847,-0.337522899594113,8.606651,6.331958,19.834479,59.6667 +711,0.946987753076075,-0.321269661692364,5.9825,3.624308,10.916779,53.8333 +712,0.952377575730397,-0.304921224656289,5.904151,3.416408,11.666643,48.5833 +713,0.957485188355039,-0.288482432880608,5.238349,3.416672,8.792343,64.2917 +714,0.962309077454148,-0.271958157534106,7.235849,6.333278,7.12545,65.0417 +715,0.966847813605277,-0.255353295116187,9.0375,8.415908,6.749714,83.875 +716,0.97110005188295,-0.238672766005951,10.486651,10.499,6.5833061,90.7083 +717,0.975064532257195,-0.221921513004165,11.309151,11.040728,14.834068,66.625 +718,0.978740079966915,-0.20510449986862,7.6275,6.582692,12.334164,62.5417 +719,0.982125605868,-0.188226709843244,7.51,6.124322,8.875021,66.7917 +720,0.98522010675606,-0.171293144181478,7.353349,3.916622,25.083661,55.6667 +721,0.988022665663698,-0.154308820664281,4.494151,-0.416542000000002,27.292182,44.125 +722,0.990532452132223,-0.137278772113264,3.554151,1.125086,8.916561,51.5417 +723,0.99274872245774,-0.120208044899353,2.871288,1.0874,5.1744368,79.1304 +724,0.994670819911521,-0.103101697447434,5.691288,3.43469,11.304642,73.4783 +725,0.996298174934608,-0.0859647987374468,3.436651,-1.458022,21.208582,82.3333 +726,0.997630305306586,-0.0688024268023196,3.945849,-1.041628,23.458911,65.2917 +727,0.998666816288476,-0.0516196672232542,3.906651,0.833036,10.416557,59 +728,0.999407400739705,-0.0344216116227456,3.906651,-0.00159999999999982,8.333661,75.2917 +729,0.999851839209116,-0.0172133561558353,4.024151,-0.707800000000001,23.500518,48.3333 +730,1,0,2.144151,-1.249858,10.374682,57.75 diff --git a/inst/code_paper/x_train.csv b/inst/code_paper/x_train.csv new file mode 100644 index 0000000000000000000000000000000000000000..2cb46a76262d086e61c95fa988968bc965a5b8fd --- /dev/null +++ b/inst/code_paper/x_train.csv @@ -0,0 +1,586 @@ +trend,cosyear,sinyear,temp,atemp,windspeed,hum +414,0.664855397964287,0.746972087696555,5.16,1.542008,16.958504,51.5833 +462,-0.0988201387328708,0.995105311100698,12.5625,12.124514,18.416357,25.4167 +178,-0.997001169925015,0.077386479233463,26.975849,29.708828,9.666961,63.4167 +525,-0.92592477719385,0.377707965203965,25.409151,26.792222,9.667229,43.7083 +194,-0.980469160361632,-0.196672889793576,23.999151,25.916864,16.124689,47.625 +117,-0.428891937912483,0.903355802324685,21.0225,22.209314,21.500836,70.0833 +298,0.405425728359997,-0.914127988185334,14.755849,15.207836,9.959014,72.0417 +228,-0.70862667826446,-0.705583610107178,25.996651,28.000286,9.625689,57.5417 +243,-0.50496105472152,-0.863142128049912,22.785,24.584786,9.500332,63.9167 +13,0.975064532257195,0.221921513004166,-0.439109999999999,-3.564742,8.478716,53.7826 +373,0.990532452132223,0.137278772113264,2.535849,0.333614000000001,6.6263,70.1667 +664,0.42110087079609,-0.907013812802636,17.654151,18.5015,8.875289,80.7083 +601,-0.605056069648849,-0.796182863782616,25.0175,26.7086,4.8337686,61.5 +602,-0.591261444863578,-0.806479946320945,23.098349,24.833936,16.375336,71.2917 +708,0.929141411403174,-0.369724542890673,10.055849,9.749636,10.584325,90.5417 +90,0.0215160974362223,0.999768501979891,6.1,2.707964,17.333436,68.625 +347,0.952377575730397,-0.304921224656289,6.9225,6.331892,4.0842061,66.375 +648,0.158559385103135,-0.987349442393986,16.165849,17.207372,12.584007,63.0833 +354,0.982125605868,-0.188226709843244,12.131651,12.249122,14.8338,85.8333 +25,0.908817637339503,0.417193602612317,2.2225,-2.5624,19.68795,86.25 +518,-0.873807103611081,0.48627270710869,19.416651,21.375008,12.499654,54.9167 +425,0.512371412128424,0.858763958275803,14.834151,15.374486,15.208129,61.5417 +713,0.957485188355039,-0.288482432880608,5.238349,3.416672,8.792343,64.2917 +210,-0.890027576434677,-0.455906693508459,29.795849,32.083442,11.291979,46.5833 +589,-0.755493314072681,-0.655156357209085,24.939151,26.542214,8.208304,53.0417 +592,-0.720667149553861,-0.693281226886978,25.213349,27.166442,11.374657,61.9583 +554,-0.993746580436178,-0.111659007121695,30.6575,36.166136,8.417143,57.375 +372,0.99274872245774,0.120208044899353,7.8625,6.457028,12.833314,46.5 +142,-0.766658819300159,0.642054713236564,21.688349,22.959536,15.667414,81 +543,-0.997001169925015,0.077386479233463,24.7825,26.292272,18.208925,36 +489,-0.534372558280979,0.845249057353063,21.4925,23.000522,10.916846,73.5417 +620,-0.317191288589107,-0.948361580012172,20.160849,21.62495,8.833682,57.7083 +22,0.929141411403174,0.369724542890673,-3.4634801,-9.4766194,16.5222,43.6522 +308,0.556017436657044,-0.831170626365808,7.353349,5.374364,12.667154,51.9167 +134,-0.671259957567531,0.741222010848596,18.4375,19.376,10.249593,86.7083 +223,-0.766658819300159,-0.642054713236564,25.291651,27.166772,8.416607,41.5 +165,-0.954966754855255,0.29671281927349,21.453349,22.791764,11.250104,47.1667 +216,-0.838279705217774,-0.545240438540651,25.409151,27.333422,12.374632,63.0833 +289,0.259511797069799,-0.965739937654855,17.105849,17.70785,11.750393,57.9583 +580,-0.847540922892831,-0.530730048161933,27.994151,31.709222,14.458064,64.25 +71,0.341570769167856,0.939856057941895,10.064356,9.086006,18.130468,52.7391 +587,-0.777597146973627,-0.628762814595834,25.644151,28.084172,16.000471,71.5833 +574,-0.897743393534234,-0.440518784350495,27.524151,30.041864,10.54245,61.3333 +140,-0.74410393987136,0.668063864213534,20.3175,21.75035,8.08355,62.6667 +152,-0.865307254363206,0.501241813445776,25.605,26.500172,19.583229,30.5 +293,0.32534208471198,-0.945596387427143,12.0925,11.957336,14.833532,57.4167 +276,0.038722280892174,-0.999250011239683,14.755849,15.208628,13.792218,71 +729,0.999851839209116,-0.0172133561558353,4.024151,-0.707800000000001,23.500518,48.3333 +40,0.772156584499164,0.635432300890177,-1.215644,-6.129832,14.869645,43.7391 +430,0.436651231956064,0.899630869652244,4.141651,0.832771999999999,13.458625,45.6667 +89,0.0387222808921745,0.999250011239683,4.611651,0.999949999999998,14.582282,91.8333 +315,0.651898995878713,-0.758305808478562,8.763349,7.624172,14.208154,55.2917 +222,-0.777597146973627,-0.628762814595835,25.7225,27.000386,11.041332,42.375 +527,-0.938377391740864,0.345612312670734,25.879151,27.541586,13.916771,58.7917 +115,-0.397542814282556,0.917583626059394,21.688349,23.209478,21.8755,72.9167 +605,-0.54884295828472,-0.835925479418637,26.231651,27.9593,12.791975,62 +455,0.0215160974362216,0.999768501979891,11.935849,11.832728,16.791339,73.8333 +597,-0.658401584698049,-0.752666827532008,21.884151,23.333822,8.708593,73.4167 +38,0.793571608952147,0.608476870115126,2.379151,-2.915764,24.25065,53.7917 +158,-0.912374757970727,0.409355958815622,28.464151,32.000414,9.166739,62.2083 +208,-0.905193189891398,-0.425000339969554,28.620849,32.8334,11.958093,58.3333 +720,0.98522010675606,-0.171293144181478,7.353349,3.916622,25.083661,55.6667 +33,0.842941537354783,0.5380051715383,0.786979000000001,-4.260052,18.609384,43.7826 +515,-0.847540922892831,0.530730048161933,22.863349,24.334514,9.000043,67 +12,0.978740079966915,0.205104499868619,-0.244999999999999,-6.041722,20.167,47.0417 +68,0.389630449530788,0.920971287716635,10.287277,9.454088,17.545759,0 +408,0.738326354003106,0.674443618832946,2.4575,-0.957742,11.207961,41.125 +307,0.541627820655981,-0.840618405634478,10.956651,10.623872,18.209193,62.25 +277,0.055916990100603,-0.998435421155564,17.301651,18.791108,11.87575,64.7917 +88,0.0559169901006033,0.998435421155564,6.1,3.6251,11.583496,64.6667 +536,-0.980469160361632,0.196672889793576,28.7775,31.58435,7.625404,59.2083 +290,0.276096973097468,-0.961129783872301,17.0275,18.499586,7.375829,70.1667 +423,0.541627820655981,0.840618405634478,8.880849,7.33265,12.958939,39.5833 +285,0.19245158197083,-0.981306470271609,19.690849,20.419064,9.499729,89.6667 +120,-0.47495107206705,0.880012203973536,13.228349,13.581464,7.125718,76.2083 +109,-0.300819807635668,0.953680996630446,19.965,21.249872,16.208975,61.4167 +157,-0.905193189891397,0.425000339969554,25.2525,27.2927,12.583136,59.7917 +63,0.467359217158002,0.884067509943364,10.055849,8.999414,16.875357,78.9167 +482,-0.428891937912483,0.903355802324685,13.5025,13.874042,23.291411,40.0833 +476,-0.333468778918187,0.942761143390421,18.79,19.832786,19.000329,68.2917 +479,-0.381689220266659,0.924290722193093,11.426651,10.748678,16.708661,45.4167 +66,0.42110087079609,0.907013812802636,5.7475,3.9584,8.08355,42.0833 +84,0.124479263886789,0.992222209417932,4.494151,0.999686000000001,14.041257,39.4167 +164,-0.949717842791432,0.313107040935827,20.395849,23.042036,18.041961,50.7083 +50,0.651898995878713,0.758305808478562,5.405199,2.30378,14.956745,40.7826 +73,0.309016994374947,0.951056516295153,6.917377,4.999748,12.348703,65.5652 +177,-0.995521372414475,0.0945367498171996,24.0775,26.042264,7.208396,65.8333 +361,0.997630305306586,-0.0688024268023196,6.05911,2.478284,19.695387,50.3913 +235,-0.618671403262503,-0.785649855078715,23.646651,25.625672,16.958236,60.5 +609,-0.49002866642906,-0.871706318709322,27.406651,30.375164,7.583529,63.8333 +329,0.814046093508218,-0.580800273453801,9.664151,9.415742,4.5841936,68.1667 +126,-0.563150724274919,0.82635419872391,16.44,17.832986,10.75015,54.125 +211,-0.882048024955854,-0.471159507673864,29.874151,32.166536,11.042471,48.0833 +309,0.570242292691787,-0.821476553302414,8.371651,7.915628,6.1676314,73.4583 +242,-0.519743812155516,-0.854322169749827,22.863349,24.333986,5.5833311,59.7917 +112,-0.349647455251228,0.936881346295431,13.62,13.707986,15.458575,88.7917 +618,-0.349647455251229,-0.936881346295431,19.416651,21.333164,17.333771,50.375 +651,0.209314645963048,-0.977848341505657,10.486651,9.832136,9.791514,49.4583 +150,-0.847540922892831,0.530730048161934,28.425,31.875278,7.459043,63.6667 +613,-0.428891937912483,-0.903355802324685,26.623349,30.792878,12.583136,74.125 +159,-0.91928596971861,0.393590276656467,29.991651,34.000214,10.042161,56.8333 +390,0.908817637339503,0.417193602612317,8.058349,7.4993,4.9175186,76.9583 +154,-0.882048024955853,0.471159507673864,21.845,23.249936,8.250514,45.625 +709,0.935367949313148,-0.353676122176372,12.484151,12.74795,12.750636,92.5 +4,0.997630305306586,0.0688024268023199,2.666979,-0.868180000000001,12.5223,43.6957 +325,0.772156584499164,-0.635432300890177,11.583349,11.831936,7.959064,96.25 +279,0.0902516100310407,-0.995918996147179,16.009151,17.290664,1.5002439,68.4167 +566,-0.949717842791432,-0.313107040935827,23.294151,24.45965,14.000789,84.4167 +237,-0.591261444863578,-0.806479946320945,24.9,27.542378,5.6254875,76.125 +338,0.893918596519257,-0.448229341740411,10.134151,9.99905,4.1679561,82.7083 +672,0.541627820655981,-0.840618405634478,8.136651,5.373836,18.125443,49.125 +136,-0.696376225596872,0.717676913675962,18.398349,19.542914,18.582718,83.7917 +454,0.0387222808921745,0.999250011239683,9.39,8.790986,9.250489,58.0833 +559,-0.980469160361632,-0.196672889793576,26.388349,28.083578,5.41695,48.5833 +588,-0.766658819300159,-0.642054713236564,24.5475,26.125622,13.834093,73.2917 +82,0.158559385103135,0.987349442393986,5.395,1.874978,16.333729,80.5833 +699,0.860961015888994,-0.508670943852104,6.021651,5.375222,3.9175436,64.9583 +195,-0.976938492777182,-0.213520915439796,23.176651,25.208486,12.249811,59.125 +657,0.309016994374947,-0.951056516295153,18.476651,19.501136,9.041918,81.5 +675,0.584298173628368,-0.811539059007361,5.199151,2.583422,11.625371,56.7083 +499,-0.671259957567532,0.741222010848596,18.946651,20.2934,14.250364,78.9583 +343,0.929141411403174,-0.369724542890673,4.925,1.583192,15.625807,50.75 +637,-0.0301203048469084,-0.999546280687357,17.4975,18.999536,15.249468,54.2917 +458,-0.0301203048469079,0.999546280687357,13.933349,14.457878,6.708911,39.6667 +19,0.946987753076075,0.321269661692364,4.298349,0.833300000000001,13.125568,53.8333 +726,0.997630305306586,-0.0688024268023196,3.945849,-1.041628,23.458911,65.2917 +163,-0.944187508834199,0.32940848222453,21.845,23.709164,20.45845,49.4583 +51,0.638749422051527,0.769414826883938,6.256651,2.74895,20.625682,60.5 +533,-0.969009825724406,0.247022180480935,19.8475,21.75035,9.708568,59.875 +176,-0.993746580436178,0.111659007121695,23.96,26.083514,6.3337311,51.3333 +553,-0.995521372414475,-0.0945367498172,32.498349,37.124258,10.958118,49.2083 +83,0.141540295217043,0.989932495087353,4.415849,0.916591999999998,15.458575,49.5 +522,-0.905193189891397,0.425000339969555,18.045849,19.791272,5.167375,61.125 +391,0.901501684131884,0.432775592550431,11.975,11.415278,22.958689,74.125 +301,0.452072203932305,-0.891981346459548,3.945849,-0.957742,23.541857,88.25 +616,-0.381689220266659,-0.924290722193093,22.980849,24.333986,18.833968,79.9167 +429,0.452072203932305,0.891981346459548,3.436651,-0.0827140000000011,15.333486,50.625 +427,0.482507741761219,0.875891705144243,11.465849,11.290472,10.792293,62.125 +249,-0.413278607782904,-0.910604630094216,20.160849,19.919114,6.5003936,91.7083 +428,0.467359217158002,0.884067509943364,7.314151,3.999386,22.416257,40.3333 +397,0.852077521101309,0.523415607365551,10.760849,10.332086,12.541529,67.2917 +677,0.611886401268724,-0.790945656756777,8.552178,6.564806,23.304945,33.3478 +380,0.966847813605278,0.255353295116187,0.93,-3.416242,15.500986,52.25 +544,-0.998185534471859,0.0602132773657926,27.210849,28.583792,11.50055,42.25 +39,0.782980103677063,0.622046748440868,-1.665199,-6.477322,12.652213,49.4783 +521,-0.897743393534234,0.440518784350495,17.419151,18.708872,14.041525,61.3333 +472,-0.267814305162174,0.963470548564149,20.591651,23.500142,18.333143,39.0417 +199,-0.959932689659744,-0.280230675199216,28.503349,33.333614,8.7502,65.0417 +124,-0.534372558280979,0.845249057353063,13.580849,13.166522,19.791264,44.4167 +264,-0.16705162550212,-0.98594814996383,21.531651,20.627558,8.584375,90.2083 +185,-0.99907411510223,-0.0430222330045306,27.093349,29.958308,8.459286,59.0417 +572,-0.912374757970727,-0.409355958815622,28.503349,32.791358,19.082471,59.6667 +251,-0.381689220266659,-0.924290722193093,22.55,22.210436,8.333393,89.7917 +457,-0.0129102960750095,0.999916658654738,12.393911,12.215858,20.913313,50.4348 +151,-0.856550995901004,0.516062391015853,27.915849,31.583822,13.875164,67.7083 +53,0.611886401268724,0.790945656756777,2.421733,0.217321999999999,6.305571,42.3043 +537,-0.98370929377361,0.179766585725562,29.874151,33.667772,7.958729,56.7917 +234,-0.632103411187349,-0.774884041367041,22.119151,24.000422,9.833121,45.5417 +288,0.242849722095935,-0.970063921851507,16.048349,17.208164,18.875039,48.6667 +184,-0.999666648510511,-0.0258184402271331,26.153349,27.917522,5.4591064,63.7917 +412,0.690173388242972,0.723644038295913,8.136651,7.207514,13.791682,63.4583 +585,-0.798779372886365,-0.601624063224923,27.25,30.709322,7.4169,67.2917 +697,0.842941537354783,-0.538005171538299,5.943349,3.124292,14.125007,50.625 +575,-0.890027576434677,-0.455906693508459,25.918349,28.083578,11.416532,62.375 +204,-0.932289213174513,-0.361713730729768,31.01,36.458714,11.334457,55.0833 +660,0.357698238833125,-0.933837228822925,14.9125,15.874172,5.4593811,56.8333 +563,-0.964614175691244,-0.263665492728008,30.461651,33.875078,7.666743,50.5833 +629,-0.16705162550212,-0.98594814996383,20.160849,21.750086,10.333343,66.875 +719,0.982125605868,-0.188226709843244,7.51,6.124322,8.875021,66.7917 +345,0.941317317512847,-0.337522899594113,3.201651,1.832936,4.25115,67.0833 +630,-0.150055398344653,-0.988677590232341,22.55,24.292208,19.000061,64.6667 +467,-0.18399835165768,0.982926551979982,10.6825,9.581864,19.458743,46.625 +508,-0.777597146973627,0.628762814595834,21.218349,22.584392,6.834,77.4583 +56,0.570242292691787,0.821476553302414,5.2775,2.624672,12.500257,53.7917 +456,0.00430353829624429,0.99999073973619,12.014151,11.540942,11.541889,67.625 +356,0.988022665663698,-0.154308820664281,9.546651,8.915858,18.374482,68.625 +278,0.0730951298980769,-0.997324973108156,15.225849,15.70805,9.041918,62.0833 +269,-0.0816763953304229,-0.99665890175417,21.923349,21.91865,7.917457,88.5417 +346,0.946987753076075,-0.321269661692365,5.2775,3.875108,9.41685,59 +128,-0.591261444863578,0.806479946320945,17.0275,18.666236,11.792,58.875 +217,-0.828770087174504,-0.559589262410177,25.683349,28.626164,15.29275,75.5 +336,0.877960084700888,-0.478733840115789,6.060849,4.499864,6.4174811,61.2917 +711,0.946987753076075,-0.321269661692364,5.9825,3.624308,10.916779,53.8333 +538,-0.986657932891657,0.162807012938517,28.5425,31.791986,12.250414,57.375 +710,0.941317317512847,-0.337522899594113,8.606651,6.331958,19.834479,59.6667 +389,0.915864288267287,0.401487989205973,5.825849,3.458186,10.791757,64.375 +497,-0.64534811322955,0.763888612790542,18.515849,19.957922,8.249911,48.0417 +221,-0.788305055830526,-0.615284599963328,28.033349,29.208878,13.417286,42.4167 +420,0.584298173628369,0.811539059007361,5.669151,0.874549999999999,28.250014,39.5833 +557,-0.986657932891657,-0.162807012938517,25.683349,26.917886,10.166111,63.3333 +162,-0.938377391740864,0.345612312670734,24.5475,26.45945,10.958989,74.7917 +622,-0.284359187281003,-0.958717816987297,21.766651,23.209478,6.958821,67.25 +667,0.467359217158002,-0.884067509943364,12.68,13.0004,23.9994,88 +640,0.0215160974362213,-0.999768501979891,19.769151,19.793978,6.999825,87.1667 +224,-0.75549331407268,-0.655156357209085,24.234151,26.626628,14.167418,72.9583 +388,0.922639548840487,0.385663406243608,8.0975,7.041128,8.292389,83.5833 +116,-0.413278607782904,0.910604630094216,21.14,21.959372,20.9174,83.5417 +54,0.598180914405916,0.801361088174677,5.895644,3.086606,16.783232,69.7391 +693,0.803927961832821,-0.594726686960763,5.081651,0.416971999999998,25.250357,40.4583 +730,1,0,2.144151,-1.249858,10.374682,57.75 +133,-0.658401584698049,0.752666827532008,16.479151,17.041514,9.04165,92.25 +446,0.175531490421428,0.984473816752092,18.045849,19.083422,7.876654,83.125 +103,-0.200890555130635,0.97961369164549,13.9725,14.540972,7.4169,54.0417 +617,-0.365722523497269,-0.93072393103798,20.67,22.20905,15.041232,54.75 +209,-0.897743393534234,-0.440518784350495,31.401651,35.873822,11.667246,54.25 +348,0.957485188355039,-0.288482432880609,11.8575,11.207642,17.958814,63.4167 +400,0.823923005757555,0.566701756291117,4.494151,1.458386,11.791732,68.7917 +257,-0.284359187281004,-0.958717816987296,19.1425,20.542286,18.166782,70.9167 +718,0.978740079966915,-0.20510449986862,7.6275,6.582692,12.334164,62.5417 +385,0.941317317512847,0.337522899594113,0.146650999999999,-4.45825,14.917014,83.125 +687,0.738326354003106,-0.674443618832946,8.0975,6.291236,15.250004,69.2917 +23,0.922639548840488,0.385663406243607,-3.4226089,-8.21662,10.60811,49.1739 +465,-0.150055398344653,0.988677590232341,12.993349,12.791378,16.708125,43.5 +129,-0.605056069648849,0.796182863782616,17.0275,18.499586,7.749957,48.9167 +647,0.141540295217043,-0.989932495087353,12.993349,12.915392,12.7501,76.1667 +376,0.982125605868001,0.188226709843244,9.9775,9.207908,12.124789,80.2917 +169,-0.973118337233262,0.230305670230612,24.860849,26.625836,6.834,66.6667 +444,0.209314645963048,0.977848341505657,18.359151,19.543178,8.125157,80.7917 +233,-0.64534811322955,-0.763888612790543,24.508349,26.124764,18.54225,47 +421,0.570242292691787,0.821476553302414,5.120849,1.708328,13.750343,41 +507,-0.766658819300159,0.642054713236564,20.905,22.33445,7.917189,76.5833 +367,0.999407400739705,0.0344216116227456,-0.95,-7.66585,24.499957,44.125 +653,0.242849722095936,-0.970063921851507,18.398349,19.5839,19.834479,70.75 +79,0.209314645963049,0.977848341505657,12.230445,11.04251,19.348461,73.7391 +652,0.226115685508288,-0.97410045517242,16.518349,17.541464,18.667004,64.0417 +35,0.823923005757554,0.566701756291118,2.966651,0.0418279999999989,10.792293,92.9167 +474,-0.300819807635668,0.953680996630446,15.421651,16.541036,4.4172564,61.25 +504,-0.732494071613579,0.680773409477016,20.2,21.415928,5.626325,45.625 +659,0.341570769167855,-0.939856057941895,13.815849,14.124314,11.125618,51 +252,-0.365722523497269,-0.93072393103798,23.02,24.125492,10.291736,75.375 +342,0.922639548840488,-0.385663406243607,5.669151,4.957772,5.5420189,69.5833 +322,0.738326354003106,-0.674443618832945,7.470849,5.415878,15.041232,50.2083 +478,-0.365722523497269,0.93072393103798,7.118349,3.87425,20.334232,76.6667 +47,0.690173388242972,0.723644038295912,12.484151,12.291428,15.416968,50.5 +449,0.12447926388679,0.992222209417932,12.5625,12.874208,14.791925,88.0833 +110,-0.317191288589106,0.948361580012172,13.580849,13.956872,21.792286,40.7083 +704,0.901501684131884,-0.432775592550431,12.601651,12.248792,21.709407,48.5 +450,0.107381346664163,0.994217906893952,12.954151,12.9575,25.917007,47.7917 +392,0.893918596519257,0.448229341740411,6.844151,5.541014,14.125543,54.3333 +316,0.664855397964286,-0.746972087696555,12.719151,12.4163,18.875307,45.8333 +294,0.341570769167855,-0.939856057941895,11.8575,12.082472,6.2086689,62.9167 +701,0.877960084700888,-0.478733840115789,8.3325,7.707728,8.333393,82.3333 +286,0.209314645963048,-0.977848341505657,17.889151,18.95855,15.000161,71.625 +700,0.869589389346611,-0.493775550159978,6.021651,4.915664,4.0001814,80.6667 +72,0.32534208471198,0.945596387427143,7.285199,5.912,9.174042,49.6957 +291,0.292600335633348,-0.956234826591906,17.461733,17.913968,16.303713,89.5217 +225,-0.744103939871361,-0.668063864213534,23.803349,25.209608,14.916411,81.75 +662,0.389630449530789,-0.920971287716634,19.6125,20.875586,4.8762064,63.625 +377,0.978740079966915,0.20510449986862,4.885849,0.457892000000001,25.333236,50.75 +171,-0.980469160361632,0.196672889793576,23.999151,26.084636,11.458675,77.0417 +296,0.373719714790468,-0.927541683579197,13.776651,14.166422,7.959064,77.2083 +714,0.962309077454148,-0.271958157534106,7.235849,6.333278,7.12545,65.0417 +92,-0.012910296075009,0.999916658654738,9.781651,8.998622,12.208271,48 +582,-0.828770087174504,-0.559589262410176,28.150849,32.251214,19.458207,65.25 +724,0.994670819911521,-0.103101697447434,5.691288,3.43469,11.304642,73.4783 +614,-0.413278607782904,-0.910604630094216,24.743349,27.251714,9.542207,81.0417 +236,-0.605056069648849,-0.796182863782616,24.155849,26.626364,14.125811,77.1667 +516,-0.856550995901004,0.516062391015853,23.96,25.667714,13.083693,49.2917 +106,-0.251190063884819,0.967937783024064,13.463349,13.415936,20.334232,47.9583 +32,0.852077521101309,0.52341560736555,4.22,0.791522000000001,17.708636,77.5417 +615,-0.397542814282556,-0.917583626059394,25.056651,27.375464,11.500282,73.625 +395,0.869589389346611,0.493775550159978,10.33,9.166922,17.541739,41.6667 +353,0.978740079966915,-0.205104499868619,10.134151,10.165964,4.1252436,59.5417 +684,0.702527474169157,-0.711656622281774,7.118349,5.416472,10.250129,62.0417 +670,0.512371412128424,-0.858763958275803,9.194151,8.416172,10.542182,58.1667 +75,0.276096973097469,0.961129783872301,11.505,11.081978,14.041793,60.2917 +93,-0.0301203048469081,0.999546280687357,18.946651,19.833314,25.833257,42.625 +502,-0.708626678264459,0.705583610107178,19.886651,21.792458,15.374825,52 +29,0.877960084700888,0.478733840115789,2.176534,0.521252000000001,4.9568342,72.2174 +532,-0.964614175691244,0.263665492728008,21.688349,23.250728,11.166689,50.4167 +433,0.389630449530789,0.920971287716634,11.309151,10.207478,27.7916,40.7083 +174,-0.989314203970366,0.145799196919875,26.035849,27.334478,14.875675,57.3333 +669,0.49751328890718,-0.867456354729597,8.8025,7.8326,11.166689,66.6667 +610,-0.47495107206705,-0.880012203973536,24.743349,26.834,4.2927436,81.5 +114,-0.381689220266659,0.924290722193093,20.513349,21.917,12.417311,77.6667 +547,-0.999962959116266,0.0086069968886887,30.344151,33.541514,11.291443,51.875 +337,0.886070621534138,-0.463550270902851,7.549151,7.0406,5.6252061,77.5833 +95,-0.0645084494493162,0.997917160865392,10.369151,9.582128,17.625221,47.0833 +357,0.990532452132223,-0.137278772113265,6.2175,3.749972,12.750368,54.25 +514,-0.838279705217774,0.545240438540651,25.9575,28.417472,19.7918,68.4583 +658,0.32534208471198,-0.945596387427143,14.755849,15.207572,7.874979,57.2917 +627,-0.200890555130636,-0.97961369164549,17.9675,19.666664,14.416725,53.6667 +453,0.0559169901006039,0.998435421155564,15.225849,15.832064,21.41655,43.9167 +548,-0.999962959116266,-0.0086069968886887,28.738349,30.334508,13.082889,44.7083 +396,0.860961015888994,0.508670943852104,14.050849,14.791508,12.667489,50.7917 +403,0.793571608952147,0.608476870115126,4.063349,1.583786,8.959307,72.2917 +229,-0.696376225596872,-0.717676913675962,25.448349,27.709028,15.624936,65.4583 +147,-0.81901488666808,0.573772267904325,22.824151,24.417014,15.416164,72.9583 +349,0.962309077454148,-0.271958157534106,9.625,7.74845,17.458525,50.0417 +424,0.527077708642373,0.849817091527527,8.184356,6.99902,12.000839,80.4783 +673,0.556017436657045,-0.831170626365808,7.314151,5.749508,12.000236,53.2917 +422,0.556017436657045,0.831170626365808,9.233349,7.624964,17.958211,49.0833 +201,-0.949717842791432,-0.313107040935827,30.305,38.540486,14.875407,69.125 +80,0.19245158197083,0.981306470271609,12.758349,13.082372,15.12525,62.4583 +634,-0.0816763953304229,-0.99665890175417,21.845,23.376458,16.3748,63.0833 +231,-0.671259957567532,-0.741222010848596,24.7825,26.833736,6.999289,67.4167 +636,-0.0473213883224323,-0.998879715585034,21.100849,22.666958,10.999993,69 +105,-0.23449138957041,0.972118196629061,12.249151,12.082472,22.834136,88.8333 +374,0.988022665663698,0.154308820664281,6.508712,5.042516,12.565984,64.6522 +10,0.985220106756061,0.171293144181478,-0.0527230000000003,-3.363376,8.182844,68.6364 +635,-0.0645084494493163,-0.997917160865392,22.55,24.12635,9.000914,69.0833 +363,0.999407400739705,-0.0344216116227456,6.648349,5.041592,9.000579,63.6667 +569,-0.932289213174514,-0.361713730729767,26.858349,29.541122,9.291761,69.4167 +402,0.803927961832822,0.594726686960763,8.645849,7.832864,9.874393,49.625 +520,-0.890027576434676,0.455906693508459,20.0825,22.166678,19.083811,48.7083 +30,0.869589389346611,0.493775550159977,0.499151,-3.7075,12.541864,60.375 +413,0.677614789046689,0.735417022963985,8.293349,7.45805,12.792243,53.4583 +556,-0.989314203970366,-0.145799196919875,25.879151,27.876536,10.166379,66.75 +483,-0.444378178104613,0.895839290734909,9.703349,8.915264,8.708325,48.9583 +464,-0.133014706534197,0.991114063993455,14.990849,15.458108,23.999132,31.75 +438,0.309016994374948,0.951056516295153,18.9075,20.208722,7.709154,50.7083 +15,0.966847813605277,0.255353295116187,2.888349,-0.541677999999999,12.625011,48.375 +196,-0.973118337233262,-0.230305670230612,24.273349,26.125358,13.958914,58.5 +644,0.0902516100310416,-0.995918996147179,18.045849,19.542386,17.957675,66.4167 +416,0.638749422051527,0.769414826883938,5.527822,3.477458,13.783039,59.4348 +411,0.702527474169157,0.711656622281774,6.883349,5.790692,6.125475,75.2917 +598,-0.645348113229551,-0.763888612790542,22.510849,23.66765,4.8756436,67.375 +11,0.982125605868001,0.188226709843244,0.118169,-5.408782,20.410009,59.9545 +415,0.651898995878712,0.758305808478563,5.16,2.043806,15.348561,50.7826 +65,0.436651231956064,0.899630869652243,4.301733,-0.261574,22.870584,55.1304 +49,0.664855397964287,0.746972087696555,10.760849,9.832664,34.000021,18.7917 +203,-0.938377391740864,-0.345612312670734,31.910849,37.082942,8.791807,50 +459,-0.0473213883224321,0.998879715585034,17.458349,19.2077,12.125325,46.9583 +703,0.893918596519257,-0.448229341740411,14.364151,14.957564,11.666643,73.375 +530,-0.954966754855255,0.29671281927349,22.471651,25.209278,17.000111,56.9583 +383,0.952377575730397,0.304921224656289,0.93,-3.457492,14.750586,49.75 +121,-0.490028666429059,0.871706318709322,17.810849,19.166978,12.291418,73 +398,0.842941537354783,0.538005171538299,6.726651,4.416836,11.959232,52.6667 +593,-0.70862667826446,-0.705583610107178,25.800849,27.209408,9.500332,51.9167 +314,0.638749422051527,-0.769414826883938,7.235849,4.249922,21.083225,44.625 +258,-0.267814305162175,-0.963470548564149,14.050849,14.45735,11.000261,59.0417 +352,0.975064532257195,-0.221921513004165,5.003349,2.541578,11.584032,63.75 +247,-0.444378178104613,-0.895839290734909,23.646651,25.292636,14.250632,79.0417 +579,-0.856550995901004,-0.516062391015853,27.3675,30.667808,8.666718,65.9583 +689,0.761104258660774,-0.648629561034982,9.585849,9.124022,5.542575,68.5 +330,0.823923005757554,-0.566701756291118,13.580849,14.0828,13.999918,69.8333 +99,-0.133014706534196,0.991114063993455,12.053349,12.164642,9.833389,85.75 +107,-0.267814305162174,0.963470548564149,16.0875,17.207636,10.958989,54.25 +300,0.436651231956063,-0.899630869652244,7.549151,5.041592,15.375093,58.5833 +9,0.988022665663698,0.154308820664281,-0.910849000000001,-6.041392,14.958889,48.2917 +451,0.0902516100310416,0.995918996147179,7.196651,4.833164,12.541864,29 +624,-0.251190063884819,-0.967937783024064,19.26,21.16625,6.0422811,57 +650,0.19245158197083,-0.981306470271609,12.5625,12.582686,15.751164,53.9167 +466,-0.16705162550212,0.98594814996383,8.388712,6.260084,19.783358,46.9565 +401,0.814046093508218,0.580800273453801,5.282623,3.564116,10.3046,62.2174 +619,-0.333468778918187,-0.942761143390421,19.1425,20.583272,6.1676314,52 +568,-0.938377391740864,-0.345612312670734,23.3725,25.12625,6.2926936,76.25 +393,0.886070621534138,0.46355027090285,5.2775,1.999586,16.08335,31.125 +7,0.99274872245774,0.120208044899353,-0.244999999999999,-5.291236,17.875868,53.5833 +113,-0.365722523497269,0.930723931037979,19.338349,20.416358,12.875725,81.0833 +260,-0.234491389570411,-0.972118196629061,15.8525,16.375442,11.958361,69.5 +28,0.886070621534138,0.463550270902851,1.236534,-1.999684,9.739455,65.1739 +305,0.512371412128424,-0.858763958275803,9.7425,9.748778,5.5001439,71.875 +625,-0.234491389570411,-0.972118196629061,19.299151,20.5013,10.166714,73.4583 +645,0.107381346664162,-0.994217906893952,11.544151,11.707658,9.457854,70.8333 +281,0.124479263886789,-0.992222209417932,17.419151,18.582878,4.25115,72.75 +486,-0.49002866642906,0.871706318709322,20.826651,22.083386,10.458432,65.9583 +266,-0.133014706534196,-0.991114063993455,20.513349,21.251192,5.2516811,86.25 +261,-0.217723230396532,-0.976010550632368,17.810849,18.95855,10.166714,69 +695,0.823923005757554,-0.566701756291117,6.726651,6.374264,3.12555,53.5417 +409,0.726607524768566,0.687052767223667,7.000849,6.040436,9.458993,50.875 +707,0.922639548840488,-0.385663406243608,9.938349,9.707528,6.792393,91.125 +218,-0.81901488666808,-0.573772267904325,26.8975,31.209272,13.499629,75.2917 +183,-0.999962959116266,-0.0086069968886887,25.683349,28.12595,15.333486,68.25 +351,0.97110005188295,-0.23867276600595,3.201651,0.208213999999998,11.375193,58.625 +628,-0.18399835165768,-0.982926551979982,17.693349,19.124672,7.917189,61.8333 +118,-0.444378178104613,0.895839290734909,15.97,16.832558,16.084221,45.7083 +641,0.038722280892174,-0.999250011239683,22.9025,23.542778,4.4585686,79.375 +649,0.175531490421428,-0.984473816752092,12.445,12.457022,12.166932,46.3333 +655,0.276096973097469,-0.961129783872301,13.424151,13.707128,6.791857,69.2917 +406,0.761104258660774,0.648629561034982,2.535849,-2.082778,19.416332,73.125 +505,-0.74410393987136,0.668063864213534,21.179151,22.541822,17.042589,53.0417 +717,0.975064532257195,-0.221921513004165,11.309151,11.040728,14.834068,66.625 +239,-0.563150724274919,-0.82635419872391,25.231773,26.765294,20.412153,56.1765 +119,-0.459732739452104,0.888057322629493,14.2075,14.625386,15.750025,50.3333 +442,0.242849722095936,0.970063921851507,14.2075,14.624,8.501161,81 +303,0.482507741761218,-0.875891705144243,7.98,7.500158,7.12545,70.3333 +440,0.276096973097469,0.961129783872301,12.484151,12.791114,7.583864,84.2083 +686,0.726607524768566,-0.687052767223667,7.275,5.541278,12.041843,54.5417 +394,0.877960084700888,0.478733840115789,4.650849,1.33325,14.458064,40.0833 +104,-0.217723230396532,0.976010550632368,12.993349,13.166258,15.167125,67.125 +280,0.107381346664162,-0.994217906893952,16.518349,17.873972,3.0420814,70.125 +179,-0.998185534471859,0.060213277365793,26.231651,27.209408,17.542007,49.7917 +439,0.292600335633348,0.956234826591906,18.2025,19.16645,10.042161,57.9583 +240,-0.54884295828472,-0.835925479418637,21.923349,24.125228,10.708275,55.4583 +519,-0.882048024955854,0.471159507673864,20.3175,21.958778,12.333829,49.3333 +166,-0.959932689659744,0.280230675199216,21.531651,23.292836,13.833557,68.8333 +46,0.702527474169157,0.711656622281775,6.958267,4.8692,16.869997,42.3478 +190,-0.99167731989929,-0.128748177452581,27.1325,29.54165,12.292557,57.8333 +36,0.814046093508218,0.580800273453801,5.434151,3.250286,9.5006,56.8333 +173,-0.986657932891657,0.162807012938517,26.231651,29.792978,15.999868,70.3333 +567,-0.944187508834199,-0.32940848222453,20.004151,20.294192,14.2911,86.5417 +302,0.467359217158002,-0.884067509943364,7.000849,5.207714,11.833339,62.375 +206,-0.91928596971861,-0.393590276656467,28.268349,30.000614,13.417286,54.0833 +18,0.952377575730397,0.304921224656289,5.732178,3.695852,13.957239,74.1739 +583,-0.81901488666808,-0.573772267904325,27.3675,30.876236,8.666718,65.4167 +671,0.527077708642372,-0.849817091527527,8.685,7.498772,17.833725,52.2083 +341,0.915864288267287,-0.401487989205973,4.494151,0.957908,16.083886,58 +102,-0.18399835165768,0.982926551979982,11.3875,11.540678,16.791339,81.9167 +722,0.990532452132223,-0.137278772113264,3.554151,1.125086,8.916561,51.5417 +529,-0.949717842791432,0.313107040935827,22.824151,24.333722,22.999693,58.2083 +187,-0.997001169925015,-0.077386479233463,27.25,29.333486,10.6664,65.125 +138,-0.720667149553861,0.693281226886978,16.949151,17.708972,7.250271,82.9583 +633,-0.0988201387328712,-0.995105311100698,17.85,19.915814,15.833507,57 +654,0.259511797069799,-0.965739937654855,14.011651,14.415836,12.208807,55.8333 +188,-0.995521372414475,-0.0945367498171996,25.330849,28.251878,15.083643,75.7917 +310,0.584298173628368,-0.811539059007361,10.565,10.457486,3.834075,75.875 +506,-0.755493314072681,0.655156357209085,20.121651,21.334022,15.624668,81.125 +541,-0.993746580436178,0.111659007121695,25.644151,27.166772,20.125661,50.4167 +37,0.803927961832821,0.594726686960763,4.768349,4.041428,3.0423561,73.8333 +599,-0.632103411187349,-0.774884041367041,23.3725,25.042364,4.7089811,67.7083 +318,0.690173388242971,-0.723644038295913,16.91,17.500214,13.375411,68.875 +340,0.908817637339503,-0.417193602612317,11.27,10.416236,17.833725,97.0417 +517,-0.865307254363206,0.501241813445775,22.745849,24.125492,15.916721,75.5417 +555,-0.99167731989929,-0.128748177452581,25.409151,27.167564,12.125325,68.3333 +335,0.869589389346611,-0.493775550159977,6.765849,5.874578,6.750518,62.5833 +20,0.941317317512847,0.337522899594113,0.342499999999999,-5.583022,23.667214,45.7083 +198,-0.964614175691244,-0.263665492728008,27.093349,30.45905,14.458868,65.125 +86,0.0902516100310412,0.995918996147179,4.424356,0.999884000000002,14.217668,30.2174 +690,0.772156584499164,-0.635432300890177,8.606651,8.082872,6.917482,61.375 +542,-0.995521372414475,0.0945367498172,21.649151,23.250464,23.292014,37.3333 +473,-0.284359187281003,0.958717816987297,13.776651,14.164508,11.250104,56.9167 +437,0.32534208471198,0.945596387427143,18.555,19.833314,15.87565,61.75 +358,0.99274872245774,-0.120208044899353,4.914801,2.477426,10.391097,68.1304 +360,0.996298174934608,-0.0859647987374468,7.275,5.623778,12.62615,76.25 +331,0.83355577183857,-0.552435313167619,15.663466,16.348052,9.522174,74.3043 +5,0.996298174934608,0.0859647987374465,1.604356,-0.608205999999999,6.0008684,51.8261 +127,-0.577291616551727,0.816538051445916,16.831651,18.249578,5.0007125,63.1667 +155,-0.890027576434677,0.455906693508459,22.471651,24.709064,9.292364,65.25 +287,0.226115685508288,-0.97410045517242,15.813349,16.91585,17.291561,48.3333 +48,0.677614789046689,0.735417022963986,16.518349,17.790878,17.749975,51.6667 +226,-0.732494071613579,-0.680773409477017,23.294151,24.667022,13.999918,71.2083 +238,-0.577291616551728,-0.816538051445916,23.96,25.946696,25.166339,85 +192,-0.986657932891657,-0.162807012938517,29.325849,32.79215,13.417018,55.9167 +418,0.611886401268725,0.790945656756777,13.345849,13.333436,12.791171,55.4583 +443,0.226115685508288,0.97410045517242,17.615,19.166186,10.875239,72.875 +189,-0.993746580436178,-0.111659007121695,26.466651,27.834428,11.250104,60.9167 +111,-0.333468778918187,0.942761143390421,7.823349,5.248964,14.707907,72.9583 +500,-0.68391942162461,0.729557554086488,20.748349,22.042664,9.875264,79.4583 +364,0.999851839209116,-0.0172133561558346,11.27,11.331986,14.750318,61.5833 +493,-0.591261444863578,0.806479946320945,19.338349,20.793086,19.833943,68.5833 +468,-0.200890555130635,0.97961369164549,12.7975,12.499328,10.416557,40.8333 +463,-0.1159345995955,0.993256849267414,15.5,16.50005,15.583932,27.5833 +58,0.541627820655981,0.840618405634478,11.141831,10.407788,19.408962,87.6364 +60,0.512371412128424,0.858763958275803,7.745,5.124686,20.624811,44.9583 +405,0.772156584499164,0.635432300890177,5.199151,3.374828,7.834243,54 +698,0.852077521101309,-0.523415607365551,5.20089,3.695852,7.739974,55.5652 +590,-0.744103939871361,-0.668063864213534,25.879151,27.708764,9.126204,54.5417 +87,0.0730951298980776,0.997324973108156,6.2175,3.331928,15.208732,31.4167 +131,-0.632103411187349,0.774884041367041,17.145,18.541958,12.707689,74.75 +691,0.782980103677063,-0.622046748440867,7.98,7.124486,3.5423436,58.0417 +250,-0.397542814282557,-0.917583626059394,21.793911,20.653826,12.914116,93.9565 +202,-0.944187508834199,-0.32940848222453,31.871651,39.499136,8.9177,58.0417 +245,-0.47495107206705,-0.880012203973536,23.450849,25.792058,12.416775,71.6667 +535,-0.976938492777182,0.213520915439796,24.351651,27.209672,9.917139,69 +417,0.625410572985246,0.780295851070776,10.604151,9.916022,15.709557,56.7917 +545,-0.99907411510223,0.0430222330045306,31.205849,35.916458,11.082939,48.875 +469,-0.217723230396531,0.976010550632368,15.265,16.207736,12.791439,50.2917 +130,-0.618671403262503,0.785649855078715,17.4975,18.8744,8.083014,63.2917 +471,-0.251190063884819,0.967937783024064,23.215849,24.58505,19.083543,56.1667 +161,-0.932289213174513,0.361713730729768,26.075,28.750508,10.37495,65.4583 +558,-0.98370929377361,-0.179766585725562,25.644151,27.209078,9.833925,52.9583 +167,-0.964614175691244,0.263665492728008,22.510849,23.625278,9.582943,73.5833 +77,0.242849722095936,0.970063921851507,14.2075,14.79065,24.667189,37.9167 +399,0.83355577183857,0.55243531316762,4.415849,1.99985,8.167032,77.9583 +313,0.625410572985246,-0.780295851070775,9.86,8.665586,12.667489,81.3333 +94,-0.0473213883224319,0.998879715585034,11.465849,10.2911,26.000489,64.2083 +220,-0.798779372886365,-0.601624063224923,28.425,31.791986,10.125107,57.0417 +509,-0.788305055830525,0.615284599963328,22.785,24.0422,11.584032,71.6667 +160,-0.92592477719385,0.377707965203965,27.485,30.417272,9.417118,60.5 +612,-0.444378178104613,-0.895839290734909,26.114151,29.334608,15.833507,75.5 +241,-0.534372558280979,-0.845249057353063,22.040849,23.250464,8.375536,54.8333 +180,-0.99907411510223,0.0430222330045306,24.743349,26.042528,12.415904,43.4167 +561,-0.973118337233262,-0.230305670230612,27.054151,30.542936,11.166689,71.7917 +723,0.99274872245774,-0.120208044899353,2.871288,1.0874,5.1744368,79.1304 +371,0.994670819911521,0.103101697447434,10.486651,9.791414,11.708786,53.1667 +16,0.962309077454149,0.271958157534106,0.264151,-4.333114,12.999139,53.75 +550,-0.99907411510223,-0.0430222330045306,29.090849,32.334242,9.04165,53.875 +186,-0.998185534471859,-0.060213277365793,25.84,29.251778,10.042161,74.3333 +248,-0.428891937912484,-0.903355802324685,17.38,18.0032,23.044181,88.6957 +571,-0.919285969718611,-0.393590276656466,26.035849,27.167564,11.0416,45 +170,-0.976938492777182,0.213520915439796,21.845,23.292836,10.416825,74.625 +284,0.175531490421428,-0.984473816752092,17.536651,18.169322,16.62605,90.625 +253,-0.349647455251228,-0.936881346295431,22.706651,24.209114,7.708618,71.375 +227,-0.720667149553861,-0.693281226886978,24.939151,26.625242,15.834043,57.8333 +44,0.726607524768566,0.687052767223667,11.505,10.2911,27.999836,37.5833 +135,-0.68391942162461,0.729557554086488,19.1425,20.333792,8.500357,78.7917 +78,0.226115685508288,0.97410045517242,7.6275,5.4995,13.917307,47.375 +604,-0.563150724274918,-0.82635419872391,25.056651,27.209408,8.625111,73.0417 +503,-0.720667149553861,0.693281226886978,18.515849,20.373986,9.166739,52.3333 +362,0.998666816288476,-0.0516196672232536,3.671651,1.416872,8.000604,57.4167 +475,-0.317191288589107,0.948361580012172,16.753349,18.04115,10.041357,69.4583 +407,0.749826401204569,0.661634618242278,-2.0075,-9.290572,27.417204,46.4583 +694,0.814046093508218,-0.580800273453801,3.554151,1.000478,10.0835,46.8333 +259,-0.25119006388482,-0.967937783024064,15.108349,15.581792,12.708225,71.8333 +445,0.19245158197083,0.981306470271609,16.988349,17.875028,6.0004061,82.125 +27,0.893918596519257,0.448229341740411,1.563466,-1.261078,8.2611,79.3043 +642,0.0559169901006039,-0.998435421155564,22.9025,24.12635,7.875582,72.2917 +387,0.929141411403174,0.369724542890673,2.261651,0.0418279999999989,7.417436,91.125 +312,0.611886401268724,-0.790945656756777,10.8,10.999214,4.1671186,75.8333 +524,-0.919285969718611,0.393590276656466,22.510849,23.458892,11.750661,46.7917 +101,-0.167051625502119,0.98594814996383,15.6175,16.541564,18.416893,73.9167 +137,-0.70862667826446,0.705583610107178,17.85,18.792428,13.499964,87 +108,-0.284359187281004,0.958717816987296,15.774151,16.291028,10.584057,66.5833 +513,-0.828770087174504,0.559589262410176,25.4875,28.8338,13.166907,67.625 +682,0.677614789046689,-0.735417022963985,8.136651,5.33285,22.917082,66.2917 +244,-0.490028666429059,-0.871706318709322,22.236651,23.917328,9.375243,72.7083 +551,-0.998185534471859,-0.0602132773657926,30.8925,34.250222,12.999943,45.7917 +510,-0.798779372886365,0.601624063224923,23.96,25.416914,9.41685,74.7083 +725,0.996298174934608,-0.0859647987374468,3.436651,-1.458022,21.208582,82.3333 +193,-0.98370929377361,-0.179766585725562,27.093349,29.500664,9.790911,63.1667 +267,-0.115934599595501,-0.993256849267414,21.805849,21.794042,3.3754064,84.5 +139,-0.732494071613579,0.680773409477017,17.223349,18.916772,8.375871,71.9583 +512,-0.81901488666808,0.573772267904325,24.43,26.33405,14.416457,69.7083 +311,0.598180914405917,-0.801361088174676,11.191651,11.208236,4.6255125,72.1667 +639,0.00430353829624382,-0.99999073973619,16.479151,17.792,6.0838814,64.9167 +484,-0.459732739452105,0.888057322629493,13.541651,13.707986,7.832836,58.7083 +382,0.957485188355039,0.288482432880608,6.256651,2.166764,27.833743,44.3333 +274,0.00430353829624382,-0.99999073973619,8.763349,6.790922,14.874871,79.1667 +200,-0.954966754855255,-0.29671281927349,28.111651,33.2921,7.625739,70.7083 +603,-0.577291616551728,-0.816538051445916,22.706651,23.335736,15.333486,84.5833 +552,-0.997001169925015,-0.077386479233463,30.931651,33.667178,9.791514,45.0833 +268,-0.0988201387328721,-0.995105311100698,22.510849,22.876772,7.4169,84.8333 +1,0.999851839209116,0.0172133561558347,9.083466,7.346774,16.652113,69.6087 +64,0.452072203932304,0.891981346459549,9.696534,8.172632,23.000229,94.8261 +638,-0.0129102960750097,-0.999916658654738,16.753349,18.165758,9.042186,58.3333 +696,0.83355577183857,-0.55243531316762,5.708349,2.582828,15.916654,78.6667 +379,0.97110005188295,0.238672766005951,-0.166651,-5.33275,16.834286,41.9167 +606,-0.534372558280979,-0.845249057353063,24.195,25.958378,7.541654,55.2083 +123,-0.519743812155515,0.854322169749827,11.465849,10.7069,22.042732,73.7083 +678,0.625410572985246,-0.780295851070776,8.998349,7.457258,14.375386,54.0833 +6,0.994670819911521,0.103101697447435,1.236534,-2.216626,11.304642,49.8696 +706,0.915864288267287,-0.401487989205973,7.079151,5.249228,8.7502,76.4167 +546,-0.999666648510511,0.0258184402271326,27.955,29.375528,10.791757,60.125 +24,0.915864288267287,0.401487989205973,2.503466,-0.521284,8.696332,61.6957 +715,0.966847813605277,-0.255353295116187,9.0375,8.415908,6.749714,83.875 +426,0.49751328890718,0.867456354729597,8.606651,7.749572,9.708568,65.7083 +256,-0.300819807635668,-0.953680996630446,23.646651,25.3754,11.2091,69.7083 +511,-0.809016994374947,0.587785252292474,24.5475,26.417936,13.332464,73.25 +215,-0.847540922892831,-0.530730048161934,25.37,27.876008,13.20905,75.75 +168,-0.969009825724406,0.247022180480936,24.743349,26.500964,8.000336,67.0417 +441,0.2595117970698,0.965739937654855,16.165849,17.333036,7.417168,75.5833 +213,-0.865307254363206,-0.501241813445776,28.816651,30.666686,13.79195,49.125 +70,0.357698238833126,0.933837228822925,7.470849,5.4995,14.791925,59.4583 +14,0.97110005188295,0.23867276600595,2.966651,0.375392000000002,10.583521,49.875 +470,-0.23449138957041,0.972118196629061,20.513349,21.87575,15.083643,50.7917 +531,-0.959932689659744,0.280230675199217,22.040849,23.583764,11.833339,58.9583 +344,0.935367949313148,-0.353676122176372,2.379151,0.708164,4.4582939,49 +435,0.357698238833125,0.933837228822925,9.001733,7.73822,14.913329,47.6957 +404,0.782980103677063,0.622046748440867,4.455,1.291208,13.000479,56.2083 +540,-0.99167731989929,0.128748177452581,26.936651,28.500764,9.750175,47.9167 +306,0.527077708642372,-0.849817091527528,11.191651,10.790786,9.166739,70.2083 +688,0.749826401204569,-0.661634618242278,9.899151,8.790986,15.749489,62.3333 +3,0.998666816288476,0.0516196672232538,1.4,-1.999948,10.739832,59.0435 +384,0.946987753076075,0.321269661692364,2.2225,-1.416772,13.58425,45 +175,-0.99167731989929,0.128748177452581,24.665,26.458658,14.041257,48.3333 +323,0.749826401204569,-0.661634618242278,13.776651,14.165828,12.45865,68.4583 +526,-0.932289213174514,0.361713730729767,26.153349,27.792122,8.959307,53.8333 +595,-0.683919421624611,-0.729557554086488,23.881651,24.792686,11.917089,60.3333 +712,0.952377575730397,-0.304921224656289,5.904151,3.416408,11.666643,48.5833 +721,0.988022665663698,-0.154308820664281,4.494151,-0.416542000000002,27.292182,44.125 +621,-0.300819807635668,-0.953680996630446,20.7875,22.250828,5.5422936,63.7083 +727,0.998666816288476,-0.0516196672232542,3.906651,0.833036,10.416557,59 +591,-0.732494071613579,-0.680773409477016,26.153349,28.667414,11.333586,68.6667 +327,0.793571608952147,-0.608476870115126,9.546651,8.583086,11.209368,54.9167 +676,0.598180914405916,-0.801361088174677,5.904151,2.124986,20.375236,54.75 +45,0.714673386042961,0.699458327051647,4.506089,0.782084000000001,19.522058,31.4348 +600,-0.618671403262504,-0.785649855078714,24.704151,26.042528,5.6679186,63.5833 +549,-0.999666648510511,-0.0258184402271326,28.699151,30.3749,8.457879,49.2083 +97,-0.0988201387328714,0.995105311100698,7.784151,5.415614,15.208464,83.625 +255,-0.317191288589106,-0.948361580012172,22.589151,23.834564,9.500868,71.25 +67,0.405425728359997,0.914127988185334,5.904151,2.916128,14.75005,77.5417 +680,0.651898995878712,-0.758305808478563,11.779151,11.833058,8.5425,65.9167 +181,-0.999666648510511,0.0258184402271331,25.9575,27.042692,6.874736,39.625 +74,0.292600335633349,0.956234826591906,9.165199,8.21738,13.608839,77.6522 +132,-0.64534811322955,0.763888612790543,16.0875,16.6238,12.041575,86.3333 +230,-0.683919421624611,-0.729557554086488,24.195,25.792586,9.333636,72.2917 +34,0.83355577183857,0.55243531316762,1.931288,-0.913257999999999,8.565213,58.5217 +488,-0.519743812155516,0.854322169749827,18.32,19.457972,8.957632,76.8333 +632,-0.115934599595501,-0.993256849267414,16.165849,17.165858,9.541068,49.2917 +492,-0.577291616551727,0.816538051445916,17.2625,18.791372,15.458307,66.4167 +716,0.97110005188295,-0.238672766005951,10.486651,10.499,6.5833061,90.7083 +665,0.436651231956063,-0.899630869652244,16.91,17.998778,15.791364,72 +434,0.373719714790469,0.927541683579197,5.5125,2.332622,15.12525,35.0417 +197,-0.969009825724406,-0.247022180480936,25.800849,28.208978,16.417211,60.4167 diff --git a/inst/code_paper/xgb.model b/inst/code_paper/xgb.model new file mode 100644 index 0000000000000000000000000000000000000000..1282be78643c7277e171e82a4b00924cc1ec588c Binary files /dev/null and b/inst/code_paper/xgb.model differ diff --git a/inst/code_paper/y_explain.csv b/inst/code_paper/y_explain.csv new file mode 100644 index 0000000000000000000000000000000000000000..75cf372299324170d3764ba75c1990362b050506 --- /dev/null +++ b/inst/code_paper/y_explain.csv @@ -0,0 +1,147 @@ +y_explain +985 +1349 +822 +683 +981 +431 +1360 +1746 +1472 +1589 +1450 +1461 +2402 +1851 +1685 +1944 +1977 +3239 +2121 +1693 +2252 +3141 +2455 +3348 +4451 +4608 +4660 +4492 +4978 +4677 +4679 +4788 +4098 +5312 +4548 +4507 +5119 +4086 +3840 +4656 +4266 +3574 +4326 +3873 +4940 +4713 +3641 +4352 +2395 +3907 +4839 +5202 +2429 +3570 +5117 +4563 +4195 +4381 +4687 +2659 +4068 +4486 +1817 +3053 +3392 +2765 +2566 +2792 +2914 +3613 +3727 +2594 +2739 +3068 +1317 +2294 +1951 +2368 +3272 +4098 +2177 +2493 +2935 +1977 +4169 +3487 +4916 +5382 +5298 +8362 +3372 +5698 +6457 +6460 +1027 +6196 +5026 +5572 +6169 +6883 +6359 +4717 +6572 +7030 +6118 +7424 +7494 +4972 +5099 +7458 +6969 +6830 +5713 +6591 +7592 +6904 +7105 +7216 +7580 +6824 +7273 +7286 +7148 +4549 +7713 +7350 +6034 +8714 +4073 +7907 +8156 +5478 +7509 +7466 +7359 +4459 +1096 +5259 +6536 +6269 +5495 +5698 +3910 +6234 +5375 +1341 diff --git a/inst/code_paper/y_train.csv b/inst/code_paper/y_train.csv new file mode 100644 index 0000000000000000000000000000000000000000..e61d7b1183137d7b66c63a456314520e0f29d635 --- /dev/null +++ b/inst/code_paper/y_train.csv @@ -0,0 +1,586 @@ +y_train +2689 +6857 +4648 +7498 +5084 +4058 +3894 +4694 +5115 +1421 +2376 +7444 +7582 +6053 +3228 +2227 +3740 +7691 +2660 +506 +8120 +4990 +5611 +4475 +6544 +7347 +4672 +3425 +4274 +7335 +6296 +7870 +986 +3926 +4553 +4905 +5180 +4866 +4570 +7175 +2417 +5786 +6685 +5805 +4968 +4304 +4456 +1796 +1538 +3956 +1685 +4067 +4792 +6664 +4400 +7040 +6235 +6530 +1530 +4401 +4390 +3623 +1550 +6855 +1406 +623 +3422 +4046 +4826 +1536 +6211 +4748 +4363 +2913 +3351 +3944 +4833 +2077 +6233 +6624 +5633 +2133 +2496 +4891 +1812 +2056 +4708 +2302 +5130 +6140 +3068 +4714 +4302 +3649 +5058 +4036 +7525 +7109 +3982 +7112 +3915 +4075 +5342 +5170 +1600 +1607 +4985 +5870 +4661 +3811 +5138 +4123 +5459 +7499 +6299 +1865 +5668 +5538 +5424 +5686 +2843 +3190 +8555 +6772 +1927 +2114 +5020 +1107 +6978 +5305 +4840 +2210 +7055 +3456 +627 +5976 +3333 +4066 +1996 +3423 +3761 +5315 +2298 +6879 +1605 +7001 +6691 +4541 +4433 +4795 +4665 +6861 +3544 +5936 +3974 +1917 +5905 +5895 +5041 +6043 +4154 +7534 +5260 +6597 +3606 +7058 +6786 +8167 +4128 +3310 +8395 +5409 +5260 +1969 +6041 +2209 +4765 +4120 +3523 +4362 +4294 +3614 +5319 +5823 +5501 +4270 +7429 +4780 +2732 +7264 +4460 +8009 +22 +4639 +4150 +4339 +3872 +1807 +2277 +2729 +3409 +6871 +3267 +8227 +3846 +3709 +2947 +3659 +5267 +1301 +4669 +1416 +5918 +4803 +6392 +4097 +4744 +6093 +4758 +3389 +6073 +2236 +5875 +2077 +6639 +1005 +6565 +8294 +6824 +5345 +3620 +3663 +3214 +2475 +4996 +4189 +5729 +5558 +4023 +3717 +4308 +4649 +3644 +5191 +2046 +2424 +3820 +7693 +3214 +4835 +4187 +5047 +3249 +5464 +1013 +6203 +3542 +7338 +3744 +1526 +7504 +4509 +3750 +5445 +5986 +2744 +3115 +7384 +1096 +7702 +4569 +4991 +5566 +5810 +4073 +5531 +3485 +2808 +1011 +5743 +8090 +7591 +6133 +6227 +4579 +2802 +3805 +4758 +3577 +1834 +5107 +4322 +3784 +2703 +7733 +5191 +7415 +795 +3598 +1263 +7393 +2999 +6966 +4375 +6998 +1501 +4318 +6290 +4220 +5585 +6312 +1204 +5923 +7965 +3777 +3005 +7006 +1162 +3129 +1872 +1635 +3285 +6436 +6606 +7363 +3292 +4401 +4151 +7605 +3368 +4760 +3403 +3351 +7261 +5634 +3071 +2895 +3429 +3747 +1321 +5102 +7333 +7282 +4862 +3784 +7767 +7410 +3243 +959 +4191 +4274 +1098 +4186 +6869 +3510 +5511 +5740 +5423 +4539 +5087 +3922 +5582 +3785 +4649 +2431 +7720 +4595 +7572 +7570 +7461 +2169 +7129 +5557 +4334 +5312 +5892 +3669 +4378 +5629 +3624 +3126 +5409 +5225 +6192 +4634 +7641 +3767 +2115 +4881 +1623 +4790 +4459 +3331 +4590 +1650 +7013 +5847 +3322 +2162 +1787 +7421 +4592 +4575 +7538 +7534 +4040 +4035 +4359 +6779 +1712 +7375 +4195 +705 +4127 +6569 +3940 +1543 +4458 +2028 +5146 +7442 +4367 +5847 +754 +1162 +3867 +1606 +4333 +4906 +5217 +2927 +4338 +1115 +4258 +5062 +6153 +5336 +1683 +5115 +2485 +5728 +6398 +5169 +1446 +2134 +3831 +5323 +6883 +2425 +4864 +2425 +1842 +3387 +4484 +6825 +4773 +5463 +7460 +4182 +6370 +4966 +7446 +4844 +3117 +2832 +2933 +1795 +4602 +6770 +4586 +6864 +5204 +5515 +6031 +920 +4521 +1000 +7403 +4629 +2710 +8173 +4010 +2416 +5046 +4725 +1913 +3958 +2471 +6917 +7639 +2423 +7290 +1529 +2424 +4511 +6230 +1167 +7328 +2432 +4109 +7736 +2034 +3855 +3204 +6043 +4094 +4727 +6241 +6734 +441 +4342 +5010 +4917 +6591 +4205 +6778 +6304 +3376 +2918 +4332 +5255 +6207 +4630 +801 +605 +6889 +3959 +2311 +7697 +2633 +5992 +1510 +5008 +5687 +1985 +3786 +3194 +4785 +6536 +4576 +5119 +7836 +4845 +2132 +1248 +7132 +7665 +2743 +4911 +3830 +6891 +3974 +5499 +1562 +3163 +5202 +3520 +6598 +7865 +5532 +1749 +7804 +3095 +6784 +1495 +5035 +1815 +7765 +6660 +1471 +4763 +1891 +6852 +5362 +2192 +4105 +4153 +1708 +6421 +7436 +6273 +4585 +7852 +4118 +5302 diff --git a/inst/extdata/day.csv b/inst/extdata/day.csv new file mode 100644 index 0000000000000000000000000000000000000000..7498062a459fff5db9254f5e1a100defa6a961b5 --- /dev/null +++ b/inst/extdata/day.csv @@ -0,0 +1,732 @@ +instant,dteday,season,yr,mnth,holiday,weekday,workingday,weathersit,temp,atemp,hum,windspeed,casual,registered,cnt +1,2011-01-01,1,0,1,0,6,0,2,0.344167,0.363625,0.805833,0.160446,331,654,985 +2,2011-01-02,1,0,1,0,0,0,2,0.363478,0.353739,0.696087,0.248539,131,670,801 +3,2011-01-03,1,0,1,0,1,1,1,0.196364,0.189405,0.437273,0.248309,120,1229,1349 +4,2011-01-04,1,0,1,0,2,1,1,0.2,0.212122,0.590435,0.160296,108,1454,1562 +5,2011-01-05,1,0,1,0,3,1,1,0.226957,0.22927,0.436957,0.1869,82,1518,1600 +6,2011-01-06,1,0,1,0,4,1,1,0.204348,0.233209,0.518261,0.0895652,88,1518,1606 +7,2011-01-07,1,0,1,0,5,1,2,0.196522,0.208839,0.498696,0.168726,148,1362,1510 +8,2011-01-08,1,0,1,0,6,0,2,0.165,0.162254,0.535833,0.266804,68,891,959 +9,2011-01-09,1,0,1,0,0,0,1,0.138333,0.116175,0.434167,0.36195,54,768,822 +10,2011-01-10,1,0,1,0,1,1,1,0.150833,0.150888,0.482917,0.223267,41,1280,1321 +11,2011-01-11,1,0,1,0,2,1,2,0.169091,0.191464,0.686364,0.122132,43,1220,1263 +12,2011-01-12,1,0,1,0,3,1,1,0.172727,0.160473,0.599545,0.304627,25,1137,1162 +13,2011-01-13,1,0,1,0,4,1,1,0.165,0.150883,0.470417,0.301,38,1368,1406 +14,2011-01-14,1,0,1,0,5,1,1,0.16087,0.188413,0.537826,0.126548,54,1367,1421 +15,2011-01-15,1,0,1,0,6,0,2,0.233333,0.248112,0.49875,0.157963,222,1026,1248 +16,2011-01-16,1,0,1,0,0,0,1,0.231667,0.234217,0.48375,0.188433,251,953,1204 +17,2011-01-17,1,0,1,1,1,0,2,0.175833,0.176771,0.5375,0.194017,117,883,1000 +18,2011-01-18,1,0,1,0,2,1,2,0.216667,0.232333,0.861667,0.146775,9,674,683 +19,2011-01-19,1,0,1,0,3,1,2,0.292174,0.298422,0.741739,0.208317,78,1572,1650 +20,2011-01-20,1,0,1,0,4,1,2,0.261667,0.25505,0.538333,0.195904,83,1844,1927 +21,2011-01-21,1,0,1,0,5,1,1,0.1775,0.157833,0.457083,0.353242,75,1468,1543 +22,2011-01-22,1,0,1,0,6,0,1,0.0591304,0.0790696,0.4,0.17197,93,888,981 +23,2011-01-23,1,0,1,0,0,0,1,0.0965217,0.0988391,0.436522,0.2466,150,836,986 +24,2011-01-24,1,0,1,0,1,1,1,0.0973913,0.11793,0.491739,0.15833,86,1330,1416 +25,2011-01-25,1,0,1,0,2,1,2,0.223478,0.234526,0.616957,0.129796,186,1799,1985 +26,2011-01-26,1,0,1,0,3,1,3,0.2175,0.2036,0.8625,0.29385,34,472,506 +27,2011-01-27,1,0,1,0,4,1,1,0.195,0.2197,0.6875,0.113837,15,416,431 +28,2011-01-28,1,0,1,0,5,1,2,0.203478,0.223317,0.793043,0.1233,38,1129,1167 +29,2011-01-29,1,0,1,0,6,0,1,0.196522,0.212126,0.651739,0.145365,123,975,1098 +30,2011-01-30,1,0,1,0,0,0,1,0.216522,0.250322,0.722174,0.0739826,140,956,1096 +31,2011-01-31,1,0,1,0,1,1,2,0.180833,0.18625,0.60375,0.187192,42,1459,1501 +32,2011-02-01,1,0,2,0,2,1,2,0.192174,0.23453,0.829565,0.053213,47,1313,1360 +33,2011-02-02,1,0,2,0,3,1,2,0.26,0.254417,0.775417,0.264308,72,1454,1526 +34,2011-02-03,1,0,2,0,4,1,1,0.186957,0.177878,0.437826,0.277752,61,1489,1550 +35,2011-02-04,1,0,2,0,5,1,2,0.211304,0.228587,0.585217,0.127839,88,1620,1708 +36,2011-02-05,1,0,2,0,6,0,2,0.233333,0.243058,0.929167,0.161079,100,905,1005 +37,2011-02-06,1,0,2,0,0,0,1,0.285833,0.291671,0.568333,0.1418,354,1269,1623 +38,2011-02-07,1,0,2,0,1,1,1,0.271667,0.303658,0.738333,0.0454083,120,1592,1712 +39,2011-02-08,1,0,2,0,2,1,1,0.220833,0.198246,0.537917,0.36195,64,1466,1530 +40,2011-02-09,1,0,2,0,3,1,2,0.134783,0.144283,0.494783,0.188839,53,1552,1605 +41,2011-02-10,1,0,2,0,4,1,1,0.144348,0.149548,0.437391,0.221935,47,1491,1538 +42,2011-02-11,1,0,2,0,5,1,1,0.189091,0.213509,0.506364,0.10855,149,1597,1746 +43,2011-02-12,1,0,2,0,6,0,1,0.2225,0.232954,0.544167,0.203367,288,1184,1472 +44,2011-02-13,1,0,2,0,0,0,1,0.316522,0.324113,0.457391,0.260883,397,1192,1589 +45,2011-02-14,1,0,2,0,1,1,1,0.415,0.39835,0.375833,0.417908,208,1705,1913 +46,2011-02-15,1,0,2,0,2,1,1,0.266087,0.254274,0.314348,0.291374,140,1675,1815 +47,2011-02-16,1,0,2,0,3,1,1,0.318261,0.3162,0.423478,0.251791,218,1897,2115 +48,2011-02-17,1,0,2,0,4,1,1,0.435833,0.428658,0.505,0.230104,259,2216,2475 +49,2011-02-18,1,0,2,0,5,1,1,0.521667,0.511983,0.516667,0.264925,579,2348,2927 +50,2011-02-19,1,0,2,0,6,0,1,0.399167,0.391404,0.187917,0.507463,532,1103,1635 +51,2011-02-20,1,0,2,0,0,0,1,0.285217,0.27733,0.407826,0.223235,639,1173,1812 +52,2011-02-21,1,0,2,1,1,0,2,0.303333,0.284075,0.605,0.307846,195,912,1107 +53,2011-02-22,1,0,2,0,2,1,1,0.182222,0.186033,0.577778,0.195683,74,1376,1450 +54,2011-02-23,1,0,2,0,3,1,1,0.221739,0.245717,0.423043,0.094113,139,1778,1917 +55,2011-02-24,1,0,2,0,4,1,2,0.295652,0.289191,0.697391,0.250496,100,1707,1807 +56,2011-02-25,1,0,2,0,5,1,2,0.364348,0.350461,0.712174,0.346539,120,1341,1461 +57,2011-02-26,1,0,2,0,6,0,1,0.2825,0.282192,0.537917,0.186571,424,1545,1969 +58,2011-02-27,1,0,2,0,0,0,1,0.343478,0.351109,0.68,0.125248,694,1708,2402 +59,2011-02-28,1,0,2,0,1,1,2,0.407273,0.400118,0.876364,0.289686,81,1365,1446 +60,2011-03-01,1,0,3,0,2,1,1,0.266667,0.263879,0.535,0.216425,137,1714,1851 +61,2011-03-02,1,0,3,0,3,1,1,0.335,0.320071,0.449583,0.307833,231,1903,2134 +62,2011-03-03,1,0,3,0,4,1,1,0.198333,0.200133,0.318333,0.225754,123,1562,1685 +63,2011-03-04,1,0,3,0,5,1,2,0.261667,0.255679,0.610417,0.203346,214,1730,1944 +64,2011-03-05,1,0,3,0,6,0,2,0.384167,0.378779,0.789167,0.251871,640,1437,2077 +65,2011-03-06,1,0,3,0,0,0,2,0.376522,0.366252,0.948261,0.343287,114,491,605 +66,2011-03-07,1,0,3,0,1,1,1,0.261739,0.238461,0.551304,0.341352,244,1628,1872 +67,2011-03-08,1,0,3,0,2,1,1,0.2925,0.3024,0.420833,0.12065,316,1817,2133 +68,2011-03-09,1,0,3,0,3,1,2,0.295833,0.286608,0.775417,0.22015,191,1700,1891 +69,2011-03-10,1,0,3,0,4,1,3,0.389091,0.385668,0,0.261877,46,577,623 +70,2011-03-11,1,0,3,0,5,1,2,0.316522,0.305,0.649565,0.23297,247,1730,1977 +71,2011-03-12,1,0,3,0,6,0,1,0.329167,0.32575,0.594583,0.220775,724,1408,2132 +72,2011-03-13,1,0,3,0,0,0,1,0.384348,0.380091,0.527391,0.270604,982,1435,2417 +73,2011-03-14,1,0,3,0,1,1,1,0.325217,0.332,0.496957,0.136926,359,1687,2046 +74,2011-03-15,1,0,3,0,2,1,2,0.317391,0.318178,0.655652,0.184309,289,1767,2056 +75,2011-03-16,1,0,3,0,3,1,2,0.365217,0.36693,0.776522,0.203117,321,1871,2192 +76,2011-03-17,1,0,3,0,4,1,1,0.415,0.410333,0.602917,0.209579,424,2320,2744 +77,2011-03-18,1,0,3,0,5,1,1,0.54,0.527009,0.525217,0.231017,884,2355,3239 +78,2011-03-19,1,0,3,0,6,0,1,0.4725,0.466525,0.379167,0.368167,1424,1693,3117 +79,2011-03-20,1,0,3,0,0,0,1,0.3325,0.32575,0.47375,0.207721,1047,1424,2471 +80,2011-03-21,2,0,3,0,1,1,2,0.430435,0.409735,0.737391,0.288783,401,1676,2077 +81,2011-03-22,2,0,3,0,2,1,1,0.441667,0.440642,0.624583,0.22575,460,2243,2703 +82,2011-03-23,2,0,3,0,3,1,2,0.346957,0.337939,0.839565,0.234261,203,1918,2121 +83,2011-03-24,2,0,3,0,4,1,2,0.285,0.270833,0.805833,0.243787,166,1699,1865 +84,2011-03-25,2,0,3,0,5,1,1,0.264167,0.256312,0.495,0.230725,300,1910,2210 +85,2011-03-26,2,0,3,0,6,0,1,0.265833,0.257571,0.394167,0.209571,981,1515,2496 +86,2011-03-27,2,0,3,0,0,0,2,0.253043,0.250339,0.493913,0.1843,472,1221,1693 +87,2011-03-28,2,0,3,0,1,1,1,0.264348,0.257574,0.302174,0.212204,222,1806,2028 +88,2011-03-29,2,0,3,0,2,1,1,0.3025,0.292908,0.314167,0.226996,317,2108,2425 +89,2011-03-30,2,0,3,0,3,1,2,0.3,0.29735,0.646667,0.172888,168,1368,1536 +90,2011-03-31,2,0,3,0,4,1,3,0.268333,0.257575,0.918333,0.217646,179,1506,1685 +91,2011-04-01,2,0,4,0,5,1,2,0.3,0.283454,0.68625,0.258708,307,1920,2227 +92,2011-04-02,2,0,4,0,6,0,2,0.315,0.315637,0.65375,0.197146,898,1354,2252 +93,2011-04-03,2,0,4,0,0,0,1,0.378333,0.378767,0.48,0.182213,1651,1598,3249 +94,2011-04-04,2,0,4,0,1,1,1,0.573333,0.542929,0.42625,0.385571,734,2381,3115 +95,2011-04-05,2,0,4,0,2,1,2,0.414167,0.39835,0.642083,0.388067,167,1628,1795 +96,2011-04-06,2,0,4,0,3,1,1,0.390833,0.387608,0.470833,0.263063,413,2395,2808 +97,2011-04-07,2,0,4,0,4,1,1,0.4375,0.433696,0.602917,0.162312,571,2570,3141 +98,2011-04-08,2,0,4,0,5,1,2,0.335833,0.324479,0.83625,0.226992,172,1299,1471 +99,2011-04-09,2,0,4,0,6,0,2,0.3425,0.341529,0.8775,0.133083,879,1576,2455 +100,2011-04-10,2,0,4,0,0,0,2,0.426667,0.426737,0.8575,0.146767,1188,1707,2895 +101,2011-04-11,2,0,4,0,1,1,2,0.595652,0.565217,0.716956,0.324474,855,2493,3348 +102,2011-04-12,2,0,4,0,2,1,2,0.5025,0.493054,0.739167,0.274879,257,1777,2034 +103,2011-04-13,2,0,4,0,3,1,2,0.4125,0.417283,0.819167,0.250617,209,1953,2162 +104,2011-04-14,2,0,4,0,4,1,1,0.4675,0.462742,0.540417,0.1107,529,2738,3267 +105,2011-04-15,2,0,4,1,5,0,1,0.446667,0.441913,0.67125,0.226375,642,2484,3126 +106,2011-04-16,2,0,4,0,6,0,3,0.430833,0.425492,0.888333,0.340808,121,674,795 +107,2011-04-17,2,0,4,0,0,0,1,0.456667,0.445696,0.479583,0.303496,1558,2186,3744 +108,2011-04-18,2,0,4,0,1,1,1,0.5125,0.503146,0.5425,0.163567,669,2760,3429 +109,2011-04-19,2,0,4,0,2,1,2,0.505833,0.489258,0.665833,0.157971,409,2795,3204 +110,2011-04-20,2,0,4,0,3,1,1,0.595,0.564392,0.614167,0.241925,613,3331,3944 +111,2011-04-21,2,0,4,0,4,1,1,0.459167,0.453892,0.407083,0.325258,745,3444,4189 +112,2011-04-22,2,0,4,0,5,1,2,0.336667,0.321954,0.729583,0.219521,177,1506,1683 +113,2011-04-23,2,0,4,0,6,0,2,0.46,0.450121,0.887917,0.230725,1462,2574,4036 +114,2011-04-24,2,0,4,0,0,0,2,0.581667,0.551763,0.810833,0.192175,1710,2481,4191 +115,2011-04-25,2,0,4,0,1,1,1,0.606667,0.5745,0.776667,0.185333,773,3300,4073 +116,2011-04-26,2,0,4,0,2,1,1,0.631667,0.594083,0.729167,0.3265,678,3722,4400 +117,2011-04-27,2,0,4,0,3,1,2,0.62,0.575142,0.835417,0.3122,547,3325,3872 +118,2011-04-28,2,0,4,0,4,1,2,0.6175,0.578929,0.700833,0.320908,569,3489,4058 +119,2011-04-29,2,0,4,0,5,1,1,0.51,0.497463,0.457083,0.240063,878,3717,4595 +120,2011-04-30,2,0,4,0,6,0,1,0.4725,0.464021,0.503333,0.235075,1965,3347,5312 +121,2011-05-01,2,0,5,0,0,0,2,0.451667,0.448204,0.762083,0.106354,1138,2213,3351 +122,2011-05-02,2,0,5,0,1,1,2,0.549167,0.532833,0.73,0.183454,847,3554,4401 +123,2011-05-03,2,0,5,0,2,1,2,0.616667,0.582079,0.697083,0.342667,603,3848,4451 +124,2011-05-04,2,0,5,0,3,1,2,0.414167,0.40465,0.737083,0.328996,255,2378,2633 +125,2011-05-05,2,0,5,0,4,1,1,0.459167,0.441917,0.444167,0.295392,614,3819,4433 +126,2011-05-06,2,0,5,0,5,1,1,0.479167,0.474117,0.59,0.228246,894,3714,4608 +127,2011-05-07,2,0,5,0,6,0,1,0.52,0.512621,0.54125,0.16045,1612,3102,4714 +128,2011-05-08,2,0,5,0,0,0,1,0.528333,0.518933,0.631667,0.0746375,1401,2932,4333 +129,2011-05-09,2,0,5,0,1,1,1,0.5325,0.525246,0.58875,0.176,664,3698,4362 +130,2011-05-10,2,0,5,0,2,1,1,0.5325,0.522721,0.489167,0.115671,694,4109,4803 +131,2011-05-11,2,0,5,0,3,1,1,0.5425,0.5284,0.632917,0.120642,550,3632,4182 +132,2011-05-12,2,0,5,0,4,1,1,0.535,0.523363,0.7475,0.189667,695,4169,4864 +133,2011-05-13,2,0,5,0,5,1,2,0.5125,0.4943,0.863333,0.179725,692,3413,4105 +134,2011-05-14,2,0,5,0,6,0,2,0.520833,0.500629,0.9225,0.13495,902,2507,3409 +135,2011-05-15,2,0,5,0,0,0,2,0.5625,0.536,0.867083,0.152979,1582,2971,4553 +136,2011-05-16,2,0,5,0,1,1,1,0.5775,0.550512,0.787917,0.126871,773,3185,3958 +137,2011-05-17,2,0,5,0,2,1,2,0.561667,0.538529,0.837917,0.277354,678,3445,4123 +138,2011-05-18,2,0,5,0,3,1,2,0.55,0.527158,0.87,0.201492,536,3319,3855 +139,2011-05-19,2,0,5,0,4,1,2,0.530833,0.510742,0.829583,0.108213,735,3840,4575 +140,2011-05-20,2,0,5,0,5,1,1,0.536667,0.529042,0.719583,0.125013,909,4008,4917 +141,2011-05-21,2,0,5,0,6,0,1,0.6025,0.571975,0.626667,0.12065,2258,3547,5805 +142,2011-05-22,2,0,5,0,0,0,1,0.604167,0.5745,0.749583,0.148008,1576,3084,4660 +143,2011-05-23,2,0,5,0,1,1,2,0.631667,0.590296,0.81,0.233842,836,3438,4274 +144,2011-05-24,2,0,5,0,2,1,2,0.66,0.604813,0.740833,0.207092,659,3833,4492 +145,2011-05-25,2,0,5,0,3,1,1,0.660833,0.615542,0.69625,0.154233,740,4238,4978 +146,2011-05-26,2,0,5,0,4,1,1,0.708333,0.654688,0.6775,0.199642,758,3919,4677 +147,2011-05-27,2,0,5,0,5,1,1,0.681667,0.637008,0.65375,0.240679,871,3808,4679 +148,2011-05-28,2,0,5,0,6,0,1,0.655833,0.612379,0.729583,0.230092,2001,2757,4758 +149,2011-05-29,2,0,5,0,0,0,1,0.6675,0.61555,0.81875,0.213938,2355,2433,4788 +150,2011-05-30,2,0,5,1,1,0,1,0.733333,0.671092,0.685,0.131225,1549,2549,4098 +151,2011-05-31,2,0,5,0,2,1,1,0.775,0.725383,0.636667,0.111329,673,3309,3982 +152,2011-06-01,2,0,6,0,3,1,2,0.764167,0.720967,0.677083,0.207092,513,3461,3974 +153,2011-06-02,2,0,6,0,4,1,1,0.715,0.643942,0.305,0.292287,736,4232,4968 +154,2011-06-03,2,0,6,0,5,1,1,0.62,0.587133,0.354167,0.253121,898,4414,5312 +155,2011-06-04,2,0,6,0,6,0,1,0.635,0.594696,0.45625,0.123142,1869,3473,5342 +156,2011-06-05,2,0,6,0,0,0,2,0.648333,0.616804,0.6525,0.138692,1685,3221,4906 +157,2011-06-06,2,0,6,0,1,1,1,0.678333,0.621858,0.6,0.121896,673,3875,4548 +158,2011-06-07,2,0,6,0,2,1,1,0.7075,0.65595,0.597917,0.187808,763,4070,4833 +159,2011-06-08,2,0,6,0,3,1,1,0.775833,0.727279,0.622083,0.136817,676,3725,4401 +160,2011-06-09,2,0,6,0,4,1,2,0.808333,0.757579,0.568333,0.149883,563,3352,3915 +161,2011-06-10,2,0,6,0,5,1,1,0.755,0.703292,0.605,0.140554,815,3771,4586 +162,2011-06-11,2,0,6,0,6,0,1,0.725,0.678038,0.654583,0.15485,1729,3237,4966 +163,2011-06-12,2,0,6,0,0,0,1,0.6925,0.643325,0.747917,0.163567,1467,2993,4460 +164,2011-06-13,2,0,6,0,1,1,1,0.635,0.601654,0.494583,0.30535,863,4157,5020 +165,2011-06-14,2,0,6,0,2,1,1,0.604167,0.591546,0.507083,0.269283,727,4164,4891 +166,2011-06-15,2,0,6,0,3,1,1,0.626667,0.587754,0.471667,0.167912,769,4411,5180 +167,2011-06-16,2,0,6,0,4,1,2,0.628333,0.595346,0.688333,0.206471,545,3222,3767 +168,2011-06-17,2,0,6,0,5,1,1,0.649167,0.600383,0.735833,0.143029,863,3981,4844 +169,2011-06-18,2,0,6,0,6,0,1,0.696667,0.643954,0.670417,0.119408,1807,3312,5119 +170,2011-06-19,2,0,6,0,0,0,2,0.699167,0.645846,0.666667,0.102,1639,3105,4744 +171,2011-06-20,2,0,6,0,1,1,2,0.635,0.595346,0.74625,0.155475,699,3311,4010 +172,2011-06-21,3,0,6,0,2,1,2,0.680833,0.637646,0.770417,0.171025,774,4061,4835 +173,2011-06-22,3,0,6,0,3,1,1,0.733333,0.693829,0.7075,0.172262,661,3846,4507 +174,2011-06-23,3,0,6,0,4,1,2,0.728333,0.693833,0.703333,0.238804,746,4044,4790 +175,2011-06-24,3,0,6,0,5,1,1,0.724167,0.656583,0.573333,0.222025,969,4022,4991 +176,2011-06-25,3,0,6,0,6,0,1,0.695,0.643313,0.483333,0.209571,1782,3420,5202 +177,2011-06-26,3,0,6,0,0,0,1,0.68,0.637629,0.513333,0.0945333,1920,3385,5305 +178,2011-06-27,3,0,6,0,1,1,2,0.6825,0.637004,0.658333,0.107588,854,3854,4708 +179,2011-06-28,3,0,6,0,2,1,1,0.744167,0.692558,0.634167,0.144283,732,3916,4648 +180,2011-06-29,3,0,6,0,3,1,1,0.728333,0.654688,0.497917,0.261821,848,4377,5225 +181,2011-06-30,3,0,6,0,4,1,1,0.696667,0.637008,0.434167,0.185312,1027,4488,5515 +182,2011-07-01,3,0,7,0,5,1,1,0.7225,0.652162,0.39625,0.102608,1246,4116,5362 +183,2011-07-02,3,0,7,0,6,0,1,0.738333,0.667308,0.444583,0.115062,2204,2915,5119 +184,2011-07-03,3,0,7,0,0,0,2,0.716667,0.668575,0.6825,0.228858,2282,2367,4649 +185,2011-07-04,3,0,7,1,1,0,2,0.726667,0.665417,0.637917,0.0814792,3065,2978,6043 +186,2011-07-05,3,0,7,0,2,1,1,0.746667,0.696338,0.590417,0.126258,1031,3634,4665 +187,2011-07-06,3,0,7,0,3,1,1,0.72,0.685633,0.743333,0.149883,784,3845,4629 +188,2011-07-07,3,0,7,0,4,1,1,0.75,0.686871,0.65125,0.1592,754,3838,4592 +189,2011-07-08,3,0,7,0,5,1,2,0.709167,0.670483,0.757917,0.225129,692,3348,4040 +190,2011-07-09,3,0,7,0,6,0,1,0.733333,0.664158,0.609167,0.167912,1988,3348,5336 +191,2011-07-10,3,0,7,0,0,0,1,0.7475,0.690025,0.578333,0.183471,1743,3138,4881 +192,2011-07-11,3,0,7,0,1,1,1,0.7625,0.729804,0.635833,0.282337,723,3363,4086 +193,2011-07-12,3,0,7,0,2,1,1,0.794167,0.739275,0.559167,0.200254,662,3596,4258 +194,2011-07-13,3,0,7,0,3,1,1,0.746667,0.689404,0.631667,0.146133,748,3594,4342 +195,2011-07-14,3,0,7,0,4,1,1,0.680833,0.635104,0.47625,0.240667,888,4196,5084 +196,2011-07-15,3,0,7,0,5,1,1,0.663333,0.624371,0.59125,0.182833,1318,4220,5538 +197,2011-07-16,3,0,7,0,6,0,1,0.686667,0.638263,0.585,0.208342,2418,3505,5923 +198,2011-07-17,3,0,7,0,0,0,1,0.719167,0.669833,0.604167,0.245033,2006,3296,5302 +199,2011-07-18,3,0,7,0,1,1,1,0.746667,0.703925,0.65125,0.215804,841,3617,4458 +200,2011-07-19,3,0,7,0,2,1,1,0.776667,0.747479,0.650417,0.1306,752,3789,4541 +201,2011-07-20,3,0,7,0,3,1,1,0.768333,0.74685,0.707083,0.113817,644,3688,4332 +202,2011-07-21,3,0,7,0,4,1,2,0.815,0.826371,0.69125,0.222021,632,3152,3784 +203,2011-07-22,3,0,7,0,5,1,1,0.848333,0.840896,0.580417,0.1331,562,2825,3387 +204,2011-07-23,3,0,7,0,6,0,1,0.849167,0.804287,0.5,0.131221,987,2298,3285 +205,2011-07-24,3,0,7,0,0,0,1,0.83,0.794829,0.550833,0.169171,1050,2556,3606 +206,2011-07-25,3,0,7,0,1,1,1,0.743333,0.720958,0.757083,0.0908083,568,3272,3840 +207,2011-07-26,3,0,7,0,2,1,1,0.771667,0.696979,0.540833,0.200258,750,3840,4590 +208,2011-07-27,3,0,7,0,3,1,1,0.775,0.690667,0.402917,0.183463,755,3901,4656 +209,2011-07-28,3,0,7,0,4,1,1,0.779167,0.7399,0.583333,0.178479,606,3784,4390 +210,2011-07-29,3,0,7,0,5,1,1,0.838333,0.785967,0.5425,0.174138,670,3176,3846 +211,2011-07-30,3,0,7,0,6,0,1,0.804167,0.728537,0.465833,0.168537,1559,2916,4475 +212,2011-07-31,3,0,7,0,0,0,1,0.805833,0.729796,0.480833,0.164813,1524,2778,4302 +213,2011-08-01,3,0,8,0,1,1,1,0.771667,0.703292,0.550833,0.156717,729,3537,4266 +214,2011-08-02,3,0,8,0,2,1,1,0.783333,0.707071,0.49125,0.20585,801,4044,4845 +215,2011-08-03,3,0,8,0,3,1,2,0.731667,0.679937,0.6575,0.135583,467,3107,3574 +216,2011-08-04,3,0,8,0,4,1,2,0.71,0.664788,0.7575,0.19715,799,3777,4576 +217,2011-08-05,3,0,8,0,5,1,1,0.710833,0.656567,0.630833,0.184696,1023,3843,4866 +218,2011-08-06,3,0,8,0,6,0,2,0.716667,0.676154,0.755,0.22825,1521,2773,4294 +219,2011-08-07,3,0,8,0,0,0,1,0.7425,0.715292,0.752917,0.201487,1298,2487,3785 +220,2011-08-08,3,0,8,0,1,1,1,0.765,0.703283,0.592083,0.192175,846,3480,4326 +221,2011-08-09,3,0,8,0,2,1,1,0.775,0.724121,0.570417,0.151121,907,3695,4602 +222,2011-08-10,3,0,8,0,3,1,1,0.766667,0.684983,0.424167,0.200258,884,3896,4780 +223,2011-08-11,3,0,8,0,4,1,1,0.7175,0.651521,0.42375,0.164796,812,3980,4792 +224,2011-08-12,3,0,8,0,5,1,1,0.708333,0.654042,0.415,0.125621,1051,3854,4905 +225,2011-08-13,3,0,8,0,6,0,2,0.685833,0.645858,0.729583,0.211454,1504,2646,4150 +226,2011-08-14,3,0,8,0,0,0,2,0.676667,0.624388,0.8175,0.222633,1338,2482,3820 +227,2011-08-15,3,0,8,0,1,1,1,0.665833,0.616167,0.712083,0.208954,775,3563,4338 +228,2011-08-16,3,0,8,0,2,1,1,0.700833,0.645837,0.578333,0.236329,721,4004,4725 +229,2011-08-17,3,0,8,0,3,1,1,0.723333,0.666671,0.575417,0.143667,668,4026,4694 +230,2011-08-18,3,0,8,0,4,1,1,0.711667,0.662258,0.654583,0.233208,639,3166,3805 +231,2011-08-19,3,0,8,0,5,1,2,0.685,0.633221,0.722917,0.139308,797,3356,4153 +232,2011-08-20,3,0,8,0,6,0,1,0.6975,0.648996,0.674167,0.104467,1914,3277,5191 +233,2011-08-21,3,0,8,0,0,0,1,0.710833,0.675525,0.77,0.248754,1249,2624,3873 +234,2011-08-22,3,0,8,0,1,1,1,0.691667,0.638254,0.47,0.27675,833,3925,4758 +235,2011-08-23,3,0,8,0,2,1,1,0.640833,0.606067,0.455417,0.146763,1281,4614,5895 +236,2011-08-24,3,0,8,0,3,1,1,0.673333,0.630692,0.605,0.253108,949,4181,5130 +237,2011-08-25,3,0,8,0,4,1,2,0.684167,0.645854,0.771667,0.210833,435,3107,3542 +238,2011-08-26,3,0,8,0,5,1,1,0.7,0.659733,0.76125,0.0839625,768,3893,4661 +239,2011-08-27,3,0,8,0,6,0,2,0.68,0.635556,0.85,0.375617,226,889,1115 +240,2011-08-28,3,0,8,0,0,0,1,0.707059,0.647959,0.561765,0.304659,1415,2919,4334 +241,2011-08-29,3,0,8,0,1,1,1,0.636667,0.607958,0.554583,0.159825,729,3905,4634 +242,2011-08-30,3,0,8,0,2,1,1,0.639167,0.594704,0.548333,0.125008,775,4429,5204 +243,2011-08-31,3,0,8,0,3,1,1,0.656667,0.611121,0.597917,0.0833333,688,4370,5058 +244,2011-09-01,3,0,9,0,4,1,1,0.655,0.614921,0.639167,0.141796,783,4332,5115 +245,2011-09-02,3,0,9,0,5,1,2,0.643333,0.604808,0.727083,0.139929,875,3852,4727 +246,2011-09-03,3,0,9,0,6,0,1,0.669167,0.633213,0.716667,0.185325,1935,2549,4484 +247,2011-09-04,3,0,9,0,0,0,1,0.709167,0.665429,0.742083,0.206467,2521,2419,4940 +248,2011-09-05,3,0,9,1,1,0,2,0.673333,0.625646,0.790417,0.212696,1236,2115,3351 +249,2011-09-06,3,0,9,0,2,1,3,0.54,0.5152,0.886957,0.343943,204,2506,2710 +250,2011-09-07,3,0,9,0,3,1,3,0.599167,0.544229,0.917083,0.0970208,118,1878,1996 +251,2011-09-08,3,0,9,0,4,1,3,0.633913,0.555361,0.939565,0.192748,153,1689,1842 +252,2011-09-09,3,0,9,0,5,1,2,0.65,0.578946,0.897917,0.124379,417,3127,3544 +253,2011-09-10,3,0,9,0,6,0,1,0.66,0.607962,0.75375,0.153608,1750,3595,5345 +254,2011-09-11,3,0,9,0,0,0,1,0.653333,0.609229,0.71375,0.115054,1633,3413,5046 +255,2011-09-12,3,0,9,0,1,1,1,0.644348,0.60213,0.692174,0.088913,690,4023,4713 +256,2011-09-13,3,0,9,0,2,1,1,0.650833,0.603554,0.7125,0.141804,701,4062,4763 +257,2011-09-14,3,0,9,0,3,1,1,0.673333,0.6269,0.697083,0.1673,647,4138,4785 +258,2011-09-15,3,0,9,0,4,1,2,0.5775,0.553671,0.709167,0.271146,428,3231,3659 +259,2011-09-16,3,0,9,0,5,1,2,0.469167,0.461475,0.590417,0.164183,742,4018,4760 +260,2011-09-17,3,0,9,0,6,0,2,0.491667,0.478512,0.718333,0.189675,1434,3077,4511 +261,2011-09-18,3,0,9,0,0,0,1,0.5075,0.490537,0.695,0.178483,1353,2921,4274 +262,2011-09-19,3,0,9,0,1,1,2,0.549167,0.529675,0.69,0.151742,691,3848,4539 +263,2011-09-20,3,0,9,0,2,1,2,0.561667,0.532217,0.88125,0.134954,438,3203,3641 +264,2011-09-21,3,0,9,0,3,1,2,0.595,0.550533,0.9,0.0964042,539,3813,4352 +265,2011-09-22,3,0,9,0,4,1,2,0.628333,0.554963,0.902083,0.128125,555,4240,4795 +266,2011-09-23,4,0,9,0,5,1,2,0.609167,0.522125,0.9725,0.0783667,258,2137,2395 +267,2011-09-24,4,0,9,0,6,0,2,0.606667,0.564412,0.8625,0.0783833,1776,3647,5423 +268,2011-09-25,4,0,9,0,0,0,2,0.634167,0.572637,0.845,0.0503792,1544,3466,5010 +269,2011-09-26,4,0,9,0,1,1,2,0.649167,0.589042,0.848333,0.1107,684,3946,4630 +270,2011-09-27,4,0,9,0,2,1,2,0.636667,0.574525,0.885417,0.118171,477,3643,4120 +271,2011-09-28,4,0,9,0,3,1,2,0.635,0.575158,0.84875,0.148629,480,3427,3907 +272,2011-09-29,4,0,9,0,4,1,1,0.616667,0.574512,0.699167,0.172883,653,4186,4839 +273,2011-09-30,4,0,9,0,5,1,1,0.564167,0.544829,0.6475,0.206475,830,4372,5202 +274,2011-10-01,4,0,10,0,6,0,2,0.41,0.412863,0.75375,0.292296,480,1949,2429 +275,2011-10-02,4,0,10,0,0,0,2,0.356667,0.345317,0.791667,0.222013,616,2302,2918 +276,2011-10-03,4,0,10,0,1,1,2,0.384167,0.392046,0.760833,0.0833458,330,3240,3570 +277,2011-10-04,4,0,10,0,2,1,1,0.484167,0.472858,0.71,0.205854,486,3970,4456 +278,2011-10-05,4,0,10,0,3,1,1,0.538333,0.527138,0.647917,0.17725,559,4267,4826 +279,2011-10-06,4,0,10,0,4,1,1,0.494167,0.480425,0.620833,0.134954,639,4126,4765 +280,2011-10-07,4,0,10,0,5,1,1,0.510833,0.504404,0.684167,0.0223917,949,4036,4985 +281,2011-10-08,4,0,10,0,6,0,1,0.521667,0.513242,0.70125,0.0454042,2235,3174,5409 +282,2011-10-09,4,0,10,0,0,0,1,0.540833,0.523983,0.7275,0.06345,2397,3114,5511 +283,2011-10-10,4,0,10,1,1,0,1,0.570833,0.542925,0.73375,0.0423042,1514,3603,5117 +284,2011-10-11,4,0,10,0,2,1,2,0.566667,0.546096,0.80875,0.143042,667,3896,4563 +285,2011-10-12,4,0,10,0,3,1,3,0.543333,0.517717,0.90625,0.24815,217,2199,2416 +286,2011-10-13,4,0,10,0,4,1,2,0.589167,0.551804,0.896667,0.141787,290,2623,2913 +287,2011-10-14,4,0,10,0,5,1,2,0.550833,0.529675,0.71625,0.223883,529,3115,3644 +288,2011-10-15,4,0,10,0,6,0,1,0.506667,0.498725,0.483333,0.258083,1899,3318,5217 +289,2011-10-16,4,0,10,0,0,0,1,0.511667,0.503154,0.486667,0.281717,1748,3293,5041 +290,2011-10-17,4,0,10,0,1,1,1,0.534167,0.510725,0.579583,0.175379,713,3857,4570 +291,2011-10-18,4,0,10,0,2,1,2,0.5325,0.522721,0.701667,0.110087,637,4111,4748 +292,2011-10-19,4,0,10,0,3,1,3,0.541739,0.513848,0.895217,0.243339,254,2170,2424 +293,2011-10-20,4,0,10,0,4,1,1,0.475833,0.466525,0.63625,0.422275,471,3724,4195 +294,2011-10-21,4,0,10,0,5,1,1,0.4275,0.423596,0.574167,0.221396,676,3628,4304 +295,2011-10-22,4,0,10,0,6,0,1,0.4225,0.425492,0.629167,0.0926667,1499,2809,4308 +296,2011-10-23,4,0,10,0,0,0,1,0.421667,0.422333,0.74125,0.0995125,1619,2762,4381 +297,2011-10-24,4,0,10,0,1,1,1,0.463333,0.457067,0.772083,0.118792,699,3488,4187 +298,2011-10-25,4,0,10,0,2,1,1,0.471667,0.463375,0.622917,0.166658,695,3992,4687 +299,2011-10-26,4,0,10,0,3,1,2,0.484167,0.472846,0.720417,0.148642,404,3490,3894 +300,2011-10-27,4,0,10,0,4,1,2,0.47,0.457046,0.812917,0.197763,240,2419,2659 +301,2011-10-28,4,0,10,0,5,1,2,0.330833,0.318812,0.585833,0.229479,456,3291,3747 +302,2011-10-29,4,0,10,0,6,0,3,0.254167,0.227913,0.8825,0.351371,57,570,627 +303,2011-10-30,4,0,10,0,0,0,1,0.319167,0.321329,0.62375,0.176617,885,2446,3331 +304,2011-10-31,4,0,10,0,1,1,1,0.34,0.356063,0.703333,0.10635,362,3307,3669 +305,2011-11-01,4,0,11,0,2,1,1,0.400833,0.397088,0.68375,0.135571,410,3658,4068 +306,2011-11-02,4,0,11,0,3,1,1,0.3775,0.390133,0.71875,0.0820917,370,3816,4186 +307,2011-11-03,4,0,11,0,4,1,1,0.408333,0.405921,0.702083,0.136817,318,3656,3974 +308,2011-11-04,4,0,11,0,5,1,2,0.403333,0.403392,0.6225,0.271779,470,3576,4046 +309,2011-11-05,4,0,11,0,6,0,1,0.326667,0.323854,0.519167,0.189062,1156,2770,3926 +310,2011-11-06,4,0,11,0,0,0,1,0.348333,0.362358,0.734583,0.0920542,952,2697,3649 +311,2011-11-07,4,0,11,0,1,1,1,0.395,0.400871,0.75875,0.057225,373,3662,4035 +312,2011-11-08,4,0,11,0,2,1,1,0.408333,0.412246,0.721667,0.0690375,376,3829,4205 +313,2011-11-09,4,0,11,0,3,1,1,0.4,0.409079,0.758333,0.0621958,305,3804,4109 +314,2011-11-10,4,0,11,0,4,1,2,0.38,0.373721,0.813333,0.189067,190,2743,2933 +315,2011-11-11,4,0,11,1,5,0,1,0.324167,0.306817,0.44625,0.314675,440,2928,3368 +316,2011-11-12,4,0,11,0,6,0,1,0.356667,0.357942,0.552917,0.212062,1275,2792,4067 +317,2011-11-13,4,0,11,0,0,0,1,0.440833,0.43055,0.458333,0.281721,1004,2713,3717 +318,2011-11-14,4,0,11,0,1,1,1,0.53,0.524612,0.587083,0.306596,595,3891,4486 +319,2011-11-15,4,0,11,0,2,1,2,0.53,0.507579,0.68875,0.199633,449,3746,4195 +320,2011-11-16,4,0,11,0,3,1,3,0.456667,0.451988,0.93,0.136829,145,1672,1817 +321,2011-11-17,4,0,11,0,4,1,2,0.341667,0.323221,0.575833,0.305362,139,2914,3053 +322,2011-11-18,4,0,11,0,5,1,1,0.274167,0.272721,0.41,0.168533,245,3147,3392 +323,2011-11-19,4,0,11,0,6,0,1,0.329167,0.324483,0.502083,0.224496,943,2720,3663 +324,2011-11-20,4,0,11,0,0,0,2,0.463333,0.457058,0.684583,0.18595,787,2733,3520 +325,2011-11-21,4,0,11,0,1,1,3,0.4475,0.445062,0.91,0.138054,220,2545,2765 +326,2011-11-22,4,0,11,0,2,1,3,0.416667,0.421696,0.9625,0.118792,69,1538,1607 +327,2011-11-23,4,0,11,0,3,1,2,0.440833,0.430537,0.757917,0.335825,112,2454,2566 +328,2011-11-24,4,0,11,1,4,0,1,0.373333,0.372471,0.549167,0.167304,560,935,1495 +329,2011-11-25,4,0,11,0,5,1,1,0.375,0.380671,0.64375,0.0988958,1095,1697,2792 +330,2011-11-26,4,0,11,0,6,0,1,0.375833,0.385087,0.681667,0.0684208,1249,1819,3068 +331,2011-11-27,4,0,11,0,0,0,1,0.459167,0.4558,0.698333,0.208954,810,2261,3071 +332,2011-11-28,4,0,11,0,1,1,1,0.503478,0.490122,0.743043,0.142122,253,3614,3867 +333,2011-11-29,4,0,11,0,2,1,2,0.458333,0.451375,0.830833,0.258092,96,2818,2914 +334,2011-11-30,4,0,11,0,3,1,1,0.325,0.311221,0.613333,0.271158,188,3425,3613 +335,2011-12-01,4,0,12,0,4,1,1,0.3125,0.305554,0.524583,0.220158,182,3545,3727 +336,2011-12-02,4,0,12,0,5,1,1,0.314167,0.331433,0.625833,0.100754,268,3672,3940 +337,2011-12-03,4,0,12,0,6,0,1,0.299167,0.310604,0.612917,0.0957833,706,2908,3614 +338,2011-12-04,4,0,12,0,0,0,1,0.330833,0.3491,0.775833,0.0839583,634,2851,3485 +339,2011-12-05,4,0,12,0,1,1,2,0.385833,0.393925,0.827083,0.0622083,233,3578,3811 +340,2011-12-06,4,0,12,0,2,1,3,0.4625,0.4564,0.949583,0.232583,126,2468,2594 +341,2011-12-07,4,0,12,0,3,1,3,0.41,0.400246,0.970417,0.266175,50,655,705 +342,2011-12-08,4,0,12,0,4,1,1,0.265833,0.256938,0.58,0.240058,150,3172,3322 +343,2011-12-09,4,0,12,0,5,1,1,0.290833,0.317542,0.695833,0.0827167,261,3359,3620 +344,2011-12-10,4,0,12,0,6,0,1,0.275,0.266412,0.5075,0.233221,502,2688,3190 +345,2011-12-11,4,0,12,0,0,0,1,0.220833,0.253154,0.49,0.0665417,377,2366,2743 +346,2011-12-12,4,0,12,0,1,1,1,0.238333,0.270196,0.670833,0.06345,143,3167,3310 +347,2011-12-13,4,0,12,0,2,1,1,0.2825,0.301138,0.59,0.14055,155,3368,3523 +348,2011-12-14,4,0,12,0,3,1,2,0.3175,0.338362,0.66375,0.0609583,178,3562,3740 +349,2011-12-15,4,0,12,0,4,1,2,0.4225,0.412237,0.634167,0.268042,181,3528,3709 +350,2011-12-16,4,0,12,0,5,1,2,0.375,0.359825,0.500417,0.260575,178,3399,3577 +351,2011-12-17,4,0,12,0,6,0,2,0.258333,0.249371,0.560833,0.243167,275,2464,2739 +352,2011-12-18,4,0,12,0,0,0,1,0.238333,0.245579,0.58625,0.169779,220,2211,2431 +353,2011-12-19,4,0,12,0,1,1,1,0.276667,0.280933,0.6375,0.172896,260,3143,3403 +354,2011-12-20,4,0,12,0,2,1,2,0.385833,0.396454,0.595417,0.0615708,216,3534,3750 +355,2011-12-21,1,0,12,0,3,1,2,0.428333,0.428017,0.858333,0.2214,107,2553,2660 +356,2011-12-22,1,0,12,0,4,1,2,0.423333,0.426121,0.7575,0.047275,227,2841,3068 +357,2011-12-23,1,0,12,0,5,1,1,0.373333,0.377513,0.68625,0.274246,163,2046,2209 +358,2011-12-24,1,0,12,0,6,0,1,0.3025,0.299242,0.5425,0.190304,155,856,1011 +359,2011-12-25,1,0,12,0,0,0,1,0.274783,0.279961,0.681304,0.155091,303,451,754 +360,2011-12-26,1,0,12,1,1,0,1,0.321739,0.315535,0.506957,0.239465,430,887,1317 +361,2011-12-27,1,0,12,0,2,1,2,0.325,0.327633,0.7625,0.18845,103,1059,1162 +362,2011-12-28,1,0,12,0,3,1,1,0.29913,0.279974,0.503913,0.293961,255,2047,2302 +363,2011-12-29,1,0,12,0,4,1,1,0.248333,0.263892,0.574167,0.119412,254,2169,2423 +364,2011-12-30,1,0,12,0,5,1,1,0.311667,0.318812,0.636667,0.134337,491,2508,2999 +365,2011-12-31,1,0,12,0,6,0,1,0.41,0.414121,0.615833,0.220154,665,1820,2485 +366,2012-01-01,1,1,1,0,0,0,1,0.37,0.375621,0.6925,0.192167,686,1608,2294 +367,2012-01-02,1,1,1,1,1,0,1,0.273043,0.252304,0.381304,0.329665,244,1707,1951 +368,2012-01-03,1,1,1,0,2,1,1,0.15,0.126275,0.44125,0.365671,89,2147,2236 +369,2012-01-04,1,1,1,0,3,1,2,0.1075,0.119337,0.414583,0.1847,95,2273,2368 +370,2012-01-05,1,1,1,0,4,1,1,0.265833,0.278412,0.524167,0.129987,140,3132,3272 +371,2012-01-06,1,1,1,0,5,1,1,0.334167,0.340267,0.542083,0.167908,307,3791,4098 +372,2012-01-07,1,1,1,0,6,0,1,0.393333,0.390779,0.531667,0.174758,1070,3451,4521 +373,2012-01-08,1,1,1,0,0,0,1,0.3375,0.340258,0.465,0.191542,599,2826,3425 +374,2012-01-09,1,1,1,0,1,1,2,0.224167,0.247479,0.701667,0.0989,106,2270,2376 +375,2012-01-10,1,1,1,0,2,1,1,0.308696,0.318826,0.646522,0.187552,173,3425,3598 +376,2012-01-11,1,1,1,0,3,1,2,0.274167,0.282821,0.8475,0.131221,92,2085,2177 +377,2012-01-12,1,1,1,0,4,1,2,0.3825,0.381938,0.802917,0.180967,269,3828,4097 +378,2012-01-13,1,1,1,0,5,1,1,0.274167,0.249362,0.5075,0.378108,174,3040,3214 +379,2012-01-14,1,1,1,0,6,0,1,0.18,0.183087,0.4575,0.187183,333,2160,2493 +380,2012-01-15,1,1,1,0,0,0,1,0.166667,0.161625,0.419167,0.251258,284,2027,2311 +381,2012-01-16,1,1,1,1,1,0,1,0.19,0.190663,0.5225,0.231358,217,2081,2298 +382,2012-01-17,1,1,1,0,2,1,2,0.373043,0.364278,0.716087,0.34913,127,2808,2935 +383,2012-01-18,1,1,1,0,3,1,1,0.303333,0.275254,0.443333,0.415429,109,3267,3376 +384,2012-01-19,1,1,1,0,4,1,1,0.19,0.190038,0.4975,0.220158,130,3162,3292 +385,2012-01-20,1,1,1,0,5,1,2,0.2175,0.220958,0.45,0.20275,115,3048,3163 +386,2012-01-21,1,1,1,0,6,0,2,0.173333,0.174875,0.83125,0.222642,67,1234,1301 +387,2012-01-22,1,1,1,0,0,0,2,0.1625,0.16225,0.79625,0.199638,196,1781,1977 +388,2012-01-23,1,1,1,0,1,1,2,0.218333,0.243058,0.91125,0.110708,145,2287,2432 +389,2012-01-24,1,1,1,0,2,1,1,0.3425,0.349108,0.835833,0.123767,439,3900,4339 +390,2012-01-25,1,1,1,0,3,1,1,0.294167,0.294821,0.64375,0.161071,467,3803,4270 +391,2012-01-26,1,1,1,0,4,1,2,0.341667,0.35605,0.769583,0.0733958,244,3831,4075 +392,2012-01-27,1,1,1,0,5,1,2,0.425,0.415383,0.74125,0.342667,269,3187,3456 +393,2012-01-28,1,1,1,0,6,0,1,0.315833,0.326379,0.543333,0.210829,775,3248,4023 +394,2012-01-29,1,1,1,0,0,0,1,0.2825,0.272721,0.31125,0.24005,558,2685,3243 +395,2012-01-30,1,1,1,0,1,1,1,0.269167,0.262625,0.400833,0.215792,126,3498,3624 +396,2012-01-31,1,1,1,0,2,1,1,0.39,0.381317,0.416667,0.261817,324,4185,4509 +397,2012-02-01,1,1,2,0,3,1,1,0.469167,0.466538,0.507917,0.189067,304,4275,4579 +398,2012-02-02,1,1,2,0,4,1,2,0.399167,0.398971,0.672917,0.187187,190,3571,3761 +399,2012-02-03,1,1,2,0,5,1,1,0.313333,0.309346,0.526667,0.178496,310,3841,4151 +400,2012-02-04,1,1,2,0,6,0,2,0.264167,0.272725,0.779583,0.121896,384,2448,2832 +401,2012-02-05,1,1,2,0,0,0,2,0.265833,0.264521,0.687917,0.175996,318,2629,2947 +402,2012-02-06,1,1,2,0,1,1,1,0.282609,0.296426,0.622174,0.1538,206,3578,3784 +403,2012-02-07,1,1,2,0,2,1,1,0.354167,0.361104,0.49625,0.147379,199,4176,4375 +404,2012-02-08,1,1,2,0,3,1,2,0.256667,0.266421,0.722917,0.133721,109,2693,2802 +405,2012-02-09,1,1,2,0,4,1,1,0.265,0.261988,0.562083,0.194037,163,3667,3830 +406,2012-02-10,1,1,2,0,5,1,2,0.280833,0.293558,0.54,0.116929,227,3604,3831 +407,2012-02-11,1,1,2,0,6,0,3,0.224167,0.210867,0.73125,0.289796,192,1977,2169 +408,2012-02-12,1,1,2,0,0,0,1,0.1275,0.101658,0.464583,0.409212,73,1456,1529 +409,2012-02-13,1,1,2,0,1,1,1,0.2225,0.227913,0.41125,0.167283,94,3328,3422 +410,2012-02-14,1,1,2,0,2,1,2,0.319167,0.333946,0.50875,0.141179,135,3787,3922 +411,2012-02-15,1,1,2,0,3,1,1,0.348333,0.351629,0.53125,0.1816,141,4028,4169 +412,2012-02-16,1,1,2,0,4,1,2,0.316667,0.330162,0.752917,0.091425,74,2931,3005 +413,2012-02-17,1,1,2,0,5,1,1,0.343333,0.351629,0.634583,0.205846,349,3805,4154 +414,2012-02-18,1,1,2,0,6,0,1,0.346667,0.355425,0.534583,0.190929,1435,2883,4318 +415,2012-02-19,1,1,2,0,0,0,2,0.28,0.265788,0.515833,0.253112,618,2071,2689 +416,2012-02-20,1,1,2,1,1,0,1,0.28,0.273391,0.507826,0.229083,502,2627,3129 +417,2012-02-21,1,1,2,0,2,1,1,0.287826,0.295113,0.594348,0.205717,163,3614,3777 +418,2012-02-22,1,1,2,0,3,1,1,0.395833,0.392667,0.567917,0.234471,394,4379,4773 +419,2012-02-23,1,1,2,0,4,1,1,0.454167,0.444446,0.554583,0.190913,516,4546,5062 +420,2012-02-24,1,1,2,0,5,1,2,0.4075,0.410971,0.7375,0.237567,246,3241,3487 +421,2012-02-25,1,1,2,0,6,0,1,0.290833,0.255675,0.395833,0.421642,317,2415,2732 +422,2012-02-26,1,1,2,0,0,0,1,0.279167,0.268308,0.41,0.205229,515,2874,3389 +423,2012-02-27,1,1,2,0,1,1,1,0.366667,0.357954,0.490833,0.268033,253,4069,4322 +424,2012-02-28,1,1,2,0,2,1,1,0.359167,0.353525,0.395833,0.193417,229,4134,4363 +425,2012-02-29,1,1,2,0,3,1,2,0.344348,0.34847,0.804783,0.179117,65,1769,1834 +426,2012-03-01,1,1,3,0,4,1,1,0.485833,0.475371,0.615417,0.226987,325,4665,4990 +427,2012-03-02,1,1,3,0,5,1,2,0.353333,0.359842,0.657083,0.144904,246,2948,3194 +428,2012-03-03,1,1,3,0,6,0,2,0.414167,0.413492,0.62125,0.161079,956,3110,4066 +429,2012-03-04,1,1,3,0,0,0,1,0.325833,0.303021,0.403333,0.334571,710,2713,3423 +430,2012-03-05,1,1,3,0,1,1,1,0.243333,0.241171,0.50625,0.228858,203,3130,3333 +431,2012-03-06,1,1,3,0,2,1,1,0.258333,0.255042,0.456667,0.200875,221,3735,3956 +432,2012-03-07,1,1,3,0,3,1,1,0.404167,0.3851,0.513333,0.345779,432,4484,4916 +433,2012-03-08,1,1,3,0,4,1,1,0.5275,0.524604,0.5675,0.441563,486,4896,5382 +434,2012-03-09,1,1,3,0,5,1,2,0.410833,0.397083,0.407083,0.4148,447,4122,4569 +435,2012-03-10,1,1,3,0,6,0,1,0.2875,0.277767,0.350417,0.22575,968,3150,4118 +436,2012-03-11,1,1,3,0,0,0,1,0.361739,0.35967,0.476957,0.222587,1658,3253,4911 +437,2012-03-12,1,1,3,0,1,1,1,0.466667,0.459592,0.489167,0.207713,838,4460,5298 +438,2012-03-13,1,1,3,0,2,1,1,0.565,0.542929,0.6175,0.23695,762,5085,5847 +439,2012-03-14,1,1,3,0,3,1,1,0.5725,0.548617,0.507083,0.115062,997,5315,6312 +440,2012-03-15,1,1,3,0,4,1,1,0.5575,0.532825,0.579583,0.149883,1005,5187,6192 +441,2012-03-16,1,1,3,0,5,1,2,0.435833,0.436229,0.842083,0.113192,548,3830,4378 +442,2012-03-17,1,1,3,0,6,0,2,0.514167,0.505046,0.755833,0.110704,3155,4681,7836 +443,2012-03-18,1,1,3,0,0,0,2,0.4725,0.464,0.81,0.126883,2207,3685,5892 +444,2012-03-19,1,1,3,0,1,1,1,0.545,0.532821,0.72875,0.162317,982,5171,6153 +445,2012-03-20,1,1,3,0,2,1,1,0.560833,0.538533,0.807917,0.121271,1051,5042,6093 +446,2012-03-21,2,1,3,0,3,1,2,0.531667,0.513258,0.82125,0.0895583,1122,5108,6230 +447,2012-03-22,2,1,3,0,4,1,1,0.554167,0.531567,0.83125,0.117562,1334,5537,6871 +448,2012-03-23,2,1,3,0,5,1,2,0.601667,0.570067,0.694167,0.1163,2469,5893,8362 +449,2012-03-24,2,1,3,0,6,0,2,0.5025,0.486733,0.885417,0.192783,1033,2339,3372 +450,2012-03-25,2,1,3,0,0,0,2,0.4375,0.437488,0.880833,0.220775,1532,3464,4996 +451,2012-03-26,2,1,3,0,1,1,1,0.445833,0.43875,0.477917,0.386821,795,4763,5558 +452,2012-03-27,2,1,3,0,2,1,1,0.323333,0.315654,0.29,0.187192,531,4571,5102 +453,2012-03-28,2,1,3,0,3,1,1,0.484167,0.47095,0.48125,0.291671,674,5024,5698 +454,2012-03-29,2,1,3,0,4,1,1,0.494167,0.482304,0.439167,0.31965,834,5299,6133 +455,2012-03-30,2,1,3,0,5,1,2,0.37,0.375621,0.580833,0.138067,796,4663,5459 +456,2012-03-31,2,1,3,0,6,0,2,0.424167,0.421708,0.738333,0.250617,2301,3934,6235 +457,2012-04-01,2,1,4,0,0,0,2,0.425833,0.417287,0.67625,0.172267,2347,3694,6041 +458,2012-04-02,2,1,4,0,1,1,1,0.433913,0.427513,0.504348,0.312139,1208,4728,5936 +459,2012-04-03,2,1,4,0,2,1,1,0.466667,0.461483,0.396667,0.100133,1348,5424,6772 +460,2012-04-04,2,1,4,0,3,1,1,0.541667,0.53345,0.469583,0.180975,1058,5378,6436 +461,2012-04-05,2,1,4,0,4,1,1,0.435,0.431163,0.374167,0.219529,1192,5265,6457 +462,2012-04-06,2,1,4,0,5,1,1,0.403333,0.390767,0.377083,0.300388,1807,4653,6460 +463,2012-04-07,2,1,4,0,6,0,1,0.4375,0.426129,0.254167,0.274871,3252,3605,6857 +464,2012-04-08,2,1,4,0,0,0,1,0.5,0.492425,0.275833,0.232596,2230,2939,5169 +465,2012-04-09,2,1,4,0,1,1,1,0.489167,0.476638,0.3175,0.358196,905,4680,5585 +466,2012-04-10,2,1,4,0,2,1,1,0.446667,0.436233,0.435,0.249375,819,5099,5918 +467,2012-04-11,2,1,4,0,3,1,1,0.348696,0.337274,0.469565,0.295274,482,4380,4862 +468,2012-04-12,2,1,4,0,4,1,1,0.3975,0.387604,0.46625,0.290429,663,4746,5409 +469,2012-04-13,2,1,4,0,5,1,1,0.4425,0.431808,0.408333,0.155471,1252,5146,6398 +470,2012-04-14,2,1,4,0,6,0,1,0.495,0.487996,0.502917,0.190917,2795,4665,7460 +471,2012-04-15,2,1,4,0,0,0,1,0.606667,0.573875,0.507917,0.225129,2846,4286,7132 +472,2012-04-16,2,1,4,1,1,0,1,0.664167,0.614925,0.561667,0.284829,1198,5172,6370 +473,2012-04-17,2,1,4,0,2,1,1,0.608333,0.598487,0.390417,0.273629,989,5702,6691 +474,2012-04-18,2,1,4,0,3,1,2,0.463333,0.457038,0.569167,0.167912,347,4020,4367 +475,2012-04-19,2,1,4,0,4,1,1,0.498333,0.493046,0.6125,0.0659292,846,5719,6565 +476,2012-04-20,2,1,4,0,5,1,1,0.526667,0.515775,0.694583,0.149871,1340,5950,7290 +477,2012-04-21,2,1,4,0,6,0,1,0.57,0.542921,0.682917,0.283587,2541,4083,6624 +478,2012-04-22,2,1,4,0,0,0,3,0.396667,0.389504,0.835417,0.344546,120,907,1027 +479,2012-04-23,2,1,4,0,1,1,2,0.321667,0.301125,0.766667,0.303496,195,3019,3214 +480,2012-04-24,2,1,4,0,2,1,1,0.413333,0.405283,0.454167,0.249383,518,5115,5633 +481,2012-04-25,2,1,4,0,3,1,1,0.476667,0.470317,0.427917,0.118792,655,5541,6196 +482,2012-04-26,2,1,4,0,4,1,2,0.498333,0.483583,0.756667,0.176625,475,4551,5026 +483,2012-04-27,2,1,4,0,5,1,1,0.4575,0.452637,0.400833,0.347633,1014,5219,6233 +484,2012-04-28,2,1,4,0,6,0,2,0.376667,0.377504,0.489583,0.129975,1120,3100,4220 +485,2012-04-29,2,1,4,0,0,0,1,0.458333,0.450121,0.587083,0.116908,2229,4075,6304 +486,2012-04-30,2,1,4,0,1,1,2,0.464167,0.457696,0.57,0.171638,665,4907,5572 +487,2012-05-01,2,1,5,0,2,1,2,0.613333,0.577021,0.659583,0.156096,653,5087,5740 +488,2012-05-02,2,1,5,0,3,1,1,0.564167,0.537896,0.797083,0.138058,667,5502,6169 +489,2012-05-03,2,1,5,0,4,1,2,0.56,0.537242,0.768333,0.133696,764,5657,6421 +490,2012-05-04,2,1,5,0,5,1,1,0.6275,0.590917,0.735417,0.162938,1069,5227,6296 +491,2012-05-05,2,1,5,0,6,0,2,0.621667,0.584608,0.756667,0.152992,2496,4387,6883 +492,2012-05-06,2,1,5,0,0,0,2,0.5625,0.546737,0.74,0.149879,2135,4224,6359 +493,2012-05-07,2,1,5,0,1,1,2,0.5375,0.527142,0.664167,0.230721,1008,5265,6273 +494,2012-05-08,2,1,5,0,2,1,2,0.581667,0.557471,0.685833,0.296029,738,4990,5728 +495,2012-05-09,2,1,5,0,3,1,2,0.575,0.553025,0.744167,0.216412,620,4097,4717 +496,2012-05-10,2,1,5,0,4,1,1,0.505833,0.491783,0.552083,0.314063,1026,5546,6572 +497,2012-05-11,2,1,5,0,5,1,1,0.533333,0.520833,0.360417,0.236937,1319,5711,7030 +498,2012-05-12,2,1,5,0,6,0,1,0.564167,0.544817,0.480417,0.123133,2622,4807,7429 +499,2012-05-13,2,1,5,0,0,0,1,0.6125,0.585238,0.57625,0.225117,2172,3946,6118 +500,2012-05-14,2,1,5,0,1,1,2,0.573333,0.5499,0.789583,0.212692,342,2501,2843 +501,2012-05-15,2,1,5,0,2,1,2,0.611667,0.576404,0.794583,0.147392,625,4490,5115 +502,2012-05-16,2,1,5,0,3,1,1,0.636667,0.595975,0.697917,0.122512,991,6433,7424 +503,2012-05-17,2,1,5,0,4,1,1,0.593333,0.572613,0.52,0.229475,1242,6142,7384 +504,2012-05-18,2,1,5,0,5,1,1,0.564167,0.551121,0.523333,0.136817,1521,6118,7639 +505,2012-05-19,2,1,5,0,6,0,1,0.6,0.566908,0.45625,0.083975,3410,4884,8294 +506,2012-05-20,2,1,5,0,0,0,1,0.620833,0.583967,0.530417,0.254367,2704,4425,7129 +507,2012-05-21,2,1,5,0,1,1,2,0.598333,0.565667,0.81125,0.233204,630,3729,4359 +508,2012-05-22,2,1,5,0,2,1,2,0.615,0.580825,0.765833,0.118167,819,5254,6073 +509,2012-05-23,2,1,5,0,3,1,2,0.621667,0.584612,0.774583,0.102,766,4494,5260 +510,2012-05-24,2,1,5,0,4,1,1,0.655,0.6067,0.716667,0.172896,1059,5711,6770 +511,2012-05-25,2,1,5,0,5,1,1,0.68,0.627529,0.747083,0.14055,1417,5317,6734 +512,2012-05-26,2,1,5,0,6,0,1,0.6925,0.642696,0.7325,0.198992,2855,3681,6536 +513,2012-05-27,2,1,5,0,0,0,1,0.69,0.641425,0.697083,0.215171,3283,3308,6591 +514,2012-05-28,2,1,5,1,1,0,1,0.7125,0.6793,0.67625,0.196521,2557,3486,6043 +515,2012-05-29,2,1,5,0,2,1,1,0.7225,0.672992,0.684583,0.2954,880,4863,5743 +516,2012-05-30,2,1,5,0,3,1,2,0.656667,0.611129,0.67,0.134329,745,6110,6855 +517,2012-05-31,2,1,5,0,4,1,1,0.68,0.631329,0.492917,0.195279,1100,6238,7338 +518,2012-06-01,2,1,6,0,5,1,2,0.654167,0.607962,0.755417,0.237563,533,3594,4127 +519,2012-06-02,2,1,6,0,6,0,1,0.583333,0.566288,0.549167,0.186562,2795,5325,8120 +520,2012-06-03,2,1,6,0,0,0,1,0.6025,0.575133,0.493333,0.184087,2494,5147,7641 +521,2012-06-04,2,1,6,0,1,1,1,0.5975,0.578283,0.487083,0.284833,1071,5927,6998 +522,2012-06-05,2,1,6,0,2,1,2,0.540833,0.525892,0.613333,0.209575,968,6033,7001 +523,2012-06-06,2,1,6,0,3,1,1,0.554167,0.542292,0.61125,0.077125,1027,6028,7055 +524,2012-06-07,2,1,6,0,4,1,1,0.6025,0.569442,0.567083,0.15735,1038,6456,7494 +525,2012-06-08,2,1,6,0,5,1,1,0.649167,0.597862,0.467917,0.175383,1488,6248,7736 +526,2012-06-09,2,1,6,0,6,0,1,0.710833,0.648367,0.437083,0.144287,2708,4790,7498 +527,2012-06-10,2,1,6,0,0,0,1,0.726667,0.663517,0.538333,0.133721,2224,4374,6598 +528,2012-06-11,2,1,6,0,1,1,2,0.720833,0.659721,0.587917,0.207713,1017,5647,6664 +529,2012-06-12,2,1,6,0,2,1,2,0.653333,0.597875,0.833333,0.214546,477,4495,4972 +530,2012-06-13,2,1,6,0,3,1,1,0.655833,0.611117,0.582083,0.343279,1173,6248,7421 +531,2012-06-14,2,1,6,0,4,1,1,0.648333,0.624383,0.569583,0.253733,1180,6183,7363 +532,2012-06-15,2,1,6,0,5,1,1,0.639167,0.599754,0.589583,0.176617,1563,6102,7665 +533,2012-06-16,2,1,6,0,6,0,1,0.631667,0.594708,0.504167,0.166667,2963,4739,7702 +534,2012-06-17,2,1,6,0,0,0,1,0.5925,0.571975,0.59875,0.144904,2634,4344,6978 +535,2012-06-18,2,1,6,0,1,1,2,0.568333,0.544842,0.777917,0.174746,653,4446,5099 +536,2012-06-19,2,1,6,0,2,1,1,0.688333,0.654692,0.69,0.148017,968,5857,6825 +537,2012-06-20,2,1,6,0,3,1,1,0.7825,0.720975,0.592083,0.113812,872,5339,6211 +538,2012-06-21,3,1,6,0,4,1,1,0.805833,0.752542,0.567917,0.118787,778,5127,5905 +539,2012-06-22,3,1,6,0,5,1,1,0.7775,0.724121,0.57375,0.182842,964,4859,5823 +540,2012-06-23,3,1,6,0,6,0,1,0.731667,0.652792,0.534583,0.179721,2657,4801,7458 +541,2012-06-24,3,1,6,0,0,0,1,0.743333,0.674254,0.479167,0.145525,2551,4340,6891 +542,2012-06-25,3,1,6,0,1,1,1,0.715833,0.654042,0.504167,0.300383,1139,5640,6779 +543,2012-06-26,3,1,6,0,2,1,1,0.630833,0.594704,0.373333,0.347642,1077,6365,7442 +544,2012-06-27,3,1,6,0,3,1,1,0.6975,0.640792,0.36,0.271775,1077,6258,7335 +545,2012-06-28,3,1,6,0,4,1,1,0.749167,0.675512,0.4225,0.17165,921,5958,6879 +546,2012-06-29,3,1,6,0,5,1,1,0.834167,0.786613,0.48875,0.165417,829,4634,5463 +547,2012-06-30,3,1,6,0,6,0,1,0.765,0.687508,0.60125,0.161071,1455,4232,5687 +548,2012-07-01,3,1,7,0,0,0,1,0.815833,0.750629,0.51875,0.168529,1421,4110,5531 +549,2012-07-02,3,1,7,0,1,1,1,0.781667,0.702038,0.447083,0.195267,904,5323,6227 +550,2012-07-03,3,1,7,0,2,1,1,0.780833,0.70265,0.492083,0.126237,1052,5608,6660 +551,2012-07-04,3,1,7,1,3,0,1,0.789167,0.732337,0.53875,0.13495,2562,4841,7403 +552,2012-07-05,3,1,7,0,4,1,1,0.8275,0.761367,0.457917,0.194029,1405,4836,6241 +553,2012-07-06,3,1,7,0,5,1,1,0.828333,0.752533,0.450833,0.146142,1366,4841,6207 +554,2012-07-07,3,1,7,0,6,0,1,0.861667,0.804913,0.492083,0.163554,1448,3392,4840 +555,2012-07-08,3,1,7,0,0,0,1,0.8225,0.790396,0.57375,0.125629,1203,3469,4672 +556,2012-07-09,3,1,7,0,1,1,2,0.710833,0.654054,0.683333,0.180975,998,5571,6569 +557,2012-07-10,3,1,7,0,2,1,2,0.720833,0.664796,0.6675,0.151737,954,5336,6290 +558,2012-07-11,3,1,7,0,3,1,1,0.716667,0.650271,0.633333,0.151733,975,6289,7264 +559,2012-07-12,3,1,7,0,4,1,1,0.715833,0.654683,0.529583,0.146775,1032,6414,7446 +560,2012-07-13,3,1,7,0,5,1,2,0.731667,0.667933,0.485833,0.08085,1511,5988,7499 +561,2012-07-14,3,1,7,0,6,0,2,0.703333,0.666042,0.699167,0.143679,2355,4614,6969 +562,2012-07-15,3,1,7,0,0,0,1,0.745833,0.705196,0.717917,0.166667,1920,4111,6031 +563,2012-07-16,3,1,7,0,1,1,1,0.763333,0.724125,0.645,0.164187,1088,5742,6830 +564,2012-07-17,3,1,7,0,2,1,1,0.818333,0.755683,0.505833,0.114429,921,5865,6786 +565,2012-07-18,3,1,7,0,3,1,1,0.793333,0.745583,0.577083,0.137442,799,4914,5713 +566,2012-07-19,3,1,7,0,4,1,1,0.77,0.714642,0.600417,0.165429,888,5703,6591 +567,2012-07-20,3,1,7,0,5,1,2,0.665833,0.613025,0.844167,0.208967,747,5123,5870 +568,2012-07-21,3,1,7,0,6,0,3,0.595833,0.549912,0.865417,0.2133,1264,3195,4459 +569,2012-07-22,3,1,7,0,0,0,2,0.6675,0.623125,0.7625,0.0939208,2544,4866,7410 +570,2012-07-23,3,1,7,0,1,1,1,0.741667,0.690017,0.694167,0.138683,1135,5831,6966 +571,2012-07-24,3,1,7,0,2,1,1,0.750833,0.70645,0.655,0.211454,1140,6452,7592 +572,2012-07-25,3,1,7,0,3,1,1,0.724167,0.654054,0.45,0.1648,1383,6790,8173 +573,2012-07-26,3,1,7,0,4,1,1,0.776667,0.739263,0.596667,0.284813,1036,5825,6861 +574,2012-07-27,3,1,7,0,5,1,1,0.781667,0.734217,0.594583,0.152992,1259,5645,6904 +575,2012-07-28,3,1,7,0,6,0,1,0.755833,0.697604,0.613333,0.15735,2234,4451,6685 +576,2012-07-29,3,1,7,0,0,0,1,0.721667,0.667933,0.62375,0.170396,2153,4444,6597 +577,2012-07-30,3,1,7,0,1,1,1,0.730833,0.684987,0.66875,0.153617,1040,6065,7105 +578,2012-07-31,3,1,7,0,2,1,1,0.713333,0.662896,0.704167,0.165425,968,6248,7216 +579,2012-08-01,3,1,8,0,3,1,1,0.7175,0.667308,0.6775,0.141179,1074,6506,7580 +580,2012-08-02,3,1,8,0,4,1,1,0.7525,0.707088,0.659583,0.129354,983,6278,7261 +581,2012-08-03,3,1,8,0,5,1,2,0.765833,0.722867,0.6425,0.215792,1328,5847,7175 +582,2012-08-04,3,1,8,0,6,0,1,0.793333,0.751267,0.613333,0.257458,2345,4479,6824 +583,2012-08-05,3,1,8,0,0,0,1,0.769167,0.731079,0.6525,0.290421,1707,3757,5464 +584,2012-08-06,3,1,8,0,1,1,2,0.7525,0.710246,0.654167,0.129354,1233,5780,7013 +585,2012-08-07,3,1,8,0,2,1,2,0.735833,0.697621,0.70375,0.116908,1278,5995,7273 +586,2012-08-08,3,1,8,0,3,1,2,0.75,0.707717,0.672917,0.1107,1263,6271,7534 +587,2012-08-09,3,1,8,0,4,1,1,0.755833,0.699508,0.620417,0.1561,1196,6090,7286 +588,2012-08-10,3,1,8,0,5,1,2,0.715833,0.667942,0.715833,0.238813,1065,4721,5786 +589,2012-08-11,3,1,8,0,6,0,2,0.6925,0.638267,0.732917,0.206479,2247,4052,6299 +590,2012-08-12,3,1,8,0,0,0,1,0.700833,0.644579,0.530417,0.122512,2182,4362,6544 +591,2012-08-13,3,1,8,0,1,1,1,0.720833,0.662254,0.545417,0.136212,1207,5676,6883 +592,2012-08-14,3,1,8,0,2,1,1,0.726667,0.676779,0.686667,0.169158,1128,5656,6784 +593,2012-08-15,3,1,8,0,3,1,1,0.706667,0.654037,0.619583,0.169771,1198,6149,7347 +594,2012-08-16,3,1,8,0,4,1,1,0.719167,0.654688,0.519167,0.141796,1338,6267,7605 +595,2012-08-17,3,1,8,0,5,1,1,0.723333,0.2424,0.570833,0.231354,1483,5665,7148 +596,2012-08-18,3,1,8,0,6,0,1,0.678333,0.618071,0.603333,0.177867,2827,5038,7865 +597,2012-08-19,3,1,8,0,0,0,2,0.635833,0.603554,0.711667,0.08645,1208,3341,4549 +598,2012-08-20,3,1,8,0,1,1,2,0.635833,0.595967,0.734167,0.129979,1026,5504,6530 +599,2012-08-21,3,1,8,0,2,1,1,0.649167,0.601025,0.67375,0.0727708,1081,5925,7006 +600,2012-08-22,3,1,8,0,3,1,1,0.6675,0.621854,0.677083,0.0702833,1094,6281,7375 +601,2012-08-23,3,1,8,0,4,1,1,0.695833,0.637008,0.635833,0.0845958,1363,6402,7765 +602,2012-08-24,3,1,8,0,5,1,2,0.7025,0.6471,0.615,0.0721458,1325,6257,7582 +603,2012-08-25,3,1,8,0,6,0,2,0.661667,0.618696,0.712917,0.244408,1829,4224,6053 +604,2012-08-26,3,1,8,0,0,0,2,0.653333,0.595996,0.845833,0.228858,1483,3772,5255 +605,2012-08-27,3,1,8,0,1,1,1,0.703333,0.654688,0.730417,0.128733,989,5928,6917 +606,2012-08-28,3,1,8,0,2,1,1,0.728333,0.66605,0.62,0.190925,935,6105,7040 +607,2012-08-29,3,1,8,0,3,1,1,0.685,0.635733,0.552083,0.112562,1177,6520,7697 +608,2012-08-30,3,1,8,0,4,1,1,0.706667,0.652779,0.590417,0.0771167,1172,6541,7713 +609,2012-08-31,3,1,8,0,5,1,1,0.764167,0.6894,0.5875,0.168533,1433,5917,7350 +610,2012-09-01,3,1,9,0,6,0,2,0.753333,0.702654,0.638333,0.113187,2352,3788,6140 +611,2012-09-02,3,1,9,0,0,0,2,0.696667,0.649,0.815,0.0640708,2613,3197,5810 +612,2012-09-03,3,1,9,1,1,0,1,0.7075,0.661629,0.790833,0.151121,1965,4069,6034 +613,2012-09-04,3,1,9,0,2,1,1,0.725833,0.686888,0.755,0.236321,867,5997,6864 +614,2012-09-05,3,1,9,0,3,1,1,0.736667,0.708983,0.74125,0.187808,832,6280,7112 +615,2012-09-06,3,1,9,0,4,1,2,0.696667,0.655329,0.810417,0.142421,611,5592,6203 +616,2012-09-07,3,1,9,0,5,1,1,0.703333,0.657204,0.73625,0.171646,1045,6459,7504 +617,2012-09-08,3,1,9,0,6,0,2,0.659167,0.611121,0.799167,0.281104,1557,4419,5976 +618,2012-09-09,3,1,9,0,0,0,1,0.61,0.578925,0.5475,0.224496,2570,5657,8227 +619,2012-09-10,3,1,9,0,1,1,1,0.583333,0.565654,0.50375,0.258713,1118,6407,7525 +620,2012-09-11,3,1,9,0,2,1,1,0.5775,0.554292,0.52,0.0920542,1070,6697,7767 +621,2012-09-12,3,1,9,0,3,1,1,0.599167,0.570075,0.577083,0.131846,1050,6820,7870 +622,2012-09-13,3,1,9,0,4,1,1,0.6125,0.579558,0.637083,0.0827208,1054,6750,7804 +623,2012-09-14,3,1,9,0,5,1,1,0.633333,0.594083,0.6725,0.103863,1379,6630,8009 +624,2012-09-15,3,1,9,0,6,0,1,0.608333,0.585867,0.501667,0.247521,3160,5554,8714 +625,2012-09-16,3,1,9,0,0,0,1,0.58,0.563125,0.57,0.0901833,2166,5167,7333 +626,2012-09-17,3,1,9,0,1,1,2,0.580833,0.55305,0.734583,0.151742,1022,5847,6869 +627,2012-09-18,3,1,9,0,2,1,2,0.623333,0.565067,0.8725,0.357587,371,3702,4073 +628,2012-09-19,3,1,9,0,3,1,1,0.5525,0.540404,0.536667,0.215175,788,6803,7591 +629,2012-09-20,3,1,9,0,4,1,1,0.546667,0.532192,0.618333,0.118167,939,6781,7720 +630,2012-09-21,3,1,9,0,5,1,1,0.599167,0.571971,0.66875,0.154229,1250,6917,8167 +631,2012-09-22,3,1,9,0,6,0,1,0.65,0.610488,0.646667,0.283583,2512,5883,8395 +632,2012-09-23,4,1,9,0,0,0,1,0.529167,0.518933,0.467083,0.223258,2454,5453,7907 +633,2012-09-24,4,1,9,0,1,1,1,0.514167,0.502513,0.492917,0.142404,1001,6435,7436 +634,2012-09-25,4,1,9,0,2,1,1,0.55,0.544179,0.57,0.236321,845,6693,7538 +635,2012-09-26,4,1,9,0,3,1,1,0.635,0.596613,0.630833,0.2444,787,6946,7733 +636,2012-09-27,4,1,9,0,4,1,2,0.65,0.607975,0.690833,0.134342,751,6642,7393 +637,2012-09-28,4,1,9,0,5,1,2,0.619167,0.585863,0.69,0.164179,1045,6370,7415 +638,2012-09-29,4,1,9,0,6,0,1,0.5425,0.530296,0.542917,0.227604,2589,5966,8555 +639,2012-09-30,4,1,9,0,0,0,1,0.526667,0.517663,0.583333,0.134958,2015,4874,6889 +640,2012-10-01,4,1,10,0,1,1,2,0.520833,0.512,0.649167,0.0908042,763,6015,6778 +641,2012-10-02,4,1,10,0,2,1,3,0.590833,0.542333,0.871667,0.104475,315,4324,4639 +642,2012-10-03,4,1,10,0,3,1,2,0.6575,0.599133,0.79375,0.0665458,728,6844,7572 +643,2012-10-04,4,1,10,0,4,1,2,0.6575,0.607975,0.722917,0.117546,891,6437,7328 +644,2012-10-05,4,1,10,0,5,1,1,0.615,0.580187,0.6275,0.10635,1516,6640,8156 +645,2012-10-06,4,1,10,0,6,0,1,0.554167,0.538521,0.664167,0.268025,3031,4934,7965 +646,2012-10-07,4,1,10,0,0,0,2,0.415833,0.419813,0.708333,0.141162,781,2729,3510 +647,2012-10-08,4,1,10,1,1,0,2,0.383333,0.387608,0.709583,0.189679,874,4604,5478 +648,2012-10-09,4,1,10,0,2,1,2,0.446667,0.438112,0.761667,0.1903,601,5791,6392 +649,2012-10-10,4,1,10,0,3,1,1,0.514167,0.503142,0.630833,0.187821,780,6911,7691 +650,2012-10-11,4,1,10,0,4,1,1,0.435,0.431167,0.463333,0.181596,834,6736,7570 +651,2012-10-12,4,1,10,0,5,1,1,0.4375,0.433071,0.539167,0.235092,1060,6222,7282 +652,2012-10-13,4,1,10,0,6,0,1,0.393333,0.391396,0.494583,0.146142,2252,4857,7109 +653,2012-10-14,4,1,10,0,0,0,1,0.521667,0.508204,0.640417,0.278612,2080,4559,6639 +654,2012-10-15,4,1,10,0,1,1,2,0.561667,0.53915,0.7075,0.296037,760,5115,5875 +655,2012-10-16,4,1,10,0,2,1,1,0.468333,0.460846,0.558333,0.182221,922,6612,7534 +656,2012-10-17,4,1,10,0,3,1,1,0.455833,0.450108,0.692917,0.101371,979,6482,7461 +657,2012-10-18,4,1,10,0,4,1,2,0.5225,0.512625,0.728333,0.236937,1008,6501,7509 +658,2012-10-19,4,1,10,0,5,1,2,0.563333,0.537896,0.815,0.134954,753,4671,5424 +659,2012-10-20,4,1,10,0,6,0,1,0.484167,0.472842,0.572917,0.117537,2806,5284,8090 +660,2012-10-21,4,1,10,0,0,0,1,0.464167,0.456429,0.51,0.166054,2132,4692,6824 +661,2012-10-22,4,1,10,0,1,1,1,0.4875,0.482942,0.568333,0.0814833,830,6228,7058 +662,2012-10-23,4,1,10,0,2,1,1,0.544167,0.530304,0.641667,0.0945458,841,6625,7466 +663,2012-10-24,4,1,10,0,3,1,1,0.5875,0.558721,0.63625,0.0727792,795,6898,7693 +664,2012-10-25,4,1,10,0,4,1,2,0.55,0.529688,0.800417,0.124375,875,6484,7359 +665,2012-10-26,4,1,10,0,5,1,2,0.545833,0.52275,0.807083,0.132467,1182,6262,7444 +666,2012-10-27,4,1,10,0,6,0,2,0.53,0.515133,0.72,0.235692,2643,5209,7852 +667,2012-10-28,4,1,10,0,0,0,2,0.4775,0.467771,0.694583,0.398008,998,3461,4459 +668,2012-10-29,4,1,10,0,1,1,3,0.44,0.4394,0.88,0.3582,2,20,22 +669,2012-10-30,4,1,10,0,2,1,2,0.318182,0.309909,0.825455,0.213009,87,1009,1096 +670,2012-10-31,4,1,10,0,3,1,2,0.3575,0.3611,0.666667,0.166667,419,5147,5566 +671,2012-11-01,4,1,11,0,4,1,2,0.365833,0.369942,0.581667,0.157346,466,5520,5986 +672,2012-11-02,4,1,11,0,5,1,1,0.355,0.356042,0.522083,0.266175,618,5229,5847 +673,2012-11-03,4,1,11,0,6,0,2,0.343333,0.323846,0.49125,0.270529,1029,4109,5138 +674,2012-11-04,4,1,11,0,0,0,1,0.325833,0.329538,0.532917,0.179108,1201,3906,5107 +675,2012-11-05,4,1,11,0,1,1,1,0.319167,0.308075,0.494167,0.236325,378,4881,5259 +676,2012-11-06,4,1,11,0,2,1,1,0.280833,0.281567,0.567083,0.173513,466,5220,5686 +677,2012-11-07,4,1,11,0,3,1,2,0.295833,0.274621,0.5475,0.304108,326,4709,5035 +678,2012-11-08,4,1,11,0,4,1,1,0.352174,0.341891,0.333478,0.347835,340,4975,5315 +679,2012-11-09,4,1,11,0,5,1,1,0.361667,0.355413,0.540833,0.214558,709,5283,5992 +680,2012-11-10,4,1,11,0,6,0,1,0.389167,0.393937,0.645417,0.0578458,2090,4446,6536 +681,2012-11-11,4,1,11,0,0,0,1,0.420833,0.421713,0.659167,0.1275,2290,4562,6852 +682,2012-11-12,4,1,11,1,1,0,1,0.485,0.475383,0.741667,0.173517,1097,5172,6269 +683,2012-11-13,4,1,11,0,2,1,2,0.343333,0.323225,0.662917,0.342046,327,3767,4094 +684,2012-11-14,4,1,11,0,3,1,1,0.289167,0.281563,0.552083,0.199625,373,5122,5495 +685,2012-11-15,4,1,11,0,4,1,2,0.321667,0.324492,0.620417,0.152987,320,5125,5445 +686,2012-11-16,4,1,11,0,5,1,1,0.345,0.347204,0.524583,0.171025,484,5214,5698 +687,2012-11-17,4,1,11,0,6,0,1,0.325,0.326383,0.545417,0.179729,1313,4316,5629 +688,2012-11-18,4,1,11,0,0,0,1,0.3425,0.337746,0.692917,0.227612,922,3747,4669 +689,2012-11-19,4,1,11,0,1,1,2,0.380833,0.375621,0.623333,0.235067,449,5050,5499 +690,2012-11-20,4,1,11,0,2,1,2,0.374167,0.380667,0.685,0.082725,534,5100,5634 +691,2012-11-21,4,1,11,0,3,1,1,0.353333,0.364892,0.61375,0.103246,615,4531,5146 +692,2012-11-22,4,1,11,1,4,0,1,0.34,0.350371,0.580417,0.0528708,955,1470,2425 +693,2012-11-23,4,1,11,0,5,1,1,0.368333,0.378779,0.56875,0.148021,1603,2307,3910 +694,2012-11-24,4,1,11,0,6,0,1,0.278333,0.248742,0.404583,0.376871,532,1745,2277 +695,2012-11-25,4,1,11,0,0,0,1,0.245833,0.257583,0.468333,0.1505,309,2115,2424 +696,2012-11-26,4,1,11,0,1,1,1,0.313333,0.339004,0.535417,0.04665,337,4750,5087 +697,2012-11-27,4,1,11,0,2,1,2,0.291667,0.281558,0.786667,0.237562,123,3836,3959 +698,2012-11-28,4,1,11,0,3,1,1,0.296667,0.289762,0.50625,0.210821,198,5062,5260 +699,2012-11-29,4,1,11,0,4,1,1,0.28087,0.298422,0.555652,0.115522,243,5080,5323 +700,2012-11-30,4,1,11,0,5,1,1,0.298333,0.323867,0.649583,0.0584708,362,5306,5668 +701,2012-12-01,4,1,12,0,6,0,2,0.298333,0.316904,0.806667,0.0597042,951,4240,5191 +702,2012-12-02,4,1,12,0,0,0,2,0.3475,0.359208,0.823333,0.124379,892,3757,4649 +703,2012-12-03,4,1,12,0,1,1,1,0.4525,0.455796,0.7675,0.0827208,555,5679,6234 +704,2012-12-04,4,1,12,0,2,1,1,0.475833,0.469054,0.73375,0.174129,551,6055,6606 +705,2012-12-05,4,1,12,0,3,1,1,0.438333,0.428012,0.485,0.324021,331,5398,5729 +706,2012-12-06,4,1,12,0,4,1,1,0.255833,0.258204,0.50875,0.174754,340,5035,5375 +707,2012-12-07,4,1,12,0,5,1,2,0.320833,0.321958,0.764167,0.1306,349,4659,5008 +708,2012-12-08,4,1,12,0,6,0,2,0.381667,0.389508,0.91125,0.101379,1153,4429,5582 +709,2012-12-09,4,1,12,0,0,0,2,0.384167,0.390146,0.905417,0.157975,441,2787,3228 +710,2012-12-10,4,1,12,0,1,1,2,0.435833,0.435575,0.925,0.190308,329,4841,5170 +711,2012-12-11,4,1,12,0,2,1,2,0.353333,0.338363,0.596667,0.296037,282,5219,5501 +712,2012-12-12,4,1,12,0,3,1,2,0.2975,0.297338,0.538333,0.162937,310,5009,5319 +713,2012-12-13,4,1,12,0,4,1,1,0.295833,0.294188,0.485833,0.174129,425,5107,5532 +714,2012-12-14,4,1,12,0,5,1,1,0.281667,0.294192,0.642917,0.131229,429,5182,5611 +715,2012-12-15,4,1,12,0,6,0,1,0.324167,0.338383,0.650417,0.10635,767,4280,5047 +716,2012-12-16,4,1,12,0,0,0,2,0.3625,0.369938,0.83875,0.100742,538,3248,3786 +717,2012-12-17,4,1,12,0,1,1,2,0.393333,0.4015,0.907083,0.0982583,212,4373,4585 +718,2012-12-18,4,1,12,0,2,1,1,0.410833,0.409708,0.66625,0.221404,433,5124,5557 +719,2012-12-19,4,1,12,0,3,1,1,0.3325,0.342162,0.625417,0.184092,333,4934,5267 +720,2012-12-20,4,1,12,0,4,1,2,0.33,0.335217,0.667917,0.132463,314,3814,4128 +721,2012-12-21,1,1,12,0,5,1,2,0.326667,0.301767,0.556667,0.374383,221,3402,3623 +722,2012-12-22,1,1,12,0,6,0,1,0.265833,0.236113,0.44125,0.407346,205,1544,1749 +723,2012-12-23,1,1,12,0,0,0,1,0.245833,0.259471,0.515417,0.133083,408,1379,1787 +724,2012-12-24,1,1,12,0,1,1,2,0.231304,0.2589,0.791304,0.0772304,174,746,920 +725,2012-12-25,1,1,12,1,2,0,2,0.291304,0.294465,0.734783,0.168726,440,573,1013 +726,2012-12-26,1,1,12,0,3,1,3,0.243333,0.220333,0.823333,0.316546,9,432,441 +727,2012-12-27,1,1,12,0,4,1,2,0.254167,0.226642,0.652917,0.350133,247,1867,2114 +728,2012-12-28,1,1,12,0,5,1,2,0.253333,0.255046,0.59,0.155471,644,2451,3095 +729,2012-12-29,1,1,12,0,6,0,2,0.253333,0.2424,0.752917,0.124383,159,1182,1341 +730,2012-12-30,1,1,12,0,0,0,1,0.255833,0.2317,0.483333,0.350754,364,1432,1796 +731,2012-12-31,1,1,12,0,1,1,2,0.215833,0.223487,0.5775,0.154846,439,2290,2729 diff --git a/inst/extdata/train_index.rds b/inst/extdata/train_index.rds new file mode 100644 index 0000000000000000000000000000000000000000..04cf1aac9c7166eb6900a4ea7eeb75ab4473f042 Binary files /dev/null and b/inst/extdata/train_index.rds differ diff --git a/inst/scripts/Beeswarm_illustration.R b/inst/scripts/Beeswarm_illustration.R new file mode 100644 index 0000000000000000000000000000000000000000..72b61cce73dca0e1a493805530d80c159a3c902b --- /dev/null +++ b/inst/scripts/Beeswarm_illustration.R @@ -0,0 +1,559 @@ +# Functions ------------------------------------------------------------------------------------------------------- +plot_shapr <- function(x, + plot_type = "bar", + digits = 3, + index_x_explain = NULL, + top_k_features = NULL, + col = NULL, # first increasing color, then decreasing color + bar_plot_phi0 = TRUE, + bar_plot_order = "largest_first", + scatter_features = NULL, + scatter_hist = TRUE, + ...) { + if (!requireNamespace("ggplot2", quietly = TRUE)) { + stop("ggplot2 is not installed. Please run install.packages('ggplot2')") + } + if (!(plot_type %in% c("bar", "waterfall", "scatter", "beeswarm"))) { + stop(paste(plot_type, "is an invalid plot type. Try plot_type='bar', plot_type='waterfall', + plot_type='scatter', or plot_type='beeswarm'.")) + } + if (!(bar_plot_order %in% c("largest_first", "smallest_first", "original"))) { + stop(paste(bar_plot_order, "is an invalid plot order. Try bar_plot_order='largest_first', + bar_plot_order='smallest_first' or bar_plot_order='original'.")) + } + + if (is.null(index_x_explain)) index_x_explain <- seq(x$internal$parameters$n_explain) + if (is.null(top_k_features)) top_k_features <- x$internal$parameters$n_features + 1 + + is_groupwise <- x$internal$parameters$is_groupwise + + # melting Kshap + shap_names <- colnames(x$shapley_values_est)[-1] + dt_shap <- round(data.table::copy(x$shapley_values_est), digits = digits) + dt_shap[, id := .I] + dt_shap_long <- data.table::melt(dt_shap, id.vars = "id", value.name = "phi") + dt_shap_long[, sign := factor(sign(phi), levels = c(1, -1), labels = c("Increases", "Decreases"))] + + # Converting and melting Xtest + if (!is_groupwise) { + desc_mat <- trimws(format(x$internal$data$x_explain, digits = digits)) + for (i in seq_len(ncol(desc_mat))) { + desc_mat[, i] <- paste0(shap_names[i], " = ", desc_mat[, i]) + } + } else { + desc_mat <- trimws(format(x$shapley_values_est[, -1], digits = digits)) + for (i in seq_len(ncol(desc_mat))) { + desc_mat[, i] <- paste0(shap_names[i]) + } + } + + dt_desc <- data.table::as.data.table(cbind(none = "none", desc_mat)) + dt_desc_long <- data.table::melt(dt_desc[, id := .I], id.vars = "id", value.name = "description") + + # Data table for plotting + dt_plot <- merge(dt_shap_long, dt_desc_long) + + # Adding the predictions + dt_pred <- data.table::data.table(id = dt_shap$id, pred = x$pred_explain) + dt_plot <- merge(dt_plot, dt_pred, by = "id") + + # Adding header for each individual plot + dt_plot[, header := paste0("id: ", id, ", pred = ", format(pred, digits = digits + 1))] + + if (plot_type == "scatter" || plot_type == "beeswarm") { + # Add feature values to data table + dt_feature_vals <- data.table::copy(x$internal$data$x_explain) + dt_feature_vals <- as.data.table(cbind(none = NA, dt_feature_vals)) + dt_feature_vals[, id := .I] + + # Deal with numeric and factor variables separately + factor_features <- dt_feature_vals[, sapply(.SD, function(x) is.factor(x) | is.character(x)), .SDcols = shap_names] + factor_features <- shap_names[factor_features] + + dt_feature_vals_long <- suppressWarnings(data.table::melt(dt_feature_vals, + id.vars = "id", + value.name = "feature_value" + )) + # this gives a warning because none-values are NA... + dt_plot <- merge(dt_plot, dt_feature_vals_long, by = c("id", "variable")) + } + + return(list(dt_plot = dt_plot, + col = col, + index_x_explain = index_x_explain, x = x, factor_features = factor_features)) +} + + +make_beeswarm_plot_old <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.4) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_new_cex <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(priority = "random", cex = 1 / length(index_x_explain)^(1/4)) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_new <- function(dt_plot, col, index_x_explain, x, factor_cols, + corral.method = "swarm", + corral.corral = "wrap", + corral.priority = "random", + corral.width = 0.75, + corral.cex = 0.75) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey70", linewidth = 0.5) + + ggbeeswarm::geom_beeswarm(method = corral.method, + corral = corral.corral, + priority = corral.priority, + corral.width = corral.width, + cex = corral.cex) + + ggplot2::coord_flip() + + ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey90", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + barwidth = 0.5, barheight = 10 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +make_beeswarm_plot_paper3 <- function(dt_plot, col, index_x_explain, x, factor_cols) { + if (!requireNamespace("ggbeeswarm", quietly = TRUE)) { + stop("geom_beeswarm is not installed. Please run install.packages('ggbeeswarm')") + } + + if (is.null(col)) { + col <- c("#F8766D", "yellow", "#00BA38") + } + if (!(length(col) %in% c(2, 3))) { + stop("'col' must be of length 2 or 3 when making beeswarm plot.") + } + + dt_plot <- dt_plot[variable != "none"] + + # Deal with factor variables + process_data <- shapr:::process_factor_data(dt_plot, factor_cols) + dt_plot <- process_data$dt_plot + + dt_train <- data.table::copy(x$internal$data$x_train) + dt_train <- suppressWarnings( # suppress warnings for coercion from int to double or to factor + data.table::melt(dt_train[, id := .I], id.vars = "id", value.name = "feature_value") + ) + dt_train <- shapr:::process_factor_data(dt_train, factor_cols)$dt_plot + dt_train[, `:=`(max = max(feature_value), min = min(feature_value)), by = variable] + dt_train <- dt_train[, .(variable, max, min)] + dt_train <- unique(dt_train) + dt_plot <- merge(dt_plot, dt_train, by = "variable") + + # scale obs. features value to their distance from min. feature value relative to the distance + # between min. and max. feature value in order to have a global color bar indicating magnitude + # of obs. feature value. + # The feature values are scaled wrt the training data + dt_plot[feature_value <= max & feature_value >= min, + feature_value_scaled := (feature_value - min) / (max - min), + by = variable + ] + dt_plot[feature_value > max, feature_value_scaled := 1] + dt_plot[feature_value < min, feature_value_scaled := 0] + + # make sure features with only one value are also scaled + dt_plot[is.nan(feature_value_scaled), feature_value_scaled := 0.5, by = variable] + + # Only plot the desired observations + dt_plot <- dt_plot[id %in% index_x_explain] + + # For factor variables, we want one line per factor level + # Give them a NA feature value to make the color grey + dt_plot[type == "factor", variable := description] + dt_plot[type == "factor", feature_value_scaled := NA] + + gg <- ggplot2::ggplot(dt_plot, ggplot2::aes(x = variable, y = phi, color = feature_value_scaled)) + + ggplot2::geom_hline(yintercept = 0, color = "grey60", linewidth = 0.5) + + #ggbeeswarm::geom_beeswarm(priority = "random", cex = 0.1) + + ggbeeswarm::geom_beeswarm(corral = "wrap", priority = "random", corral.width = 0.75) + + # the cex-parameter doesnt generalize well, should use corral but not available yet.... + ggplot2::coord_flip() + + #ggplot2::theme_classic() + + ggplot2::theme(panel.grid.major.y = ggplot2::element_line(colour = "grey75", linetype = "dashed")) + + ggplot2::labs(x = "", y = "Shapley value") + + ggplot2::guides(color = ggplot2::guide_colourbar( + ticks = FALSE, + #barwidth = 0.5, barheight = 10 + barwidth = 10, barheight = 0.5 + )) + + if (length(col) == 3) { # check is col-parameter is the default + gg <- gg + + ggplot2::scale_color_gradient2( + low = col[3], mid = col[2], high = col[1], + midpoint = 0.5, + breaks = c(0, 1), + limits = c(0, 1), + labels = c(" Low", "High "), + name = "Feature value: " + ) + + theme(legend.position = 'bottom') + + guides(fill = guide_legend(nrow = 1)) + } else if (length(col) == 2) { # allow user to specify three colors + gg <- gg + + ggplot2::scale_color_gradient( + low = col[2], + high = col[1], + breaks = c(0, 1), + limits = c(0, 1), + labels = c("Low", "High"), + name = "Feature \n value" + ) + } + + return(gg) +} + +# Run code from here ---------------------------------------------------------------------------------------------- +# Load necessary library +library(shapr) +library(xgboost) +library(data.table) +library(MASS) +library(ggplot2) +library(ggpubr) + +# Parameters +M <- 10 # Number of dimensions +N_train <- 1000 # Number of training observations +N_explain <- 5000 # Number of test observations +mu <- rep(0, M) # Mean vector, for example, a zero vector +rho <- 0.5 # Correlation coefficient (must be between -1 and 1) +beta = matrix(c(1, -2, 2, 0.5, 1.5, 0.25, 0.75, -0.5, 1, -2)[1:M]) + +# Construct the equi-correlation matrix +cov_matrix <- matrix(rho, nrow = M, ncol = M) +diag(cov_matrix) <- 1 # Set diagonal to 1 + +# Generate N observations from the multivariate normal distribution +set.seed(123) # Set seed for reproducibility +x_train <- mvrnorm(N_train, mu, cov_matrix) +x_explain <- mvrnorm(N_explain, mu, cov_matrix) + +y_train <- x_train %*% beta + rnorm(N_train, sd = 1) +y_explain <- x_explain %*% beta + rnorm(N_explain, sd = 1) + +x_train = as.data.table(x_train) +x_explain = as.data.table(x_explain) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Computing the actual Shapley values with kernelSHAP accounting for feature dependence using +# the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) +explanation <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 10, # Do not need precise Shapley values to illustrate the behaviour of beeswarm plot + n_MC_samples = 10 # Do not need precise Shapley values to illustrate the behaviour of beeswarm plot +) + +# Get the objects needed to make the beeswarm plot +tmp_list = plot_shapr(explanation, plot_type = "beeswarm") + + +## Plots ----------------------------------------------------------------------------------------------------------- +# Make the old and new beeswarm plot +list_figures = lapply(c(50, 100, 1000, 5000), function(N_explain_plot) { + # Old version have problem with runaway points: see https://github.com/eclarke/ggbeeswarm?tab=readme-ov-file#corral-runaway-points + gg_old <- make_beeswarm_plot_old(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + + gg_new_cex <- make_beeswarm_plot_new_cex(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + + gg_new <- make_beeswarm_plot_new(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features, + corral.corral = "wrap", # Default. Other options: "none" (default in geom_beeswarm), "gutter", "random", "omit" + corral.method = "swarm", # Default (and default in geom_beeswarm). Other options: "compactswarm", "hex", "square", "center + corral.priority = "random", # Default . Other options: "ascending" (default in geom_beeswarm), "descending", "density" + corral.width = 0.75, # Default. 0.9 is default in geom_beeswarm + corral.cex = 0.75) # Default. 1 is default in geom_beeswarm + + gg_paper3 <- make_beeswarm_plot_paper3(dt_plot = tmp_list$dt_plot, + col = tmp_list$col, + index_x_explain = tmp_list$index_x_explain[seq(N_explain_plot)], + x = tmp_list$x, + factor_cols = tmp_list$factor_features) + return(ggpubr::ggarrange(gg_old, gg_new_cex, gg_new, gg_paper3, labels = c("Old", "New_cex", "New", "Paper3"), nrow = 1, vjust = 2)) +}) + + +# 50 +list_figures[[1]] + +# 100 +list_figures[[2]] + +# 1000 +list_figures[[3]] + +# 5000 +list_figures[[4]] + +# Plot them together +ggpubr::ggarrange(list_figures[[1]], list_figures[[2]], list_figures[[3]], list_figures[[4]], labels = c(50, 100, 1000, 5000), ncol = 1, vjust = 1) diff --git a/inst/scripts/Compare_Conditional_and_Causal_Categorical.R b/inst/scripts/Compare_Conditional_and_Causal_Categorical.R new file mode 100644 index 0000000000000000000000000000000000000000..f30efa475c53e78d4f52f71e2f67ee9e3ef76b9b --- /dev/null +++ b/inst/scripts/Compare_Conditional_and_Causal_Categorical.R @@ -0,0 +1,167 @@ +# In this file, we compare the causal and conditional Shapley values for a categorical dataset. +# We see that "categorical" approach sometimes produce Shapley values of the opposite sign than +# the other approaches, but this happens for both causal and conditional Shapley values. +# I.e., there is likely no mistake in the cateogical causal Shapley value code. +{ + options(digits = 5) # To avoid round off errors when printing output on different systems + + set.seed(12345) + + data <- data.table::as.data.table(airquality) + data[, Month_factor := as.factor(Month)] + data[, Ozone_sub30 := (Ozone < 30) * 1] + data[, Ozone_sub30_factor := as.factor(Ozone_sub30)] + data[, Solar.R_factor := as.factor(cut(Solar.R, 10))] + data[, Wind_factor := as.factor(round(Wind))] + + data_complete <- data[complete.cases(airquality), ] + data_complete <- data_complete[sample(seq_len(.N))] + y_var_numeric <- "Ozone" + x_var_categorical <- c("Month_factor", "Ozone_sub30_factor", "Solar.R_factor", "Wind_factor") + data_train <- head(data_complete, -10) + data_explain <- tail(data_complete, 10) + x_train_categorical <- data_train[, ..x_var_categorical] + x_explain_categorical <- data_explain[, ..x_var_categorical] + lm_formula_categorical <- as.formula(paste0(y_var_numeric, " ~ ", paste0(x_var_categorical, collapse = " + "))) + model_lm_categorical <- lm(lm_formula_categorical, data = data_complete) + p0 <- data_train[, mean(get(y_var_numeric))] +} + +# Causal Shapley values ----- +causal_independence <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "independence", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE +) + +causal_categorical <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +# Warning CTREE is the slowest approach by far +causal_ctree <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +causal_vaeac <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "vaeac", + vaeac.epochs = 20, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +shapr::plot_SV_several_approaches(list( + ind = causal_independence, + cat = causal_categorical, + ctree = causal_ctree, + vaeac = causal_vaeac +)) + +# Conditional Shapley values ------ +conditional_independence <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "independence", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +conditional_categorical <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +# Warning CTREE is the slowest approach by far +conditional_ctree <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +conditional_vaeac <- explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "vaeac", + vaeac.epochs = 20, + phi0 = p0, + # asymmetric = FALSE, + # causal_ordering = list(3:4, 2, 1), + # confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 50, # Just for speed + verbose = c("basic", "convergence", "shapley", "vS_details"), + keep_samp_for_vS = TRUE, + iterative = FALSE +) + +shapr::plot_SV_several_approaches(list( + ind = conditional_independence, + cat = conditional_categorical, + ctree = conditional_ctree, + vaeac = conditional_vaeac +)) diff --git a/inst/scripts/Compare_categorical_prepare_data.R b/inst/scripts/Compare_categorical_prepare_data.R new file mode 100644 index 0000000000000000000000000000000000000000..dd913ee4d1a736d5647eb80937b3997799202c54 --- /dev/null +++ b/inst/scripts/Compare_categorical_prepare_data.R @@ -0,0 +1,563 @@ +# File with several proposals for new versions of the `compute_conditional_prob` function used by +# the categorical approach, which are much faster. +# The `compute_conditional_prob_shapr_old` computed a lot of unnecessary things, e.g., it compute the conditional +# prob for all colaitions and then threw away all results not relevant to the coalitions in the batch at the end. +# The `compute_conditional_prob_shapr_new` computes only the relevant stuff for the applicable coalitions in the batch. + +# The versions ---------------------------------------------------------------------------------------------------- +compute_conditional_prob <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain_copy = data.table::copy(x_explain)[,id := .I] + + # Loop over the combinations and convert to a single data table containing all the conditional probabilities + results = data.table::rbindlist(lapply(index_features, function(index_feature) { + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_feature,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain_copy[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination = joint_probability_dt[dt_conditional_feature_values, on = cond_cols, allow.cartesian = TRUE] + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # If we have a combination not in the joint prob, then we delete it + # TODO: or should we do something else? + # TODO: Comment out the printouts. Only used to debug + results_not_valid = results_id_combination[is.na(w)] + str_tmp = paste(sapply(results_not_valid$id, function(i) { + paste0("(id = ", i, ", ", paste(cond_cols, "=", results_not_valid[id == i,..cond_cols], collapse = ", "), ")") + }), collapse = ", ") + paste0("The following explicands where removed as they are not in `joint_probability_dt`: ", str_tmp, ".") + + # Return the data table where we remove the NA entries + return(results_id_combination[!is.na(w)]) + }), idcol = "id_combination", use.names = TRUE) + + # Update the index_features to their correct value + results[, id_combination := index_features[id_combination]] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results, c("id_combination", "id")) + data.table::setcolorder(results, c("id_combination", "id", feature_names)) + + return(results) +} + +compute_conditional_prob_merge <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain = data.table::copy(x_explain)[,id := .I] + + # Loop over the combinations and convert to a single data table containing all the conditional probabilities + results = data.table::rbindlist(lapply(index_features, function(index_feature) { + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_feature,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination <- data.table::merge.data.table(joint_probability_dt, dt_conditional_feature_values, by = cond_cols, allow.cartesian = TRUE) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # Return the data table + return(results_id_combination) + }), idcol = "id_combination", use.names = TRUE) + + # Update the index_features to their correct value + results[, id_combination := index_features[id_combination]] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results, c("id_combination", "id")) + data.table::setcolorder(results, c("id_combination", "id", feature_names)) + + return(results) +} + +compute_conditional_prob_merge_one_coalition <- function(S, index_features, x_explain, joint_probability_dt) { + if (length(index_features) != 1) stop("`index_features` must be single integer.") + + # Extract the feature names and add an id column to x_explain (copy as this changes `x_explain` outside the function) + feature_names = names(x_explain) + x_explain = data.table::copy(x_explain)[,id := .I] + + # Extract the feature names of the features we are to condition on + cond_cols <- feature_names[S[index_features,] == 1] + cond_cols_with_id = c("id", cond_cols) + + # Extract the feature values to condition and including the id column + dt_conditional_feature_values = x_explain[, cond_cols_with_id, with = FALSE] + + # Merge (right outer join) the joint_probability_dt data with the conditional feature values + results_id_combination <- data.table::merge.data.table(joint_probability_dt, dt_conditional_feature_values, by = cond_cols, allow.cartesian = TRUE) + + # Get the weights/conditional probabilities for each valid X_sbar conditioned on X_s for all explicands + results_id_combination[, w := joint_prob / sum(joint_prob), by = id] + results_id_combination[, c("id_all", "joint_prob") := NULL] + + # Set the index_features to their correct value + results_id_combination[, id_combination := index_features] + + # Set id_combination and id to be the keys and the two first columns for consistency with other methods + data.table::setkeyv(results_id_combination, c("id_combination", "id")) + data.table::setcolorder(results_id_combination, c("id_combination", "id", feature_names)) + + return(results_id_combination) +} + +compute_conditional_prob_shapr_old = function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the needed objects/variables + #x_train <- internal$data$x_train + #x_explain <- internal$data$x_explain + #joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + #X <- internal$objects$X + #S <- internal$objects$S + + # if (is.null(index_features)) { # 2,3 + # features <- X$features # list of [1], [2], [2, 3] + # } else { + # features <- X$features[index_features] # list of [1], + # } + feature_names <- names(x_explain) + + # 3 id columns: id, id_combination, and id_all + # id: for each x_explain observation + # id_combination: the rows of the S matrix + # id_all: identifies the unique combinations of feature values from + # the training data (not necessarily the ones in the explain data) + + + feature_conditioned <- paste0(feature_names, "_conditioned") + feature_conditioned_id <- c(feature_conditioned, "id") + + S_dt <- data.table::data.table(S) + S_dt[S_dt == 0] <- NA + S_dt[, id_combination := seq_len(nrow(S_dt))] + + data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + + # (1) Compute marginal probabilities + + # multiply table of probabilities nrow(S) times + joint_probability_mult <- joint_probability_dt[rep(id_all, nrow(S))] + + data.table::setkeyv(joint_probability_mult, "id_all") + j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix + + j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) # with zeros + j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) + + j_S_feat[which(is.na(j_S_feat_cond))] <- NA # with NAs + j_S_feat_with_NA <- data.table::as.data.table(j_S_feat) + + # now we have a data.table with the conditioned + # features and the feature value but no ids + data.table::setnames(j_S_feat_with_NA, feature_conditioned) + + j_S_no_conditioned_features <- data.table::copy(j_S_dt) + j_S_no_conditioned_features[, (feature_conditioned) := NULL] + + # dt with conditioned features (correct values) + ids + joint_prob + j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) # features match id_all + + # compute all marginal probabilities + marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] + + # (2) Compute conditional probabilities + + cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] + cond_dt[, cond_prob := joint_prob / marg_prob] + cond_dt[id_combination == 1, marg_prob := 0] + cond_dt[id_combination == 1, cond_prob := 1] + + # check marginal probabilities + cond_dt_unique <- unique(cond_dt, by = feature_conditioned) + check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), + by = "id_combination" + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all marginal probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + # make x_explain + data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + x_explain_with_id <- data.table::copy(x_explain)[, id := .I] + dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] + + # this is a really important step to get the proper "w" which will be used in compute_preds() + dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] + + cond_dt[, id_all := NULL] + dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] + + # check conditional probabilities + check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), + by = c("id_combination", "id") + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all conditional probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + setnames(dt, "cond_prob", "w") + data.table::setkeyv(dt, c("id_combination", "id")) + + # here we merge so that we only return the combintations found in our actual explain data + # this merge does not change the number of rows in dt + # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") + # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] + dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] + ret_col <- c("id_combination", "id", feature_names, "w") + dt_temp = dt[id_combination %in% index_features, mget(ret_col)] + + + return(dt_temp) +} + +compute_conditional_prob_shapr_new <- function(S, index_features, x_explain, joint_probability_dt) { + + # Extract the needed objects/variables + #x_train <- internal$data$x_train + #x_explain <- internal$data$x_explain + #joint_probability_dt <- internal$parameters$categorical.joint_prob_dt + #X <- internal$objects$X + #S <- internal$objects$S + + # if (is.null(index_features)) { # 2,3 + # features <- X$features # list of [1], [2], [2, 3] + # } else { + # features <- X$features[index_features] # list of [1], + # } + feature_names <- names(x_explain) + + # TODO: add + # For causal sampling, we use + # if (causal_sampling) + + # 3 id columns: id, id_combination, and id_all + # id: for each x_explain observation + # id_combination: the rows of the S matrix + # id_all: identifies the unique combinations of feature values from + # the training data (not necessarily the ones in the explain data) + + + feature_conditioned <- paste0(feature_names, "_conditioned") + feature_conditioned_id <- c(feature_conditioned, "id") + + S_dt <- data.table::data.table(S[index_features, , drop = FALSE]) + S_dt[S_dt == 0] <- NA + S_dt[, id_combination := index_features] + + data.table::setnames(S_dt, c(feature_conditioned, "id_combination")) + + # (1) Compute marginal probabilities + + # multiply table of probabilities length(index_features) times + joint_probability_mult <- joint_probability_dt[rep(id_all, length(index_features))] + + data.table::setkeyv(joint_probability_mult, "id_all") + j_S_dt <- cbind(joint_probability_mult, S_dt) # combine joint probability and S matrix + + j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) # with zeros + j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) + + j_S_feat[which(is.na(j_S_feat_cond))] <- NA # with NAs + j_S_feat_with_NA <- data.table::as.data.table(j_S_feat) + + # now we have a data.table with the conditioned + # features and the feature value but no ids + data.table::setnames(j_S_feat_with_NA, feature_conditioned) + + j_S_no_conditioned_features <- data.table::copy(j_S_dt) + j_S_no_conditioned_features[, (feature_conditioned) := NULL] + + # dt with conditioned features (correct values) + ids + joint_prob + j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) # features match id_all + + # compute all marginal probabilities + marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] + + # (2) Compute conditional probabilities + + cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] + cond_dt[, cond_prob := joint_prob / marg_prob] + #cond_dt[id_combination == 1, marg_prob := 0] + #cond_dt[id_combination == 1, cond_prob := 1] + + # check marginal probabilities + cond_dt_unique <- unique(cond_dt, by = feature_conditioned) + check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), + by = "id_combination" + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all marginal probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + # make x_explain + data.table::setkeyv(cond_dt, c("id_combination", "id_all")) + x_explain_with_id <- data.table::copy(x_explain)[, id := .I] + + # dt_just_explain <- rbindlist(lapply(seq(length(index_features)), function(index_features_i) { + # feature_names_now = feature_names[S[index_features[index_features_i],] == 1] + # cond_dt[x_explain_with_id, on = feature_names_now] + # }), use.names = TRUE, fill = TRUE) + + dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] + + # TODO: bare legge til at cond prob er veldig veldig lav? + + + # this is a really important step to get the proper "w" which will be used in compute_preds() + dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] + + cond_dt[, id_all := NULL] + + # dt <- rbindlist(lapply(seq(length(index_features)), function(index_features_i) { + # feature_conditioned_now = paste0(feature_names[S[index_features[index_features_i],] == 0], "_conditioned") + # cond_dt[dt_explain_just_conditioned, on = feature_conditioned_now, allow.cartesian = TRUE] + # }), use.names = TRUE, fill = TRUE) + + dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] + + + # check conditional probabilities + check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), + by = c("id_combination", "id") + ][["sum_prob"]] + if (!all(round(check) == 1)) { + print("Warning - not all conditional probabilities sum to 1. There could be a problem + with the joint probabilities. Consider checking.") + } + + setnames(dt, "cond_prob", "w") + data.table::setkeyv(dt, c("id_combination", "id")) + + # here we merge so that we only return the combintations found in our actual explain data + # this merge does not change the number of rows in dt + # dt <- merge(dt, x$X[, .(id_combination, n_features)], by = "id_combination") + # dt[n_features %in% c(0, ncol(x_explain)), w := 1.0] + # dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] + dt_temp = dt[, mget(c("id_combination", "id", feature_names, "w"))] + + return(dt_temp) +} + + +# compute_conditional_prob_shapr2 <- function(S, index_features, x_explain, joint_probability_dt) { +# # Extract the feature names +# feature_names <- names(x_explain) +# +# # Add an id column to x_explain +# x_explain = copy(x_explain)[, id := .I] +# +# # Filter the S matrix and create a data table with only relevant id_combinations +# relevant_S <- S[index_features, , drop = FALSE] +# S_dt <- data.table(relevant_S) +# S_dt[S_dt == 0] <- NA +# S_dt[, id_combination := index_features] +# +# # Define feature names with "_conditioned" +# feature_conditioned <- paste0(feature_names, "_conditioned") +# feature_conditioned_id <- c(feature_conditioned, "id") +# +# # Set column names for S_dt +# setnames(S_dt, c(feature_conditioned, "id_combination")) +# +# # Replicate the joint_probability_dt for the number of relevant id_combinations +# joint_probability_mult <- joint_probability_dt[rep(id_all, each = nrow(S_dt))] +# joint_probability_mult[, id_combination := rep(S_dt$id_combination, each = nrow(joint_probability_dt))] +# +# # Combine joint_probability_mult with S_dt +# j_S_dt <- cbind(joint_probability_mult, S_dt) +# +# # Convert features to matrix and condition them with NAs +# j_S_feat <- as.matrix(j_S_dt[, feature_names, with = FALSE]) +# j_S_feat_cond <- as.matrix(j_S_dt[, feature_conditioned, with = FALSE]) +# j_S_feat[is.na(j_S_feat_cond)] <- NA +# j_S_feat_with_NA <- as.data.table(j_S_feat) +# setnames(j_S_feat_with_NA, feature_conditioned) +# +# # Combine conditioned features with joint probabilities +# j_S_no_conditioned_features <- copy(j_S_dt) +# j_S_no_conditioned_features[, (feature_conditioned) := NULL] +# j_S_all_feat <- cbind(j_S_no_conditioned_features, j_S_feat_with_NA) +# +# # Compute marginal probabilities +# marg_dt <- j_S_all_feat[, .(marg_prob = sum(joint_prob)), by = feature_conditioned] +# +# # Compute conditional probabilities +# cond_dt <- j_S_all_feat[marg_dt, on = feature_conditioned] +# cond_dt[, cond_prob := joint_prob / marg_prob] +# cond_dt[id_combination == 1, marg_prob := 0] +# cond_dt[id_combination == 1, cond_prob := 1] +# +# # Check marginal probabilities +# cond_dt_unique <- unique(cond_dt, by = feature_conditioned) +# check <- cond_dt_unique[id_combination != 1][, .(sum_prob = sum(marg_prob)), by = "id_combination"][["sum_prob"]] +# if (!all(round(check) == 1)) { +# warning("Not all marginal probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") +# } +# +# # Merge with x_explain +# setkeyv(cond_dt, c("id_combination", "id_all")) +# x_explain_with_id <- copy(x_explain)[, id := .I] +# dt_just_explain <- cond_dt[x_explain_with_id, on = feature_names] +# +# # Prepare the explain data +# dt_explain_just_conditioned <- dt_just_explain[, feature_conditioned_id, with = FALSE] +# cond_dt[, id_all := NULL] +# dt <- cond_dt[dt_explain_just_conditioned, on = feature_conditioned, allow.cartesian = TRUE] +# +# # Check conditional probabilities +# check <- dt[id_combination != 1][, .(sum_prob = sum(cond_prob)), by = c("id_combination", "id")][["sum_prob"]] +# if (!all(round(check) == 1)) { +# warning("Not all conditional probabilities sum to 1. There could be a problem with the joint probabilities. Consider checking.") +# } +# +# # Rename and reorder columns +# setnames(dt, "cond_prob", "w") +# setkeyv(dt, c("id_combination", "id")) +# +# # Filter and return relevant combinations +# dt[id_combination %in% c(1, 2^ncol(x_explain)), w := 1.0] +# ret_col <- c("id_combination", "id", feature_names, "w") +# dt_temp <- dt[id_combination %in% index_features, ..ret_col] +# +# return(dt_temp) +# } + +# Comparing ------------------------------------------------------------------------------------------------------- +library(data.table) + +# Need to have loaded shapr for this to work (`devtools::load_all(".")`) +explanation = explain( + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + n_batches = 1, + timing = FALSE +) + +S = explanation$internal$objects$S +joint_probability_dt = explanation$internal$parameters$categorical.joint_prob_dt +x_explain = x_explain_categorical + +# Chose any values between 2 and 15 +index_features = 2:15 + +dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +all.equal(dt, shapr_new) +all.equal(merge, shapr_new) +all.equal(shapr_old, shapr_new) + +# Compare with only 1 combination (dt and merge are equally fast, shapr_old is 6 times slower) +index_features = 5 +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge_one_coalition = compute_conditional_prob_merge_one_coalition(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 500) +# FOR index_features = 2 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.596 1.136 1.535 0.028 0 0 +# 2 merge 500 1.640 1.167 1.527 0.035 0 0 +# 3 merge_one_coalition 500 1.405 1.000 1.324 0.024 0 0 +# 5 shapr_new 500 6.200 4.413 6.014 0.103 0 0 +# 4 shapr_old 500 11.203 7.974 10.032 0.267 0 0 + +# FOR index_features = 5 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.529 1.374 1.463 0.045 0 0 +# 2 merge 500 1.193 1.072 1.180 0.010 0 0 +# 3 merge_one_coalition 500 1.113 1.000 1.098 0.013 0 0 +# 5 shapr_new 500 5.705 5.126 5.599 0.068 0 0 +# 4 shapr_old 500 8.105 7.282 7.964 0.121 0 0 + +# FOR index_features = 12 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 1.679 1.119 1.623 0.031 0 0 +# 2 merge 500 1.553 1.035 1.520 0.020 0 0 +# 3 merge_one_coalition 500 1.501 1.000 1.463 0.019 0 0 +# 5 shapr_new 500 5.783 3.853 5.619 0.058 0 0 +# 4 shapr_old 500 9.833 6.551 9.389 0.269 0 0 + +# FOR index_features = 12 +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 500 2.561 1.891 1.996 0.094 0 0 +# 2 merge 500 1.599 1.181 1.520 0.026 0 0 +# 3 merge_one_coalition 500 1.354 1.000 1.337 0.013 0 0 +# 5 shapr_new 500 5.323 3.931 5.246 0.065 0 0 +# 4 shapr_old 500 8.170 6.034 8.019 0.131 0 0 + + +merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +merge_one_coalition = compute_conditional_prob_merge_one_coalition(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt) +all.equal(merge, merge_one_coalition) + + +# Compare with only 4 combination +index_features = c(2,6,9,12) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 0.961 1.016 0.940 0.013 0 0 +# 2 merge 100 0.946 1.000 0.919 0.013 0 0 +# 4 shapr_new 100 1.368 1.446 1.316 0.025 0 0 +# 3 shapr_old 100 2.046 2.163 1.950 0.051 0 0 + + +# Compare with half of the combinations +index_features = seq(2, 15, 2) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) + +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 1.614 1.075 1.559 0.028 0 0 +# 2 merge 100 1.758 1.171 1.623 0.042 0 0 +# 4 shapr_new 100 1.501 1.000 1.437 0.033 0 0 +# 3 shapr_old 100 2.001 1.333 1.920 0.038 0 0 + +# Compare with all the combinations +index_features = seq(2, 15) +rbenchmark::benchmark(dt = compute_conditional_prob(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + merge = compute_conditional_prob_merge(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_old = compute_conditional_prob_shapr_old(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + shapr_new = compute_conditional_prob_shapr_new(S = S, index_features = index_features, x_explain = x_explain, joint_probability_dt = joint_probability_dt), + replications = 100) + +# test replications elapsed relative user.self sys.self user.child sys.child +# 1 dt 100 3.435 2.426 3.286 0.077 0 0 +# 2 merge 100 3.511 2.480 3.373 0.070 0 0 +# 4 shapr_new 100 1.416 1.000 1.363 0.026 0 0 +# 3 shapr_old 100 2.153 1.520 2.006 0.045 0 0 + + diff --git a/inst/scripts/Heskes_bike_rental_illustration.R b/inst/scripts/Heskes_bike_rental_illustration.R new file mode 100644 index 0000000000000000000000000000000000000000..e74e48ce9d530d9503ce93f709cabe217d62c8c2 --- /dev/null +++ b/inst/scripts/Heskes_bike_rental_illustration.R @@ -0,0 +1,1087 @@ +# This file build on Pull Request https://github.com/NorskRegnesentral/shapr/pull/273 +# This file does not run on the iterative version. +# The point of the file was to replicate the plot values that Heskes obtained in their implementation +# to validate my implementation. + +# Set to true in order to save plots in the main folder +save_plots <- FALSE + + +# Sina plot ------------------------------------------------------------------------------------------------------- +#' Make a sina plot of the Shapley values computed using shapr. +#' +#' @param explanation shapr list containing an explanation produced by shapr::explain. +#' +#' @return ggplot2 object containing the sina plot. +#' @export +#' +#' @import tidyr +#' @import shapr +#' @import ggplot2 +#' @import ggforce +#' +#' @importFrom dplyr `%>%` +#' +#' @examples +#' # set parameters and random seed +#' set.seed(2020) +#' N <- 1000 +#' m <- 4 +#' sds <- runif(4, 0.5, 1.5) +#' pars <- runif(7, -1, 1) +#' +#' # Create data from a structural equation model +#' X_1 <- rnorm(N, sd = sds[1]) +#' Z <- rnorm(N, 1) +#' X_2 <- X_1 * pars[1] + Z * pars[2] + rnorm(N, sd = sds[2]) +#' X_3 <- X_1 * pars[3] + Z * pars[4] + rnorm(N, sd = sds[3]) +#' Y <- X_1 * pars[5] + X_2 * pars[6] + X_3 * pars[7] + rnorm(N, sd = sds[4]) +#' +#' # collecting data +#' mu_A <- rep(0, m) +#' X_A <- cbind(X_1, X_2, X_3) +#' dat_A <- cbind(X_A, Y) +#' cov_A <- cov(dat_A) +#' +#' model <- lm(Y ~ . + 0 , data = as.data.frame(dat_A)) +#' explainer <- shapr::shapr(X_A, model) +#' y_mean <- mean(Y) +#' +#' explanation_classic <- shapr::explain( +#' dat_A, +#' approach = "gaussian", +#' explainer = explainer, +#' phi0 = y_mean +#' ) +#' sina_plot(explanation_classic) +#' +#' explanation_causal <- shapr::explain( +#' dat_A, +#' approach = "causal", +#' explainer = explainer, +#' phi0 = y_mean, +#' ordering = list(1, c(2, 3)) +#' ) +#' sina_plot(explanation_causal) +#' +#' @seealso \link[SHAPforxgboost]{shap.plot.summary} +#' +#' @details Function adapted from \link[SHAPforxgboost]{shap.plot.summary}. +#' Copyright © 2020 - Yang Liu & Allan Just +#' +sina_plot <- function(explanation, seed = 123) { + set.seed(seed) + + shapley_values_est <- explanation$shapley_values_est[, -"none", drop = FALSE] + X_values <- explanation$internal$data$x_explain + + # If we are doing group Shapley, then we compute the mean feature value for each group for each explicand + if (explanation$internal$parameters$is_groupwise) { + feature_groups = explanation$internal$parameters$group + X_values <- X_values[, lapply(feature_groups, function(cols) rowMeans(.SD[, .SD, .SDcols = cols], na.rm = TRUE))] + #setnames(X_values, names(X_values), paste0(names(X_values), "_mean")) # Rename columns to reflect mean calculations + } + + data_long <- X_values %>% + tidyr::pivot_longer(everything()) %>% + dplyr::bind_cols( + explanation$shapley_values_est %>% + dplyr::select(-none) %>% + tidyr::pivot_longer(everything()) %>% + dplyr::select(-name) %>% + dplyr::rename(shap = value)) %>% + dplyr::mutate(name = factor(name, levels = rev(names(explanation$shapley_values_est)))) %>% + dplyr::group_by(name) %>% + dplyr::arrange(name) %>% + dplyr::mutate(mean_value = mean(value)) %>% + dplyr::mutate(std_value = (value - min(value)) / (max(value) - min(value))) + + x_bound <- max(abs(max(data_long$shap)), abs(min(data_long$shap))) + + ggplot2::ggplot(data = data_long) + + ggplot2::coord_flip(ylim = c(-x_bound, x_bound)) + + ggplot2::geom_hline(yintercept = 0) + + ggforce::geom_sina( + ggplot2::aes(x = name, y = shap, color = std_value), + method = "counts", maxwidth = 0.7, alpha = 0.7 + ) + + ggplot2::theme_minimal() + ggplot2::theme( + axis.line.y = ggplot2::element_blank(), axis.ticks.y = ggplot2::element_blank(), + legend.position = "top", + legend.title = ggplot2::element_text(size = 16), legend.text = ggplot2::element_text(size = 14), + axis.title.y = ggplot2::element_text(size = 16), axis.text.y = ggplot2::element_text(size = 14), + axis.title.x = ggplot2::element_text(size = 16, vjust = -1), axis.text.x = ggplot2::element_text(size = 14) + ) + + ggplot2::scale_color_gradient( + low = "dark green" , high = "sandybrown" , + breaks = c(0, 1), labels = c(" Low", "High "), + guide = ggplot2::guide_colorbar(barwidth = 12, barheight = 0.3) + ) + + ggplot2::labs(y = "Causal Shapley value (impact on model output)", + x = "", color = "Scaled feature value ") +} + + +# 0 - Load Packages and Source Files -------------------------------------- +library(tidyverse) +library(data.table) +library(xgboost) +library(ggpubr) +library(shapr) +library(ggplot2) +library(grid) +library(gridExtra) + +if (save_plots && !dir.exists("figures")) dir.create("figures") + +# 1 - Prepare and Plot Data ----------------------------------------------- +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) + + +bike <- read.csv("inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +# bike$trend <- as.integer(difftime(bike$dteday, min(as.Date(bike$dteday)))+1)/24 +bike$cosyear <- cospi(bike$trend/365*2) +bike$sinyear <- sinpi(bike$trend/365*2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +bike_plot <- ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab( "Days since 1 January 2011") + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +if (save_plots) { + ggsave("figures/bike_rental_plot.pdf", bike_plot, width = 4.5, height = 2) +} else { + print(bike_plot) +} + +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# so we saved the training-test split. +# set.seed(2013) +# train_index <- caret::createDataPartition(bike$cnt, p = .8, list = FALSE, times = 1) +train_index <- readRDS("inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Test data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Fit an XGBoost model to the training data +model <- xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) +# caret::RMSE(y_explain, predict(model, x_explain)) +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +phi0 <- mean(y_train) + +message("1. Prepared and plotted data, trained XGBoost model") + +# 2 - Compute Shapley Values ---------------------------------------------- +progressr::handlers("cli") +explanation_gaussian_time = system.time({ + explanation_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = FALSE, + seed = 2020, + n_samples = 50, + keep_samp_for_vS = FALSE + ) + }) +}) + +saveRDS(list(explanation_asymmetric = explanation_asymmetric, + time = explanation_asymmetric_time), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") + + +## a. We compute the causal symmetric Shapley values on a given partial order (see paper) #### +message("2a. Computing and plotting causal Shapley values") +progressr::handlers("cli") +explanation_causal_time = system.time({ + explanation_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 50, + keep_samp_for_vS = FALSE, + verbose = 2, + ) + }) +}) + +set.seed(123) +sina_causal <- sina_plot(explanation_causal) +sina_causal + +# save limits of sina_causal plot for comparing against marginal and asymmetric +ylim_causal <- sina_causal$coordinates$limits$y + +sina_causal = sina_causal + + coord_flip(ylim = ylim_causal) + + ylab("Causal Shapley value (impact on model output)") + +sina_causal + +saveRDS(list(explanation = explanation_causal, + time = explanation_causal_time, + plot = sina_causal, + version = "Causal Shapley values"), + "inst/extdata/explanation_causal_Olsen.RDS") + +if (save_plots) { + ggsave("figures/sina_plot_causal.pdf", sina_causal, height = 6.5, width = 6.5) +} else { + print(sina_causal) +} + + +## b. For computing marginal Shapley values, we assume one component with confounding #### +message("2b. Computing and plotting marginal Shapley values") +explanation_marginal_time = system.time({ + explanation_marginal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "independence", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = FALSE, + seed = 2020, + n_samples = 5000, + keep_samp_for_vS = FALSE + ) + }) +}) + +set.seed(123) +sina_marginal <- sina_plot(explanation_marginal) + + coord_flip(ylim = ylim_causal) + + ylab("Marginal Shapley value (impact on model output)") + +sina_marginal + +saveRDS(list(explanation = explanation_marginal, + time = explanation_marginal_time, + plot = sina_marginal, + version = "Marginal Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_marginal_Olsen.RDS") + + + +if (save_plots) { + ggsave("figures/sina_plot_marginal.pdf", sina_marginal, height = 6.5, width = 6.5) +} else { + print(sina_marginal) +} + + + + +## c. Finally, we compute the asymmetric Shapley values for the same partial order #### +message("2c. Computing and plotting asymmetric conditional Shapley values") + +progressr::handlers("cli") +explanation_asymmetric_time = system.time({ + explanation_asymmetric <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 10000, + keep_samp_for_vS = FALSE + ) + }) +}) +set.seed(123) +sina_asymmetric <- sina_plot(explanation_asymmetric) + + coord_flip(ylim = ylim_causal) + + ylab("Asymmetric conditional Shapley value (impact on model output)") + +sina_asymmetric + +saveRDS(list(explanation = explanation_asymmetric, + time = explanation_asymmetric_time, + plot = sina_asymmetric, + version = "Asymmetric conditional Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") + +if (save_plots) { + ggsave("figures/sina_plot_asymmetric.pdf", sina_asymmetric, height = 6.5, width = 6.5) +} else { + print(sina_asymmetric) +} + + + + + +## d. Asymmetric causal Shapley values (very similar to the conditional ones) #### +message("2d. Computing and plotting asymmetric conditional Shapley values") + +progressr::handlers("cli") +explanation_asymmetric_causal_time = system.time({ + explanation_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 10000, + keep_samp_for_vS = FALSE + ) + }) +}) + +set.seed(123) +sina_asymmetric_causal <- sina_plot(explanation_asymmetric_causal) + + coord_flip(ylim = ylim_causal) + + ylab("Asymmetric causal Shapley value (impact on model output)") + +sina_asymmetric_causal + +saveRDS(list(explanation = explanation_asymmetric_causal, + time = explanation_asymmetric_causal_time, + plot = sina_asymmetric_causal, + version = "Asymmetric causal Shapley values"), + "~/CauSHAPley/inst/extdata/explanation_asymmetric_causal_Olsen.RDS") + + +if (save_plots) { + ggsave("figures/sina_plot_asymmetric_causal.pdf", sina_asymmetric_causal, height = 6.5, width = 6.5) +} else { + print(sina_asymmetric_causal) +} + + + + +# 2.5 Compare with old implementation ---- +save_explanation_causal = readRDS("~/CauSHAPley/inst/extdata/explanation_causal.RDS") +save_explanation_marginal = readRDS("~/CauSHAPley/inst/extdata/explanation_marginal.RDS") +save_explanation_asymmetric = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric.RDS") +save_explanation_asymmetric_causal = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_causal.RDS") + +save_explanation_causal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_causal_Olsen.RDS") +save_explanation_marginal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_marginal_Olsen.RDS") +save_explanation_asymmetric_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_Olsen.RDS") +save_explanation_asymmetric_causal_Olsen = readRDS("~/CauSHAPley/inst/extdata/explanation_asymmetric_causal_Olsen.RDS") + +explanation_causal = save_explanation_causal_Olsen$explanation +explanation_marginal = save_explanation_marginal_Olsen$explanation +explanation_asymmetric = save_explanation_asymmetric_Olsen$explanation +explanation_asymmetric_causal = save_explanation_asymmetric_causal_Olsen$explanation + +gridExtra::grid.arrange(save_explanation_causal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_causal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Causal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +# Will be a difference as we use marginal independence and they us marginal Gaussian +gridExtra::grid.arrange(save_explanation_marginal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_marginal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Marginal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +gridExtra::grid.arrange(save_explanation_asymmetric$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_asymmetric_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Asymmetric conditional Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + +gridExtra::grid.arrange(save_explanation_asymmetric_causal$plot + ggplot2::ggtitle("Heskes et al. (2020):"), + save_explanation_asymmetric_causal_Olsen$plot + ggplot2::ggtitle("SHAPR:"), + top = grid::textGrob("Asymmetric causal Shapley values", + gp = grid::gpar(fontsize=18,font=8))) + + + + +# 3 - Shapley value scatter plots (Figure 3) ------------------------------ +message("3. Producing scatter plots comparing marginal and causal Shapley values on the test set") +sv_correlation_df <- data.frame( + temp = x_explain[, "temp"], + sv_marg_cosyear = explanation_marginal$shapley_values_est$cosyear, + sv_caus_cosyear = explanation_causal$shapley_values_est$cosyear, + sv_marg_temp = explanation_marginal$shapley_values_est$temp, + sv_caus_temp = explanation_causal$shapley_values_est$temp +) + + + +scatterplot_topleft <- + ggplot(sv_correlation_df, aes(x = sv_marg_temp, y = sv_marg_cosyear, color = temp)) + + geom_point(size = 1)+xlab("MargSV temp")+ylab( "MargSV cosyear")+ + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low="blue", high="red") + + theme_minimal() + + theme(text = element_text(size = 12), + axis.text.x = element_blank(), axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), axis.title.x = element_blank()) + +scatterplot_topright <- + ggplot(sv_correlation_df, aes(x = sv_caus_cosyear, y = sv_marg_cosyear, color = temp)) + + geom_point(size = 1) + scale_color_gradient(low="blue", high="red") + + xlab("CauSV cosyear") + ylab("MargSV cosyear") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme(text = element_text(size=12), axis.title.x = element_blank(), axis.title.y=element_blank(), + axis.text.x = element_blank(), axis.ticks.x = element_blank(), + axis.text.y = element_blank(), axis.ticks.y = element_blank()) + +scatterplot_bottomleft <- + ggplot(sv_correlation_df, aes(x = sv_marg_temp, y = sv_caus_temp, color = temp)) + + geom_point(size = 1) + scale_color_gradient(low="blue", high="red") + + ylab( "CauSV temp") + xlab("MargSV temp") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme(text = element_text(size=12), + axis.text.x = element_text(size=12), axis.text.y = element_text(size=12)) + +scatterplot_bottomright <- + ggplot(sv_correlation_df, aes(x = sv_caus_cosyear, y = sv_caus_temp, color = temp)) + + geom_point(size = 1) + ylab("CauSV temp") + xlab( "CauSV cosyear") + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low="blue", high="red")+ + theme_minimal() + + theme(text = element_text(size=12), axis.text.x=element_text(size=12), + axis.title.y = element_blank(), axis.text.y = element_blank(), axis.ticks.y = element_blank()) + +grid_top <- gridExtra::grid.arrange(scatterplot_topleft, scatterplot_topright, ncol = 2) +grid_bottom <- gridExtra::grid.arrange(scatterplot_bottomleft, scatterplot_bottomright, legend = "none") + +grid_top <- ggpubr::ggarrange(scatterplot_topleft, scatterplot_topright, legend = "none") +grid_bottom <- ggpubr::ggarrange(scatterplot_bottomleft, scatterplot_bottomright, legend = "none") + +bike_plot <- ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab( "Days since 1 January 2011") + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +p1 = ggpubr::ggarrange(scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none") + +ggpubr::ggarrange(bike_plot, p1, nrow = 2, heights = c(1,2)) + +if (save_plots) { + ggsave("figures/scatter_plots_top.pdf", grid_top, width = 5, height = 1) + ggsave("figures/scatter_plots_bottom.pdf", grid_bottom, width = 5, height = 2) +} else { + print(ggpubr::ggarrange(grid_top, grid_bottom, nrow = 2)) +} + + +# 4 - Shapley value bar plots (Figure 4) ---------------------------------- +message("4. Producing bar plots comparing marginal, causal, and asymmetric conditional Shapley values") + +# Get test set index for two data points with similar temperature +# 1. 2012-10-09 (October) +# 2. 2012-12-03 (December) +features = c("cosyear", "temp") +dates = c("2012-10-09", "2012-12-03") +dates_idx = sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predicted values for the two points +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) + +explanations = list("Marginal" = explanation_marginal, "Causal" = explanation_causal) +explanations_extracted = data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[dates_idx, ..features][, `:=` (Date = dates, type = names(explanations)[idx])] +})) + +dt_all = data.table::melt(explanations_extracted, id.vars = c("Date", "type"), variable.name = "feature") +bar_plots <- ggplot(dt_all, aes(x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2))) + + geom_col(position = "dodge") + + theme_classic() + ylab("Shapley value") + + facet_wrap(vars(type)) + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c('indianred4', 'ivory4')) + + theme(legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14)) + + +if (save_plots) { + ggsave("figures/bar_plots.pdf", bar_plots, width = 6, height = 3) +} else { + print(bar_plots) +} + + +plot_SV_several_approaches(explanations, index_explicands = dates_idx, only_these_features = features, facet_ncol = 1, + facet_scales = "free_y") + + + +# 5 - Other approaches ------------------------------------------------------------------------------------------- +approaches = c("independence", "empirical", "gaussian", "copula", "ctree", "vaeac") +n_samples_list = list("independence" = 1000, + "empirical" = 1000, + "gaussian" = 1000, + "copula" = 1000, + "ctree" = 1000, + "vaeac" = 1000) +explanation_list = list() + +for (approach_idx in seq_along(approaches)) { + +} + + + + + +progressr::handlers("cli") +explanation_asymmetric_causal_gaussian_time = system.time({ + explanation_asymmetric_causal_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + }) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_copula_time = system.time({ + explanation_asymmetric_causal_copula <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "copula", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_ctree_time = system.time({ + explanation_asymmetric_causal_ctree <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 500, + keep_samp_for_vS = FALSE + ) + }) +}) + + +progressr::handlers("cli") +explanation_asymmetric_causal_independence_time = system.time({ + explanation_asymmetric_causal_independence <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "independence", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_empirical_time = system.time({ + explanation_asymmetric_causal_empirical <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "empirical", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE + ) + #}) +}) + +progressr::handlers("cli") +explanation_asymmetric_causal_vaeac_time = system.time({ + explanation_asymmetric_causal_vaeac <- + #progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "vaeac", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE, + verbose = 2 + ) + #}) +}) + +sina_plot(explanation_asymmetric_causal_independence) +sina_plot(explanation_asymmetric_causal_empirical) +sina_plot(explanation_asymmetric_causal_gaussian) +sina_plot(explanation_asymmetric_causal_copula) +sina_plot(explanation_asymmetric_causal_ctree) +sina_plot(explanation_asymmetric_causal_vaeac) + + + + + + + + + + + + + + + +# 6 - Sampled n_combinations -------------------------------------------------------------------------------------- +explanation_asymmetric_all_gaussian2 <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 1000, + n_combinations = 10, + keep_samp_for_vS = FALSE, + n_batches = 1 + ) + }) + +explanation_asymmetric_all_gaussian$shapley_values_est - explanation_asymmetric_all_gaussian2$shapley_values_est + + +explanation_asymmetric_all_gaussian$MSEv +explanation_asymmetric_all_gaussian2$MSEv + +sina_plot(explanation_asymmetric_all_gaussian) +sina_plot(explanation_asymmetric_all_gaussian2) + + +explanation_asymmetric_gaussian <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = FALSE, + seed = 2020, + n_samples = 1000, + keep_samp_for_vS = FALSE, + n_combinations = 10 + ) + }) + + + +explanation_asymmetric_causal_gaussian +explanation_asymmetric_causal_gaussian + + + + +explanation_causal_time = system.time({ + explanation_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 5000, + keep_samp_for_vS = FALSE, + verbose = 2, + ) + }) +}) + + +explanation_causal_time_sampled = system.time({ + explanation_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, c(2, 3), c(4:7)), + confounding = c(FALSE, TRUE, FALSE), + seed = 2020, + n_samples = 5000, + n_combinations = 10, + keep_samp_for_vS = FALSE + ) + }) +}) + +explanation_causal_time +explanation_causal_time_sampled + +sina_plot(explanation_causal) +sina_plot(explanation_causal_sampled) + + + + +# 7 - Group ------------------------------------------------------------------------------------------------------- +# It makes sense to group the "temp" and "atemp" due to their high correlation +cor(x_train[,4], x_train[,5]) +plot(x_train[,4], x_train[,5]) +pairs(x_train) + +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum") +causal_ordering = list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +causal_ordering = list(1, 2:3, 4:6) # Equivalent to using the names (verified) +confounding = c(FALSE, TRUE, FALSE) +asymmetric = TRUE + +progressr::handlers("cli") +explanation_group_asymmetric_causal_time = system.time({ + explanation_group_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:6), + confounding = c(FALSE, TRUE, FALSE), + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_asymmetric_causal$shapley_values_est +sina_plot(explanation_group_asymmetric_causal) + +# Now we compute the group Shapley values based on only half of the coalitions +explanation_group_asymmetric_causal_sampled_time = system.time({ + explanation_group_asymmetric_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:6), + confounding = confounding, + group = group_list, + n_combinations = explanation_group_asymmetric_causal$internal$parameters$n_combinations_causal_max/2 + 1, + seed = 2020, + n_samples = 1000 + ) + }) +}) + + +# Now we compute the group symmetric causal Shapley values +explanation_group_symmetric_causal_time = system.time({ + explanation_group_symmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:6), #FORTSETT HER MED Å ENDRE OG SE HVA SOM KRÆSJER + confounding = confounding, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_symmetric_causal_sampled_time = system.time({ + explanation_group_symmetric_causal_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering, + confounding = confounding, + group = group_list, + n_combinations = 30, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +# Symmetric Conditional +progressr::handlers("cli") +explanation_group_symmetric_conditional_time = system.time({ + explanation_group_symmetric_conditional <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = FALSE, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_symmetric_conditional_sampled_time = system.time({ + explanation_group_symmetric_conditional_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = FALSE, + group = group_list, + n_combinations = 30, + seed = 2020, + n_samples = 1000 + ) + }) +}) + +explanation_group_asymmetric_conditional_time = system.time({ + explanation_group_asymmetric_conditional <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(seq_along(group_list)), + confounding = FALSE, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) +explanation_group_asymmetric_conditional$internal$objects$X + +explanation_group_asymmetric_causal_time = system.time({ + explanation_group_asymmetric_causal <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = c(FALSE, TRUE, FALSE), + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) +explanation_group_asymmetric_causal$internal$objects$X + +explanation_group_asymmetric_conditional$internal$objects$S_causal_strings +explanation_group_asymmetric_causal$internal$objects$S_causal_strings +all.equal(explanation_group_asymmetric_causal$internal$objects$S_causal_strings, + explanation_group_asymmetric_conditional$internal$objects$S_causal_strings) + +explanation_group_asymmetric_conditional_sampled_time = system.time({ + explanation_group_asymmetric_conditional_sampled <- + progressr::with_progress({ + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = FALSE, + n_combinations = 7, + group = group_list, + seed = 2020, + n_samples = 1000 + ) + }) +}) + + +sina_plot(explanation_asymmetric_causal) +sina_plot(explanation_group_asymmetric_causal) +sina_plot(explanation_group_asymmetric_causal_sampled) + +n_index_x_explain = 6 +index_x_explain = order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +plot(explanation_group_asymmetric_causal, index_x_explain = index_x_explain) +plot(explanation_group_asymmetric_causal_sampled, index_x_explain = index_x_explain) + +plot(explanation_asymmetric_causal, plot_type = "beeswarm") + + +plot_SV_several_approaches(list(feature = explanation_asymmetric_causal), + index_explicands = index_x_explain) +plot_SV_several_approaches(list(exact = explanation_group_asymmetric_causal, + non_exact = explanation_group_asymmetric_causal_sampled), + index_explicands = index_x_explain, + include_group_feature_means = TRUE) + +plot_SV_several_approaches( + list( + GrAsymCau_exact = explanation_group_asymmetric_causal, + GrAsymCau_non_exact = explanation_group_asymmetric_causal_sampled, + GrSymCau_exact = explanation_group_symmetric_causal, + GrSymCau_non_exact = explanation_group_symmetric_causal_sampled, + GrAsymCon_exact = explanation_group_asymmetric_conditional, + GrAsymCon_non_exact = explanation_group_asymmetric_conditional_sampled, + GrSymCon_exact = explanation_group_symmetric_conditional, + GrSymCon_non_exact = explanation_group_symmetric_conditional_sampled + ), + index_explicands = index_x_explain, + brewer_palette = "Paired", + include_group_feature_means = FALSE) diff --git a/inst/scripts/analyze_bash_test_data.R b/inst/scripts/analyze_bash_test_data.R index 519801de3600e3723031e5d70551ecbf37bc13c5..3cd9435e4ef2023b7499509d8f993bf6d1bc296b 100644 --- a/inst/scripts/analyze_bash_test_data.R +++ b/inst/scripts/analyze_bash_test_data.R @@ -52,10 +52,10 @@ dt_time0 <- fread("inst/scripts/timing_test_2023_new2.csv") dt_time0[,n_batches_real:=pmin(2^p-2,n_batches)] -dt_time <- dt_time0[,.(time,secs_explain,timing_setup,timing_test_prediction, timing_setup_computation ,timing_compute_vS ,timing_postprocessing ,timing_shapley_computation, rep,p,n_train,n_explain,n_batches_real,approach,n_combinations)] +dt_time <- dt_time0[,.(time,secs_explain,timing_setup,timing_test_prediction, timing_setup_computation ,timing_compute_vS ,timing_postprocessing ,timing_shapley_computation, rep,p,n_train,n_explain,n_batches_real,approach,n_coalitions)] dt_time[n_batches_real==1,secs_explain_singlebatch :=secs_explain] -dt_time[,secs_explain_singlebatch:=mean(secs_explain_singlebatch,na.rm=T),by=.(p,n_train,n_explain,approach,n_combinations)] +dt_time[,secs_explain_singlebatch:=mean(secs_explain_singlebatch,na.rm=T),by=.(p,n_train,n_explain,approach,n_coalitions)] dt_time[,secs_explain_prop_singlebatch:=secs_explain/secs_explain_singlebatch] ggplot(dt_time[p<14],aes(x=n_batches_real,y=secs_explain,col=as.factor(n_explain),linetype=as.factor(n_train)))+ @@ -101,14 +101,14 @@ ggplot(dt_time[p<16& p>2 & approach=="empirical"],aes(x=n_batches_real,y=secs_ex # max 100, min 10 n_batches_fun <- function(approach,p){ - n_combinations <- 2^p-2 + n_coalitions <- 2^p-2 if(approach %in% c("ctree","gaussian","copula")){ - init <- ceiling(n_combinations/10) + init <- ceiling(n_coalitions/10) floor <- max(c(10,init)) ret <- min(c(1000,floor)) } else { - init <- ceiling(n_combinations/100) + init <- ceiling(n_coalitions/100) floor <- max(c(2,init)) ret <- min(c(100,floor)) } diff --git a/inst/scripts/check_model_workflow.R b/inst/scripts/check_model_workflow.R index 01799eae15d2a865f70931f32c30846afd34bd72..296c090aed79ad2d16a3f099ac1d3545b4d7db99 100644 --- a/inst/scripts/check_model_workflow.R +++ b/inst/scripts/check_model_workflow.R @@ -50,7 +50,7 @@ explain_workflow = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -59,12 +59,12 @@ explain_xgboost = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) # See that the shapley values are identical -all.equal(explain_workflow$shapley_values, explain_xgboost$shapley_values) +all.equal(explain_workflow$shapley_values_est, explain_xgboost$shapley_values_est) # Other models in workflow --------------------------------------------------------------------------------------------- set.seed(1) @@ -103,7 +103,7 @@ explain_decision_tree_ctree = explain( x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -113,7 +113,7 @@ explain_decision_tree_lm = explain( x_train = x_train_mixed, approach = "regression_separate", regression.model = parsnip::linear_reg(), - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -149,7 +149,7 @@ explain_decision_model_rf_cv_rf = explain( x_train = x_train_mixed, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression"), - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) @@ -159,7 +159,7 @@ explain_decision_model_rf_cv_ctree = explain( x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, n_batches = 4 ) diff --git a/inst/scripts/compare_copula_in_R_and_C++.R b/inst/scripts/compare_copula_in_R_and_C++.R index fd6b1cfb4ae7ddcd0b20f6157f5b6c5e2636a050..f3811c7fa583b0355082adf4ec608c7205562ddb 100644 --- a/inst/scripts/compare_copula_in_R_and_C++.R +++ b/inst/scripts/compare_copula_in_R_and_C++.R @@ -41,10 +41,10 @@ prepare_data.copula_old <- function(internal, index_features = NULL, ...) { x_train = as.matrix(x_train), x_explain_gaussian = as.matrix(copula.x_explain_gaussian)[i, , drop = FALSE] ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") dt_l[[i]][, w := 1 / n_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) @@ -171,7 +171,7 @@ prepare_data.copula_cpp_arma <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -199,16 +199,16 @@ prepare_data.copula_cpp_arma <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -229,7 +229,7 @@ prepare_data.copula_cpp_and_R <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -257,16 +257,16 @@ prepare_data.copula_cpp_and_R <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -327,7 +327,7 @@ prepare_data.copula_sourceCpp <- function(internal, index_features, ...) { n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features - n_combinations_now <- length(index_features) + n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -351,16 +351,16 @@ prepare_data.copula_sourceCpp <- function(internal, index_features, ...) { ) # Reshape `dt` to a 2D array of dimension (n_samples * n_explain * n_coalitions, n_features). - dim(dt) <- c(n_combinations_now * n_explain * n_samples, n_features) + dim(dt) <- c(n_coalitions_now * n_explain * n_samples, n_features) # Convert to a data.table and add extra identification columns dt <- data.table::as.data.table(dt) data.table::setnames(dt, feature_names) - dt[, id_combination := rep(seq_len(nrow(S)), each = n_samples * n_explain)] + dt[, id_coalition := rep(seq_len(nrow(S)), each = n_samples * n_explain)] dt[, id := rep(seq(n_explain), each = n_samples, times = nrow(S))] dt[, w := 1 / n_samples] - dt[, id_combination := index_features[id_combination]] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + dt[, id_coalition := index_features[id_coalition]] + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) return(dt) } @@ -444,7 +444,7 @@ using namespace Rcpp; // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -642,7 +642,7 @@ arma::mat inv_gaussian_transform_cpp_arma(arma::mat z, arma::mat x) { // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -747,7 +747,7 @@ arma::cube prepare_data_copula_cpp_arma(arma::mat MC_samples_mat, // observations to explain after being transformed using the Gaussian transform, i.e., the samples have been // transformed to a standardized normal distribution. // @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -// @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +// @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of // the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. // This is not a problem internally in shapr as the empty and grand coalitions treated differently. // @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -906,7 +906,7 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -915,7 +915,7 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- shapr:::get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -925,8 +925,8 @@ arma::cube prepare_data_copula_cpp_and_R(arma::mat MC_samples_mat, x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -959,7 +959,7 @@ feature_names <- internal$parameters$feature_names n_explain <- internal$parameters$n_explain n_samples <- internal$parameters$n_samples n_features <- internal$parameters$n_features -n_combinations_now <- length(index_features) +n_coalitions_now <- length(index_features) x_train_mat <- as.matrix(internal$data$x_train) x_explain_mat <- as.matrix(internal$data$x_explain) copula.mu <- internal$parameters$copula.mu @@ -1060,7 +1060,7 @@ time_only_cpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp, c("id", "id_coalition")) time_only_cpp # The C++ code with my own quantile function @@ -1070,7 +1070,7 @@ time_only_cpp_sourceCpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_coalition")) time_only_cpp_sourceCpp # The C++ code with quantile functions from arma @@ -1080,7 +1080,7 @@ time_only_cpp_arma <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_arma, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_arma, c("id", "id_coalition")) time_only_cpp_arma # The new C++ code with quantile from R @@ -1090,7 +1090,7 @@ time_cpp_and_R <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_cpp_and_R, c("id", "id_combination")) +data.table::setorderv(res_cpp_and_R, c("id", "id_coalition")) time_cpp_and_R # Create a table of the times. Less is better @@ -1131,11 +1131,11 @@ res_only_cpp <- res_only_cpp[, w := NULL] res_only_cpp_sourceCpp <- res_only_cpp_sourceCpp[, w := NULL] res_only_cpp_arma <- res_only_cpp_arma[, w := NULL] res_cpp_and_R <- res_cpp_and_R[, w := NULL] -res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_arma_agr <- res_only_cpp_arma[, lapply(.SD, mean), by = c("id", "id_combination")] -res_cpp_and_R_agr <- res_cpp_and_R[, lapply(.SD, mean), by = c("id", "id_combination")] +res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_arma_agr <- res_only_cpp_arma[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_cpp_and_R_agr <- res_cpp_and_R[, lapply(.SD, mean), by = c("id", "id_coalition")] # Difference res_only_R_agr - res_only_cpp_agr @@ -1400,7 +1400,7 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -1409,7 +1409,7 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- shapr:::get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -1419,8 +1419,8 @@ all.equal(shapr_mat_arma_res, sourceCpp_mat_arma_res) x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -1464,7 +1464,7 @@ time_only_cpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp, c("id", "id_coalition")) time_only_cpp # The C++ code with my own quantile function @@ -1474,7 +1474,7 @@ time_only_cpp_sourceCpp <- system.time({ index_features = internal$objects$S_batch$`1`[look_at_coalitions] ) }) -data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_combination")) +data.table::setorderv(res_only_cpp_sourceCpp, c("id", "id_coalition")) time_only_cpp_sourceCpp # Look at the differences @@ -1482,9 +1482,9 @@ time_only_cpp_sourceCpp # res_only_R <- res_only_R[, w := NULL] # res_only_cpp <- res_only_cpp[, w := NULL] # res_only_cpp_sourceCpp <- res_only_cpp_sourceCpp[, w := NULL] -res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_combination")] -res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_combination")] +res_only_R_agr <- res_only_R[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_agr <- res_only_cpp[, lapply(.SD, mean), by = c("id", "id_coalition")] +res_only_cpp_sourceCpp_agr <- res_only_cpp_sourceCpp[, lapply(.SD, mean), by = c("id", "id_coalition")] # Difference res_only_R_agr - res_only_cpp_agr @@ -1511,7 +1511,7 @@ temp_shapley_value_func = function(dt, internal, model, predict_model) { xreg = internal$data$xreg ) dt_vS2 <- compute_MCint(dt, paste0("p_hat", seq_len(internal$parameters$output_size))) - dt_vS <- rbind(t(as.matrix(c(1, rep(prediction_zero, n_test)))), dt_vS2, t(as.matrix(c(2^M, response_test))), + dt_vS <- rbind(t(as.matrix(c(1, rep(phi0, n_test)))), dt_vS2, t(as.matrix(c(2^M, response_test))), use.names = FALSE) colnames(dt_vS) = colnames(dt_vS2) compute_shapley_new(internal, dt_vS) diff --git a/inst/scripts/compare_gaussian_in_R_and_C++.R b/inst/scripts/compare_gaussian_in_R_and_C++.R index b9ca398aae466c48c47a0eba5756cc32e06a0a00..b358c9127cdfc611496105b7740cea5169e52d84 100644 --- a/inst/scripts/compare_gaussian_in_R_and_C++.R +++ b/inst/scripts/compare_gaussian_in_R_and_C++.R @@ -63,7 +63,7 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_explain) { # //' univariate standard normal. # //' @param x_explain_mat matrix. Matrix of dimension `n_explain` times `n_features` containing the observations # //' to explain. -# //' @param S matrix. Matrix of dimension `n_combinations` times `n_features` containing binary representations of +# //' @param S matrix. Matrix of dimension `n_coalitions` times `n_features` containing binary representations of # //' the used coalitions. # //' @param mu vector. Vector of length `n_features` containing the mean of each feature. # //' @param cov_mat mat. Matrix of dimension `n_features` times `n_features` containing the pariwise covariance between @@ -72,7 +72,7 @@ sample_gaussian <- function(index_given, n_samples, mu, cov_mat, m, x_explain) { # //' @export # //' @keywords internal # //' -# //' @return List of length `n_combinations`*`n_samples`, where each entry is a matrix of dimension `n_samples` times +# //' @return List of length `n_coalitions`*`n_samples`, where each entry is a matrix of dimension `n_samples` times # //' `n_features` containing the conditional MC samples for each coalition and explicand. # //' @author Lars Henry Berge Olsen # // [[Rcpp::export]] @@ -728,10 +728,10 @@ prepare_data_gaussian_old <- function(internal, index_features = NULL, ...) { x_explain = x_explain0[i, , drop = FALSE] ) - dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_combination") + dt_l[[i]] <- data.table::rbindlist(l, idcol = "id_coalition") dt_l[[i]][, w := 1 / n_samples] dt_l[[i]][, id := i] - if (!is.null(index_features)) dt_l[[i]][, id_combination := index_features[id_combination]] + if (!is.null(index_features)) dt_l[[i]][, id_coalition := index_features[id_coalition]] } dt <- data.table::rbindlist(dt_l, use.names = TRUE, fill = TRUE) @@ -756,7 +756,7 @@ prepare_data_gaussian_new_v1 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -840,18 +840,18 @@ prepare_data_gaussian_new_v1 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -875,7 +875,7 @@ prepare_data_gaussian_new_v2 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -962,18 +962,18 @@ prepare_data_gaussian_new_v2 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -997,7 +997,7 @@ prepare_data_gaussian_new_v3 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1090,18 +1090,18 @@ prepare_data_gaussian_new_v3 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1124,7 +1124,7 @@ prepare_data_gaussian_new_v4 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1213,18 +1213,18 @@ prepare_data_gaussian_new_v4 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1248,7 +1248,7 @@ prepare_data_gaussian_new_v5 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1338,18 +1338,18 @@ prepare_data_gaussian_new_v5 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1371,7 +1371,7 @@ prepare_data_gaussian_new_v5_rnorm <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1467,18 +1467,18 @@ prepare_data_gaussian_new_v5_rnorm <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1500,7 +1500,7 @@ prepare_data_gaussian_new_v5_rnorm_v2 <- function(internal, index_features, ...) n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1593,18 +1593,18 @@ prepare_data_gaussian_new_v5_rnorm_v2 <- function(internal, index_features, ...) ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1628,7 +1628,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp <- function(internal, index_features, ... n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1647,19 +1647,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp <- function(internal, index_features, ... dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1681,7 +1681,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_with_wrap <- function(internal, index_fea n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1700,19 +1700,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_with_wrap <- function(internal, index_fea dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1735,7 +1735,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_v2 <- function(internal, index_features, n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1754,19 +1754,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_v2 <- function(internal, index_features, dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1788,7 +1788,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat <- function(internal, index n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1807,19 +1807,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat <- function(internal, index cov_mat = cov_mat) ) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1841,7 +1841,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat_v2 <- function(internal, in n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1860,19 +1860,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_large_mat_v2 <- function(internal, in cov_mat = cov_mat) ) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1894,7 +1894,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- function n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1914,19 +1914,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- function # Here we first put the inner list together and then the whole thing. Maybe exist another faster way! dt = as.data.table(do.call(rbind, lapply(result_list, function(inner_list) do.call(rbind, inner_list)))) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -1948,7 +1948,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube <- function(internal, index_feat n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -1975,19 +1975,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube <- function(internal, index_feat dim(result_cube) <- c(prod(dims[-2]), dims[2]) dt = as.data.table(result_cube) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2009,8 +2009,8 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube_v2 <- function(internal, index_f n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations - n_combinations_now <- length(index_features) + n_coalitions <- internal$parameters$n_coalitions + n_coalitions_now <- length(index_features) # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -2028,22 +2028,22 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_cube_v2 <- function(internal, index_f cov_mat = cov_mat) # Reshape and convert to data.table - dim(dt) = c(n_combinations_now*n_explain*n_samples, n_features) + dim(dt) = c(n_coalitions_now*n_explain*n_samples, n_features) print(system.time({dt = as.data.table(dt)}, gcFirst = FALSE)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2065,7 +2065,7 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_std_list <- function(internal, index_ n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. @@ -2090,19 +2090,19 @@ prepare_data_gaussian_new_v5_rnorm_cpp_fix_std_list <- function(internal, index_ # Here we first put the inner list together and then the whole thing. Maybe exist another faster way! dt = as.data.table(do.call(rbind, result_list)) setnames(dt, feature_names) - dt[, "id_combination" := rep(seq(nrow(S)), each = n_samples * n_explain)] + dt[, "id_coalition" := rep(seq(nrow(S)), each = n_samples * n_explain)] dt[, "id" := rep(seq(n_explain), each = n_samples, times = nrow(S))] - data.table::setcolorder(dt, c("id_combination", "id", feature_names)) + data.table::setcolorder(dt, c("id_coalition", "id", feature_names)) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2126,19 +2126,19 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { n_features <- internal$parameters$n_features n_samples <- internal$parameters$n_samples feature_names <- internal$parameters$feature_names - n_combinations <- internal$parameters$n_combinations + n_coalitions <- internal$parameters$n_coalitions # Extract the relevant coalitions specified in `index_features` from `S`. # This will always be called as `index_features` is never NULL. S <- if (!is.null(index_features)) S[index_features, , drop = FALSE] - n_combinations_in_this_batch <- nrow(S) + n_coalitions_in_this_batch <- nrow(S) # Allocate an empty matrix used in mvnfast:::rmvnCpp to store the generated MC samples. - B <- matrix(nrow = n_samples * n_combinations_in_this_batch, ncol = n_features) + B <- matrix(nrow = n_samples * n_coalitions_in_this_batch, ncol = n_features) class(B) <- "numeric" .Call("rmvnCpp", - n_ = n_samples * n_combinations_in_this_batch, + n_ = n_samples * n_coalitions_in_this_batch, mu_ = rep(0, n_features), sigma_ = diag(n_features), ncores_ = 1, @@ -2148,7 +2148,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { ) # Indices of the start for the combinations - B_indices <- n_samples * (seq(0, n_combinations_in_this_batch)) + 1 + B_indices <- n_samples * (seq(0, n_coalitions_in_this_batch)) + 1 # Generate a data table containing all Monte Carlo samples for all test observations and coalitions dt <- data.table::rbindlist( @@ -2221,18 +2221,18 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { ) } ), - idcol = "id_combination" + idcol = "id_coalition" ) - # Update the id_combination. This will always be called as `index_features` is never NULL. - if (!is.null(index_features)) dt[, id_combination := index_features[id_combination]] + # Update the id_coalition. This will always be called as `index_features` is never NULL. + if (!is.null(index_features)) dt[, id_coalition := index_features[id_coalition]] # Add uniform weights dt[, w := 1 / n_samples] # Remove: # This is not needed when we assume that the empty and grand coalitions will never be present - # dt[id_combination %in% c(1, n_combinations), w := 1] + # dt[id_coalition %in% c(1, n_coalitions), w := 1] # Return the MC samples return(dt) @@ -2289,7 +2289,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { predictive_model <- lm(y ~ ., data = data_train_with_response) # Get the prediction zero, i.e., the phi0 Shapley value. - prediction_zero <- mean(response_train) + phi0 <- mean(response_train) model <- predictive_model x_explain <- data_test @@ -2298,7 +2298,7 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { predict_model <- NULL get_model_specs <- NULL timing <- TRUE - n_combinations <- NULL + n_coalitions <- NULL group <- NULL feature_specs <- get_feature_specs(get_model_specs, model) n_batches <- 1 @@ -2308,8 +2308,8 @@ prepare_data_gaussian_new_v6 <- function(internal, index_features, ...) { x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = n_combinations, + phi0 = phi0, + n_coalitions = n_coalitions, group = group, n_samples = n_samples, n_batches = n_batches, @@ -2688,25 +2688,25 @@ rbind(one_coalition_time_old, one_coalition_time_new_v6) internal$objects$S[internal$objects$S_batch$`1`[look_at_coalition], , drop = FALSE] -means_old <- one_coalition_res_old[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_old2 <- one_coalition_res_old2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v1 <- one_coalition_res_new_v1[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v2 <- one_coalition_res_new_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v3 <- one_coalition_res_new_v3[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v4 <- one_coalition_res_new_v4[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5 <- one_coalition_res_new_v5[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm <- one_coalition_res_new_v5_rnorm[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_v2 <- one_coalition_res_new_v5_rnorm_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp <- one_coalition_res_new_v5_rnorm_cpp[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_with_wrap <- one_coalition_res_new_v5_rnorm_cpp_with_wrap[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_v2 <- one_coalition_res_new_v5_rnorm_cpp_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_large_mat <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_large_mat_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_cube <- one_coalition_res_new_v5_rnorm_cpp_fix_cube[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_cube_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_cube_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- one_coalition_res_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v5_rnorm_cpp_fix_std_list <- one_coalition_res_new_v5_rnorm_cpp_fix_std_list[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] -means_v6 <- one_coalition_res_new_v6[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_combination, id)] +means_old <- one_coalition_res_old[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_old2 <- one_coalition_res_old2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v1 <- one_coalition_res_new_v1[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v2 <- one_coalition_res_new_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v3 <- one_coalition_res_new_v3[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v4 <- one_coalition_res_new_v4[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5 <- one_coalition_res_new_v5[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm <- one_coalition_res_new_v5_rnorm[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_v2 <- one_coalition_res_new_v5_rnorm_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp <- one_coalition_res_new_v5_rnorm_cpp[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_with_wrap <- one_coalition_res_new_v5_rnorm_cpp_with_wrap[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_v2 <- one_coalition_res_new_v5_rnorm_cpp_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_large_mat <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_large_mat_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_large_mat_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_cube <- one_coalition_res_new_v5_rnorm_cpp_fix_cube[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_cube_v2 <- one_coalition_res_new_v5_rnorm_cpp_fix_cube_v2[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_list_of_lists_of_matrices <- one_coalition_res_new_v5_rnorm_cpp_fix_list_of_lists_of_matrices[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v5_rnorm_cpp_fix_std_list <- one_coalition_res_new_v5_rnorm_cpp_fix_std_list[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] +means_v6 <- one_coalition_res_new_v6[, lapply(.SD, mean), .SDcols = paste0("X", seq(M)), by = list(id_coalition, id)] # They are all in the same ballpark, so the differences are due to sampling. # This is supported by the fact that mean_old and mean_old2 use the same old code, and the difference there is the diff --git a/inst/scripts/compare_shap_python.R b/inst/scripts/compare_shap_python.R index 6a4ed7787ef2334ce77a9355e62daee2ab2a02d9..ebc39e2c32593ff82c0b1fb5e85f7ccddbce0a38 100644 --- a/inst/scripts/compare_shap_python.R +++ b/inst/scripts/compare_shap_python.R @@ -47,12 +47,12 @@ time_R_prepare <- proc.time() # Computing the actual Shapley values with kernelSHAP accounting for feature dependence using # the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) -explanation_independence <- explain(x_test, explainer, approach = "independence", prediction_zero = p0) +explanation_independence <- explain(x_test, explainer, approach = "independence", phi0 = p0) time_R_indep0 <- proc.time() explanation_largesigma <- explain(x_test, explainer, approach = "empirical", type = "fixed_sigma", - fixed_sigma_vec = 10000, w_threshold = 1, prediction_zero = p0) + fixed_sigma_vec = 10000, w_threshold = 1, phi0 = p0) time_R_largesigma0 <- proc.time() diff --git a/inst/scripts/compare_shap_python_new.R b/inst/scripts/compare_shap_python_new.R index c15fed9d69c0807815853181752843b24dd5c6f8..5e51120f462ecd2f40ae05ea9dc0236c24b81ca9 100644 --- a/inst/scripts/compare_shap_python_new.R +++ b/inst/scripts/compare_shap_python_new.R @@ -40,14 +40,14 @@ time_R_start <- proc.time() # Computing the actual Shapley values with kernelSHAP accounting for feature dependence using # the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) explanation_independence <- explain(model = model,x_explain = x_test,x_train=x_train, - approach = "independence", prediction_zero = p0,n_batches = 1) + approach = "independence", phi0 = p0,n_batches = 1) time_R_indep0 <- proc.time() explanation_largesigma <- explain(model = model,x_explain = x_test,x_train=x_train, approach = "empirical",empirical.type="fixed_sigma",empirical.fixed_sigma=10000,empirical.eta=1, - prediction_zero = p0,n_batches=1) + phi0 = p0,n_batches=1) time_R_largesigma0 <- proc.time() @@ -56,8 +56,8 @@ time_R_largesigma0 <- proc.time() (time_R_largesigma <- time_R_largesigma0 - time_R_indep0) # Printing the Shapley values for the test data -Kshap_indep <- explanation_independence$shapley_values -Kshap_largesigma <- explanation_largesigma$shapley_values +Kshap_indep <- explanation_independence$shapley_values_est +Kshap_largesigma <- explanation_largesigma$shapley_values_est Kshap_indep Kshap_largesigma diff --git a/inst/scripts/devel/Rscript_test_shapr.R b/inst/scripts/devel/Rscript_test_shapr.R index 8f8b5a504e475a64ea97df0efb6f4d41fb7a08fb..03380a6ed3cded3a5706c498d8d31ca8c1d85ef8 100644 --- a/inst/scripts/devel/Rscript_test_shapr.R +++ b/inst/scripts/devel/Rscript_test_shapr.R @@ -62,7 +62,7 @@ sys_time_start_shapr <- Sys.time() explainer <- shapr(x_train, model) sys_time_end_shapr <- Sys.time() -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(nrow(explainer$S),n_batches) @@ -73,7 +73,7 @@ explanation <- explain( x_test, approach = approach, explainer = explainer, - prediction_zero = prediction_zero, + phi0 = phi0, n_batches = n_batches_use ) sys_time_end_explain <- Sys.time() diff --git a/inst/scripts/devel/compare_explain_batch.R b/inst/scripts/devel/compare_explain_batch.R index cedf257fbe5bfabf762ee03a94d99913283bb141..48544bd8036d1b6ebf7468d97c162fee21b44531 100644 --- a/inst/scripts/devel/compare_explain_batch.R +++ b/inst/scripts/devel/compare_explain_batch.R @@ -23,15 +23,15 @@ model <- xgboost( # THIS IS GENERATED FROM MASTER BRANCH # Prepare the data for explanation library(shapr) -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000) #saveRDS(list(gauss = gauss, empirical = emp, copula = copula, indep = indep, comb = comb, ctree = ctree, ctree_comb = ctree2), file = "inst/scripts/devel/master_res2.rds") # saveRDS(list(ctree = ctree, ctree_comb = ctree2), file = "inst/scripts/devel/master_res_ctree.rds") @@ -40,15 +40,15 @@ detach("package:shapr", unload = TRUE) devtools::load_all() nobs = 6 x_test <- as.matrix(Boston[1:nobs, x_var]) -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000, n_batches = 1) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000, n_batches = 1) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000, n_batches = 1) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 1) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000, n_batches = 1) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000, n_batches = 1) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000, n_batches = 1) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000, n_batches = 1) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000, n_batches = 1) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000, n_batches = 1) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 1) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000, n_batches = 1) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000, n_batches = 1) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000, n_batches = 1) res = readRDS("inst/scripts/devel/master_res2.rds") @@ -60,8 +60,8 @@ res$comb$dt comb$dt # With batches -gauss_b = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000, n_batches = 3) -emp_b = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000, n_batches = 3) +gauss_b = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000, n_batches = 3) +emp_b = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000, n_batches = 3) gauss_b$dt res$gauss$dt @@ -71,7 +71,7 @@ res$empirical$dt #### MJ stuff here: -explain.independence2 <- function(x, explainer, approach, prediction_zero, +explain.independence2 <- function(x, explainer, approach, phi0, n_samples = 1e3, n_batches = 1, seed = 1, only_return_contrib_dt = FALSE, ...) { @@ -82,12 +82,12 @@ explain.independence2 <- function(x, explainer, approach, prediction_zero, explainer$approach <- approach explainer$n_samples <- n_samples - r <- prepare_and_predict(explainer, n_batches, prediction_zero, only_return_contrib_dt, ...) + r <- prepare_and_predict(explainer, n_batches, phi0, only_return_contrib_dt, ...) } prepare_data.independence2 <- function(x, index_features = NULL, ...) { - id <- id_combination <- w <- NULL # due to NSE notes in R CMD check + id <- id_coalition <- w <- NULL # due to NSE notes in R CMD check if (is.null(index_features)) { index_features <- x$X[, .I] @@ -122,7 +122,7 @@ prepare_data.independence2 <- function(x, index_features = NULL, ...) { # Add keys dt_l[[i]] <- data.table::as.data.table(dt_p) data.table::setnames(dt_l[[i]], colnames(x_train)) - dt_l[[i]][, id_combination := index_s] + dt_l[[i]][, id_coalition := index_s] dt_l[[i]][, w := w] # IS THIS NECESSARY? dt_l[[i]][, id := i] } @@ -137,36 +137,36 @@ prepare_data.independence2 <- function(x, index_features = NULL, ...) { # Using independence with n_samples > nrow(x_train) such that no sampling is performed -indep1 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 1) -indep2 = explain(x_test, explainer, "independence2", prediction_zero = p, n_samples = 10000, n_batches = 1) +indep1 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 1) +indep2 = explain(x_test, explainer, "independence2", phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,indep2) # TRUE -indep1_batch_2 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 2) +indep1_batch_2 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,indep1_batch_2) # TRUE -indep1_batch_5 = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000, n_batches = 5) +indep1_batch_5 = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000, n_batches = 5) all.equal(indep1,indep1_batch_5) # TRUE -comb_indep_1_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), prediction_zero = p, n_samples = 10000, n_batches = 1) +comb_indep_1_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,comb_indep_1_batch_1) # TRUE -comb_indep_1_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), prediction_zero = p, n_samples = 10000, n_batches = 2) +comb_indep_1_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence", "independence"), phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,comb_indep_1_batch_2) # TRUE -comb_indep_1_2_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 1) +comb_indep_1_2_batch_1 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 1) all.equal(indep1,comb_indep_1_2_batch_1) #TRUE -comb_indep_1_2_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 2) +comb_indep_1_2_batch_2 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 2) all.equal(indep1,comb_indep_1_2_batch_2) #TRUE -comb_indep_1_2_batch_5 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), prediction_zero = p, n_samples = 10000, n_batches = 5) +comb_indep_1_2_batch_5 = explain(x_test, explainer, c("independence", "independence", "independence2", "independence2"), phi0 = p, n_samples = 10000, n_batches = 5) all.equal(indep1,comb_indep_1_2_batch_5) #TRUE diff --git a/inst/scripts/devel/compare_indep_implementations.R b/inst/scripts/devel/compare_indep_implementations.R index a508e2d1e12aafc5deadc2678b26fdf0a53bc290..ae035b492aaeba4ff72692b05aa1447bb53c52b8 100644 --- a/inst/scripts/devel/compare_indep_implementations.R +++ b/inst/scripts/devel/compare_indep_implementations.R @@ -37,7 +37,7 @@ explanation_old <- explain( approach = "empirical", type = "independence", explainer = explainer, - prediction_zero = p, seed=111,n_samples = 100 + phi0 = p, seed=111,n_samples = 100 ) print(proc.time()-t_old) #user system elapsed @@ -48,7 +48,7 @@ explanation_new <- explain( x_test, approach = "independence", explainer = explainer, - prediction_zero = p,seed = 111,n_samples = 100 + phi0 = p,seed = 111,n_samples = 100 ) print(proc.time()-t_new) #user system elapsed @@ -69,7 +69,7 @@ explanation_full_old <- explain( approach = "empirical", type = "independence", explainer = explainer, - prediction_zero = p, seed=111 + phi0 = p, seed=111 ) print(proc.time()-t_old) #user system elapsed @@ -80,7 +80,7 @@ explanation_full_new <- explain( x_test, approach = "independence", explainer = explainer, - prediction_zero = p,seed = 111 + phi0 = p,seed = 111 ) print(proc.time()-t_new) #user system elapsed diff --git a/inst/scripts/devel/demonstrate_combined_approaches_bugs.R b/inst/scripts/devel/demonstrate_combined_approaches_bugs.R index 57e5b9f44d5db3dc7a3a14d4d5e537cadc367a94..bafa0bab34c57ae411381c6b27e0a2a884b5bdac 100644 --- a/inst/scripts/devel/demonstrate_combined_approaches_bugs.R +++ b/inst/scripts/devel/demonstrate_combined_approaches_bugs.R @@ -10,7 +10,7 @@ explanation_1 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, n_batches = 3, timing = FALSE, seed = 1) @@ -42,7 +42,7 @@ explanation_2 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"), - prediction_zero = p0, + phi0 = p0, n_batches = 2, timing = FALSE, seed = 1) @@ -62,7 +62,7 @@ explanation_3 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "ctree", "ctree", "ctree" ,"ctree"), - prediction_zero = p0, + phi0 = p0, n_batches = 15, timing = FALSE, seed = 1) @@ -93,7 +93,7 @@ explanation_combined_1 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -102,7 +102,7 @@ explanation_combined_2 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "empirical"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -117,7 +117,7 @@ explanation_combined_3 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "ctree"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) @@ -126,7 +126,7 @@ explanation_combined_4 = explain( x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula", "ctree"), - prediction_zero = p0, + phi0 = p0, timing = FALSE, seed = 1) diff --git a/inst/scripts/devel/devel_batch_testing.R b/inst/scripts/devel/devel_batch_testing.R new file mode 100644 index 0000000000000000000000000000000000000000..20c1063f35700da27785381d4ed2ddee470ff8a1 --- /dev/null +++ b/inst/scripts/devel/devel_batch_testing.R @@ -0,0 +1,67 @@ + +#remotes::install_github("NorskRegnesentral/shapr") # Installs GitHub version of shapr + +library(shapr) +library(data.table) +library(MASS) +library(Matrix) + +# Just sample some data to work with +m <- 9 +n_train <- 10000 +n_explain <- 10 +rho_1 <- 0.5 +rho_2 <- 0 +rho_3 <- 0.4 +Sigma_1 <- matrix(rho_1, m/3, m/3) + diag(m/3) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/3, m/3) + diag(m/3) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/3, m/3) + diag(m/3) * (1 - rho_3) +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3)) +mu <- rep(0,m) + +set.seed(123) + + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + +beta <- c(4:1, rep(0, m - 4)) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +p0 <- mean(y_train) + +# We need to pass a model object and a proper prediction function to shapr for it to work, but it can be anything as we don't use it +model <- lm(y_train ~ ., data = x_train) + +### First run proper shapr call on this +library(progressr) +library(future) +# Not necessary, and only apply to the explain() call below +progressr::handlers(global = TRUE) # For progress bars +#future::plan(multisession, workers = 2) # Parallized computations +#future::plan(sequential) + +expl <- explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "ctree", + phi0 = p0, + n_batches = 100, + n_samples = 1000, + iterative = TRUE, + print_iter_info = TRUE, + print_shapleyres = TRUE) + + +n_combinations <- 5 +max_batch_size <- 10 +min_n_batches <- 10 + diff --git a/inst/scripts/devel/devel_convergence_branch.R b/inst/scripts/devel/devel_convergence_branch.R new file mode 100644 index 0000000000000000000000000000000000000000..313a28698be7af373450daac7684082c6c641060 --- /dev/null +++ b/inst/scripts/devel/devel_convergence_branch.R @@ -0,0 +1,148 @@ +library(xgboost) +#library(shapr) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] +data[,new1 :=sqrt(Wind*Ozone)] +data[,new2 :=sqrt(Wind*Temp)] +data[,new3 :=sqrt(Wind*Day)] +data[,new4 :=sqrt(Wind*Solar.R)] +data[,new5 :=rnorm(.N)] +data[,new6 :=rnorm(.N)] +data[,new7 :=rnorm(.N)] + + +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day","new1","new2","new3","new4","new5")#"new6","new7") +y_var <- "Ozone" + +ind_x_explain <- 1:20 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Looking at the dependence between the features +cor(x_train) + +# Fitting a basic xgboost model to the training data +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Computing the actual Shapley values with kernelSHAP accounting for feature dependence using +# the empirical (conditional) distribution approach with bandwidth parameter sigma = 0.1 (default) +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + max_n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + + +explanation_noniterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = 400, + phi0 = p0, + iterative = FALSE +) + + +explanation_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = 500, + phi0 = p0, + iterative = TRUE, + iterative_args = list(initial_n_coalitions=10,convergence_tol=0.0001), + print_shapleyres = TRUE, # tmp + print_iter_info = TRUE, # tmp + kernelSHAP_reweighting = "on_N" +) + + + + + + + + + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1,Solar.R],type="l") +sd_full <- explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1][.N,Solar.R] +n_samples_full <- explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[.N,n_current_samples] +sd_full0 <- sd_full*sqrt(n_samples_full) +lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples),type="l",col=2) + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$estimated_required_samples,type="l",ylim=c(0,4000),lwd=4) +for(i in 1:20){ + lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[[5+i]],type="l",col=1+i) +} + + +plot(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1,Solar.R],type="l",ylim=c(0,2)) +sd_full <- explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==1][.N,Solar.R] +n_samples_full <- explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res[.N,n_current_samples] +sd_full0 <- sd_full*sqrt(n_samples_full) +lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples),type="l",col=2,lwd=3) + +for(i in 1:20){ + lines(explanation_iterative$internal$output$iter_objects$dt_iter_convergence_res$n_current_samples, + explanation_iterative$internal$output$iter_objects$dt_iter_shapley_sd[explain_id==i,Solar.R],type="l",col=1+i) +} + + + +lines(explanation_iterative$internal$output$dt_iter_convergence_res$n_current_samples, + sd_full0/sqrt(explanation_iterative$internal$output$dt_iter_convergence_res$n_current_samples),type="l",col=2) + + +plot(explanation_iterative$internal$output$dt_iter_convergence_res$estimated_required_samples) + +explanation_regular <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + n_coalitions = NULL, + phi0 = p0, + iterative = FALSE +) + diff --git a/inst/scripts/devel/devel_non_exact_grouping.R b/inst/scripts/devel/devel_non_exact_grouping.R index 02f3196da4fcb5e656f1314934205e7e2b8ed7cb..d5e29e3b0a11bc6f83a278499f2f8609e4179ff6 100644 --- a/inst/scripts/devel/devel_non_exact_grouping.R +++ b/inst/scripts/devel/devel_non_exact_grouping.R @@ -1,5 +1,5 @@ -### NOTE: THIS DOES NO LONGER WORK AS WE SWITCH TO exact when a large n_combinations is used, but the checks +### NOTE: THIS DOES NO LONGER WORK AS WE SWITCH TO exact when a large n_coalitions is used, but the checks ### confirms the code works as intended. library(xgboost) @@ -30,7 +30,7 @@ model <- xgboost( group <- list(A=x_var[1:3],B=x_var[4:5],C=x_var[7],D=x_var[c(6,8)],E=x_var[9]) -explainer1 <- shapr(x_train, model,group = group,n_combinations=10^ 6) +explainer1 <- shapr(x_train, model,group = group,n_coalitions=10^ 6) explainer2 <- shapr(x_train, model,group = group) @@ -38,14 +38,14 @@ explanation1 <- explain( x_test, approach = "independence", explainer = explainer1, - prediction_zero = p + phi0 = p ) explanation2 <- explain( x_test, approach = "independence", explainer = explainer2, - prediction_zero = p + phi0 = p ) diff --git a/inst/scripts/devel/devel_parallelization.R b/inst/scripts/devel/devel_parallelization.R index 6dd6d10bd8846cc8a507820c940be00d01ec7c14..21aa964ccab2c9086e0ca5a3fecb7f37f0d026ec 100644 --- a/inst/scripts/devel/devel_parallelization.R +++ b/inst/scripts/devel/devel_parallelization.R @@ -35,7 +35,7 @@ explanation0 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time0 <- stop-start @@ -48,7 +48,7 @@ explanation1 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time1 <- stop-start @@ -60,7 +60,7 @@ explanation2 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time2 <- stop-start @@ -72,7 +72,7 @@ explanation3 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time3 <- stop-start @@ -84,7 +84,7 @@ explanation4 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time4 <- stop-start @@ -96,7 +96,7 @@ explanation5 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time5 <- stop-start @@ -108,7 +108,7 @@ explanation6 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() time6 <- stop-start @@ -123,7 +123,7 @@ explanation7 <- explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 32 + phi0 = p,n_batches = 32 ) stop <- proc.time() parallel::stopCluster(cl) diff --git a/inst/scripts/devel/devel_tmp_new_batch.R b/inst/scripts/devel/devel_tmp_new_batch.R index 290d5c009f622a19cdf31375fe84dd7f5166ef89..37950b3a3f76ff154594aedac2f35a47666dabb1 100644 --- a/inst/scripts/devel/devel_tmp_new_batch.R +++ b/inst/scripts/devel/devel_tmp_new_batch.R @@ -5,7 +5,7 @@ explainer <- explain_setup( x_test, approach = c("empirical","empirical","gaussian","copula"), explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 4 ) diff --git a/inst/scripts/devel/devel_verbose.R b/inst/scripts/devel/devel_verbose.R new file mode 100644 index 0000000000000000000000000000000000000000..ad4a2fb7da0dbecd69304267fb6920d936d56128 --- /dev/null +++ b/inst/scripts/devel/devel_verbose.R @@ -0,0 +1,135 @@ +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE,verbose=c("basic","progress") +) + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = TRUE,verbose=c("vS_details") +) +ex <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = TRUE,verbose=c("basic","progress","vS_details"), + regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), + regression.tune_values = dials::grid_regular(dials::tree_depth(), levels = 4), + regression.vfold_cv_para = list(v = 5) +) + +ex <- explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_surrogate", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","vS_details"), + regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), + regression.tune_values = dials::grid_regular(dials::tree_depth(), levels = 4), + regression.vfold_cv_para = list(v = 5) +) + + +future::plan("multisession", workers = 4) +progressr::handlers(global = TRUE) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "vaeac", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","progress","vS_details"), + n_MC_samples = 100, + vaeac.epochs = 3 +) + +ex2 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "vaeac", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic","progress","vS_details"), + n_MC_samples = 100, + vaeac.extra_parameters = list( + vaeac.pretrained_vaeac_model = ex$internal$parameters$vaeac + ) +) + + + +vaeac.extra_parameters = list( + vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac +) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + max_n_coalitions = 30, + iterative = FALSE,verbose=c("basic") +) + + +ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE,verbose=c("basic","convergence","shapley") +) + + +explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative = TRUE, + iterative_args <- list(n_initial_) + verbose = c("basic"), + paired_shap_sampling = TRUE +) diff --git a/inst/scripts/devel/explain_new.R b/inst/scripts/devel/explain_new.R index b6a1e2af75fd1d0a83fff250f8339ec4023fac19..1e86d12760e56a12b8ab1e6d0768659141344c62 100644 --- a/inst/scripts/devel/explain_new.R +++ b/inst/scripts/devel/explain_new.R @@ -39,7 +39,7 @@ explanation_new <- explain_new( x_test, approach = "gaussian", explainer = explainer1, - prediction_zero = p, + phi0 = p, n_samples = 5*10^5,n_batches = 1 ) @@ -56,7 +56,7 @@ explanation_new <- explain_new( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 4 ) @@ -73,7 +73,7 @@ explanation_new <- explain_new( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 1 ) @@ -90,7 +90,7 @@ explanation_new <- explain_new( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_samples = 10^5,n_batches = 4 ) @@ -112,7 +112,7 @@ explanation_new$dt_shapley # x_test, # approach = "gaussian", # explainer = explainer, -# prediction_zero = p +# phi0 = p # ) # # str(explainer,max.level = 1) @@ -122,7 +122,7 @@ explainer <- explain_setup( x_test, approach = "empirical", explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 4 ) @@ -130,7 +130,7 @@ explainer0 <- explain_setup( x_test, approach = c("empirical","copula","ctree","gaussian"), explainer = explainer, - prediction_zero = p, + phi0 = p, n_batches = 7 ) @@ -149,7 +149,7 @@ explainer0$X # x_test, # approach = "gaussian", # explainer = explainer, -# prediction_zero = p, +# phi0 = p, # n_samples = 10^5 # ) diff --git a/inst/scripts/devel/future_testing.R b/inst/scripts/devel/future_testing.R new file mode 100644 index 0000000000000000000000000000000000000000..6d6734f76a36b1f5fe040b0042fbb196bf50c04f --- /dev/null +++ b/inst/scripts/devel/future_testing.R @@ -0,0 +1,56 @@ + +plan(multisession, workers = 5) # Adjust the number of workers as needed +plan(sequential) # Adjust the number of workers as needed + +fun <- function(x) { + print(x) + if(z==0){ + if(x==5){ + Sys.sleep(1) + z <<- 100 + } + return(x+z) + } else { + return(NA) + } +} + +z <- 0 + + + + +plan(multisession, workers = 5) +plan(multicore, workers = 5) + +plan(sequential) + +fun2 <- function(x){ + x^2 +} + + +start <- proc.time() +for(i in 1:100){ + future.apply::future_lapply(1:10, fun2) +} +print(proc.time()-start) +#user system elapsed +#14.985 0.045 20.323 + +start <- proc.time() +for(i in 1:10){ + future.apply::future_lapply(rep(1:10,10), fun2) +} +print(proc.time()-start) +#user system elapsed +#1.504 0.005 2.009 + +start <- proc.time() +aa=future.apply::future_lapply(rep(1:10,100), fun2) +print(proc.time()-start) +#user system elapsed +#0.146 0.000 0.202 + + + diff --git a/inst/scripts/devel/real_data_iterative_kernelshap.R b/inst/scripts/devel/real_data_iterative_kernelshap.R new file mode 100644 index 0000000000000000000000000000000000000000..0e33ae141177ee2f876a6c4f159b5d3ce0b88890 --- /dev/null +++ b/inst/scripts/devel/real_data_iterative_kernelshap.R @@ -0,0 +1,276 @@ + +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + +print(Sys.time()) +library(data.table) +library(shapr) +library(ranger) + +# Give me some credit data set +gmc <- read.table("/nr/project/stat//BigInsight//Projects//Explanations//Counterfactual_kode//Carla_datasets//GiveMeSomeCredit-training.csv",header=TRUE, sep=",") +foo <- apply(gmc,1,sum) +ind <- which(is.na(foo)) +gmc <- gmc[-ind,] + + +nobs <- dim(gmc)[1] +ind <- sample(x=nobs, size=round(0.75*nobs)) +gmcTrain <- gmc[ind,-1] +gmcTest <- gmc[-ind,-1] +gmcTrain <- as.data.table(gmcTrain) +gmcTest <- as.data.table(gmcTest) + +integer_columns <- sapply(gmcTrain, is.integer) # Identify integer columns +integer_columns = integer_columns[2:length(integer_columns)] +gmcTrain[, c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents"):= +lapply(.SD, as.numeric), .SDcols = c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents")] +integer_columns <- sapply(gmcTest, is.integer) # Identify integer columns +integer_columns = integer_columns[2:length(integer_columns)] +gmcTest[, c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents"):= +lapply(.SD, as.numeric), .SDcols = c("RevolvingUtilizationOfUnsecuredLines", "age", +"NumberOfTime30.59DaysPastDueNotWorse", "DebtRatio", "MonthlyIncome", +"NumberOfOpenCreditLinesAndLoans", "NumberOfTimes90DaysLate", +"NumberRealEstateLoansOrLines", "NumberOfTime60.89DaysPastDueNotWorse", "NumberOfDependents")] + +# model <- ranger(SeriousDlqin2yrs ~ ., data = gmcTrain, num.trees = 500, num.threads = 6, +# verbose = TRUE, +# probability = FALSE, +# importance = "impurity", +# mtry = sqrt(11), +# seed = 3045) +library(hmeasure) +#pred.rf <- predict(model, data = gmcTest) +#results <- HMeasure(unlist(as.vector(gmcTest[,1])),pred.rf$predictions,threshold=0.15) +#results$metrics$AUC + +y_train = gmcTrain$SeriousDlqin2yrs +x_train = gmcTrain[,-1] +y_explain = gmcTest$SeriousDlqin2yrs +x_explain = gmcTest[,-1] + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE,params = list(objective = "binary:logistic") +) +pred.xgb <- predict(model, newdata = as.matrix(x_explain)) +results <- HMeasure(as.vector(y_explain),pred.xgb,threshold=0.15) +results$metrics$AUC + + +set.seed(123) + +inds_train = sample(1:nrow(x_train), 9000) +x_train = x_train[inds_train,] +y_train = y_train[inds_train] + +m = ncol(x_train) + + +p0 <- mean(y_train) +mu = colMeans(x_train) +Sigma = cov(x_train) + +### First run proper shapr call on this + +sim_results_saving_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/"#"../effektiv_shapley_output/" +kernelSHAP_reweighting_strategy = "none" + +predict_model_xgb <- function(object,newdata){ + xgboost:::predict.xgb.Booster(object,as.matrix(newdata)) +} + + +preds_explain <- predict_model_xgb(model,x_explain) +head(order(-preds_explain),50) +inds_1 <- head(order(-preds_explain),50) +set.seed(123) +inds_2 <- sample(which(preds_explain>quantile(preds_explain,0.9) & preds_explain 0.05 +shapley_threshold_val <- 0.02 +shapley_threshold_prob <- 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) +runres_list <- runcomps_list <- list() + +cutoff_feats = colnames(x_train) + +run_obj_list <- list() +for(kk in seq_along(testObs_computed_vec)){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict_model_xgb(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_coalitions = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict_model_xgb, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + run_obj_list[[kk]] <- run + print(kk) + print(Sys.time()) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in seq_along(testObs_computed_vec)){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[testObs_computed_vec[i],], + x_train = x_train, + approach = approach, + phi0 = p0, + n_coalitions = runcomps_list[[i]]) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +truth <- expl$shapley_values_est + +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_lingauss_p12_", kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + +print(Sys.time()) + +# TODO: Må finne ut av hvorfor det ikke gir korrekt sum her... +# Hvis det er noen variabler som ble ekskludert, så må jeg legge til disse i summen for å få prediksjonen til modellen. +# for(i in 1:18){ +# print(sum(unlist(run$keep_list[[i]]$kshap_est_dt[,-1]))+run$keep_list[[i]]$shap_it_excluded_features) +# #print(run$keep_list[[i]]$shap_it_excluded_features) +# } + +# run$kshap_it_est_dt + + + +# run$kshap_final +# expl$shapley_values_est + + + + +# kshap_final <- copy(run$kshap_est_dt_list[,-1]) +# setnafill(kshap_final,"locf") +# kshap_final[.N,] # final estimate + +# sum(unlist(kshap_final[.N,])) + +# sum(unlist(expl$shapley_values_est[testObs_computed,])) + + + + + + + + + + +# cutoff_feats <- paste0("VV",1:6) +# testObs_computed <- 5 + +# full_pred <- predict(model,x_explain)[5] +# p0 <- mean(y_train) +# pred_not_to_decompose <- sum(expl$shapley_values_est[5,VV7:VV9]) + + +# run_minor <- iterative_kshap_func(model,x_explain,x_train, +# testObs_computed = 5, +# cutoff_feats = cutoff_feats, +# full_pred = full_pred, +# pred_not_to_decompose = pred_not_to_decompose, +# p0 = p0, +# predict_model = predict.lm,shapley_threshold_val = 0) + + +# aa=run$keep_list[[8]]$dt_vS + +# bb=run_minor$keep_list[[6]]$dt_vS +# setnames(bb,"p_hat_1","p_hat_1_approx") + +# cc=merge(aa,bb) +# cc[,diff:=p_hat_1-p_hat_1_approx] + + +# TODO: + +# 1. Run example with gaussian features where the truth is known in advance in a large setting, with e.g. 12 features or so. I want the estimate +# both for the full 12 features, and for subsets where one is removed. +# 2. + +# Utfordringer: +# 1. Hvordan justere vekter og samplingrutine fra subset S når man allerede har et utvalg sampler (som også er noe biased). +# 2. Bruker altså E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*] som proxy for E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*,x3=x3*], +#men hva med E[f(x1=x1*,x2,x3,x4=x4*)|x1=x1*,x4=x4*]? Burde jeg bruke den for +#E[f(x1=x1*,x2,x3=x3*,x4=x4*)|x1=x1*,x4=x4*]? +# 3. Når jeg fjerner en variabel (som har lite å si), så settes shapley-verdien til det den har per da. MEN den verdien vil trolig være noe biased fordi den fjernes første gangen den går over terskelverdiene +# jeg har satt for ekskludering. + diff --git a/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R b/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R new file mode 100644 index 0000000000000000000000000000000000000000..866d28bf9ecc5424b54e69f3108b7fc8082b83c9 --- /dev/null +++ b/inst/scripts/devel/real_data_iterative_kernelshap_analyze_results.R @@ -0,0 +1,135 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/" + +load(paste0("/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/gmc_data_v3/iterative_kernelshap_lingauss_p12_", kernelSHAP_reweighting_strategy, ".RData")) + + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) +# names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +treeshap_vals <- as.data.table(predict(model,newdata=as.matrix(x_explain),predcontrib = TRUE)) +setnames(treeshap_vals,"BIAS","none") +setcolorder(treeshap_vals,"none") +head(treeshap_vals) +mae_vec_treeshap <- colMeans(abs(exact_vals - treeshap_vals)) +mean(mae_vec_treeshap[-1]) + + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +df[,features0:=.GRP,by="features"] +df[,features1:=paste0("VV",features0)] +df[,features1:=factor(features1,levels=c(paste0("VV",1:11)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features1, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p,width = 10, height = 5) + + + + + + + + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = abs_bias, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "bias_comparison.png"), plot = p) + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + +#### Just looking at the largest predictions + +preds <- rowSums(exact_vals) + +these <- head(order(-preds),10) + +preds[these]-rowSums(iterative_vals)[these] + +bias_vec <- colMeans(exact_vals[these] - iterative_vals[these]) +rmse_vec <- sqrt(colMeans((exact_vals[these] - iterative_vals[these])^2)) +mae_vec <- colMeans(abs(exact_vals[these] - iterative_vals[these])) + +bias_vec_approx <- colMeans(exact_vals[these] - approx_vals[these]) +rmse_vec_approx <- sqrt(colMeans((exact_vals[these] - approx_vals[these])^2)) +mae_vec_approx <- colMeans(abs(exact_vals[these] - approx_vals[these])) + + + diff --git a/inst/scripts/devel/same_seed_as_master.R b/inst/scripts/devel/same_seed_as_master.R index a06469cb2930ba5b934553b00c04788774f1f07b..4460b7e62b08d9fecf46369d053911488c29c0ec 100644 --- a/inst/scripts/devel/same_seed_as_master.R +++ b/inst/scripts/devel/same_seed_as_master.R @@ -20,15 +20,15 @@ model <- xgboost( ) # THIS IS GENERATED FROM MASTER BRANCH # Prepare the data for explanation -explainer <- shapr(x_train, model,n_combinations = 100) +explainer <- shapr(x_train, model,n_coalitions = 100) p = mean(y_train) -gauss = explain(x_test, explainer, "gaussian", prediction_zero = p, n_samples = 10000) -emp = explain(x_test, explainer, "empirical", prediction_zero = p, n_samples = 10000) -copula = explain(x_test, explainer, "copula", prediction_zero = p, n_samples = 10000) -indep = explain(x_test, explainer, "independence", prediction_zero = p, n_samples = 10000) -comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), prediction_zero = p, n_samples = 10000) -ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, prediction_zero = p, n_samples = 10000) -ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), prediction_zero = p, n_samples = 10000) +gauss = explain(x_test, explainer, "gaussian", phi0 = p, n_samples = 10000) +emp = explain(x_test, explainer, "empirical", phi0 = p, n_samples = 10000) +copula = explain(x_test, explainer, "copula", phi0 = p, n_samples = 10000) +indep = explain(x_test, explainer, "independence", phi0 = p, n_samples = 10000) +comb = explain(x_test, explainer, c("gaussian", "gaussian", "empirical", "empirical"), phi0 = p, n_samples = 10000) +ctree = explain(x_test, explainer, "ctree", mincriterion = 0.95, phi0 = p, n_samples = 10000) +ctree2 = explain(x_test, explainer, "ctree", mincriterion = c(0.95, 0.95, 0.95, 0.95), phi0 = p, n_samples = 10000) # results from master diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R new file mode 100644 index 0000000000000000000000000000000000000000..ac35df40cb3c6b522bd703fcff4e39d873ec84db --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_analyze_results.R @@ -0,0 +1,88 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_lingauss_v2/" + +load(paste0(sim_results_folder,"iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = abs_bias, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "bias_comparison.png"), plot = p) + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R new file mode 100644 index 0000000000000000000000000000000000000000..afbd2467f2535d87003b6a76c99d9f927d3ea3b9 --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_lingauss_v2.R @@ -0,0 +1,261 @@ + +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +shapley_threshold_prob <- 0.2 +shapley_threshold_val <- 0.1 + +m <- 12 +n_train <- 5000 +n_explain <- 100 +rho_1 <- 0.5 +rho_2 <- 0.5 +rho_3 <- 0.5 +rho_4 <- 0 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +library(corrplot) +corrplot(Sigma) +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +beta <- c(5:1, rep(0, m - 5)) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) + +model <- lm(y_train ~ .,data = xy_train) + +pred_train <- predict(model, x_train) +plot(unlist(x_train[,1]),pred_train) +plot(unlist(x_train[,2]),pred_train) +plot(unlist(x_train[,3]),pred_train) +plot(unlist(x_train[,4]),pred_train) +plot(unlist(x_train[,5]),pred_train) +plot(unlist(x_train[,6]),pred_train) + +this_order <- order(unlist(x_train[,1])) + +plot(unlist(x_train[this_order,1]),pred_train[this_order],type="l") + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +sim_results_saving_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_lingauss_v2/"#"../effektiv_shapley_output/" +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +inds = 1:n_explain +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain[inds,], + x_train = x_train, + approach = "gaussian", + phi0 = p0,Sigma=Sigma,mu=mu) + +fwrite(expl$shapley_values_est,paste0(sim_results_saving_folder,"exact_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + +cutoff_feats <- paste0("VV",1:12) + + +### Need to create an lm analogoue to pred_mod_xgb here + + +set.seed(123) + + + +# These are the parameters for for interative_kshap_func +n_samples <- 1000 +approach = "gaussian" + +# Reduce if < 10% prob of shapval > 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) + +# Using threshold: 0.1 +runres_list <- runcomps_list <- list() +for(kk in testObs_computed_vec){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_coalitions = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict.lm, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + n_samples = n_samples, + gaussian.mu = mu, + gaussian.cov_mat = Sigma, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + print(kk) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + + + +truth <- expl$shapley_values_est + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in testObs_computed_vec){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[inds[i],], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_coalitions = runcomps_list[[i]], + Sigma=Sigma,mu=mu) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + + + + + + + + +# TODO: Må finne ut av hvorfor det ikke gir korrekt sum her... +# Hvis det er noen variabler som ble ekskludert, så må jeg legge til disse i summen for å få prediksjonen til modellen. +# for(i in 1:18){ +# print(sum(unlist(run$keep_list[[i]]$kshap_est_dt[,-1]))+run$keep_list[[i]]$shap_it_excluded_features) +# #print(run$keep_list[[i]]$shap_it_excluded_features) +# } + +# run$kshap_it_est_dt + + + +# run$kshap_final +# expl$shapley_values_est + + + + +# kshap_final <- copy(run$kshap_est_dt_list[,-1]) +# setnafill(kshap_final,"locf") +# kshap_final[.N,] # final estimate + +# sum(unlist(kshap_final[.N,])) + +# sum(unlist(expl$shapley_values_est[testObs_computed,])) + + + + + + + + + + +# cutoff_feats <- paste0("VV",1:6) +# testObs_computed <- 5 + +# full_pred <- predict(model,x_explain)[5] +# p0 <- mean(y_train) +# pred_not_to_decompose <- sum(expl$shapley_values_est[5,VV7:VV9]) + + +# run_minor <- iterative_kshap_func(model,x_explain,x_train, +# testObs_computed = 5, +# cutoff_feats = cutoff_feats, +# full_pred = full_pred, +# pred_not_to_decompose = pred_not_to_decompose, +# p0 = p0, +# predict_model = predict.lm,shapley_threshold_val = 0) + + +# aa=run$keep_list[[8]]$dt_vS + +# bb=run_minor$keep_list[[6]]$dt_vS +# setnames(bb,"p_hat_1","p_hat_1_approx") + +# cc=merge(aa,bb) +# cc[,diff:=p_hat_1-p_hat_1_approx] + + +# TODO: + +# 1. Run example with gaussian features where the truth is known in advance in a large setting, with e.g. 12 features or so. I want the estimate +# both for the full 12 features, and for subsets where one is removed. +# 2. + +# Utfordringer: +# 1. Hvordan justere vekter og samplingrutine fra subset S når man allerede har et utvalg sampler (som også er noe biased). +# 2. Bruker altså E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*] som proxy for E[f(x1=x1*,x2,x3=x3*,x4)|x1=x1*,x3=x3*], +#men hva med E[f(x1=x1*,x2,x3,x4=x4*)|x1=x1*,x4=x4*]? Burde jeg bruke den for +#E[f(x1=x1*,x2,x3=x3*,x4=x4*)|x1=x1*,x4=x4*]? +# 3. Når jeg fjerner en variabel (som har lite å si), så settes shapley-verdien til det den har per da. MEN den verdien vil trolig være noe biased fordi den fjernes første gangen den går over terskelverdiene +# jeg har satt for ekskludering. + diff --git a/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R b/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R new file mode 100644 index 0000000000000000000000000000000000000000..9888f57f1d761710a2acb0c67a9e953d5098d71e --- /dev/null +++ b/inst/scripts/devel/simtest_iterative_kernelshap_nonlingauss_analyze_results.R @@ -0,0 +1,122 @@ +library(data.table) +kernelSHAP_reweighting_strategy = "none" +shapley_threshold_val <- 0.2 + + + +sim_results_folder = "/nr/project/stat/BigInsight/Projects/Explanations/EffektivShapley/Frida/simuleringsresultater/sim_nonlingauss_v2/" + +load(paste0(sim_results_folder,"iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + + +exact_vals = fread(paste0(sim_results_folder,"exact_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +names(exact_vals) <- c("phi0", paste0("VV",1:12)) +iterative_vals = fread(paste0(sim_results_folder,"iterative_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) +approx_vals = fread(paste0(sim_results_folder,"approx_shapley_values_", shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(exact_vals - iterative_vals) +rmse_vec <- sqrt(colMeans((exact_vals - iterative_vals)^2)) +mae_vec <- colMeans(abs(exact_vals - iterative_vals)) + +bias_vec_approx <- colMeans(exact_vals - approx_vals) +rmse_vec_approx <- sqrt(colMeans((exact_vals - approx_vals)^2)) +mae_vec_approx <- colMeans(abs(exact_vals - approx_vals)) + +mean(mae_vec[-1]) +mean(mae_vec_approx[-1]) + +treeshap_vals <- as.data.table(predict(model,newdata=as.matrix(x_explain),predcontrib = TRUE)) +setnames(treeshap_vals,"BIAS","none") +setcolorder(treeshap_vals,"none") +head(treeshap_vals) +mae_vec_treeshap <- colMeans(abs(exact_vals - treeshap_vals)) +mean(mae_vec_treeshap[-1]) + +library(ggplot2) + +# Create a data frame for the bar plot + +# MAE +df <- data.frame(matrix(0, length(mae_vec)*2, 3)) +colnames(df) <- c("MAE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- mae_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- mae_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df <- as.data.table(df) +dt_treeshap <- data.frame(MAE=mae_vec_treeshap,approach="TreeSHAP",features=names(mae_vec_treeshap)) +df <- rbind(df,dt_treeshap) + +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + + + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p) + + +# RMSE +df <- data.frame(matrix(0, length(rmse_vec)*2, 3)) +colnames(df) <- c("RMSE", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- rmse_vec_approx +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- rmse_vec +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + + + + + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = RMSE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "rmse_comparison.png"), plot = p) + + +# Bias +df <- data.frame(matrix(0, length(bias_vec)*2, 3)) +colnames(df) <- c("abs_bias", "approach", "features") +# rownames(df) <- names(exact_vals) +df[1:length(exact_vals), 1] <- abs(bias_vec_approx) +df[1:length(exact_vals), 2] <- rep("approx", length(exact_vals)) +df[(length(exact_vals)+1):nrow(df), 1] <- abs(bias_vec) +df[(length(exact_vals)+1):nrow(df), 2] <- rep("iterative", length(exact_vals)) +df[, 3] <- rep(names(exact_vals), 2) +df[,features:=factor(features,levels=c("phi0",paste0("VV",1:12)))] + +# Create the bar plot using ggplot +p <- ggplot(df, aes(x = features, y = MAE, fill = approach)) + + geom_col(position = "dodge") +ggsave(paste(sim_results_folder, "mae_comparison.png"), plot = p, width = 10, height = 5) + + +# Number of sample used +runcomps_list + +df = data.frame(matrix(0, length(runcomps_list), 1)) +colnames(df) <- c("n_rows") +df$n_rows <- as.numeric(runcomps_list) + +p <- ggplot(df, aes(n_rows)) + + geom_histogram() +ggsave(paste0(sim_results_folder, "n_rows.png"), plot = p) + diff --git a/inst/scripts/devel/simtest_reweighting_strategies.R b/inst/scripts/devel/simtest_reweighting_strategies.R new file mode 100644 index 0000000000000000000000000000000000000000..3f6a1e3dfe13526f7f85e937d6e7b10824114570 --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies.R @@ -0,0 +1,263 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +beta <- rnorm(m) +alpha <- 1 +y_train <- as.vector(alpha + as.matrix(x_train) %*% beta + rnorm(n_train, 0, 1)) +y_explain <- alpha + as.matrix(x_explain) %*% beta + rnorm(n_explain, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) + +model <- lm(y_train ~ .,data = xy_train) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 100 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- list() + +for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + for(ii in seq_along(n_coalitions_vec)){ + + this_seed <- 1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + } + + } + + print(i) + + } + +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_lingaus.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +library(ggplot2) + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + + + +#### OLD #### + +### Need to create an lm analogoue to pred_mod_xgb here + + +set.seed(123) + + + +# These are the parameters for for interative_kshap_func +n_samples <- 1000 +approach = "gaussian" + +# Reduce if < 10% prob of shapval > 0.2 + +source("inst/scripts/devel/iterative_kernelshap_sourcefuncs.R") + +testObs_computed_vec <- inds# seq_len(n_explain) + +# Using threshold: 0.1 +runres_list <- runcomps_list <- list() +for(kk in testObs_computed_vec){ + testObs_computed <- testObs_computed_vec[kk] + full_pred <- predict(model,x_explain)[testObs_computed] + shapsum_other_features <- 0 + + + run <- iterative_kshap_func(model,x_explain,x_train, + testObs_computed = testObs_computed, + cutoff_feats = cutoff_feats, + initial_n_combinations = 50, + full_pred = full_pred, + shapsum_other_features = shapsum_other_features, + p0 = p0, + predict_model = predict.lm, + shapley_threshold_val = shapley_threshold_val, + shapley_threshold_prob = shapley_threshold_prob, + approach = approach, + n_samples = n_samples, + gaussian.mu = mu, + gaussian.cov_mat = Sigma, + kernelSHAP_reweighting_strategy = kernelSHAP_reweighting_strategy) + runres_list[[kk]] <- run$kshap_final + runcomps_list[[kk]] <- sum(sapply(run$keep_list,"[[","no_computed_combinations")) + print(kk) +} + +est <- rbindlist(runres_list) +est[,other_features:=NULL] +fwrite(est,paste0(sim_results_saving_folder,"iterative_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + + + + +truth <- expl$shapley_values_est + +expl_approx <- matrix(0, nrow = length(inds), ncol = m+1) +expl_approx_obj_list <- list() +for (i in testObs_computed_vec){ + expl_approx_obj <- shapr::explain(model = model, + x_explain= x_explain[inds[i],], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_combinations = runcomps_list[[i]], + Sigma=Sigma,mu=mu) + expl_approx[i,] = unlist(expl_approx_obj$shapley_values_est) + expl_approx_obj_list[[i]] <- expl_approx_obj +} +expl_approx <- as.data.table(expl_approx) +colnames(expl_approx) <- colnames(truth) +fwrite(expl_approx,paste0(sim_results_saving_folder,"approx_shapley_values_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".csv")) + +bias_vec <- colMeans(est-truth) +rmse_vec <- sqrt(colMeans((est-truth)^2)) +mae_vec <- colMeans(abs(est-truth)) + +bias_vec_approx <- colMeans(expl_approx-truth) +rmse_vec_approx <- sqrt(colMeans((expl_approx-truth)^2)) +mae_vec_approx <- colMeans(abs(expl_approx-truth)) + +save.image(paste0(sim_results_saving_folder, "iterative_kernelshap_",shapley_threshold_val,"_",kernelSHAP_reweighting_strategy, ".RData")) + +hist(unlist(runcomps_list),breaks = 20) + +summary(unlist(runcomps_list)) + + +run$kshap_final +sum(unlist(run$kshap_final)) +full_pred + + + + + + + diff --git a/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R new file mode 100644 index 0000000000000000000000000000000000000000..c7ab347b74797d92f09e640b069ed3eb8ac5764c --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear.R @@ -0,0 +1,182 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 100 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- list() + +for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + for(ii in seq_along(n_coalitions_vec)){ + + this_seed <- 1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + } + + } + + print(i) + + } + +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_nonlingaus.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +library(ggplot2) + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + +ggplot(resres[paired_shap_sampling==FALSE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() diff --git a/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R new file mode 100644 index 0000000000000000000000000000000000000000..84f9e71c4e721a03284589c7e4c4cc791c40c053 --- /dev/null +++ b/inst/scripts/devel/simtest_reweighting_strategies_nonlinear_nonunique_sampling.R @@ -0,0 +1,217 @@ +### Upcoming generalization: + +#1. Use non-linear truth (xgboost or so) +#2. Even more features + + +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +m <- 12 +n_train <- 5000 +n_explain <- 5 +rho_1 <- 0.9 +rho_2 <- 0.6 +rho_3 <- 0.3 +rho_4 <- 0.1 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +p0 <- mean(y_train) + + +### First run proper shapr call on this + +kernelSHAP_reweighting_strategy = "none" + +set.seed(465132) +progressr::handlers(global = TRUE) +expl <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_batches=100,n_samples = 10000, + phi0 = p0,Sigma=Sigma,mu=mu) + +dt_vS_map <- merge(expl$internal$iter_list[[1]]$coalition_map,expl$internal$output$dt_vS,by="id_coalition")[,-"id_coalition"] + + +kernelSHAP_reweighting_strategy_vec <- c("none","on_N","on_coal_size","on_all","on_all_cond","on_all_cond_paired","comb") + +n_coalitions_vec <- c(50,100,200,400,800,1200,1600,2000,2400,2800,3200,3600,4000) + +reps <- 200 + +paired_shap_sampling_vec <- c(FALSE,TRUE) + +res_list <- weight_list <- list() + +for(ii in seq_along(n_coalitions_vec)){ + + for(i0 in seq_along(paired_shap_sampling_vec)){ + + for(i in seq_len(reps)){ + + this_seed <- 10000+1+i + this_n_coalitions <- n_coalitions_vec[ii] + this_paired_shap_sampling <- paired_shap_sampling_vec[i0] + + this <- shapr::explain(model = model, + x_explain= x_explain, + x_train = x_train, + approach = "gaussian", + n_samples = 10, # Never used + n_batches=10, + phi0 = p0, + Sigma=Sigma, + mu=mu, + seed = this_seed, + max_n_coalitions = this_n_coalitions, + kernelSHAP_reweighting = "none", + unique_sampling = TRUE, + paired_shap_sampling = this_paired_shap_sampling) + + this0_X <- this$internal$objects$X + + + exact_dt_vS <- merge(this$internal$iter_list[[1]]$coalition_map,dt_vS_map,by="coalitions_str") + setorder(exact_dt_vS,id_coalition) + + + for(iii in seq_along(kernelSHAP_reweighting_strategy_vec)){ + this_kernelSHAP_reweighting_strategy <- kernelSHAP_reweighting_strategy_vec[iii] + + this_X <- copy(this0_X) + + shapr:::kernelSHAP_reweighting(this_X,reweight=this_kernelSHAP_reweighting_strategy) + + this_W <- weight_matrix( + X = this_X, + normalize_W_weights = TRUE + ) + + shap_dt0 <- as.data.table(cbind(seq_len(n_explain),t(this_W%*%as.matrix(exact_dt_vS[,-c("coalitions_str","id_coalition")])))) + names(shap_dt0) <- names(this$shapley_values_est) + + this_diff <- unlist(shap_dt0[,-c(1,2)]-expl$shapley_values_est[,-c(1,2)]) + this_bias <- mean(this_diff) + this_var <- var(this_diff) + this_MAE <- mean(abs(this_diff)) + this_RMSE <- sqrt(mean(this_diff^2)) + + res_vec <- data.table(n_coalitions = this_n_coalitions, + paired_shap_sampling = this_paired_shap_sampling, + kernelSHAP_reweighting_strategy = this_kernelSHAP_reweighting_strategy, + seed = this_seed, + bias=this_bias, + var = this_var, + MAE = this_MAE, + RMSE = this_RMSE) + + res_list[[length(res_list)+1]] <- copy(res_vec) + + # weight_dt <- unique(this_X[,.(coalition_size,shapley_weight)][,shapley_weight:=mean(shapley_weight),by=coalition_size][]) + weight_dt <- this_X[,.(coalition_size,shapley_weight)][,head(.SD,1),by=coalition_size] + + weight_dt[,n_coalitions:=this_n_coalitions] + weight_dt[,paired_shap_sampling:=this_paired_shap_sampling] + weight_dt[,kernelSHAP_reweighting_strategy:=this_kernelSHAP_reweighting_strategy] + weight_dt[,seed:=this_seed] + + weight_list[[length(weight_list)+1]] <- copy(weight_dt) + + + } + + } + + print(i) + + } + + print(n_coalitions_vec[ii]) +} + + +res_dt <- rbindlist(res_list) + +fwrite(res_dt,file = "../../Div/extra_shapr_scripts_etc/res_dt_reweighting_sims_nonlingaus_nonunique_sampling_new.csv") + +resres <- res_dt[,lapply(.SD,mean),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] +resres_sd <- res_dt[,lapply(.SD,sd),.SDcols=c("bias","var","MAE","RMSE"),by=.(paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + + +library(ggplot2) + +ggplot(resres,aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line() + +ggplot(resres[paired_shap_sampling==FALSE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line()+scale_y_log10() + +ggplot(resres[paired_shap_sampling==TRUE],aes(x=n_coalitions,y=MAE,col=kernelSHAP_reweighting_strategy,linetype= paired_shap_sampling))+ + geom_line()+scale_y_log10() + + + + +weight_dt <- rbindlist(weight_list) + + +weight_dt[!(coalition_size%in%c(0,12)),sum_shapley_weight:=sum(shapley_weight),by=.(seed,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +weight_dt[!(coalition_size%in%c(0,12)),shapley_weight:=shapley_weight/sum_shapley_weight] +weight_dt[!(coalition_size%in%c(0,12)),mean(shapley_weight),by=.(seed,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + + +ww_dt <- weight_dt[!(coalition_size%in%c(0,12)),list(mean_weight=mean(shapley_weight)),by=.(coalition_size,paired_shap_sampling,n_coalitions,kernelSHAP_reweighting_strategy)] + +ggplot(ww_dt[paired_shap_sampling==TRUE & kernelSHAP_reweighting_strategy %in% c("none","on_all_cond_paired","on_N")],aes(x=coalition_size,y=mean_weight,col=kernelSHAP_reweighting_strategy))+ + geom_point()+facet_grid(~n_coalitions) diff --git a/inst/scripts/devel/simtest_timing_to_Frida.R b/inst/scripts/devel/simtest_timing_to_Frida.R new file mode 100644 index 0000000000000000000000000000000000000000..acc7e3e2a058f88d6c2c04308cd5d21e86c3bd1d --- /dev/null +++ b/inst/scripts/devel/simtest_timing_to_Frida.R @@ -0,0 +1,107 @@ +library(data.table) +library(MASS) +library(Matrix) +library(shapr) +library(future) +library(xgboost) + +shapley_threshold_prob <- 0.2 +shapley_threshold_val <- 0.1 + +m <- 12 +n_train <- 5000 +n_explain <- 100 +rho_1 <- 0.5 +rho_2 <- 0.5 +rho_3 <- 0.5 +rho_4 <- 0 +Sigma_1 <- matrix(rho_1, m/4, m/4) + diag(m/4) * (1 - rho_1) +Sigma_2 <- matrix(rho_2, m/4, m/4) + diag(m/4) * (1 - rho_2) +Sigma_3 <- matrix(rho_3, m/4, m/4) + diag(m/4) * (1 - rho_3) +Sigma_4 <- matrix(rho_4, m/4, m/4) + diag(m/4) * (1 - rho_4) + +Sigma <- as.matrix(bdiag(Sigma_1, Sigma_2, Sigma_3, Sigma_4)) +mu <- rep(0,m) + +set.seed(123) + + +x_train <- as.data.table(MASS::mvrnorm(n_train,mu,Sigma)) +x_explain <- as.data.table(MASS::mvrnorm(n_explain,mu,Sigma)) + +names(x_train) <- paste0("VV",1:m) +names(x_explain) <- paste0("VV",1:m) + + +g <- function(a,b){ + a*b+a*b^2+a^2*b +} + +beta <- c(0.2, -0.8, 1.0, 0.5, -0.8, rep(0, m - 5)) +gamma <- c(0.8,-1) +alpha <- 1 +y_train <- alpha + + as.vector(as.matrix(cos(x_train))%*%beta) + + unlist(gamma[1]*g(x_train[,1],x_train[,2])) + + unlist(gamma[1]*g(x_train[,3],x_train[,4])) + + rnorm(n_train, 0, 1) +y_explain <- alpha + + as.vector(as.matrix(cos(x_explain))%*%beta) + + unlist(gamma[1]*g(x_explain[,1],x_explain[,2])) + + unlist(gamma[1]*g(x_explain[,3],x_explain[,4])) + + rnorm(n_train, 0, 1) + +xy_train <- cbind(y_train, x_train) + +set.seed(123) +model <- xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 50, + verbose = FALSE +) + +pred_train <- predict(model, as.matrix(x_train)) + +this_order <- order(unlist(x_train[,1])) + +plot(unlist(x_train[this_order,1]),pred_train[this_order],type="l") + +p0 <- mean(y_train) + + +### First run proper shapr call on this + + +set.seed(465132) +inds = 1:5#1:n_explain + +expl <- explain( + model = model, + x_explain= x_explain[inds,], + x_train = x_train, + approach = "gaussian", + phi0 = p0, + n_coalitions = 100, + Sigma=Sigma, + mu=mu, + iterative = TRUE, + unique_sampling = FALSE, + iterative_args = list(initial_n_coalitions = 50, + fixed_n_coalitions_per_iter = 50, + max_iter = 10, + convergence_tol = 10^(-10), + compute_sd = TRUE), + kernelSHAP_reweighting = "none", + print_iter_info = TRUE +) + +# Number of (non-unique) coalitions per iteration +sapply(expl$internal$iter_list,function(dt) dt$X[,sum(sample_freq)]) + +# Timing of main function call +expl$timing$main_timing_secs + +# Timings per iteration +expl$timing$iter_timing_secs_dt[] + diff --git a/inst/scripts/devel/testing_explain_forevast_n_comb.R b/inst/scripts/devel/testing_explain_forevast_n_comb.R index 48784a6cf8d025ac64b235859fc115b9a7684c01..03bea2181f38180f4fc3705de56ed055baa7fa00 100644 --- a/inst/scripts/devel/testing_explain_forevast_n_comb.R +++ b/inst/scripts/devel/testing_explain_forevast_n_comb.R @@ -9,12 +9,12 @@ h3test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 300 + n_coalitions = 300 ) h2test <- explain_forecast(model = model_arima_temp, @@ -26,12 +26,12 @@ h2test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 10^7 + n_coalitions = 10^7 ) h1test <- explain_forecast(model = model_arima_temp, @@ -43,12 +43,12 @@ h1test <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = 10^7 + n_coalitions = 10^7 ) w <- h3test$internal$objects$X_list[[1]][["shapley_weight"]] @@ -87,7 +87,7 @@ h3full <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, @@ -103,7 +103,7 @@ h1full <- explain_forecast(model = model_arima_temp, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, @@ -122,12 +122,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = ncomb + n_coalitions = ncomb ) h2list[[i]] <- explain_forecast(model = model_arima_temp, @@ -139,12 +139,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = ncomb + n_coalitions = ncomb ) h1list[[i]] <- explain_forecast(model = model_arima_temp, @@ -156,12 +156,12 @@ for (i in 1:reps){ explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, timing = FALSE, seed = i, - n_combinations = min(ncomb,31) + n_coalitions = min(ncomb,31) ) print(i) @@ -175,14 +175,14 @@ cols_horizon3 <- h3full$internal$objects$cols_per_horizon[[3]] h1mean1 <- h2mean1 <- h2mean2 <- h3mean1 <- h3mean2 <- h3mean3 <- list() for(i in 1:reps){ - h1mean1[[i]] <- as.matrix(h1list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) + h1mean1[[i]] <- as.matrix(h1list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) - h2mean1[[i]] <- as.matrix(h2list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) - h2mean2[[i]] <- as.matrix(h2list[[i]]$shapley_values[horizon==2, ..cols_horizon2]) + h2mean1[[i]] <- as.matrix(h2list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) + h2mean2[[i]] <- as.matrix(h2list[[i]]$shapley_values_est[horizon==2, ..cols_horizon2]) - h3mean1[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==1, ..cols_horizon1]) - h3mean2[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==2, ..cols_horizon2]) - h3mean3[[i]] <- as.matrix(h3list[[i]]$shapley_values[horizon==3, ..cols_horizon3]) + h3mean1[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==1, ..cols_horizon1]) + h3mean2[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==2, ..cols_horizon2]) + h3mean3[[i]] <- as.matrix(h3list[[i]]$shapley_values_est[horizon==3, ..cols_horizon3]) } @@ -190,25 +190,25 @@ for(i in 1:reps){ Reduce("+", h1mean1) / reps Reduce("+", h2mean1) / reps Reduce("+", h3mean1) / reps -h3full$shapley_values[horizon==1,..cols_horizon1] +h3full$shapley_values_est[horizon==1,..cols_horizon1] # Horizon 2 Reduce("+", h2mean2) / reps Reduce("+", h3mean2) / reps -h3full$shapley_values[horizon==2,..cols_horizon2] +h3full$shapley_values_est[horizon==2,..cols_horizon2] # Horizon 3 Reduce("+", h3mean3) / reps -h3full$shapley_values[horizon==3,..cols_horizon3] +h3full$shapley_values_est[horizon==3,..cols_horizon3] -expect_equal(h2$shapley_values[horizon==1, ..cols_horizon1], - h1$shapley_values[horizon==1,..cols_horizon1]) +expect_equal(h2$shapley_values_est[horizon==1, ..cols_horizon1], + h1$shapley_values_est[horizon==1,..cols_horizon1]) -expect_equal(h3$shapley_values[horizon==1, ..cols_horizon1], - h1$shapley_values[horizon==1,..cols_horizon1]) +expect_equal(h3$shapley_values_est[horizon==1, ..cols_horizon1], + h1$shapley_values_est[horizon==1,..cols_horizon1]) cols_horizon2 <- h2$internal$objects$cols_per_horizon[[2]] -expect_equal(h3$shapley_values[horizon==2, ..cols_horizon2], - h2$shapley_values[horizon==2,..cols_horizon2]) +expect_equal(h3$shapley_values_est[horizon==2, ..cols_horizon2], + h2$shapley_values_est[horizon==2,..cols_horizon2]) diff --git a/inst/scripts/devel/testing_for_valid_defualt_n_batches.R b/inst/scripts/devel/testing_for_valid_defualt_n_batches.R index 2c5f3ef098b5dcdadc1ee1137f49c0c9299ced4a..a097fe73c5d624bafb65f4f1c8964a8374089b86 100644 --- a/inst/scripts/devel/testing_for_valid_defualt_n_batches.R +++ b/inst/scripts/devel/testing_for_valid_defualt_n_batches.R @@ -1,10 +1,10 @@ # In this code we demonstrate that (before the bugfix) the `explain()` function -# does not enter the exact mode when n_combinations is larger than or equal to 2^m. -# The mode is only changed if n_combinations is strictly larger than 2^m. -# This means that we end up with using all coalitions when n_combinations is 2^m, +# does not enter the exact mode when n_coalitions is larger than or equal to 2^m. +# The mode is only changed if n_coalitions is strictly larger than 2^m. +# This means that we end up with using all coalitions when n_coalitions is 2^m, # but use not the exact Shapley kernel weights. # Bugfix replaces `>` with `=>`in the places where the code tests if -# n_combinations is larger than or equal to 2^m. Then the text/messages printed by +# n_coalitions is larger than or equal to 2^m. Then the text/messages printed by # shapr and the code correspond. library(xgboost) @@ -34,13 +34,13 @@ model <- xgboost::xgboost( p0 <- mean(y_train) # Shapr sets the default number of batches to be 10 for this dataset for the -# "ctree", "gaussian", and "copula" approaches. Thus, setting `n_combinations` +# "ctree", "gaussian", and "copula" approaches. Thus, setting `n_coalitions` # to any value lower of equal to 10 causes the error. any_number_equal_or_below_10 = 8 # Before the bugfix, shapr:::check_n_batches() throws the error: # Error in check_n_batches(internal) : -# `n_batches` (10) must be smaller than the number feature combinations/`n_combinations` (8) +# `n_batches` (10) must be smaller than the number feature combinations/`n_coalitions` (8) # Bug only occures for "ctree", "gaussian", and "copula" as they are treated different in # `get_default_n_batches()`, I am not certain why. Ask Martin about the logic behind that. explanation <- explain( @@ -49,6 +49,6 @@ explanation <- explain( x_train = x_train, n_samples = 2, # Low value for fast computations approach = "gaussian", - prediction_zero = p0, - n_combinations = any_number_equal_or_below_10 + phi0 = p0, + n_coalitions = any_number_equal_or_below_10 ) diff --git a/inst/scripts/devel/testing_intermediate_saving.R b/inst/scripts/devel/testing_intermediate_saving.R new file mode 100644 index 0000000000000000000000000000000000000000..85981c381aa709143395dae1b6d925c95a47c79a --- /dev/null +++ b/inst/scripts/devel/testing_intermediate_saving.R @@ -0,0 +1,132 @@ + + +aa = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.01, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 30 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE,kernelSHAP_reweighting = "on_N" +) + +bb = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 30 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE,kernelSHAP_reweighting = "on_N",prev_shapr_object = aa +) + + + + +##### Reproducable results setting seed outside, and not setting it inside of explain (+ an seed-independent approach) +# Add something like this + + +set.seed(123) +full = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 7 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL +) + +set.seed(123) +first = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 4 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL +) + + +second = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 7 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL, + prev_shapr_object = first +) + + + +# This cannot be tested, I think. +second_path = explain( + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + print_shapleyres = TRUE, + print_iter_info = TRUE, + kernelSHAP_reweighting = "on_N", + seed=NULL, + prev_shapr_object = first$internal$parameters$output_args$saving_path +) + + +# Identical results +all.equal(full$shapley_values_est,second$shapley_values_est) # TRUE +all.equal(full$shapley_values_est,second2$shapley_values_est) # TRUE +all.equal(full$shapley_values_est,second_path$shapley_values_est) # TRUE diff --git a/inst/scripts/devel/testing_memory_monitoring.R b/inst/scripts/devel/testing_memory_monitoring.R index a372c6cf33f79d5ade10dab4dfd4b9934a2be6ea..f161c3d2e7541ad782ba8e8bd598deb2089e3730 100644 --- a/inst/scripts/devel/testing_memory_monitoring.R +++ b/inst/scripts/devel/testing_memory_monitoring.R @@ -44,7 +44,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -explainer <- shapr(x_train, model,n_combinations = 1000) +explainer <- shapr(x_train, model,n_coalitions = 1000) p <- mean(y_train) @@ -56,7 +56,7 @@ peakRAM(explain( x_test, approach = "gaussian", explainer = explainer, - prediction_zero = p,n_batches = 4) + phi0 = p,n_batches = 4) ) # , @@ -64,28 +64,28 @@ peakRAM(explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 2), +# phi0 = p,n_batches = 2), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 4)) +# phi0 = p,n_batches = 4)) # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 8), +# phi0 = p,n_batches = 8), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 16), +# phi0 = p,n_batches = 16), # explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 32) +# phi0 = p,n_batches = 32) # ) # s <- proc.time() @@ -93,6 +93,6 @@ peakRAM(explain( # x_test, # approach = "empirical", # explainer = explainer, -# prediction_zero = p,n_batches = 32) +# phi0 = p,n_batches = 32) # print(proc.time()-s) # diff --git a/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R b/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R index 56e447dee554f260b6ada57af617af3f998328e6..ee8a01e3f87b25251000d30c670c1e38d96c7154 100644 --- a/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R +++ b/inst/scripts/devel/testing_n_cobinations_equal_2_power_m.R @@ -1,10 +1,10 @@ # In this code we demonstrate that (before the bugfix) the `explain()` function -# does not enter the exact mode when n_combinations is larger than or equal to 2^m. -# The mode is only changed if n_combinations is strictly larger than 2^m. -# This means that we end up with using all coalitions when n_combinations is 2^m, +# does not enter the exact mode when n_coalitions is larger than or equal to 2^m. +# The mode is only changed if n_coalitions is strictly larger than 2^m. +# This means that we end up with using all coalitions when n_coalitions is 2^m, # but use not the exact Shapley kernel weights. # Bugfix replaces `>` with `=>`in the places where the code tests if -# n_combinations is larger than or equal to 2^m. Then the text/messages printed by +# n_coalitions is larger than or equal to 2^m. Then the text/messages printed by # shapr and the code correspond. library(xgboost) @@ -41,8 +41,8 @@ explanation_exact <- explain( n_samples = 2, # Low value for fast computations n_batches = 1, # Not related to the bug approach = "gaussian", - prediction_zero = p0, - n_combinations = NULL + phi0 = p0, + n_coalitions = NULL ) # Computing the conditional Shapley values using the gaussian approach @@ -53,13 +53,13 @@ explanation_should_also_be_exact <- explain( n_samples = 2, # Low value for fast computations n_batches = 1, # Not related to the bug approach = "gaussian", - prediction_zero = p0, - n_combinations = 2^ncol(x_explain) + phi0 = p0, + n_coalitions = 2^ncol(x_explain) ) # see that both `explain()` objects have the same number of combinations -explanation_exact$internal$parameters$n_combinations -explanation_should_also_be_exact$internal$parameters$n_combinations +explanation_exact$internal$parameters$n_coalitions +explanation_should_also_be_exact$internal$parameters$n_coalitions # But the first one of them is exact and the other not. explanation_exact$internal$parameters$exact diff --git a/inst/scripts/devel/testing_parallelization.R b/inst/scripts/devel/testing_parallelization.R index 24cacc1a7cc9082d4f011cdf32caddb9c13ef8e5..3f82541f28daa4b0241b5fc4c12e9293e356f03f 100644 --- a/inst/scripts/devel/testing_parallelization.R +++ b/inst/scripts/devel/testing_parallelization.R @@ -78,7 +78,7 @@ for(i in seq_len(nrow(res_dt))){ x_test, approach = approach_use, explainer = explainer, - prediction_zero = p,n_batches = n_batches_use + phi0 = p,n_batches = n_batches_use )},iterations = reps,time_unit ='s',memory = F, min_time = Inf ) diff --git a/inst/scripts/devel/testing_verification_ar_model.R b/inst/scripts/devel/testing_verification_ar_model.R index ab5c43d6ae3cad676bc148b5a03a900483b750df..6cf50f894ab73ac6707433f450dff2ce0fb851b8 100644 --- a/inst/scripts/devel/testing_verification_ar_model.R +++ b/inst/scripts/devel/testing_verification_ar_model.R @@ -28,11 +28,11 @@ exp <- explain_forecast(model = model_arima_temp, explain_xreg_lags = c(0,0), horizon = 2, approach = "empirical", - prediction_zero = c(0,0), + phi0 = c(0,0), group_lags = FALSE, n_batches = 1, timing = FALSE, - n_combinations = 50 + n_coalitions = 50 ) diff --git a/inst/scripts/devel/time_series_annabelle.R b/inst/scripts/devel/time_series_annabelle.R index 26e1f8b38328e7a8d0fffa58a64666890d71e056..62fdffd7bd05f158bd28a1ddf620c49cd52ae514 100644 --- a/inst/scripts/devel/time_series_annabelle.R +++ b/inst/scripts/devel/time_series_annabelle.R @@ -71,7 +71,7 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "timeseries", - prediction_zero = p0, + phi0 = p0, group = group, timeseries.fixed_sigma_vec = 2 # timeseries.bounds = c(-1, 2) diff --git a/inst/scripts/devel/verifying_arima_model_output.R b/inst/scripts/devel/verifying_arima_model_output.R index 7a63bcbf51538550d59a874b9f8e5238a67b51b0..47ce0641dddf3050be2647c190f46f24814d4ae9 100644 --- a/inst/scripts/devel/verifying_arima_model_output.R +++ b/inst/scripts/devel/verifying_arima_model_output.R @@ -45,26 +45,26 @@ exp <- explain_forecast(model = model_arima_temp, explain_xreg_lags = c(0,1), horizon = 1, approach = "empirical", - prediction_zero = rep(mean(y),1), + phi0 = rep(mean(y),1), group_lags = FALSE, n_batches = 1) # These two should be approximately equal # For y -exp$shapley_values$Y1.1 +exp$shapley_values_est$Y1.1 model_arima_temp$coef[1]*(y[explain_idx]-mean(y)) #[1] -0.13500 0.20643 #[1] -0.079164 0.208118 # for xreg1 -exp$shapley_values$var1.F1 +exp$shapley_values_est$var1.F1 model_arima_temp$coef[3]*(xreg[explain_idx+1,1]-mean(xreg[,1])) #[1] -0.030901 1.179386 #[1] -0.12034 1.19589 # for xreg2 -exp$shapley_values$var2.F1 +exp$shapley_values_est$var2.F1 0 #[1] 0.011555 0.031911 #[1] 0 diff --git a/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R b/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R index 0c57fe6c1fe9028eaf7f4a73f80ada7019f7d9df..f9189e480350373a4beefc54863052047fe1f46b 100644 --- a/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R +++ b/inst/scripts/devel/visual_bug_in_Shapley_bar_plot.R @@ -41,7 +41,7 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, + phi0 = p0, n_samples = 10, keep_samp_for_vS = TRUE ) diff --git a/inst/scripts/empirical_memory_testing2.R b/inst/scripts/empirical_memory_testing2.R index ca57a8d5fdd293a2aa669a356cc61a105b3a2440..84e1f863fcb0dcb3b1f55b25618cd5400b7acd5c 100644 --- a/inst/scripts/empirical_memory_testing2.R +++ b/inst/scripts/empirical_memory_testing2.R @@ -60,7 +60,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -71,7 +71,7 @@ explanation_many <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero + phi0 = phi0 ) @@ -81,7 +81,7 @@ explanation_many <- explain( # x_train = x_train, # approach = approach, # n_batches = 1, -# prediction_zero = prediction_zero +# phi0 = phi0 #) @@ -99,8 +99,8 @@ internal <- setup( x_train = x_train, x_explain = x_explain, approach = approach, - prediction_zero = prediction_zero, - n_combinations = 2^p, + phi0 = phi0, + n_coalitions = 2^p, group = NULL, n_samples = 1e3, n_batches = n_batches_use, diff --git a/inst/scripts/example_annabelle.R b/inst/scripts/example_annabelle.R index feede50bbf5d9244f8a3d0f86807b48fc98ac02a..b2cad403166f303f53684de6bbc8d217ff19cfa7 100644 --- a/inst/scripts/example_annabelle.R +++ b/inst/scripts/example_annabelle.R @@ -46,7 +46,7 @@ temp = explain( x_explain = x_test, model = model, approach = "categorical", - prediction_zero = p, + phi0 = p, joint_probability_dt = joint_prob_dt ) print(temp) diff --git a/inst/scripts/example_ctree_method.R b/inst/scripts/example_ctree_method.R index 6f0d26f12828a0f4dfa95a06a948bccb40f94786..6765a989c52362196ca9a4432a9e36729f902494 100644 --- a/inst/scripts/example_ctree_method.R +++ b/inst/scripts/example_ctree_method.R @@ -33,7 +33,7 @@ p0 <- mean(y_train) # and sample = TRUE explanation <- explain(x_test, explainer, approach = "ctree", - prediction_zero = p0) + phi0 = p0) # Printing the Shapley values for the test data explanation$dt @@ -91,7 +91,7 @@ explanation_cat <- explain( dummylist$testdata_new, approach = "ctree", explainer = explainer_cat, - prediction_zero = p0 + phi0 = p0 ) # Plot the resulting explanations for observations 1 and 6, excluding diff --git a/inst/scripts/example_custom_model.R b/inst/scripts/example_custom_model.R index 34a6377a4d63791085b94b40f7a2e221eabb63ac..c2a476a318ffaa88b1b1164bb8e03f2e6455b5a7 100644 --- a/inst/scripts/example_custom_model.R +++ b/inst/scripts/example_custom_model.R @@ -65,7 +65,7 @@ get_model_specs.gbm <- function(x){ set.seed(123) explainer <- shapr(xy_train, model) p0 <- mean(xy_train[,y_var]) -explanation <- explain(x_test, explainer, approach = "empirical", prediction_zero = p0) +explanation <- explain(x_test, explainer, approach = "empirical", phi0 = p0) # Plot results plot(explanation) @@ -89,6 +89,6 @@ predict_model.gbm <- function(x, newdata) { set.seed(123) explainer <- shapr(x_train, model) p0 <- mean(xy_train[,y_var]) -explanation <- explain(x_test, explainer, approach = "empirical", prediction_zero = p0) +explanation <- explain(x_test, explainer, approach = "empirical", phi0 = p0) # Plot results plot(explanation) diff --git a/inst/scripts/example_plot_MSEv.R b/inst/scripts/example_plot_MSEv.R index 42587ccbdc33162d1cbf2673d80b5f3d5014e1c7..725b1d89648aaf1319a899f95f1d5d5426ffc868 100644 --- a/inst/scripts/example_plot_MSEv.R +++ b/inst/scripts/example_plot_MSEv.R @@ -29,7 +29,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -37,7 +37,7 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -47,7 +47,7 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -57,7 +57,7 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e1 ) @@ -67,7 +67,7 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -77,7 +77,7 @@ explanation_ctree <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -87,7 +87,7 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "independence", "ctree"), - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -228,7 +228,7 @@ plot_MSEv_eval_crit(explanation_list_named, )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15) + id_coalition = c(3, 4, 9, 13:15) )$MSEv_combination_bar @@ -236,11 +236,11 @@ plot_MSEv_eval_crit(explanation_list_named, MSEv_combination <- plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15) + id_coalition = c(3, 4, 9, 13:15) )$MSEv_combination_bar MSEv_combination$data$Method <- factor(MSEv_combination$data$Method, levels = rev(levels(MSEv_combination$data$Method))) MSEv_combination + - ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_combination))) + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination$data$id_coalition))) + ggplot2::scale_fill_discrete(breaks = rev(levels(MSEv_combination$data$Method)), direction = -1) + ggplot2::coord_flip() @@ -249,14 +249,14 @@ MSEv_combination + MSEv_combination_wo_CI <- plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = NULL )$MSEv_combination_bar MSEv_combination_wo_CI$data$Method <- factor(MSEv_combination_wo_CI$data$Method, levels = rev(levels(MSEv_combination_wo_CI$data$Method)) ) MSEv_combination_wo_CI + - ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_combination))) + + ggplot2::scale_x_discrete(limits = rev(unique(MSEv_combination_wo_CI$data$id_coalition))) + ggplot2::scale_fill_brewer( breaks = rev(levels(MSEv_combination_wo_CI$data$Method)), palette = "Paired", @@ -290,9 +290,9 @@ explanation_gaussian_seed_1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 1 ) @@ -301,9 +301,9 @@ explanation_gaussian_seed_1_V2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 1 ) @@ -312,9 +312,9 @@ explanation_gaussian_seed_2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 2 ) @@ -323,9 +323,9 @@ explanation_gaussian_seed_3 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10, - n_combinations = 10, + n_coalitions = 10, seed = 3 ) @@ -350,7 +350,7 @@ explanation_gaussian_all <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) @@ -359,7 +359,7 @@ explanation_gaussian_only_5 <- explain( x_explain = x_explain[1:5, ], x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) @@ -376,12 +376,12 @@ explanation_gaussian <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) explanation_gaussian_copy <- copy(explanation_gaussian_all) -colnames(explanation_gaussian_copy$shapley_values) <- rev(colnames(explanation_gaussian_copy$shapley_values)) +colnames(explanation_gaussian_copy$shapley_values_est) <- rev(colnames(explanation_gaussian_copy$shapley_values_est)) # Will give an error due to different feature names plot_MSEv_eval_crit(list( @@ -397,7 +397,7 @@ explanation_gaussian <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 10 ) diff --git a/inst/scripts/example_plot_SV_several_approaches.R b/inst/scripts/example_plot_SV_several_approaches.R index a25c66b3622aae8d53888e04eb69fbe43c79f3b7..564e4c133626151968b80bc31e62c4c5ebfb89be 100644 --- a/inst/scripts/example_plot_SV_several_approaches.R +++ b/inst/scripts/example_plot_SV_several_approaches.R @@ -27,7 +27,7 @@ model = xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero = mean(y_train) +phi0 = mean(y_train) # Independence approach explanation_independence = explain( @@ -35,7 +35,7 @@ explanation_independence = explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -45,7 +45,7 @@ explanation_empirical = explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -55,7 +55,7 @@ explanation_gaussian_1e1 = explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e1 ) @@ -65,7 +65,7 @@ explanation_gaussian_1e2 = explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) @@ -75,7 +75,7 @@ explanation_combined = explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "ctree", "empirical"), - prediction_zero = prediction_zero, + phi0 = phi0, n_samples = 1e2 ) diff --git a/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R b/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R index a364a9ce4fc950c0e64ec64c6ea820b5c0f1113d..85e9e39145592d60e4299bf68f37db3162609dd9 100644 --- a/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R +++ b/inst/scripts/example_plot_several_vaeacs_VLB_IWAE.R @@ -28,7 +28,7 @@ explanation_paired_sampling_TRUE <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -44,7 +44,7 @@ explanation_paired_sampling_FALSE <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -61,7 +61,7 @@ explanation_paired_sampling_FALSE_small <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. @@ -80,7 +80,7 @@ explanation_paired_sampling_TRUE_small <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, + phi0 = p0, n_batches = 2, n_samples = 1, #' As we are only interested in the training of the vaeac vaeac.epochs = 25, #' Should be higher in applications. diff --git a/inst/scripts/explain_memory_testing.R b/inst/scripts/explain_memory_testing.R index 7c3030ffc1d5cff94b118acb0bfe0acc4c6cf958..d9e35e7eb475bf7de1ff9db1f815feb96cfb785c 100644 --- a/inst/scripts/explain_memory_testing.R +++ b/inst/scripts/explain_memory_testing.R @@ -60,7 +60,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -74,7 +74,7 @@ explanation <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero + phi0 = phi0 ) },threshold=10^4) diff --git a/inst/scripts/problematic_plots_jens.R b/inst/scripts/problematic_plots_jens.R index 2aa26c89682e2c9ae80ac6f236f489243a1f176d..176af6a9f5a10b24e7055fd3f3cb479a655aa1b0 100644 --- a/inst/scripts/problematic_plots_jens.R +++ b/inst/scripts/problematic_plots_jens.R @@ -41,7 +41,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) @@ -62,7 +62,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Works fine @@ -85,7 +85,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Only 4 ticks in the x-axis for the factor @@ -107,7 +107,7 @@ explanation_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Duplicated labels on the x-axis diff --git a/inst/scripts/readme_example.R b/inst/scripts/readme_example.R index 480f599d7d7c807572e329b330e247545d5f5c58..9d63bc1a1e3f577a54c4f4d64a75ea61f21d2eb6 100644 --- a/inst/scripts/readme_example.R +++ b/inst/scripts/readme_example.R @@ -34,12 +34,12 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0 ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Finally we plot the resulting explanations plot(explanation) diff --git a/inst/scripts/testing_samling_ncombinations.R b/inst/scripts/testing_samling_ncombinations.R index 65e066d98757f799690894e05be85b0e28da0aca..d11220a4f855a9b2f80c433e3c711b512eb14f33 100644 --- a/inst/scripts/testing_samling_ncombinations.R +++ b/inst/scripts/testing_samling_ncombinations.R @@ -5,12 +5,12 @@ library(shapr) library(data.table) n = c(100, 1000, 2000) p = c(5, 10, 10) -n_combinations = c(20, 800, 800) +n_coalitions = c(20, 800, 800) res = list() for (i in seq_along(n)) { set.seed(123) - cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") + cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i]*p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10*p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -25,8 +25,8 @@ for (i in seq_along(n)) { x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] ) ) } @@ -37,7 +37,7 @@ for (i in seq_along(n)) { set.seed(123) - cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") + cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i] * p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10 * p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -52,8 +52,8 @@ for (i in seq_along(n)) { x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] ) ) } @@ -65,7 +65,7 @@ saveRDS(res2, "inst/scripts/testing_samling_ncombinations2.rds") i = 2 set.seed(123) -cat("n =", n[i], "p =", p[i], "n_combinations =", n_combinations[i], "\n") +cat("n =", n[i], "p =", p[i], "n_coalitions =", n_coalitions[i], "\n") x_train = data.table(matrix(rnorm(n[i] * p[i]), nrow = n[i], ncol = p[i])) x_test = data.table(matrix(rnorm(10 * p[i]), nrow = 10, ncol = p[i])) beta = rnorm(p[i]) @@ -79,8 +79,8 @@ system.time({res = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = 1000 + phi0 = p_mean, + n_coalitions = 1000 )}) devtools::load_all() @@ -89,8 +89,8 @@ system.time({res2 = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = 800 + phi0 = p_mean, + n_coalitions = 800 )}) @@ -100,8 +100,8 @@ system.time({res3 = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = NULL + phi0 = p_mean, + n_coalitions = NULL )}) x2 = Sys.time() @@ -117,8 +117,8 @@ res = profvis({res = explain( x_test, model = model, approach = "empirical", - prediction_zero = p_mean, - n_combinations = n_combinations[i] + phi0 = p_mean, + n_coalitions = n_coalitions[i] )}) res diff --git a/inst/scripts/time_series_annabelle.R b/inst/scripts/time_series_annabelle.R index 26e1f8b38328e7a8d0fffa58a64666890d71e056..62fdffd7bd05f158bd28a1ddf620c49cd52ae514 100644 --- a/inst/scripts/time_series_annabelle.R +++ b/inst/scripts/time_series_annabelle.R @@ -71,7 +71,7 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "timeseries", - prediction_zero = p0, + phi0 = p0, group = group, timeseries.fixed_sigma_vec = 2 # timeseries.bounds = c(-1, 2) diff --git a/inst/scripts/timing_script_2023.R b/inst/scripts/timing_script_2023.R index d43db74f62445609ab11b87a722f4e05f331dbdd..31c258d9806777eca41d00394f8d0d0b2796ba88 100644 --- a/inst/scripts/timing_script_2023.R +++ b/inst/scripts/timing_script_2023.R @@ -59,7 +59,7 @@ xy_train <- cbind(x_train,y=y_train) model <- lm(formula = y~.,data=xy_train) -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) n_batches_use <- min(2^p-2,n_batches) @@ -72,8 +72,8 @@ explanation <- explain( x_train = x_train, approach = approach, n_batches = n_batches_use, - prediction_zero = prediction_zero, - n_combinations = 10^4 + phi0 = phi0, + n_coalitions = 10^4 ) sys_time_end_explain <- Sys.time() @@ -89,7 +89,7 @@ timing <- list(p = p, n_batches = n_batches, n_cores = n_cores, approach = approach, - n_combinations = explanation$internal$parameters$used_n_combinations, + n_coalitions = explanation$internal$parameters$used_n_coalitions, sys_time_initial = as.character(sys_time_initial), sys_time_start_explain = as.character(sys_time_start_explain), sys_time_end_explain = as.character(sys_time_end_explain), diff --git a/inst/scripts/vilde/airquality_example.R b/inst/scripts/vilde/airquality_example.R index 9c162bfe279f15e212a6b2cd41239b2f051751b8..59d2e225afca142db2815d6ebc2ccc286e748beb 100644 --- a/inst/scripts/vilde/airquality_example.R +++ b/inst/scripts/vilde/airquality_example.R @@ -15,7 +15,7 @@ x <- explain( test, model = model, approach = "empirical", - prediction_zero = p + phi0 = p ) if (requireNamespace("ggplot2", quietly = TRUE)) { diff --git a/inst/scripts/vilde/check_progress.R b/inst/scripts/vilde/check_progress.R index aee0f765cd4db3ebe3fc51fbb231ecf5890a5113..ec3da4887f7731db176b44ef4bc271c9154f575c 100644 --- a/inst/scripts/vilde/check_progress.R +++ b/inst/scripts/vilde/check_progress.R @@ -25,34 +25,34 @@ p <- mean(y_train) plan(multisession, workers=3) # when we simply call explain(), no progress bar is shown -x <- explain(x_train, x_test, model, approach="gaussian", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="gaussian", phi0=p, n_batches = 4) # the handler specifies what kind of progress bar is shown # Wrapping explain() in with_progress() gives a progress bar when calling explain() handlers("txtprogressbar") x <- with_progress( - explain(x_train, x_test, model, approach="empirical", prediction_zero=p, n_batches = 5) + explain(x_train, x_test, model, approach="empirical", phi0=p, n_batches = 5) ) # with global=TRUE the progress bar is displayed whenever the explain-function is called, and there is no need to use with_progress() handlers(global = TRUE) -x <- explain(x_train, x_test, model, approach="gaussian", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="gaussian", phi0=p, n_batches = 4) # there are different options for what kind of progress bar should be displayed handlers("txtprogressbar") #this is the default -x <- explain(x_train, x_test, model, approach="independence", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="independence", phi0=p, n_batches = 4) handlers("progress") -x <- explain(x_train, x_test, model, approach="independence", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="independence", phi0=p, n_batches = 4) # you can edit the symbol used to draw completed progress in the progress bar (as well as other features) with handler_progress() handlers(handler_progress(complete = "#")) -x <- explain(x_train, x_test, model, approach="copula", prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach="copula", phi0=p, n_batches = 4) plan("sequential") handlers("progress") -x <- explain(x_train, x_test, model, approach=c(rep("ctree",4),"independence","independence"), prediction_zero=p, n_batches = 4) +x <- explain(x_train, x_test, model, approach=c(rep("ctree",4),"independence","independence"), phi0=p, n_batches = 4) diff --git a/inst/scripts/vilde/sketch_for_waterfall_plot.R b/inst/scripts/vilde/sketch_for_waterfall_plot.R index dc9e9278f67d3a6e1e47aa4a808eae7b1f778681..e31971a1a3b040ebf5b9b4e84bb9e4ef9fbe043e 100644 --- a/inst/scripts/vilde/sketch_for_waterfall_plot.R +++ b/inst/scripts/vilde/sketch_for_waterfall_plot.R @@ -25,15 +25,15 @@ model <- xgboost( p <- mean(y_train) # Prepare the data for explanation -res <- explain_final(x_train,x_test,model,approach="independence",prediction_zero=p,n_batches = 4) +res <- explain_final(x_train,x_test,model,approach="independence",phi0=p,n_batches = 4) plot(res) i<- 1 # index for observation we want to plot -dt <- data.table(feat_name = paste0(colnames(res$shapley_values[,-1]), " = ", format(res$internal$data$x_explain[i,], 2) ), - shapley_value = as.numeric(res$shapley_values[i,-1]) +dt <- data.table(feat_name = paste0(colnames(res$shapley_values_est[,-1]), " = ", format(res$internal$data$x_explain[i,], 2) ), + shapley_value = as.numeric(res$shapley_values_est[i,-1]) ) dt -expected <- as.numeric(res$shapley_values[i,])[1] +expected <- as.numeric(res$shapley_values_est[i,])[1] observed <- res$pred_explain[i] dt[, sign := ifelse(shapley_value > 0, "Increases", "Decreases")] diff --git a/inst/scripts/vilde/waterfall_plot.R b/inst/scripts/vilde/waterfall_plot.R index 531f1e4c14c33a6fed674faa1857672d235c3c6e..5035d252845756610e2e0aa56861ba96efb87662 100644 --- a/inst/scripts/vilde/waterfall_plot.R +++ b/inst/scripts/vilde/waterfall_plot.R @@ -19,7 +19,7 @@ model <- xgboost( verbose = FALSE ) p <- mean(y_train) -x <- explain_final(x_train,x_test,model,approach="independence",prediction_zero=p,n_batches = 4) +x <- explain_final(x_train,x_test,model,approach="independence",phi0=p,n_batches = 4) plot.shapr(x, plot_type = "bar", digits = 3, diff --git a/man/additional_regression_setup.Rd b/man/additional_regression_setup.Rd new file mode 100644 index 0000000000000000000000000000000000000000..9aebdd0355123324e1db523b596fcb8391898c6c --- /dev/null +++ b/man/additional_regression_setup.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{additional_regression_setup} +\alias{additional_regression_setup} +\title{Additional setup for regression-based methods} +\usage{ +additional_regression_setup(internal, model, predict_model) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Additional setup for regression-based methods +} +\keyword{internal} diff --git a/man/append_vS_list.Rd b/man/append_vS_list.Rd new file mode 100644 index 0000000000000000000000000000000000000000..ceb1db0888b4be990ab31c82829d8efbce23447e --- /dev/null +++ b/man/append_vS_list.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute_vS.R +\name{append_vS_list} +\alias{append_vS_list} +\title{Appends the new vS_list to the prev vS_list} +\usage{ +append_vS_list(vS_list, internal) +} +\arguments{ +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} + +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Appends the new vS_list to the prev vS_list +} +\keyword{internal} diff --git a/man/check_categorical_valid_MCsamp.Rd b/man/check_categorical_valid_MCsamp.Rd new file mode 100644 index 0000000000000000000000000000000000000000..65515d63bf2ebaffa4a857b44ec7ed1408e48f9e --- /dev/null +++ b/man/check_categorical_valid_MCsamp.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{check_categorical_valid_MCsamp} +\alias{check_categorical_valid_MCsamp} +\title{Check that all explicands has at least one valid MC sample in causal Shapley values} +\usage{ +check_categorical_valid_MCsamp( + dt, + n_explain, + n_MC_samples, + joint_probability_dt +) +} +\arguments{ +\item{dt}{Data.table containing the generated MC samples (and conditional values) after each sampling step} + +\item{n_explain}{Integer. The number of explicands/observations to explain.} + +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} +} +\description{ +Check that all explicands has at least one valid MC sample in causal Shapley values +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/check_convergence.Rd b/man/check_convergence.Rd new file mode 100644 index 0000000000000000000000000000000000000000..8d727207adaaa64acaff5232fd09fd634e445dab --- /dev/null +++ b/man/check_convergence.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/check_convergence.R +\name{check_convergence} +\alias{check_convergence} +\title{Checks the convergence according to the convergence threshold} +\usage{ +check_convergence(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Checks the convergence according to the convergence threshold +} +\keyword{internal} diff --git a/man/check_verbose.Rd b/man/check_verbose.Rd new file mode 100644 index 0000000000000000000000000000000000000000..5af03b591d10007c4db7a0e1b4bcfb29fe85c16f --- /dev/null +++ b/man/check_verbose.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{check_verbose} +\alias{check_verbose} +\title{Function that checks the verbose parameter} +\usage{ +check_verbose(verbose) +} +\arguments{ +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} +} +\value{ +The function does not return anything. +} +\description{ +Function that checks the verbose parameter +} +\author{ +Lars Henry Berge Olsen, Martin Jullum +} +\keyword{internal} diff --git a/man/cli_compute_vS.Rd b/man/cli_compute_vS.Rd new file mode 100644 index 0000000000000000000000000000000000000000..5fcf73210d369963cf40e0eeca8b4e5e24c538d3 --- /dev/null +++ b/man/cli_compute_vS.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_compute_vS} +\alias{cli_compute_vS} +\title{Printing messages in compute_vS with cli} +\usage{ +cli_compute_vS(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Printing messages in compute_vS with cli +} +\keyword{internal} diff --git a/man/cli_iter.Rd b/man/cli_iter.Rd new file mode 100644 index 0000000000000000000000000000000000000000..6426af8c9813a9c23c42e0d1726039c24b5a56a5 --- /dev/null +++ b/man/cli_iter.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_iter} +\alias{cli_iter} +\title{Printing messages in iterative procedure with cli} +\usage{ +cli_iter(verbose, internal, iter) +} +\arguments{ +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} + +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{iter}{Integer. +The iteration number. Only used internally.} +} +\description{ +Printing messages in iterative procedure with cli +} +\keyword{internal} diff --git a/man/cli_startup.Rd b/man/cli_startup.Rd new file mode 100644 index 0000000000000000000000000000000000000000..afd5aa3a804cbb01f14d458b0820f9bae1acc080 --- /dev/null +++ b/man/cli_startup.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/cli.R +\name{cli_startup} +\alias{cli_startup} +\title{Printing startup messages with cli} +\usage{ +cli_startup(internal, model_class, verbose) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{model_class}{String. +Class of the model as a string} + +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} +} +\description{ +Printing startup messages with cli +} +\keyword{internal} diff --git a/man/coalition_matrix_cpp.Rd b/man/coalition_matrix_cpp.Rd new file mode 100644 index 0000000000000000000000000000000000000000..5f5956e11ae803a132fd62284da0a7a62d83a2e2 --- /dev/null +++ b/man/coalition_matrix_cpp.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{coalition_matrix_cpp} +\alias{coalition_matrix_cpp} +\title{Get coalition matrix} +\usage{ +coalition_matrix_cpp(coalitions, m) +} +\arguments{ +\item{coalitions}{List} + +\item{m}{Positive integer. Total number of coalitions} +} +\value{ +Matrix +} +\description{ +Get coalition matrix +} +\author{ +Nikolai Sellereite, Martin Jullum +} +\keyword{internal} diff --git a/man/compute_MSEv_eval_crit.Rd b/man/compute_MSEv_eval_crit.Rd index c6e3e0549095fb347bfe7f9110820d067e7ae22d..27643a769850be566715010a7c01d6074796aef5 100644 --- a/man/compute_MSEv_eval_crit.Rd +++ b/man/compute_MSEv_eval_crit.Rd @@ -14,38 +14,34 @@ compute_MSEv_eval_crit( \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} -\item{dt_vS}{Data.table of dimension \code{n_combinations} times \code{n_explain + 1} containing the contribution function -estimates. The first column is assumed to be named \code{id_combination} and containing the ids of the combinations. -The last row is assumed to be the full combination, i.e., it contains the predicted responses for the observations +\item{dt_vS}{Data.table of dimension \code{n_coalitions} times \code{n_explain + 1} containing the contribution function +estimates. The first column is assumed to be named \code{id_coalition} and containing the ids of the coalitions. +The last row is assumed to be the full coalition, i.e., it contains the predicted responses for the observations which are to be explained.} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - \item{MSEv_skip_empty_full_comb}{Logical. If \code{TRUE} (default), we exclude the empty and grand -combinations/coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical +coalitions when computing the MSEv evaluation criterion. This is reasonable as they are identical for all methods, i.e., their contribution function is independent of the used method as they are special cases not -effected by the used method. If \code{FALSE}, we include the empty and grand combinations/coalitions. In this situation, +effected by the used method. If \code{FALSE}, we include the empty and grand coalitions. In this situation, we also recommend setting \code{MSEv_uniform_comb_weights = TRUE}, as otherwise the large weights for the empty and -grand combinations/coalitions will outweigh all other combinations and make the MSEv criterion uninformative.} +grand coalitions will outweigh all other coalitions and make the MSEv criterion uninformative.} } \value{ List containing: \describe{ \item{\code{MSEv}}{A \code{\link[data.table]{data.table}} with the overall MSEv evaluation criterion averaged -over both the combinations/coalitions and observations/explicands. The \code{\link[data.table]{data.table}} -also contains the standard deviation of the MSEv values for each explicand (only averaged over the combinations) +over both the coalitions and observations/explicands. The \code{\link[data.table]{data.table}} +also contains the standard deviation of the MSEv values for each explicand (only averaged over the coalitions) divided by the square root of the number of explicands.} \item{\code{MSEv_explicand}}{A \code{\link[data.table]{data.table}} with the mean squared error for each -explicand, i.e., only averaged over the combinations/coalitions.} -\item{\code{MSEv_combination}}{A \code{\link[data.table]{data.table}} with the mean squared error for each -combination/coalition, i.e., only averaged over the explicands/observations. +explicand, i.e., only averaged over the coalitions.} +\item{\code{MSEv_coalition}}{A \code{\link[data.table]{data.table}} with the mean squared error for each +coalition, i.e., only averaged over the explicands/observations. The \code{\link[data.table]{data.table}} also contains the standard deviation of the MSEv values for -each combination divided by the square root of the number of explicands.} +each coalition divided by the square root of the number of explicands.} } } \description{ diff --git a/man/compute_estimates.Rd b/man/compute_estimates.Rd new file mode 100644 index 0000000000000000000000000000000000000000..9d708f738ce9fd8f2a356931f4198b51c21bfe15 --- /dev/null +++ b/man/compute_estimates.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute_estimates.R +\name{compute_estimates} +\alias{compute_estimates} +\title{Computes the the Shapley values and their standard deviation given the \code{v(S)}} +\usage{ +compute_estimates(internal, vS_list) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} +} +\description{ +Computes the the Shapley values and their standard deviation given the \code{v(S)} +} +\keyword{internal} diff --git a/man/compute_shapley_new.Rd b/man/compute_shapley_new.Rd index 14e77306db9385848c9050a9c6b42e1d9ec0eabd..3c1d249f200e72d89040ecea51238e3e9f960f09 100644 --- a/man/compute_shapley_new.Rd +++ b/man/compute_shapley_new.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/finalize_explanation.R +% Please edit documentation in R/compute_estimates.R \name{compute_shapley_new} \alias{compute_shapley_new} \title{Compute shapley values} @@ -9,7 +9,8 @@ compute_shapley_new(internal, dt_vS) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{dt_vS}{The contribution matrix.} } diff --git a/man/compute_time.Rd b/man/compute_time.Rd new file mode 100644 index 0000000000000000000000000000000000000000..e6539e5e44f14049534024c8adcb728723972eed --- /dev/null +++ b/man/compute_time.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/timing.R +\name{compute_time} +\alias{compute_time} +\title{Gathers and computes the timing of the different parts of the explain function.} +\usage{ +compute_time(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Gathers and computes the timing of the different parts of the explain function. +} +\keyword{internal} diff --git a/man/compute_vS.Rd b/man/compute_vS.Rd index 1988ef5c549f3a8e0a05908b5cb153e4e40303c8..1f8a69e0df6eace46caca10531442dd0c3f77d41 100644 --- a/man/compute_vS.Rd +++ b/man/compute_vS.Rd @@ -8,8 +8,7 @@ compute_vS(internal, model, predict_model, method = "future") } \arguments{ \item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{model}{Objects. The model object that ought to be explained. @@ -20,8 +19,10 @@ The prediction function used when \code{model} is not natively supported. See the documentation of \code{\link[=explain]{explain()}} for details.} \item{method}{Character -Indicates whether the lappy method (default) or loop method should be used.} +Indicates whether the lappy method (default) or loop method should be used. +This is only used for testing purposes.} } \description{ Computes \code{v(S)} for all features subsets \code{S}. } +\keyword{internal} diff --git a/man/convert_feature_name_to_idx.Rd b/man/convert_feature_name_to_idx.Rd new file mode 100644 index 0000000000000000000000000000000000000000..1629b930d36ed7e032e442dde237dd98cf40546a --- /dev/null +++ b/man/convert_feature_name_to_idx.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{convert_feature_name_to_idx} +\alias{convert_feature_name_to_idx} +\title{Convert feature names into feature indices} +\usage{ +convert_feature_name_to_idx(causal_ordering, labels, feat_group_txt) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{labels}{Vector of strings containing (the order of) the feature names.} + +\item{feat_group_txt}{String that is either "feature" or "group" based on +if \code{shapr} is computing feature- or group-wise Shapley values} +} +\value{ +The \code{causal_ordering} list, but with feature indices (w.r.t. \code{labels}) instead of feature names. +} +\description{ +Functions that takes a \code{causal_ordering} specified using strings and convert these strings to feature indices. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_coalition_table.Rd b/man/create_coalition_table.Rd new file mode 100644 index 0000000000000000000000000000000000000000..1b340f20729067e2b4507325cdb1c39dceb6ac98 --- /dev/null +++ b/man/create_coalition_table.Rd @@ -0,0 +1,78 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/shapley_setup.R +\name{create_coalition_table} +\alias{create_coalition_table} +\title{Define coalitions, and fetch additional information about each unique coalition} +\usage{ +create_coalition_table( + m, + exact = TRUE, + n_coalitions = 200, + weight_zero_m = 10^6, + paired_shap_sampling = TRUE, + prev_coal_samples = NULL, + coal_feature_list = as.list(seq_len(m)), + approach0 = "gaussian", + kernelSHAP_reweighting = "none", + dt_valid_causal_coalitions = NULL +) +} +\arguments{ +\item{m}{Positive integer. +Total number of features/groups.} + +\item{exact}{Logical. +If \code{TRUE} all \code{2^m} coalitions are generated, otherwise a subsample of the coalitions is used.} + +\item{n_coalitions}{Positive integer. +Note that if \code{exact = TRUE}, \code{n_coalitions} is ignored.} + +\item{weight_zero_m}{Numeric. +The value to use as a replacement for infinite coalition weights when doing numerical operations.} + +\item{paired_shap_sampling}{Logical. +Whether to do paired sampling of coalitions.} + +\item{prev_coal_samples}{List. +A list of previously sampled coalitions.} + +\item{coal_feature_list}{List. +A list mapping each coalition to the features it contains.} + +\item{approach0}{Character vector. +Contains the approach to be used for eastimation of each coalition size. Same as \code{approach} in \code{explain()}.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} + +\item{dt_valid_causal_coalitions}{data.table. Only applicable for asymmetric Shapley +values explanations, and is \code{NULL} for symmetric Shapley values. +The data.table contains information about the coalitions that respects the causal ordering.} +} +\value{ +A data.table with columns about the that contains the following columns: +} +\description{ +Define coalitions, and fetch additional information about each unique coalition +} +\examples{ +# All coalitions +x <- create_coalition_table(m = 3) +nrow(x) # Equals 2^3 = 8 + +# Subsample of coalitions +x <- create_coalition_table(exact = FALSE, m = 10, n_coalitions = 1e2) +} +\author{ +Nikolai Sellereite, Martin Jullum +} diff --git a/man/create_ctree.Rd b/man/create_ctree.Rd index 3c3db21f6ed82dd6ec2726266e07bcdb09962fd5..a85d8871f26314d6705bf32649c7ec9eebeab346 100644 --- a/man/create_ctree.Rd +++ b/man/create_ctree.Rd @@ -27,7 +27,7 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{minbucket}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} -\item{use_partykit}{String. In some semi-rare cases \code{partyk::ctree} runs into an error related to the LINPACK +\item{use_partykit}{String. In some semi-rare cases \code{partykit::ctree} runs into an error related to the LINPACK used by R. To get around this problem, one may fall back to using the newer (but slower) \code{partykit::ctree} function, which is a reimplementation of the same method. Setting this parameter to \code{"on_error"} (default) falls back to \code{partykit::ctree}, if \code{party::ctree} fails. Other options are \code{"never"}, which always diff --git a/man/create_marginal_data_categoric.Rd b/man/create_marginal_data_categoric.Rd new file mode 100644 index 0000000000000000000000000000000000000000..6dfb185e23aa23d0809428cef50c8918dc6d0769 --- /dev/null +++ b/man/create_marginal_data_categoric.Rd @@ -0,0 +1,59 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{create_marginal_data_categoric} +\alias{create_marginal_data_categoric} +\title{Create marginal categorical data for causal Shapley values} +\usage{ +create_marginal_data_categoric( + n_MC_samples, + x_explain, + Sbar_features, + S_original, + joint_prob_dt +) +} +\arguments{ +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{x_explain}{A matrix or data.frame/data.table. +Contains the the features, whose predictions ought to be explained.} + +\item{Sbar_features}{Vector of integers containing the features indices to generate marginal observations for. +That is, if \code{Sbar_features} is \code{c(1,4)}, then we sample \code{n_MC_samples} observations from \eqn{P(X_1, X_4)}. +That is, we sample the first and fourth feature values from the same valid feature coalition using +the marginal probability, so we do not break the dependence between them.} + +\item{S_original}{Vector of integers containing the features indices of the original coalition \code{S}. I.e., not the +features in the current sampling step, but the features are known to us before starting the chain of sampling steps.} + +\item{joint_prob_dt}{Data.table containing the joint probability distribution for each coalition of feature values.} +} +\value{ +Data table of dimension \eqn{(`n_MC_samples` * `nrow(x_explain)`) \times `length(Sbar_features)`} with the +sampled observations. +} +\description{ +This function is used when we generate marginal data for the categorical approach when we have several sampling +steps. We need to treat this separately, as we here in the marginal step CANNOT make feature values such +that the combination of those and the feature values we condition in S are NOT in +\code{categorical.joint_prob_dt}. If we do this, then we cannot progress further in the chain of sampling +steps. E.g., X1 in (1,2,3), X2 in (1,2,3), and X3 in (1,2,3). +We know X2 = 2, and let causal structure be X1 -> X2 -> X3. Assume that +P(X1 = 1, X2 = 2, X = 3) = P(X1 = 2, X2 = 2, X = 3) = 1/2. Then there is no point +generating X1 = 3, as we then cannot generate X3. +The solution is only to generate the values which can proceed through the whole +chain of sampling steps. To do that, we have to ensure the the marginal sampling +respects the valid feature coalitions for all sets of conditional features, i.e., +the features in \code{features_steps_cond_on}. +We sample from the valid coalitions using the MARGINAL probabilities. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_marginal_data_gaussian.Rd b/man/create_marginal_data_gaussian.Rd new file mode 100644 index 0000000000000000000000000000000000000000..31d54467cf17d156fb8265e21549ee134f30e5ef --- /dev/null +++ b/man/create_marginal_data_gaussian.Rd @@ -0,0 +1,25 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/approach_gaussian.R +\name{create_marginal_data_gaussian} +\alias{create_marginal_data_gaussian} +\title{Generate marginal Gaussian data using Cholesky decomposition} +\usage{ +create_marginal_data_gaussian(n_MC_samples, Sbar_features, mu, cov_mat) +} +\arguments{ +\item{n_MC_samples}{Integer. The number of samples to generate.} + +\item{Sbar_features}{Vector of integers indicating which marginals to sample from.} + +\item{mu}{Numeric vector containing the expected values for all features in the multivariate Gaussian distribution.} + +\item{cov_mat}{Numeric matrix containing the covariance between all features +in the multivariate Gaussian distribution.} +} +\description{ +Given a multivariate Gaussian distribution, this function creates data from specified marginals of said distribution. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/create_marginal_data_training.Rd b/man/create_marginal_data_training.Rd new file mode 100644 index 0000000000000000000000000000000000000000..e86985e8e3c1141e4b2fb533b7149c85c2e742ac --- /dev/null +++ b/man/create_marginal_data_training.Rd @@ -0,0 +1,65 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{create_marginal_data_training} +\alias{create_marginal_data_training} +\title{Function that samples data from the empirical marginal training distribution} +\usage{ +create_marginal_data_training( + x_train, + n_explain, + Sbar_features, + n_MC_samples = 1000, + stable_version = TRUE +) +} +\arguments{ +\item{x_train}{Matrix or data.frame/data.table. +Contains the data used to estimate the (conditional) distributions for the features +needed to properly estimate the conditional expectations in the Shapley formula.} + +\item{n_explain}{Integer. The number of explicands/observations to explain.} + +\item{Sbar_features}{Vector of integers containing the features indices to generate marginal observations for. +That is, if \code{Sbar_features} is \code{c(1,4)}, then we sample \code{n_MC_samples} observations from \eqn{P(X_1, X_4)} using the +empirical training observations (with replacements). That is, we sample the first and fourth feature values from +the same training observation, so we do not break the dependence between them.} + +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{stable_version}{Logical. If \code{TRUE} and \code{n_MC_samples} > \code{n_train}, then we include each training observation +\code{n_MC_samples \%/\% n_train} times and then sample the remaining \verb{n_MC_samples \%\% n_train samples}. Only the latter is +done when \code{n_MC_samples < n_train}. This is done separately for each explicand. If \code{FALSE}, we randomly sample the +from the observations.} +} +\value{ +Data table of dimension \eqn{`n_MC_samples` \times `length(Sbar_features)`} with the sampled observations. +} +\description{ +Sample observations from the empirical distribution P(X) using the training dataset. +} +\examples{ +\dontrun{ +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +x_train +create_marginal_data__training(x_train = x_train, Sbar_features = c(1, 4), n_MC_samples = 10) +} + +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/default_doc.Rd b/man/default_doc.Rd index cca1358e3f0a6e3a267b0c5a42b2b1422db3a5c6..1da47ca1ebfcb7d0797d8f4bc2c2656c53b8089f 100644 --- a/man/default_doc.Rd +++ b/man/default_doc.Rd @@ -9,7 +9,8 @@ default_doc(internal, model, predict_model, output_size, extra, ...) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/default_doc_explain.Rd b/man/default_doc_explain.Rd index 6adafa2d72872108bb3ea54bf881e243f7c15813..a33882c5ed5b9b20a1d5225bc707f7eb9a65b711 100644 --- a/man/default_doc_explain.Rd +++ b/man/default_doc_explain.Rd @@ -4,13 +4,17 @@ \alias{default_doc_explain} \title{Exported documentation helper function.} \usage{ -default_doc_explain(internal, index_features) +default_doc_explain(internal, iter, index_features) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{iter}{Integer. +The iteration number. Only used internally.} + +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} } \description{ Exported documentation helper function. diff --git a/man/explain.Rd b/man/explain.Rd index 2b121b12dd1aaf09d465d9ba5a532d4ff3479799..f7fd8694ea631653fe6064135d92404b0e157204 100644 --- a/man/explain.Rd +++ b/man/explain.Rd @@ -9,18 +9,24 @@ explain( x_explain, x_train, approach, - prediction_zero, - n_combinations = NULL, + phi0, + iterative = NULL, + max_n_coalitions = NULL, group = NULL, - n_samples = 1000, - n_batches = NULL, + paired_shap_sampling = TRUE, + n_MC_samples = 1000, + kernelSHAP_reweighting = "on_all_cond", seed = 1, - keep_samp_for_vS = FALSE, + verbose = "basic", predict_model = NULL, get_model_specs = NULL, - MSEv_uniform_comb_weights = TRUE, - timing = TRUE, - verbose = 0, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + extra_computation_args = list(), + iterative_args = list(), + output_args = list(), ... ) } @@ -43,17 +49,31 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable.} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} \item{group}{List. If \code{NULL} regular feature wise Shapley values are computed. @@ -61,39 +81,65 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} +\item{paired_shap_sampling}{Logical. +If \code{TRUE} (default), paired versions of all sampled coalitions are also included in the computation. +That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +computing the Shapley values. This is done to reduce the variance of the Shapley value estimates.} -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} + +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{predict_model}{Function. The prediction function used when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function must have two arguments, \code{model} and \code{newdata} which specify, respectively, the model -and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +and a data.frame/data.table to compute predictions for. +The function must give the prediction as a numeric vector. \code{NULL} (the default) uses functions specified internally. Can also be used to override the default function for natively supported model classes.} \item{get_model_specs}{Function. An optional function for checking model/data consistency when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function takes \code{model} as argument and provides a list with 3 elements: \describe{ \item{labels}{Character vector with the names of each feature.} @@ -104,18 +150,59 @@ If \code{NULL} (the default) internal functions are used for natively supported disabled for unsupported model classes. Can also be used to override the default function for natively supported model classes.} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} - -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{prev_shapr_object}{\code{shapr} object or string. +If an object of class \code{shapr} is provided or string with a path to where intermediate results are strored, +then the function will use the previous object to continue the computation. +This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +want to continue the iterative estimation. See the vignette for examples.} + +\item{asymmetric}{Logical. +Not applicable for (regular) non-causal or asymmetric explanations. +If \code{FALSE} (default), \code{explain} computes regular symmetric Shapley values, +If \code{TRUE}, then \code{explain} compute asymmetric Shapley values based on the (partial) causal ordering +given by \code{causal_ordering}. That is, \code{explain} only uses the feature combinations/coalitions that +respect the causal ordering when computing the asymmetric Shapley values. If \code{asymmetric} is \code{TRUE} and +\code{confounding} is \code{NULL} (default), then \code{explain} computes asymmetric conditional Shapley values as specified in +Frye et al. (2020). If \code{confounding} is provided, i.e., not \code{NULL}, then \code{explain} computes asymmetric causal +Shapley values as specified in Heskes et al. (2020).} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{extra_computation_args}{Named list. +Specifices extra arguments related to the computation of the Shapley values. +See \code{\link[=get_extra_est_args_default]{get_extra_est_args_default()}} for description of the arguments and their default values.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{output_args}{Named list. +Specifices certain arguments related to the output of the function. +See \code{\link[=get_output_args_default]{get_output_args_default()}} for description of the arguments and their default values.} \item{...}{ Arguments passed on to \code{\link[=setup_approach.empirical]{setup_approach.empirical}}, \code{\link[=setup_approach.independence]{setup_approach.independence}}, \code{\link[=setup_approach.gaussian]{setup_approach.gaussian}}, \code{\link[=setup_approach.copula]{setup_approach.copula}}, \code{\link[=setup_approach.ctree]{setup_approach.ctree}}, \code{\link[=setup_approach.vaeac]{setup_approach.vaeac}}, \code{\link[=setup_approach.categorical]{setup_approach.categorical}}, \code{\link[=setup_approach.regression_separate]{setup_approach.regression_separate}}, \code{\link[=setup_approach.regression_surrogate]{setup_approach.regression_surrogate}}, \code{\link[=setup_approach.timeseries]{setup_approach.timeseries}} @@ -130,7 +217,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -144,7 +231,8 @@ Only used for \code{empirical.type} is either \code{"AICc_each_k"} or \code{"AIC \item{\code{empirical.cov_mat}}{Numeric matrix. (Optional, default = NULL) Containing the covariance matrix of the data generating distribution used to define the Mahalanobis distance. \code{NULL} means it is estimated from \code{x_train}.} - \item{\code{internal}}{Not used.} + \item{\code{internal}}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -160,13 +248,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{vaeac.depth}}{Positive integer (default is \code{3}). The number of hidden layers in the neural networks of the masked encoder, full encoder, and decoder.} \item{\code{vaeac.width}}{Positive integer (default is \code{32}). The number of neurons in each @@ -188,7 +276,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{regression.model}}{A \code{tidymodels} object of class \code{model_specs}. Default is a linear regression model, i.e., @@ -201,8 +289,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -218,13 +306,17 @@ Note, to make it easier to call \code{explain()} from Python, the \code{regressi containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} - \item{\code{regression.surrogate_n_comb}}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} + \item{\code{regression.surrogate_n_comb}}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} \item{\code{timeseries.bounds}}{Numeric vector of length two. (Default = c(NULL, NULL)) @@ -236,58 +328,52 @@ This is useful if the underlying time series are scaled between 0 and 1, for exa \value{ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} \item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} } - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. } \description{ Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified \code{model} by using the method specified in \code{approach} to estimate the conditional expectation. } \details{ -The most important thing to notice is that \code{shapr} has implemented eight different -Monte Carlo-based approaches for estimating the conditional distributions of the data, namely \code{"empirical"}, -\code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. -\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}, -and see the separate vignette on the regression-based approaches for more information. -In addition, the user also has the option of combining the different Monte Carlo-based approaches. -E.g., if you're in a situation where you have trained a model that consists of 10 features, -and you'd like to use the \code{"gaussian"} approach when you condition on a single feature, -the \code{"empirical"} approach if you condition on 2-5 features, and \code{"copula"} version -if you condition on more than 5 features this can be done by simply passing -\code{approach = c("gaussian", rep("empirical", 4), rep("copula", 4))}. If -\code{"approach[i]" = "gaussian"} means that you'd like to use the \code{"gaussian"} approach -when conditioning on \code{i} features. Conditioning on all features needs no approach as that is given -by the complete prediction itself, and should thus not be part of the vector. - -For \code{approach="ctree"}, \code{n_samples} corresponds to the number of samples -from the leaf node (see an exception related to the \code{sample} argument). -For \code{approach="empirical"}, \code{n_samples} is the \eqn{K} parameter in equations (14-15) of -Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -\code{empirical.eta} argument. +The \code{shapr} package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +\code{"empirical"}, \code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. +\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}. +It is also possible to combine the different approaches, see the vignettes for more information. + +The package also supports the computation of causal and asymmetric Shapley values as introduced by +Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +as a way to incorporate causal knowledge in the real world by restricting the possible feature +combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +on the prediction, taking into account their causal relationships, by adapting the sampling procedure in \code{shapr}. + +The package allows for parallelized computation with progress updates through the tightly connected +\link[future:future]{future::future} and \link[progressr:progressr]{progressr::progressr} packages. See the examples below. +For iterative estimation (\code{iterative=TRUE}), intermediate results may also be printed to the console +(according to the \code{verbose} argument). +Moreover, the intermediate results are written to disk. +This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +in a memory friendly manner. } \examples{ @@ -311,14 +397,26 @@ model <- lm(lm_formula, data = data_train) # Explain predictions p <- mean(data_train[, y_var]) +\dontrun{ +# (Optionally) enable parallelization via the future package +if (requireNamespace("future", quietly = TRUE)) { + future::plan("multisession", workers = 2) +} +} + +# (Optionally) enable progress updates within every iteration via the progressr package +if (requireNamespace("progressr", quietly = TRUE)) { + progressr::handlers(global = TRUE) +} + # Empirical approach explain1 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Gaussian approach @@ -327,8 +425,8 @@ explain2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Gaussian copula approach @@ -337,8 +435,8 @@ explain3 <- explain( x_explain = x_explain, x_train = x_train, approach = "copula", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # ctree approach @@ -347,8 +445,8 @@ explain4 <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Combined approach @@ -358,12 +456,12 @@ explain5 <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) # Print the Shapley values -print(explain1$shapley_values) +print(explain1$shapley_values_est) # Plot the results if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -380,10 +478,10 @@ explain_groups <- explain( x_train = x_train, group = group_list, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) -print(explain_groups$shapley_values) +print(explain_groups$shapley_values_est) # Separate and surrogate regression approaches with linear regression models. # More complex regression models can be used, and we can use CV to @@ -395,7 +493,7 @@ explain_separate_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p, + phi0 = p, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -404,15 +502,40 @@ explain_surrogate_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p, + phi0 = p, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) +## iterative estimation +# For illustration purposes only. By default not used for such small dimensions as here + +# Gaussian approach +explain_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) +) + } \references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} } \author{ Martin Jullum, Lars Henry Berge Olsen diff --git a/man/explain_forecast.Rd b/man/explain_forecast.Rd index 91565d96d81fba863c3dc0f8e299907b25da7dc2..df2d3176ce81078873fa407165ef6685e1e87874 100644 --- a/man/explain_forecast.Rd +++ b/man/explain_forecast.Rd @@ -14,18 +14,18 @@ explain_forecast( explain_xreg_lags = explain_y_lags, horizon, approach, - prediction_zero, - n_combinations = NULL, + phi0, + max_n_coalitions = NULL, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "on_all_cond", group_lags = TRUE, group = NULL, - n_samples = 1000, - n_batches = NULL, + n_MC_samples = 1000, seed = 1, - keep_samp_for_vS = FALSE, predict_model = NULL, get_model_specs = NULL, - timing = TRUE, - verbose = 0, + verbose = "basic", ... ) } @@ -70,17 +70,48 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable.} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} + +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{group_lags}{Logical. If \code{TRUE} all lags of each variable are grouped together and explained as a group. @@ -92,39 +123,30 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} - -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} \item{predict_model}{Function. The prediction function used when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function must have two arguments, \code{model} and \code{newdata} which specify, respectively, the model -and a data.frame/data.table to compute predictions for. The function must give the prediction as a numeric vector. +and a data.frame/data.table to compute predictions for. +The function must give the prediction as a numeric vector. \code{NULL} (the default) uses functions specified internally. Can also be used to override the default function for natively supported model classes.} \item{get_model_specs}{Function. An optional function for checking model/data consistency when \code{model} is not natively supported. -(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported -models.) +(Run \code{\link[=get_supported_models]{get_supported_models()}} for a list of natively supported models.) The function takes \code{model} as argument and provides a list with 3 elements: \describe{ \item{labels}{Character vector with the names of each feature.} @@ -135,13 +157,22 @@ If \code{NULL} (the default) internal functions are used for natively supported disabled for unsupported model classes. Can also be used to override the default function for natively supported model classes.} -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} - -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{...}{ Arguments passed on to \code{\link[=setup_approach.empirical]{setup_approach.empirical}}, \code{\link[=setup_approach.independence]{setup_approach.independence}}, \code{\link[=setup_approach.gaussian]{setup_approach.gaussian}}, \code{\link[=setup_approach.copula]{setup_approach.copula}}, \code{\link[=setup_approach.ctree]{setup_approach.ctree}}, \code{\link[=setup_approach.vaeac]{setup_approach.vaeac}}, \code{\link[=setup_approach.categorical]{setup_approach.categorical}}, \code{\link[=setup_approach.timeseries]{setup_approach.timeseries}} @@ -156,7 +187,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -170,7 +201,8 @@ Only used for \code{empirical.type} is either \code{"AICc_each_k"} or \code{"AIC \item{\code{empirical.cov_mat}}{Numeric matrix. (Optional, default = NULL) Containing the covariance matrix of the data generating distribution used to define the Mahalanobis distance. \code{NULL} means it is estimated from \code{x_train}.} - \item{\code{internal}}{Not used.} + \item{\code{internal}}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -186,13 +218,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{vaeac.depth}}{Positive integer (default is \code{3}). The number of hidden layers in the neural networks of the masked encoder, full encoder, and decoder.} \item{\code{vaeac.width}}{Positive integer (default is \code{32}). The number of neurons in each @@ -214,7 +246,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) @@ -228,32 +260,25 @@ This is useful if the underlying time series are scaled between 0 and 1, for exa \value{ Object of class \code{c("shapr", "list")}. Contains the following items: \describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} \item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} } - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. } \description{ Computes dependence-aware Shapley values for observations in \code{explain_idx} from the specified @@ -291,14 +316,24 @@ explain_forecast( explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE ) } \references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} } \author{ Martin Jullum, Lars Henry Berge Olsen diff --git a/man/explain_tripledot_docs.Rd b/man/explain_tripledot_docs.Rd index a739b97b554090cb730992830c71b6528da6c093..bd52859e73b787ee54721273e09b45f02c78f4f4 100644 --- a/man/explain_tripledot_docs.Rd +++ b/man/explain_tripledot_docs.Rd @@ -20,7 +20,7 @@ If e.g. \code{eta = .8} we will choose the \code{K} samples with the largest wei accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{\code{empirical.fixed_sigma}}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{\code{empirical.n_samples_aicc}}{Positive integer. (default = 1000) Number of samples to consider in AICc optimization. @@ -40,7 +40,7 @@ values. \code{NULL} means it is estimated from the \code{x_train} and \code{x_explain}.} \item{\code{categorical.epsilon}}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} \item{\code{ctree.mincriterion}}{Numeric scalar or vector. (default = 0.95) @@ -52,13 +52,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi \item{\code{ctree.minbucket}}{Numeric scalar. (default = 7) Determines the minimum sum of weights in a terminal node required for a split} \item{\code{ctree.sample}}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{\code{gaussian.mu}}{Numeric vector. (Optional) Containing the mean of the data generating distribution. \code{NULL} means it is estimated from the \code{x_train}.} @@ -75,8 +75,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -92,13 +92,17 @@ Note, to make it easier to call \code{explain()} from Python, the \code{regressi containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} - \item{\code{regression.surrogate_n_comb}}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} + \item{\code{regression.surrogate_n_comb}}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{\code{timeseries.fixed_sigma_vec}}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} \item{\code{timeseries.bounds}}{Numeric vector of length two. (Default = c(NULL, NULL)) @@ -125,7 +129,7 @@ This includes \code{vaeac.extra_parameters$epochs_initiation_phase}, where the d \description{ This helper function displays the specific arguments applicable to the different approaches. Note that when calling \code{\link[=explain]{explain()}} from Python, the parameters -are renamed from the form \code{approach.parameter_name} to \code{approach_parameter_name}. +are renamed from the \code{approach.parameter_name} to \code{approach_parameter_name}. That is, an underscore has replaced the dot as the dot is reserved in Python. } \author{ diff --git a/man/feature_combinations.Rd b/man/feature_combinations.Rd deleted file mode 100644 index f6b6c42208a2d243effb16eb54fa6d311a6c40a5..0000000000000000000000000000000000000000 --- a/man/feature_combinations.Rd +++ /dev/null @@ -1,58 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_combinations} -\alias{feature_combinations} -\title{Define feature combinations, and fetch additional information about each unique combination} -\usage{ -feature_combinations( - m, - exact = TRUE, - n_combinations = 200, - weight_zero_m = 10^6, - group_num = NULL -) -} -\arguments{ -\item{m}{Positive integer. Total number of features.} - -\item{exact}{Logical. If \code{TRUE} all \code{2^m} combinations are generated, otherwise a -subsample of the combinations is used.} - -\item{n_combinations}{Positive integer. Note that if \code{exact = TRUE}, -\code{n_combinations} is ignored. However, if \code{m > 12} you'll need to add a positive integer -value for \code{n_combinations}.} - -\item{weight_zero_m}{Numeric. The value to use as a replacement for infinite combination -weights when doing numerical operations.} - -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} -} -\value{ -A data.table that contains the following columns: -\describe{ -\item{id_combination}{Positive integer. Represents a unique key for each combination. Note that the table -is sorted by \code{id_combination}, so that is always equal to \code{x[["id_combination"]] = 1:nrow(x)}.} -\item{features}{List. Each item of the list is an integer vector where \code{features[[i]]} -represents the indices of the features included in combination \code{i}. Note that all the items -are sorted such that \code{features[[i]] == sort(features[[i]])} is always true.} -\item{n_features}{Vector of positive integers. \code{n_features[i]} equals the number of features in combination -\code{i}, i.e. \code{n_features[i] = length(features[[i]])}.}. -\item{N}{Positive integer. The number of unique ways to sample \code{n_features[i]} features -from \code{m} different features, without replacement.} -} -} -\description{ -Define feature combinations, and fetch additional information about each unique combination -} -\examples{ -# All combinations -x <- feature_combinations(m = 3) -nrow(x) # Equals 2^3 = 8 - -# Subsample of combinations -x <- feature_combinations(exact = FALSE, m = 10, n_combinations = 1e2) -} -\author{ -Nikolai Sellereite, Martin Jullum -} diff --git a/man/feature_group.Rd b/man/feature_group.Rd deleted file mode 100644 index ce67752458173550e2945385a9106e86b3e4125c..0000000000000000000000000000000000000000 --- a/man/feature_group.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_group} -\alias{feature_group} -\title{Analogue to feature_exact, but for groups instead.} -\usage{ -feature_group(group_num, weight_zero_m = 10^6) -} -\arguments{ -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} - -\item{weight_zero_m}{Positive integer. Represents the Shapley weight for two special -cases, i.e. the case where you have either \code{0} or \code{m} features/feature groups.} -} -\value{ -data.table with all feature group combinations, shapley weights etc. -} -\description{ -Analogue to feature_exact, but for groups instead. -} -\keyword{internal} diff --git a/man/feature_group_not_exact.Rd b/man/feature_group_not_exact.Rd deleted file mode 100644 index da4d90d6651017768f4c7390a9c78159fa7aeabf..0000000000000000000000000000000000000000 --- a/man/feature_group_not_exact.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R -\name{feature_group_not_exact} -\alias{feature_group_not_exact} -\title{Analogue to feature_not_exact, but for groups instead.} -\usage{ -feature_group_not_exact(group_num, n_combinations = 200, weight_zero_m = 10^6) -} -\arguments{ -\item{group_num}{List. Contains vector of integers indicating the feature numbers for the -different groups.} - -\item{weight_zero_m}{Positive integer. Represents the Shapley weight for two special -cases, i.e. the case where you have either \code{0} or \code{m} features/feature groups.} -} -\value{ -data.table with all feature group combinations, shapley weights etc. -} -\description{ -Analogue to feature_not_exact, but for groups instead. -} -\keyword{internal} diff --git a/man/feature_matrix_cpp.Rd b/man/feature_matrix_cpp.Rd deleted file mode 100644 index 8282cf1f238681d206e3ef0897060a905068e284..0000000000000000000000000000000000000000 --- a/man/feature_matrix_cpp.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/RcppExports.R -\name{feature_matrix_cpp} -\alias{feature_matrix_cpp} -\title{Get feature matrix} -\usage{ -feature_matrix_cpp(features, m) -} -\arguments{ -\item{features}{List} - -\item{m}{Positive integer. Total number of features} -} -\value{ -Matrix -} -\description{ -Get feature matrix -} -\author{ -Nikolai Sellereite -} -\keyword{internal} diff --git a/man/figures/README-basic_example-1.png b/man/figures/README-basic_example-1.png index 95378c7c3c76472652b716c6ee2c234f2f790ce9..7c3f4ee4ac3595ed845cfda63a3d09bca8b2cb94 100644 Binary files a/man/figures/README-basic_example-1.png and b/man/figures/README-basic_example-1.png differ diff --git a/man/finalize_explanation.Rd b/man/finalize_explanation.Rd index ee74c8903d32d6e1c8574a47c17874c175139860..cb92dcfdd189990348fa2f6724412ad42e4c8ab5 100644 --- a/man/finalize_explanation.Rd +++ b/man/finalize_explanation.Rd @@ -2,199 +2,14 @@ % Please edit documentation in R/finalize_explanation.R \name{finalize_explanation} \alias{finalize_explanation} -\title{Computes the Shapley values given \code{v(S)}} +\title{Gathers the final output to create the explanation object} \usage{ -finalize_explanation(vS_list, internal) +finalize_explanation(internal) } \arguments{ -\item{vS_list}{List -Output from \code{\link[=compute_vS]{compute_vS()}}} - \item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} -} -\value{ -Object of class \code{c("shapr", "list")}. Contains the following items: -\describe{ -\item{shapley_values}{data.table with the estimated Shapley values} -\item{internal}{List with the different parameters, data and functions used internally} -\item{pred_explain}{Numeric vector with the predictions for the explained observations} -\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach.} -} - -\code{shapley_values} is a data.table where the number of rows equals -the number of observations you'd like to explain, and the number of columns equals \code{m +1}, -where \code{m} equals the total number of features in your model. - -If \code{shapley_values[i, j + 1] > 0} it indicates that the j-th feature increased the prediction for -the i-th observation. Likewise, if \code{shapley_values[i, j + 1] < 0} it indicates that the j-th feature -decreased the prediction for the i-th observation. -The magnitude of the value is also important to notice. E.g. if \code{shapley_values[i, k + 1]} and -\code{shapley_values[i, j + 1]} are greater than \code{0}, where \code{j != k}, and -\code{shapley_values[i, k + 1]} > \code{shapley_values[i, j + 1]} this indicates that feature -\code{j} and \code{k} both increased the value of the prediction, but that the effect of the k-th -feature was larger than the j-th feature. - -The first column in \code{dt}, called \code{none}, is the prediction value not assigned to any of the features -(\ifelse{html}{\eqn{\phi}\out{0}}{\eqn{\phi_0}}). -It's equal for all observations and set by the user through the argument \code{prediction_zero}. -The difference between the prediction and \code{none} is distributed among the other features. -In theory this value should be the expected prediction without conditioning on any features. -Typically we set this value equal to the mean of the response variable in our training data, but other choices -such as the mean of the predictions in the training data are also reasonable. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} } \description{ -Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified -\code{model} by using the method specified in \code{approach} to estimate the conditional expectation. -} -\details{ -The most important thing to notice is that \code{shapr} has implemented eight different -Monte Carlo-based approaches for estimating the conditional distributions of the data, namely \code{"empirical"}, -\code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. -\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}, -and see the separate vignette on the regression-based approaches for more information. -In addition, the user also has the option of combining the different Monte Carlo-based approaches. -E.g., if you're in a situation where you have trained a model that consists of 10 features, -and you'd like to use the \code{"gaussian"} approach when you condition on a single feature, -the \code{"empirical"} approach if you condition on 2-5 features, and \code{"copula"} version -if you condition on more than 5 features this can be done by simply passing -\code{approach = c("gaussian", rep("empirical", 4), rep("copula", 4))}. If -\code{"approach[i]" = "gaussian"} means that you'd like to use the \code{"gaussian"} approach -when conditioning on \code{i} features. Conditioning on all features needs no approach as that is given -by the complete prediction itself, and should thus not be part of the vector. - -For \code{approach="ctree"}, \code{n_samples} corresponds to the number of samples -from the leaf node (see an exception related to the \code{sample} argument). -For \code{approach="empirical"}, \code{n_samples} is the \eqn{K} parameter in equations (14-15) of -Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the -\code{empirical.eta} argument. -} -\examples{ - -# Load example data -data("airquality") -airquality <- airquality[complete.cases(airquality), ] -x_var <- c("Solar.R", "Wind", "Temp", "Month") -y_var <- "Ozone" - -# Split data into test- and training data -data_train <- head(airquality, -3) -data_explain <- tail(airquality, 3) - -x_train <- data_train[, x_var] -x_explain <- data_explain[, x_var] - -# Fit a linear model -lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + "))) -model <- lm(lm_formula, data = data_train) - -# Explain predictions -p <- mean(data_train[, y_var]) - -# Empirical approach -explain1 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "empirical", - prediction_zero = p, - n_samples = 1e2 -) - -# Gaussian approach -explain2 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "gaussian", - prediction_zero = p, - n_samples = 1e2 -) - -# Gaussian copula approach -explain3 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "copula", - prediction_zero = p, - n_samples = 1e2 -) - -# ctree approach -explain4 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = "ctree", - prediction_zero = p, - n_samples = 1e2 -) - -# Combined approach -approach <- c("gaussian", "gaussian", "empirical") -explain5 <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - approach = approach, - prediction_zero = p, - n_samples = 1e2 -) - -# Print the Shapley values -print(explain1$shapley_values) - -# Plot the results -if (requireNamespace("ggplot2", quietly = TRUE)) { - plot(explain1) - plot(explain1, plot_type = "waterfall") -} - -# Group-wise explanations -group_list <- list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")) - -explain_groups <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - group = group_list, - approach = "empirical", - prediction_zero = p, - n_samples = 1e2 -) -print(explain_groups$shapley_values) - -# Separate and surrogate regression approaches with linear regression models. -# More complex regression models can be used, and we can use CV to -# tune the hyperparameters of the regression models and preprocess -# the data before sending it to the model. See the regression vignette -# (Shapley value explanations using the regression paradigm) for more -# details about the `regression_separate` and `regression_surrogate` approaches. -explain_separate_lm <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - prediction_zero = p, - approach = "regression_separate", - regression.model = parsnip::linear_reg() -) - -explain_surrogate_lm <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, - prediction_zero = p, - approach = "regression_surrogate", - regression.model = parsnip::linear_reg() -) - -} -\references{ -Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: -More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. -} -\author{ -Martin Jullum, Lars Henry Berge Olsen +Gathers the final output to create the explanation object } diff --git a/man/finalize_explanation_forecast.Rd b/man/finalize_explanation_forecast.Rd new file mode 100644 index 0000000000000000000000000000000000000000..6911de4a94ce906a0afc7d7ed1b405f9cd14d563 --- /dev/null +++ b/man/finalize_explanation_forecast.Rd @@ -0,0 +1,232 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/finalize_explanation.R +\name{finalize_explanation_forecast} +\alias{finalize_explanation_forecast} +\title{Computes the Shapley values given \code{v(S)}} +\usage{ +finalize_explanation_forecast(vS_list, internal) +} +\arguments{ +\item{vS_list}{List +Output from \code{\link[=compute_vS]{compute_vS()}}} + +\item{internal}{List. +Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} +} +\value{ +Object of class \code{c("shapr", "list")}. Contains the following items: +\describe{ +\item{shapley_values_est}{data.table with the estimated Shapley values with explained observation in the rows and +features along the columns. +The column \code{none} is the prediction not devoted to any of the features (given by the argument \code{phi0})} +\item{shapley_values_sd}{data.table with the standard deviation of the Shapley values reflecting the uncertainty. +Note that this only reflects the coalition sampling part of the kernelSHAP procedure, and is therefore by +definition 0 when all coalitions is used. +Only present when \code{extra_computation_args$compute_sd=TRUE}.} +\item{internal}{List with the different parameters, data, functions and other output used internally.} +\item{pred_explain}{Numeric vector with the predictions for the explained observations} +\item{MSEv}{List with the values of the MSEv evaluation criterion for the approach. See the +\href{https://norskregnesentral.github.io/shapr/articles/understanding_shapr.html#msev-evaluation-criterion +}{MSEv evaluation section in the vignette for details}.} +\item{timing}{List containing timing information for the different parts of the computation. +\code{init_time} and \code{end_time} gives the time stamps for the start and end of the computation. +\code{total_time_secs} gives the total time in seconds for the complete execution of \code{explain()}. +\code{main_timing_secs} gives the time in seconds for the main computations. +\code{iter_timing_secs} gives for each iteration of the iterative estimation, the time spent on the different parts +iterative estimation routine.} +} +} +\description{ +Computes dependence-aware Shapley values for observations in \code{x_explain} from the specified +\code{model} by using the method specified in \code{approach} to estimate the conditional expectation. +} +\details{ +The \code{shapr} package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +\code{"empirical"}, \code{"gaussian"}, \code{"copula"}, \code{"ctree"}, \code{"vaeac"}, \code{"categorical"}, \code{"timeseries"}, and \code{"independence"}. +\code{shapr} has also implemented two regression-based approaches \code{"regression_separate"} and \code{"regression_surrogate"}. +It is also possible to combine the different approaches, see the vignettes for more information. + +The package also supports the computation of causal and asymmetric Shapley values as introduced by +Heskes et al. (2020) and Frye et al. (2020). Asymmetric Shapley values were proposed by Heskes et al. (2020) +as a way to incorporate causal knowledge in the real world by restricting the possible feature +combinations/coalitions when computing the Shapley values to those consistent with a (partial) causal ordering. +Causal Shapley values were proposed by Frye et al. (2020) as a way to explain the total effect of features +on the prediction, taking into account their causal relationships, by adapting the sampling procedure in \code{shapr}. + +The package allows for parallelized computation with progress updates through the tightly connected +\link[future:future]{future::future} and \link[progressr:progressr]{progressr::progressr} packages. See the examples below. +For iterative estimation (\code{iterative=TRUE}), intermediate results may also be printed to the console +(according to the \code{verbose} argument). +Moreover, the intermediate results are written to disk. +This combined with iterative estimation with (optional) intermediate results printed to the console (and temporary +written to disk, and batch computing of the v(S) values, enables fast and accurate estimation of the Shapley values +in a memory friendly manner. +} +\examples{ + +# Load example data +data("airquality") +airquality <- airquality[complete.cases(airquality), ] +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +# Split data into test- and training data +data_train <- head(airquality, -3) +data_explain <- tail(airquality, 3) + +x_train <- data_train[, x_var] +x_explain <- data_explain[, x_var] + +# Fit a linear model +lm_formula <- as.formula(paste0(y_var, " ~ ", paste0(x_var, collapse = " + "))) +model <- lm(lm_formula, data = data_train) + +# Explain predictions +p <- mean(data_train[, y_var]) + +\dontrun{ +# (Optionally) enable parallelization via the future package +if (requireNamespace("future", quietly = TRUE)) { + future::plan("multisession", workers = 2) +} +} + +# (Optionally) enable progress updates within every iteration via the progressr package +if (requireNamespace("progressr", quietly = TRUE)) { + progressr::handlers(global = TRUE) +} + +# Empirical approach +explain1 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "empirical", + phi0 = p, + n_MC_samples = 1e2 +) + +# Gaussian approach +explain2 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2 +) + +# Gaussian copula approach +explain3 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "copula", + phi0 = p, + n_MC_samples = 1e2 +) + +# ctree approach +explain4 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "ctree", + phi0 = p, + n_MC_samples = 1e2 +) + +# Combined approach +approach <- c("gaussian", "gaussian", "empirical") +explain5 <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = approach, + phi0 = p, + n_MC_samples = 1e2 +) + +# Print the Shapley values +print(explain1$shapley_values_est) + +# Plot the results +if (requireNamespace("ggplot2", quietly = TRUE)) { + plot(explain1) + plot(explain1, plot_type = "waterfall") +} + +# Group-wise explanations +group_list <- list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")) + +explain_groups <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + group = group_list, + approach = "empirical", + phi0 = p, + n_MC_samples = 1e2 +) +print(explain_groups$shapley_values_est) + +# Separate and surrogate regression approaches with linear regression models. +# More complex regression models can be used, and we can use CV to +# tune the hyperparameters of the regression models and preprocess +# the data before sending it to the model. See the regression vignette +# (Shapley value explanations using the regression paradigm) for more +# details about the `regression_separate` and `regression_surrogate` approaches. +explain_separate_lm <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = p, + approach = "regression_separate", + regression.model = parsnip::linear_reg() +) + +explain_surrogate_lm <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + phi0 = p, + approach = "regression_surrogate", + regression.model = parsnip::linear_reg() +) + +## iterative estimation +# For illustration purposes only. By default not used for such small dimensions as here + +# Gaussian approach +explain_iterative <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p, + n_MC_samples = 1e2, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) +) + +} +\references{ +\itemize{ +\item Aas, K., Jullum, M., & Lland, A. (2021). Explaining individual predictions when features are dependent: +More accurate approximations to Shapley values. Artificial Intelligence, 298, 103502. +\item Frye, C., Rowat, C., & Feige, I. (2020). Asymmetric Shapley values: +incorporating causal knowledge into model-agnostic explainability. +Advances in neural information processing systems, 33, 1229-1239. +\item Heskes, T., Sijben, E., Bucur, I. G., & Claassen, T. (2020). Causal shapley values: +Exploiting causal knowledge to explain individual predictions of complex models. +Advances in neural information processing systems, 33, 4778-4789. +\item Olsen, L. H. B., Glad, I. K., Jullum, M., & Aas, K. (2024). A comparative study of methods for estimating +model-agnostic Shapley value explanations. Data Mining and Knowledge Discovery, 1-48. +} +} +\author{ +Martin Jullum, Lars Henry Berge Olsen +} diff --git a/man/get_S_causal_steps.Rd b/man/get_S_causal_steps.Rd new file mode 100644 index 0000000000000000000000000000000000000000..33059af5eb339b2a7919a69d5b4f8465f16e4a91 --- /dev/null +++ b/man/get_S_causal_steps.Rd @@ -0,0 +1,99 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_S_causal_steps} +\alias{get_S_causal_steps} +\title{Get the steps for generating MC samples for coalitions following a causal ordering} +\usage{ +get_S_causal_steps(S, causal_ordering, confounding, as_string = FALSE) +} +\arguments{ +\item{S}{Integer matrix of dimension \code{n_coalitions_valid x m}, where \code{n_coalitions_valid} equals +the total number of valid coalitions that respect the causal ordering given in \code{causal_ordering} and \code{m} equals +the total number of features.} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{as_string}{Boolean. +If the returned object is to be a list of lists of integers or a list of vectors of strings.} +} +\value{ +Depends on the value of the parameter \code{as_string}. If a string, then \code{results[j]} is a vector specifying +the process of generating the samples for coalition \code{j}. The length of \code{results[j]} is the number of steps, and +\code{results[j][i]} is a string of the form \code{features_to_sample|features_to_condition_on}. If the +\code{features_to_condition_on} part is blank, then we are to sample from the marginal distribution. +For \code{as_string == FALSE}, then we rather return a vector where \code{results[[j]][[i]]} contains the elements +\code{Sbar} and \code{S} representing the features to sample and condition on, respectively. +} +\description{ +Get the steps for generating MC samples for coalitions following a causal ordering +} +\examples{ +\dontrun{ +m <- 5 +causal_ordering <- list(1:2, 3:4, 5) +S <- shapr::feature_matrix_cpp(get_valid_causal_coalitions(causal_ordering = causal_ordering), + m = m +) +confounding <- c(TRUE, TRUE, FALSE) +get_S_causal_steps(S, causal_ordering, confounding, as_string = TRUE) + +# Look at the effect of changing the confounding assumptions +SS1 <- get_S_causal_steps(S, causal_ordering, + confounding = c(FALSE, FALSE, FALSE), + as_string = TRUE +) +SS2 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, FALSE, FALSE), as_string = TRUE) +SS3 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, FALSE), as_string = TRUE) +SS4 <- get_S_causal_steps(S, causal_ordering, confounding = c(TRUE, TRUE, TRUE), as_string = TRUE) + +all.equal(SS1, SS2) +SS1[[2]] # Condition on 1 as there is no confounding in the first component +SS2[[2]] # Do NOT condition on 1 as there is confounding in the first component +SS1[[3]] +SS2[[3]] + +all.equal(SS1, SS3) +SS1[[2]] # Condition on 1 as there is no confounding in the first component +SS3[[2]] # Do NOT condition on 1 as there is confounding in the first component +SS1[[5]] # Condition on 3 as there is no confounding in the second component +SS3[[5]] # Do NOT condition on 3 as there is confounding in the second component +SS1[[6]] +SS3[[6]] + +all.equal(SS2, SS3) +SS2[[5]] +SS3[[5]] +SS2[[6]] +SS3[[6]] + +all.equal(SS3, SS4) # No difference as the last component is a singleton +} +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/get_extra_est_args_default.Rd b/man/get_extra_est_args_default.Rd new file mode 100644 index 0000000000000000000000000000000000000000..4f77725324a611db7defe3afad4ef9d4a78aeb16 --- /dev/null +++ b/man/get_extra_est_args_default.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_extra_est_args_default} +\alias{get_extra_est_args_default} +\title{Gets the default values for the extra estimation arguments} +\usage{ +get_extra_est_args_default( + internal, + compute_sd = isFALSE(internal$parameters$exact), + n_boot_samps = 100, + max_batch_size = 10, + min_n_batches = 10 +) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{compute_sd}{Logical. Whether to estimate the standard deviations of the Shapley value estimates. This is TRUE +whenever sampling based kernelSHAP is applied (either iteratively or with a fixed number of coalitions).} + +\item{n_boot_samps}{Integer. The number of bootstrapped samples (i.e. samples with replacement) from the set of all +coalitions used to estimate the standard deviations of the Shapley value estimates.} + +\item{max_batch_size}{Integer. The maximum number of coalitions to estimate simultaneously within each iteration. +A larger numbers requires more memory, but may have a slight computational advantage.} + +\item{min_n_batches}{Integer. The minimum number of batches to split the computation into within each iteration. +Larger numbers gives more frequent progress updates. If parallelization is applied, this should be set no smaller +than the number of parallel workers.} +} +\description{ +Gets the default values for the extra estimation arguments +} +\author{ +Martin Jullum +} diff --git a/man/get_extra_parameters.Rd b/man/get_extra_parameters.Rd index de1acfa3582320865f8d0f18c4dcb81bde8242ef..7168e74bd298dc118fe2d60257229ed7e6b5ccc4 100644 --- a/man/get_extra_parameters.Rd +++ b/man/get_extra_parameters.Rd @@ -4,7 +4,7 @@ \alias{get_extra_parameters} \title{This includes both extra parameters and other objects} \usage{ -get_extra_parameters(internal) +get_extra_parameters(internal, type) } \description{ This includes both extra parameters and other objects diff --git a/man/get_iterative_args_default.Rd b/man/get_iterative_args_default.Rd new file mode 100644 index 0000000000000000000000000000000000000000..ca995d454b1fc99a0be8f37858ec8c589ed03c2b --- /dev/null +++ b/man/get_iterative_args_default.Rd @@ -0,0 +1,50 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_iterative_args_default} +\alias{get_iterative_args_default} +\title{Function to specify arguments of the iterative estimation procedure} +\usage{ +get_iterative_args_default( + internal, + initial_n_coalitions = ceiling(min(200, max(5, internal$parameters$n_features, + (2^internal$parameters$n_features)/10))), + fixed_n_coalitions_per_iter = NULL, + max_iter = 20, + convergence_tol = 0.02, + n_coal_next_iter_factor_vec = c(seq(0.1, 1, by = 0.1), rep(1, max_iter - 10)) +) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{initial_n_coalitions}{Integer. Number of coalitions to use in the first estimation iteration.} + +\item{fixed_n_coalitions_per_iter}{Integer. Number of \code{n_coalitions} to use in each iteration. +\code{NULL} (default) means setting it based on estimates based on a set convergence threshold.} + +\item{max_iter}{Integer. Maximum number of estimation iterations} + +\item{convergence_tol}{Numeric. The t variable in the convergence threshold formula on page 6 in the paper +Covert and Lee (2021), 'Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression' +https://arxiv.org/pdf/2012.01536. Smaller values requires more coalitions before convergence is reached.} + +\item{n_coal_next_iter_factor_vec}{Numeric vector. The number of \code{n_coalitions} that must be used to reach +convergence in the next iteration is estimated. +The number of \code{n_coalitions} actually used in the next iteration is set to this estimate multiplied by +\code{n_coal_next_iter_factor_vec[i]} for iteration \code{i}. +It is wise to start with smaller numbers to avoid using too many \code{n_coalitions} due to uncertain estimates in +the first iterations.} +} +\description{ +Function to specify arguments of the iterative estimation procedure +} +\details{ +The functions sets default values for the iterative estimation procedure, according to the function +defaults. +If the argument \code{iterative} of \code{\link[=explain]{explain()}} is FALSE, it sets parameters corresponding to the use of a +non-iterative estimation procedure +} +\author{ +Martin Jullum +} diff --git a/man/get_max_n_coalitions_causal.Rd b/man/get_max_n_coalitions_causal.Rd new file mode 100644 index 0000000000000000000000000000000000000000..dfcd1e7ec09a56b0cb02481c2dbceccbc19cb474 --- /dev/null +++ b/man/get_max_n_coalitions_causal.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_max_n_coalitions_causal} +\alias{get_max_n_coalitions_causal} +\title{Get the number of coalitions that respects the causal ordering} +\usage{ +get_max_n_coalitions_causal(causal_ordering) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} +} +\value{ +Integer. The (maximum) number of coalitions that respects the causal ordering. +} +\description{ +Get the number of coalitions that respects the causal ordering +} +\details{ +The function computes the number of coalitions that respects the causal ordering by computing the number +of coalitions in each partial causal component and then summing these. We compute +the number of coalitions in the \eqn{i}th a partial causal component by \eqn{2^n - 1}, +where \eqn{n} is the number of features in the the \eqn{i}th partial causal component +and we subtract one as we do not want to include the situation where no features in +the \eqn{i}th partial causal component are present. In the end, we add 1 for the +empty coalition. +} +\examples{ +\dontrun{ +get_max_n_coalitions_causal(list(1:10)) # 2^10 = 1024 (no causal order) +get_max_n_coalitions_causal(list(1:3, 4:7, 8:10)) # 30 +get_max_n_coalitions_causal(list(1:3, 4:5, 6:7, 8, 9:10)) # 18 +get_max_n_coalitions_causal(list(1:3, c(4, 8), c(5, 7), 6, 9:10)) # 18 +get_max_n_coalitions_causal(list(1, 2, 3, 4, 5, 6, 7, 8, 9, 10)) # 11 +} + +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/get_output_args_default.Rd b/man/get_output_args_default.Rd new file mode 100644 index 0000000000000000000000000000000000000000..4365118bc354f904f9c9f6262983c71eccbac5f4 --- /dev/null +++ b/man/get_output_args_default.Rd @@ -0,0 +1,33 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/setup.R +\name{get_output_args_default} +\alias{get_output_args_default} +\title{Gets the default values for the output arguments} +\usage{ +get_output_args_default( + keep_samp_for_vS = FALSE, + MSEv_uniform_comb_weights = TRUE, + saving_path = tempfile("shapr_obj_", fileext = ".rds") +) +} +\arguments{ +\item{keep_samp_for_vS}{Logical. +Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in \code{internal$output}). +Not used for \code{approach="regression_separate"} or \code{approach="regression_surrogate"}.} + +\item{MSEv_uniform_comb_weights}{Logical. +If \code{TRUE} (default), then the function weights the coalitions uniformly when computing the MSEv criterion. +If \code{FALSE}, then the function use the Shapley kernel weights to weight the coalitions when computing the MSEv +criterion. +Note that the Shapley kernel weights are replaced by the sampling frequency when not all coalitions are considered.} + +\item{saving_path}{String. +The path to the directory where the results of the iterative estimation procedure should be saved. +Defaults to a temporary directory.} +} +\description{ +Gets the default values for the output arguments +} +\author{ +Martin Jullum +} diff --git a/man/get_valid_causal_coalitions.Rd b/man/get_valid_causal_coalitions.Rd new file mode 100644 index 0000000000000000000000000000000000000000..537857745e53c2315bc7a63d225b547f973ec3d9 --- /dev/null +++ b/man/get_valid_causal_coalitions.Rd @@ -0,0 +1,41 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{get_valid_causal_coalitions} +\alias{get_valid_causal_coalitions} +\title{Get all coalitions satisfying the causal ordering} +\usage{ +get_valid_causal_coalitions( + causal_ordering, + sort_features_in_coalitions = TRUE +) +} +\arguments{ +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{sort_features_in_coalitions}{Boolean. If \code{TRUE}, then the feature indices in the +coalitions are sorted in increasing order. If \code{FALSE}, then the function maintains the +order of features within each group given in \code{causal_ordering}.} +} +\value{ +List of vectors containing all coalitions that respects the causal ordering. +} +\description{ +This function is only relevant when we are computing asymmetric Shapley values. +For symmetric Shapley values (both regular and causal), all coalitions are allowed. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/observation_impute.Rd b/man/observation_impute.Rd index 813869b284aa60182b8c1e1471298a8d25f43927..690879315d483184b112d6c69d38ec50ff0fb464 100644 --- a/man/observation_impute.Rd +++ b/man/observation_impute.Rd @@ -10,24 +10,28 @@ observation_impute( x_train, x_explain, empirical.eta = 0.7, - n_samples = 1000 + n_MC_samples = 1000 ) } \arguments{ \item{W_kernel}{Numeric matrix. Contains all nonscaled weights between training and test -observations for all feature combinations. The dimension equals \verb{n_train x m}.} +observations for all coalitions. The dimension equals \verb{n_train x m}.} -\item{S}{Integer matrix of dimension \verb{n_combinations x m}, where \code{n_combinations} -and \code{m} equals the total number of sampled/non-sampled feature combinations and +\item{S}{Integer matrix of dimension \verb{n_coalitions x m}, where \code{n_coalitions} +and \code{m} equals the total number of sampled/non-sampled coalitions and the total number of unique features, respectively. Note that \code{m = ncol(x_train)}.} \item{x_train}{Numeric matrix} \item{x_explain}{Numeric matrix} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} } \value{ data.table diff --git a/man/observation_impute_cpp.Rd b/man/observation_impute_cpp.Rd index 077b419abd670eb19081e831fc62637607f81f41..ffd4838d30fd300d02357af1cd7342de48abc3f0 100644 --- a/man/observation_impute_cpp.Rd +++ b/man/observation_impute_cpp.Rd @@ -17,7 +17,7 @@ i.e. \code{min(index_s) >= 1} and \code{max(index_s) <= nrow(S)}.} \item{xtest}{Numeric matrix. Represents a single test observation.} -\item{S}{Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +\item{S}{Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals the total number of sampled/non-sampled feature combinations and \code{m} equals the total number of unique features. Note that \code{m = ncol(xtrain)}. See details for more information.} diff --git a/man/plot.shapr.Rd b/man/plot.shapr.Rd index f45485d4edd64469d5230a8a4d4b6e68d3135084..e2e8564029e81f45ab2b8a1678e8e2256008544b 100644 --- a/man/plot.shapr.Rd +++ b/man/plot.shapr.Rd @@ -15,6 +15,7 @@ bar_plot_order = "largest_first", scatter_features = NULL, scatter_hist = TRUE, + include_group_feature_means = FALSE, ... ) } @@ -86,8 +87,13 @@ character vector, indicating the name(s) of the feature(s) to plot.} \item{scatter_hist}{Logical. Only used for \code{plot_type = "scatter"}. -Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note that the -bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot.} +Whether to include a scatter_hist indicating the distribution of the data when making the scatter plot. Note +that the bins are scaled so that when all the bins are stacked they fit the span of the y-axis of the plot.} + +\item{include_group_feature_means}{Logical. +Whether to include the average feature value in a group on the y-axis or not. +If \code{FALSE} (default), then no value is shown for the groups. If \code{TRUE}, then \code{shapr} includes the mean of the +features in each group.} \item{...}{Currently not used.} } @@ -128,8 +134,8 @@ x <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -178,8 +184,8 @@ x <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = p, - n_samples = 1e2 + phi0 = p, + n_MC_samples = 1e2 ) if (requireNamespace("ggplot2", quietly = TRUE)) { @@ -189,5 +195,5 @@ if (requireNamespace("ggplot2", quietly = TRUE)) { } \author{ -Martin Jullum, Vilde Ung +Martin Jullum, Vilde Ung, Lars Henry Berge Olsen } diff --git a/man/plot_MSEv_eval_crit.Rd b/man/plot_MSEv_eval_crit.Rd index 24c3fc2d0a20ad1f4ae509ac5ed2da65fc54101b..c7d569feec2cbe34c3eff06b08d68ba5f3f89e7c 100644 --- a/man/plot_MSEv_eval_crit.Rd +++ b/man/plot_MSEv_eval_crit.Rd @@ -7,7 +7,7 @@ plot_MSEv_eval_crit( explanation_list, index_x_explain = NULL, - id_combination = NULL, + id_coalition = NULL, CI_level = if (length(explanation_list[[1]]$pred_explain) < 20) NULL else 0.95, geom_col_width = 0.9, plot_type = "overall" @@ -23,29 +23,29 @@ Which of the test observations to plot. E.g. if you have explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the first 5 observations by setting \code{index_x_explain = 1:5}.} -\item{id_combination}{Integer vector. Which of the combinations (coalitions) to plot. -E.g. if you used \code{n_combinations = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the -first 5 combinations and the 10th by setting \code{id_combination = c(1:5, 10)}.} +\item{id_coalition}{Integer vector. Which of the coalitions to plot. +E.g. if you used \code{n_coalitions = 16} in \code{\link[=explain]{explain()}}, you can generate a plot for the +first 5 coalitions and the 10th by setting \code{id_coalition = c(1:5, 10)}.} \item{CI_level}{Positive numeric between zero and one. Default is \code{0.95} if the number of observations to explain is larger than 20, otherwise \code{CI_level = NULL}, which removes the confidence intervals. The level of the approximate -confidence intervals for the overall MSEv and the MSEv_combination. The confidence intervals are based on that +confidence intervals for the overall MSEv and the MSEv_coalition. The confidence intervals are based on that the MSEv scores are means over the observations/explicands, and that means are approximation normal. Since the standard deviations are estimated, we use the quantile t from the T distribution with N_explicands - 1 degrees of freedom corresponding to the provided level. Here, N_explicands is the number of observations/explicands. -MSEv ± t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by -sqrt(N_explicands), thus, the CI are MSEv ± t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the +MSEv +/- t\emph{SD(MSEv)/sqrt(N_explicands). Note that the \code{explain()} function already scales the standard deviation by +sqrt(N_explicands), thus, the CI are MSEv \/- t}MSEv_sd, where the values MSEv and MSEv_sd are extracted from the MSEv data.tables in the objects in the \code{explanation_list}.} \item{geom_col_width}{Numeric. Bar width. By default, set to 90\% of the \code{\link[ggplot2:resolution]{ggplot2::resolution()}} of the data.} \item{plot_type}{Character vector. The possible options are "overall" (default), "comb", and "explicand". If \code{plot_type = "overall"}, then the plot (one bar plot) associated with the overall MSEv evaluation criterion -for each method is created, i.e., when averaging over both the combinations/coalitions and observations/explicands. +for each method is created, i.e., when averaging over both the coalitions and observations/explicands. If \code{plot_type = "comb"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -criterion for each combination/coalition are created, i.e., when we only average over the observations/explicands. +criterion for each coalition are created, i.e., when we only average over the observations/explicands. If \code{plot_type = "explicand"}, then the plots (one line plot and one bar plot) associated with the MSEv evaluation -criterion for each observations/explicands are created, i.e., when we only average over the combinations/coalitions. +criterion for each observations/explicands are created, i.e., when we only average over the coalitions. If \code{plot_type} is a vector of one or several of "overall", "comb", and "explicand", then the associated plots are created.} } @@ -57,8 +57,8 @@ of \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} objects based on the \code{pl Make plots to visualize and compare the MSEv evaluation criterion for a list of \code{\link[=explain]{explain()}} objects applied to the same data and model. The function creates bar plots and line plots with points to illustrate the overall MSEv evaluation -criterion, but also for each observation/explicand and combination by only averaging over -the combinations and observations/explicands, respectively. +criterion, but also for each observation/explicand and coalition by only averaging over +the coalitions and observations/explicands, respectively. } \examples{ # Load necessary librarieslibrary(xgboost) @@ -90,7 +90,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -98,8 +98,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Gaussian 1e1 approach @@ -108,8 +108,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e1 + phi0 = phi0, + n_MC_samples = 1e1 ) # Gaussian 1e2 approach @@ -118,8 +118,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # ctree approach @@ -128,8 +128,8 @@ explanation_ctree <- explain( x_explain = x_explain, x_train = x_train, approach = "ctree", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Combined approach @@ -138,8 +138,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "independence", "ctree"), - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Create a list of explanations with names @@ -152,24 +152,24 @@ explanation_list_named <- list( ) if (requireNamespace("ggplot2", quietly = TRUE)) { - # Create the default MSEv plot where we average over both the combinations and observations + # Create the default MSEv plot where we average over both the coalitions and observations # with approximate 95\% confidence intervals plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = "overall") - # Can also create plots of the MSEv criterion averaged only over the combinations or observations. + # Can also create plots of the MSEv criterion averaged only over the coalitions or observations. MSEv_figures <- plot_MSEv_eval_crit(explanation_list_named, CI_level = 0.95, plot_type = c("overall", "comb", "explicand") ) MSEv_figures$MSEv_bar - MSEv_figures$MSEv_combination_bar + MSEv_figures$MSEv_coalition_bar MSEv_figures$MSEv_explicand_bar - # When there are many combinations or observations, then it can be easier to look at line plots - MSEv_figures$MSEv_combination_line_point + # When there are many coalitions or observations, then it can be easier to look at line plots + MSEv_figures$MSEv_coalition_line_point MSEv_figures$MSEv_explicand_line_point - # We can specify which observations or combinations to plot + # We can specify which observations or coalitions to plot plot_MSEv_eval_crit(explanation_list_named, plot_type = "explicand", index_x_explain = c(1, 3:4, 6), @@ -177,9 +177,9 @@ if (requireNamespace("ggplot2", quietly = TRUE)) { )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 - )$MSEv_combination_bar + )$MSEv_coalition_bar # We can alter the figures if other palette schemes or design is wanted bar_text_n_decimals <- 1 diff --git a/man/plot_SV_several_approaches.Rd b/man/plot_SV_several_approaches.Rd index 274b1a608594cb402e88e87dc3e5e810d497e82a..2fcfd1111093f88c170851e032a7f69c263c3105 100644 --- a/man/plot_SV_several_approaches.Rd +++ b/man/plot_SV_several_approaches.Rd @@ -7,6 +7,7 @@ plot_SV_several_approaches( explanation_list, index_explicands = NULL, + index_explicands_sort = FALSE, only_these_features = NULL, plot_phi0 = FALSE, digits = 4, @@ -17,7 +18,8 @@ plot_SV_several_approaches( facet_scales = "free", facet_ncol = 2, geom_col_width = 0.85, - brewer_palette = NULL + brewer_palette = NULL, + include_group_feature_means = FALSE ) } \arguments{ @@ -27,7 +29,12 @@ the approach names (with integer suffix for duplicates) for the explanation obje \item{index_explicands}{Integer vector. Which of the explicands (test observations) to plot. E.g. if you have explained 10 observations using \code{\link[=explain]{explain()}}, you can generate a plot for the -first 5 observations/explicands and the 10th by setting \code{index_x_explain = c(1:5, 10)}.} +first 5 observations/explicands and the 10th by setting \code{index_x_explain = c(1:5, 10)}. +The argument \code{index_explicands_sort} must be \code{FALSE} to plot the explicand +in the order specified in \code{index_x_explain}.} + +\item{index_explicands_sort}{Boolean. If \code{FALSE} (default), then \code{shapr} plots the explicands in the order +specified in \code{index_explicands}. If \code{TRUE}, then \code{shapr} sort the indices in incressing oreder based on their id.} \item{only_these_features}{String vector. Containing the names of the features which are to be included in the bar plots.} @@ -65,13 +72,18 @@ The following palettes are available for use with these scales: \item{Sequential}{Blues, BuGn, BuPu, GnBu, Greens, Greys, Oranges, OrRd, PuBu, PuBuGn, PuRd, Purples, RdPu, Reds, YlGn, YlGnBu, YlOrBr, YlOrRd} }} + +\item{include_group_feature_means}{Logical. Whether to include the average feature value in a group on the +y-axis or not. If \code{FALSE} (default), then no value is shown for the groups. If \code{TRUE}, then \code{shapr} includes +the mean of the features in each group.} } \value{ A \code{\link[ggplot2:ggplot]{ggplot2::ggplot()}} object. } \description{ Make plots to visualize and compare the estimated Shapley values for a list of -\code{\link[=explain]{explain()}} objects applied to the same data and model. +\code{\link[=explain]{explain()}} objects applied to the same data and model. For group-wise Shapley values, +the features values plotted are the mean feature values for all features in each group. } \examples{ # Load necessary libraries @@ -102,7 +114,7 @@ model <- xgboost::xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Independence approach explanation_independence <- explain( @@ -110,8 +122,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Empirical approach @@ -120,8 +132,8 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Gaussian 1e1 approach @@ -130,8 +142,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e1 + phi0 = phi0, + n_MC_samples = 1e1 ) # Gaussian 1e2 approach @@ -140,8 +152,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Combined approach @@ -150,8 +162,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "ctree", "empirical"), - prediction_zero = prediction_zero, - n_samples = 1e2 + phi0 = phi0, + n_MC_samples = 1e2 ) # Create a list of explanations with names diff --git a/man/prepare_data.Rd b/man/prepare_data.Rd index d7d6d7f39aa0a2ff05a593e2208387239d97936c..827e3cee58a5df37578270a0ba7bd8dbd461be10 100644 --- a/man/prepare_data.Rd +++ b/man/prepare_data.Rd @@ -41,10 +41,11 @@ prepare_data(internal, index_features = NULL, ...) \method{prepare_data}{vaeac}(internal, index_features = NULL, ...) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} \item{...}{Currently not used.} } @@ -56,6 +57,8 @@ the contribution function by Monte Carlo integration. Generate data used for predictions and Monte Carlo integration } \author{ +Annabelle Redelmeier and Lars Henry Berge Olsen + Lars Henry Berge Olsen } \keyword{internal} diff --git a/man/prepare_data_causal.Rd b/man/prepare_data_causal.Rd new file mode 100644 index 0000000000000000000000000000000000000000..47c62b1400ede845fb12674e8fbf9ab67251c2ba --- /dev/null +++ b/man/prepare_data_causal.Rd @@ -0,0 +1,38 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/asymmetric_and_casual_Shapley.R +\name{prepare_data_causal} +\alias{prepare_data_causal} +\title{Generate data used for predictions and Monte Carlo integration for causal Shapley values} +\usage{ +prepare_data_causal(internal, index_features = NULL, ...) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} + +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} + +\item{...}{Currently not used.} +} +\value{ +A data.table containing simulated data that respects the (partial) causal ordering and the +the confounding assumptions. The data is used to estimate the contribution function by Monte Carlo integration. +} +\description{ +This function loops over the given coalitions, and for each coalition it extracts the +chain of relevant sampling steps provided in \code{internal$object$S_causal}. This chain +can contain sampling from marginal and conditional distributions. We use the approach given by +\code{internal$parameters$approach} to generate the samples from the conditional distributions, and +we iteratively call \code{prepare_data()} with a modified \code{internal_copy} list to reuse code. +However, this also means that chains with the same conditional distributions will retrain a +model of said conditional distributions several times. +For the marginal distribution, we sample from the Gaussian marginals when the approach is +\code{gaussian} and from the marginals of the training data for all other approaches. Note that +we could extend the code to sample from the marginal (gaussian) copula, too, when \code{approach} is +\code{copula}. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_copula_cpp.Rd b/man/prepare_data_copula_cpp.Rd index ca901031d03f47960ce9a6f94285c25b669c2490..ce3aafeb3e79f0f48cf335fd507acf13fe4767de 100644 --- a/man/prepare_data_copula_cpp.Rd +++ b/man/prepare_data_copula_cpp.Rd @@ -15,7 +15,7 @@ prepare_data_copula_cpp( ) } \arguments{ -\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_samples}, \code{n_features}) containing samples from the +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_MC_samples}, \code{n_features}) containing samples from the univariate standard normal.} \item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations @@ -27,7 +27,7 @@ transformed to a standardized normal distribution.} \item{x_train_mat}{arma::mat. Matrix of dimension (\code{n_train}, \code{n_features}) containing the training observations.} -\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. This is not a problem internally in shapr as the empty and grand coalitions treated differently.} @@ -39,8 +39,8 @@ between all pairs of features after being transformed using the Gaussian transfo transformed to a standardized normal distribution.} } \value{ -An arma::cube/3D array of dimension (\code{n_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where -the columns (\emph{,j,}) are matrices of dimension (\code{n_samples}, \code{n_features}) containing the conditional Gaussian +An arma::cube/3D array of dimension (\code{n_MC_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where +the columns (\emph{,j,}) are matrices of dimension (\code{n_MC_samples}, \code{n_features}) containing the conditional Gaussian copula MC samples for each explicand and coalition on the original scale. } \description{ diff --git a/man/prepare_data_copula_cpp_caus.Rd b/man/prepare_data_copula_cpp_caus.Rd new file mode 100644 index 0000000000000000000000000000000000000000..d70b3ad78800554b37d2b6a25ee83d879c492c11 --- /dev/null +++ b/man/prepare_data_copula_cpp_caus.Rd @@ -0,0 +1,52 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{prepare_data_copula_cpp_caus} +\alias{prepare_data_copula_cpp_caus} +\title{Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand} +\usage{ +prepare_data_copula_cpp_caus( + MC_samples_mat, + x_explain_mat, + x_explain_gaussian_mat, + x_train_mat, + S, + mu, + cov_mat +) +} +\arguments{ +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing samples from the +univariate standard normal. The i'th row will be applied to the i'th row in \code{x_explain_mat}.} + +\item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations to +explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in \code{MC_samples_mat}.} + +\item{x_explain_gaussian_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the +observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +transformed to a standardized normal distribution.} + +\item{x_train_mat}{arma::mat. Matrix of dimension (\code{n_train}, \code{n_features}) containing the training observations.} + +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of +the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +This is not a problem internally in shapr as the empty and grand coalitions treated differently.} + +\item{mu}{arma::vec. Vector of length \code{n_features} containing the mean of each feature after being transformed +using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution.} + +\item{cov_mat}{arma::mat. Matrix of dimension (\code{n_features}, \code{n_features}) containing the pairwise covariance +between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +transformed to a standardized normal distribution.} +} +\value{ +An arma::mat/2D array of dimension (\code{n_explain} * \code{n_coalitions}, \code{n_features}), +where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +conditional Gaussian MC samples for each explicand and \code{S_ind} coalition. +} +\description{ +Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_gaussian_cpp.Rd b/man/prepare_data_gaussian_cpp.Rd index b24b431e6a7e793d5d3ebb7842cef52bcd3f9176..095769cf0f4e0335b4b1314135adc9706712bba1 100644 --- a/man/prepare_data_gaussian_cpp.Rd +++ b/man/prepare_data_gaussian_cpp.Rd @@ -7,13 +7,13 @@ prepare_data_gaussian_cpp(MC_samples_mat, x_explain_mat, S, mu, cov_mat) } \arguments{ -\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_samples}, \code{n_features}) containing samples from the +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_MC_samples}, \code{n_features}) containing samples from the univariate standard normal.} \item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations to explain.} -\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +\item{S}{arma::mat. Matrix of dimension (\code{n_coalitions}, \code{n_features}) containing binary representations of the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. This is not a problem internally in shapr as the empty and grand coalitions treated differently.} @@ -23,8 +23,8 @@ This is not a problem internally in shapr as the empty and grand coalitions trea between all pairs of features.} } \value{ -An arma::cube/3D array of dimension (\code{n_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where -the columns (\emph{,j,}) are matrices of dimension (\code{n_samples}, \code{n_features}) containing the conditional Gaussian +An arma::cube/3D array of dimension (\code{n_MC_samples}, \code{n_explain} * \code{n_coalitions}, \code{n_features}), where +the columns (\emph{,j,}) are matrices of dimension (\code{n_MC_samples}, \code{n_features}) containing the conditional Gaussian MC samples for each explicand and coalition. } \description{ diff --git a/man/prepare_data_gaussian_cpp_caus.Rd b/man/prepare_data_gaussian_cpp_caus.Rd new file mode 100644 index 0000000000000000000000000000000000000000..33cc1835f59cca847e3243a595416742af9e14f2 --- /dev/null +++ b/man/prepare_data_gaussian_cpp_caus.Rd @@ -0,0 +1,36 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/RcppExports.R +\name{prepare_data_gaussian_cpp_caus} +\alias{prepare_data_gaussian_cpp_caus} +\title{Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand} +\usage{ +prepare_data_gaussian_cpp_caus(MC_samples_mat, x_explain_mat, S, mu, cov_mat) +} +\arguments{ +\item{MC_samples_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing samples from the +univariate standard normal. The i'th row will be applied to the i'th row in \code{x_explain_mat}.} + +\item{x_explain_mat}{arma::mat. Matrix of dimension (\code{n_explain}, \code{n_features}) containing the observations +to explain. The MC sample for the i'th explicand is based on the i'th row in \code{MC_samples_mat}} + +\item{S}{arma::mat. Matrix of dimension (\code{n_combinations}, \code{n_features}) containing binary representations of +the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +This is not a problem internally in shapr as the empty and grand coalitions treated differently.} + +\item{mu}{arma::vec. Vector of length \code{n_features} containing the mean of each feature.} + +\item{cov_mat}{arma::mat. Matrix of dimension (\code{n_features}, \code{n_features}) containing the pairwise covariance +between all pairs of features.} +} +\value{ +An arma::mat/2D array of dimension (\code{n_explain} * \code{n_coalitions}, \code{n_features}), +where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +conditional Gaussian MC samples for each explicand and \code{S_ind} coalition. +} +\description{ +Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_data_single_coalition.Rd b/man/prepare_data_single_coalition.Rd new file mode 100644 index 0000000000000000000000000000000000000000..9bd170b2b771de172819f6c0200d129d24edfe54 --- /dev/null +++ b/man/prepare_data_single_coalition.Rd @@ -0,0 +1,22 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/approach_categorical.R +\name{prepare_data_single_coalition} +\alias{prepare_data_single_coalition} +\title{Compute the conditional probabilities for a single coalition for the categorical approach} +\usage{ +prepare_data_single_coalition(internal, index_features) +} +\arguments{ +\item{internal}{List. +Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} +} +\description{ +The \code{\link[=prepare_data.categorical]{prepare_data.categorical()}} function is slow when evaluated for a single coalition. +This is a bottleneck for Causal Shapley values which call said function a lot with single coalitions. +} +\author{ +Lars Henry Berge Olsen +} +\keyword{internal} diff --git a/man/prepare_next_iteration.Rd b/man/prepare_next_iteration.Rd new file mode 100644 index 0000000000000000000000000000000000000000..996a7330d9c7109966268966a05cc5119c7e6651 --- /dev/null +++ b/man/prepare_next_iteration.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prepare_next_iteration.R +\name{prepare_next_iteration} +\alias{prepare_next_iteration} +\title{Prepares the next iteration of the iterative sampling algorithm} +\usage{ +prepare_next_iteration(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Prepares the next iteration of the iterative sampling algorithm +} +\keyword{internal} diff --git a/man/print_iter.Rd b/man/print_iter.Rd new file mode 100644 index 0000000000000000000000000000000000000000..abab85a3b447fffab4aeb6db2495c63dc71589dd --- /dev/null +++ b/man/print_iter.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/print_iter.R +\name{print_iter} +\alias{print_iter} +\title{Prints iterative information} +\usage{ +print_iter(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Prints iterative information +} +\keyword{internal} diff --git a/man/regression.check_parameters.Rd b/man/regression.check_parameters.Rd index fbe747374bf784af9731330d5a444ee6880d5741..55c2f3e22e2127c07390eb2dee34569d57681c58 100644 --- a/man/regression.check_parameters.Rd +++ b/man/regression.check_parameters.Rd @@ -9,7 +9,8 @@ regression.check_parameters(internal) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} } \value{ The same \code{internal} list, but added logical indicator \code{internal$parameters$regression.tune} diff --git a/man/regression.check_sur_n_comb.Rd b/man/regression.check_sur_n_comb.Rd index 1ede6d346015ffbb6ab9842a29f3de1eec2c82f9..3160bdae957165b7096dc5cb0b3aa6d3874c390c 100644 --- a/man/regression.check_sur_n_comb.Rd +++ b/man/regression.check_sur_n_comb.Rd @@ -4,18 +4,22 @@ \alias{regression.check_sur_n_comb} \title{Check the \code{regression.surrogate_n_comb} parameter} \usage{ -regression.check_sur_n_comb(regression.surrogate_n_comb, used_n_combinations) +regression.check_sur_n_comb(regression.surrogate_n_comb, n_coalitions) } \arguments{ -\item{regression.surrogate_n_comb}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} +\item{regression.surrogate_n_comb}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} -\item{used_n_combinations}{Integer. The number of used combinations (including the empty and grand coalitions).} +\item{n_coalitions}{Integer. The number of used coalitions (including the empty and grand coalition).} } \description{ Check that \code{regression.surrogate_n_comb} is either NULL or a valid integer. diff --git a/man/regression.cv_message.Rd b/man/regression.cv_message.Rd index 145e514a070190e83fbbb13c286426ab9942bb45..2826b11bbb656e4e90be482a04342fc43533ec0f 100644 --- a/man/regression.cv_message.Rd +++ b/man/regression.cv_message.Rd @@ -4,7 +4,12 @@ \alias{regression.cv_message} \title{Produce message about which batch prepare_data is working on} \usage{ -regression.cv_message(regression.results, regression.grid, n_cv = 10) +regression.cv_message( + regression.results, + regression.grid, + n_cv = 10, + current_comb +) } \arguments{ \item{regression.results}{The results of the CV procedures.} diff --git a/man/regression.get_tune.Rd b/man/regression.get_tune.Rd index 7c5440741c9216f08c2689a26d4bdb6930e9fb3d..148c36a93ef863e738fb365ee3275e854814da75 100644 --- a/man/regression.get_tune.Rd +++ b/man/regression.get_tune.Rd @@ -18,8 +18,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, diff --git a/man/regression.get_y_hat.Rd b/man/regression.get_y_hat.Rd index 6b03d3d49d868c69657ff586200c6fe37ab53d94..9eff9cbd5d70d9bf4ad71f0fbd96e8c564a80683 100644 --- a/man/regression.get_y_hat.Rd +++ b/man/regression.get_y_hat.Rd @@ -9,7 +9,8 @@ regression.get_y_hat(internal, model, predict_model) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/regression.prep_message_batch.Rd b/man/regression.prep_message_batch.Rd deleted file mode 100644 index 9b8a942e20144b470ce8fd391a10f24c3de9a8ca..0000000000000000000000000000000000000000 --- a/man/regression.prep_message_batch.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.prep_message_batch} -\alias{regression.prep_message_batch} -\title{Produce message about which batch prepare_data is working on} -\usage{ -regression.prep_message_batch(internal, index_features) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} -} -\description{ -Produce message about which batch prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.prep_message_comb.Rd b/man/regression.prep_message_comb.Rd deleted file mode 100644 index 84739b82a78f0459d1f33a1fb3872b738c91981a..0000000000000000000000000000000000000000 --- a/man/regression.prep_message_comb.Rd +++ /dev/null @@ -1,25 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.prep_message_comb} -\alias{regression.prep_message_comb} -\title{Produce message about which combination prepare_data is working on} -\usage{ -regression.prep_message_comb(internal, index_features, comb_idx) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} - -\item{comb_idx}{Integer. The index of the combination in a specific batch.} -} -\description{ -Produce message about which combination prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.separate_time_mess.Rd b/man/regression.separate_time_mess.Rd deleted file mode 100644 index cf0438000d6e919fb862d1bcb74840bf2dfc7ee9..0000000000000000000000000000000000000000 --- a/man/regression.separate_time_mess.Rd +++ /dev/null @@ -1,15 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_regression_separate.R -\name{regression.separate_time_mess} -\alias{regression.separate_time_mess} -\title{Produce time message for separate regression} -\usage{ -regression.separate_time_mess() -} -\description{ -Produce time message for separate regression -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/regression.surrogate_aug_data.Rd b/man/regression.surrogate_aug_data.Rd index 8ebd0ccbd9ca5e8d9e050fbaa0c7a6bec044e857..2acb6c41976bff7ccc1e4293c0215b28d9a51f37 100644 --- a/man/regression.surrogate_aug_data.Rd +++ b/man/regression.surrogate_aug_data.Rd @@ -11,7 +11,7 @@ regression.surrogate_aug_data( index_features = NULL, augment_masks_as_factor = FALSE, augment_include_grand = FALSE, - augment_add_id_comb = FALSE, + augment_add_id_coal = FALSE, augment_comb_prob = NULL, augment_weights = NULL ) @@ -19,7 +19,8 @@ regression.surrogate_aug_data( \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{x}{Data.table containing the data. Either the training data or the explicands. If \code{x} is the explicands, then \code{index_features} must be provided.} @@ -34,20 +35,20 @@ to factors. If \code{FALSE}, then the binary masks are numerics.} \item{augment_include_grand}{Logical (default is \code{FALSE}). If \code{TRUE}, then the grand coalition is included. If \code{index_features} are provided, then \code{augment_include_grand} has no effect. Note that if we sample the -combinations then the grand coalition is equally likely to be samples as the other coalitions (or weighted if +coalitions then the grand coalition is equally likely to be samples as the other coalitions (or weighted if \code{augment_comb_prob} is provided).} -\item{augment_add_id_comb}{Logical (default is \code{FALSE}). If \code{TRUE}, an additional column is adding containing +\item{augment_add_id_coal}{Logical (default is \code{FALSE}). If \code{TRUE}, an additional column is adding containing which coalition was applied.} \item{augment_comb_prob}{Array of numerics (default is \code{NULL}). The length of the array must match the number of -combinations being considered, where each entry specifies the probability of sampling the corresponding coalition. +coalitions being considered, where each entry specifies the probability of sampling the corresponding coalition. This is useful if we want to generate more training data for some specific coalitions. One possible choice would be -\code{augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_combinations] else NULL}.} +\code{augment_comb_prob = if (use_Shapley_weights) internal$objects$X$shapley_weight[2:actual_n_coalitions] else NULL}.} \item{augment_weights}{String (optional). Specifying which type of weights to add to the observations. If \code{NULL} (default), then no weights are added. If \code{"Shapley"}, then the Shapley weights for the different -combinations are added to corresponding observations where the coalitions was applied. If \code{uniform}, then +coalitions are added to corresponding observations where the coalitions was applied. If \code{uniform}, then all observations get an equal weight of one.} } \value{ diff --git a/man/regression.train_model.Rd b/man/regression.train_model.Rd index 8ee6b669a7379cc8b8b8f0ac5a457f0f40cd09fb..6d5c0807eedaf161d4c1243fa3c601223d6aae95 100644 --- a/man/regression.train_model.Rd +++ b/man/regression.train_model.Rd @@ -7,14 +7,15 @@ regression.train_model( x, seed = 1, - verbose = 0, + verbose = NULL, regression.model = parsnip::linear_reg(), regression.tune = FALSE, regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, regression.response_var = "y_hat", - regression.surrogate_n_comb = NULL + regression.surrogate_n_comb = NULL, + current_comb = NULL ) } \arguments{ @@ -23,12 +24,24 @@ then \code{index_features} must be provided.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} +If \code{NULL} no seed is set in the calling environment.} -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{regression.model}{A \code{tidymodels} object of class \code{model_specs}. Default is a linear regression model, i.e., \code{\link[parsnip:linear_reg]{parsnip::linear_reg()}}. See \href{https://www.tidymodels.org/find/parsnip/}{tidymodels} for all possible models, @@ -45,8 +58,8 @@ the values provided in \code{regression.tune_values}. Note that no checks are co The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, diff --git a/man/sample_ctree.Rd b/man/sample_ctree.Rd index f95f743836442bb9e9d2bea527ddea1c448167b8..4bd10b52d958fb017d885f50e2014b5957486efc 100644 --- a/man/sample_ctree.Rd +++ b/man/sample_ctree.Rd @@ -4,13 +4,13 @@ \alias{sample_ctree} \title{Sample ctree variables from a given conditional inference tree} \usage{ -sample_ctree(tree, n_samples, x_explain, x_train, n_features, sample) +sample_ctree(tree, n_MC_samples, x_explain, x_train, n_features, sample) } \arguments{ \item{tree}{List. Contains tree which is an object of type ctree built from the party package. Also contains given_ind, the features to condition upon.} -\item{n_samples}{Numeric. Indicates how many samples to use for MCMC.} +\item{n_MC_samples}{Numeric. Indicates how many samples to use for MCMC.} \item{x_explain}{Matrix, data.frame or data.table with the features of the observation whose predictions ought to be explained (test data). Dimension \verb{1\\timesp} or \verb{p\\times1}.} @@ -21,10 +21,10 @@ predictions ought to be explained (test data). Dimension \verb{1\\timesp} or \ve \item{sample}{Boolean. True indicates that the method samples from the terminal node of the tree whereas False indicates that the method takes all the observations if it is -less than n_samples.} +less than n_MC_samples.} } \value{ -data.table with \code{n_samples} (conditional) Gaussian samples +data.table with \code{n_MC_samples} (conditional) Gaussian samples } \description{ Sample ctree variables from a given conditional inference tree diff --git a/man/save_results.Rd b/man/save_results.Rd new file mode 100644 index 0000000000000000000000000000000000000000..fa15361720b714e7be868eaae740541ca4ae25c0 --- /dev/null +++ b/man/save_results.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/save_results.R +\name{save_results} +\alias{save_results} +\title{Saves the itermediate results to disk} +\usage{ +save_results(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Saves the itermediate results to disk +} +\keyword{internal} diff --git a/man/setup.Rd b/man/setup.Rd index fce91a6b078134bebfcf61642db8ad334ea84f61..dec833a20c223e592d0dadff8a204ce7a239f354 100644 --- a/man/setup.Rd +++ b/man/setup.Rd @@ -8,16 +8,14 @@ setup( x_train, x_explain, approach, - prediction_zero, + paired_shap_sampling = TRUE, + phi0, output_size = 1, - n_combinations, + max_n_coalitions, group, - n_samples, - n_batches, + n_MC_samples, seed, - keep_samp_for_vS, feature_specs, - MSEv_uniform_comb_weights = TRUE, type = "normal", horizon = NULL, y = NULL, @@ -27,9 +25,19 @@ setup( explain_y_lags = NULL, explain_xreg_lags = NULL, group_lags = NULL, - timing, verbose, + iterative = NULL, + iterative_args = list(), + kernelSHAP_reweighting = "none", is_python = FALSE, + testing = FALSE, + init_time = NULL, + prev_shapr_object = NULL, + asymmetric = FALSE, + causal_ordering = NULL, + confounding = NULL, + output_args = list(), + extra_computation_args = list(), ... ) } @@ -46,7 +54,12 @@ All elements should, either be \code{"gaussian"}, \code{"copula"}, \code{"empiri \code{"categorical"}, \code{"timeseries"}, \code{"independence"}, \code{"regression_separate"}, or \code{"regression_surrogate"}. The two regression approaches can not be combined with any other approach. See details for more information.} -\item{prediction_zero}{Numeric. +\item{paired_shap_sampling}{Logical. +If \code{TRUE} (default), paired versions of all sampled coalitions are also included in the computation. +That is, if there are 5 features and e.g. coalitions (1,3,5) are sampled, then also coalition (2,4) is used for +computing the Shapley values. This is done to reduce the variance of the Shapley value estimates.} + +\item{phi0}{Numeric. The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices @@ -54,11 +67,13 @@ such as the mean of the predictions in the training data are also reasonable.} \item{output_size}{TODO: Document} -\item{n_combinations}{Integer. -If \code{group = NULL}, \code{n_combinations} represents the number of unique feature combinations to sample. -If \code{group != NULL}, \code{n_combinations} represents the number of unique group combinations to sample. -If \code{n_combinations = NULL}, the exact method is used and all combinations are considered. -The maximum number of combinations equals \code{2^m}, where \code{m} is the number of features.} +\item{max_n_coalitions}{Integer. +The upper limit on the number of unique feature/group coalitions to use in the iterative procedure +(if \code{iterative = TRUE}). +If \code{iterative = FALSE} it represents the number of feature/group coalitions to use directly. +The quantity refers to the number of unique feature coalitions if \code{group = NULL}, +and group coalitions if \code{group != NULL}. +\code{max_n_coalitions = NULL} corresponds to \code{max_n_coalitions=2^n_features}.} \item{group}{List. If \code{NULL} regular feature wise Shapley values are computed. @@ -66,25 +81,17 @@ If provided, group wise Shapley values are computed. \code{group} then has lengt the number of groups. The list element contains character vectors with the features included in each of the different groups.} -\item{n_samples}{Positive integer. -Indicating the maximum number of samples to use in the -Monte Carlo integration for every conditional expectation. See also details.} - -\item{n_batches}{Positive integer (or NULL). -Specifies how many batches the total number of feature combinations should be split into when calculating the -contribution function for each test observation. -The default value is NULL which uses a reasonable trade-off between RAM allocation and computation speed, -which depends on \code{approach} and \code{n_combinations}. -For models with many features, increasing the number of batches reduces the RAM allocation significantly. -This typically comes with a small increase in computation time.} +\item{n_MC_samples}{Positive integer. +Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. +For \code{approach="ctree"}, \code{n_MC_samples} corresponds to the number of samples +from the leaf node (see an exception related to the \code{ctree.sample} argument \code{\link[=setup_approach.ctree]{setup_approach.ctree()}}). +For \code{approach="empirical"}, \code{n_MC_samples} is the \eqn{K} parameter in equations (14-15) of +Aas et al. (2021), i.e. the maximum number of observations (with largest weights) that is used, see also the +\code{empirical.eta} argument \code{\link[=setup_approach.empirical]{setup_approach.empirical()}}.} \item{seed}{Positive integer. Specifies the seed before any randomness based code is being run. -If \code{NULL} the seed will be inherited from the calling environment.} - -\item{keep_samp_for_vS}{Logical. -Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned -(in \code{internal$output})} +If \code{NULL} no seed is set in the calling environment.} \item{feature_specs}{List. The output from \code{\link[=get_model_specs]{get_model_specs()}} or \code{\link[=get_data_specs]{get_data_specs()}}. Contains the 3 elements: @@ -94,11 +101,6 @@ Contains the 3 elements: \item{factor_levels}{Character vector with the levels for any categorical features.} }} -\item{MSEv_uniform_comb_weights}{Logical. If \code{TRUE} (default), then the function weights the combinations -uniformly when computing the MSEv criterion. If \code{FALSE}, then the function use the Shapley kernel weights to -weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by the -sampling frequency when not all combinations are considered.} - \item{type}{Character. Either "normal" or "forecast" corresponding to function \code{setup()} is called from, correspondingly the type of explanation that should be generated.} @@ -136,18 +138,114 @@ If \code{xreg != NULL}, denotes the number of lags that should be used for each If \code{TRUE} all lags of each variable are grouped together and explained as a group. If \code{FALSE} all lags of each variable are explained individually.} -\item{timing}{Logical. -Whether the timing of the different parts of the \code{explain()} should saved in the model object.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} -\item{verbose}{An integer specifying the level of verbosity. If \code{0}, \code{shapr} will stay silent. -If \code{1}, it will print information about performance. If \code{2}, some additional information will be printed out. -Use \code{0} (default) for no verbosity, \code{1} for low verbose, and \code{2} for high verbose. -TODO: Make this clearer when we end up fixing this and if they should force a progressr bar.} +\item{iterative}{Logical or NULL +If \code{NULL} (default), the argument is set to \code{TRUE} if there are more than 5 features/groups, and \code{FALSE} otherwise. +If eventually \code{TRUE}, the Shapley values are estimated iteratively in an iterative manner. +This provides sufficiently accurate Shapley value estimates faster. +First an initial number of coalitions is sampled, then bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through \code{iterative_args}.} + +\item{iterative_args}{Named list. +Specifices the arguments for the iterative procedure. +See \code{\link[=get_iterative_args_default]{get_iterative_args_default()}} for description of the arguments and their default values.} + +\item{kernelSHAP_reweighting}{String. +How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing +the randomness and thereby the variance of the Shapley value estimates. +One of \code{'none'}, \code{'on_N'}, \code{'on_all'}, \code{'on_all_cond'} (default). +\code{'none'} means no reweighting, i.e. the sampling frequency weights are used as is. +\code{'on_coal_size'} means the sampling frequencies are averaged over all coalitions of the same size. +\code{'on_N'} means the sampling frequencies are averaged over all coalitions with the same original sampling +probabilities. +\code{'on_all'} means the original sampling probabilities are used for all coalitions. +\code{'on_all_cond'} means the original sampling probabilities are used for all coalitions, while adjusting for the +probability that they are sampled at least once. +This method is preferred as it has performed the best in simulation studies.} \item{is_python}{Logical. Indicates whether the function is called from the Python wrapper. Default is FALSE which is never changed when calling the function via \code{explain()} in R. The parameter is later used to disallow running the AICc-versions of the empirical as that requires data based optimization.} +\item{testing}{Logical. +Only use to remove random components like timing from the object output when comparing output with testthat. +Defaults to \code{FALSE}.} + +\item{init_time}{POSIXct object. +The time when the \code{explain()} function was called, as outputted by \code{Sys.time()}. +Used to calculate the time it took to run the full \code{explain} call.} + +\item{prev_shapr_object}{\code{shapr} object or string. +If an object of class \code{shapr} is provided or string with a path to where intermediate results are strored, +then the function will use the previous object to continue the computation. +This is useful if the computation is interrupted or you want higher accuracy than already obtained, and therefore +want to continue the iterative estimation. See the vignette for examples.} + +\item{asymmetric}{Logical. +Not applicable for (regular) non-causal or asymmetric explanations. +If \code{FALSE} (default), \code{explain} computes regular symmetric Shapley values, +If \code{TRUE}, then \code{explain} compute asymmetric Shapley values based on the (partial) causal ordering +given by \code{causal_ordering}. That is, \code{explain} only uses the feature combinations/coalitions that +respect the causal ordering when computing the asymmetric Shapley values. If \code{asymmetric} is \code{TRUE} and +\code{confounding} is \code{NULL} (default), then \code{explain} computes asymmetric conditional Shapley values as specified in +Frye et al. (2020). If \code{confounding} is provided, i.e., not \code{NULL}, then \code{explain} computes asymmetric causal +Shapley values as specified in Heskes et al. (2020).} + +\item{causal_ordering}{List. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{causal_ordering} is an unnamed list of vectors specifying the components of the +partial causal ordering that the coalitions must respect. Each vector represents +a component and contains one or more features/groups identified by their names +(strings) or indices (integers). If \code{causal_ordering} is \code{NULL} (default), no causal +ordering is assumed and all possible coalitions are allowed. No causal ordering is +equivalent to a causal ordering with a single component that includes all features +(\code{list(1:n_features)}) or groups (\code{list(1:n_groups)}) for feature-wise and group-wise +Shapley values, respectively. For feature-wise Shapley values and +\code{causal_ordering = list(c(1, 2), c(3, 4))}, the interpretation is that features 1 and 2 +are the ancestors of features 3 and 4, while features 3 and 4 are on the same level. +Note: All features/groups must be included in the \code{causal_ordering} without any duplicates.} + +\item{confounding}{Logical vector. +Not applicable for (regular) non-causal or asymmetric explanations. +\code{confounding} is a vector of logicals specifying whether confounding is assumed or not for each component in the +\code{causal_ordering}. If \code{NULL} (default), then no assumption about the confounding structure is made and \code{explain} +computes asymmetric/symmetric conditional Shapley values, depending on the value of \code{asymmetric}. +If \code{confounding} is a single logical, i.e., \code{FALSE} or \code{TRUE}, then this assumption is set globally +for all components in the causal ordering. Otherwise, \code{confounding} must be a vector of logicals of the same +length as \code{causal_ordering}, indicating the confounding assumption for each component. When \code{confounding} is +specified, then \code{explain} computes asymmetric/symmetric causal Shapley values, depending on the value of +\code{asymmetric}. The \code{approach} cannot be \code{regression_separate} and \code{regression_surrogate} as the +regression-based approaches are not applicable to the causal Shapley value methodology.} + +\item{output_args}{Named list. +Specifices certain arguments related to the output of the function. +See \code{\link[=get_output_args_default]{get_output_args_default()}} for description of the arguments and their default values.} + +\item{extra_computation_args}{Named list. +Specifices extra arguments related to the computation of the Shapley values. +See \code{\link[=get_extra_est_args_default]{get_extra_est_args_default()}} for description of the arguments and their default values.} + \item{...}{Further arguments passed to specific approaches} } \description{ diff --git a/man/setup_approach.Rd b/man/setup_approach.Rd index cf1ee8d0df999d8e4f86164e5da4e77ebc31a686..e7781040faf3d0d746dcd711439e4da1cc5ffa8a 100644 --- a/man/setup_approach.Rd +++ b/man/setup_approach.Rd @@ -71,7 +71,8 @@ setup_approach(internal, ...) regression.tune_values = NULL, regression.vfold_cv_para = NULL, regression.recipe_func = NULL, - regression.surrogate_n_comb = internal$parameters$used_n_combinations - 2, + regression.surrogate_n_comb = + internal$iter_list[[length(internal$iter_list)]]$n_coalitions - 2, ... ) @@ -96,7 +97,8 @@ setup_approach(internal, ...) ) } \arguments{ -\item{internal}{Not used.} +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} \item{...}{\code{approach}-specific arguments. See below.} @@ -107,7 +109,7 @@ values. \item{categorical.epsilon}{Numeric value. (Optional) If \code{joint_probability_dt} is not supplied, probabilities/frequencies are -estimated using \code{x_train}. If certain observations occur in \code{x_train} and NOT in \code{x_explain}, +estimated using \code{x_train}. If certain observations occur in \code{x_explain} and NOT in \code{x_train}, then epsilon is used as the proportion of times that these observations occurs in the training data. In theory, this proportion should be zero, but this causes an error later in the Shapley computation.} @@ -123,13 +125,13 @@ Determines minimum value that the sum of the left and right daughter nodes requi Determines the minimum sum of weights in a terminal node required for a split} \item{ctree.sample}{Boolean. (default = TRUE) -If TRUE, then the method always samples \code{n_samples} observations from the leaf nodes (with replacement). -If FALSE and the number of observations in the leaf node is less than \code{n_samples}, +If TRUE, then the method always samples \code{n_MC_samples} observations from the leaf nodes (with replacement). +If FALSE and the number of observations in the leaf node is less than \code{n_MC_samples}, the method will take all observations in the leaf. -If FALSE and the number of observations in the leaf node is more than \code{n_samples}, -the method will sample \code{n_samples} observations (with replacement). +If FALSE and the number of observations in the leaf node is more than \code{n_MC_samples}, +the method will sample \code{n_MC_samples} observations (with replacement). This means that there will always be sampling in the leaf unless -\code{sample} = FALSE AND the number of obs in the node is less than \code{n_samples}.} +\code{sample} = FALSE AND the number of obs in the node is less than \code{n_MC_samples}.} \item{empirical.type}{Character. (default = \code{"fixed_sigma"}) Should be equal to either \code{"independence"},\code{"fixed_sigma"}, \code{"AICc_each_k"} \code{"AICc_full"}. @@ -143,7 +145,7 @@ accounts for 80\\% of the total weight. \code{eta} is the \eqn{\eta} parameter in equation (15) of Aas et al (2021).} \item{empirical.fixed_sigma}{Positive numeric scalar. (default = 0.1) -Represents the kernel bandwidth in the distance computation used when conditioning on all different combinations. +Represents the kernel bandwidth in the distance computation used when conditioning on all different coalitions. Only used when \code{empirical.type = "fixed_sigma"}} \item{empirical.n_samples_aicc}{Positive integer. (default = 1000) @@ -189,8 +191,8 @@ is also a valid input. It is essential to include the package prefix if the pack The data.frame must contain the possible hyperparameter value combinations to try. The column names must match the names of the tuneable parameters specified in \code{regression.model}. If \code{regression.tune_values} is a function, then it should take one argument \code{x} which is the training data -for the current combination/coalition and returns a data.frame/data.table/tibble with the properties described above. -Using a function allows the hyperparameter values to change based on the size of the combination. See the regression +for the current coalition and returns a data.frame/data.table/tibble with the properties described above. +Using a function allows the hyperparameter values to change based on the size of the coalition See the regression vignette for several examples. Note, to make it easier to call \code{explain()} from Python, the \code{regression.tune_values} can also be a string containing an R function. For example, @@ -209,13 +211,17 @@ containing an R function. For example, \code{"function(recipe) return(recipes::step_ns(recipe, recipes::all_numeric_predictors(), deg_free = 2))"} is also a valid input. It is essential to include the package prefix if the package is not loaded.} -\item{regression.surrogate_n_comb}{Integer (default is \code{internal$parameters$used_n_combinations}) specifying the -number of unique combinations/coalitions to apply to each training observation. Maximum allowed value is -"\code{internal$parameters$used_n_combinations} - 2". By default, we use all coalitions, but this can take a lot of memory -in larger dimensions. Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. This will be all -\eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in the exact mode. If the -user sets a lower value than \code{internal$parameters$used_n_combinations}, then we sample this amount of unique -coalitions separately for each training observations. That is, on average, all coalitions should be equally trained.} +\item{regression.surrogate_n_comb}{Integer. +(default is \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}) specifying the +number of unique coalitions to apply to each training observation. Maximum allowed value is +"\code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions} - 2". +By default, we use all coalitions, but this can take a lot of memory in larger dimensions. +Note that by "all", we mean all coalitions chosen by \code{shapr} to be used. +This will be all \eqn{2^{n_{\text{features}}}} coalitions (minus empty and grand coalition) if \code{shapr} is in +the exact mode. +If the user sets a lower value than \code{internal$iter_list[[length(internal$iter_list)]]$n_coalitions}, +then we sample this amount of unique coalitions separately for each training observations. +That is, on average, all coalitions should be equally trained.} \item{timeseries.fixed_sigma_vec}{Numeric. (Default = 2) Represents the kernel bandwidth in the distance computation. TODO: What length should it have? 1?} diff --git a/man/setup_computation.Rd b/man/setup_computation.Rd index f731787e5a16c380313e2684ff2d2a1852099121..afd255e0029b5d91b78725a55e5c671c0a0ef8f0 100644 --- a/man/setup_computation.Rd +++ b/man/setup_computation.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{setup_computation} \alias{setup_computation} \title{Sets up everything for the Shapley values computation in \code{\link[=explain]{explain()}}} @@ -9,7 +9,8 @@ setup_computation(internal, model, predict_model) \arguments{ \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} \item{model}{Objects. The model object that ought to be explained. diff --git a/man/shapley_setup.Rd b/man/shapley_setup.Rd new file mode 100644 index 0000000000000000000000000000000000000000..0b96d78717be848e44f7fe10a8f13d14bb0464e7 --- /dev/null +++ b/man/shapley_setup.Rd @@ -0,0 +1,16 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/shapley_setup.R +\name{shapley_setup} +\alias{shapley_setup} +\title{Set up the kernelSHAP framework} +\usage{ +shapley_setup(internal) +} +\arguments{ +\item{internal}{List. +Not used directly, but passed through from \code{\link[=explain]{explain()}}.} +} +\description{ +Set up the kernelSHAP framework +} +\keyword{internal} diff --git a/man/shapley_weights.Rd b/man/shapley_weights.Rd index 109e68de39fcf41ab1e1f53fe68c8c77270a12ca..572955a88cb76b5bc180e69ed3b381224dabe4ff 100644 --- a/man/shapley_weights.Rd +++ b/man/shapley_weights.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{shapley_weights} \alias{shapley_weights} \title{Calculate Shapley weight} @@ -9,7 +9,7 @@ shapley_weights(m, N, n_components, weight_zero_m = 10^6) \arguments{ \item{m}{Positive integer. Total number of features/feature groups.} -\item{N}{Positive integer. The number of unique combinations when sampling \code{n_components} features/feature +\item{N}{Positive integer. The number of unique coalitions when sampling \code{n_components} features/feature groups, without replacement, from a sample space consisting of \code{m} different features/feature groups.} \item{n_components}{Positive integer. Represents the number of features/feature groups you want to sample from diff --git a/man/shapr-package.Rd b/man/shapr-package.Rd index 1041460afa615323af01fe9b349ce52c87b22c2c..825c92be00247dda8f64977dda45de70fba57b0f 100644 --- a/man/shapr-package.Rd +++ b/man/shapr-package.Rd @@ -6,7 +6,7 @@ \alias{shapr-package} \title{shapr: Prediction Explanation with Dependence-Aware Shapley Values} \description{ -Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley values do, however, assume feature independence. This package implements the method described in Aas, Jullum and Løland (2019) \href{https://arxiv.org/abs/1903.10464}{arXiv:1903.10464}, which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. +Complex machine learning models are often hard to interpret. However, in many situations it is crucial to understand and explain why a model made a specific prediction. Shapley values is the only method for such prediction explanation framework with a solid theoretical foundation. Previously known methods for estimating the Shapley values do, however, assume feature independence. This package implements methods which accounts for any feature dependence, and thereby produces more accurate estimates of the true Shapley values. An accompanying Python wrapper (shaprpy) is available on GitHub. } \seealso{ Useful links: @@ -22,10 +22,10 @@ Useful links: Authors: \itemize{ - \item Nikolai Sellereite \email{nikolaisellereite@gmail.com} (\href{https://orcid.org/0000-0002-4671-0337}{ORCID}) \item Lars Henry Berge Olsen \email{lholsen@math.uio.no} (\href{https://orcid.org/0009-0006-9360-6993}{ORCID}) \item Annabelle Redelmeier \email{Annabelle.Redelmeier@nr.no} - \item Jon Lachmann \email{Jon@lachmann.nu} + \item Jon Lachmann \email{Jon@lachmann.nu} (\href{https://orcid.org/0000-0001-8396-5673}{ORCID}) + \item Nikolai Sellereite \email{nikolaisellereite@gmail.com} (\href{https://orcid.org/0000-0002-4671-0337}{ORCID}) } Other contributors: diff --git a/man/test_predict_model.Rd b/man/test_predict_model.Rd index f428150e09ab4085e92f5a0632118976388b12ae..b43d1f6ec6a4969f427a99659754cd84aef02ba0 100644 --- a/man/test_predict_model.Rd +++ b/man/test_predict_model.Rd @@ -17,7 +17,8 @@ See the documentation of \code{\link[=explain]{explain()}} for details.} \item{internal}{List. Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} +The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{iter_list}, \code{timing_list}, +\code{main_timing_list}, \code{output}, and \code{iter_timing_list}.} } \description{ Model testing function diff --git a/man/testing_cleanup.Rd b/man/testing_cleanup.Rd new file mode 100644 index 0000000000000000000000000000000000000000..3c590807f727119c149a55c23c0e1386aa437c7f --- /dev/null +++ b/man/testing_cleanup.Rd @@ -0,0 +1,15 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/explain.R +\name{testing_cleanup} +\alias{testing_cleanup} +\title{Cleans out certain output arguments to allow perfect reproducability of the output} +\usage{ +testing_cleanup(output) +} +\description{ +Cleans out certain output arguments to allow perfect reproducability of the output +} +\author{ +Lars Henry Berge Olsen, Martin Jullum +} +\keyword{internal} diff --git a/man/vaeac_check_mask_gen.Rd b/man/vaeac_check_mask_gen.Rd index 89b9af1dbb1e202220007396aac4e580d1fba94c..92cfa921fe4cd6fdc99aac010c0441682dee1535 100644 --- a/man/vaeac_check_mask_gen.Rd +++ b/man/vaeac_check_mask_gen.Rd @@ -9,8 +9,8 @@ vaeac_check_mask_gen(mask_gen_coalitions, mask_gen_coalitions_prob, x_train) \arguments{ \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height diff --git a/man/vaeac_check_parameters.Rd b/man/vaeac_check_parameters.Rd index faeb6b8c83260ff0a5e91f13fa4a299e479dcf11..70b539ee60aa2751cdb5feb1cd2bfe0f2b85eba0 100644 --- a/man/vaeac_check_parameters.Rd +++ b/man/vaeac_check_parameters.Rd @@ -130,8 +130,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{mas \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -163,8 +163,22 @@ Note that additional choices are available if \code{vaeac.save_every_nth_epoch} \code{vaeac.save_every_nth_epoch = 5}, then \code{vaeac.which_vaeac_model} can also take the values \code{"epoch_5"}, \code{"epoch_10"}, \code{"epoch_15"}, and so on.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/vaeac_check_verbose.Rd b/man/vaeac_check_verbose.Rd deleted file mode 100644 index 73ab85049f7a91ce47dd64196f752d20453a692c..0000000000000000000000000000000000000000 --- a/man/vaeac_check_verbose.Rd +++ /dev/null @@ -1,22 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_vaeac.R -\name{vaeac_check_verbose} -\alias{vaeac_check_verbose} -\title{Function that checks the verbose parameter} -\usage{ -vaeac_check_verbose(verbose) -} -\arguments{ -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} -} -\value{ -The function does not return anything. -} -\description{ -Function that checks the verbose parameter -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/vaeac_get_extra_para_default.Rd b/man/vaeac_get_extra_para_default.Rd index f2229c3b33883b336d73621c255f761da246a3ec..8b54f75f942c85faf5f9f451088d5e76570b1b1c 100644 --- a/man/vaeac_get_extra_para_default.Rd +++ b/man/vaeac_get_extra_para_default.Rd @@ -78,9 +78,10 @@ during the training of the vaeac model. Used in \code{\link[torch:dataloader]{to \item{vaeac.batch_size_sampling}{Positive integer (default is \code{NULL}) The number of samples to include in each batch when generating the Monte Carlo samples. If \code{NULL}, then the function generates the Monte Carlo samples -for the provided coalitions/combinations and all explicands sent to \code{\link[=explain]{explain()}} at the time. -The number of coalitions are determined by \code{n_batches} in \code{\link[=explain]{explain()}}. We recommend to tweak \code{n_batches} -rather than \code{vaeac.batch_size_sampling}. Larger batch sizes are often much faster provided sufficient memory.} +for the provided coalitions and all explicands sent to \code{\link[=explain]{explain()}} at the time. +The number of coalitions are determined by the \code{n_batches} used by \code{\link[=explain]{explain()}}. We recommend to tweak +\code{extra_computation_args$max_batch_size} and \code{extra_computation_args$min_n_batches} +rather than \code{vaeac.batch_size_sampling}. Larger batch sizes are often much faster provided sufficient memory.} \item{vaeac.running_avg_n_values}{Positive integer (default is \code{5}). The number of previous IWAE values to include when we compute the running means of the IWAE criterion.} @@ -112,8 +113,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{vae \item{vaeac.mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{vaeac.mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height diff --git a/man/vaeac_get_mask_generator_name.Rd b/man/vaeac_get_mask_generator_name.Rd index 8ea86c356671a32cd7bdcbf0ce9d6c52453e7704..00601f6d7b172f0b0a67f19fe5e65eee9c98f651 100644 --- a/man/vaeac_get_mask_generator_name.Rd +++ b/man/vaeac_get_mask_generator_name.Rd @@ -14,8 +14,8 @@ vaeac_get_mask_generator_name( \arguments{ \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -27,8 +27,22 @@ of \code{mask_gen_coalitions} containing the probabilities of sampling the corre model can do arbitrary conditioning as all coalitions will be trained. \code{masking_ratio} will be overruled if \code{mask_gen_coalitions} is specified.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} } \value{ The function does not return anything. diff --git a/man/vaeac_get_x_explain_extended.Rd b/man/vaeac_get_x_explain_extended.Rd index 91b76a56b1b5a7611da7000d882d0fd1d0b829f4..7f9bb1a107228d7aed7609c47d0629498798283c 100644 --- a/man/vaeac_get_x_explain_extended.Rd +++ b/man/vaeac_get_x_explain_extended.Rd @@ -12,8 +12,8 @@ Contains the the features, whose predictions ought to be explained.} \item{S}{The \code{internal$objects$S} matrix containing the possible coalitions.} -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} +\item{index_features}{Positive integer vector. Specifies the id_coalition to +apply to the present method. \code{NULL} means all coalitions. Only used internally.} } \value{ The extended version of \code{x_explain} where the masks from \code{S} with indices \code{index_features} have been applied. diff --git a/man/vaeac_impute_missing_entries.Rd b/man/vaeac_impute_missing_entries.Rd index a3dda74f4df602e60cac36857c5df28cd1da304f..e1f36ce8357157f41b7f57025e2843c7a43beab3 100644 --- a/man/vaeac_impute_missing_entries.Rd +++ b/man/vaeac_impute_missing_entries.Rd @@ -6,12 +6,12 @@ \usage{ vaeac_impute_missing_entries( x_explain_with_NaNs, - n_samples, + n_MC_samples, vaeac_model, checkpoint, sampler, batch_size, - verbose = 0, + verbose = NULL, seed = NULL, n_explain = NULL, index_features = NULL @@ -20,7 +20,7 @@ vaeac_impute_missing_entries( \arguments{ \item{x_explain_with_NaNs}{A 2D matrix, where the missing entries to impute are represented by \code{NaN}.} -\item{n_samples}{Integer. The number of imputed versions we create for each row in \code{x_explain_with_NaNs}.} +\item{n_MC_samples}{Integer. The number of imputed versions we create for each row in \code{x_explain_with_NaNs}.} \item{vaeac_model}{An initialized \code{vaeac} model that we are going to use to generate the MC samples.} @@ -31,8 +31,22 @@ vaeac_impute_missing_entries( \item{batch_size}{Positive integer (default is \code{64}). The number of samples to include in each batch during the training of the vaeac model. Used in \code{\link[torch:dataloader]{torch::dataloader()}}.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} @@ -42,7 +56,8 @@ based code is being run.} \item{index_features}{Optional integer vector. Used internally in shapr package to index the coalitions.} } \value{ -A data.table where the missing values (\code{NaN}) in \code{x_explain_with_NaNs} have been imputed \code{n_samples} times. +A data.table where the missing values (\code{NaN}) in \code{x_explain_with_NaNs} have been imputed \code{n_MC_samples} +times. The data table will contain extra id columns if \code{index_features} and \code{n_explain} are provided. } \description{ diff --git a/man/vaeac_plot_eval_crit.Rd b/man/vaeac_plot_eval_crit.Rd index c94895d1b89488dcdf631598caa251d9225b993f..fc8e2865b256b0c08b02efdbb4d92d42338c5f7b 100644 --- a/man/vaeac_plot_eval_crit.Rd +++ b/man/vaeac_plot_eval_crit.Rd @@ -79,8 +79,8 @@ explanation_paired <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, - n_samples = 1, # As we are only interested in the training of the vaeac + phi0 = p0, + n_MC_samples = 1, # As we are only interested in the training of the vaeac vaeac.epochs = 10, # Should be higher in applications. vaeac.n_vaeacs_initialize = 1, vaeac.width = 16, @@ -93,8 +93,8 @@ explanation_regular <- explain( x_explain = x_explain, x_train = x_train, approach = approach, - prediction_zero = p0, - n_samples = 1, # As we are only interested in the training of the vaeac + phi0 = p0, + n_MC_samples = 1, # As we are only interested in the training of the vaeac vaeac.epochs = 10, # Should be higher in applications. vaeac.width = 16, vaeac.depth = 2, diff --git a/man/vaeac_plot_imputed_ggpairs.Rd b/man/vaeac_plot_imputed_ggpairs.Rd index b667281f6dad38624b34dd2ff924d05afcdb6182..6b4b1a75b1e4437e49e1bed26e5fbf5d0da2415d 100644 --- a/man/vaeac_plot_imputed_ggpairs.Rd +++ b/man/vaeac_plot_imputed_ggpairs.Rd @@ -108,8 +108,8 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = mean(y_train), - n_samples = 1, + phi0 = mean(y_train), + n_MC_samples = 1, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1 ) diff --git a/man/vaeac_prep_message_batch.Rd b/man/vaeac_prep_message_batch.Rd deleted file mode 100644 index 7dd4d773ae67f7188b3a446772bb5a188f59c086..0000000000000000000000000000000000000000 --- a/man/vaeac_prep_message_batch.Rd +++ /dev/null @@ -1,23 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/approach_vaeac.R -\name{vaeac_prep_message_batch} -\alias{vaeac_prep_message_batch} -\title{Produce message about which batch prepare_data is working on} -\usage{ -vaeac_prep_message_batch(internal, index_features) -} -\arguments{ -\item{internal}{List. -Holds all parameters, data, functions and computed objects used within \code{\link[=explain]{explain()}} -The list contains one or more of the elements \code{parameters}, \code{data}, \code{objects}, \code{output}.} - -\item{index_features}{Positive integer vector. Specifies the indices of combinations to -apply to the present method. \code{NULL} means all combinations. Only used internally.} -} -\description{ -Produce message about which batch prepare_data is working on -} -\author{ -Lars Henry Berge Olsen -} -\keyword{internal} diff --git a/man/vaeac_train_model.Rd b/man/vaeac_train_model.Rd index f21fbb6f8979290e9ecc1e67fa2f86ad92cd710c..4d1314f5ed0f7ceb026bd34ab1d7b5566ae94d9f 100644 --- a/man/vaeac_train_model.Rd +++ b/man/vaeac_train_model.Rd @@ -130,8 +130,8 @@ model can do arbitrary conditioning as all coalitions will be trained. \code{mas \item{mask_gen_coalitions}{Matrix (default is \code{NULL}). Matrix containing the coalitions that the \code{vaeac} model will be trained on, see \code{\link[=specified_masks_mask_generator]{specified_masks_mask_generator()}}. This parameter is used internally -in \code{shapr} when we only consider a subset of coalitions/combinations, i.e., when -\code{n_combinations} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., +in \code{shapr} when we only consider a subset of coalitions, i.e., when +\code{n_coalitions} \eqn{< 2^{n_{\text{features}}}}, and for group Shapley, i.e., when \code{group} is specified in \code{\link[=explain]{explain()}}.} \item{mask_gen_coalitions_prob}{Numeric array (default is \code{NULL}). Array of length equal to the height @@ -163,8 +163,22 @@ Note that additional choices are available if \code{vaeac.save_every_nth_epoch} \code{vaeac.save_every_nth_epoch = 5}, then \code{vaeac.which_vaeac_model} can also take the values \code{"epoch_5"}, \code{"epoch_10"}, \code{"epoch_15"}, and so on.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/vaeac_train_model_auxiliary.Rd b/man/vaeac_train_model_auxiliary.Rd index 8aec551543cba307197f869d260668f4615f6b18..65f1fb617c22813ac19148113200b4b83543a898 100644 --- a/man/vaeac_train_model_auxiliary.Rd +++ b/man/vaeac_train_model_auxiliary.Rd @@ -43,8 +43,22 @@ to compute the IWAE criterion when validating the vaeac model on the validation The number of previous IWAE values to include when we compute the running means of the IWAE criterion.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{cuda}{Logical (default is \code{FALSE}). If \code{TRUE}, then the \code{vaeac} model will be trained using cuda/GPU. If \code{\link[torch:cuda_is_available]{torch::cuda_is_available()}} is \code{FALSE}, the we fall back to use CPU. If \code{FALSE}, we use the CPU. Using a GPU diff --git a/man/vaeac_train_model_continue.Rd b/man/vaeac_train_model_continue.Rd index 36b946e7397cc62f805bfa4d3038e7d419bd20cd..552c561c79da80dbbac1d064c738a53a5c7d6441 100644 --- a/man/vaeac_train_model_continue.Rd +++ b/man/vaeac_train_model_continue.Rd @@ -10,7 +10,7 @@ vaeac_train_model_continue( lr_new = NULL, x_train = NULL, save_data = FALSE, - verbose = 0, + verbose = NULL, seed = 1 ) } @@ -26,8 +26,22 @@ vaeac_train_model_continue( \item{save_data}{Logical (default is \code{FALSE}). If \code{TRUE}, then the data is stored together with the model. Useful if one are to continue to train the model later using \code{\link[=vaeac_train_model_continue]{vaeac_train_model_continue()}}.} -\item{verbose}{Boolean. An integer specifying the level of verbosity. Use \code{0} (default) for no verbosity, -\code{1} for low verbose, and \code{2} for high verbose.} +\item{verbose}{String vector or NULL. +Specifies the verbosity (printout detail level) through one or more of strings \code{"basic"}, \code{"progress"}, +\code{"convergence"}, \code{"shapley"} and \code{"vS_details"}. +\code{"basic"} (default) displays basic information about the computation which is being performed. +\verb{"progress} displays information about where in the calculation process the function currently is. +#' \code{"convergence"} displays information on how close to convergence the Shapley value estimates are +(only when \code{iterative = TRUE}) . +\code{"shapley"} displays intermediate Shapley value estimates and standard deviations (only when \code{iterative = TRUE}) +\itemize{ +\item the final estimates. +\code{"vS_details"} displays information about the v_S estimates. +This is most relevant for \verb{approach \%in\% c("regression_separate", "regression_surrogate", "vaeac"}). +\code{NULL} means no printout. +Note that any combination of four strings can be used. +E.g. \code{verbose = c("basic", "vS_details")} will display basic information + details about the vS estimation process. +}} \item{seed}{Positive integer (default is \code{1}). Seed for reproducibility. Specifies the seed before any randomness based code is being run.} diff --git a/man/weight_matrix.Rd b/man/weight_matrix.Rd index 734160661809f10368a8b4fece31fa5b093275e6..043a46c4aefd433bcefef8929b2844c399f8c445 100644 --- a/man/weight_matrix.Rd +++ b/man/weight_matrix.Rd @@ -1,19 +1,17 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/setup_computation.R +% Please edit documentation in R/shapley_setup.R \name{weight_matrix} \alias{weight_matrix} \title{Calculate weighted matrix} \usage{ -weight_matrix(X, normalize_W_weights = TRUE, is_groupwise = FALSE) +weight_matrix(X, normalize_W_weights = TRUE) } \arguments{ \item{X}{data.table} -\item{normalize_W_weights}{Logical. Whether to normalize the weights for the combinations to sum to 1 for -increased numerical stability before solving the WLS (weighted least squares). Applies to all combinations -except combination \code{1} and \code{2^m}.} - -\item{is_groupwise}{Logical. Indicating whether group wise Shapley values are to be computed.} +\item{normalize_W_weights}{Logical. Whether to normalize the weights for the coalitions to sum to 1 for +increased numerical stability before solving the WLS (weighted least squares). Applies to all coalitions +except coalition \code{1} and \code{2^m}.} } \value{ Numeric matrix. See \code{\link[=weight_matrix_cpp]{weight_matrix_cpp()}} for more information. diff --git a/man/weight_matrix_cpp.Rd b/man/weight_matrix_cpp.Rd index 054764afe8d3ebc4c6cf70e54050cd34dedcb189..0a6505b9f2c795ef79b3829604ba2e646bbdf7dd 100644 --- a/man/weight_matrix_cpp.Rd +++ b/man/weight_matrix_cpp.Rd @@ -4,10 +4,10 @@ \alias{weight_matrix_cpp} \title{Calculate weight matrix} \usage{ -weight_matrix_cpp(subsets, m, n, w) +weight_matrix_cpp(coalitions, m, n, w) } \arguments{ -\item{subsets}{List. Each of the elements equals an integer +\item{coalitions}{List. Each of the elements equals an integer vector representing a valid combination of features/feature groups.} \item{m}{Integer. Number of features/feature groups} @@ -16,7 +16,7 @@ vector representing a valid combination of features/feature groups.} \item{w}{Numeric vector of length \code{n}, i.e. \code{w[i]} equals the Shapley weight of feature/feature group combination \code{i}, represented by -\code{subsets[[i]]}.} +\code{coalitions[[i]]}.} } \value{ Matrix of dimension n x m + 1 @@ -25,6 +25,6 @@ Matrix of dimension n x m + 1 Calculate weight matrix } \author{ -Nikolai Sellereite +Nikolai Sellereite, Martin Jullum } \keyword{internal} diff --git a/python/README.md b/python/README.md index b010fec77e1dae2af703a66391f26a7c32d29367..512ce3c39ccb71266216f198ce4d41ca19312b93 100644 --- a/python/README.md +++ b/python/README.md @@ -51,7 +51,7 @@ df_shapley, pred_explain, internal, timing = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) ``` diff --git a/python/examples/code_paper/code_sec_5.py b/python/examples/code_paper/code_sec_5.py new file mode 100644 index 0000000000000000000000000000000000000000..6244f592927fc86e07057c2e4f8eed0f2fdc31bd --- /dev/null +++ b/python/examples/code_paper/code_sec_5.py @@ -0,0 +1,33 @@ +import xgboost as xgb +import pandas as pd +from shaprpy import explain + +path = "inst/code_paper/" + +# Read data +x_train = pd.read_csv(path + "x_train.csv") +x_explain = pd.read_csv(path + "x_explain.csv") +y_train = pd.read_csv(path + "y_train.csv") + +# Load the XGBoost model from the raw format and add feature names +model = xgb.Booster() +model.load_model(path +"xgb.model") +model.feature_names = x_train.columns.tolist() + +exp_20_ctree = explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = 'ctree', + phi0 = y_train.mean().item(), + max_n_coalitions=20, + ctree_sample = False) + + +# Print the Shapley values +print(exp_20_ctree['shapley_values_est'].iloc[:, 1:].round(1)) + + + + + diff --git a/python/examples/devel_new_explain.py b/python/examples/devel_new_explain.py new file mode 100644 index 0000000000000000000000000000000000000000..cfa4e5d5257f6fc60e932b1e91be02c0701f7d78 --- /dev/null +++ b/python/examples/devel_new_explain.py @@ -0,0 +1,94 @@ +import xgboost as xgb +import warnings +import numpy as np +import pandas as pd +from typing import Callable +from datetime import datetime +import rpy2.robjects as ro +from rpy2.robjects.packages import importr +from rpy2.rinterface import NULL, NA +from shaprpy.utils import r2py, py2r, recurse_r_tree +from rpy2.robjects.vectors import StrVector, ListVector +from shaprpy import explain +from shaprpy.datasets import load_california_housing + +dfx_train, dfx_test, dfy_train, dfy_test = load_california_housing() + +## Fit model +model = xgb.XGBRegressor() +model.fit(dfx_train, dfy_train.values.flatten()) + +from shaprpy import explain +from shaprpy.utils import r2py, py2r, recurse_r_tree + + +## Shapr +output = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'gaussian', + phi0 = dfy_train.mean().item(), + max_n_coalitions=30 +) + +output["shapley_values_est"] + +saving_path + + +shapley_values_est +shapley_values_sd +pred_explain +MSEv +iterative_results["dt_iter_shapley_sd"] +saving_path +rinternal + +recurse_r_tree(rinternal) + + +### Testing different approaches and settings + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'gaussian', + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = False +) + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = ['gaussian', 'empirical',"gaussian","empirical","gaussian","gaussian","empirical"], + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = True, + verbose = ["basic", "progress"] +) + +shapley_values, shapley_values_sd, pred_explain, MSEv, iterative_results, saving_path, rinternal = explain( + model = model, + x_train = dfx_train, + x_explain = dfx_test, + approach = 'vaeac', + phi0 = dfy_train.mean().item(), + max_n_coalitions=100, + iterative = False, + verbose = ["basic", "progress","vS_details","shapley"] +) + + +regtest = explain( + model=model, + x_train=dfx_train, + x_explain=dfx_test, + approach='regression_separate', + phi0=dfy_train.mean().item(), + regression_model='parsnip::linear_reg()' +) + diff --git a/python/examples/keras_classifier.py b/python/examples/keras_classifier.py index d7b31e70ffaef37e3e9f397a65f3825ea3efbb2f..60138165f3272774d45e6eef9e6f6619bb4c81e0 100644 --- a/python/examples/keras_classifier.py +++ b/python/examples/keras_classifier.py @@ -30,7 +30,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) @@ -57,4 +57,4 @@ MSEv["MSEv"] """ MSEv MSEv_sd 1 0.000312 0.00014 -""" \ No newline at end of file +""" diff --git a/python/examples/pytorch_custom.py b/python/examples/pytorch_custom.py index eac345337e3d52ea926322021386f1dc05f82e90..d58fa53376b4f6e16b72097cc5922e5904f0a8c5 100644 --- a/python/examples/pytorch_custom.py +++ b/python/examples/pytorch_custom.py @@ -42,7 +42,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_explain = dfx_test, approach = 'empirical', predict_model = lambda m, x: m(torch.from_numpy(x.values).float()).cpu().detach().numpy(), - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) """ @@ -65,4 +65,4 @@ MSEv["MSEv"] """ MSEv MSEv_sd 1 27.046126 7.253933 -""" \ No newline at end of file +""" diff --git a/python/examples/regression_paradigm.py b/python/examples/regression_paradigm.py index c5daab4c4a02fa6a0ed837ec41a6bbf29c86d943..bf53b77fed826543f0e00bf30a439e9954ad467d 100644 --- a/python/examples/regression_paradigm.py +++ b/python/examples/regression_paradigm.py @@ -27,7 +27,7 @@ explanation_list["empirical"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='empirical', - prediction_zero=dfy_train.mean().item() + phi0=dfy_train.mean().item() ) # Explain the model using several separate regression methods @@ -37,7 +37,7 @@ explanation_list["sep_lm"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()' @@ -49,7 +49,7 @@ explanation_list["sep_pca"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()', @@ -64,7 +64,7 @@ explanation_list["sep_splines"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()', @@ -79,7 +79,7 @@ explanation_list["sep_tree_cv"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::decision_tree(tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression')", @@ -93,7 +93,7 @@ explanation_list["sep_xgboost"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::boost_tree(engine = 'xgboost', mode = 'regression')" @@ -105,7 +105,7 @@ explanation_list["sep_xgboost_cv"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_separate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::boost_tree(trees = hardhat::tune(), engine = 'xgboost', mode = 'regression')", @@ -120,7 +120,7 @@ explanation_list["sur_lm"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model='parsnip::linear_reg()' @@ -132,7 +132,7 @@ explanation_list["sur_rf"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="parsnip::rand_forest(engine = 'ranger', mode = 'regression')" @@ -144,7 +144,7 @@ explanation_list["sur_rf_cv"] = explain( x_train=dfx_train, x_explain=dfx_test, approach='regression_surrogate', - prediction_zero=dfy_train.mean().item(), + phi0=dfy_train.mean().item(), verbose=2, n_batches=1, regression_model="""parsnip::rand_forest( @@ -191,4 +191,4 @@ explanation_list["sep_xgboost"][0] 3 0.276002 0.957242 4 0.028560 0.049815 5 -0.242943 0.006815 -""" \ No newline at end of file +""" diff --git a/python/examples/sklearn_classifier.py b/python/examples/sklearn_classifier.py index 418f88016d427f802187c96d6f554f2600c9708a..14e7dc263a24197cc3da06c353d329c8914d1f6b 100644 --- a/python/examples/sklearn_classifier.py +++ b/python/examples/sklearn_classifier.py @@ -14,7 +14,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/examples/sklearn_regressor.py b/python/examples/sklearn_regressor.py index 6f7d590673410cb26b22c12719d67e5f72246c0e..3c7e87ac0bbc05cb0e586fa721fcac6ebf9ebd16 100644 --- a/python/examples/sklearn_regressor.py +++ b/python/examples/sklearn_regressor.py @@ -14,7 +14,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item() + phi0 = dfy_train.mean().item() ) print(df_shapley) @@ -51,7 +51,7 @@ df_shapley_g, pred_explain_g, internal_g, timing_g, MSEv_g = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), group = group ) print(df_shapley_g) diff --git a/python/examples/xgboost_booster.py b/python/examples/xgboost_booster.py index b89044344a8450df2454ea53e57e18dede23d6b4..d000ea06ba7750cfef06a562b47754f76aa59262 100644 --- a/python/examples/xgboost_booster.py +++ b/python/examples/xgboost_booster.py @@ -14,7 +14,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/examples/xgboost_regressor.py b/python/examples/xgboost_regressor.py index 7183a2dd825afa2bd76162bd3301ad977a1c3878..da9a36389e04c89cb13105facba4c4b1dbbf9655 100644 --- a/python/examples/xgboost_regressor.py +++ b/python/examples/xgboost_regressor.py @@ -14,7 +14,7 @@ df_shapley, pred_explain, internal, timing, MSEv = explain( x_train = dfx_train, x_explain = dfx_test, approach = 'empirical', - prediction_zero = dfy_train.mean().item(), + phi0 = dfy_train.mean().item(), ) print(df_shapley) diff --git a/python/install_r_packages.R b/python/install_r_packages.R index 598886243e4b257a539a75fe139cba210c9cc170..71a6fd07120955c1a20aa2b2a5dbef6b6773448d 100644 --- a/python/install_r_packages.R +++ b/python/install_r_packages.R @@ -1,4 +1,4 @@ # Installs the required R-packages install.packages("remotes", repos = "https://cloud.r-project.org") -remotes::install_github("NorskRegnesentral/shapr") +remotes::install_github("NorskRegnesentral/shapr", ref = "py_iter") # Installs the development version of shapr from the master branch on CRAN diff --git a/python/shaprpy/explain.py b/python/shaprpy/explain.py index 1e06422273aa564c80c461d2788871ccaf23b134..4ca9c3ad6fe3554d39713cb2e5107b420db4438e 100644 --- a/python/shaprpy/explain.py +++ b/python/shaprpy/explain.py @@ -24,21 +24,27 @@ def explain( x_explain: pd.DataFrame, x_train: pd.DataFrame, approach: str, - prediction_zero: float, - n_combinations: int | None = None, + phi0: float, + iterative: bool | None = None, + max_n_coalitions: int | None = None, group: dict | None = None, - n_samples: int = 1e3, - n_batches: int | None = None, + paired_shap_sampling: bool = True, + n_MC_samples: int = 1e3, + kernelSHAP_reweighting: str = "on_all_cond", seed: int | None = 1, - keep_samp_for_vS: bool = False, + verbose: str = "basic", predict_model: Callable = None, get_model_specs: Callable = None, - MSEv_uniform_comb_weights: bool = True, - timing: bool = True, - verbose: int | None = 0, + asymmetric: bool = False, + causal_ordering: dict | None = None, + confounding: bool | None = None, + extra_computation_args: dict | None = None, + iterative_args: dict | None = None, + output_args: dict | None = None, **kwargs, ): - '''Explain the output of machine learning models with more accurately estimated Shapley values. + """ + Explain the output of machine learning models with more accurately estimated Shapley values. Computes dependence-aware Shapley values for observations in `x_explain` from the specified `model` by using the method specified in `approach` to estimate the conditional expectation. @@ -48,76 +54,83 @@ def explain( model: The model whose predictions we want to explain. `shaprpy` natively supports `sklearn`, `xgboost` and `keras` models. Unsupported models can still be explained by passing `predict_model` and (optionally) `get_model_specs`. - x_explain: Contains the features whose predictions ought to be explained. - x_train: Contains the data used to estimate the (conditional) distributions for the features + x_explain: pd.DataFrame + Contains the features whose predictions ought to be explained. + x_train: pd.DataFrame + Contains the data used to estimate the (conditional) distributions for the features needed to properly estimate the conditional expectations in the Shapley formula. - approach: str or list[str] of length `n_features`. - `n_features` equals the total number of features in the model. All elements should, - either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, or `"independence"`. - prediction_zero: The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any + approach: str or list[str] + The method(s) to estimate the conditional expectation. All elements should, + either be `"gaussian"`, `"copula"`, `"empirical"`, `"ctree"`, `"categorical"`, `"timeseries"`, `"independence"`, + `"regression_separate"`, or `"regression_surrogate"`. + phi0: float + The prediction value for unseen data, i.e. an estimate of the expected prediction without conditioning on any features. Typically we set this value equal to the mean of the response variable in our training data, but other choices such as the mean of the predictions in the training data are also reasonable. - n_combinations: If `group = None`, `n_combinations` represents the number of unique feature combinations to sample. - If `group != None`, `n_combinations` represents the number of unique group combinations to sample. - If `n_combinations = None`, the exact method is used and all combinations are considered. - The maximum number of combinations equals `2^m`, where `m` is the number of features. - group: If `None` regular feature wise Shapley values are computed. - If a dict is provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the - features included in each of the different groups. The length of the dict equals the number of groups. - n_samples: Indicating the maximum number of samples to use in the - Monte Carlo integration for every conditional expectation. - n_batches: Specifies how many batches the total number of feature combinations should be split into when calculating the - contribution function for each test observation. - The default value is 1. - Increasing the number of batches may significantly reduce the RAM allocation for models with many features. - This typically comes with a small increase in computation time. - seed: Specifies the seed before any randomness based code is being run. - If `None` the seed will be inherited from the calling environment. - keep_samp_for_vS: Indicates whether the samples used in the Monte Carlo estimation of v_S should be returned (in `internal['output']`) - predict_model: The prediction function used when `model` is not natively supported. - The function must have two arguments, `model` and `newdata` which specify, respectively, the model - and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array. - `None` (the default) uses functions specified internally. - Can also be used to override the default function for natively supported model classes. - get_model_specs: An optional function for checking model/data consistency when `model` is not natively supported. - This method has yet to be implemented for keras models. - The function takes `model` as argument and provides a `dict with 3 elements: - - labels: list[str] with the names of each feature. - - classes: list[str] with the classes of each features. - - factor_levels: dict[str, list[str]] with the levels for any categorical features. - If `None` (the default) internal functions are used for natively supported model classes, and the checking is - disabled for unsupported model classes. - Can also be used to override the default function for natively supported model classes. - MSEv_uniform_comb_weights: Logical. If `True` (default), then the function weights the combinations - uniformly when computing the MSEv criterion. If `False`, then the function use the Shapley kernel weights to - weight the combinations when computing the MSEv criterion. Note that the Shapley kernel weights are replaced by - the sampling frequency when not all combinations are considered. - timing: Indicates whether the timing of the different parts of the explain call should be saved and returned. - verbose: An integer specifying the level of verbosity. If `0` (default), `shapr` will stay silent. - If `1`, it will print information about performance. If `2`, some additional information will be printed out. - kwargs: Further arguments passed to specific approaches. See R-documentation of the function - `explain_tripledot_docs` for more information about the approach specific arguments - (https://norskregnesentral.github.io/shapr/reference/explain_tripledot_docs.html). Note that the parameters - in R are called 'approach.parameter_name', but in Python the equivalent would be 'approach_parameter_name'. + iterative: bool or None, optional + If `None` (default), the argument is set to `True` if there are more than 5 features/groups, and `False` otherwise. + If `True`, the Shapley values are estimated iteratively in an iterative manner. + max_n_coalitions: int or None, optional + The upper limit on the number of unique feature/group coalitions to use in the iterative procedure + (if `iterative = True`). If `iterative = False` it represents the number of feature/group coalitions to use directly. + `max_n_coalitions = None` corresponds to `max_n_coalitions=2^n_features`. + group: dict or None, optional + If `None` regular feature wise Shapley values are computed. + If provided, group wise Shapley values are computed. `group` then contains lists of unique feature names with the + features included in each of the different groups. + paired_shap_sampling: bool, optional + If `True` (default), paired versions of all sampled coalitions are also included in the computation. + n_MC_samples: int, optional + Indicating the maximum number of samples to use in the Monte Carlo integration for every conditional expectation. + kernelSHAP_reweighting: str, optional + How to reweight the sampling frequency weights in the kernelSHAP solution after sampling, with the aim of reducing + the randomness and thereby the variance of the Shapley value estimates. One of `'none'`, `'on_N'`, `'on_all'`, + `'on_all_cond'` (default). + seed: int or None, optional + Specifies the seed before any randomness based code is being run. If `None` the seed will be inherited from the calling environment. + verbose: str or list[str], optional + Specifies the verbosity (printout detail level) through one or more of strings `"basic"`, `"progress"`, + `"convergence"`, `"shapley"` and `"vS_details"`. `None` means no printout. + predict_model: Callable, optional + The prediction function used when `model` is not natively supported. The function must have two arguments, `model` and `newdata` + which specify, respectively, the model and a pandas.DataFrame to compute predictions for. The function must give the prediction as a numpy.Array. + get_model_specs: Callable, optional + An optional function for checking model/data consistency when `model` is not natively supported. The function takes `model` as argument + and provides a `dict` with 3 elements: `labels`, `classes`, and `factor_levels`. + asymmetric: bool, optional + If `False` (default), `explain` computes regular symmetric Shapley values. If `True`, then `explain` computes asymmetric Shapley values + based on the (partial) causal ordering given by `causal_ordering`. + causal_ordering: dict or None, optional + An unnamed list of vectors specifying the components of the partial causal ordering that the coalitions must respect. + confounding: bool or None, optional + A vector of logicals specifying whether confounding is assumed or not for each component in the `causal_ordering`. + extra_computation_args: dict or None, optional + Specifies extra arguments related to the computation of the Shapley values. + iterative_args: dict or None, optional + Specifies the arguments for the iterative procedure. + output_args: dict or None, optional + Specifies certain arguments related to the output of the function. + **kwargs: Further arguments passed to specific approaches. Returns ------- - pandas.DataFrame - A pandas.DataFrame with the Shapley values. - numpy.Array - A numpy.Array with the predictions on `x_explain`. dict - A dictionary of additional information. - dict - A dictionary of elapsed time information if `timing` is set to `True`. - dict - A dictionary of the MSEv evaluation criterion scores: averaged over both the explicands and coalitions, - only over the explicands, and only over the coalitions. - ''' + A dictionary containing the following items: + - "shapley_values_est": pd.DataFrame with the estimated Shapley values. + - "shapley_values_sd": pd.DataFrame with the standard deviation of the Shapley values. + - "pred_explain": numpy.Array with the predictions for the explained observations. + - "MSEv": dict with the values of the MSEv evaluation criterion. + - "iterative_results": dict with the results of the iterative estimation. + - "saving_path": str with the path where intermediate results are stored. + - "internal": dict with the different parameters, data, functions and other output used internally. + - "timing": dict containing timing information for the different parts of the computation. + """ - timing_list = {"init_time": datetime.now()} + init_time = base.Sys_time() # datetime.now() - base.set_seed(seed) + + if seed is not None: + base.set_seed(seed) # Gets and check feature specs from the model rfeature_specs = get_feature_specs(get_model_specs, model) @@ -133,82 +146,183 @@ def explain( if 'regression.vfold_cv_para' in kwargs: kwargs['regression.vfold_cv_para'] = ListVector(kwargs['regression.vfold_cv_para']) + # Convert from None or dict to a named list in R + if iterative_args is None: + iterative_args = ro.ListVector({}) + else: + iterative_args = ListVector(iterative_args) + + if output_args is None: + output_args = ro.ListVector({}) + else: + output_args = ListVector(output_args) + + if extra_computation_args is None: + extra_computation_args = ro.ListVector({}) + else: + extra_computation_args = ListVector(extra_computation_args) + # Sets up and organizes input parameters # Checks the input parameters and their compatability # Checks data/model compatability + + if type(approach) == str: + approach = [approach] + + if type(verbose) == str: + verbose = [verbose] + + rinternal = shapr.setup( - x_train = py2r(x_train), - x_explain = py2r(x_explain), - approach = approach, - prediction_zero = prediction_zero, - n_combinations = maybe_null(n_combinations), - group = r_group, - n_samples = n_samples, - n_batches = maybe_null(n_batches), - seed = seed, - keep_samp_for_vS = keep_samp_for_vS, - feature_specs = rfeature_specs, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights, - timing = timing, - verbose = verbose, - is_python=True, - **kwargs + x_train = py2r(x_train), + x_explain = py2r(x_explain), + approach = StrVector(approach), + paired_shap_sampling = paired_shap_sampling, + phi0 = phi0, + max_n_coalitions = maybe_null(max_n_coalitions), + group = r_group, + n_MC_samples = n_MC_samples, + seed = maybe_null(seed), + feature_specs = rfeature_specs, + verbose = StrVector(verbose), + iterative = maybe_null(iterative), + iterative_args = iterative_args, # Might do some conversion here + kernelSHAP_reweighting = kernelSHAP_reweighting, + asymmetric = asymmetric, + causal_ordering = maybe_null(causal_ordering), # Might do some conversion here + confounding = maybe_null(confounding), # Might do some conversion here + output_args = output_args, # Might do some conversion here + extra_computation_args = extra_computation_args, # Might do some conversion here + init_time = init_time, + is_python=True, + **kwargs ) - timing_list["setup"] = datetime.now() - # Gets predict_model (if not passed to explain) and checks that predict_model gives correct format predict_model = get_predict_model(x_test=x_train.head(2), predict_model=predict_model, model=model) - timing_list["test_prediction"] = datetime.now() + rinternal.rx2['timing_list'].rx2['test_prediction'] = base.Sys_time() + + rinternal = additional_regression_setup( + rinternal, + model, + predict_model, + x_train, + x_explain) + + # Not called for approach %in% c("regression_surrogate","vaeac") + rinternal = shapr.setup_approach(internal = rinternal) # model and predict_model are not supported in Python + + rinternal.rx2['main_timing_list'] = rinternal.rx2['timing_list'] + + converged = False + iter = len(rinternal.rx2('iter_list')) + + if seed is not None: + base.set_seed(seed) + + model_class = f"{type(model).__module__}.{type(model).__name__}" + shapr.cli_startup(rinternal, model_class, verbose) + + rinternal.rx2['iter_timing_list'] = ro.ListVector({}) + + while not converged: + shapr.cli_iter(verbose, rinternal, iter) + + rinternal.rx2['timing_list'] = ro.ListVector({'init': base.Sys_time()}) + + # Setup the Shapley framework + rinternal = shapr.shapley_setup(rinternal) + + # Only actually called for approach in ["regression_surrogate", "vaeac"] + rinternal = shapr.setup_approach(rinternal) + + # Compute the vS + vS_list = compute_vS(rinternal, model, predict_model) + + # Compute Shapley value estimates and bootstrapped standard deviations + rinternal = shapr.compute_estimates(rinternal, vS_list) + + # Check convergence based on estimates and standard deviations (and thresholds) + rinternal = shapr.check_convergence(rinternal) + + # Save intermediate results + shapr.save_results(rinternal) - # Add the predicted response of the training and explain data to the internal list for regression-based methods - using_regression_paradigm = rinternal.rx2("parameters").rx2("regression")[0] - if using_regression_paradigm: - rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain) + # Preparing parameters for next iteration (does not do anything if already converged) + rinternal = shapr.prepare_next_iteration(rinternal) - # Sets up the Shapley framework and prepares the conditional expectation computation for the chosen approach - rinternal = shapr.setup_computation(rinternal, NULL, NULL) + # Printing iteration information + shapr.print_iter(rinternal) - # Compute the v(S): - # MC: - # 1. Get the samples for the conditional distributions with the specified approach - # 2. Predict with these samples - # 3. Perform MC integration on these to estimate the conditional expectation (v(S)) - # Regression: - # 1. Directly estimate the conditional expectation (v(S)) using the fitted regression model(s) - rvS_list = compute_vS(rinternal, model, predict_model) + # Setting globals to simplify the loop + converged = rinternal.rx2('iter_list')[iter-1].rx2('converged')[0] - timing_list["compute_vS"] = datetime.now() + rinternal.rx2['timing_list'].rx2['postprocess_res'] = base.Sys_time() - # Compute Shapley values based on conditional expectations (v(S)) - # Organize function output - routput = shapr.finalize_explanation(vS_list=rvS_list, internal=rinternal) + # Add the current timing_list to the iter_timing_list + #iter_timing_list = list(rinternal.rx2['iter_timing_list']) + #iter_timing_list.append(rinternal.rx2['timing_list']) + #rinternal.rx2['iter_timing_list'] = ro.ListVector(iter_timing_list) - timing_list["shapley_computation"] = datetime.now() + rinternal.rx2['iter_timing_list'].rx2[iter] = rinternal.rx2['timing_list'] + iter += 1 - # Compute the elapsed time for the different steps - timing = compute_time(timing_list) if timing else None + rinternal.rx2['main_timing_list'].rx2['main_computation'] = base.Sys_time() - # If regression, then delete the regression/tidymodels objects in routput as they cannot be converted to python - if using_regression_paradigm: - routput = regression_remove_objects(routput) + # Rerun after convergence to get the same output format as for the non-iterative approach + routput = shapr.finalize_explanation(rinternal) + + rinternal.rx2['main_timing_list'].rx2['finalize_explanation'] = base.Sys_time() + + routput.rx2['timing'] = shapr.compute_time(rinternal) + + # Some cleanup when doing testing + #testing = rinternal.rx2('parameters').rx2('testing')[0] + #if base.isTRUE(testing): + # routput = shapr.testing_cleanup(routput) # Convert R objects to Python objects - df_shapley = r2py(base.as_data_frame(routput.rx2('shapley_values'))) + shapley_values_est = r2py(base.as_data_frame(routput.rx2('shapley_values_est'))) + shapley_values_sd = r2py(base.as_data_frame(routput.rx2('shapley_values_sd'))) pred_explain = r2py(routput.rx2('pred_explain')) - internal = recurse_r_tree(routput.rx2('internal')) MSEv = recurse_r_tree(routput.rx2('MSEv')) - - return df_shapley, pred_explain, internal, timing, MSEv + iterative_results = recurse_r_tree(routput.rx2('iterative_results')) + #saving_path = StrVector(routput.rx2['saving_path']) # NOt sure why this is not working + saving_path = StrVector(rinternal.rx2['parameters'].rx2['output_args'].rx2['saving_path'])[0] + #internal = recurse_r_tree(routput.rx2('rinternal')) # Currently get an error with NULL elements here + rtiming = routput.rx2['timing'] + + return { + "shapley_values_est": shapley_values_est, + "shapley_values_sd": shapley_values_sd, + "pred_explain": pred_explain, + "MSEv": MSEv, + "iterative_results": iterative_results, + "saving_path": saving_path, + "internal": rinternal, + "timing": rtiming + } def compute_vS(rinternal, model, predict_model): - S_batch = rinternal.rx2('objects').rx2('S_batch') - ret = ro.ListVector({}) + + iter = len(rinternal.rx2('iter_list')) + + # S_batch <- internal$iter_list[[iter]]$S_batch + S_batch = rinternal.rx2('iter_list')[iter-1].rx2('S_batch') + + # verbose + shapr.cli_compute_vS(rinternal) + + vS_list = ro.ListVector({}) for i, S in enumerate(S_batch): - ret.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model) - return ret + vS_list.rx2[i+1] = batch_compute_vS(S=S, rinternal=rinternal, model=model, predict_model=predict_model) + + #### Adds v_S output above to any vS_list already computed #### + vS_list = shapr.append_vS_list(vS_list,rinternal) + + return vS_list def batch_compute_vS(S, rinternal, model, predict_model): @@ -218,17 +332,20 @@ def batch_compute_vS(S, rinternal, model, predict_model): if regression: dt_vS = shapr.batch_prepare_vS_regression(S=S, internal=rinternal) else: - # dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$keep_samp_for_vS = TRUE + # dt_vS is either only dt_vS or a list containing dt_vS and dt if internal$parameters$output_args$keep_samp_for_vS = TRUE dt_vS = batch_prepare_vS_MC(S=S, rinternal=rinternal, model=model, predict_model=predict_model) return dt_vS -def batch_prepare_vS_MC(S, rinternal, model, predict_model): +def batch_prepare_vS_MC_old(S, rinternal, model, predict_model): keep_samp_for_vS = rinternal.rx2('parameters').rx2('keep_samp_for_vS')[0] feature_names = list(rinternal.rx2('parameters').rx2('feature_names')) + dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal) + dt = compute_preds(dt=dt, feature_names=feature_names, predict_model=predict_model, model=model) + dt_vS = shapr.compute_MCint(dt) if keep_samp_for_vS: @@ -236,8 +353,91 @@ def batch_prepare_vS_MC(S, rinternal, model, predict_model): else: return dt_vS +def batch_prepare_vS_MC(S, rinternal, model, predict_model): + feature_names = list(rinternal.rx2('parameters').rx2('feature_names')) + keep_samp_for_vS = rinternal.rx2('parameters').rx2('output_args').rx2('keep_samp_for_vS')[0] + causal_sampling = rinternal.rx2('parameters').rx2('causal_sampling')[0] + output_size = int(rinternal.rx2('parameters').rx2('output_size')[0]) + + dt = shapr.batch_prepare_vS_MC_auxiliary(S=S, internal=rinternal, causal_sampling=causal_sampling) + + pred_cols = [f"p_hat{i+1}" for i in range(output_size)] + type_ = rinternal.rx2('parameters').rx2('type')[0] + + if type_ == "forecast": + horizon = rinternal.rx2('parameters').rx2('horizon')[0] + n_endo = rinternal.rx2('data').rx2('n_endo')[0] + explain_idx = rinternal.rx2('parameters').rx2('explain_idx')[0] + explain_lags = rinternal.rx2('parameters').rx2('explain_lags')[0] + y = rinternal.rx2('data').rx2('y') + xreg = rinternal.rx2('data').rx2('xreg') + dt = compute_preds( + dt=dt, + feature_names=feature_names, + predict_model=predict_model, + model=model, + type_=type_, + horizon=horizon, + n_endo=n_endo, + explain_idx=explain_idx, + explain_lags=explain_lags, + y=y, + xreg=xreg + ) + else: + dt = compute_preds( + dt=dt, + feature_names=feature_names, + predict_model=predict_model, + model=model, + type_=type_ + ) + + dt_vS = shapr.compute_MCint(dt) -def compute_preds(dt, feature_names, predict_model, model): + if keep_samp_for_vS: + return ro.ListVector({'dt_vS': dt_vS, 'dt_samp_for_vS': dt}) + else: + return dt_vS + +def compute_preds( + dt, + feature_names, + predict_model, + model, + type_, + horizon=None, + n_endo=None, + explain_idx=None, + explain_lags=None, + y=None, + xreg=None +): + # Predictions + if type_ == "forecast": + # TODO: I actually dont't think this works + preds = predict_model( + model, + r2py(dt).loc[:,:n_endo], + r2py(dt).loc[:,n_endo:], + horizon, + explain_idx, + explain_lags, + y, + xreg + ) + + else: + preds = predict_model( + model, + r2py(dt).loc[:,feature_names] + ) + + return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist())) + + + +def compute_preds_old(dt, feature_names, predict_model, model): preds = predict_model(model, r2py(dt).loc[:,feature_names]) return ro.r.cbind(dt, p_hat=ro.FloatVector(preds.tolist())) @@ -272,7 +472,7 @@ def get_feature_specs(get_model_specs, model): py2r_or_na = lambda v: py2r(v) if v is not None else NA def strvec_or_na(v): if v is None: return NA - strvec = ro.StrVector(list(v.values())) + strvec = StrVector(list(v.values())) strvec.names = list(v.keys()) return strvec def listvec_or_na(v): @@ -386,6 +586,15 @@ def compute_time(timing_list): return timing_output +def additional_regression_setup(rinternal, model, predict_model, x_train, x_explain): + # Add the predicted response of the training and explain data to the internal list for regression-based methods + regression = rinternal.rx2("parameters").rx2("regression")[0] + if regression: + rinternal = regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain) + + return rinternal + + def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain): x_train_y_hat = predict_model(model, x_train) x_explain_y_hat = predict_model(model, x_explain) @@ -402,7 +611,7 @@ def regression_get_y_hat(rinternal, model, predict_model, x_train, x_explain): def regression_remove_objects(routput): tmp_internal = routput.rx2("internal") tmp_parameters = tmp_internal.rx2("parameters") - objects = ro.StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para", + objects = StrVector(("regression", "regression.model", "regression.tune_values", "regression.vfold_cv_para", "regression.recipe_func", "regression.tune", "regression.surrogate_n_comb")) tmp_parameters.rx[objects] = NULL tmp_internal.rx2["parameters"] = tmp_parameters @@ -418,4 +627,5 @@ def change_first_underscore_to_dot(kwargs): kwargs_tmp = {} for k, v in kwargs.items(): kwargs_tmp[k.replace('_', '.', 1)] = v - return kwargs_tmp \ No newline at end of file + return kwargs_tmp + diff --git a/rebuild_long_running_vignette.R b/rebuild_long_running_vignette.R index a75a3a7a4ff936dfcd6e5bf4dadebd7400515202..ca9ad1eacb25850d7658dd0f6764889940415913 100644 --- a/rebuild_long_running_vignette.R +++ b/rebuild_long_running_vignette.R @@ -15,4 +15,7 @@ knitr::knit("understanding_shapr_vaeac.Rmd.orig", output = "understanding_shapr_ knitr::knit("understanding_shapr_regression.Rmd.orig", output = "understanding_shapr_regression.Rmd") # knitr::purl("understanding_shapr_regression.Rmd.orig", output = "understanding_shapr_regression.R") # Don't need this +knitr::knit("understanding_shapr_asymmetric_causal.Rmd.orig", output = "understanding_shapr_asymmetric_causal.Rmd") +# knitr::purl("understanding_shapr_asymmetric_causal.Rmd.orig", output = "understanding_shapr_asymmetric_causal.R") + setwd(old_wd) diff --git a/src/Copula.cpp b/src/Copula.cpp index 732ed3a4f75d59424319295a54225133963680d6..9ae9666b12dff770d9ca44e166920da2e6c41488 100644 --- a/src/Copula.cpp +++ b/src/Copula.cpp @@ -54,8 +54,8 @@ arma::vec quantile_type7_cpp(const arma::vec& x, const arma::vec& probs) { // [[Rcpp::export]] arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { int n_features = z.n_cols; - int n_samples = z.n_rows; - arma::mat z_new(n_samples, n_features); + int n_MC_samples = z.n_rows; + arma::mat z_new(n_MC_samples, n_features); arma::mat u = arma::normcdf(z); for (int feature_idx = 0; feature_idx < n_features; feature_idx++) { z_new.col(feature_idx) = quantile_type7_cpp(x.col(feature_idx), u.col(feature_idx)); @@ -65,7 +65,7 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' Generate (Gaussian) Copula MC samples //' -//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the //' univariate standard normal. //' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations //' to explain on the original scale. @@ -73,7 +73,7 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been //' transformed to a standardized normal distribution. //' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. -//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of //' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. //' This is not a problem internally in shapr as the empty and grand coalitions treated differently. //' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed @@ -82,8 +82,8 @@ arma::mat inv_gaussian_transform_cpp(const arma::mat& z, const arma::mat& x) { //' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been //' transformed to a standardized normal distribution. //' -//' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -//' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +//' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +//' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian //' copula MC samples for each explicand and coalition on the original scale. //' //' @export @@ -99,13 +99,13 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, const arma::mat& cov_mat) { int n_explain = x_explain_mat.n_rows; - int n_samples = MC_samples_mat.n_rows; + int n_MC_samples = MC_samples_mat.n_rows; int n_features = MC_samples_mat.n_cols; int n_coalitions = S.n_rows; // Initialize auxiliary matrix and result cube - arma::mat aux_mat(n_samples, n_features); - arma::cube result_cube(n_samples, n_explain*n_coalitions, n_features); + arma::mat aux_mat(n_MC_samples, n_features); + arma::cube result_cube(n_MC_samples, n_explain*n_coalitions, n_features); // Iterate over the coalitions for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { @@ -150,7 +150,7 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, // Transform the MC samples to be from N(mu_{Sbar|S}, Sigma_{Sbar|S}) for one coalition and one explicand arma::mat MC_samples_mat_now_now = - MC_samples_mat_now + repmat(trans(x_Sbar_gaussian_mean.col(idx_now)), n_samples, 1); + MC_samples_mat_now + repmat(trans(x_Sbar_gaussian_mean.col(idx_now)), n_MC_samples, 1); // Transform the MC to the original scale using the inverse Gaussian transform arma::mat MC_samples_mat_now_now_trans = @@ -158,7 +158,7 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, // Insert the generate Gaussian copula MC samples and the feature values we condition on into an auxiliary matrix aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now_now_trans; - aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_samples, 1); + aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_MC_samples, 1); // Insert the auxiliary matrix into the result cube result_cube.col(S_ind*n_explain + idx_now) = aux_mat; @@ -167,3 +167,101 @@ arma::cube prepare_data_copula_cpp(const arma::mat& MC_samples_mat, return result_cube; } + +//' Generate (Gaussian) Copula MC samples for the causal setup with a single MC sample for each explicand +//' +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +//' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +//' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations to +//' explain on the original scale. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat`. +//' @param x_explain_gaussian_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the +//' observations to explain after being transformed using the Gaussian transform, i.e., the samples have been +//' transformed to a standardized normal distribution. +//' @param x_train_mat arma::mat. Matrix of dimension (`n_train`, `n_features`) containing the training observations. +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of +//' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +//' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +//' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature after being transformed +//' using the Gaussian transform, i.e., the samples have been transformed to a standardized normal distribution. +//' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +//' between all pairs of features after being transformed using the Gaussian transform, i.e., the samples have been +//' transformed to a standardized normal distribution. +//' +//' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +//' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +//' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +//' +//' @export +//' @keywords internal +//' @author Lars Henry Berge Olsen +// [[Rcpp::export]] +arma::mat prepare_data_copula_cpp_caus(const arma::mat& MC_samples_mat, + const arma::mat& x_explain_mat, + const arma::mat& x_explain_gaussian_mat, + const arma::mat& x_train_mat, + const arma::mat& S, + const arma::vec& mu, + const arma::mat& cov_mat) { + + int n_explain = x_explain_mat.n_rows; + int n_features = MC_samples_mat.n_cols; + int n_coalitions = S.n_rows; + + // Initialize auxiliary matrix and result cube + arma::mat result_mat(n_explain * n_coalitions, n_features); + + // Iterate over the coalitions + for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { + + // Get the row_indices in the result_mat for the current coalition + arma::uvec row_vec = arma::linspace(n_explain * S_ind, n_explain * (S_ind + 1) - 1, n_explain); + + // Get current coalition S and the indices of the features in coalition S and mask Sbar + arma::mat S_now = S.row(S_ind); + arma::uvec S_now_idx = arma::find(S_now > 0.5); + arma::uvec Sbar_now_idx = arma::find(S_now < 0.5); + + // Extract the features we condition on, both on the original scale and the Gaussian transformed values. + arma::mat x_S_star = x_explain_mat.cols(S_now_idx); + arma::mat x_S_star_gaussian = x_explain_gaussian_mat.cols(S_now_idx); + + // Extract the mean values of the Gaussian transformed features in the two sets + arma::vec mu_S = mu.elem(S_now_idx); + arma::vec mu_Sbar = mu.elem(Sbar_now_idx); + + // Extract the relevant parts of the Gaussian transformed covariance matrix + arma::mat cov_mat_SS = cov_mat.submat(S_now_idx, S_now_idx); + arma::mat cov_mat_SSbar = cov_mat.submat(S_now_idx, Sbar_now_idx); + arma::mat cov_mat_SbarS = cov_mat.submat(Sbar_now_idx, S_now_idx); + arma::mat cov_mat_SbarSbar = cov_mat.submat(Sbar_now_idx, Sbar_now_idx); + + // Compute the covariance matrix multiplication factors/terms and the conditional covariance matrix + arma::mat cov_mat_SbarS_cov_mat_SS_inv = cov_mat_SbarS * inv(cov_mat_SS); + arma::mat cond_cov_mat_Sbar_given_S = cov_mat_SbarSbar - cov_mat_SbarS_cov_mat_SS_inv * cov_mat_SSbar; + + // Ensure that the conditional covariance matrix is symmetric + if (!cond_cov_mat_Sbar_given_S.is_symmetric()) { + cond_cov_mat_Sbar_given_S = arma::symmatl(cond_cov_mat_Sbar_given_S); + } + + // Compute the conditional mean of Xsbar given Xs = Xs_star_gaussian, i.e., of the Gaussian transformed features + arma::mat x_Sbar_gaussian_mean = cov_mat_SbarS_cov_mat_SS_inv * (x_S_star_gaussian.each_row() - mu_S.t()).t(); + x_Sbar_gaussian_mean.each_col() += mu_Sbar; + + // Transform the samples to be from N(O, Sigma_{Sbar|S}) + arma::mat MC_samples_mat_now = MC_samples_mat.cols(Sbar_now_idx) * arma::chol(cond_cov_mat_Sbar_given_S); + + // Transform the MC samples to be from N(mu_{Sbar|S}, Sigma_{Sbar|S}) for one coalition + arma::mat MC_samples_mat_now_now = MC_samples_mat_now + trans(x_Sbar_gaussian_mean); + + // Transform the MC to the original scale using the inverse Gaussian transform + arma::mat MC_samples_mat_now_now_trans = + inv_gaussian_transform_cpp(MC_samples_mat_now_now, x_train_mat.cols(Sbar_now_idx)); + + // Combine the generated values with the values we conditioned on to generate the final MC samples and save them + result_mat.submat(row_vec, S_now_idx) = x_S_star; + result_mat.submat(row_vec, Sbar_now_idx) = MC_samples_mat_now_now_trans; + } + + return result_mat; +} diff --git a/src/Gaussian.cpp b/src/Gaussian.cpp index c375ed5106d21dc5cb1d79df5977850fef9de95c..07fcf97064e7be048d2294d88427ee1b467eefb6 100644 --- a/src/Gaussian.cpp +++ b/src/Gaussian.cpp @@ -5,19 +5,19 @@ using namespace Rcpp; //' Generate Gaussian MC samples //' -//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_samples`, `n_features`) containing samples from the +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_MC_samples`, `n_features`) containing samples from the //' univariate standard normal. //' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations //' to explain. -//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' @param S arma::mat. Matrix of dimension (`n_coalitions`, `n_features`) containing binary representations of //' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. //' This is not a problem internally in shapr as the empty and grand coalitions treated differently. //' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. //' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance //' between all pairs of features. //' -//' @return An arma::cube/3D array of dimension (`n_samples`, `n_explain` * `n_coalitions`, `n_features`), where -//' the columns (_,j,_) are matrices of dimension (`n_samples`, `n_features`) containing the conditional Gaussian +//' @return An arma::cube/3D array of dimension (`n_MC_samples`, `n_explain` * `n_coalitions`, `n_features`), where +//' the columns (_,j,_) are matrices of dimension (`n_MC_samples`, `n_features`) containing the conditional Gaussian //' MC samples for each explicand and coalition. //' //' @export @@ -31,13 +31,13 @@ arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, const arma::mat& cov_mat) { int n_explain = x_explain_mat.n_rows; - int n_samples = MC_samples_mat.n_rows; + int n_MC_samples = MC_samples_mat.n_rows; int n_features = MC_samples_mat.n_cols; int n_coalitions = S.n_rows; // Initialize auxiliary matrix and result cube - arma::mat aux_mat(n_samples, n_features); - arma::cube result_cube(n_samples, n_explain*n_coalitions, n_features); + arma::mat aux_mat(n_MC_samples, n_features); + arma::cube result_cube(n_MC_samples, n_explain * n_coalitions, n_features); // Iterate over the coalitions for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { @@ -78,11 +78,93 @@ arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, // Loop over the different explicands and combine the generated values with the values we conditioned on for (int idx_now = 0; idx_now < n_explain; idx_now++) { - aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_samples, 1); - aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now + repmat(trans(x_Sbar_mean.col(idx_now)), n_samples, 1); - result_cube.col(S_ind*n_explain + idx_now) = aux_mat; + aux_mat.cols(S_now_idx) = repmat(x_S_star.row(idx_now), n_MC_samples, 1); + aux_mat.cols(Sbar_now_idx) = MC_samples_mat_now + repmat(trans(x_Sbar_mean.col(idx_now)), n_MC_samples, 1); + result_cube.col(S_ind * n_explain + idx_now) = aux_mat; } } return result_cube; } + +//' Generate Gaussian MC samples for the causal setup with a single MC sample for each explicand +//' +//' @param MC_samples_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing samples from the +//' univariate standard normal. The i'th row will be applied to the i'th row in `x_explain_mat`. +//' @param x_explain_mat arma::mat. Matrix of dimension (`n_explain`, `n_features`) containing the observations +//' to explain. The MC sample for the i'th explicand is based on the i'th row in `MC_samples_mat` +//' @param S arma::mat. Matrix of dimension (`n_combinations`, `n_features`) containing binary representations of +//' the used coalitions. S cannot contain the empty or grand coalition, i.e., a row containing only zeros or ones. +//' This is not a problem internally in shapr as the empty and grand coalitions treated differently. +//' @param mu arma::vec. Vector of length `n_features` containing the mean of each feature. +//' @param cov_mat arma::mat. Matrix of dimension (`n_features`, `n_features`) containing the pairwise covariance +//' between all pairs of features. +//' +//' @return An arma::mat/2D array of dimension (`n_explain` * `n_coalitions`, `n_features`), +//' where the rows (n_explain * S_ind, n_explain * (S_ind + 1) - 1) contains the single +//' conditional Gaussian MC samples for each explicand and `S_ind` coalition. +//' +//' @export +//' @keywords internal +//' @author Lars Henry Berge Olsen +// [[Rcpp::export]] +arma::mat prepare_data_gaussian_cpp_caus(const arma::mat& MC_samples_mat, + const arma::mat& x_explain_mat, + const arma::mat& S, + const arma::vec& mu, + const arma::mat& cov_mat) { + + int n_explain = x_explain_mat.n_rows; + int n_features = MC_samples_mat.n_cols; + int n_coalitions = S.n_rows; + + // Initialize the result matrix + arma::mat result_mat(n_explain * n_coalitions, n_features); + + // Iterate over the coalitions + for (int S_ind = 0; S_ind < n_coalitions; S_ind++) { + + // Get the row_indices in the result_mat for the current coalition + arma::uvec row_vec = arma::linspace(n_explain * S_ind, n_explain * (S_ind + 1) - 1, n_explain); + + // Get current coalition S and the indices of the features in coalition S and mask Sbar + arma::mat S_now = S.row(S_ind); + arma::uvec S_now_idx = arma::find(S_now > 0.5); + arma::uvec Sbar_now_idx = arma::find(S_now < 0.5); + + // Extract the features we condition on + arma::mat x_S_star = x_explain_mat.cols(S_now_idx); + + // Extract the mean values of the features in the two sets + arma::vec mu_S = mu.elem(S_now_idx); + arma::vec mu_Sbar = mu.elem(Sbar_now_idx); + + // Extract the relevant parts of the covariance matrix + arma::mat cov_mat_SS = cov_mat.submat(S_now_idx, S_now_idx); + arma::mat cov_mat_SSbar = cov_mat.submat(S_now_idx, Sbar_now_idx); + arma::mat cov_mat_SbarS = cov_mat.submat(Sbar_now_idx, S_now_idx); + arma::mat cov_mat_SbarSbar = cov_mat.submat(Sbar_now_idx, Sbar_now_idx); + + // Compute the covariance matrix multiplication factors/terms and the conditional covariance matrix + arma::mat cov_mat_SbarS_cov_mat_SS_inv = cov_mat_SbarS * inv(cov_mat_SS); + arma::mat cond_cov_mat_Sbar_given_S = cov_mat_SbarSbar - cov_mat_SbarS_cov_mat_SS_inv * cov_mat_SSbar; + + // Ensure that the conditional covariance matrix is symmetric + if (!cond_cov_mat_Sbar_given_S.is_symmetric()) { + cond_cov_mat_Sbar_given_S = arma::symmatl(cond_cov_mat_Sbar_given_S); + } + + // Compute the conditional mean of Xsbar given Xs = Xs_star + arma::mat x_Sbar_mean = cov_mat_SbarS_cov_mat_SS_inv * (x_S_star.each_row() - mu_S.t()).t(); + x_Sbar_mean.each_col() += mu_Sbar; + + // Transform the samples to be from N(O, Sigma_{Sbar|S}) + arma::mat MC_samples_mat_now = MC_samples_mat.cols(Sbar_now_idx) * arma::chol(cond_cov_mat_Sbar_given_S); + + // Combine the generated values with the values we conditioned on to generate the final MC samples and save them + result_mat.submat(row_vec, S_now_idx) = x_S_star; + result_mat.submat(row_vec, Sbar_now_idx) = MC_samples_mat_now + trans(x_Sbar_mean); + } + + return result_mat; +} diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index c95d55541af9643a658ab2c735659539cdf13a79..3ed8157eb4b7bea59e5e42134fe76d1e42e5f76e 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -121,6 +121,23 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// prepare_data_copula_cpp_caus +arma::mat prepare_data_copula_cpp_caus(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& x_explain_gaussian_mat, const arma::mat& x_train_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); +RcppExport SEXP _shapr_prepare_data_copula_cpp_caus(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP x_explain_gaussian_matSEXP, SEXP x_train_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type MC_samples_mat(MC_samples_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_mat(x_explain_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_gaussian_mat(x_explain_gaussian_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_train_mat(x_train_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type mu(muSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type cov_mat(cov_matSEXP); + rcpp_result_gen = Rcpp::wrap(prepare_data_copula_cpp_caus(MC_samples_mat, x_explain_mat, x_explain_gaussian_mat, x_train_mat, S, mu, cov_mat)); + return rcpp_result_gen; +END_RCPP +} // prepare_data_gaussian_cpp arma::cube prepare_data_gaussian_cpp(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); RcppExport SEXP _shapr_prepare_data_gaussian_cpp(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { @@ -136,6 +153,21 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// prepare_data_gaussian_cpp_caus +arma::mat prepare_data_gaussian_cpp_caus(const arma::mat& MC_samples_mat, const arma::mat& x_explain_mat, const arma::mat& S, const arma::vec& mu, const arma::mat& cov_mat); +RcppExport SEXP _shapr_prepare_data_gaussian_cpp_caus(SEXP MC_samples_matSEXP, SEXP x_explain_matSEXP, SEXP SSEXP, SEXP muSEXP, SEXP cov_matSEXP) { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + Rcpp::traits::input_parameter< const arma::mat& >::type MC_samples_mat(MC_samples_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type x_explain_mat(x_explain_matSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type S(SSEXP); + Rcpp::traits::input_parameter< const arma::vec& >::type mu(muSEXP); + Rcpp::traits::input_parameter< const arma::mat& >::type cov_mat(cov_matSEXP); + rcpp_result_gen = Rcpp::wrap(prepare_data_gaussian_cpp_caus(MC_samples_mat, x_explain_mat, S, mu, cov_mat)); + return rcpp_result_gen; +END_RCPP +} // mahalanobis_distance_cpp arma::cube mahalanobis_distance_cpp(Rcpp::List featureList, arma::mat Xtrain_mat, arma::mat Xtest_mat, arma::mat mcov, bool S_scale_dist); RcppExport SEXP _shapr_mahalanobis_distance_cpp(SEXP featureListSEXP, SEXP Xtrain_matSEXP, SEXP Xtest_matSEXP, SEXP mcovSEXP, SEXP S_scale_distSEXP) { @@ -179,28 +211,28 @@ BEGIN_RCPP END_RCPP } // weight_matrix_cpp -arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w); -RcppExport SEXP _shapr_weight_matrix_cpp(SEXP subsetsSEXP, SEXP mSEXP, SEXP nSEXP, SEXP wSEXP) { +arma::mat weight_matrix_cpp(List coalitions, int m, int n, NumericVector w); +RcppExport SEXP _shapr_weight_matrix_cpp(SEXP coalitionsSEXP, SEXP mSEXP, SEXP nSEXP, SEXP wSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< List >::type subsets(subsetsSEXP); + Rcpp::traits::input_parameter< List >::type coalitions(coalitionsSEXP); Rcpp::traits::input_parameter< int >::type m(mSEXP); Rcpp::traits::input_parameter< int >::type n(nSEXP); Rcpp::traits::input_parameter< NumericVector >::type w(wSEXP); - rcpp_result_gen = Rcpp::wrap(weight_matrix_cpp(subsets, m, n, w)); + rcpp_result_gen = Rcpp::wrap(weight_matrix_cpp(coalitions, m, n, w)); return rcpp_result_gen; END_RCPP } -// feature_matrix_cpp -NumericMatrix feature_matrix_cpp(List features, int m); -RcppExport SEXP _shapr_feature_matrix_cpp(SEXP featuresSEXP, SEXP mSEXP) { +// coalition_matrix_cpp +NumericMatrix coalition_matrix_cpp(List coalitions, int m); +RcppExport SEXP _shapr_coalition_matrix_cpp(SEXP coalitionsSEXP, SEXP mSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; - Rcpp::traits::input_parameter< List >::type features(featuresSEXP); + Rcpp::traits::input_parameter< List >::type coalitions(coalitionsSEXP); Rcpp::traits::input_parameter< int >::type m(mSEXP); - rcpp_result_gen = Rcpp::wrap(feature_matrix_cpp(features, m)); + rcpp_result_gen = Rcpp::wrap(coalition_matrix_cpp(coalitions, m)); return rcpp_result_gen; END_RCPP } @@ -214,12 +246,14 @@ static const R_CallMethodDef CallEntries[] = { {"_shapr_quantile_type7_cpp", (DL_FUNC) &_shapr_quantile_type7_cpp, 2}, {"_shapr_inv_gaussian_transform_cpp", (DL_FUNC) &_shapr_inv_gaussian_transform_cpp, 2}, {"_shapr_prepare_data_copula_cpp", (DL_FUNC) &_shapr_prepare_data_copula_cpp, 7}, + {"_shapr_prepare_data_copula_cpp_caus", (DL_FUNC) &_shapr_prepare_data_copula_cpp_caus, 7}, {"_shapr_prepare_data_gaussian_cpp", (DL_FUNC) &_shapr_prepare_data_gaussian_cpp, 5}, + {"_shapr_prepare_data_gaussian_cpp_caus", (DL_FUNC) &_shapr_prepare_data_gaussian_cpp_caus, 5}, {"_shapr_mahalanobis_distance_cpp", (DL_FUNC) &_shapr_mahalanobis_distance_cpp, 5}, {"_shapr_sample_features_cpp", (DL_FUNC) &_shapr_sample_features_cpp, 2}, {"_shapr_observation_impute_cpp", (DL_FUNC) &_shapr_observation_impute_cpp, 5}, {"_shapr_weight_matrix_cpp", (DL_FUNC) &_shapr_weight_matrix_cpp, 4}, - {"_shapr_feature_matrix_cpp", (DL_FUNC) &_shapr_feature_matrix_cpp, 2}, + {"_shapr_coalition_matrix_cpp", (DL_FUNC) &_shapr_coalition_matrix_cpp, 2}, {NULL, NULL, 0} }; diff --git a/src/impute_data.cpp b/src/impute_data.cpp index cced8fa517c4682f245d8fd8d0857bde2cd50232..2c6f4d4da1a7c45d15852bf858f1663338fc2b86 100644 --- a/src/impute_data.cpp +++ b/src/impute_data.cpp @@ -13,7 +13,7 @@ using namespace Rcpp; //' //' @param xtest Numeric matrix. Represents a single test observation. //' -//' @param S Integer matrix of dimension \code{n_combinations x m}, where \code{n_combinations} equals +//' @param S Integer matrix of dimension \code{n_coalitions x m}, where \code{n_coalitions} equals //' the total number of sampled/non-sampled feature combinations and \code{m} equals //' the total number of unique features. Note that \code{m = ncol(xtrain)}. See details //' for more information. diff --git a/src/weighted_matrix.cpp b/src/weighted_matrix.cpp index 8b71520ad628b8f8a6fe798b0d0b2792303994cf..79eaa87621a04aae2184981d9510496bd1ac4676 100644 --- a/src/weighted_matrix.cpp +++ b/src/weighted_matrix.cpp @@ -1,29 +1,32 @@ +#define ARMA_WARN_LEVEL 1 // Disables the warning regarding approximate solution for small n_coalitions #include using namespace Rcpp; + + //' Calculate weight matrix //' -//' @param subsets List. Each of the elements equals an integer +//' @param coalitions List. Each of the elements equals an integer //' vector representing a valid combination of features/feature groups. //' @param m Integer. Number of features/feature groups //' @param n Integer. Number of combinations //' @param w Numeric vector of length \code{n}, i.e. \code{w[i]} equals //' the Shapley weight of feature/feature group combination \code{i}, represented by -//' \code{subsets[[i]]}. +//' \code{coalitions[[i]]}. //' //' @export //' @keywords internal //' //' @return Matrix of dimension n x m + 1 -//' @author Nikolai Sellereite +//' @author Nikolai Sellereite, Martin Jullum // [[Rcpp::export]] -arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ +arma::mat weight_matrix_cpp(List coalitions, int m, int n, NumericVector w){ // Note that Z is a n x (m + 1) matrix, where m is the number - // of unique subsets. All elements in the first column are equal to 1. + // of unique coalitions. All elements in the first column are equal to 1. // For j > 0, Z(i, j) = 1 if and only if feature/feature group j is present in - // the ith combination of subsets. In example, if Z(i, j) = 1 we know that - // j is present in subsets[i]. + // the ith combination of coalitions. In example, if Z(i, j) = 1 we know that + // j is present in coalitions[i]. // Note that w represents the diagonal in W, where W is a diagoanl // n x n matrix. @@ -51,8 +54,8 @@ arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ // Set all elements in the first column equal to 1 Z(i, 0) = 1; - // Extract subsets - subset_vec = subsets[i]; + // Extract coalitions + subset_vec = coalitions[i]; n_elements = subset_vec.length(); if (n_elements > 0) { for (int j = 0; j < n_elements; j++) @@ -74,32 +77,32 @@ arma::mat weight_matrix_cpp(List subsets, int m, int n, NumericVector w){ return R; } -//' Get feature matrix +//' Get coalition matrix //' -//' @param features List -//' @param m Positive integer. Total number of features +//' @param coalitions List +//' @param m Positive integer. Total number of coalitions //' //' @export //' @keywords internal //' //' @return Matrix -//' @author Nikolai Sellereite +//' @author Nikolai Sellereite, Martin Jullum // [[Rcpp::export]] -NumericMatrix feature_matrix_cpp(List features, int m) { +NumericMatrix coalition_matrix_cpp(List coalitions, int m) { // Define variables - int n_combinations; - n_combinations = features.length(); - NumericMatrix A(n_combinations, m); + int n_coalitions; + n_coalitions = coalitions.length(); + NumericMatrix A(n_coalitions, m); // Error-check - IntegerVector features_zero = features[0]; + IntegerVector features_zero = coalitions[0]; if (features_zero.length() > 0) - Rcpp::stop("The first element of features should be an empty vector, i.e. integer(0)"); + Rcpp::stop("Internal error: The first element of coalitions should be an empty vector, i.e. integer(0)"); - for (int i = 1; i < n_combinations; ++i) { + for (int i = 1; i < n_coalitions; ++i) { - NumericVector feature_vec = features[i]; + NumericVector feature_vec = coalitions[i]; for (int j = 0; j < feature_vec.length(); ++j) { diff --git a/tests/testthat/_snaps/adaptive-output.md b/tests/testthat/_snaps/adaptive-output.md new file mode 100644 index 0000000000000000000000000000000000000000..72239e8d086195785e4b587869e510d389714f5e --- /dev/null +++ b/tests/testthat/_snaps/adaptive-output.md @@ -0,0 +1,984 @@ +# output_lm_numeric_independence_reach_exact + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.31 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) 0.258 (2.14) 0.258 (2.14) 17.463 (5.62) -5.635 (1.84) + 2: 42.444 (0.00) -0.986 (0.56) -0.986 (0.56) -5.286 (1.40) -5.635 (1.45) + 3: 42.444 (0.00) -4.493 (0.33) -4.493 (0.33) -1.495 (0.98) -2.595 (0.59) + Day + + 1: 0.258 (2.14) + 2: -0.986 (0.56) + 3: -4.493 (0.33) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.18 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.411 (3.37) 8.305 (3.82) 17.463 (3.50) -5.635 (0.19) + 2: 42.444 (0.00) 2.376 (1.47) -3.309 (1.07) -5.286 (1.24) -5.635 (1.02) + 3: 42.444 (0.00) 3.834 (3.22) -18.574 (5.10) -1.495 (2.37) -2.595 (0.83) + Day + + 1: -3.121 (3.24) + 2: -2.025 (1.13) + 3: 1.261 (4.44) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.079 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.467 (0.21) 8.284 (0.98) 17.485 (0.01) -5.635 (0.12) + 2: 42.444 (0.00) 2.320 (0.75) -3.331 (0.11) -5.264 (0.01) -5.635 (0.39) + 3: 42.444 (0.00) 3.778 (0.47) -18.596 (1.70) -1.473 (0.01) -2.595 (0.34) + Day + + 1: -3.065 (1.02) + 2: -1.969 (0.67) + 3: 1.317 (1.77) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.541 (0.05) 8.330 (0.80) 17.491 (0.02) -5.585 (0.02) + 2: 42.444 (0.00) 2.246 (0.05) -3.285 (0.10) -5.258 (0.02) -5.585 (0.02) + 3: 42.444 (0.00) 3.704 (0.05) -18.549 (1.40) -1.467 (0.02) -2.545 (0.02) + Day + + 1: -3.093 (0.80) + 2: -1.997 (0.10) + 3: 1.289 (1.40) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093 + 2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997 + 3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289 + +# output_lm_numeric_independence_converges_tol + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.1] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 8 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 10 + (Concervatively) adding 20% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.538 8.268 17.523 -5.589 -3.061 + 2: 2 42.44 2.249 -3.347 -5.227 -5.589 -1.966 + 3: 3 42.44 3.707 -18.611 -1.435 -2.549 1.321 + +# output_lm_numeric_independence_converges_maxit + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.001] + Estimated remaining coalitions: 20 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.001] + Estimated remaining coalitions: 18 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.001] + Estimated remaining coalitions: 16 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.099 [needs 0.001] + Estimated remaining coalitions: 14 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + Message + + -- Iteration 5 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 18 coalitions: + Current convergence measure: 0.06 [needs 0.001] + Estimated remaining coalitions: 12 + (Concervatively) adding 0.001% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.536 (1.11) 8.270 (0.03) 17.519 (2.34) -5.592 (1.16) + 2: 42.444 (0.00) 2.251 (0.47) -3.344 (0.03) -5.231 (0.47) -5.592 (0.03) + 3: 42.444 (0.00) 3.709 (0.30) -18.609 (0.03) -1.439 (0.36) -2.552 (0.06) + Day + + 1: -3.059 (1.77) + 2: -1.964 (0.42) + 3: 1.323 (0.30) + Message + + -- Iteration 6 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 20 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.534 (0.01) 8.272 (0.01) 17.520 (0.01) -5.592 (0.01) + 2: 42.444 (0.00) 2.253 (0.01) -3.342 (0.01) -5.229 (0.01) -5.592 (0.01) + 3: 42.444 (0.00) 3.711 (0.01) -18.607 (0.01) -1.438 (0.01) -2.553 (0.01) + Day + + 1: -3.064 (0.01) + 2: -1.968 (0.01) + 3: 1.318 (0.01) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.534 8.272 17.520 -5.592 -3.064 + 2: 2 42.44 2.253 -3.342 -5.229 -5.592 -1.968 + 3: 3 42.44 3.711 -18.607 -1.438 -2.553 1.318 + +# output_lm_numeric_indep_conv_max_n_coalitions + + Code + (out <- code) + Message + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.31 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) 0.258 (2.14) 0.258 (2.14) 17.463 (5.62) -5.635 (1.84) + 2: 42.444 (0.00) -0.986 (0.56) -0.986 (0.56) -5.286 (1.40) -5.635 (1.45) + 3: 42.444 (0.00) -4.493 (0.33) -4.493 (0.33) -1.495 (0.98) -2.595 (0.59) + Day + + 1: 0.258 (2.14) + 2: -0.986 (0.56) + 3: -4.493 (0.33) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.18 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.411 (3.37) 8.305 (3.82) 17.463 (3.50) -5.635 (0.19) + 2: 42.444 (0.00) 2.376 (1.47) -3.309 (1.07) -5.286 (1.24) -5.635 (1.02) + 3: 42.444 (0.00) 3.834 (3.22) -18.574 (5.10) -1.495 (2.37) -2.595 (0.83) + Day + + 1: -3.121 (3.24) + 2: -2.025 (1.13) + 3: 1.261 (4.44) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.079 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.467 (0.21) 8.284 (0.98) 17.485 (0.01) -5.635 (0.12) + 2: 42.444 (0.00) 2.320 (0.75) -3.331 (0.11) -5.264 (0.01) -5.635 (0.39) + 3: 42.444 (0.00) 3.778 (0.47) -18.596 (1.70) -1.473 (0.01) -2.595 (0.34) + Day + + 1: -3.065 (1.02) + 2: -1.969 (0.67) + 3: 1.317 (1.77) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.541 (0.05) 8.330 (0.80) 17.491 (0.02) -5.585 (0.02) + 2: 42.444 (0.00) 2.246 (0.05) -3.285 (0.10) -5.258 (0.02) -5.585 (0.02) + 3: 42.444 (0.00) 3.704 (0.05) -18.549 (1.40) -1.467 (0.02) -2.545 (0.02) + Day + + 1: -3.093 (0.80) + 2: -1.997 (0.10) + 3: 1.289 (1.40) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.541 8.330 17.491 -5.585 -3.093 + 2: 2 42.44 2.246 -3.285 -5.258 -5.585 -1.997 + 3: 3 42.44 3.704 -18.549 -1.467 -2.545 1.289 + +# output_lm_numeric_gaussian_group_converges_tol + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 6 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none A B C + + 1: 42.444 (0.00) 0.772 (2.66) 13.337 (3.49) -1.507 (3.31) + 2: 42.444 (0.00) 0.601 (2.97) -13.440 (3.32) -1.040 (2.77) + 3: 42.444 (0.00) -18.368 (3.91) 0.127 (3.95) 0.673 (0.12) + explain_id none A B C + + 1: 1 42.44 0.7716 13.3373 -1.5069 + 2: 2 42.44 0.6006 -13.4404 -1.0396 + 3: 3 42.44 -18.3678 0.1268 0.6728 + +# output_lm_numeric_independence_converges_tol_paired + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.19 [needs 0.1] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (2.23) 8.215 (3.14) 17.463 (5.65) -5.545 (3.30) + 2: 42.444 (0.00) 2.196 (1.45) -3.399 (0.45) -5.286 (1.14) -5.545 (1.04) + 3: 42.444 (0.00) 3.654 (0.94) -18.664 (4.32) -1.495 (1.14) -2.505 (3.75) + Day + + 1: -2.940 (4.17) + 2: -1.845 (1.51) + 3: 1.442 (2.14) + Message + + -- Iteration 2 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 8 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.591 (0.76) 8.215 (2.20) 17.463 (4.64) -5.545 (2.14) + 2: 42.444 (0.00) 2.196 (0.98) -3.399 (0.47) -5.286 (0.76) -5.545 (0.98) + 3: 42.444 (0.00) 3.654 (1.12) -18.664 (3.06) -1.495 (0.82) -2.505 (2.55) + Day + + 1: -2.940 (4.54) + 2: -1.845 (1.11) + 3: 1.442 (1.96) + Message + + -- Iteration 3 ----------------------------------------------------------------- + + -- Convergence info + i Not converged after 14 coalitions: + Current convergence measure: 0.14 [needs 0.1] + Estimated remaining coalitions: 10 + (Concervatively) adding 20% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.570 (0.87) 8.236 (1.92) 17.463 (4.97) -5.593 (1.32) + 2: 42.444 (0.00) 2.217 (0.66) -3.378 (0.33) -5.286 (0.86) -5.593 (0.26) + 3: 42.444 (0.00) 3.675 (0.52) -18.643 (3.19) -1.495 (0.72) -2.553 (1.19) + Day + + 1: -2.934 (4.68) + 2: -1.839 (1.06) + 3: 1.448 (3.00) + Message + + -- Iteration 4 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 16 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -4.538 (1.90) 8.268 (0.56) 17.523 (3.29) -5.589 (0.04) + 2: 42.444 (0.00) 2.249 (0.66) -3.347 (0.09) -5.227 (0.77) -5.589 (0.04) + 3: 42.444 (0.00) 3.707 (0.45) -18.611 (1.01) -1.435 (0.58) -2.549 (0.04) + Day + + 1: -3.061 (2.86) + 2: -1.966 (0.50) + 3: 1.321 (1.06) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.538 8.268 17.523 -5.589 -3.061 + 2: 2 42.44 2.249 -3.347 -5.227 -5.589 -1.966 + 3: 3 42.44 3.707 -18.611 -1.435 -2.549 1.321 + +# output_lm_numeric_independence_saving_and_cont_est + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.531 8.202 17.504 -5.549 -3.024 + 2: 2 42.44 2.256 -3.412 -5.246 -5.549 -1.928 + 3: 3 42.44 3.714 -18.677 -1.454 -2.509 1.358 + +--- + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.531 8.202 17.504 -5.549 -3.024 + 2: 2 42.44 2.256 -3.412 -5.246 -5.549 -1.928 + 3: 3 42.44 3.714 -18.677 -1.454 -2.509 1.358 + +# output_verbose_1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3_4 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -1.428 (1.74) -1.428 (1.74) 15.197 (5.43) 1.688 (0.97) + 2: 42.444 (0.00) -0.914 (1.10) -0.914 (1.10) -10.815 (3.23) -0.321 (0.19) + 3: 42.444 (0.00) -5.807 (0.72) -5.807 (0.72) 0.168 (1.95) -0.316 (1.71) + Day + + 1: -1.428 (1.74) + 2: -0.914 (1.10) + 3: -5.807 (0.72) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp + + 1: 42.444 (0.00) -10.984 (4.19) 6.696 (3.77) 15.197 (4.21) + 2: 42.444 (0.00) 2.151 (2.02) -6.851 (2.61) -10.815 (2.04) + 3: 42.444 (0.00) 6.820 (4.76) -26.009 (7.25) 0.168 (3.47) + Month Day + + 1: 1.688 (1.57) 0.006 (3.61) + 2: -0.321 (0.33) 1.957 (2.22) + 3: -0.316 (0.90) 1.769 (6.40) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -9.803 (1.62) 7.155 (0.72) 14.738 (0.31) 1.688 (0.48) + 2: 42.444 (0.00) 4.188 (1.34) -6.060 (0.82) -11.606 (0.54) -0.321 (0.16) + 3: 42.444 (0.00) 7.531 (1.13) -25.733 (2.34) -0.109 (0.19) -0.316 (0.31) + Day + + 1: -1.175 (1.69) + 2: -0.080 (1.41) + 3: 1.057 (2.57) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.850 (0.50) 7.165 (0.77) 14.627 (0.34) 1.200 (0.24) + 2: 42.444 (0.00) 4.909 (0.49) -5.670 (0.76) -11.676 (0.54) -0.592 (0.19) + 3: 42.444 (0.00) 7.453 (0.17) -25.529 (1.87) -0.083 (0.18) -0.223 (0.09) + Day + + 1: -1.541 (0.65) + 2: -0.851 (0.60) + 3: 0.814 (1.89) + Message + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.534 (0.45) 7.868 (0.36) 14.315 (0.27) 0.850 (0.37) + 2: 42.444 (0.00) 4.919 (0.36) -4.878 (0.53) -11.909 (0.38) -0.841 (0.23) + 3: 42.444 (0.00) 7.447 (0.16) -25.748 (0.16) 0.032 (0.13) -0.198 (0.07) + Day + + 1: -1.897 (0.19) + 2: -1.171 (0.25) + 3: 0.898 (0.12) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + +# output_verbose_1_3_4_5 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Convergence info + i Not converged after 6 coalitions: + Current convergence measure: 0.33 [needs 0.02] + Estimated remaining coalitions: 24 + (Concervatively) adding 10% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -1.428 (1.74) -1.428 (1.74) 15.197 (5.43) 1.688 (0.97) + 2: 42.444 (0.00) -0.914 (1.10) -0.914 (1.10) -10.815 (3.23) -0.321 (0.19) + 3: 42.444 (0.00) -5.807 (0.72) -5.807 (0.72) 0.168 (1.95) -0.316 (1.71) + Day + + 1: -1.428 (1.74) + 2: -0.914 (1.10) + 3: -5.807 (0.72) + Message + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 10 coalitions: + Current convergence measure: 0.2 [needs 0.02] + Estimated remaining coalitions: 20 + (Concervatively) adding 10% of that (2 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp + + 1: 42.444 (0.00) -10.984 (4.19) 6.696 (3.77) 15.197 (4.21) + 2: 42.444 (0.00) 2.151 (2.02) -6.851 (2.61) -10.815 (2.04) + 3: 42.444 (0.00) 6.820 (4.76) -26.009 (7.25) 0.168 (3.47) + Month Day + + 1: 1.688 (1.57) 0.006 (3.61) + 2: -0.321 (0.33) 1.957 (2.22) + 3: -0.316 (0.90) 1.769 (6.40) + Message + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Convergence info + i Not converged after 12 coalitions: + Current convergence measure: 0.077 [needs 0.02] + Estimated remaining coalitions: 18 + (Concervatively) adding 20% of that (4 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -9.803 (1.62) 7.155 (0.72) 14.738 (0.31) 1.688 (0.48) + 2: 42.444 (0.00) 4.188 (1.34) -6.060 (0.82) -11.606 (0.54) -0.321 (0.16) + 3: 42.444 (0.00) 7.531 (1.13) -25.733 (2.34) -0.109 (0.19) -0.316 (0.31) + Day + + 1: -1.175 (1.69) + 2: -0.080 (1.41) + 3: 1.057 (2.57) + Message + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Convergence info + i Not converged after 16 coalitions: + Current convergence measure: 0.046 [needs 0.02] + Estimated remaining coalitions: 14 + (Concervatively) adding 30% of that (6 coalitions) in the next iteration. + + -- Current estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.850 (0.50) 7.165 (0.77) 14.627 (0.34) 1.200 (0.24) + 2: 42.444 (0.00) 4.909 (0.49) -5.670 (0.76) -11.676 (0.54) -0.592 (0.19) + 3: 42.444 (0.00) 7.453 (0.17) -25.529 (1.87) -0.083 (0.18) -0.223 (0.09) + Day + + 1: -1.541 (0.65) + 2: -0.851 (0.60) + 3: 0.814 (1.89) + Message + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Convergence info + v Converged after 22 coalitions: + Convergence tolerance reached! + + -- Final estimated Shapley values (sd) + Output + none Solar.R Wind Temp Month + + 1: 42.444 (0.00) -8.534 (0.45) 7.868 (0.36) 14.315 (0.27) 0.850 (0.37) + 2: 42.444 (0.00) 4.919 (0.36) -4.878 (0.53) -11.909 (0.38) -0.841 (0.23) + 3: 42.444 (0.00) 7.447 (0.16) -25.748 (0.16) 0.032 (0.13) -0.198 (0.07) + Day + + 1: -1.897 (0.19) + 2: -1.171 (0.25) + 3: 0.898 (0.12) + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.534 7.868 14.3146 0.8504 -1.8969 + 2: 2 42.44 4.919 -4.878 -11.9086 -0.8405 -1.1714 + 3: 3 42.44 7.447 -25.748 0.0324 -0.1976 0.8978 + diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds new file mode 100644 index 0000000000000000000000000000000000000000..ed6d05c26d72c8e9001e3ded0a881268887f3f26 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_gaussian_group_converges_tol.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds new file mode 100644 index 0000000000000000000000000000000000000000..0a0f7379e99794ef591c3a8c7b093f423fb66174 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_indep_conv_max_n_coalitions.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds new file mode 100644 index 0000000000000000000000000000000000000000..0507b05cd97c5ee52179cf2b36fec688b1ce8812 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_object.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds new file mode 100644 index 0000000000000000000000000000000000000000..0507b05cd97c5ee52179cf2b36fec688b1ce8812 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_cont_est_path.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds new file mode 100644 index 0000000000000000000000000000000000000000..752207bf6600c8ee0b448108992f8b3937f89867 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_maxit.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds new file mode 100644 index 0000000000000000000000000000000000000000..02f2a785c42dad66fbcd71d0c7a8828a87f92c4e Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds new file mode 100644 index 0000000000000000000000000000000000000000..02f2a785c42dad66fbcd71d0c7a8828a87f92c4e Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_converges_tol_paired.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds new file mode 100644 index 0000000000000000000000000000000000000000..4c5089cbb13e0023e712c5dc9ef399768d5b1428 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_lm_numeric_independence_reach_exact.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds new file mode 100644 index 0000000000000000000000000000000000000000..876bd1a66565d911f89bb365bd4a4837a5f81e50 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds new file mode 100644 index 0000000000000000000000000000000000000000..fd125a547ccf10968929f287379033bd7feb9492 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds new file mode 100644 index 0000000000000000000000000000000000000000..583867cacc243ff285116111acaeb34f4065cc35 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4.rds differ diff --git a/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds new file mode 100644 index 0000000000000000000000000000000000000000..f0b55a8cf7c271ea9ba33c1ed31c95289014ddc5 Binary files /dev/null and b/tests/testthat/_snaps/adaptive-output/output_verbose_1_3_4_5.rds differ diff --git a/tests/testthat/_snaps/adaptive-setup.md b/tests/testthat/_snaps/adaptive-setup.md new file mode 100644 index 0000000000000000000000000000000000000000..326a03d446795f0821a54e5e1d0c1f338889a305 --- /dev/null +++ b/tests/testthat/_snaps/adaptive-setup.md @@ -0,0 +1,96 @@ +# erroneous input: `min_n_batches` + + Code + n_batches_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_integer)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_too_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_is_NA)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + +--- + + Code + n_batches_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_positive)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_extra_computation_args()`: + ! `extra_computation_args$min_n_batches` must be NULL or a single positive integer. + diff --git a/tests/testthat/_snaps/asymmetric-causal-output.md b/tests/testthat/_snaps/asymmetric-causal-output.md new file mode 100644 index 0000000000000000000000000000000000000000..0177b8b4dcc8dc42b18f067bb6d9c57c4c2ac427 --- /dev/null +++ b/tests/testthat/_snaps/asymmetric-causal-output.md @@ -0,0 +1,744 @@ +# output_asymmetric_conditional + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -24.516 29.347 11.557 -0.626 -3.161 + 2: 2 42.44 -7.632 8.053 -7.467 -4.634 -2.200 + 3: 3 42.44 -3.458 -18.240 4.321 -1.347 1.156 + +# output_asym_cond_reg + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.337 15.032 14.506 -2.656 -2.943 + 2: 2 42.44 5.546 -6.262 -4.518 -6.664 -1.982 + 3: 3 42.44 9.720 -32.555 7.270 -3.377 1.374 + +# output_asym_cond_reg_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 8 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 8 coalitions, 3 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.352 15.017 14.540 -2.658 -2.945 + 2: 2 42.44 5.552 -6.256 -4.526 -6.666 -1.984 + 3: 3 42.44 9.720 -32.556 7.270 -3.377 1.374 + +# output_symmetric_conditional + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -11.395 7.610 15.278 1.3845 -0.2755 + 2: 2 42.44 2.001 -5.047 -10.833 -0.2829 0.2824 + 3: 3 42.44 4.589 -25.823 1.138 0.2876 2.2401 + +# output_symmetric_marginal_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind, Temp, Month, Day} + * Components with confounding: {Solar.R, Wind, Temp, Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -2.644 6.870 16.5974 -0.5859 -7.636 + 2: 2 42.44 -1.315 -3.251 -6.6438 -5.9780 3.308 + 3: 3 42.44 -1.114 -10.549 -0.8839 -7.0244 2.004 + +# output_symmetric_marginal_gaussian + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind, Temp, Month, Day} + * Components with confounding: {Solar.R, Wind, Temp, Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.1241 6.631 15.251 -2.3173 1.161 + 2: 2 42.44 0.8798 -2.652 -6.971 -1.2012 -3.935 + 3: 3 42.44 3.3391 -14.550 -3.145 -0.4127 -2.800 + +# output_asym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -12.804 11.755 17.3723 -0.499 -3.222 + 2: 2 42.44 1.471 -2.609 -5.9820 -4.592 -2.168 + 3: 3 42.44 14.736 -31.711 -0.3884 -1.430 1.225 + +# output_asym_caus_conf_FALSE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: No component with confounding + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -15.4362 17.9420 13.883 -0.626 -3.161 + 2: 2 42.44 -0.8741 -0.4898 -5.682 -4.634 -2.200 + 3: 3 42.44 7.2517 -30.3922 5.763 -1.347 1.156 + +# output_asym_caus_conf_mix + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -12.804 11.755 17.4378 -0.626 -3.161 + 2: 2 42.44 1.471 -2.609 -5.9087 -4.634 -2.200 + 3: 3 42.44 14.736 -31.711 -0.4028 -1.347 1.156 + +# output_asym_caus_conf_mix_n_coal + + Code + (out <- code) + Message + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 6 of 6 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -17.410 26.305 13.958 -7.146 -3.105 + 2: 2 42.44 -2.592 5.563 -3.561 -11.136 -2.154 + 3: 3 42.44 21.260 -43.085 10.992 -8.054 1.319 + +# output_asym_caus_conf_mix_empirical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.609 9.859 17.410 -4.136 -0.9212 + 2: 2 42.44 14.220 -17.195 -7.333 -1.904 -1.6682 + 3: 3 42.44 0.661 -20.737 7.258 -5.048 0.2978 + +# output_asym_caus_conf_mix_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -17.734 20.45 19.217 -5.820 -3.5086 + 2: 2 42.44 19.188 -15.28 -9.429 -8.159 -0.1952 + 3: 3 42.44 5.409 -29.78 8.986 -1.464 -0.7140 + +# output_sym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Month, Day} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -10.586 9.603 14.085 -2.429 1.9293 + 2: 2 42.44 1.626 -3.712 -2.724 -7.310 -1.7595 + 3: 3 42.44 9.581 -25.344 1.892 -4.089 0.3918 + +# output_sym_caus_conf_FALSE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: No component with confounding + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -7.978 10.871 12.1981 -2.188 -0.3003 + 2: 2 42.44 3.637 -6.474 -9.6711 -1.850 0.4779 + 3: 3 42.44 1.926 -27.039 0.7298 1.404 5.4112 + +# output_sym_caus_conf_mix + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Month, Day} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -10.60 9.600 14.068 -2.464 1.9983 + 2: 2 42.44 1.62 -3.719 -2.722 -7.284 -1.7747 + 3: 3 42.44 9.58 -25.345 1.893 -4.005 0.3084 + +# output_sym_caus_conf_TRUE_group + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + * Causal ordering: {A, B}, {C} + * Components with confounding: {A, B}, {C} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 11.547 16.725 -15.67 + 2: 2 42.44 7.269 -10.685 -10.46 + 3: 3 42.44 -5.058 1.578 -14.09 + +# output_sym_caus_conf_mix_group + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + * Causal ordering: {A}, {B}, {C} + * Components with confounding: {A}, {B} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 -13.728 31.822 -5.493 + 2: 2 42.44 3.126 -6.343 -10.662 + 3: 3 42.44 5.310 -17.036 -5.842 + +# output_sym_caus_conf_mix_group_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 8, + and is therefore set to 2^n_groups = 8. + + + -- Iteration 1 ----------------------------------------------------------------- + + -- Convergence info + v Converged after 6 coalitions: + Convergence tolerance reached! + Output + explain_id none A B C + + 1: 1 42.44 -17.921 39.86 -9.334 + 2: 2 42.44 -2.802 -5.92 -5.157 + 3: 3 42.44 -2.233 -20.16 4.828 + +# output_mixed_sym_caus_conf_TRUE + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -1.065 18.16 8.030 -0.1478 -14.394 + 2: 2 42.44 4.729 -11.40 -7.837 1.6971 -2.570 + 3: 3 42.44 3.010 -23.62 3.218 4.8728 1.922 + +# output_mixed_sym_caus_conf_TRUE_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + + -- Iteration 6 ----------------------------------------------------------------- + i Using 26 of 32 coalitions, 4 new. + + -- Iteration 7 ----------------------------------------------------------------- + i Using 28 of 32 coalitions, 2 new. + + -- Iteration 8 ----------------------------------------------------------------- + i Using 30 of 32 coalitions, 2 new. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -2.13189 8.867 9.390 -1.137 -4.404 + 2: 2 42.44 0.07794 -7.916 -3.340 -1.378 -2.828 + 3: 3 42.44 -2.32289 -13.512 4.116 -1.343 2.462 + +# output_mixed_asym_caus_conf_mixed + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Solar.R, Wind} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -2.8521 17.231 5.46662 -6.018 -3.243 + 2: 2 42.44 0.6492 -4.826 -0.02641 -5.053 -6.127 + 3: 3 42.44 -10.7232 -14.690 8.32742 1.080 5.406 + +# output_mixed_asym_caus_conf_mixed_2 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + * Components with confounding: {Temp}, {Day, Month_factor} + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 1.656 17.903 0.2668 -3.7786 -5.463 + 2: 2 42.44 -2.941 -6.389 4.8876 -4.4941 -6.446 + 3: 3 42.44 4.715 -34.627 13.1031 0.4327 5.776 + +# output_mixed_asym_cond_reg + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or number of coalitions respecting the causal + ordering 8, and is therefore set to 8. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + * Number of asymmetric coalitions: 8 + * Causal ordering: {Solar.R, Wind}, {Temp}, {Day, Month_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 8 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 8 coalitions, 3 new. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -11.300 15.085 14.281 -2.2816 -5.201 + 2: 2 42.44 5.495 -6.312 -4.640 -1.6405 -8.286 + 3: 3 42.44 9.635 -32.764 7.451 0.8945 4.184 + +# output_categorical_asym_causal_mixed_cat + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 2 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -10.128 15.35 -10.26 4.526 + 2: 2 42.44 -4.316 -10.80 21.06 -20.769 + +# output_cat_asym_causal_mixed_cat_ad + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 16 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 8 of 16 coalitions, 2 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 10 of 16 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 12 of 16 coalitions, 2 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 14 of 16 coalitions, 2 new. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -3.774 8.585 -10.6692 5.35 + 2: 2 42.44 -1.083 -14.855 19.0929 -17.99 + 3: 3 42.44 15.582 -17.251 -0.1388 -16.56 + +# output_categorical_asym_causal_mixed_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + * Causal ordering: {Solar.R_factor, Wind_factor}, {Ozone_sub30_factor}, + {Month_factor} + * Components with confounding: {Solar.R_factor, Wind_factor} + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -7.113 11.37 -6.100 1.336 + 2: 2 42.44 -2.421 -21.49 23.445 -14.366 + 3: 3 42.44 11.296 -16.94 2.581 -15.297 + diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds new file mode 100644 index 0000000000000000000000000000000000000000..2c86f38882b122dd2c59401453817df2f2b00635 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_FALSE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds new file mode 100644 index 0000000000000000000000000000000000000000..7efc59830f424f7ce1f0af73a5d805e9cf7ee423 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds new file mode 100644 index 0000000000000000000000000000000000000000..2f53b0bb045ba7c577c2c5b89db4bf18334496eb Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..317c795b6d14ea8fe08218128c2e43b7e74f5c76 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_ctree.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds new file mode 100644 index 0000000000000000000000000000000000000000..7ff35f3b16ed7338050718ef0034c196b2b77e92 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_empirical.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds new file mode 100644 index 0000000000000000000000000000000000000000..0d5093ecbe9bf4e2adf0b209276c889687da55b9 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_caus_conf_mix_n_coal.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds new file mode 100644 index 0000000000000000000000000000000000000000..90492aadc7570f41001f9c04d83fa06bb9c71ad0 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..0feb8fa183c24764e61cd73019821b066f4b8162 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asym_cond_reg_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds new file mode 100644 index 0000000000000000000000000000000000000000..b14832cead838939789c39cc15677a039adef341 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_asymmetric_conditional.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds new file mode 100644 index 0000000000000000000000000000000000000000..a1f893d639e5f48411e704c300a76698f403f746 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_cat_asym_causal_mixed_cat_ad.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds new file mode 100644 index 0000000000000000000000000000000000000000..9382e91c43ba5c878adb880765be30b6622173a5 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_cat.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..dfa242805c7a4afc02f8d4045d7d1ddf8475eab4 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_categorical_asym_causal_mixed_ctree.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds new file mode 100644 index 0000000000000000000000000000000000000000..43e42e350238be00a54e49f359f39ffee864f4bc Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_asym_cond_reg.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds new file mode 100644 index 0000000000000000000000000000000000000000..acc0b0c51979e2d441375d935803dabc836b49b3 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..dca60b54b314ace3abbecce19e00e0418adad846 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_TRUE_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds new file mode 100644 index 0000000000000000000000000000000000000000..013be31f6798e38e8ec6d473c8be087506e20bdc Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds new file mode 100644 index 0000000000000000000000000000000000000000..7282711544cfa773b3439a5f89f91e18835aaf2b Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_mixed_sym_caus_conf_mixed_2.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds new file mode 100644 index 0000000000000000000000000000000000000000..eb3acdebbfd8cc73a037e9f55bb0e273c92e3335 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_FALSE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds new file mode 100644 index 0000000000000000000000000000000000000000..254104338ec12ab6af754b89f641e2ef8cc8af76 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds new file mode 100644 index 0000000000000000000000000000000000000000..ed45a6d4f93a26ddcb2e4ec8a64d6626e02c1564 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_TRUE_group.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds new file mode 100644 index 0000000000000000000000000000000000000000..6af3b48cec60aef74d77b9f68a46bd85ae657cf0 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds new file mode 100644 index 0000000000000000000000000000000000000000..bf5adcd67fe1510eec84d00a7e5b2a61ee40b9f6 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..f35ab24073b96a78d3c0fe14a6355e6574005e8a Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_sym_caus_conf_mix_group_iterative.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds new file mode 100644 index 0000000000000000000000000000000000000000..17d3b04e1973c84c4b46859a08423fed57eb6b7b Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_conditional.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds new file mode 100644 index 0000000000000000000000000000000000000000..39833ddbaefab2fb51642d68dcd2252aa8689c95 Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_gaussian.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds new file mode 100644 index 0000000000000000000000000000000000000000..4f2cb1c495f875a824b13b994c312d994dceef2b Binary files /dev/null and b/tests/testthat/_snaps/asymmetric-causal-output/output_symmetric_marginal_independence.rds differ diff --git a/tests/testthat/_snaps/asymmetric-causal-setup.md b/tests/testthat/_snaps/asymmetric-causal-setup.md new file mode 100644 index 0000000000000000000000000000000000000000..984aa7b47e370f4fc71a1de0bdfd0909a3f16f55 --- /dev/null +++ b/tests/testthat/_snaps/asymmetric-causal-setup.md @@ -0,0 +1,183 @@ +# asymmetric erroneous input: `causal_ordering` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:6), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:5, 5), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 2:5, 5), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + 1:2, 4), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Month", "Day", "Invalid feature name"), + confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains feature names (`Invalid feature name`) that are not in the data (`Solar.R`, `Wind`, `Temp`, `Month`, `Day`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Month", "Day", "Day"), confounding = NULL, + approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind", "Temp", "Day", "Day"), confounding = NULL, approach = "gaussian", + iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + "Solar.R", "Wind"), confounding = NULL, approach = "gaussian", iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all feature names or indices exactly once. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("Solar.R", "Wind", "Temp", "Month"), "Day"), confounding = NULL, + approach = "gaussian", group = list(A = c("Solar.R", "Wind"), B = "Temp", C = c( + "Month", "Day")), iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains group names (`Solar.R`, `Wind`, `Temp`, `Month`, `Day`) that are not in the data (`A`, `B`, `C`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("A", "C"), "Wrong name"), confounding = NULL, approach = "gaussian", + group = list(A = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE) + Condition + Error in `convert_feature_name_to_idx()`: + ! `causal_ordering` contains group names (`Wrong name`) that are not in the data (`A`, `B`, `C`). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = TRUE, causal_ordering = list( + c("A"), "B"), confounding = NULL, approach = "gaussian", group = list(A = c( + "Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), iterative = FALSE) + Condition + Error in `check_and_set_causal_ordering()`: + ! `causal_ordering` is incomplete/incorrect. It must contain all group names or indices exactly once. + +# asymmetric erroneous input: `approach` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = TRUE, approach = c("gaussian", "independence", + "empirical", "gaussian"), iterative = FALSE) + Condition + Error in `check_and_set_causal_sampling()`: + ! Causal Shapley values is not applicable for combined approaches. + +# asymmetric erroneous input: `asymmetric` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = c(FALSE, FALSE), + causal_ordering = list(1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = "Must be a single logical", + causal_ordering = list(1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = 1L, causal_ordering = list( + 1:2, 3:4, 5), confounding = TRUE, approach = "gaussian", iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `asymmetric` must be a single logical. + +# asymmetric erroneous input: `confounding` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = c("A", "B", "C"), approach = "gaussian", + iterative = FALSE) + Condition + Error in `get_parameters()`: + ! `confounding` must be a logical (vector). + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, asymmetric = FALSE, causal_ordering = list( + 1:2, 3:4, 5), confounding = c(TRUE, FALSE), approach = "gaussian", + iterative = FALSE) + Condition + Error in `check_and_set_confounding()`: + ! `confounding` must either be a single logical or a vector of logicals of the same length as the number of components in `causal_ordering` (3). + diff --git a/tests/testthat/_snaps/forecast-output.md b/tests/testthat/_snaps/forecast-output.md index dbc55f06fa91bfb95ec0fb8478a86f8d32c6c29e..e2fae1c194ae50e3cc5fdd42916d0f9fdb3af9c4 100644 --- a/tests/testthat/_snaps/forecast-output.md +++ b/tests/testthat/_snaps/forecast-output.md @@ -6,6 +6,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 4 of 4 coalitions. Output explain_idx horizon none Temp.1 Temp.2 @@ -24,6 +37,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 128, + and is therefore set to 2^n_features = 128. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 7 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 128 of 128 coalitions. Output explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.2 Wind.F1 Wind.F2 @@ -42,6 +68,82 @@ 5: 0.5630 6: -0.7615 +# forecast_output_arima_numeric_iterative + + Code + (out <- code) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + * Model class: + * Approach: empirical + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 9 + * Number of observations to explain: 2 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 10 of 512 coalitions, 10 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 30 of 512 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 78 of 512 coalitions, 6 new. + Output + explain_idx horizon none Temp.1 Temp.2 Temp.3 Wind.1 Wind.2 Wind.3 + + 1: 149 1 77.88 -2.795 -4.5597 -1.114 1.564 -1.8995 0.2087 + 2: 150 1 77.88 4.024 -0.5774 -4.589 -2.234 0.1985 -2.2827 + 3: 149 2 77.88 -3.701 -4.2427 -1.326 1.465 -1.9227 0.7060 + 4: 150 2 77.88 3.460 -0.9158 -5.264 -2.452 0.7709 -1.7864 + 5: 149 3 77.88 -4.721 -3.4208 -1.503 1.172 -0.4564 -0.6058 + 6: 150 3 77.88 2.811 0.4206 -5.361 -1.388 0.0752 -0.2130 + Wind.F1 Wind.F2 Wind.F3 + + 1: -1.9118 NA NA + 2: -0.1747 NA NA + 3: -1.1883 -0.6744 NA + 4: 0.7128 1.9982 NA + 5: -1.5436 -0.5418 2.8952 + 6: -0.6202 -0.8545 0.4549 + +# forecast_output_arima_numeric_iterative_groups + + Code + (out <- code) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + * Model class: + * Approach: empirical + * Iterative estimation: TRUE + * Number of group-wise Shapley values: 10 + * Number of observations to explain: 2 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 10 of 1024 coalitions, 10 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 28 of 1024 coalitions, 2 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 56 of 1024 coalitions, 12 new. + Output + explain_idx horizon none Temp Wind Solar.R Ozone + + 1: 149 1 77.88 -4.680 -3.6712 0.3230 -1.253 + 2: 150 1 77.88 -2.487 -3.6317 1.8415 -0.891 + 3: 149 2 77.88 -6.032 -4.1973 2.5973 -2.402 + 4: 150 2 77.88 -3.124 0.1986 0.8258 -2.245 + 5: 149 3 77.88 -7.777 1.1382 0.6962 -3.267 + 6: 150 3 77.88 -3.142 -1.6674 2.9047 -2.024 + # forecast_output_arima_numeric_no_xreg Code @@ -50,6 +152,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 4 of 4 coalitions. Output explain_idx horizon none Temp.1 Temp.2 @@ -68,6 +183,19 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output explain_idx horizon none Temp Wind @@ -86,3550 +214,26 @@ Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. - Condition - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] - Warning in `matrix()`: - data length [2] is not a sub-multiple or multiple of the number of rows [3] + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 8, + and is therefore set to 2^n_features = 8. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 3 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 8 of 8 coalitions. Output explain_idx horizon none Wind.F1 Wind.F2 Wind.F3 - 1: 149 1 77.88 -9.391 NA NA - 2: 150 1 77.88 -4.142 NA NA - 3: 149 2 77.88 -4.699 -4.6989 NA - 4: 150 2 77.88 -2.074 -2.0745 NA - 5: 149 3 77.88 -3.130 -4.6234 -3.130 - 6: 150 3 77.88 -1.381 -0.7147 -1.381 + 1: 149 1 77.88 -10.507 NA NA + 2: 150 1 77.88 -5.635 NA NA + 3: 149 2 77.88 -4.696 -6.189 NA + 4: 150 2 77.88 -2.071 -1.405 NA + 5: 149 3 77.88 -3.133 -3.133 -2.46 + 6: 150 3 77.88 -1.383 -1.383 -1.91 diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds index ca1606114745c2c2e706cf500db286269da43fcf..f7ed9883499bb93956c45811f1a69dc3da8c7d1a 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_ar_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds index bc7ca40af114d3c9a013d12bc930bd3cc28b943a..3b42ce6c3fddfce0104f25c3d1056b537dc5f1d0 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..6a691f19b3f83e46c07ee78eb862a9cf724e8a8a Binary files /dev/null and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds new file mode 100644 index 0000000000000000000000000000000000000000..0061a2ab86c9461e279cd065be709494602349ca Binary files /dev/null and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_iterative_groups.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds index f0974f44958f0604e03dc9fafe89ff9cdeb65b75..c51ff8268c263c308029da6d5271c2998dbed73b 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_lags.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds index 8fecb3578ad936235064eb00db70c0a8a98b169e..08001992ae7be3546bb1679696b432dde590e2d3 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_arima_numeric_no_xreg.rds differ diff --git a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds index 5dfbb93b20b46ae285f0b5cb65d1770b2ef2fbea..1ab8c99a16ed4b9a9a43dadffbfdbe2c8563a62d 100644 Binary files a/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds and b/tests/testthat/_snaps/forecast-output/forecast_output_forecast_ARIMA_group_numeric.rds differ diff --git a/tests/testthat/_snaps/forecast-setup.md b/tests/testthat/_snaps/forecast-setup.md index b3b968b2333562c9822c029be2d18c4250373de1..fdf2616a9b53a1241479f8c9df9a780fdd3164f9 100644 --- a/tests/testthat/_snaps/forecast-setup.md +++ b/tests/testthat/_snaps/forecast-setup.md @@ -3,14 +3,18 @@ Code model_custom_arima_temp <- model_arima_temp class(model_custom_arima_temp) <- "whatever" - explain_forecast(model = model_custom_arima_temp, y = data[1:150, "Temp"], - xreg = data[, "Wind"], train_idx = 2:148, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_custom_arima_temp, y = data_arima[ + 1:150, "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149: + 150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Message Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + Condition Error in `get_predict_model()`: ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). @@ -19,11 +23,11 @@ # erroneous input: `x_train/x_explain` Code - y_wrong_format <- data[, c("Temp", "Wind")] - explain_forecast(model = model_arima_temp, y = y_wrong_format, xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + y_wrong_format <- data_arima[, c("Temp", "Wind")] + explain_forecast(testing = TRUE, model = model_arima_temp, y = y_wrong_format, + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `y` has 2 columns (Temp,Wind). @@ -33,11 +37,11 @@ --- Code - xreg_wrong_format <- data[, c("Temp", "Wind")] - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = xreg_wrong_format, - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + xreg_wrong_format <- data_arima[, c("Temp", "Wind")] + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = xreg_wrong_format, train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` has 2 columns (Temp,Wind). @@ -47,12 +51,12 @@ --- Code - xreg_no_column_names <- data[, "Wind"] + xreg_no_column_names <- data_arima[, "Wind"] names(xreg_no_column_names) <- NULL - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = xreg_no_column_names, - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = xreg_no_column_names, train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` misses column names. @@ -60,45 +64,48 @@ # erroneous input: `model` Code - explain_forecast(y = data[1:150, "Temp"], xreg = data[, "Wind"], train_idx = 2: - 148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, - horizon = 3, approach = "independence", prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, y = data_arima[1:150, "Temp"], xreg = data_arima[ + , "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, + explain_xreg_lags = 2, horizon = 3, approach = "independence", phi0 = p0_ar) Condition Error in `explain_forecast()`: ! argument "model" is missing, with no default -# erroneous input: `prediction_zero` +# erroneous input: `phi0` Code p0_wrong_length <- p0_ar[1:2] - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_wrong_length, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", + phi0 = p0_wrong_length) Condition Error in `get_parameters()`: - ! `prediction_zero` (77.8823529411765, 77.8823529411765) must be numeric and match the output size of the model (3). + ! `phi0` (77.8823529411765, 77.8823529411765) must be numeric and match the output size of the model (3). -# erroneous input: `n_combinations` +# erroneous input: `max_n_coalitions` Code horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- horizon + explain_y_lags + explain_xreg_lags - 1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, n_batches = 1, n_combinations = n_combinations, + n_coalitions <- horizon + explain_y_lags + explain_xreg_lags - 1 + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, + horizon = horizon, approach = "independence", phi0 = p0_ar, max_n_coalitions = n_coalitions, group_lags = FALSE) Message Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 8),which will result in unreliable results. + It is therefore set to 10. + Condition - Error in `check_n_combinations()`: - ! `n_combinations` (6) has to be greater than the number of components to decompose the forecast onto: - `horizon` (3) + `explain_y_lags` (2) + sum(`explain_xreg_lags`) (2). + Error in `check_iterative_args()`: + ! `iterative_args$initial_n_coalitions` must be a single integer between 2 and `max_n_coalitions`. --- @@ -106,29 +113,47 @@ horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- 1 + 1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, n_batches = 1, n_combinations = n_combinations, + n_coalitions <- 1 + 1 + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, + horizon = horizon, approach = "independence", phi0 = p0_ar, max_n_coalitions = n_coalitions, group_lags = TRUE) Message Note: Feature names extracted from the model contains NA. Consistency checks between model and data is therefore disabled. - Condition - Error in `check_n_combinations()`: - ! `n_combinations` (2) has to be greater than the number of components to decompose the forecast onto: - ncol(`xreg`) (1) + 1 + Success with message: + max_n_coalitions is smaller than max(10, n_groups + 1 = 5),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 5 of 16 coalitions. + Output + explain_idx horizon none Temp Wind + + 1: 149 1 77.88 -8.252 -2.2557 + 2: 150 1 77.88 -2.977 -2.6587 + 3: 149 2 77.88 -8.252 -2.6320 + 4: 150 2 77.88 -2.977 -0.4990 + 5: 149 3 77.88 -8.256 -0.4697 + 6: 150 3 77.88 -2.981 -1.6952 # erroneous input: `train_idx` Code train_idx_too_short <- 2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_too_short, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_too_short, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `train_idx` must be a vector of positive finite integers and length > 1. @@ -137,10 +162,10 @@ Code train_idx_not_integer <- c(3:5) + 0.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_not_integer, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_not_integer, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `train_idx` must be a vector of positive finite integers and length > 1. @@ -149,10 +174,10 @@ Code train_idx_out_of_range <- 1:5 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = train_idx_out_of_range, explain_idx = 149:150, - explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = train_idx_out_of_range, + explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! The train (`train_idx`) and explain (`explain_idx`) indices must fit in the lagged data. @@ -162,10 +187,10 @@ Code explain_idx_not_integer <- c(3:5) + 0.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_idx` must be a vector of positive finite integers. @@ -174,10 +199,10 @@ Code explain_idx_out_of_range <- 1:5 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! The train (`train_idx`) and explain (`explain_idx`) indices must fit in the lagged data. @@ -187,10 +212,10 @@ Code explain_y_lags_negative <- -1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_negative, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_negative, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_y_lags` must be a vector of positive finite integers. @@ -199,10 +224,10 @@ Code explain_y_lags_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_not_integer, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_not_integer, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_y_lags` must be a vector of positive finite integers. @@ -211,10 +236,10 @@ Code explain_y_lags_more_than_one <- c(1, 2) - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_more_than_one, - explain_xreg_lags = 2, horizon = 3, approach = "independence", prediction_zero = p0_ar, - n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = explain_y_lags_more_than_one, explain_xreg_lags = 2, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `y` has 1 columns (Temp). @@ -225,9 +250,9 @@ Code explain_y_lags_zero <- 0 - explain_forecast(model = model_arima_temp_noxreg, y = data[1:150, "Temp"], - train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, horizon = 3, - approach = "independence", prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp_noxreg, y = data_arima[ + 1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, + horizon = 3, approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `explain_y_lags=0` is not allowed for models without exogeneous variables @@ -236,10 +261,10 @@ Code explain_xreg_lags_negative <- -2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_xreg_lags` must be a vector of positive finite integers. @@ -248,10 +273,10 @@ Code explain_xreg_lags_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `explain_xreg_lags` must be a vector of positive finite integers. @@ -260,10 +285,10 @@ Code explain_x_lags_wrong_length <- c(1, 2) - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, + approach = "independence", phi0 = p0_ar) Condition Error in `get_data_forecast()`: ! `xreg` has 1 columns (Wind). @@ -274,10 +299,10 @@ Code horizon_negative <- -2 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = horizon_negative, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_negative, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `horizon` must be a vector (or scalar) of positive integers. @@ -286,10 +311,10 @@ Code horizon_not_integer <- 2.1 - explain_forecast(model = model_arima_temp, y = data[1:150, "Temp"], xreg = data[, - "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, - explain_xreg_lags = 2, horizon = horizon_not_integer, approach = "independence", - prediction_zero = p0_ar, n_batches = 1) + explain_forecast(testing = TRUE, model = model_arima_temp, y = data_arima[1:150, + "Temp"], xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, + explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_not_integer, + approach = "independence", phi0 = p0_ar) Condition Error in `get_parameters()`: ! `horizon` must be a vector (or scalar) of positive integers. diff --git a/tests/testthat/_snaps/output.md b/tests/testthat/_snaps/output.md deleted file mode 100644 index d241853e10a5a1ce7115d269bfa1729fee9885dd..0000000000000000000000000000000000000000 --- a/tests/testthat/_snaps/output.md +++ /dev/null @@ -1,356 +0,0 @@ -# output_lm_numeric_independence - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_independence_MSEv_Shapley_weights - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 - 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 - 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 - -# output_lm_numeric_empirical_n_combinations - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -5.795 15.320 8.557 -7.547 2.066 - 2: 42.44 3.266 -3.252 -7.693 -7.663 1.462 - 3: 42.44 4.290 -24.395 6.739 -1.006 -3.197 - -# output_lm_numeric_empirical_independence - - Code - (out <- code) - Condition - Warning in `setup_approach.empirical()`: - Using empirical.type = 'independence' for approach = 'empirical' is deprecated. - Please use approach = 'independence' instead. - Message - - Success with message: - empirical.eta force set to 1 for empirical.type = 'independence' - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical_AICc_each - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -15.66 6.823 17.5092 0.2463 3.6847 - 2: 42.44 10.70 -1.063 -10.6804 -13.0305 0.1983 - 3: 42.44 14.65 -19.946 0.9675 -7.3433 -5.8946 - -# output_lm_numeric_empirical_AICc_full - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -14.98 6.3170 17.4103 0.2876 3.5623 - 2: 42.44 12.42 0.1482 -10.2338 -16.4096 0.1967 - 3: 42.44 15.74 -19.7250 0.9992 -8.6950 -5.8886 - -# output_lm_numeric_gaussian - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.117 7.438 14.0026 0.8602 -1.5813 - 2: 42.44 5.278 -5.219 -12.1079 -0.8073 -1.0235 - 3: 42.44 7.867 -25.995 -0.1377 -0.2368 0.9342 - -# output_lm_numeric_copula - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -5.960 7.046 13.863 -0.274 -2.074 - 2: 42.44 4.482 -4.892 -10.491 -1.659 -1.319 - 3: 42.44 6.587 -25.533 1.279 -1.043 1.142 - -# output_lm_numeric_ctree - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 - 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 - 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 - -# output_lm_numeric_vaeac - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -6.534 9.146 18.8166 -5.238 -3.5884 - 2: 42.44 1.421 -5.329 -6.8472 -3.668 0.5436 - 3: 42.44 7.073 -18.914 -0.6391 -6.038 0.9493 - -# output_lm_categorical_ctree - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -6.206 15.38 -6.705 -2.973 - 2: 42.44 -5.764 -17.71 21.866 -13.219 - 3: 42.44 7.101 -21.78 1.730 -5.413 - -# output_lm_categorical_vaeac - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 1.795 10.32 -6.919 -5.704 - 2: 42.44 -2.438 -18.15 20.755 -14.999 - 3: 42.44 8.299 -23.71 8.751 -11.708 - -# output_lm_categorical_categorical - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 13.656 -19.73 4.369 -16.659 - 2: 42.44 -5.448 11.31 -11.445 5.078 - 3: 42.44 -7.493 -12.27 19.672 -14.744 - -# output_lm_categorical_independence - - Code - (out <- code) - Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -5.252 13.95 -7.041 -2.167 - 2: 42.44 -5.252 -15.61 20.086 -14.050 - 3: 42.44 4.833 -15.61 0.596 -8.178 - -# output_lm_ts_timeseries - - Code - (out <- code) - Output - none S1 S2 S3 S4 - - 1: 4.895 -0.5261 0.7831 -0.21023 -0.3885 - 2: 4.895 -0.6310 1.6288 -0.04498 -2.9298 - -# output_lm_numeric_comb1 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.746 9.03 15.366 -2.619 -0.4293 - 2: 42.44 3.126 -4.50 -7.789 -4.401 -0.3161 - 3: 42.44 7.037 -22.86 -1.837 0.607 -0.5181 - -# output_lm_numeric_comb2 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.294 9.327 17.31641 -1.754 -2.9935 - 2: 42.44 5.194 -5.506 -8.45049 -2.935 -2.1810 - 3: 42.44 6.452 -22.967 -0.09553 -1.310 0.3519 - -# output_lm_numeric_comb3 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -6.952 10.777 12.160 -3.641 0.25767 - 2: 42.44 2.538 -2.586 -8.503 -5.376 0.04789 - 3: 42.44 5.803 -22.122 3.362 -2.926 -1.68514 - -# output_lm_mixed_independence - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - -# output_lm_mixed_ctree - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -9.165 11.815 13.184 -0.4473 -4.802 - 2: 42.44 3.652 -5.782 -6.524 -0.4349 -6.295 - 3: 42.44 6.268 -21.441 -7.323 1.6330 10.262 - -# output_lm_mixed_vaeac - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -3.629 8.898 17.330 -2.5409 -9.4742 - 2: 42.44 3.938 -3.933 -8.190 0.6284 -7.8259 - 3: 42.44 5.711 -15.928 -3.216 2.2431 0.5899 - -# output_lm_mixed_comb - - Code - (out <- code) - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -7.886 10.511 16.292 -0.9519 -7.382 - 2: 42.44 5.001 -4.925 -7.015 -1.0954 -7.349 - 3: 42.44 5.505 -20.583 -4.328 0.7825 8.023 - -# output_custom_lm_numeric_independence_1 - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_custom_lm_numeric_independence_2 - - Code - (out <- code) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_custom_xgboost_mixed_dummy_ctree - - Code - (out <- code) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -5.603 13.05 20.43 0.08508 -0.2664 - 2: 42.44 4.645 -12.57 -16.65 1.29133 -2.1574 - 3: 42.44 5.451 -14.01 -19.72 1.32503 6.3851 - -# output_lm_numeric_interaction - - Code - (out <- code) - Output - none Solar.R Wind - - 1: 42.44 -13.818 10.579 - 2: 42.44 4.642 -6.287 - 3: 42.44 4.452 -34.602 - -# output_lm_numeric_ctree_parallelized - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.124 9.509 17.139 -1.4711 -3.451 - 2: 42.44 5.342 -6.097 -8.232 -2.8129 -2.079 - 3: 42.44 6.901 -21.079 -4.687 0.1494 1.146 - -# output_lm_numeric_independence_more_batches - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - -# output_lm_numeric_empirical_progress - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -13.252 15.541 12.826 -5.77179 3.259 - 2: 42.44 2.758 -3.325 -7.992 -7.12800 1.808 - 3: 42.44 6.805 -22.126 3.730 -0.09235 -5.885 - -# output_lm_numeric_independence_keep_samp_for_vS - - Code - (out <- code) - Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -4.537 8.269 17.517 -5.581 -3.066 - 2: 42.44 2.250 -3.345 -5.232 -5.581 -1.971 - 3: 42.44 3.708 -18.610 -1.440 -2.541 1.316 - diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds deleted file mode 100644 index faa720cd357948e864839e8e066f3382661905bd..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_1.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds deleted file mode 100644 index faa720cd357948e864839e8e066f3382661905bd..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_custom_lm_numeric_independence_2.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds deleted file mode 100644 index f6b3d80cad358b1953121d106512f4138d6ad1f4..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_custom_xgboost_mixed_dummy_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds deleted file mode 100644 index eddfb6733bc46da2526117855f573f4f7c9bb50b..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/output/output_lm_categorical_independence.rds deleted file mode 100644 index 140ceb5d0a3cf437a45b60c60b4c51eb94118397..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_method.rds b/tests/testthat/_snaps/output/output_lm_categorical_method.rds deleted file mode 100644 index e5c62746f9cdfc76678491b3449c54fa2868183e..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_method.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds b/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds deleted file mode 100644 index 94b04392ce6e49d0364d0497df62460f8bc7fc43..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_categorical_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/output/output_lm_mixed_comb.rds deleted file mode 100644 index 8300a78bc7586d33bca72a0f46c5c30dc4a1aadf..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_comb.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds deleted file mode 100644 index 429c7837a594cea1778ee6ee867fe9064407aa81..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/output/output_lm_mixed_independence.rds deleted file mode 100644 index 14024d680a8aba20f6404fae35051ef8e7d9c3a0..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds b/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds deleted file mode 100644 index ab0abc134f899060e75bae0112d0ee84cdbd6887..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_mixed_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds deleted file mode 100644 index 67e8ca9820a852bb59ae19078bf41b09b51c5b9b..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb1.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds deleted file mode 100644 index aebe607e89909671b104996cb5424e1cba512d6d..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb2.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds deleted file mode 100644 index 8dfecc3ebf224852f31524d504886b4632c7047f..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_comb3.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/output/output_lm_numeric_copula.rds deleted file mode 100644 index f0ce11bc296b2e46da060914852c7573b2f17b20..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_copula.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds deleted file mode 100644 index cd92b5926b609bb839d3a945f4466ce66d19d286..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds deleted file mode 100644 index cd92b5926b609bb839d3a945f4466ce66d19d286..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_ctree_parallelized.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds deleted file mode 100644 index cc396937fc707d09c2314b96be2d69f40bc7e9f7..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds deleted file mode 100644 index f3dd31c55f5c7bfa47596f967275d4491dd6b59f..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_each.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds deleted file mode 100644 index 45c1baa528581960b0193647677b1f1a3ad26d0c..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_AICc_full.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds deleted file mode 100644 index 873268bc81e9d81978a6eff13184f1d6719a03c5..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds deleted file mode 100644 index acf7f5e78c3c15745458d26b75072ce15f0bb134..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_n_combinations.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds deleted file mode 100644 index b7311ca0a75bccd84293586f6826978154bd565f..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_empirical_progress.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds deleted file mode 100644 index 628f63a1cfd526d271640f68438b1ba1b4837b52..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_gaussian.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence.rds deleted file mode 100644 index 46cdda26bf5ddbfa8cd35c503c8a481eaea44c33..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds deleted file mode 100644 index 5273db3651bd5d397d545257c3f35cf1bd30582c..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_MSEv_Shapley_weights.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds deleted file mode 100644 index b9142b8571feeb51b84391c6a61f9f918fda02ca..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_keep_samp_for_vS.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds b/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds deleted file mode 100644 index e05527ffc239d4d830dfa867e3bab627414b3fd5..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_independence_n_batches_10.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds deleted file mode 100644 index 4696060f750eb1d25f0f17eee1c2559fa7528bfc..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_interaction.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds b/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds deleted file mode 100644 index e68838c69ef22f82ced81adf3571e80ea5bd3659..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_numeric_vaeac.rds and /dev/null differ diff --git a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/output/output_lm_timeseries_method.rds deleted file mode 100644 index cf15a0fb06404165ea8e600ddcdad68d5b3baa99..0000000000000000000000000000000000000000 Binary files a/tests/testthat/_snaps/output/output_lm_timeseries_method.rds and /dev/null differ diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-default.svg b/tests/testthat/_snaps/plot/beeswarm-plot-default.svg index d51801e9dc8e4e5adfdf98f2dfbe847a4ddb5cd6..32b3989d083dd6fec9563f10569f2813424451d2 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-default.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-default.svg @@ -47,7 +47,7 @@ - + diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg b/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg index 0e8d2fc048ccc4cb7cb79bdec69cbce35970450b..e3afa735d645f7b35dab871c4e192ff87b995c58 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-index-x-explain-1-2.svg @@ -41,7 +41,7 @@ - + diff --git a/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg b/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg index 81f75d489f3ad2b2e8e7b20adc547edcd3481ac0..a177c19ff0095a3d63a6b0a84015d32655eb8174 100644 --- a/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg +++ b/tests/testthat/_snaps/plot/beeswarm-plot-new-colors.svg @@ -47,7 +47,7 @@ - + diff --git a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg index 20d6fe5e48392dbaf9ad9c3a6331b8106695d61a..2b246ec31121c8c3f9001877fde8e419fcd220b9 100644 --- a/tests/testthat/_snaps/plot/msev-bar-50-ci.svg +++ b/tests/testthat/_snaps/plot/msev-bar-50-ci.svg @@ -28,18 +28,18 @@ - - + + - - - - - - + + + + + + @@ -94,13 +94,13 @@ 32 -combinations and - -3 - -explicands with - -50 -% CI +coalitions and + +3 + +explicands with + +50 +% CI diff --git a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg index c6d232890d8ae1f51be4e2995792175982d2a37c..ffd02e4dd7276fe2ee0caf7735ed3e22b0a6dad9 100644 --- a/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg +++ b/tests/testthat/_snaps/plot/msev-bar-with-ci-different-width.svg @@ -28,8 +28,8 @@ - - + + @@ -80,10 +80,10 @@ 32 -combinations and - -3 - -explicands +coalitions and + +3 + +explicands diff --git a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg index 053323a672cf1ae33eeae3995e1da8bc12daac84..a59ba4a7254560ff812d41110f0a8e4895451486 100644 --- a/tests/testthat/_snaps/plot/msev-bar-without-ci.svg +++ b/tests/testthat/_snaps/plot/msev-bar-without-ci.svg @@ -28,8 +28,8 @@ - - + + @@ -80,10 +80,10 @@ 32 -combinations and - -3 - -explicands +coalitions and + +3 + +explicands diff --git a/tests/testthat/_snaps/plot/msev-bar.svg b/tests/testthat/_snaps/plot/msev-bar.svg index 57d503c7051e44f1e3d0310a251fda1322955947..e847f531ba0839acd85b6a78af0f2aca863dbac6 100644 --- a/tests/testthat/_snaps/plot/msev-bar.svg +++ b/tests/testthat/_snaps/plot/msev-bar.svg @@ -28,18 +28,18 @@ - - + + - - - - - - + + + + + + @@ -92,13 +92,13 @@ 32 -combinations and - -3 - -explicands with - -95 -% CI +coalitions and + +3 + +explicands with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg similarity index 74% rename from tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg rename to tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg index f6752c74c1738bf8f5ef9d3cae58a0cd8918bf64..50d7b01d79c5f69d4f657bef71051a3d35c4b61d 100644 --- a/tests/testthat/_snaps/plot/msev-combination-bar-specified-width.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-bar-specified-width.svg @@ -27,139 +27,139 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 -250 -500 -750 -1000 +250 +500 +750 +1000 - - - - + + + + @@ -220,13 +220,13 @@ 29 30 31 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -250,6 +250,6 @@ 3 -explicands for each combination +explicands for each coalition diff --git a/tests/testthat/_snaps/plot/msev-combination-bar.svg b/tests/testthat/_snaps/plot/msev-coalition-bar.svg similarity index 78% rename from tests/testthat/_snaps/plot/msev-combination-bar.svg rename to tests/testthat/_snaps/plot/msev-coalition-bar.svg index 1ce6b589182a41868881838374eb4fa2a73599ac..9387f05cb1df611de6dbb656fe745839cd0708d7 100644 --- a/tests/testthat/_snaps/plot/msev-combination-bar.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-bar.svg @@ -57,66 +57,66 @@ - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + - + - - - + + + - - + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -237,185 +237,185 @@ - - - - - - - - + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -580,13 +580,13 @@ 29 30 31 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -610,9 +610,9 @@ 3 -explicands for each combination with - -95 -% CI +explicands for each coalition with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-combination-line-point.svg b/tests/testthat/_snaps/plot/msev-coalition-line-point.svg similarity index 54% rename from tests/testthat/_snaps/plot/msev-combination-line-point.svg rename to tests/testthat/_snaps/plot/msev-coalition-line-point.svg index c971fffdf85305aa7c92db4579c10a72de1e499b..3065d2b0111e75d34b365464d956e0a7d1c315c0 100644 --- a/tests/testthat/_snaps/plot/msev-combination-line-point.svg +++ b/tests/testthat/_snaps/plot/msev-coalition-line-point.svg @@ -27,143 +27,143 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + 0 -250 -500 -750 -1000 +250 +500 +750 +1000 - - - - + + + + @@ -172,13 +172,13 @@ 10 20 30 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -206,6 +206,6 @@ 3 -explicands for each combination +explicands for each coalition diff --git a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg b/tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg similarity index 80% rename from tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg rename to tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg index 6c46c88974f6bf41f8218f66d2527c59651e6d7d..8622872a86f6ad005fe8e16bfddcbedff6856ecc 100644 --- a/tests/testthat/_snaps/plot/msev-combinations-for-specified-combinations.svg +++ b/tests/testthat/_snaps/plot/msev-coalitions-for-specified-coalitions.svg @@ -33,18 +33,18 @@ - - - - - - - - - - - - + + + + + + + + + + + + @@ -69,42 +69,42 @@ - - - - - + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -148,13 +148,13 @@ 13 14 15 -id_combination -M -S -E -v - -(combination) +id_coalition +M +S +E +v + +(coalition) Method @@ -178,9 +178,9 @@ 3 -explicands for each combination with - -95 -% CI +explicands for each coalition with + +95 +% CI diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg index c3be8d1e8266ebec41b73e9cee12ac78f2d5e383..e5a67dd6bfe27f63649ecf3cd333254442d8f137 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-bar-specified-width.svg @@ -27,27 +27,27 @@ - - - - - + + + + + - - - - - - + + + + + + 0 -100 -200 +100 +200 - - + + @@ -84,6 +84,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-bar.svg b/tests/testthat/_snaps/plot/msev-explicand-bar.svg index 02cdd7d64c7f4e55d8a02fd5dafee7b3a7bd369b..04dc5b6b6b3a58f949a844822d9da2d4197c123f 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-bar.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-bar.svg @@ -27,27 +27,27 @@ - - - - - + + + + + - - - - - - + + + + + + 0 -100 -200 +100 +200 - - + + @@ -84,6 +84,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg index 87b0706fdb3bd961ef9c2cf156bfacac80b8c2a8..4e271336b7697e90df862cbe682dc4a9cf2b43d5 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-for-specified-observations.svg @@ -27,23 +27,23 @@ - - - + + + - - - - + + + + 0 -100 -200 +100 +200 - - + + 1 @@ -78,6 +78,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg index 13aa02cfb26d79b735877f07ad734425b9a06087..332e3f9ab9f9ab001f4fcaf2d42b1983569cba00 100644 --- a/tests/testthat/_snaps/plot/msev-explicand-line-point.svg +++ b/tests/testthat/_snaps/plot/msev-explicand-line-point.svg @@ -27,33 +27,33 @@ - + - - - + + + - - - - - - - - - - + + + + + + + + + + -100 -150 -200 -250 - - - - +100 +150 +200 +250 + + + + @@ -98,6 +98,6 @@ 32 -combinations for each explicand +coalitions for each explicand diff --git a/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg b/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg index a8a727adcdc1b6ef8eef0eade8bb686b6a6a2bff..c74310ec600cdf41608760281a642a26df52a660 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-approaches-default.svg @@ -28,24 +28,24 @@ - - + + - - + + - - + + - - + + - - + + @@ -58,26 +58,26 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + @@ -89,26 +89,26 @@ - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + @@ -163,26 +163,26 @@ id: 2, pred = 28.57 - - - --20 --10 -0 + + + +-20 +-10 +0 -10 0 10 - - - - --10 --5 -0 -5 + + + + +-10 +-5 +0 +5 Day = 9 Month = 9 Temp = 75 diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg index 379b1661f93d24d1ef64c4a0225208cb1b67411c..d49a48ee50d9d0ecb65061c2c6e9219dc2671c4f 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-1.svg @@ -27,31 +27,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -63,31 +63,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -99,31 +99,31 @@ - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + @@ -160,24 +160,30 @@ id: 3, pred = 24.88 - - - --25 -0 -25 - - - --25 -0 -25 - - - --25 -0 -25 + + + + +-20 +0 +20 +40 + + + + +-20 +0 +20 +40 + + + + +-20 +0 +20 +40 Day = 21 Month = 8 Temp = 77 diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg index ef5bde04a50497075563bcaea7de040939418dbb..ef79bde818d166973c4066f1ce147816f7655ef7 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-2.svg @@ -27,27 +27,27 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + @@ -59,27 +59,27 @@ - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + @@ -125,22 +125,22 @@ Temp = 87 Month = 9 Day = 5 --20 --10 -0 -10 - - - - --20 --10 -0 -10 - - - - +-20 +-10 +0 +10 + + + + +-20 +-10 +0 +10 + + + + Feature and value Feature contribution (Shapley value diff --git a/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg b/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg index 3f00d9fef080789b4f0ed0fdd173dff5bda24e02..c3b42791a06f199daada70ce5477bc3e7cc60165 100644 --- a/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg +++ b/tests/testthat/_snaps/plot/plot-sv-several-div-input-3.svg @@ -32,16 +32,16 @@ - - + + - - + + - - + + @@ -59,16 +59,16 @@ - - + + - - + + - - + + @@ -86,16 +86,16 @@ - - + + - - + + - - + + diff --git a/tests/testthat/_snaps/regression-output.md b/tests/testthat/_snaps/regression-output.md index 4b8f56c25457ab71bcc1e22b98db2de7a3380e31..73230c66448d41a6e83389bd3b3e5495a707bde1 100644 --- a/tests/testthat/_snaps/regression-output.md +++ b/tests/testthat/_snaps/regression-output.md @@ -1,165 +1,430 @@ +# output_lm_numeric_lm_separate_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: TRUE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- iterative computation started -- + + -- Iteration 1 ----------------------------------------------------------------- + i Using 5 of 32 coalitions, 5 new. + + -- Iteration 2 ----------------------------------------------------------------- + i Using 10 of 32 coalitions, 4 new. + + -- Iteration 3 ----------------------------------------------------------------- + i Using 12 of 32 coalitions, 2 new. + + -- Iteration 4 ----------------------------------------------------------------- + i Using 16 of 32 coalitions, 4 new. + + -- Iteration 5 ----------------------------------------------------------------- + i Using 22 of 32 coalitions, 6 new. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.727 8.110 14.4650 0.7756 -2.0211 + 2: 2 42.44 4.725 -4.636 -11.7582 -0.9153 -1.2956 + 3: 3 42.44 7.253 -25.505 0.1828 -0.2723 0.7736 + # output_lm_numeric_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -8.577 7.845 14.4756 0.6251 -1.7664 - 2: 42.44 4.818 -4.811 -11.6350 -1.0423 -1.2086 - 3: 42.44 7.406 -25.587 0.3353 -0.4718 0.7491 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.577 7.845 14.4756 0.6251 -1.7664 + 2: 2 42.44 4.818 -4.811 -11.6350 -1.0423 -1.2086 + 3: 3 42.44 7.406 -25.587 0.3353 -0.4718 0.7491 # output_lm_numeric_lm_separate_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -7.806 14.811 5.751 4.056 -4.2111 - 2: 42.44 5.056 -7.055 -16.887 5.976 -0.9692 - 3: 42.44 7.020 -33.059 2.395 3.782 2.2943 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.593 8.491 15.3573 -0.9151 -1.739 + 2: 2 42.44 4.948 -3.745 -10.6547 -2.8369 -1.591 + 3: 3 42.44 7.129 -25.351 0.3282 -1.3110 1.637 # output_lm_categorical_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -9.806 18.60 -11.788 2.489 - 2: 42.44 -7.256 -18.88 24.751 -13.445 - 3: 42.44 15.594 -26.01 5.887 -13.834 + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -9.806 18.60 -11.788 2.489 + 2: 2 42.44 -7.256 -18.88 24.751 -13.445 + 3: 3 42.44 15.594 -26.01 5.887 -13.834 # output_lm_mixed_lm_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.782 8.165 20.389 -1.2383 -7.950 - 2: 42.44 4.623 -3.551 -6.199 -0.9110 -9.345 - 3: 42.44 8.029 -25.200 -4.821 0.4172 10.975 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -8.782 8.165 20.389 -1.2383 -7.950 + 2: 2 42.44 4.623 -3.551 -6.199 -0.9110 -9.345 + 3: 3 42.44 8.029 -25.200 -4.821 0.4172 10.975 # output_lm_mixed_splines_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.083 7.102 18.732 1.483 -8.651 - 2: 42.44 6.147 -4.314 -6.445 -2.136 -8.635 - 3: 42.44 7.536 -22.504 -5.081 -2.170 11.619 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -8.083 7.102 18.732 1.483 -8.651 + 2: 2 42.44 6.147 -4.314 -6.445 -2.136 -8.635 + 3: 3 42.44 7.536 -22.504 -5.081 -2.170 11.619 # output_lm_mixed_decision_tree_cv_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.131 12.303 9.935 1.6221 -5.145 - 2: 42.44 2.907 -5.119 -7.128 1.7841 -7.827 - 3: 42.44 6.237 -9.010 -17.927 -0.6915 10.791 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.742 12.044 9.676 2.0107 -5.405 + 2: 2 42.44 2.688 -4.973 -6.982 1.5650 -7.681 + 3: 3 42.44 6.018 -8.864 -17.781 -0.9106 10.937 # output_lm_mixed_decision_tree_cv_separate_parallel Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -8.131 12.303 9.935 1.6221 -5.145 - 2: 42.44 2.907 -5.119 -7.128 1.7841 -7.827 - 3: 42.44 6.237 -9.010 -17.927 -0.6915 10.791 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.742 12.044 9.676 2.0107 -5.405 + 2: 2 42.44 2.688 -4.973 -6.982 1.5650 -7.681 + 3: 3 42.44 6.018 -8.864 -17.781 -0.9106 10.937 # output_lm_mixed_xgboost_separate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_separate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -13.991 14.352 16.490 1.82 -8.088 + 2: 2 42.44 8.183 -1.463 -16.499 3.63 -9.233 + 3: 3 42.44 3.364 -14.946 0.401 -11.32 11.905 + +# output_lm_numeric_lm_surrogate_iterative + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Warning in `check_and_set_iterative()`: + Iterative estimation of Shapley values are not supported for approach = regression_surrogate. Setting iterative = FALSE. + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -13.991 14.352 16.490 1.82 -8.088 - 2: 42.44 8.183 -1.463 -16.499 3.63 -9.233 - 3: 42.44 3.364 -14.946 0.401 -11.32 11.905 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 + 2: 2 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 + 3: 3 42.44 6.801 -24.090 -1.295 0.1202 0.8953 # output_lm_numeric_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 - 2: 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 - 3: 42.44 6.801 -24.090 -1.295 0.1202 0.8953 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.273 9.578 16.536 -1.2690 -2.9707 + 2: 2 42.44 2.623 -5.766 -6.717 -1.4694 -2.5496 + 3: 3 42.44 6.801 -24.090 -1.295 0.1202 0.8953 # output_lm_numeric_lm_surrogate_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.6804 12.2171 11.4871 0.74529 -2.1671 - 2: 42.44 0.6882 0.3332 -12.8835 1.93235 -3.9496 - 3: 42.44 7.8022 -26.0731 -0.2148 0.04831 0.8691 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.946 9.182 16.2078 -2.630 -0.2120 + 2: 2 42.44 2.239 -6.194 -7.0743 -2.630 -0.2199 + 3: 3 42.44 8.127 -24.230 0.4572 -1.188 -0.7344 # output_lm_numeric_lm_surrogate_reg_surr_n_comb Code (out <- code) + Message + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 10 of 32 coalitions. Output - none Solar.R Wind Temp Month Day - - 1: 42.44 -9.6804 12.2171 11.4871 0.74529 -2.1671 - 2: 42.44 0.6882 0.3332 -12.8835 1.93235 -3.9496 - 3: 42.44 7.8022 -26.0731 -0.2148 0.04831 0.8691 + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.946 9.182 16.2078 -2.630 -0.2120 + 2: 2 42.44 2.239 -6.194 -7.0743 -2.630 -0.2199 + 3: 3 42.44 8.127 -24.230 0.4572 -1.188 -0.7344 # output_lm_categorical_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. Output - none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor - - 1: 42.44 -7.137 16.29 -9.895 0.2304 - 2: 42.44 -6.018 -16.28 23.091 -15.6258 - 3: 42.44 10.042 -18.58 2.415 -12.2431 + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -7.137 16.29 -9.895 0.2304 + 2: 2 42.44 -6.018 -16.28 23.091 -15.6258 + 3: 3 42.44 10.042 -18.58 2.415 -12.2431 # output_lm_mixed_lm_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -7.427 10.831 16.477 -0.6280 -8.669 - 2: 42.44 3.916 -4.232 -4.849 -0.8776 -9.341 - 3: 42.44 5.629 -24.012 -2.274 -0.4774 10.534 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.427 10.831 16.477 -0.6280 -8.669 + 2: 2 42.44 3.916 -4.232 -4.849 -0.8776 -9.341 + 3: 3 42.44 5.629 -24.012 -2.274 -0.4774 10.534 # output_lm_mixed_decision_tree_cv_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.219 -4.219 27.460 -4.219 -4.219 - 2: 42.44 -3.077 -3.077 -3.077 -3.077 -3.077 - 3: 42.44 -6.716 -6.716 -6.716 -6.716 16.262 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.219 -4.219 27.460 -4.219 -4.219 + 2: 2 42.44 -3.077 -3.077 -3.077 -3.077 -3.077 + 3: 3 42.44 -6.716 -6.716 -6.716 -6.716 16.262 # output_lm_mixed_xgboost_surrogate Code (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -11.165 8.002 20.61 2.030 -8.896 - 2: 42.44 4.143 -1.515 -11.23 2.025 -8.806 - 3: 42.44 6.515 -18.268 -4.06 -3.992 9.204 + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -11.165 8.002 20.61 2.030 -8.896 + 2: 2 42.44 4.143 -1.515 -11.23 2.025 -8.806 + 3: 3 42.44 6.515 -18.268 -4.06 -3.992 9.204 diff --git a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds index 0bf5e6e52978492a237aa6c8cab7e9e74e25a561..d5bc1b7efd10cc698f373ddc673b6d31667d588e 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds index f859e3d75470029e2660c74999ad7766351ac2c5..5287a9c1f514ccd9d64783a4a9bc34ce05999e3b 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_categorical_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds index 54e491a349f70e16754d638fe5ca0efeb7b66d22..374301c6f3151127f3d56690e8aa7187c97e3aef 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds index 959f84115acc2bb8e4d6e98e7c6e57f3beda49db..f0b5651a3734fb23c0b3704db2683a9cbc86860b 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_separate_parallel.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds index fb0af97ebe70713218eaae72f0806c381c545123..f3e2341ec3f0ab46cf77a8dfbdfbdfb5bc3b4775 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_decision_tree_cv_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds index b45d28996444847248976de842129fa22e727237..afd7c3d30cefd55d0dae92d94ef9eba02966e4e5 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds index 46e511c5812b114fcbaaadebba54670ed8b2fd1b..33203d03ac763c577a77a39d0f3f03bd201a62b6 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds index 2a7766305c37e191a6a3ca66d7d43d0826edb0ea..77bedf0adb15e419d6c5b87df62cde9108a481fb 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_splines_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds index c187df49e987c44c69a3e14bfac1f7f99c1954d1..6582ba9bf7252f6b6689f0bde5ae4f2b63ae9b07 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds index 4fc7f83c5f509140d487d7ab3548b0698e0afa7a..194dea76122688f8fc0abd5b7c47530075f4519b 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_mixed_xgboost_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds index 365fd2c69c54cade2b574a1a29f53ac0681c2f63..dc65779ef9392fdc7d9253273732e01cb7b1680e 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..c2ff52301ca4cfbda5831465845cf89150aa269c Binary files /dev/null and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_iterative.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds index f14c13b359de264b8c48712485be0fdc8da8cbc1..e15da382d691200239edfe398dc3cd6c8815e2a2 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_separate_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds index 373f99a3d58bf3c2688d416602822f0a02ab14ee..0e18b9b3e11ce919295a3f795754724f27c6cffc 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds new file mode 100644 index 0000000000000000000000000000000000000000..0e18b9b3e11ce919295a3f795754724f27c6cffc Binary files /dev/null and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_iterative.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds index d5bf3bb59a91569f4249e7606316a91267c919a3..6539290216db9f8704378981d97ee0e6870d3980 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds index 96fcf7828ab7c3d0f47aba5a8299f276f1e69225..114fa6707bf2f93e70952e21ba8301dedd9de323 100644 Binary files a/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds and b/tests/testthat/_snaps/regression-output/output_lm_numeric_lm_surrogate_reg_surr_n_comb.rds differ diff --git a/tests/testthat/_snaps/regression-setup.md b/tests/testthat/_snaps/regression-setup.md index 6cf8babcf96b8b5059b51b4886761f8b0acda0f2..754236c2e5513431490a777d9d823f93195bb9f4 100644 --- a/tests/testthat/_snaps/regression-setup.md +++ b/tests/testthat/_snaps/regression-setup.md @@ -1,9 +1,9 @@ # regression erroneous input: `approach` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = c( - "regression_surrogate", "gaussian", "independence", "empirical"), ) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = c("regression_surrogate", + "gaussian", "independence", "empirical"), iterative = FALSE) Condition Error in `check_approach()`: ! The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches. @@ -11,9 +11,9 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = c( - "regression_separate", "gaussian", "independence", "empirical"), ) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = c("regression_separate", + "gaussian", "independence", "empirical"), iterative = FALSE) Condition Error in `check_approach()`: ! The `regression_separate` and `regression_surrogate` approaches cannot be combined with other approaches. @@ -21,9 +21,14 @@ # regression erroneous input: `regression.model` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = NULL) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.model` must be a tidymodels object with class 'model_spec'. See documentation. @@ -31,9 +36,14 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = lm) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.model` must be a tidymodels object with class 'model_spec'. See documentation. @@ -41,10 +51,15 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression")) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.tune_values` must be provided when `regression.model` contains hyperparameters to tune. @@ -52,11 +67,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(num_terms = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('num_terms') must match. @@ -64,11 +84,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3), num_terms = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('tree_depth', 'num_terms') must match. @@ -76,11 +101,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('') and `regression.tune_values` ('tree_depth') must match. @@ -88,9 +118,23 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.tune_values = data.frame(tree_depth = c(1, 2, 3))) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('') and `regression.tune_values` ('tree_depth') must match. @@ -98,11 +142,16 @@ # regression erroneous input: `regression.tune_values` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = as.matrix(data.frame( tree_depth = c(1, 2, 3)))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! `regression.tune_values` must be of either class `data.frame` or `function`. See documentation. @@ -110,10 +159,15 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) c(1, 2, 3)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The output of the user provided `regression.tune_values` function must be of class `data.frame`. @@ -121,11 +175,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) data.frame( wrong_name = c(1, 2, 3))) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.get_tune()`: ! The tunable parameters in `regression.model` ('tree_depth') and `regression.tune_values` ('wrong_name') must match. @@ -133,11 +192,16 @@ # regression erroneous input: `regression.vfold_cv_para` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = 10) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! `regression.vfold_cv_para` must be a named list. See documentation using '?shapr::explain()'. @@ -145,11 +209,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = list(10)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! `regression.vfold_cv_para` must be a named list. See documentation using '?shapr::explain()'. @@ -157,11 +226,16 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), regression.vfold_cv_para = list(hey = 10)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_vfold_cv_para()`: ! The following parameters in `regression.vfold_cv_para` are not supported by `rsample::vfold_cv()`: 'hey'. @@ -169,9 +243,14 @@ # regression erroneous input: `regression.recipe_func` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_separate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_separate", regression.recipe_func = 3) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + Condition Error in `regression.check_recipe_func()`: ! `regression.recipe_func` must be a function. See documentation. @@ -179,11 +258,25 @@ --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", regression.recipe_func = function(x) { return(2) - }) + }, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_recipe_func()`: ! The output of the `regression.recipe_func` must be of class `recipe`. @@ -191,20 +284,48 @@ # regression erroneous input: `regression.surrogate_n_comb` Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_sur_n_comb()`: - ! `regression.surrogate_n_comb` (31) must be a positive integer less than or equal to `used_n_combinations` minus two (30). + ! `regression.surrogate_n_comb` (31) must be a positive integer less than or equal to `n_coalitions` minus two (30). --- Code - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, n_batches = 1, timing = FALSE, approach = "regression_surrogate", - regression.surrogate_n_comb = 0) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "regression_surrogate", + regression.surrogate_n_comb = 0, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: regression_surrogate + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. Condition Error in `regression.check_sur_n_comb()`: - ! `regression.surrogate_n_comb` (0) must be a positive integer less than or equal to `used_n_combinations` minus two (30). + ! `regression.surrogate_n_comb` (0) must be a positive integer less than or equal to `n_coalitions` minus two (30). diff --git a/tests/testthat/_snaps/regular-output.md b/tests/testthat/_snaps/regular-output.md new file mode 100644 index 0000000000000000000000000000000000000000..632383c8fc1f5e0b3e7d88ae5c826e8cf5777a85 --- /dev/null +++ b/tests/testthat/_snaps/regular-output.md @@ -0,0 +1,778 @@ +# output_lm_numeric_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_independence_MSEv_Shapley_weights + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_empirical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -13.252 15.541 12.826 -5.77179 3.259 + 2: 2 42.44 2.758 -3.325 -7.992 -7.12800 1.808 + 3: 3 42.44 6.805 -22.126 3.730 -0.09234 -5.885 + +# output_lm_numeric_empirical_n_coalitions + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 20 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -14.030 18.711 9.718 -6.1533 4.356 + 2: 2 42.44 3.015 -3.442 -7.095 -7.8174 1.459 + 3: 3 42.44 8.566 -24.310 3.208 0.6956 -5.728 + +# output_lm_numeric_empirical_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Warning in `setup_approach.empirical()`: + Using empirical.type = 'independence' for approach = 'empirical' is deprecated. + Please use approach = 'independence' instead. + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + + Success with message: + empirical.eta force set to 1 for empirical.type = 'independence' + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_lm_numeric_empirical_AICc_each + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.778 9.084 5.4596 5.4596 2.37679 + 2: 2 42.44 6.833 -4.912 -7.9095 -7.9095 0.01837 + 3: 3 42.44 6.895 -21.308 0.6281 0.6281 -4.41122 + +# output_lm_numeric_empirical_AICc_full + + Code + (out <- code) + Message + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.778 9.084 5.4596 5.4596 2.37679 + 2: 2 42.44 6.833 -4.912 -7.9095 -7.9095 0.01837 + 3: 3 42.44 6.895 -21.308 0.6281 0.6281 -4.41122 + +# output_lm_numeric_gaussian + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.645 7.842 14.4120 0.535 -1.5427 + 2: 2 42.44 4.751 -4.814 -11.6985 -1.132 -0.9848 + 3: 3 42.44 7.339 -25.590 0.2717 -0.562 0.9729 + +# output_lm_numeric_copula + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: copula + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -6.512 7.341 14.357 -0.5201 -2.064 + 2: 2 42.44 3.983 -4.656 -10.001 -1.8813 -1.324 + 3: 3 42.44 6.076 -25.219 1.754 -1.3488 1.169 + +# output_lm_numeric_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.198 9.679 16.925 -1.3310 -3.473 + 2: 2 42.44 5.283 -6.046 -8.095 -2.7998 -2.222 + 3: 3 42.44 6.984 -20.837 -4.762 -0.1545 1.201 + +# output_lm_numeric_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.941 7.495 17.471 -4.35451 -3.0686 + 2: 2 42.44 1.824 -5.193 -8.943 0.07104 -1.6383 + 3: 3 42.44 4.530 -20.285 3.170 -4.28496 -0.6978 + +# output_lm_categorical_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.719 15.22 -6.220 -3.791 + 2: 2 42.44 -5.687 -17.48 22.095 -13.755 + 3: 3 42.44 6.839 -21.90 1.997 -5.301 + +# output_lm_categorical_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -1.966 12.55 -4.716 -6.38 + 2: 2 42.44 -2.405 -14.39 14.433 -12.47 + 3: 3 42.44 2.755 -14.24 3.222 -10.10 + +# output_lm_categorical_categorical + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: categorical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.448 11.31 -11.445 5.078 + 2: 2 42.44 -7.493 -12.27 19.672 -14.744 + 3: 3 42.44 13.656 -19.73 4.369 -16.659 + +# output_lm_categorical_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 16, + and is therefore set to 2^n_features = 16. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 4 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none Month_factor Ozone_sub30_factor Solar.R_factor Wind_factor + + 1: 1 42.44 -5.252 13.95 -7.041 -2.167 + 2: 2 42.44 -5.252 -15.61 20.086 -14.050 + 3: 3 42.44 4.833 -15.61 0.596 -8.178 + +# output_lm_ts_timeseries + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_groups = 16, + and is therefore set to 2^n_groups = 16. + + * Model class: + * Approach: timeseries + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 4 + * Number of observations to explain: 2 + + -- Main computation started -- + + i Using 16 of 16 coalitions. + Output + explain_id none S1 S2 S3 S4 + + 1: 1 4.895 -0.5261 0.7831 -0.21023 -0.3885 + 2: 2 4.895 -0.6310 1.6288 -0.04498 -2.9297 + +# output_lm_numeric_comb1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian, empirical, ctree, and independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -8.987 9.070 15.511 -2.5647 -0.4281 + 2: 2 42.44 2.916 -4.516 -7.845 -4.1649 -0.2686 + 3: 3 42.44 6.968 -22.988 -1.717 0.6776 -0.5085 + +# output_lm_numeric_comb2 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree, copula, independence, and copula + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.394 9.435 17.0084 -1.700 -2.7465 + 2: 2 42.44 5.227 -5.209 -8.5226 -2.968 -2.4068 + 3: 3 42.44 6.186 -22.904 -0.3273 -1.132 0.6081 + +# output_lm_numeric_comb3 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence, empirical, gaussian, and empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -6.887 10.715 12.199 -3.670 0.24393 + 2: 2 42.44 2.603 -2.648 -8.464 -5.405 0.03415 + 3: 3 42.44 5.868 -22.184 3.401 -2.955 -1.69888 + +# output_lm_mixed_independence + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +# output_lm_mixed_ctree + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -9.150 12.057 13.162 -0.8269 -4.658 + 2: 2 42.44 4.425 -6.006 -6.260 -0.3910 -7.151 + 3: 3 42.44 6.941 -21.427 -7.518 1.3987 10.006 + +# output_lm_mixed_vaeac + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: vaeac + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -5.050 6.861 15.73013 -0.2083 -6.749 + 2: 2 42.44 2.600 -4.636 -2.26409 -3.1294 -7.954 + 3: 3 42.44 5.139 -17.878 -0.01372 0.5855 1.567 + +# output_lm_mixed_comb + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree, independence, ctree, and independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -7.677 10.757 16.247 -1.446 -7.297 + 2: 2 42.44 5.049 -5.028 -6.965 -1.265 -7.174 + 3: 3 42.44 5.895 -20.744 -4.468 0.775 7.943 + +# output_custom_lm_numeric_independence_1 + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_custom_lm_numeric_independence_2 + + Code + (out <- code) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + +# output_custom_xgboost_mixed_dummy_ctree + + Code + (out <- code) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -5.639 13.31 20.93 -0.4716 -0.425 + 2: 2 42.44 5.709 -13.30 -16.52 1.4006 -2.738 + 3: 3 42.44 6.319 -14.07 -19.77 1.0831 5.870 + +# output_lm_numeric_interaction + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 4, + and is therefore set to 2^n_features = 4. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 2 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 4 of 4 coalitions. + Output + explain_id none Solar.R Wind + + 1: 1 42.44 -13.818 10.579 + 2: 2 42.44 4.642 -6.287 + 3: 3 42.44 4.452 -34.602 + +# output_lm_numeric_ctree_parallelized + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: ctree + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -9.198 9.679 16.925 -1.3310 -3.473 + 2: 2 42.44 5.283 -6.046 -8.095 -2.7998 -2.222 + 3: 3 42.44 6.984 -20.837 -4.762 -0.1545 1.201 + +# output_lm_numeric_empirical_progress + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: empirical + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -13.252 15.541 12.826 -5.77179 3.259 + 2: 2 42.44 2.758 -3.325 -7.992 -7.12800 1.808 + 3: 3 42.44 6.805 -22.126 3.730 -0.09234 -5.885 + +# output_lm_numeric_independence_keep_samp_for_vS + + Code + (out <- code) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -4.537 8.269 17.517 -5.581 -3.066 + 2: 2 42.44 2.250 -3.345 -5.232 -5.581 -1.971 + 3: 3 42.44 3.708 -18.610 -1.440 -2.541 1.316 + diff --git a/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds new file mode 100644 index 0000000000000000000000000000000000000000..5485560c013d5413983af0403da631c7f4737687 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_1.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds new file mode 100644 index 0000000000000000000000000000000000000000..5485560c013d5413983af0403da631c7f4737687 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_lm_numeric_independence_2.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds b/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..2f103a3d374907cc062b53b1da8ee34e57a6a452 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_custom_xgboost_mixed_dummy_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..59124c1b9b8c633058ccb2f43ccb92ef6ad89932 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds new file mode 100644 index 0000000000000000000000000000000000000000..4ea1ead8fdbc8f31e40672b6581a0c97ca088012 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds new file mode 100644 index 0000000000000000000000000000000000000000..cde306c3fd828dcdbb6c4e1a630c432c709b3173 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_method.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds new file mode 100644 index 0000000000000000000000000000000000000000..95aaddf73dd5a5b83301c817170950504953af17 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_categorical_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds new file mode 100644 index 0000000000000000000000000000000000000000..3e7b804e0393ab1ba3c47a5899c449ac4f9ca4b0 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_comb.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..5f307017216febaec3a77e015da2ffee0069f1c9 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds new file mode 100644 index 0000000000000000000000000000000000000000..b18091acc127944b37e6b0a1011f15e31a503700 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds new file mode 100644 index 0000000000000000000000000000000000000000..846bbb00fc67fcb32604e437205ea49962785c08 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_mixed_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds new file mode 100644 index 0000000000000000000000000000000000000000..fc58e7f798b640dade8ed16880b108cc45ed2ec2 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb1.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds new file mode 100644 index 0000000000000000000000000000000000000000..382deb5fea76aac37ab1b984123bfdd14e5692a9 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb2.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds new file mode 100644 index 0000000000000000000000000000000000000000..c256b378b8f1e61f4530da8414441618790c70e4 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_comb3.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds new file mode 100644 index 0000000000000000000000000000000000000000..30842c349b562a2d62563f3c2ab3c3f6811a67eb Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_copula.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds new file mode 100644 index 0000000000000000000000000000000000000000..accea429f0d26ad843a8cb8b027560170b90c135 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds new file mode 100644 index 0000000000000000000000000000000000000000..accea429f0d26ad843a8cb8b027560170b90c135 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_ctree_parallelized.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds new file mode 100644 index 0000000000000000000000000000000000000000..aaf9e052f0944502a0c7819951fc564ef24ff06c Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds new file mode 100644 index 0000000000000000000000000000000000000000..0b11904ba217ce5d18660c46ae7dd9a5e3894ac5 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_each.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds new file mode 100644 index 0000000000000000000000000000000000000000..57fed3dad411e43b05f1c7a5cb8ece8f739135b8 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_AICc_full.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds new file mode 100644 index 0000000000000000000000000000000000000000..c21420f3114fbf8e7150a52844ae6e3307ac361e Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds new file mode 100644 index 0000000000000000000000000000000000000000..4b240bd440833602fe5dc64ff262d3629f8b962f Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_n_coalitions.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds new file mode 100644 index 0000000000000000000000000000000000000000..aaf9e052f0944502a0c7819951fc564ef24ff06c Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_empirical_progress.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds new file mode 100644 index 0000000000000000000000000000000000000000..9a197ced8d1581e04139d38f73ae3f25e8c617b5 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_gaussian.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds new file mode 100644 index 0000000000000000000000000000000000000000..b23b244ebc5364573f780cc4d7e19873b72787ee Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds new file mode 100644 index 0000000000000000000000000000000000000000..3793591696d1f6c4e72c972721df7149f6fce7be Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_MSEv_Shapley_weights.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds new file mode 100644 index 0000000000000000000000000000000000000000..f9f4575a1caa4ff79c7ae358f563d76642b2db4d Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_independence_keep_samp_for_vS.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds new file mode 100644 index 0000000000000000000000000000000000000000..e7a21d7363c833198021c5e98d0a53501ea81169 Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_interaction.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds b/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds new file mode 100644 index 0000000000000000000000000000000000000000..edea23233bd5fadd2615853fd4482b407271116a Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_numeric_vaeac.rds differ diff --git a/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds b/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds new file mode 100644 index 0000000000000000000000000000000000000000..5aa38d8f8c03094a37a25005b68ee492c713381d Binary files /dev/null and b/tests/testthat/_snaps/regular-output/output_lm_timeseries_method.rds differ diff --git a/tests/testthat/_snaps/regular-setup.md b/tests/testthat/_snaps/regular-setup.md new file mode 100644 index 0000000000000000000000000000000000000000..12ca26adfd26a3f3e5e38fe89e1237f78f5f8e20 --- /dev/null +++ b/tests/testthat/_snaps/regular-setup.md @@ -0,0 +1,1019 @@ +# error with custom model without providing predict_model + + Code + model_custom_lm_mixed <- model_lm_mixed + class(model_custom_lm_mixed) <- "whatever" + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `get_predict_model()`: + ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). + See ?shapr::explain or the vignette for more information on how to run shapr with custom models. + +# messages with missing detail in get_model_specs + + Code + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = NA) + Message + Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_get_model_specs_no_lab <- (function(x) { + feature_specs <- list(labels = NA, classes = NA, factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_get_model_specs_no_lab) + Message + Note: Feature names extracted from the model contains NA. + Consistency checks between model and data is therefore disabled. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_gms_no_classes <- (function(x) { + feature_specs <- list(labels = labels(x$terms), classes = NA, factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_gms_no_classes) + Message + Note: Feature classes extracted from the model contains NA. + Assuming feature classes from the data are correct. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +--- + + Code + custom_gms_no_factor_levels <- (function(x) { + feature_specs <- list(labels = labels(x$terms), classes = attr(x$terms, + "dataClasses")[-1], factor_levels = NA) + }) + explain(testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, + x_explain = x_explain_mixed, approach = "independence", phi0 = p0, + predict_model = custom_predict_model, get_model_specs = custom_gms_no_factor_levels) + Message + Note: Feature factor levels extracted from the model contains NA. + Assuming feature factor levels from the data are correct. + + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: independence + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Day Month_factor + + 1: 1 42.44 -4.730 7.750 17.753 -2.601 -7.588 + 2: 2 42.44 2.338 -3.147 -5.310 -1.676 -7.588 + 3: 3 42.44 3.857 -17.469 -1.466 1.099 3.379 + +# erroneous input: `x_train/x_explain` + + Code + x_train_wrong_format <- c(a = 1, b = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_wrong_format, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train should be a matrix or a data.frame/data.table. + +--- + + Code + x_explain_wrong_format <- c(a = 1, b = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, + x_train = x_train_numeric, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain should be a matrix or a data.frame/data.table. + +--- + + Code + x_train_wrong_format <- c(a = 1, b = 2) + x_explain_wrong_format <- c(a = 3, b = 4) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, + x_train = x_train_wrong_format, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train should be a matrix or a data.frame/data.table. + x_explain should be a matrix or a data.frame/data.table. + +--- + + Code + x_train_no_column_names <- as.data.frame(x_train_numeric) + names(x_train_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_no_column_names, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_train misses column names. + +--- + + Code + x_explain_no_column_names <- as.data.frame(x_explain_numeric) + names(x_explain_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, + x_train = x_train_numeric, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain misses column names. + +--- + + Code + x_train_no_column_names <- as.data.frame(x_train_numeric) + x_explain_no_column_names <- as.data.frame(x_explain_numeric) + names(x_explain_no_column_names) <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, + x_train = x_train_no_column_names, approach = "independence", phi0 = p0) + Condition + Error in `get_data()`: + ! x_explain misses column names. + +# erroneous input: `model` + + Code + explain(testing = TRUE, x_explain = x_explain_numeric, x_train = x_train_numeric, + approach = "independence", phi0 = p0) + Condition + Error in `explain()`: + ! argument "model" is missing, with no default + +# erroneous input: `approach` + + Code + approach_non_character <- 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_non_character, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +--- + + Code + approach_incorrect_length <- c("empirical", "gaussian") + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_incorrect_length, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +--- + + Code + approach_incorrect_character <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = approach_incorrect_character, phi0 = p0) + Condition + Error in `check_approach()`: + ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. + These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). + +# erroneous input: `phi0` + + Code + p0_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `phi0` (bla) must be numeric and match the output size of the model (1). + +--- + + Code + p0_non_numeric_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `phi0` () must be numeric and match the output size of the model (1). + +--- + + Code + p0_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_too_long) + Condition + Error in `get_parameters()`: + ! `phi0` (1, 2) must be numeric and match the output size of the model (1). + +--- + + Code + p0_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0_is_NA) + Condition + Error in `get_parameters()`: + ! `phi0` (NA) must be numeric and match the output size of the model (1). + +# erroneous input: `max_n_coalitions` + + Code + max_n_comb_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_comb_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_non_integer) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_too_long) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_coalitions_is_NA) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_comb_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + max_n_coalitions = max_n_comb_non_positive) + Condition + Error in `get_parameters()`: + ! `max_n_coalitions` must be NULL or a single positive integer. + +--- + + Code + max_n_coalitions <- ncol(x_explain_numeric) - 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "gaussian", + max_n_coalitions = max_n_coalitions) + Message + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 6),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 6 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 -1.4276 -1.4276 15.1967 1.6879 -1.4276 + 2: 2 42.44 -0.9143 -0.9143 -10.8152 -0.3212 -0.9143 + 3: 3 42.44 -5.8068 -5.8068 0.1677 -0.3155 -5.8068 + +--- + + Code + groups <- list(A = c("Solar.R", "Wind"), B = c("Temp", "Month"), C = "Day") + max_n_coalitions <- length(groups) - 1 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, phi0 = p0, approach = "gaussian", group = groups, + max_n_coalitions = max_n_coalitions) + Message + Success with message: + n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (8) that we should use all to get reliable results. + max_n_coalitions is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 0.2636 13.7991 -1.4606 + 2: 2 42.44 0.1788 -13.1512 -0.9071 + 3: 3 42.44 -18.4998 -0.1635 1.0951 + +# erroneous input: `group` + + Code + group_non_list <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_non_list) + Condition + Error in `get_parameters()`: + ! `group` must be NULL or a list + +--- + + Code + group_with_non_characters <- list(A = 1, B = 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_with_non_characters) + Condition + Error in `check_groups()`: + ! All components of group should be a character. + +--- + + Code + group_with_non_data_features <- list(A = c("Solar.R", "Wind", + "not_a_data_feature"), B = c("Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_with_non_data_features) + Condition + Error in `check_groups()`: + ! The group feature(s) not_a_data_feature are not + among the features in the data: Solar.R, Wind, Temp, Month, Day. Delete from group. + +--- + + Code + group_missing_data_features <- list(A = c("Solar.R"), B = c("Temp", "Month", + "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_missing_data_features) + Condition + Error in `check_groups()`: + ! The data feature(s) Wind do not + belong to one of the groups. Add to a group. + +--- + + Code + group_dup_data_features <- list(A = c("Solar.R", "Solar.R", "Wind"), B = c( + "Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = group_dup_data_features) + Condition + Error in `check_groups()`: + ! Feature(s) Solar.R are found in more than one group or multiple times per group. + Make sure each feature is only represented in one group, and only once. + +--- + + Code + single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, group = single_group) + Condition + Error in `check_groups()`: + ! You have specified only a single group named A, containing the features: Solar.R, Wind, Temp, Month, Day. + The predictions must be decomposed in at least two groups to be meaningful. + +# erroneous input: `n_MC_samples` + + Code + n_samples_non_numeric_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_numeric_1) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_numeric_2 <- TRUE + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_numeric_2) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_integer <- 10.5 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_integer) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_too_long <- c(1, 2) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_too_long) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_is_NA <- as.numeric(NA) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_is_NA) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +--- + + Code + n_samples_non_positive <- 0 + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + n_MC_samples = n_samples_non_positive) + Condition + Error in `get_parameters()`: + ! `n_MC_samples` must be a single positive integer. + +# erroneous input: `seed` + + Code + seed_not_integer_interpretable <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, seed = seed_not_integer_interpretable) + Condition + Warning in `set.seed()`: + NAs introduced by coercion + Error in `set.seed()`: + ! supplied seed is not a valid integer + +# erroneous input: `keep_samp_for_vS` + + Code + keep_samp_for_vS_non_logical_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_non_logical_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +--- + + Code + keep_samp_for_vS_non_logical_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_non_logical_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +--- + + Code + keep_samp_for_vS_too_long <- c(TRUE, FALSE) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + keep_samp_for_vS = keep_samp_for_vS_too_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$keep_samp_for_vS` must be single logical. + +# erroneous input: `MSEv_uniform_comb_weights` + + Code + MSEv_uniform_comb_weights_nl_1 <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_nl_2 <- NULL + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +--- + + Code + MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, output_args = list( + MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long)) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `check_output_args()`: + ! `output_args$MSEv_uniform_comb_weights` must be single logical. + +# erroneous input: `predict_model` + + Code + predict_model_nonfunction <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_nonfunction) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `get_predict_model()`: + ! `predict_model` must be NULL or a function. + +--- + + Code + predict_model_non_num_output <- (function(model, x) { + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_non_num_output) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` does not return a numeric output of the desired length + for single output models or a data.table of the correct + dimensions for a multiple output model. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + + for more information on running shapr with custom models. + +--- + + Code + predict_model_wrong_output_len <- (function(model, x) { + rep(1, nrow(x) + 1) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_wrong_output_len) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` does not return a numeric output of the desired length + for single output models or a data.table of the correct + dimensions for a multiple output model. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + + for more information on running shapr with custom models. + +--- + + Code + predict_model_invalid_argument <- (function(model) { + rep(1, nrow(x)) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_invalid_argument) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + A basic function test threw the following error: + Error in predict_model(model, x_test): unused argument (x_test) + +--- + + Code + predict_model_error <- (function(model, x) { + 1 + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + predict_model = predict_model_error) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `test_predict_model()`: + ! The predict_model function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + A basic function test threw the following error: + Error in 1 + "bla": non-numeric argument to binary operator + +# erroneous input: `get_model_specs` + + Code + get_model_specs_nonfunction <- "bla" + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_model_specs_nonfunction) + Condition + Error in `get_feature_specs()`: + ! `get_model_specs` must be NULL, NA or a function. + +--- + + Code + get_ms_output_not_list <- (function(x) { + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_not_list) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_ms_output_too_long <- (function(x) { + list(1, 2, 3, 4) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_too_long) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_ms_output_wrong_names <- (function(x) { + list(labels = 1, classes = 2, not_a_name = 3) + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_ms_output_wrong_names) + Condition + Error in `get_feature_specs()`: + ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models and the required output format of get_model_specs. + +--- + + Code + get_model_specs_error <- (function(x) { + 1 + "bla" + }) + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_train_numeric, approach = "independence", phi0 = p0, + get_model_specs = get_model_specs_error) + Condition + Error in `get_feature_specs()`: + ! The get_model_specs function of class `lm` is invalid. + See the 'Advanced usage' section of the vignette: + vignette('understanding_shapr', package = 'shapr') + for more information on running shapr with custom models. + Note that `get_model_specs` is not required (can be set to NULL) + unless you require consistency checks between model and data. + A basic function test threw the following error: + Error in 1 + "bla": non-numeric argument to binary operator + +# incompatible input: `data/approach` + + Code + non_factor_approach_1 <- "gaussian" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_1, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.gaussian()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'gaussian' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +--- + + Code + non_factor_approach_2 <- "empirical" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_2, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.empirical()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'empirical' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +--- + + Code + non_factor_approach_3 <- "copula" + explain(testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, + x_train = x_explain_mixed, approach = non_factor_approach_3, phi0 = p0) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + Condition + Error in `setup_approach.copula()`: + ! The following feature(s) are factor(s): Month_factor. + approach = 'copula' does not support factor features. + Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. + +# Message with too low `max_n_coalitions` + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_explain_numeric, phi0 = p0, approach = "gaussian", + max_n_coalitions = max_n_coalitions) + Message + Success with message: + max_n_coalitions is smaller than max(10, n_features + 1 = 6),which will result in unreliable results. + It is therefore set to 10. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 6 of 32 coalitions. + Output + explain_id none Solar.R Wind Temp Month Day + + 1: 1 42.44 2.3585 2.3585 5.900 -0.3739 2.3585 + 2: 2 42.44 -1.5323 -1.5323 -8.909 -0.3739 -1.5323 + 3: 3 42.44 -0.7635 -0.7635 -6.441 -8.8373 -0.7635 + +--- + + Code + explain(testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, + x_train = x_explain_numeric, phi0 = p0, approach = "gaussian", group = groups, + max_n_coalitions = max_n_coalitions) + Message + Success with message: + n_groups is smaller than or equal to 3, meaning there are so few unique coalitions (8) that we should use all to get reliable results. + max_n_coalitions is therefore set to 2^n_groups = 8. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of group-wise Shapley values: 3 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 8 of 8 coalitions. + Output + explain_id none A B C + + 1: 1 42.44 5.589 5.591 1.4213 + 2: 2 42.44 -6.637 -6.636 -0.6071 + 3: 3 42.44 -5.439 -5.436 -6.6932 + +# Shapr with `max_n_coalitions` >= 2^m uses exact Shapley kernel weights + + Code + explanation_exact <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, max_n_coalitions = NULL, iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + +--- + + Code + explanation_equal <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, extra_computation_args = list( + compute_sd = FALSE), max_n_coalitions = 2^ncol(x_explain_numeric), + iterative = FALSE) + Message + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + +--- + + Code + explanation_larger <- explain(testing = TRUE, model = model_lm_numeric, + x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", + phi0 = p0, n_MC_samples = 2, seed = 123, extra_computation_args = list( + compute_sd = FALSE), max_n_coalitions = 2^ncol(x_explain_numeric) + 1, + iterative = FALSE) + Message + Success with message: + max_n_coalitions is NULL or larger than or 2^n_features = 32, + and is therefore set to 2^n_features = 32. + + * Model class: + * Approach: gaussian + * Iterative estimation: FALSE + * Number of feature-wise Shapley values: 5 + * Number of observations to explain: 3 + + -- Main computation started -- + + i Using 32 of 32 coalitions. + diff --git a/tests/testthat/_snaps/setup.md b/tests/testthat/_snaps/setup.md deleted file mode 100644 index 72c2883153e6e382fe65c4bdecd13048b6805958..0000000000000000000000000000000000000000 --- a/tests/testthat/_snaps/setup.md +++ /dev/null @@ -1,849 +0,0 @@ -# error with custom model without providing predict_model - - Code - model_custom_lm_mixed <- model_lm_mixed - class(model_custom_lm_mixed) <- "whatever" - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Condition - Error in `get_predict_model()`: - ! You passed a model to explain() which is not natively supported, and did not supply the 'predict_model' function to explain(). - See ?shapr::explain or the vignette for more information on how to run shapr with custom models. - -# messages with missing detail in get_model_specs - - Code - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = NA, n_batches = 1, timing = FALSE) - Message - Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_get_model_specs_no_lab <- (function(x) { - feature_specs <- list(labels = NA, classes = NA, factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_get_model_specs_no_lab, n_batches = 1, timing = FALSE) - Message - Note: Feature names extracted from the model contains NA. - Consistency checks between model and data is therefore disabled. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_gms_no_classes <- (function(x) { - feature_specs <- list(labels = labels(x$terms), classes = NA, factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_classes, n_batches = 1, timing = FALSE) - Message - Note: Feature classes extracted from the model contains NA. - Assuming feature classes from the data are correct. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - ---- - - Code - custom_gms_no_factor_levels <- (function(x) { - feature_specs <- list(labels = labels(x$terms), classes = attr(x$terms, - "dataClasses")[-1], factor_levels = NA) - }) - explain(model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, - approach = "independence", prediction_zero = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_factor_levels, n_batches = 1, timing = FALSE) - Message - Note: Feature factor levels extracted from the model contains NA. - Assuming feature factor levels from the data are correct. - - Output - none Solar.R Wind Temp Day Month_factor - - 1: 42.44 -4.730 7.750 17.753 -2.601 -7.588 - 2: 42.44 2.338 -3.147 -5.310 -1.676 -7.588 - 3: 42.44 3.857 -17.469 -1.466 1.099 3.379 - -# erroneous input: `x_train/x_explain` - - Code - x_train_wrong_format <- c(a = 1, b = 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_wrong_format, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train should be a matrix or a data.frame/data.table. - ---- - - Code - x_explain_wrong_format <- c(a = 1, b = 2) - explain(model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain should be a matrix or a data.frame/data.table. - ---- - - Code - x_train_wrong_format <- c(a = 1, b = 2) - x_explain_wrong_format <- c(a = 3, b = 4) - explain(model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_wrong_format, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train should be a matrix or a data.frame/data.table. - x_explain should be a matrix or a data.frame/data.table. - ---- - - Code - x_train_no_column_names <- as.data.frame(x_train_numeric) - names(x_train_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_no_column_names, - approach = "independence", prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_train misses column names. - ---- - - Code - x_explain_no_column_names <- as.data.frame(x_explain_numeric) - names(x_explain_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_no_column_names, - x_train = x_train_numeric, approach = "independence", prediction_zero = p0, - n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain misses column names. - ---- - - Code - x_train_no_column_names <- as.data.frame(x_train_numeric) - x_explain_no_column_names <- as.data.frame(x_explain_numeric) - names(x_explain_no_column_names) <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_no_column_names, - x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `get_data()`: - ! x_explain misses column names. - -# erroneous input: `model` - - Code - explain(x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, n_batches = 1, timing = FALSE) - Condition - Error in `explain()`: - ! argument "model" is missing, with no default - -# erroneous input: `approach` - - Code - approach_non_character <- 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_non_character, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - ---- - - Code - approach_incorrect_length <- c("empirical", "gaussian") - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_incorrect_length, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - ---- - - Code - approach_incorrect_character <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = approach_incorrect_character, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `check_approach()`: - ! `approach` must be one of the following: 'categorical', 'copula', 'ctree', 'empirical', 'gaussian', 'independence', 'regression_separate', 'regression_surrogate', 'timeseries', 'vaeac'. - These can also be combined (except 'regression_surrogate' and 'regression_separate') by passing a vector of length one less than the number of features (4). - -# erroneous input: `prediction_zero` - - Code - p0_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_non_numeric_1, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (bla) must be numeric and match the output size of the model (1). - ---- - - Code - p0_non_numeric_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_non_numeric_2, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` () must be numeric and match the output size of the model (1). - ---- - - Code - p0_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_too_long, n_batches = 1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (1, 2) must be numeric and match the output size of the model (1). - ---- - - Code - p0_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0_is_NA, n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `prediction_zero` (NA) must be numeric and match the output size of the model (1). - -# erroneous input: `n_combinations` - - Code - n_combinations_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_numeric_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_numeric_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_integer, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_is_NA, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations_non_positive, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_combinations` must be NULL or a single positive integer. - ---- - - Code - n_combinations <- ncol(x_explain_numeric) - 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, approach = "gaussian", n_combinations = n_combinations, - n_batches = 1, timing = FALSE) - Condition - Error in `check_n_combinations()`: - ! `n_combinations` has to be greater than the number of features. - ---- - - Code - groups <- list(A = c("Solar.R", "Wind"), B = c("Temp", "Month"), C = "Day") - n_combinations <- length(groups) - 1 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, approach = "gaussian", group = groups, n_combinations = n_combinations, - n_batches = 1, timing = FALSE) - Condition - Error in `check_n_combinations()`: - ! `n_combinations` has to be greater than the number of groups. - -# erroneous input: `group` - - Code - group_non_list <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_non_list, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `group` must be NULL or a list - ---- - - Code - group_with_non_characters <- list(A = 1, B = 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_with_non_characters, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! All components of group should be a character. - ---- - - Code - group_with_non_data_features <- list(A = c("Solar.R", "Wind", - "not_a_data_feature"), B = c("Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_with_non_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! The group feature(s) not_a_data_feature are not - among the features in the data: Solar.R, Wind, Temp, Month, Day. Delete from group. - ---- - - Code - group_missing_data_features <- list(A = c("Solar.R"), B = c("Temp", "Month", - "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_missing_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! The data feature(s) Wind do not - belong to one of the groups. Add to a group. - ---- - - Code - group_dup_data_features <- list(A = c("Solar.R", "Solar.R", "Wind"), B = c( - "Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = group_dup_data_features, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! Feature(s) Solar.R are found in more than one group or multiple times per group. - Make sure each feature is only represented in one group, and only once. - ---- - - Code - single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, group = single_group, - n_batches = 1, timing = FALSE) - Condition - Error in `check_groups()`: - ! You have specified only a single group named A, containing the features: Solar.R, Wind, Temp, Month, Day. - The predictions must be decomposed in at least two groups to be meaningful. - -# erroneous input: `n_samples` - - Code - n_samples_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_numeric_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_numeric_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_integer, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_is_NA, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - ---- - - Code - n_samples_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_samples = n_samples_non_positive, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_samples` must be a single positive integer. - -# erroneous input: `n_batches` - - Code - n_batches_non_numeric_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_numeric_1, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_numeric_2 <- TRUE - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_numeric_2, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_integer <- 10.5 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_integer, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_too_long <- c(1, 2) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_too_long, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_is_NA <- as.numeric(NA) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_is_NA, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_batches_non_positive <- 0 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_non_positive, - timing = FALSE) - Condition - Error in `get_parameters()`: - ! `n_batches` must be NULL or a single positive integer. - ---- - - Code - n_combinations <- 10 - n_batches_too_large <- 11 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_combinations = n_combinations, - n_batches = n_batches_too_large, timing = FALSE) - Condition - Error in `check_n_batches()`: - ! `n_batches` (11) must be smaller than the number of feature combinations/`n_combinations` (10) - ---- - - Code - n_batches_too_large_2 <- 32 - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, n_batches = n_batches_too_large_2, - timing = FALSE) - Condition - Error in `check_n_batches()`: - ! `n_batches` (32) must be smaller than the number of feature combinations/`n_combinations` (32) - -# erroneous input: `seed` - - Code - seed_not_integer_interpretable <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, seed = seed_not_integer_interpretable, - n_batches = 1, timing = FALSE) - Condition - Warning in `set.seed()`: - NAs introduced by coercion - Error in `set.seed()`: - ! supplied seed is not a valid integer - -# erroneous input: `keep_samp_for_vS` - - Code - keep_samp_for_vS_non_logical_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_non_logical_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - ---- - - Code - keep_samp_for_vS_non_logical_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_non_logical_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - ---- - - Code - keep_samp_for_vS_too_long <- c(TRUE, FALSE) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, keep_samp_for_vS = keep_samp_for_vS_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `keep_samp_for_vS` must be single logical. - -# erroneous input: `MSEv_uniform_comb_weights` - - Code - MSEv_uniform_comb_weights_nl_1 <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - ---- - - Code - MSEv_uniform_comb_weights_nl_2 <- NULL - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - ---- - - Code - MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_parameters()`: - ! `MSEv_uniform_comb_weights` must be single logical. - -# erroneous input: `predict_model` - - Code - predict_model_nonfunction <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_nonfunction, - n_batches = 1, timing = FALSE) - Condition - Error in `get_predict_model()`: - ! `predict_model` must be NULL or a function. - ---- - - Code - predict_model_non_num_output <- (function(model, x) { - "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_non_num_output, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` does not return a numeric output of the desired length - for single output models or a data.table of the correct - dimensions for a multiple output model. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - - for more information on running shapr with custom models. - ---- - - Code - predict_model_wrong_output_len <- (function(model, x) { - rep(1, nrow(x) + 1) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_wrong_output_len, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` does not return a numeric output of the desired length - for single output models or a data.table of the correct - dimensions for a multiple output model. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - - for more information on running shapr with custom models. - ---- - - Code - predict_model_invalid_argument <- (function(model) { - rep(1, nrow(x)) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_invalid_argument, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - A basic function test threw the following error: - Error in predict_model(model, x_test): unused argument (x_test) - ---- - - Code - predict_model_error <- (function(model, x) { - 1 + "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, predict_model = predict_model_error, - n_batches = 1, timing = FALSE) - Condition - Error in `test_predict_model()`: - ! The predict_model function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - A basic function test threw the following error: - Error in 1 + "bla": non-numeric argument to binary operator - -# erroneous input: `get_model_specs` - - Code - get_model_specs_nonfunction <- "bla" - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_model_specs_nonfunction, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! `get_model_specs` must be NULL, NA or a function. - ---- - - Code - get_ms_output_not_list <- (function(x) { - "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_not_list, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_ms_output_too_long <- (function(x) { - list(1, 2, 3, 4) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_too_long, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_ms_output_wrong_names <- (function(x) { - list(labels = 1, classes = 2, not_a_name = 3) - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_ms_output_wrong_names, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The `get_model_specs` function of class `lm` does not return a list of length 3 with elements "labels","classes","factor_levels". - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models and the required output format of get_model_specs. - ---- - - Code - get_model_specs_error <- (function(x) { - 1 + "bla" - }) - explain(model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - approach = "independence", prediction_zero = p0, get_model_specs = get_model_specs_error, - n_batches = 1, timing = FALSE) - Condition - Error in `get_feature_specs()`: - ! The get_model_specs function of class `lm` is invalid. - See the 'Advanced usage' section of the vignette: - vignette('understanding_shapr', package = 'shapr') - for more information on running shapr with custom models. - Note that `get_model_specs` is not required (can be set to NULL) - unless you require consistency checks between model and data. - A basic function test threw the following error: - Error in 1 + "bla": non-numeric argument to binary operator - -# incompatible input: `data/approach` - - Code - non_factor_approach_1 <- "gaussian" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_1, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.gaussian()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'gaussian' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - ---- - - Code - non_factor_approach_2 <- "empirical" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_2, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.empirical()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'empirical' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - ---- - - Code - non_factor_approach_3 <- "copula" - explain(model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - approach = non_factor_approach_3, prediction_zero = p0, n_batches = 1, - timing = FALSE) - Condition - Error in `setup_approach.copula()`: - ! The following feature(s) are factor(s): Month_factor. - approach = 'copula' does not support factor features. - Please change approach to one of 'independence' (not recommended), 'ctree', 'vaeac', 'categorical'. - diff --git a/tests/testthat/helper-ar-arima.R b/tests/testthat/helper-ar-arima.R index 47944e87b0f5a1d3b00f6d52629754b62afab893..9ac21641b12b47c075f80aeb7ce7269c7ec2d032 100644 --- a/tests/testthat/helper-ar-arima.R +++ b/tests/testthat/helper-ar-arima.R @@ -1,17 +1,18 @@ options(digits = 5) # To avoid round off errors when printing output on different systems +data_arima <- data.table::as.data.table(airquality) +data_arima[, Solar.R := ifelse(is.na(Solar.R), mean(Solar.R, na.rm = TRUE), Solar.R)] +data_arima[, Ozone := ifelse(is.na(Ozone), mean(Ozone, na.rm = TRUE), Ozone)] - -data <- data.table::as.data.table(airquality) - -model_ar_temp <- ar(data$Temp, order = 2) +model_ar_temp <- ar(data_arima$Temp, order = 2) model_ar_temp$n.ahead <- 3 -p0_ar <- rep(mean(data$Temp), 3) +p0_ar <- rep(mean(data_arima$Temp), 3) -model_arima_temp <- arima(data$Temp[1:150], c(2, 1, 0), xreg = data$Wind[1:150]) +model_arima_temp <- arima(data_arima$Temp[1:150], c(2, 1, 0), xreg = data_arima$Wind[1:150]) +model_arima_temp2 <- arima(data_arima$Temp[1:150], c(2, 1, 0), xreg = data_arima[1:150, c("Wind", "Solar.R", "Ozone")]) -model_arima_temp_noxreg <- arima(data$Temp[1:150], c(2, 1, 0)) +model_arima_temp_noxreg <- arima(data_arima$Temp[1:150], c(2, 1, 0)) # When loading this here we avoid the "Registered S3 method overwritten" when calling forecast -model_forecast_ARIMA_temp <- forecast::Arima(data$Temp[1:150], order = c(2, 1, 0), xreg = data$Wind[1:150]) +model_forecast_ARIMA_temp <- forecast::Arima(data_arima$Temp[1:150], order = c(2, 1, 0), xreg = data_arima$Wind[1:150]) diff --git a/tests/testthat/test-adaptive-output.R b/tests/testthat/test-adaptive-output.R new file mode 100644 index 0000000000000000000000000000000000000000..0be8006c01725bdc76ea854ee8f3806f2a4c6bec --- /dev/null +++ b/tests/testthat/test-adaptive-output.R @@ -0,0 +1,312 @@ +# lm_numeric with different approaches + +test_that("output_lm_numeric_independence_reach_exact", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley"), + paired_shap_sampling = TRUE + ), + "output_lm_numeric_independence_reach_exact" + ) +}) + +test_that("output_lm_numeric_independence_converges_tol", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_independence_converges_tol" + ) +}) + +test_that("output_lm_numeric_independence_converges_maxit", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_independence_converges_maxit" + ) +}) + +test_that("output_lm_numeric_indep_conv_max_n_coalitions", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_indep_conv_max_n_coalitions" + ) +}) + + +test_that("output_lm_numeric_gaussian_group_converges_tol", { + groups <- list( + A = c("Solar.R", "Wind"), + B = c("Temp", "Month"), + C = "Day" + ) + + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley") + ), + "output_lm_numeric_gaussian_group_converges_tol" + ) +}) + +test_that("output_lm_numeric_independence_converges_tol_paired", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.1 + ), + iterative = TRUE, + verbose = c("convergence", "shapley"), + paired_shap_sampling = TRUE + ), + "output_lm_numeric_independence_converges_tol_paired" + ) +}) + +test_that("output_lm_numeric_independence_saving_and_cont_est", { + # Full 8 iteration estimation to compare against + # Sets seed on the outside + seed = NULL for reproducibility in two-step estimation + set.seed(123) + full <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Testing saving and continuation estimation + # By setting the seed outside (+ seed= NULL), we should get identical objects when calling explain twice this way + set.seed(123) + e_init_object <- explain( + testing = FALSE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Continue estimation from the init object + expect_snapshot_rds( + e_cont_est_object <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = NULL, + prev_shapr_object = e_init_object, + seed = NULL, + ), + "output_lm_numeric_independence_cont_est_object" + ) + + # Testing equality with the object being run in one go + expect_equal(e_cont_est_object, full) + + + # Same as above but using the saving_path instead of the shapr object itself # + set.seed(123) + e_init_path <- explain( + testing = FALSE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 5 + ), + iterative = TRUE, + seed = NULL, + verbose = NULL + ) + + # Continue estimation from the init object + expect_snapshot_rds( + e_cont_est_path <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + paired_shap_sampling = FALSE, + iterative_args = list( + initial_n_coalitions = 10, + convergence_tol = 0.001, + n_coal_next_iter_factor_vec = rep(10^(-5), 10), + max_iter = 8 + ), + iterative = TRUE, + verbose = NULL, + prev_shapr_object = e_init_path$saving_path, + seed = NULL + ), + "output_lm_numeric_independence_cont_est_path" + ) + + # Testing equality with the object being run in one go + expect_equal(e_cont_est_path, full) +}) + +test_that("output_verbose_1", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic") + ), + "output_verbose_1" + ) +}) + +test_that("output_verbose_1_3", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence") + ), + "output_verbose_1_3" + ) +}) + +test_that("output_verbose_1_3_4", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley") + ), + "output_verbose_1_3_4" + ) +}) + +test_that("output_verbose_1_3_4_5", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + verbose = c("basic", "convergence", "shapley", "vS_details") + ), + "output_verbose_1_3_4_5" + ) +}) diff --git a/tests/testthat/test-adaptive-setup.R b/tests/testthat/test-adaptive-setup.R new file mode 100644 index 0000000000000000000000000000000000000000..45f132e6b7a4122906dde2ec06d982d47e6ca0cc --- /dev/null +++ b/tests/testthat/test-adaptive-setup.R @@ -0,0 +1,242 @@ +test_that("iterative_args are respected", { + ex <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + max_n_coalitions = 30, + iterative_args = list( + initial_n_coalitions = 6, + convergence_tol = 0.0005, + n_coal_next_iter_factor_vec = rep(10^(-6), 10), + max_iter = 8 + ), + iterative = TRUE + ) + + # Check that initial_n_coalitions is respected + expect_equal(ex$internal$iter_list[[1]]$X[, .N], 6) + + # Check that max_iter is respected + expect_equal(length(ex$internal$iter_list), 8) + expect_true(ex$iterative_results$iter_info_dt[.N, converged_max_iter]) +}) + + +test_that("iterative feature wise and groupwise computations identical", { + groups <- list( + Solar.R = "Solar.R", + Wind = "Wind", + Temp = "Temp", + Month = "Month", + Day = "Day" + ) + + expl_feat <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE + ) + + + expl_group <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0, + iterative_args = list( + initial_n_coalitions = 5, + convergence_tol = 0.1 + ), + iterative = TRUE + ) + + + # Checking equality in the list with all final and intermediate results + expect_equal(expl_feat$iter_results, expl_group$iter_results) +}) + +test_that("erroneous input: `min_n_batches`", { + set.seed(123) + + # non-numeric 1 + expect_snapshot( + { + n_batches_non_numeric_1 <- "bla" + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_1) + ) + }, + error = TRUE + ) + + # non-numeric 2 + expect_snapshot( + { + n_batches_non_numeric_2 <- TRUE + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_numeric_2) + ) + }, + error = TRUE + ) + + # non-integer + expect_snapshot( + { + n_batches_non_integer <- 10.5 + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_integer) + ) + }, + error = TRUE + ) + + # length > 1 + expect_snapshot( + { + n_batches_too_long <- c(1, 2) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_too_long) + ) + }, + error = TRUE + ) + + # NA-numeric + expect_snapshot( + { + n_batches_is_NA <- as.numeric(NA) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_is_NA) + ) + }, + error = TRUE + ) + + # Non-positive + expect_snapshot( + { + n_batches_non_positive <- 0 + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + extra_computation_args = list(min_n_batches = n_batches_non_positive) + ) + }, + error = TRUE + ) +}) + +test_that("different n_batches gives same/different shapley values for different approaches", { + # approach "empirical" is seed independent + explain.empirical_n_batches_5 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + extra_computation_args = list(min_n_batches = 5, max_batch_size = 10) + ) + + explain.empirical_n_batches_10 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + extra_computation_args = list(min_n_batches = 10, max_batch_size = 10) + ) + + # Difference in the objects (n_batches and related) + expect_false(identical( + explain.empirical_n_batches_5, + explain.empirical_n_batches_10 + )) + # Same Shapley values + expect_equal( + explain.empirical_n_batches_5$shapley_values_est, + explain.empirical_n_batches_10$shapley_values_est + ) + + # approach "ctree" is seed dependent + explain.ctree_n_batches_5 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + extra_computation_args = list(min_n_batches = 5, max_batch_size = 10) + ) + + explain.ctree_n_batches_10 <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + extra_computation_args = list(min_n_batches = 10, max_batch_size = 10) + ) + + # Difference in the objects (n_batches and related) + expect_false(identical( + explain.ctree_n_batches_5, + explain.ctree_n_batches_10 + )) + # NEITHER same Shapley values + expect_false(identical( + explain.ctree_n_batches_5$shapley_values_est, + explain.ctree_n_batches_10$shapley_values_est + )) +}) diff --git a/tests/testthat/test-asymmetric-causal-output.R b/tests/testthat/test-asymmetric-causal-output.R new file mode 100644 index 0000000000000000000000000000000000000000..bc8f0f0179493e6c91459537014ac495a791655b --- /dev/null +++ b/tests/testthat/test-asymmetric-causal-output.R @@ -0,0 +1,507 @@ +# Continuous data ------------------------------------------------------------------------------------------------- +test_that("output_asymmetric_conditional", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asymmetric_conditional" + ) +}) + +test_that("output_asym_cond_reg", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE + ), + "output_asym_cond_reg" + ) +}) + +test_that("output_asym_cond_reg_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = NULL, + paired_shap_sampling = FALSE, + iterative = TRUE + ), + "output_asym_cond_reg_iterative" + ) +}) + +test_that("output_symmetric_conditional", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), # Does not matter when asymmetric = TRUE and confounding = NULL + confounding = NULL, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_conditional" + ) +}) + +test_that("output_symmetric_marginal_independence", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "independence", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_marginal_independence" + ) +}) + +test_that("output_symmetric_marginal_gaussian", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_symmetric_marginal_gaussian" + ) +}) + +test_that("output_asym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_TRUE" + ) +}) + + + +test_that("output_asym_caus_conf_FALSE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = FALSE, + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_FALSE" + ) +}) + +test_that("output_asym_caus_conf_mix", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix" + ) +}) + +test_that("output_asym_caus_conf_mix_n_coal", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + paired_shap_sampling = FALSE, + max_n_coalitions = 6 + ), + "output_asym_caus_conf_mix_n_coal" + ) +}) + +test_that("output_asym_caus_conf_mix_empirical", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "empirical", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix_empirical" + ) +}) + +test_that("output_asym_caus_conf_mix_ctree", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_asym_caus_conf_mix_ctree" + ) +}) + +test_that("output_sym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_TRUE" + ) +}) + +test_that("output_sym_caus_conf_FALSE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_FALSE" + ) +}) + +test_that("output_sym_caus_conf_mix", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_mix" + ) +}) + + +## Group-wise ----------------------------------------------------------------------------------------------------- +test_that("output_sym_caus_conf_TRUE_group", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3), + confounding = TRUE, + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_TRUE_group" + ) +}) + + +test_that("output_sym_caus_conf_mix_group", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1, 2, 3), + confounding = c(TRUE, TRUE, FALSE), + group = list("A" = c("Solar.R"), B = c("Wind", "Temp"), C = c("Month", "Day")), + n_MC_samples = 5 # Just for speed + ), + "output_sym_caus_conf_mix_group" + ) +}) + +test_that("output_sym_caus_conf_mix_group_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1, 2, 3), + confounding = c(TRUE, TRUE, FALSE), + group = list("A" = c("Solar.R"), B = c("Wind", "Temp"), C = c("Month", "Day")), + n_MC_samples = 5, # Just for speed, + verbose = c("convergence"), + iterative = TRUE + ), + "output_sym_caus_conf_mix_group_iterative" + ) +}) + + + + + +# Mixed data ------------------------------------------------------------------------------------------------------ +test_that("output_mixed_sym_caus_conf_TRUE", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_TRUE" + ) +}) + +test_that("output_mixed_sym_caus_conf_TRUE_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3, 4:5), + confounding = TRUE, + n_MC_samples = 5, # Just for speed + iterative = TRUE + ), + "output_mixed_sym_caus_conf_TRUE_iterative" + ) +}) + +test_that("output_mixed_asym_caus_conf_mixed", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(TRUE, FALSE, FALSE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_mixed" + ) +}) + +test_that("output_mixed_asym_caus_conf_mixed_2", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "ctree", + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + confounding = c(FALSE, TRUE, TRUE), + paired_shap_sampling = FALSE, + n_MC_samples = 5 # Just for speed + ), + "output_mixed_sym_caus_conf_mixed_2" + ) +}) + + +test_that("output_mixed_asym_cond_reg", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_mixed, + x_explain = x_explain_mixed, + x_train = x_train_mixed, + approach = "regression_separate", + regression.model = parsnip::linear_reg(), + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 3, 4:5), + paired_shap_sampling = FALSE, + confounding = NULL, + iterative = TRUE + ), + "output_mixed_asym_cond_reg" + ) +}) + + + +# Categorical data ------------------------------------------------------------------------------------------------ +test_that("output_categorical_asym_causal_mixed_cat", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical[1:2], # Temp [1:2] as [1:3] give different sample on GHA-macOS (unknown reason) + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + output_args = list(keep_samp_for_vS = TRUE) + ), + "output_categorical_asym_causal_mixed_cat" + ) +}) + + + +test_that("output_cat_asym_causal_mixed_cat_ad", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "categorical", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5, # Just for speed + iterative = TRUE + ), + "output_cat_asym_causal_mixed_cat_ad" + ) +}) + +test_that("output_categorical_asym_causal_mixed_ctree", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_categorical, + x_explain = x_explain_categorical, + x_train = x_train_categorical, + approach = "ctree", + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(3:4, 2, 1), + confounding = c(TRUE, FALSE, FALSE), + n_MC_samples = 5 # Just for speed + ), + "output_categorical_asym_causal_mixed_ctree" + ) +}) diff --git a/tests/testthat/test-asymmetric-causal-setup.R b/tests/testthat/test-asymmetric-causal-setup.R new file mode 100644 index 0000000000000000000000000000000000000000..75f03cb9830829a6da6efaf614f19aa8030d6406 --- /dev/null +++ b/tests/testthat/test-asymmetric-causal-setup.R @@ -0,0 +1,343 @@ +test_that("asymmetric erroneous input: `causal_ordering`", { + set.seed(123) + + expect_snapshot( + { + # Too many variables (6 does not exist) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:6), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (5 duplicate) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:5, 5), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Correct number of variables, but 5 duplicate + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(2:5, 5), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # To few variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(1:2, 4), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (not valid feature name) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Month", "Day", "Invalid feature name"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too many variables (duplicate) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Month", "Day", "Day"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Duplicate and missing "Month", but right number of variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind", "Temp", "Day", "Day"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Too few variables + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list("Solar.R", "Wind"), + confounding = NULL, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: not giving the group names + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("Solar.R", "Wind", "Temp", "Month"), "Day"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: not giving all the group names correctly + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("A", "C"), "Wrong name"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Group Shapley: missing a group names + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = TRUE, + causal_ordering = list(c("A"), "B"), + confounding = NULL, + approach = "gaussian", + group = list("A" = c("Solar.R", "Wind"), B = "Temp", C = c("Month", "Day")), + iterative = FALSE + ) + }, + error = TRUE + ) +}) + + +test_that("asymmetric erroneous input: `approach`", { + set.seed(123) + + expect_snapshot( + { + # Causal Shapley values is not applicable for combined approaches. + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = c("gaussian", "independence", "empirical", "gaussian"), + iterative = FALSE + ) + }, + error = TRUE + ) +}) + +test_that("asymmetric erroneous input: `asymmetric`", { + set.seed(123) + + expect_snapshot( + { + # Vector + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = c(FALSE, FALSE), + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # String + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = "Must be a single logical", + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # Integer + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = 1L, + causal_ordering = list(1:2, 3:4, 5), + confounding = TRUE, + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) +}) + + +test_that("asymmetric erroneous input: `confounding`", { + set.seed(123) + + expect_snapshot( + { + # confounding not logical vector + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = c("A", "B", "C"), + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) + + expect_snapshot( + { + # logical vector of incorrect length + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + asymmetric = FALSE, + causal_ordering = list(1:2, 3:4, 5), + confounding = c(TRUE, FALSE), + approach = "gaussian", + iterative = FALSE + ) + }, + error = TRUE + ) +}) diff --git a/tests/testthat/test-bugfixes.R b/tests/testthat/test-bugfixes.R deleted file mode 100644 index e1fffd9ebd683b0b12bdf95ecae563478e1178d1..0000000000000000000000000000000000000000 --- a/tests/testthat/test-bugfixes.R +++ /dev/null @@ -1,27 +0,0 @@ -test_that("bug with column name ordering in edge case is fixed", { - # Before the bugfix, data.table throw the warning: - # Column 2 ['Solar.R'] of item 2 appears in position 1 in item 1. Set use.names=TRUE to match by column name, - # or use.names=FALSE to ignore column names. use.names='check' (default from v1.12.2) emits this message and - # proceeds as if use.names=FALSE for backwards compatibility. - # See news item 5 in v1.12.2 for options to control this message. - expect_silent({ # Apparently, expect_no_message() does not react to the data.table message/warning - e.one_subset_per_batch <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, - n_batches = 2^5 - 1, # Bug happens when n_batches = n_combinations - 1 - keep_samp_for_vS = TRUE, - seed = 123 - ) - }) - - # The bug causes id_combination to suddenly not be integer. - expect_true( - is.integer( - e.one_subset_per_batch$internal$output$dt_samp_for_vS$id_combination[1] - ) - ) -}) diff --git a/tests/testthat/test-forecast-output.R b/tests/testthat/test-forecast-output.R index c2bcc000bb7dff5370385b65d6ed0e6dcb94a737..803f028e4550595b14c828b6dbd81f869edf7cef 100644 --- a/tests/testthat/test-forecast-output.R +++ b/tests/testthat/test-forecast-output.R @@ -1,17 +1,17 @@ test_that("forecast_output_ar_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_ar_temp, - y = data[, "Temp"], + y = data_arima[, "Temp"], train_idx = 2:151, explain_idx = 152:153, explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_ar_numeric" ) @@ -20,217 +20,275 @@ test_that("forecast_output_ar_numeric", { test_that("forecast_output_arima_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + max_n_coalitions = 150, + iterative = FALSE ), "forecast_output_arima_numeric" ) }) +test_that("forecast_output_arima_numeric_iterative", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = 3, + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = FALSE, + max_n_coalitions = 150, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10) + ), + "forecast_output_arima_numeric_iterative" + ) +}) + +test_that("forecast_output_arima_numeric_iterative_groups", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp2, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, c("Wind", "Solar.R", "Ozone")], + train_idx = 3:148, + explain_idx = 149:150, + explain_y_lags = 3, + explain_xreg_lags = c(3, 3, 3), + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = TRUE, + max_n_coalitions = 150, + iterative = TRUE, + iterative_args = list(initial_n_coalitions = 10, convergence_tol = 7e-3) + ), + "forecast_output_arima_numeric_iterative_groups" + ) +}) + test_that("forecast_output_arima_numeric_no_xreg", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_arima_temp_noxreg, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = FALSE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_arima_numeric_no_xreg" ) }) +# Old snap does not correspond to the results from the master branch, why is unclear. test_that("forecast_output_forecast_ARIMA_group_numeric", { expect_snapshot_rds( explain_forecast( + testing = TRUE, model = model_forecast_ARIMA_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar, + phi0 = p0_ar, group_lags = TRUE, - n_batches = 1, - timing = FALSE + n_batches = 1 ), "forecast_output_forecast_ARIMA_group_numeric" ) }) +test_that("forecast_output_arima_numeric_no_lags", { + expect_snapshot_rds( + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 2:148, + explain_idx = 149:150, + explain_y_lags = 0, + explain_xreg_lags = 0, + horizon = 3, + approach = "independence", + phi0 = p0_ar, + group_lags = FALSE, + n_batches = 1 + ), + "forecast_output_arima_numeric_no_lags" + ) +}) test_that("ARIMA gives the same output with different horizons", { h3 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 200, + iterative = FALSE ) h2 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 100, + iterative = FALSE ) h1 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = FALSE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) cols_horizon1 <- h2$internal$objects$cols_per_horizon[[1]] expect_equal( - h2$shapley_values[horizon == 1, ..cols_horizon1], - h1$shapley_values[horizon == 1, ..cols_horizon1] + h2$shapley_values_est[horizon == 1, ..cols_horizon1], + h1$shapley_values_est[horizon == 1, ..cols_horizon1] ) expect_equal( - h3$shapley_values[horizon == 1, ..cols_horizon1], - h1$shapley_values[horizon == 1, ..cols_horizon1] + h3$shapley_values_est[horizon == 1, ..cols_horizon1], + h1$shapley_values_est[horizon == 1, ..cols_horizon1] ) cols_horizon2 <- h2$internal$objects$cols_per_horizon[[2]] expect_equal( - h3$shapley_values[horizon == 2, ..cols_horizon2], - h2$shapley_values[horizon == 2, ..cols_horizon2] + h3$shapley_values_est[horizon == 2, ..cols_horizon2], + h2$shapley_values_est[horizon == 2, ..cols_horizon2] ) }) test_that("ARIMA gives the same output with different horizons with grouping", { h3 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "empirical", - prediction_zero = p0_ar[1:3], + phi0 = p0_ar[1:3], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) h2 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 2, approach = "empirical", - prediction_zero = p0_ar[1:2], + phi0 = p0_ar[1:2], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) h1 <- explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 1, approach = "empirical", - prediction_zero = p0_ar[1], + phi0 = p0_ar[1], group_lags = TRUE, n_batches = 1, - timing = FALSE, n_combinations = 50 + max_n_coalitions = 50, + iterative = FALSE ) expect_equal( - h2$shapley_values[horizon == 1], - h1$shapley_values[horizon == 1] + h2$shapley_values_est[horizon == 1], + h1$shapley_values_est[horizon == 1] ) expect_equal( - h3$shapley_values[horizon == 1], - h1$shapley_values[horizon == 1] + h3$shapley_values_est[horizon == 1], + h1$shapley_values_est[horizon == 1] ) expect_equal( - h3$shapley_values[horizon == 2], - h2$shapley_values[horizon == 2] - ) -}) - -test_that("forecast_output_arima_numeric_no_lags", { - # TODO: Need to check out this output. It gives lots of warnings, which indicates something might be wrong. - expect_snapshot_rds( - explain_forecast( - model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], - train_idx = 2:148, - explain_idx = 149:150, - explain_y_lags = 0, - explain_xreg_lags = 0, - horizon = 3, - approach = "independence", - prediction_zero = p0_ar, - group_lags = FALSE, - n_batches = 1, - timing = FALSE - ), - "forecast_output_arima_numeric_no_lags" + h3$shapley_values_est[horizon == 2], + h2$shapley_values_est[horizon == 2] ) }) diff --git a/tests/testthat/test-forecast-setup.R b/tests/testthat/test-forecast-setup.R index 70a49eafbd8bfd0d3cdde3785f23c202e94cf47e..cd211d39236d8543d22ba4516ec133dfb3594c6b 100644 --- a/tests/testthat/test-forecast-setup.R +++ b/tests/testthat/test-forecast-setup.R @@ -9,17 +9,17 @@ test_that("error with custom model without providing predict_model", { class(model_custom_arima_temp) <- "whatever" explain_forecast( + testing = TRUE, model = model_custom_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -33,20 +33,20 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # not vector or one-column data.table/matrix - y_wrong_format <- data[, c("Temp", "Wind")] + y_wrong_format <- data_arima[, c("Temp", "Wind")] explain_forecast( + testing = TRUE, model = model_arima_temp, y = y_wrong_format, - xreg = data[, "Wind"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -55,11 +55,12 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # not correct dimension - xreg_wrong_format <- data[, c("Temp", "Wind")] + xreg_wrong_format <- data_arima[, c("Temp", "Wind")] explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], xreg = xreg_wrong_format, train_idx = 2:148, explain_idx = 149:150, @@ -67,8 +68,7 @@ test_that("erroneous input: `x_train/x_explain`", { explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -77,12 +77,13 @@ test_that("erroneous input: `x_train/x_explain`", { expect_snapshot( { # missing column names x_train - xreg_no_column_names <- data[, "Wind"] + xreg_no_column_names <- data_arima[, "Wind"] names(xreg_no_column_names) <- NULL explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], xreg = xreg_no_column_names, train_idx = 2:148, explain_idx = 149:150, @@ -90,8 +91,7 @@ test_that("erroneous input: `x_train/x_explain`", { explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -105,16 +105,16 @@ test_that("erroneous input: `model`", { { # no model passed explain_forecast( - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + testing = TRUE, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -122,7 +122,7 @@ test_that("erroneous input: `model`", { }) -test_that("erroneous input: `prediction_zero`", { +test_that("erroneous input: `phi0`", { set.seed(123) expect_snapshot( @@ -131,48 +131,48 @@ test_that("erroneous input: `prediction_zero`", { p0_wrong_length <- p0_ar[1:2] explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_wrong_length, - n_batches = 1 + phi0 = p0_wrong_length ) }, error = TRUE ) }) -test_that("erroneous input: `n_combinations`", { +test_that("erroneous input: `max_n_coalitions`", { set.seed(123) expect_snapshot( { - # Too low n_combinations (smaller than # features) + # Too low max_n_coalitions (smaller than # features) horizon <- 3 explain_y_lags <- 2 explain_xreg_lags <- 2 - n_combinations <- horizon + explain_y_lags + explain_xreg_lags - 1 + n_coalitions <- horizon + explain_y_lags + explain_xreg_lags - 1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags, explain_xreg_lags = explain_xreg_lags, horizon = horizon, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1, - n_combinations = n_combinations, + phi0 = p0_ar, + max_n_coalitions = n_coalitions, group_lags = FALSE ) }, @@ -180,33 +180,30 @@ test_that("erroneous input: `n_combinations`", { ) - expect_snapshot( - { - # Too low n_combinations (smaller than # groups) - horizon <- 3 - explain_y_lags <- 2 - explain_xreg_lags <- 2 - - n_combinations <- 1 + 1 - - explain_forecast( - model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], - train_idx = 2:148, - explain_idx = 149:150, - explain_y_lags = explain_y_lags, - explain_xreg_lags = explain_xreg_lags, - horizon = horizon, - approach = "independence", - prediction_zero = p0_ar, - n_batches = 1, - n_combinations = n_combinations, - group_lags = TRUE - ) - }, - error = TRUE - ) + expect_snapshot({ + # Too low n_coalitions (smaller than # groups) + horizon <- 3 + explain_y_lags <- 2 + explain_xreg_lags <- 2 + + n_coalitions <- 1 + 1 + + explain_forecast( + testing = TRUE, + model = model_arima_temp, + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], + train_idx = 2:148, + explain_idx = 149:150, + explain_y_lags = explain_y_lags, + explain_xreg_lags = explain_xreg_lags, + horizon = horizon, + approach = "independence", + phi0 = p0_ar, + max_n_coalitions = n_coalitions, + group_lags = TRUE + ) + }) }) @@ -219,17 +216,17 @@ test_that("erroneous input: `train_idx`", { train_idx_too_short <- 2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_too_short, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -242,17 +239,17 @@ test_that("erroneous input: `train_idx`", { train_idx_not_integer <- c(3:5) + 0.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_not_integer, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -264,17 +261,17 @@ test_that("erroneous input: `train_idx`", { train_idx_out_of_range <- 1:5 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = train_idx_out_of_range, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -290,17 +287,17 @@ test_that("erroneous input: `explain_idx`", { explain_idx_not_integer <- c(3:5) + 0.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_not_integer, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -312,17 +309,17 @@ test_that("erroneous input: `explain_idx`", { explain_idx_out_of_range <- 1:5 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = explain_idx_out_of_range, explain_y_lags = 2, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -338,17 +335,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_negative <- -1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_negative, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -360,17 +357,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_not_integer, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -382,17 +379,17 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_more_than_one <- c(1, 2) explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = explain_y_lags_more_than_one, explain_xreg_lags = 2, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -405,15 +402,15 @@ test_that("erroneous input: `explain_y_lags`", { explain_y_lags_zero <- 0 explain_forecast( + testing = TRUE, model = model_arima_temp_noxreg, - y = data[1:150, "Temp"], + y = data_arima[1:150, "Temp"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 0, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -430,17 +427,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_xreg_lags_negative <- -2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_negative, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -452,17 +449,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_xreg_lags_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_xreg_lags_not_integer, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -474,17 +471,17 @@ test_that("erroneous input: `explain_x_lags`", { explain_x_lags_wrong_length <- c(1, 2) # only 1 xreg variable defined explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = explain_x_lags_wrong_length, horizon = 3, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -500,17 +497,17 @@ test_that("erroneous input: `horizon`", { horizon_negative <- -2 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_negative, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE @@ -522,17 +519,17 @@ test_that("erroneous input: `horizon`", { horizon_not_integer <- 2.1 explain_forecast( + testing = TRUE, model = model_arima_temp, - y = data[1:150, "Temp"], - xreg = data[, "Wind"], + y = data_arima[1:150, "Temp"], + xreg = data_arima[, "Wind"], train_idx = 2:148, explain_idx = 149:150, explain_y_lags = 2, explain_xreg_lags = 2, horizon = horizon_not_integer, approach = "independence", - prediction_zero = p0_ar, - n_batches = 1 + phi0 = p0_ar ) }, error = TRUE diff --git a/tests/testthat/test-plot.R b/tests/testthat/test-plot.R index 5cb291cda263646af431e8deb3a7c7e0a75c7978..6fe9d69de6de05fcf9f710f9cf9cb1fe9d4ba1ce 100644 --- a/tests/testthat/test-plot.R +++ b/tests/testthat/test-plot.R @@ -1,53 +1,48 @@ set.seed(123) # explain_mixed <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_empirical <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_gaussian <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_ctree <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain_numeric_combined <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("empirical", "ctree", "gaussian", "ctree"), - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) # Create a list of explanations with names @@ -237,18 +232,18 @@ test_that("MSEv evaluation criterion plots", { ) vdiffr::expect_doppelganger( - title = "MSEv_combination_bar", - fig = MSEv_plots$MSEv_combination_bar + title = "MSEv_coalition_bar", + fig = MSEv_plots$MSEv_coalition_bar ) vdiffr::expect_doppelganger( - title = "MSEv_combination_bar specified width", - fig = MSEv_plots_specified_width$MSEv_combination_bar + title = "MSEv_coalition_bar specified width", + fig = MSEv_plots_specified_width$MSEv_coalition_bar ) vdiffr::expect_doppelganger( - title = "MSEv_combination_line_point", - fig = MSEv_plots$MSEv_combination_line_point + title = "MSEv_coalition_line_point", + fig = MSEv_plots$MSEv_coalition_line_point ) vdiffr::expect_doppelganger( @@ -261,13 +256,13 @@ test_that("MSEv evaluation criterion plots", { ) vdiffr::expect_doppelganger( - title = "MSEv_combinations for specified combinations", + title = "MSEv_coalitions for specified coalitions", fig = plot_MSEv_eval_crit( explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 - )$MSEv_combination_bar + )$MSEv_coalition_bar ) }) diff --git a/tests/testthat/test-regression-output.R b/tests/testthat/test-regression-output.R index 5b97e46e8a70333bed6a88ce4e906559abac72a2..d43acc7013ade35a2a68427b3e35b4558dd7b8bc 100644 --- a/tests/testthat/test-regression-output.R +++ b/tests/testthat/test-regression-output.R @@ -1,15 +1,32 @@ # Separate regression ================================================================================================== +test_that("output_lm_numeric_lm_separate_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_separate", + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = TRUE + ), + "output_lm_numeric_lm_separate_iterative" + ) +}) + + test_that("output_lm_numeric_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_separate" ) @@ -17,16 +34,16 @@ test_that("output_lm_numeric_lm_separate", { test_that("output_lm_numeric_lm_separate_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_separate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + max_n_coalitions = 10, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_separate_n_comb" ) @@ -34,15 +51,15 @@ test_that("output_lm_numeric_lm_separate_n_comb", { test_that("output_lm_categorical_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "regression_separate", - prediction_zero = p0, - n_batches = 4, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_categorical_lm_separate" ) @@ -50,15 +67,15 @@ test_that("output_lm_categorical_lm_separate", { test_that("output_lm_mixed_lm_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_mixed_lm_separate" ) @@ -66,18 +83,18 @@ test_that("output_lm_mixed_lm_separate", { test_that("output_lm_mixed_splines_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_separate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression.recipe) { recipes::step_ns(regression.recipe, recipes::all_numeric_predictors(), deg_free = 2) - } + }, + iterative = FALSE ), "output_lm_mixed_splines_separate" ) @@ -85,17 +102,17 @@ test_that("output_lm_mixed_splines_separate", { test_that("output_lm_mixed_decision_tree_cv_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_separate" ) @@ -104,17 +121,17 @@ test_that("output_lm_mixed_decision_tree_cv_separate", { test_that("output_lm_mixed_decision_tree_cv_separate_parallel", { future::plan("multisession", workers = 2) expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_separate_parallel" ) @@ -123,35 +140,52 @@ test_that("output_lm_mixed_decision_tree_cv_separate_parallel", { test_that("output_lm_mixed_xgboost_separate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression.recipe) { return(recipes::step_dummy(regression.recipe, recipes::all_factor_predictors())) - } + }, + iterative = FALSE ), "output_lm_mixed_xgboost_separate" ) }) # Surrogate regression ================================================================================================= +test_that("output_lm_numeric_lm_surrogate_iterative", { + expect_snapshot_rds( + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "regression_surrogate", + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = TRUE + ), + "output_lm_numeric_lm_surrogate_iterative" + ) +}) + + test_that("output_lm_numeric_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_surrogate" ) @@ -159,16 +193,16 @@ test_that("output_lm_numeric_lm_surrogate", { test_that("output_lm_numeric_lm_surrogate_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + max_n_coalitions = 10, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_numeric_lm_surrogate_n_comb" ) @@ -176,17 +210,17 @@ test_that("output_lm_numeric_lm_surrogate_n_comb", { test_that("output_lm_numeric_lm_surrogate_reg_surr_n_comb", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - n_combinations = 10, - timing = FALSE, + phi0 = p0, + max_n_coalitions = 10, regression.model = parsnip::linear_reg(), - regression.surrogate_n_comb = 8 + regression.surrogate_n_comb = 8, + iterative = FALSE ), "output_lm_numeric_lm_surrogate_reg_surr_n_comb" ) @@ -194,15 +228,15 @@ test_that("output_lm_numeric_lm_surrogate_reg_surr_n_comb", { test_that("output_lm_categorical_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 2, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_categorical_lm_surrogate" ) @@ -210,15 +244,15 @@ test_that("output_lm_categorical_lm_surrogate", { test_that("output_lm_mixed_lm_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "regression_surrogate", - prediction_zero = p0, - n_batches = 4, - timing = FALSE, - regression.model = parsnip::linear_reg() + phi0 = p0, + regression.model = parsnip::linear_reg(), + iterative = FALSE ), "output_lm_mixed_lm_surrogate" ) @@ -226,17 +260,17 @@ test_that("output_lm_mixed_lm_surrogate", { test_that("output_lm_mixed_decision_tree_cv_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::decision_tree(tree_depth = hardhat::tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2)), - regression.vfold_cv_para = list(v = 2) + regression.vfold_cv_para = list(v = 2), + iterative = FALSE ), "output_lm_mixed_decision_tree_cv_surrogate" ) @@ -244,18 +278,18 @@ test_that("output_lm_mixed_decision_tree_cv_surrogate", { test_that("output_lm_mixed_xgboost_surrogate", { expect_snapshot_rds( - shapr::explain( + explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, - prediction_zero = p0, - n_batches = 4, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression.recipe) { recipes::step_dummy(regression.recipe, recipes::all_factor_predictors()) - } + }, + iterative = FALSE ), "output_lm_mixed_xgboost_surrogate" ) diff --git a/tests/testthat/test-regression-setup.R b/tests/testthat/test-regression-setup.R index 43f4b3fc431871d1ac33e12eeace758601df931e..f88c3692fa8a16caefd75811f72fd9909e6bfa64 100644 --- a/tests/testthat/test-regression-setup.R +++ b/tests/testthat/test-regression-setup.R @@ -5,13 +5,13 @@ test_that("regression erroneous input: `approach`", { { # Include regression_surrogate explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = c("regression_surrogate", "gaussian", "independence", "empirical"), + iterative = FALSE ) }, error = TRUE @@ -21,13 +21,13 @@ test_that("regression erroneous input: `approach`", { { # Include regression_separate explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = c("regression_separate", "gaussian", "independence", "empirical"), + iterative = FALSE ) }, error = TRUE @@ -41,12 +41,11 @@ test_that("regression erroneous input: `regression.model`", { { # no regression model passed explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = NULL ) @@ -58,12 +57,11 @@ test_that("regression erroneous input: `regression.model`", { { # not a tidymodels object of class model_spec explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = lm ) @@ -75,12 +73,11 @@ test_that("regression erroneous input: `regression.model`", { { # regression.tune_values` must be provided when `regression.model` contains hyperparameters to tune. explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression") ) @@ -92,12 +89,11 @@ test_that("regression erroneous input: `regression.model`", { { # The tunable parameters and the parameters value do not match explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(num_terms = c(1, 2, 3)) @@ -110,12 +106,11 @@ test_that("regression erroneous input: `regression.model`", { { # The tunable parameters and the parameters value do not match explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3), num_terms = c(1, 2, 3)) @@ -128,12 +123,11 @@ test_that("regression erroneous input: `regression.model`", { { # Provide regression.tune_values but the parameter has allready been specified in the regression.model explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)) @@ -146,14 +140,14 @@ test_that("regression erroneous input: `regression.model`", { { # Provide regression.tune_values but not a model where these are to be used explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.tune_values = data.frame(tree_depth = c(1, 2, 3)) + regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), + iterative = FALSE ) }, error = TRUE @@ -168,12 +162,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # Provide hyperparameter values, but hyperparameter has not been declared as a tunable parameter explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = 2, engine = "rpart", mode = "regression"), regression.tune_values = as.matrix(data.frame(tree_depth = c(1, 2, 3))) @@ -186,12 +179,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # The regression.tune_values function must return a data.frame explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) c(1, 2, 3) @@ -204,12 +196,11 @@ test_that("regression erroneous input: `regression.tune_values`", { { # The regression.tune_values function must return a data.frame with correct names explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = function(x) data.frame(wrong_name = c(1, 2, 3)) @@ -226,12 +217,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # `regression.vfold_cv_para` is not a list explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -245,12 +235,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # `regression.vfold_cv_para` is not a named list explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -264,12 +253,11 @@ test_that("regression erroneous input: `regression.vfold_cv_para`", { { # Unrecognized parameter explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(tree_depth = tune(), engine = "rpart", mode = "regression"), regression.tune_values = data.frame(tree_depth = c(1, 2, 3)), @@ -288,12 +276,11 @@ test_that("regression erroneous input: `regression.recipe_func`", { { # regression.recipe_func is not a function explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_separate", regression.recipe_func = 3 ) @@ -305,16 +292,16 @@ test_that("regression erroneous input: `regression.recipe_func`", { { # regression.recipe_func must output a recipe explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", regression.recipe_func = function(x) { return(2) - } + }, + iterative = FALSE ) }, error = TRUE @@ -328,14 +315,14 @@ test_that("regression erroneous input: `regression.surrogate_n_comb`", { { # regression.surrogate_n_comb must be between 1 and 2^n_features - 2 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1 + regression.surrogate_n_comb = 2^ncol(x_explain_numeric) - 1, + iterative = FALSE ) }, error = TRUE @@ -345,14 +332,14 @@ test_that("regression erroneous input: `regression.surrogate_n_comb`", { { # regression.surrogate_n_comb must be between 1 and 2^n_features - 2 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, - prediction_zero = p0, - n_batches = 1, - timing = FALSE, + phi0 = p0, approach = "regression_surrogate", - regression.surrogate_n_comb = 0 + regression.surrogate_n_comb = 0, + iterative = FALSE ) }, error = TRUE diff --git a/tests/testthat/test-output.R b/tests/testthat/test-regular-output.R similarity index 79% rename from tests/testthat/test-output.R rename to tests/testthat/test-regular-output.R index cea7c696a41d054c5fe4f565f4cfb9e20094e280..7470806077f2ad4799d7e1db6dc6d85a4309c201 100644 --- a/tests/testthat/test-output.R +++ b/tests/testthat/test-regular-output.R @@ -3,13 +3,13 @@ test_that("output_lm_numeric_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_independence" ) @@ -18,14 +18,14 @@ test_that("output_lm_numeric_independence", { test_that("output_lm_numeric_independence_MSEv_Shapley_weights", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - MSEv_uniform_comb_weights = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = FALSE), + iterative = FALSE ), "output_lm_numeric_independence_MSEv_Shapley_weights" ) @@ -34,31 +34,31 @@ test_that("output_lm_numeric_independence_MSEv_Shapley_weights", { test_that("output_lm_numeric_empirical", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_empirical" ) }) -test_that("output_lm_numeric_empirical_n_combinations", { +test_that("output_lm_numeric_empirical_n_coalitions", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 20, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = 20, + iterative = FALSE ), - "output_lm_numeric_empirical_n_combinations" + "output_lm_numeric_empirical_n_coalitions" ) }) @@ -66,14 +66,14 @@ test_that("output_lm_numeric_empirical_independence", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, + phi0 = p0, empirical.type = "independence", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_independence" ) @@ -83,15 +83,15 @@ test_that("output_lm_numeric_empirical_AICc_each", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 8, + phi0 = p0, + max_n_coalitions = 8, empirical.type = "AICc_each_k", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_AICc_each" ) @@ -101,15 +101,15 @@ test_that("output_lm_numeric_empirical_AICc_full", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_combinations = 8, + phi0 = p0, + max_n_coalitions = 8, empirical.type = "AICc_full", - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_numeric_empirical_AICc_full" ) @@ -118,13 +118,13 @@ test_that("output_lm_numeric_empirical_AICc_full", { test_that("output_lm_numeric_gaussian", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_gaussian" ) @@ -133,13 +133,13 @@ test_that("output_lm_numeric_gaussian", { test_that("output_lm_numeric_copula", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "copula", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_copula" ) @@ -148,35 +148,36 @@ test_that("output_lm_numeric_copula", { test_that("output_lm_numeric_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_ctree" ) }) test_that("output_lm_numeric_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes names and objects such as tmpdir and tmpfile - ) + ), + iterative = FALSE ), "output_lm_numeric_vaeac" ) @@ -185,35 +186,36 @@ test_that("output_lm_numeric_vaeac", { test_that("output_lm_categorical_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_ctree" ) }) test_that("output_lm_categorical_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes tmpdir and tmpfiles - ) + ), + iterative = FALSE ), "output_lm_categorical_vaeac" ) @@ -222,13 +224,13 @@ test_that("output_lm_categorical_vaeac", { test_that("output_lm_categorical_categorical", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "categorical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_method" ) @@ -237,13 +239,13 @@ test_that("output_lm_categorical_categorical", { test_that("output_lm_categorical_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_categorical, x_explain = x_explain_categorical, x_train = x_train_categorical, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_categorical_independence" ) @@ -252,14 +254,14 @@ test_that("output_lm_categorical_independence", { test_that("output_lm_ts_timeseries", { expect_snapshot_rds( explanation_timeseries <- explain( + testing = TRUE, model = model_lm_ts, x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, + phi0 = p0_ts, group = group_ts, - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_lm_timeseries_method" ) @@ -268,13 +270,13 @@ test_that("output_lm_ts_timeseries", { test_that("output_lm_numeric_comb1", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("gaussian", "empirical", "ctree", "independence"), - prediction_zero = p0, - n_batches = 4, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb1" ) @@ -283,13 +285,13 @@ test_that("output_lm_numeric_comb1", { test_that("output_lm_numeric_comb2", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("ctree", "copula", "independence", "copula"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb2" ) @@ -298,13 +300,13 @@ test_that("output_lm_numeric_comb2", { test_that("output_lm_numeric_comb3", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "empirical"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_comb3" ) @@ -316,13 +318,13 @@ test_that("output_lm_numeric_comb3", { test_that("output_lm_mixed_independence", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_independence" ) @@ -331,35 +333,36 @@ test_that("output_lm_mixed_independence", { test_that("output_lm_mixed_ctree", { expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_ctree" ) }) test_that("output_lm_mixed_vaeac", { + skip_on_os("mac") # The code runs on macOS, but it gives different Shapley values due to inconsistencies in torch seed expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - n_samples = 10, # Low value here to speed up the time + phi0 = p0, + n_MC_samples = 10, # Low value here to speed up the time vaeac.epochs = 4, # Low value here to speed up the time vaeac.n_vaeacs_initialize = 2, # Low value here to speed up the time vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2, # Low value here to speed up the time vaeac.save_model = FALSE # Removes tmpdir and tmpfiles - ) + ), + iterative = FALSE ), "output_lm_mixed_vaeac" ) @@ -369,13 +372,13 @@ test_that("output_lm_mixed_comb", { set.seed(123) expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = c("ctree", "independence", "ctree", "independence"), - prediction_zero = p0, - n_batches = 2, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_mixed_comb" ) @@ -396,14 +399,14 @@ test_that("output_custom_lm_numeric_independence_1", { expect_snapshot_rds( explain( + testing = TRUE, model = model_custom_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_pred_func, - n_batches = 1, - timing = FALSE + iterative = FALSE ), "output_custom_lm_numeric_independence_1" ) @@ -423,32 +426,32 @@ test_that("output_custom_lm_numeric_independence_2", { expect_snapshot_rds( (custom <- explain( + testing = TRUE, model = model_custom_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_pred_func, - n_batches = 1, - timing = FALSE + iterative = FALSE )), "output_custom_lm_numeric_independence_2" ) native <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ) # Check that the printed Shapley values are identical expect_equal( - custom$shapley_values, - native$shapley_values + custom$shapley_values_est, + native$shapley_values_est ) }) @@ -486,15 +489,15 @@ test_that("output_custom_xgboost_mixed_dummy_ctree", { expect_snapshot_rds( { custom <- explain( + testing = TRUE, model = model_xgboost_mixed_dummy, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "ctree", - prediction_zero = p0, + phi0 = p0, predict_model = predict_model.xgboost_dummy, get_model_specs = NA, - n_batches = 1, - timing = FALSE + iterative = FALSE ) # custom$internal$objects$predict_model <- "Del on purpose" # Avoids issues with xgboost package updates custom @@ -509,13 +512,13 @@ test_that("output_lm_numeric_interaction", { x_explain_interaction <- x_explain_numeric[, mget(all.vars(formula(model_lm_interaction))[-1])] expect_snapshot_rds( explain( + testing = TRUE, model = model_lm_interaction, x_explain = x_explain_interaction, x_train = x_train_interaction, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ), "output_lm_numeric_interaction" ) @@ -526,13 +529,13 @@ test_that("output_lm_numeric_ctree_parallelized", { expect_snapshot_rds( { explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0, + iterative = FALSE ) }, "output_lm_numeric_ctree_parallelized" @@ -540,23 +543,6 @@ test_that("output_lm_numeric_ctree_parallelized", { future::plan("sequential") }) -test_that("output_lm_numeric_independence_more_batches", { - expect_snapshot_rds( - { - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - }, - "output_lm_numeric_independence_n_batches_10" - ) -}) - # Nothing special here, as the test does not record the actual progress output. # It just checks whether calling on progressr does not produce an error or unexpected output. test_that("output_lm_numeric_empirical_progress", { @@ -565,13 +551,13 @@ test_that("output_lm_numeric_empirical_progress", { { progressr::with_progress({ explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0, + iterative = FALSE ) }) }, @@ -580,18 +566,18 @@ test_that("output_lm_numeric_empirical_progress", { }) -# Just checking that internal$output$dt_samp_for_vS keep_samp_for_vS +# Just checking that internal$output$dt_samp_for_vS works test_that("output_lm_numeric_independence_keep_samp_for_vS", { expect_snapshot_rds( (out <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE, - keep_samp_for_vS = TRUE + phi0 = p0, + output_args = list(keep_samp_for_vS = TRUE), + iterative = FALSE )), "output_lm_numeric_independence_keep_samp_for_vS" ) diff --git a/tests/testthat/test-setup.R b/tests/testthat/test-regular-setup.R similarity index 59% rename from tests/testthat/test-setup.R rename to tests/testthat/test-regular-setup.R index 6fdb0b9e080a5192c7f72354362743f7ee1d4b6f..ba610ad33df435dccbdcfc54438a2f82a503c475 100644 --- a/tests/testthat/test-setup.R +++ b/tests/testthat/test-regular-setup.R @@ -10,13 +10,12 @@ test_that("error with custom model without providing predict_model", { class(model_custom_lm_mixed) <- "whatever" explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -38,15 +37,14 @@ test_that("messages with missing detail in get_model_specs", { expect_snapshot({ # Custom model with no get_model_specs explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = NA, - n_batches = 1, - timing = FALSE + get_model_specs = NA ) }) @@ -58,15 +56,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_get_model_specs_no_lab, - n_batches = 1, - timing = FALSE + get_model_specs = custom_get_model_specs_no_lab ) }) @@ -78,15 +75,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_classes, - n_batches = 1, - timing = FALSE + get_model_specs = custom_gms_no_classes ) }) @@ -102,15 +98,14 @@ test_that("messages with missing detail in get_model_specs", { } explain( + testing = TRUE, model = model_custom_lm_mixed, x_train = x_train_mixed, x_explain = x_explain_mixed, approach = "independence", - prediction_zero = p0, + phi0 = p0, predict_model = custom_predict_model, - get_model_specs = custom_gms_no_factor_levels, - n_batches = 1, - timing = FALSE + get_model_specs = custom_gms_no_factor_levels ) }) }) @@ -124,13 +119,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_train_wrong_format <- c(a = 1, b = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_wrong_format, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -142,13 +136,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_explain_wrong_format <- c(a = 1, b = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -161,13 +154,12 @@ test_that("erroneous input: `x_train/x_explain`", { x_explain_wrong_format <- c(a = 3, b = 4) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_wrong_format, x_train = x_train_wrong_format, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -181,13 +173,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_train_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -200,13 +191,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_explain_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -220,13 +210,12 @@ test_that("erroneous input: `x_train/x_explain`", { names(x_explain_no_column_names) <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_no_column_names, x_train = x_train_no_column_names, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -240,12 +229,11 @@ test_that("erroneous input: `model`", { { # no model passed explain( + testing = TRUE, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -261,13 +249,12 @@ test_that("erroneous input: `approach`", { approach_non_character <- 1 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_non_character, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -279,13 +266,12 @@ test_that("erroneous input: `approach`", { approach_incorrect_length <- c("empirical", "gaussian") explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_incorrect_length, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -297,20 +283,19 @@ test_that("erroneous input: `approach`", { approach_incorrect_character <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = approach_incorrect_character, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE ) }) -test_that("erroneous input: `prediction_zero`", { +test_that("erroneous input: `phi0`", { set.seed(123) expect_snapshot( @@ -319,13 +304,12 @@ test_that("erroneous input: `prediction_zero`", { p0_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0_non_numeric_1 ) }, error = TRUE @@ -337,13 +321,12 @@ test_that("erroneous input: `prediction_zero`", { p0_non_numeric_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0_non_numeric_2 ) }, error = TRUE @@ -356,13 +339,12 @@ test_that("erroneous input: `prediction_zero`", { p0_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0_too_long ) }, error = TRUE @@ -374,36 +356,34 @@ test_that("erroneous input: `prediction_zero`", { p0_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0_is_NA ) }, error = TRUE ) }) -test_that("erroneous input: `n_combinations`", { +test_that("erroneous input: `max_n_coalitions`", { set.seed(123) expect_snapshot( { # non-numeric 1 - n_combinations_non_numeric_1 <- "bla" + max_n_comb_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_1 ) }, error = TRUE @@ -412,17 +392,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # non-numeric 2 - n_combinations_non_numeric_2 <- TRUE + max_n_comb_non_numeric_2 <- TRUE explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_numeric_2 ) }, error = TRUE @@ -432,17 +411,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # non-integer - n_combinations_non_integer <- 10.5 + max_n_coalitions_non_integer <- 10.5 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_integer, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_non_integer ) }, error = TRUE @@ -453,17 +431,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # length > 1 - n_combinations_too_long <- c(1, 2) + max_n_coalitions_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_too_long ) }, error = TRUE @@ -472,17 +449,16 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # NA-numeric - n_combinations_is_NA <- as.numeric(NA) + max_n_coalitions_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_coalitions_is_NA ) }, error = TRUE @@ -491,67 +467,58 @@ test_that("erroneous input: `n_combinations`", { expect_snapshot( { # Non-positive - n_combinations_non_positive <- 0 + max_n_comb_non_positive <- 0 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations_non_positive, - n_batches = 1, - timing = FALSE + phi0 = p0, + max_n_coalitions = max_n_comb_non_positive ) }, error = TRUE ) - expect_snapshot( - { - # Too low n_combinations (smaller than # features - n_combinations <- ncol(x_explain_numeric) - 1 + expect_snapshot({ + # Too low max_n_coalitions (smaller than # features + max_n_coalitions <- ncol(x_explain_numeric) - 1 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - prediction_zero = p0, - approach = "gaussian", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE - ) - }, - error = TRUE - ) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + approach = "gaussian", + max_n_coalitions = max_n_coalitions + ) + }) - expect_snapshot( - { - # Too low n_combinations (smaller than # groups - groups <- list( - A = c("Solar.R", "Wind"), - B = c("Temp", "Month"), - C = "Day" - ) + expect_snapshot({ + # Too low max_n_coalitions (smaller than # groups + groups <- list( + A = c("Solar.R", "Wind"), + B = c("Temp", "Month"), + C = "Day" + ) - n_combinations <- length(groups) - 1 + max_n_coalitions <- length(groups) - 1 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - prediction_zero = p0, - approach = "gaussian", - group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE - ) - }, - error = TRUE - ) + explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + phi0 = p0, + approach = "gaussian", + group = groups, + max_n_coalitions = max_n_coalitions + ) + }) }) test_that("erroneous input: `group`", { @@ -563,14 +530,13 @@ test_that("erroneous input: `group`", { group_non_list <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_non_list, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_non_list ) }, error = TRUE @@ -582,14 +548,13 @@ test_that("erroneous input: `group`", { group_with_non_characters <- list(A = 1, B = 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_with_non_characters, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_with_non_characters ) }, error = TRUE @@ -603,14 +568,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_with_non_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_with_non_data_features ) }, error = TRUE @@ -624,14 +588,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_missing_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_missing_data_features ) }, error = TRUE @@ -645,14 +608,13 @@ test_that("erroneous input: `group`", { B = c("Temp", "Month", "Day") ) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = group_dup_data_features, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = group_dup_data_features ) }, error = TRUE @@ -663,21 +625,20 @@ test_that("erroneous input: `group`", { # a single group only single_group <- list(A = c("Solar.R", "Wind", "Temp", "Month", "Day")) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - group = single_group, - n_batches = 1, - timing = FALSE + phi0 = p0, + group = single_group ) }, error = TRUE ) }) -test_that("erroneous input: `n_samples`", { +test_that("erroneous input: `n_MC_samples`", { set.seed(123) expect_snapshot( @@ -686,14 +647,13 @@ test_that("erroneous input: `n_samples`", { n_samples_non_numeric_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_numeric_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_numeric_1 ) }, error = TRUE @@ -705,14 +665,13 @@ test_that("erroneous input: `n_samples`", { n_samples_non_numeric_2 <- TRUE explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_numeric_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_numeric_2 ) }, error = TRUE @@ -723,14 +682,13 @@ test_that("erroneous input: `n_samples`", { # non-integer n_samples_non_integer <- 10.5 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_integer, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_integer ) }, error = TRUE @@ -741,14 +699,13 @@ test_that("erroneous input: `n_samples`", { { n_samples_too_long <- c(1, 2) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_too_long ) }, error = TRUE @@ -759,14 +716,13 @@ test_that("erroneous input: `n_samples`", { { n_samples_is_NA <- as.numeric(NA) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_is_NA, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_is_NA ) }, error = TRUE @@ -777,161 +733,19 @@ test_that("erroneous input: `n_samples`", { { n_samples_non_positive <- 0 explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - n_samples = n_samples_non_positive, - n_batches = 1, - timing = FALSE + phi0 = p0, + n_MC_samples = n_samples_non_positive ) }, error = TRUE ) }) -test_that("erroneous input: `n_batches`", { - set.seed(123) - - # non-numeric 1 - expect_snapshot( - { - n_batches_non_numeric_1 <- "bla" - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_numeric_1, - timing = FALSE - ) - }, - error = TRUE - ) - - # non-numeric 2 - expect_snapshot( - { - n_batches_non_numeric_2 <- TRUE - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_numeric_2, - timing = FALSE - ) - }, - error = TRUE - ) - - # non-integer - expect_snapshot( - { - n_batches_non_integer <- 10.5 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_integer, - timing = FALSE - ) - }, - error = TRUE - ) - - # length > 1 - expect_snapshot( - { - n_batches_too_long <- c(1, 2) - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_too_long, - timing = FALSE - ) - }, - error = TRUE - ) - - # NA-numeric - expect_snapshot( - { - n_batches_is_NA <- as.numeric(NA) - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_is_NA, - timing = FALSE - ) - }, - error = TRUE - ) - - # Non-positive - expect_snapshot( - { - n_batches_non_positive <- 0 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_non_positive, - timing = FALSE - ) - }, - error = TRUE - ) - - # Larger than number of n_combinations - expect_snapshot( - { - n_combinations <- 10 - n_batches_too_large <- 11 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_combinations = n_combinations, - n_batches = n_batches_too_large, - timing = FALSE - ) - }, - error = TRUE - ) - - # Larger than number of n_combinations without specification - expect_snapshot( - { - n_batches_too_large_2 <- 32 - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "independence", - prediction_zero = p0, - n_batches = n_batches_too_large_2, - timing = FALSE - ) - }, - error = TRUE - ) -}) test_that("erroneous input: `seed`", { set.seed(123) @@ -941,14 +755,13 @@ test_that("erroneous input: `seed`", { { seed_not_integer_interpretable <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - seed = seed_not_integer_interpretable, - n_batches = 1, - timing = FALSE + phi0 = p0, + seed = seed_not_integer_interpretable ) }, error = TRUE @@ -963,14 +776,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_non_logical_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_non_logical_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_non_logical_1) ) }, error = TRUE @@ -981,14 +793,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_non_logical_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_non_logical_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_non_logical_2) ) }, error = TRUE @@ -999,14 +810,13 @@ test_that("erroneous input: `keep_samp_for_vS`", { { keep_samp_for_vS_too_long <- c(TRUE, FALSE) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - keep_samp_for_vS = keep_samp_for_vS_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(keep_samp_for_vS = keep_samp_for_vS_too_long) ) }, error = TRUE @@ -1021,14 +831,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_nl_1 <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_1) ) }, error = TRUE @@ -1039,14 +848,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_nl_2 <- NULL explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_nl_2) ) }, error = TRUE @@ -1057,14 +865,13 @@ test_that("erroneous input: `MSEv_uniform_comb_weights`", { { MSEv_uniform_comb_weights_long <- c(TRUE, FALSE) explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + output_args = list(MSEv_uniform_comb_weights = MSEv_uniform_comb_weights_long) ) }, error = TRUE @@ -1080,14 +887,13 @@ test_that("erroneous input: `predict_model`", { predict_model_nonfunction <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_nonfunction, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_nonfunction ) }, error = TRUE @@ -1101,14 +907,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_non_num_output, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_non_num_output ) }, error = TRUE @@ -1122,14 +927,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_wrong_output_len, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_wrong_output_len ) }, error = TRUE @@ -1143,14 +947,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_invalid_argument, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_invalid_argument ) }, error = TRUE @@ -1164,14 +967,13 @@ test_that("erroneous input: `predict_model`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - predict_model = predict_model_error, - n_batches = 1, - timing = FALSE + phi0 = p0, + predict_model = predict_model_error ) }, error = TRUE @@ -1187,14 +989,13 @@ test_that("erroneous input: `get_model_specs`", { get_model_specs_nonfunction <- "bla" explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_model_specs_nonfunction, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_model_specs_nonfunction ) }, error = TRUE @@ -1209,14 +1010,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_not_list, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_not_list ) }, error = TRUE @@ -1230,14 +1030,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_too_long, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_too_long ) }, error = TRUE @@ -1255,14 +1054,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_ms_output_wrong_names, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_ms_output_wrong_names ) }, error = TRUE @@ -1276,14 +1074,13 @@ test_that("erroneous input: `get_model_specs`", { } explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "independence", - prediction_zero = p0, - get_model_specs = get_model_specs_error, - n_batches = 1, - timing = FALSE + phi0 = p0, + get_model_specs = get_model_specs_error ) }, error = TRUE @@ -1298,13 +1095,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach gaussian non_factor_approach_1 <- "gaussian" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_1, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1315,13 +1111,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach empirical non_factor_approach_2 <- "empirical" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_2, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1332,13 +1127,12 @@ test_that("incompatible input: `data/approach`", { # factor model/data with approach copula non_factor_approach_3 <- "copula" explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, approach = non_factor_approach_3, - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) }, error = TRUE @@ -1346,35 +1140,33 @@ test_that("incompatible input: `data/approach`", { }) test_that("Correct dimension of S when sampling combinations", { - n_combinations <- 10 + max_n_coalitions <- 10 res <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - prediction_zero = p0, + phi0 = p0, approach = "ctree", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) - expect_equal(nrow(res$internal$objects$S), n_combinations) + expect_equal(nrow(res$internal$objects$S), max_n_coalitions) }) -test_that("Error with too low `n_combinations`", { - n_combinations <- ncol(x_explain_numeric) - 1 +test_that("Message with too low `max_n_coalitions`", { + max_n_coalitions <- ncol(x_explain_numeric) - 1 - expect_error( + expect_snapshot( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_explain_numeric, - prediction_zero = p0, + phi0 = p0, approach = "gaussian", - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) ) @@ -1385,85 +1177,76 @@ test_that("Error with too low `n_combinations`", { C = "Day" ) - n_combinations <- length(groups) - 1 + max_n_coalitions <- length(groups) - 1 - expect_error( + expect_snapshot( explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_explain_numeric, - prediction_zero = p0, + phi0 = p0, approach = "gaussian", group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) ) }) -test_that("Shapr with `n_combinations` >= 2^m uses exact Shapley kernel weights", { - # Check that the `explain()` function enters the exact mode when n_combinations +test_that("Shapr with `max_n_coalitions` >= 2^m uses exact Shapley kernel weights", { + # Check that the `explain()` function enters the exact mode when max_n_coalitions # is larger than or equal to 2^m. # Create three explainer object: one with exact mode, one with - # `n_combinations` = 2^m, and one with `n_combinations` > 2^m + # `max_n_coalitions` = 2^m, and one with `max_n_coalitions` > 2^m # No message as n_combination = NULL sets exact mode - expect_no_message( - object = { - explanation_exact <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = NULL, - timing = FALSE - ) - } + expect_snapshot( + explanation_exact <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + max_n_coalitions = NULL, + iterative = FALSE + ) ) - # We should get a message saying that we are using the exact mode. - # The `regexp` format match the one written in `feature_combinations()`. - expect_message( - object = { - explanation_equal <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = 2^ncol(x_explain_numeric), - timing = FALSE - ) - }, - regexp = "Success with message:\nn_combinations is larger than or equal to 2\\^m = 32. \nUsing exact instead." + expect_snapshot( + explanation_equal <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + extra_computation_args = list(compute_sd = FALSE), + max_n_coalitions = 2^ncol(x_explain_numeric), + iterative = FALSE + ) ) # We should get a message saying that we are using the exact mode. - # The `regexp` format match the one written in `feature_combinations()`. - expect_message( - object = { - explanation_larger <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "gaussian", - prediction_zero = p0, - n_samples = 2, # Low value for fast computations - n_batches = 1, # Not related to the bug - seed = 123, - n_combinations = 2^ncol(x_explain_numeric) + 1, - timing = FALSE - ) - }, - regexp = "Success with message:\nn_combinations is larger than or equal to 2\\^m = 32. \nUsing exact instead." + # The `regexp` format match the one written in `create_coalition_table()`. + expect_snapshot( + explanation_larger <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0, + n_MC_samples = 2, # Low value for fast computations + seed = 123, + extra_computation_args = list(compute_sd = FALSE), + max_n_coalitions = 2^ncol(x_explain_numeric) + 1, + iterative = FALSE + ) ) # Test that returned objects are identical (including all using the exact option and having the same Shapley weights) @@ -1476,19 +1259,19 @@ test_that("Shapr with `n_combinations` >= 2^m uses exact Shapley kernel weights" explanation_larger ) - # Explicitly check that exact mode is set and that n_combinations equals 2^ncol(x_explain_numeric) (32) + # Explicitly check that exact mode is set and that max_n_coalitions equals 2^ncol(x_explain_numeric) (32) # Since all 3 explanation objects are equal (per the above test) it suffices to do this for explanation_exact expect_true( explanation_exact$internal$parameters$exact ) expect_equal( - explanation_exact$internal$parameters$n_combinations, + explanation_exact$internal$objects$X[, .N], 2^ncol(x_explain_numeric) ) }) test_that("Correct dimension of S when sampling combinations with groups", { - n_combinations <- 5 + max_n_coalitions <- 6 groups <- list( A = c("Solar.R", "Wind"), @@ -1497,59 +1280,55 @@ test_that("Correct dimension of S when sampling combinations with groups", { ) res <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_explain_mixed, - prediction_zero = p0, + phi0 = p0, approach = "ctree", group = groups, - n_combinations = n_combinations, - n_batches = 1, - timing = FALSE + max_n_coalitions = max_n_coalitions ) - expect_equal(nrow(res$internal$objects$S), n_combinations) + expect_equal(nrow(res$internal$objects$S), max_n_coalitions) }) test_that("data feature ordering is output_lm_numeric_column_order", { explain.original <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) - explain.new_data_feature_order <- explain( + ex.new_data_feature_order <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = rev(x_explain_numeric), x_train = rev(x_train_numeric), approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) explain.new_model_feat_order <- explain( + testing = TRUE, model = model_lm_numeric_col_order, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 1, - timing = FALSE + phi0 = p0 ) # Same Shapley values, but different order expect_false(identical( - explain.original$shapley_values, - explain.new_data_feature_order$shapley_values + explain.original$shapley_values_est, + ex.new_data_feature_order$shapley_values_est )) expect_equal( - explain.original$shapley_values[, mget(sort(names(explain.original$shapley_values)))], - explain.new_data_feature_order$shapley_values[, mget(sort(names(explain.new_data_feature_order$shapley_values)))] + explain.original$shapley_values_est[, mget(sort(names(explain.original$shapley_values_est)))], + ex.new_data_feature_order$shapley_values_est[, mget(sort(names(ex.new_data_feature_order$shapley_values_est)))] ) # Same Shapley values in same order @@ -1559,24 +1338,22 @@ test_that("data feature ordering is output_lm_numeric_column_order", { test_that("parallelization gives same output for any approach", { # Empirical is seed independent explain.empirical_sequential <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("multisession", workers = 2) # Parallelized with 2 cores explain.empirical_multisession <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("sequential") # Resetting to sequential computation @@ -1590,24 +1367,22 @@ test_that("parallelization gives same output for any approach", { # ctree is seed NOT independent explain.ctree_sequential <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("multisession", workers = 2) # Parallelized with 2 cores explain.ctree_multisession <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE + phi0 = p0 ) future::plan("sequential") # Resetting to sequential computation @@ -1619,81 +1394,16 @@ test_that("parallelization gives same output for any approach", { ) }) -test_that("different n_batches gives same/different shapley values for different approaches", { - # approach "empirical" is seed independent - explain.empirical_n_batches_5 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "empirical", - prediction_zero = p0, - n_batches = 5, - timing = FALSE - ) - - explain.empirical_n_batches_10 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "empirical", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - - # Difference in the objects (n_batches and related) - expect_false(identical( - explain.empirical_n_batches_5, - explain.empirical_n_batches_10 - )) - # Same Shapley values - expect_equal( - explain.empirical_n_batches_5$shapley_values, - explain.empirical_n_batches_10$shapley_values - ) - - # approach "ctree" is seed dependent - explain.ctree_n_batches_5 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "ctree", - prediction_zero = p0, - n_batches = 5, - timing = FALSE - ) - - explain.ctree_n_batches_10 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = "ctree", - prediction_zero = p0, - n_batches = 10, - timing = FALSE - ) - - # Difference in the objects (n_batches and related) - expect_false(identical( - explain.ctree_n_batches_5, - explain.ctree_n_batches_10 - )) - # NEITHER same Shapley values - expect_false(identical( - explain.ctree_n_batches_5$shapley_values, - explain.ctree_n_batches_10$shapley_values - )) -}) test_that("gaussian approach use the user provided parameters", { # approach "gaussian" with default parameter estimation, i.e., sample mean and covariance e.gaussian_samp_mean_cov <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - timing = FALSE + phi0 = p0, ) # Expect that gaussian.mu is the sample mean when no values are provided @@ -1718,8 +1428,7 @@ test_that("gaussian approach use the user provided parameters", { x_explain = x_explain_numeric, x_train = x_train_numeric, approach = "gaussian", - prediction_zero = p0, - timing = FALSE, + phi0 = p0, gaussian.mu = gaussian.provided_mu, gaussian.cov_mat = gaussian.provided_cov_mat ) @@ -1737,166 +1446,31 @@ test_that("gaussian approach use the user provided parameters", { ) }) -test_that("Shapr sets a valid default value for `n_batches`", { - # Shapr sets the default number of batches to be 10 for this dataset and the - # "ctree", "gaussian", and "copula" approaches. Thus, setting `n_combinations` - # to any value lower of equal to 10 causes the error. - any_number_equal_or_below_10 <- 8 - - # Before the bugfix, shapr:::check_n_batches() throws the error: - # Error in check_n_batches(internal) : - # `n_batches` (10) must be smaller than the number feature combinations/`n_combinations` (8) - # Bug only occures for "ctree", "gaussian", and "copula" as they are treated different in - # `get_default_n_batches()`, I am not certain why. Ask Martin about the logic behind that. - expect_no_error( - explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - n_samples = 2, # Low value for fast computations - approach = "gaussian", - prediction_zero = p0, - n_combinations = any_number_equal_or_below_10 - ) - ) -}) - -test_that("Error with to low `n_batches` compared to the number of unique approaches", { - # Expect to get the following error: - # `n_batches` (3) must be larger than the number of unique approaches in `approach` (4). - expect_error( - object = explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - n_batches = 3, - timing = FALSE, - seed = 1 - ) - ) - - # Except that shapr sets a valid `n_batches` and get no errors - expect_no_error( - object = explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - n_batches = NULL, - timing = FALSE, - seed = 1 - ) - ) -}) - -test_that("the used number of batches mathces the provided `n_batches` for combined approaches", { - explanation_1 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = 2, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal( - explanation_1$internal$parameters$n_batches, - length(explanation_1$internal$objects$S_batch) - ) - - explanation_2 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = 15, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the provided `n_batches` - expect_equal( - explanation_2$internal$parameters$n_batches, - length(explanation_2$internal$objects$S_batch) - ) - - # Check for the default value for `n_batch` - explanation_3 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "ctree", "ctree", "ctree"), - prediction_zero = p0, - n_batches = NULL, - timing = FALSE, - seed = 1 - ) - - # Check that the used number of batches corresponds with the `n_batches` - expect_equal( - explanation_3$internal$parameters$n_batches, - length(explanation_3$internal$objects$S_batch) - ) -}) test_that("setting the seed for combined approaches works", { # Check that setting the seed works for a combination of approaches - # Here `n_batches` is set to `4`, so one batch for each method, - # i.e., no randomness. explanation_combined_1 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) explanation_combined_2 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) # Check that they are equal expect_equal(explanation_combined_1, explanation_combined_2) - - # Here `n_batches` is set to `10`, so NOT one batch for each method, - # i.e., randomness in assigning the batches. - explanation_combined_3 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, - seed = 1 - ) - - explanation_combined_4 <- explain( - model = model_lm_numeric, - x_explain = x_explain_numeric, - x_train = x_train_numeric, - approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, - seed = 1 - ) - - # Check that they are equal - expect_equal(explanation_combined_3, explanation_combined_4) }) test_that("counting the number of unique approaches", { @@ -1905,48 +1479,48 @@ test_that("counting the number of unique approaches", { # Recall that the last approach is not counted in `n_unique_approaches` as # we do not use it as we then condition on all features. explanation_combined_1 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "gaussian", "copula"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_1$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_1$internal$parameters$n_unique_approaches, 4) explanation_combined_2 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_2$internal$parameters$n_approaches, 1) expect_equal(explanation_combined_2$internal$parameters$n_unique_approaches, 1) explanation_combined_3 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("gaussian", "gaussian", "gaussian", "gaussian"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_3$internal$parameters$n_approaches, 4) expect_equal(explanation_combined_3$internal$parameters$n_unique_approaches, 1) explanation_combined_4 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "independence", "empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_4$internal$parameters$n_approaches, 4) @@ -1954,12 +1528,12 @@ test_that("counting the number of unique approaches", { # Check that the last one is not counted explanation_combined_5 <- explain( + testing = TRUE, model = model_lm_numeric, x_explain = x_explain_numeric, x_train = x_train_numeric, approach = c("independence", "empirical", "independence", "empirical"), - prediction_zero = p0, - timing = FALSE, + phi0 = p0, seed = 1 ) expect_equal(explanation_combined_5$internal$parameters$n_approaches, 4) @@ -1971,39 +1545,41 @@ test_that("counting the number of unique approaches", { test_that("vaeac_set_seed_works", { # Train two vaeac models with the same seed explanation_vaeac_1 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) explanation_vaeac_2 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_vaeac_2$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_vaeac_2$shapley_values_est) }) test_that("vaeac_pretreained_vaeac_model", { @@ -2011,19 +1587,20 @@ test_that("vaeac_pretreained_vaeac_model", { # have trained it in a previous shapr::explain object. explanation_vaeac_1 <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list( vaeac.epochs_initiation_phase = 2 - ) + ), + iterative = FALSE ) #### We can do this by reusing the vaeac model OBJECT @@ -2032,21 +1609,22 @@ test_that("vaeac_pretreained_vaeac_model", { # send the pre-trained vaeac model to the explain function explanation_pretrained_vaeac <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = vaeac.pretrained_vaeac_model - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_pretrained_vaeac$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_pretrained_vaeac$shapley_values_est) #### We can also do this by reusing the vaeac model PATH # Get the pre-trained vaeac model path @@ -2054,19 +1632,55 @@ test_that("vaeac_pretreained_vaeac_model", { # send the pre-trained vaeac model to the explain function explanation_pretrained_vaeac <- explain( + testing = TRUE, model = model_lm_mixed, x_explain = x_explain_mixed, x_train = x_train_mixed, approach = "vaeac", - prediction_zero = p0, - n_samples = 10, - n_batches = 2, + phi0 = p0, + n_MC_samples = 10, seed = 1, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = vaeac.pretrained_vaeac_path - ) + ), + iterative = FALSE ) # Check for equal Shapley values - expect_equal(explanation_vaeac_1$shapley_values, explanation_pretrained_vaeac$shapley_values) + expect_equal(explanation_vaeac_1$shapley_values_est, explanation_pretrained_vaeac$shapley_values_est) +}) + + +test_that("feature wise and groupwise computations are identical", { + groups <- list( + Solar.R = "Solar.R", + Wind = "Wind", + Temp = "Temp", + Month = "Month", + Day = "Day" + ) + + expl_feat <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + phi0 = p0 + ) + + + expl_group <- explain( + testing = TRUE, + model = model_lm_numeric, + x_explain = x_explain_numeric, + x_train = x_train_numeric, + approach = "gaussian", + group = groups, + phi0 = p0 + ) + + + # Checking equality in the list with all final and intermediate results + expect_equal(expl_feat$shapley_values_est, expl_group$shapley_values_est) }) diff --git a/vignettes/.gitignore b/vignettes/.gitignore index f48855dd4c216dc290e4e5cde49b86d2d8a3d0d0..aead93cbc9d8b2a4ea996e362cecd2878c89cdb8 100644 --- a/vignettes/.gitignore +++ b/vignettes/.gitignore @@ -3,3 +3,4 @@ cache_main/ cache_vaeac/ cache_regression/ +cache_asymmetric_causal/ diff --git a/vignettes/cache_main/__packages b/vignettes/cache_main/__packages index ab530a49310db637b6e3eebd45e501fb33590d35..de6a8a59dbe05e0b80883246e77d24f232b936a3 100644 --- a/vignettes/cache_main/__packages +++ b/vignettes/cache_main/__packages @@ -2,3 +2,4 @@ shapr xgboost data.table gbm +future diff --git a/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png b/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png new file mode 100644 index 0000000000000000000000000000000000000000..4b222e6e4395fa3573f3598293ee1ac0d7609b65 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/Asymmetric_ordering.png differ diff --git a/vignettes/figure_asymmetric_causal/Causal_ordering.png b/vignettes/figure_asymmetric_causal/Causal_ordering.png new file mode 100644 index 0000000000000000000000000000000000000000..e4faada4621caa00bc1f8b1c08663777e8562be1 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/Causal_ordering.png differ diff --git a/vignettes/figure_asymmetric_causal/compare_plots-1.png b/vignettes/figure_asymmetric_causal/compare_plots-1.png new file mode 100644 index 0000000000000000000000000000000000000000..bef50c2372bdc5b6c86cf565a62bb466b5db690d Binary files /dev/null and b/vignettes/figure_asymmetric_causal/compare_plots-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png new file mode 100644 index 0000000000000000000000000000000000000000..14006afc3b99d41469d9b8cdf6729744690ce2c4 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_cau_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..7b5e7e390b1f1bd555acb9bf15c19ad5d3a1d89c Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..7e13225c5b68ceadba2ad224bc68cb7ca67a5822 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..14b7892f838c26ecae707a2143b20b474b91a5f7 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png new file mode 100644 index 0000000000000000000000000000000000000000..4d59ab481dcf72b0c3e59e2feb649a02524e14d1 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_con_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..dfef7de80b00867cf779a85261cdeb1f81cdbcaf Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png b/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..ad345d8ab8404df23f530445594af33db4b26c6e Binary files /dev/null and b/vignettes/figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_cor-1.png b/vignettes/figure_asymmetric_causal/group_cor-1.png new file mode 100644 index 0000000000000000000000000000000000000000..47a84bd0fbec5e4cac2f8a5e61d5e5ba829982d5 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_cor-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png b/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png new file mode 100644 index 0000000000000000000000000000000000000000..c6218f861798f24525190c7d9c5619bda8c2ebae Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_gaussian_plot_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png b/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..5251de2d84961c439c72ad3ac7e701a1441ac2b8 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png b/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png new file mode 100644 index 0000000000000000000000000000000000000000..c4a04809a02b707c4b4db2e5a118ff56e6028579 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/n_coalitions_plot_SV-1.png differ diff --git a/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png b/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png new file mode 100644 index 0000000000000000000000000000000000000000..5da78afe34cf79651f8af8a3ebeaafcc1425b1d6 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png differ diff --git a/vignettes/figure_asymmetric_causal/scatter_plots-1.png b/vignettes/figure_asymmetric_causal/scatter_plots-1.png new file mode 100644 index 0000000000000000000000000000000000000000..58273ac0de0ad9035694cade35f88147c55d66c6 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/scatter_plots-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_1-1.png b/vignettes/figure_asymmetric_causal/setup_1-1.png new file mode 100644 index 0000000000000000000000000000000000000000..ff2e6977b22a46ad916f02b138c990299670892a Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_1-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_2-1.png b/vignettes/figure_asymmetric_causal/setup_2-1.png new file mode 100644 index 0000000000000000000000000000000000000000..8a3c9320e2d56ee43107e3af379942053b9f1507 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_2-1.png differ diff --git a/vignettes/figure_asymmetric_causal/setup_3-1.png b/vignettes/figure_asymmetric_causal/setup_3-1.png new file mode 100644 index 0000000000000000000000000000000000000000..fcce487024ea52c284c79a830c57c2ec10c647de Binary files /dev/null and b/vignettes/figure_asymmetric_causal/setup_3-1.png differ diff --git a/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png b/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png new file mode 100644 index 0000000000000000000000000000000000000000..400cc85aa4b1b44afbebae74868d014bca5dd695 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_1-1.png b/vignettes/figure_asymmetric_causal/two_dates_1-1.png new file mode 100644 index 0000000000000000000000000000000000000000..f642e0de18780622f8352a5045f571c762b3c0b7 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_1-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_2-1.png b/vignettes/figure_asymmetric_causal/two_dates_2-1.png new file mode 100644 index 0000000000000000000000000000000000000000..ce9e6694b5b06ff7fc45567f8e358755aa105532 Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_2-1.png differ diff --git a/vignettes/figure_asymmetric_causal/two_dates_3-1.png b/vignettes/figure_asymmetric_causal/two_dates_3-1.png new file mode 100644 index 0000000000000000000000000000000000000000..7dac8bcda231c080930045831829aba970b0556d Binary files /dev/null and b/vignettes/figure_asymmetric_causal/two_dates_3-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-12-1.png b/vignettes/figure_main/unnamed-chunk-12-1.png index f39f175bbcd1d7579ff277286b4a3b3f2f2d7298..fb616327f2cb19fa24e2a92ca1dfa4770379e193 100644 Binary files a/vignettes/figure_main/unnamed-chunk-12-1.png and b/vignettes/figure_main/unnamed-chunk-12-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-12-2.png b/vignettes/figure_main/unnamed-chunk-12-2.png index 48ef4a1b4427c06fe997201384e64a1e87d8af4c..dd2c35d9a53bbb82ce5b565a98aa23cb674a8368 100644 Binary files a/vignettes/figure_main/unnamed-chunk-12-2.png and b/vignettes/figure_main/unnamed-chunk-12-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-13-1.png b/vignettes/figure_main/unnamed-chunk-13-1.png index 4dde3b845e9a3eb80c72a2392119658d39315d8f..f4cb0bc2cd3d31c26979d591ad834b7609c91a8b 100644 Binary files a/vignettes/figure_main/unnamed-chunk-13-1.png and b/vignettes/figure_main/unnamed-chunk-13-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-14-1.png b/vignettes/figure_main/unnamed-chunk-14-1.png index c3e047ece014c8e40a368c7ba563ad9d54efa0b7..6bdadcf45dc0ab2d6bc86ad70b99461b2ed63c1b 100644 Binary files a/vignettes/figure_main/unnamed-chunk-14-1.png and b/vignettes/figure_main/unnamed-chunk-14-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-18-1.png b/vignettes/figure_main/unnamed-chunk-18-1.png new file mode 100644 index 0000000000000000000000000000000000000000..026223eae6e28e5b0cb4146261fbb0f44532f780 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-18-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-19-1.png b/vignettes/figure_main/unnamed-chunk-19-1.png new file mode 100644 index 0000000000000000000000000000000000000000..026223eae6e28e5b0cb4146261fbb0f44532f780 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-19-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-2-1.png b/vignettes/figure_main/unnamed-chunk-2-1.png index ac95b58185f10e5b062669a27c6c1ae4dd5a1af4..b8a19b268985fec7c2a2ea604062495ac00039c2 100644 Binary files a/vignettes/figure_main/unnamed-chunk-2-1.png and b/vignettes/figure_main/unnamed-chunk-2-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-20-1.png b/vignettes/figure_main/unnamed-chunk-20-1.png index f915a79610edb57fb9ac48f532e555b4e945c7ce..5d60aa4d37aacd3345f5f6fb41a78d00a66a9110 100644 Binary files a/vignettes/figure_main/unnamed-chunk-20-1.png and b/vignettes/figure_main/unnamed-chunk-20-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-21-1.png b/vignettes/figure_main/unnamed-chunk-21-1.png new file mode 100644 index 0000000000000000000000000000000000000000..577b5318473e9053eebf61d5110baea9402daeee Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-21-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-21-2.png b/vignettes/figure_main/unnamed-chunk-21-2.png new file mode 100644 index 0000000000000000000000000000000000000000..577b5318473e9053eebf61d5110baea9402daeee Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-21-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-22-1.png b/vignettes/figure_main/unnamed-chunk-22-1.png index dd32ab9fec05e92c168afb7102fc4f7b120ee655..577b5318473e9053eebf61d5110baea9402daeee 100644 Binary files a/vignettes/figure_main/unnamed-chunk-22-1.png and b/vignettes/figure_main/unnamed-chunk-22-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-22-2.png b/vignettes/figure_main/unnamed-chunk-22-2.png new file mode 100644 index 0000000000000000000000000000000000000000..577b5318473e9053eebf61d5110baea9402daeee Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-22-2.png differ diff --git a/vignettes/figure_main/unnamed-chunk-3-1.png b/vignettes/figure_main/unnamed-chunk-3-1.png index 90868c1fb2ffc80e1245b4e3b7058d4ce54cdc64..148f14fa967d6af797f5e1a142a3668e2dac7b8c 100644 Binary files a/vignettes/figure_main/unnamed-chunk-3-1.png and b/vignettes/figure_main/unnamed-chunk-3-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-4-1.png b/vignettes/figure_main/unnamed-chunk-4-1.png index df0fde47148ecaf47ef08e362dbd55da4826b1a5..00cd020fd7847594ad9e57172e0b14b78906ced6 100644 Binary files a/vignettes/figure_main/unnamed-chunk-4-1.png and b/vignettes/figure_main/unnamed-chunk-4-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-5-1.png b/vignettes/figure_main/unnamed-chunk-5-1.png index 0290ecd84123912c40332c4047a3af4a3103ff11..2a9708a94c366766ecfd8211be36cd320dcc98a7 100644 Binary files a/vignettes/figure_main/unnamed-chunk-5-1.png and b/vignettes/figure_main/unnamed-chunk-5-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-58-1.png b/vignettes/figure_main/unnamed-chunk-58-1.png new file mode 100644 index 0000000000000000000000000000000000000000..3f9ea98805f7274326635854e9bf4ff6f2507a78 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-58-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-59-1.png b/vignettes/figure_main/unnamed-chunk-59-1.png new file mode 100644 index 0000000000000000000000000000000000000000..28f9dacae804125586d7d37c5049ab98362cf6e6 Binary files /dev/null and b/vignettes/figure_main/unnamed-chunk-59-1.png differ diff --git a/vignettes/figure_main/unnamed-chunk-6-1.png b/vignettes/figure_main/unnamed-chunk-6-1.png index 271c82ed9405bb43b9ccef77aa7a974b4ecf71ca..e0187283e4ebb61dfbf56f2c3f64da9891f1398d 100644 Binary files a/vignettes/figure_main/unnamed-chunk-6-1.png and b/vignettes/figure_main/unnamed-chunk-6-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-1-1.png b/vignettes/figure_main/vaeac-plot-1-1.png index c4cd18e88201e298e652e1f341b9515b11f46eca..49369bdb22930cca0d14a75620608eb618f367f7 100644 Binary files a/vignettes/figure_main/vaeac-plot-1-1.png and b/vignettes/figure_main/vaeac-plot-1-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-2-1.png b/vignettes/figure_main/vaeac-plot-2-1.png index 8fc2362bc46838a24495ab4d8bead8df7d06ffb6..9a91b1149d80732c185b9641099c0f8ed2080590 100644 Binary files a/vignettes/figure_main/vaeac-plot-2-1.png and b/vignettes/figure_main/vaeac-plot-2-1.png differ diff --git a/vignettes/figure_main/vaeac-plot-3-1.png b/vignettes/figure_main/vaeac-plot-3-1.png index 92434e4a5560c8bfc6f68242ea6f2cd699ffd6a2..ac26abc57484e0848b3e80eea230c5f327611bb9 100644 Binary files a/vignettes/figure_main/vaeac-plot-3-1.png and b/vignettes/figure_main/vaeac-plot-3-1.png differ diff --git a/vignettes/figure_regression/MSEv-sum-1.png b/vignettes/figure_regression/MSEv-sum-1.png index 946229591ff5b77e410eb18e9dc5b9569c9b2bed..3b65ba88d1e3a2264b56dbe5071de97a3bc4168c 100644 Binary files a/vignettes/figure_regression/MSEv-sum-1.png and b/vignettes/figure_regression/MSEv-sum-1.png differ diff --git a/vignettes/figure_regression/MSEv-sum-2-1.png b/vignettes/figure_regression/MSEv-sum-2-1.png index 7c9cff6c6b9956ed33d7cd647516472d363c415e..a73fb581caf787625d6a5f91139f0b43de270020 100644 Binary files a/vignettes/figure_regression/MSEv-sum-2-1.png and b/vignettes/figure_regression/MSEv-sum-2-1.png differ diff --git a/vignettes/figure_regression/SV-sum-1.png b/vignettes/figure_regression/SV-sum-1.png index a5e3156d6ea548d22f01d0501133f32ca99b81fa..72d82cc1784d57aedf84eac5a8f41e1e8f42890d 100644 Binary files a/vignettes/figure_regression/SV-sum-1.png and b/vignettes/figure_regression/SV-sum-1.png differ diff --git a/vignettes/figure_regression/SV-sum-2-1.png b/vignettes/figure_regression/SV-sum-2-1.png index 67a259a0acf0619561e34c69680aa4c307228fd2..c8f781e5e65104de541e60442f72b51bdf60429e 100644 Binary files a/vignettes/figure_regression/SV-sum-2-1.png and b/vignettes/figure_regression/SV-sum-2-1.png differ diff --git a/vignettes/figure_regression/SV-sum-2.png b/vignettes/figure_regression/SV-sum-2.png index b5c6c636057af406cabce907958a2d7f294b6a74..1bfebe7d9fe4bb286cf6ce45ea924fe7e58d1034 100644 Binary files a/vignettes/figure_regression/SV-sum-2.png and b/vignettes/figure_regression/SV-sum-2.png differ diff --git a/vignettes/figure_regression/SV-sum-3.png b/vignettes/figure_regression/SV-sum-3.png index c7a0578debeb42c789b3308322e9fb854f0a0a3f..d5c7c83d3ce137f93c6c4e16c54f888d3f1ad759 100644 Binary files a/vignettes/figure_regression/SV-sum-3.png and b/vignettes/figure_regression/SV-sum-3.png differ diff --git a/vignettes/figure_regression/decision-tree-plot-1.png b/vignettes/figure_regression/decision-tree-plot-1.png index c211b764bf3f7391175ded5d40de9458fa08f03c..c387f5f7b547066a37fa29d4a9b612f1def666a4 100644 Binary files a/vignettes/figure_regression/decision-tree-plot-1.png and b/vignettes/figure_regression/decision-tree-plot-1.png differ diff --git a/vignettes/figure_regression/dt-cv-plot-1.png b/vignettes/figure_regression/dt-cv-plot-1.png index e3f0c19015f7784f7114730fe1a151f75f61e7fc..a8749762d7956041b55f97a01a0dfaa6b80d1a8c 100644 Binary files a/vignettes/figure_regression/dt-cv-plot-1.png and b/vignettes/figure_regression/dt-cv-plot-1.png differ diff --git a/vignettes/figure_regression/lm-emp-msev-1.png b/vignettes/figure_regression/lm-emp-msev-1.png index a79ef864eb661810cdf7e9e88f1494517c601e02..4aed9c4c91100b7c6074cd1ce3863575c67e7ec7 100644 Binary files a/vignettes/figure_regression/lm-emp-msev-1.png and b/vignettes/figure_regression/lm-emp-msev-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-1.png b/vignettes/figure_regression/mixed-plot-1.png index def0c68ad2cb97ba2570d456ba1cca5a0b237626..01b6c7ae72399121fe66cd4cbed3185052a9ce20 100644 Binary files a/vignettes/figure_regression/mixed-plot-1.png and b/vignettes/figure_regression/mixed-plot-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-2-1.png b/vignettes/figure_regression/mixed-plot-2-1.png index bbf7975cf6009cf089e144d455effc79bf7e3ae9..5f03624f836508b248def13d139168e4d57a0a32 100644 Binary files a/vignettes/figure_regression/mixed-plot-2-1.png and b/vignettes/figure_regression/mixed-plot-2-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-3-1.png b/vignettes/figure_regression/mixed-plot-3-1.png index a31e191b23d75bdddadc88fcf7bdcbc6d78963ca..3b263fedbac9df59889a56b45192da79eda70726 100644 Binary files a/vignettes/figure_regression/mixed-plot-3-1.png and b/vignettes/figure_regression/mixed-plot-3-1.png differ diff --git a/vignettes/figure_regression/mixed-plot-4-1.png b/vignettes/figure_regression/mixed-plot-4-1.png index 134ddaf59d0fc1080e5be914440f8104e78ac7a8..70d599b304ed0d0d1cc4e3adee0680ccfd06384d 100644 Binary files a/vignettes/figure_regression/mixed-plot-4-1.png and b/vignettes/figure_regression/mixed-plot-4-1.png differ diff --git a/vignettes/figure_regression/ppr-plot-1.png b/vignettes/figure_regression/ppr-plot-1.png index 80e82d1a73ee7927dffe0ff1be0dbc04a550c653..0d5d0e0d56ec2c6e97fbd2c5bf699bc462d75622 100644 Binary files a/vignettes/figure_regression/ppr-plot-1.png and b/vignettes/figure_regression/ppr-plot-1.png differ diff --git a/vignettes/figure_regression/preproc-plot-1.png b/vignettes/figure_regression/preproc-plot-1.png index d69b210e8f3add5861cca4850dd06035272c5d82..4fa9c9f91f5f0792607412b1c44d7a98199b710f 100644 Binary files a/vignettes/figure_regression/preproc-plot-1.png and b/vignettes/figure_regression/preproc-plot-1.png differ diff --git a/vignettes/figure_regression/surrogate-plot-1.png b/vignettes/figure_regression/surrogate-plot-1.png index ffc2d758428d08583250774d64b72db9b45d3192..323ce8cc3b0807ab0c26957535ef17fa2a85cdc3 100644 Binary files a/vignettes/figure_regression/surrogate-plot-1.png and b/vignettes/figure_regression/surrogate-plot-1.png differ diff --git a/vignettes/figure_vaeac/check-n_coalitions-1.png b/vignettes/figure_vaeac/check-n_coalitions-1.png new file mode 100644 index 0000000000000000000000000000000000000000..84439d62f76aebe71fb2ad15136e125076c8d68e Binary files /dev/null and b/vignettes/figure_vaeac/check-n_coalitions-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-1.png b/vignettes/figure_vaeac/continue-training-1.png index 0e5fb697e31162b8f36ecc213099d7dc2ee3fbe4..22484df8d38bdf327d609998e5c48c833fe5220c 100644 Binary files a/vignettes/figure_vaeac/continue-training-1.png and b/vignettes/figure_vaeac/continue-training-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-2-1.png b/vignettes/figure_vaeac/continue-training-2-1.png index 7a11d30a50b74cb8c66cba03b2f641e95e687544..70f7a48e8347268b32c6f3f2ab1765ec940ea2a0 100644 Binary files a/vignettes/figure_vaeac/continue-training-2-1.png and b/vignettes/figure_vaeac/continue-training-2-1.png differ diff --git a/vignettes/figure_vaeac/continue-training-2-2.png b/vignettes/figure_vaeac/continue-training-2-2.png index 5f399e592c8dbef0c32258b884ac3e9854ead837..ee9958673bd9f0de0c60d1190a1d02cfe6cb1ae7 100644 Binary files a/vignettes/figure_vaeac/continue-training-2-2.png and b/vignettes/figure_vaeac/continue-training-2-2.png differ diff --git a/vignettes/figure_vaeac/continue-training-2.png b/vignettes/figure_vaeac/continue-training-2.png index 149759141c11c38b81a5d061df4db42b128f44cc..ae3e25a22ab98a64a730267b248871ca988b3965 100644 Binary files a/vignettes/figure_vaeac/continue-training-2.png and b/vignettes/figure_vaeac/continue-training-2.png differ diff --git a/vignettes/figure_vaeac/continue-training-3.png b/vignettes/figure_vaeac/continue-training-3.png index 1120e4ed30f2bb8b1200ab20575ae3ca172623de..b6876b55f6766f377fc08103437f13db3ebf15cf 100644 Binary files a/vignettes/figure_vaeac/continue-training-3.png and b/vignettes/figure_vaeac/continue-training-3.png differ diff --git a/vignettes/figure_vaeac/continue-training-4.png b/vignettes/figure_vaeac/continue-training-4.png index 6ee3224598dfb272b2c67b12dcd4eafa6e082b17..c5c7da264bbaf0b6f50e8bc2c837fde7dcbc4bd0 100644 Binary files a/vignettes/figure_vaeac/continue-training-4.png and b/vignettes/figure_vaeac/continue-training-4.png differ diff --git a/vignettes/figure_vaeac/continue-training-5.png b/vignettes/figure_vaeac/continue-training-5.png index 7ad958b75cffaae8bb042abd03fbb1f3d51397bf..2808f4ac69a6fc7a35818972a7c43a3132971604 100644 Binary files a/vignettes/figure_vaeac/continue-training-5.png and b/vignettes/figure_vaeac/continue-training-5.png differ diff --git a/vignettes/figure_vaeac/early-stopping-1-1.png b/vignettes/figure_vaeac/early-stopping-1-1.png index 464cc6f9dc6a2bdf505024d7621fb744d342ded2..d2eb50906c902da6aad6183895748adf3f1f1e16 100644 Binary files a/vignettes/figure_vaeac/early-stopping-1-1.png and b/vignettes/figure_vaeac/early-stopping-1-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-2-1.png b/vignettes/figure_vaeac/early-stopping-2-1.png index 5c7c98dde35fe53685f2ed20b66fe60c3f412c9f..c8bf25482245121080258b445dd76871b5729f44 100644 Binary files a/vignettes/figure_vaeac/early-stopping-2-1.png and b/vignettes/figure_vaeac/early-stopping-2-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-3-1.png b/vignettes/figure_vaeac/early-stopping-3-1.png index cc3a3fd5e27ebcb8b05b727b5ee46a6760f6220c..961e478ba6b0b42f5060ee2795fd5397085804f7 100644 Binary files a/vignettes/figure_vaeac/early-stopping-3-1.png and b/vignettes/figure_vaeac/early-stopping-3-1.png differ diff --git a/vignettes/figure_vaeac/early-stopping-3-2.png b/vignettes/figure_vaeac/early-stopping-3-2.png index 3c6fe34febd3ab26493bbc20a7f92627e4212753..db627d8e02ebadc8a6637b9668fef142386dade2 100644 Binary files a/vignettes/figure_vaeac/early-stopping-3-2.png and b/vignettes/figure_vaeac/early-stopping-3-2.png differ diff --git a/vignettes/figure_vaeac/first-vaeac-plots-1.png b/vignettes/figure_vaeac/first-vaeac-plots-1.png index dd10cc011f40ebd7fbb1001548c77be20df56182..18edc47c05121ff0aded062e32141fa2ca5f41b0 100644 Binary files a/vignettes/figure_vaeac/first-vaeac-plots-1.png and b/vignettes/figure_vaeac/first-vaeac-plots-1.png differ diff --git a/vignettes/figure_vaeac/paired-sampling-plotting-1.png b/vignettes/figure_vaeac/paired-sampling-plotting-1.png index 4e5f4052ad0805529857f45250fccf5dbab9f595..f22ecde74cb867ce497f3f5ff3a887ea6ea92322 100644 Binary files a/vignettes/figure_vaeac/paired-sampling-plotting-1.png and b/vignettes/figure_vaeac/paired-sampling-plotting-1.png differ diff --git a/vignettes/figure_vaeac/paired-sampling-plotting-2.png b/vignettes/figure_vaeac/paired-sampling-plotting-2.png index 0117a8e8e2e8822a9eb1b6a146035f27fb12c253..0e6473cd2677babf3b548025cb3edc19962a972d 100644 Binary files a/vignettes/figure_vaeac/paired-sampling-plotting-2.png and b/vignettes/figure_vaeac/paired-sampling-plotting-2.png differ diff --git a/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png b/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png index 0e618dd3ddc2ba407418ce5aa7d0f7cbb151528e..d7af1938311c60a471462346ddbaaa7cbdcc4f33 100644 Binary files a/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png and b/vignettes/figure_vaeac/vaeac-grouping-of-features-1.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-1.png b/vignettes/figure_vaeac/vaeac-mixed-data-1.png index 81e21d290e53917b8061c2e14f6670f2a703c887..1811f662ff5e4a716645746864941921a70d1fe9 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-1.png and b/vignettes/figure_vaeac/vaeac-mixed-data-1.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-2.png b/vignettes/figure_vaeac/vaeac-mixed-data-2.png index 4c4bce005ef9cb426288ae5ecc607d1db42e92a0..999ef0edcfd5bdec320c1611c0dc676d65662bf0 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-2.png and b/vignettes/figure_vaeac/vaeac-mixed-data-2.png differ diff --git a/vignettes/figure_vaeac/vaeac-mixed-data-3.png b/vignettes/figure_vaeac/vaeac-mixed-data-3.png index ce652685add65dce6e83bd9074180e7a150d40d5..5efe4e6225df4208b3a90d9a0a1af454cad89caa 100644 Binary files a/vignettes/figure_vaeac/vaeac-mixed-data-3.png and b/vignettes/figure_vaeac/vaeac-mixed-data-3.png differ diff --git a/vignettes/understanding_shapr.Rmd b/vignettes/understanding_shapr.Rmd index a0bd1ab0dd32fa41ae5d233728a9fe5297baef1d..c3c953ff8b440f0b510e42ff9c92df485998c563 100644 --- a/vignettes/understanding_shapr.Rmd +++ b/vignettes/understanding_shapr.Rmd @@ -20,15 +20,19 @@ editor_options: > [Overview of Package](#overview) -> [The Kernel SHAP Method](#KSHAP) +> [KernelSHAP and dependence-aware estimators](#KSHAP) -> [Examples](#ex) +> [Estimation approaches and plotting functionality](#ex) -> [Advanced usage](#advanced) +> [iterative estimation](#iterative) + +> [Parallelization](#para) -> [Scalability and efficency](#scalability) +> [Verbosity and progress updates](#verbose) + +> [Advanced usage](#advanced) -> [Comparison to Lundberg & Lee's implementation](#compare) +> [Explaining forecasting models](#forecasting) @@ -43,7 +47,7 @@ on interpreting individual predictions, Shapley values is regarded to be the only model-agnostic explanation method with a solid theoretical foundation (@lundberg2017unified). Kernel SHAP is a computationally efficient approximation to Shapley values in higher dimensions, but it -assumes independent features. @aas2019explaining extend the Kernel SHAP +assumes independent features. @aas2019explaining extends the Kernel SHAP method to handle dependent features, resulting in more accurate approximations to the true Shapley values. See the [paper](https://www.sciencedirect.com/sdfe/reader/pii/S0004370221000539/pdf) @@ -55,7 +59,7 @@ approximations to the true Shapley values. See the # Overview of Package -## Functions +## Functionality Here is an overview of the main functions. You can read their documentation and see examples with `?function_name`. @@ -68,11 +72,62 @@ documentation and see examples with `?function_name`. : Main functions in the `shapr` package. +The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +`"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +`shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +See [Estimation approaches and plotting functionality](#ex) below for examples. +It is also possible to combine the different approaches, see the [combined approach](#combined). + +The package allows for parallelized computation through the `future`package, see [Parallelization](#para) for details. + +The level of detail in the output can be controlled through the `verbose` argument. In addition, progress updates +on the process of estimating the `v(S)`'s (and training the `"vaeac"` model) is available through the +`progressr` package, supporting progress updates also for parallelized computation. +See [Verbosity and progress updates](#verbose) for details. + +Moreover, the default behavior is to estimate the Shapley values iteratively/iteratively, with increasing number of +feature coalitions being added, and to stop estimation as the estimated Shapley values has achieved a certain level of +stability. +More information about this is provided in [iterative estimation](#iterative) +The above, combined with batch computation of the `v(S)` values, enables fast and accurate estimation of the +Shapley values in a memory friendly manner. + +The package also provides functionality for computing Shapley values for groups of features, and custom function explanation, see [Advanced usage](#advanced). +Finally, explanation of multiple output time series forecasting models are discussed in +[Explaining forecasting models](#forecasting). + + +## Default behavior of `explain` + +Below we provide brief descriptions of the most important parts of the default behavior of the `explain` function. + +By default `explain` always compute feature-wise Shapley values. +Groups of features can be explained by providing the feature groups through the `group` argument. + +When there are five or less features (or feature groups), iterative estimation is by default disabled. +The reason for this is that it is usually faster to estimate the Shapley values for all possible coalitions (`v(S)`), +than to estimate the uncertainty of the Shapley values, and potentially stop estimation earlier. +While iterative estimation is the default starting from six features, it is mainly when there are more than ten features, +that it is most beneficial, and can save a lot of computation time. +The reason for this is that the number of possible coalitions grows exponentially. +These defaults can be overridden by setting the `iterative` argument to `TRUE` or `FALSE`. +When using the `iterative` argument, the estimation for an observation is stopped when all Shapley value +standard deviations are below `t` times the range of the Shapley values. +The `t` value controls the convergence tolerance, defaults to 0.02, and can be set through the `iterative_args$convergence_tol` argument, see [iterative estimation](#iterative) for more details. + +Since the iterativeness default changes based on the number of features (or feature groups), the default is also to have +no upper bound on the number of coalitions considered. +This can be controlled through the `max_n_coalitions` argument. + +
-# The Kernel SHAP Method +# KernelSHAP and dependence-aware estimators + +## The Kernel SHAP Method Assume a predictive model $f(\boldsymbol{x})$ for a response value $y$ with features $\boldsymbol{x}\in \mathbb{R}^M$, trained on a training @@ -237,9 +292,7 @@ AIC known as AICc. As calculation of it is computationally intensive, an approximate version of the selection criterion is also suggested. Details on this is found in @aas2019explaining. - - -
+ ## Conditional Inference Tree Approach @@ -319,6 +372,8 @@ the `explain()` function. For example, we can the change the batch size to 32 by `vaeac.extra_parameters = list(vaeac.batch_size = 32)` as a parameter in the call the `explain()` function. See `?shapr::vaeac_get_extra_para_default` for a description of the possible extra parameters to the `vaeac` approach. We strongly encourage the user to specify the main and extra parameters to the `vaeac` approach at the correct place in the call to the `explain()` function. That is, the main parameters are directly entered to the `explain()` function, while the extra parameters are included in a named list called `vaeac.extra_parameters`. However, the `vaeac` approach will try to correct for misplaced and duplicated parameters and give warnings to the user. + + ## Categorical Approach When the features are all categorical, we can estimate the conditional @@ -365,17 +420,11 @@ paradigm into the separate and surrogate regression method classes. In the separate vignette, we briefly introduce the two method classes. For an in-depth explanation, we refer the reader to Sections 3.5 and 3.6 in @olsen2024comparative. -
-# Examples {#examples} - -`shapr` supports computation of Shapley values with any predictive model -which takes a set of numeric features and produces a numeric outcome. -Note that the ctree method takes both numeric and categorical variables. -Check under "Advanced usage" for an example of how this can be done. +# Estimation approaches and plotting functionality {#ex} The following example shows how a simple `xgboost` model is trained using the `airquality` dataset, and how `shapr` can be used to explain @@ -388,9 +437,10 @@ below. -```r +``` r library(xgboost) library(data.table) +#> data.table 1.15.4 using 16 threads (see ?getDTthreads). Latest news: r-datatable.com data("airquality") data <- data.table::as.data.table(airquality) @@ -425,25 +475,40 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. #> -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:05 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c207abd4b.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> -#> 1: 43.086 13.21173 4.7856 -25.572 -5.5992 -#> 2: 43.086 -9.97277 5.8307 -11.039 -7.8300 -#> 3: 43.086 -2.29162 -7.0534 -10.150 -4.4525 -#> 4: 43.086 3.32546 -3.2409 -10.225 -6.6635 -#> 5: 43.086 4.30396 -2.6278 -14.152 -12.2669 -#> 6: 43.086 0.47864 -5.2487 -12.553 -6.6457 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.086 13.21173 4.7856 -25.572 -5.5992 +#> 2: 2 43.086 -9.97277 5.8307 -11.039 -7.8300 +#> 3: 3 43.086 -2.29162 -7.0534 -10.150 -4.4525 +#> 4: 4 43.086 3.32546 -3.2409 -10.225 -6.6635 +#> 5: 5 43.086 4.30396 -2.6278 -14.152 -12.2669 +#> 6: 6 43.086 0.47864 -5.2487 -12.553 -6.6457 # Plot the resulting explanations for observations 1 and 6 plot(explanation, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) @@ -459,7 +524,7 @@ There are multiple plot options specified by the `plot_type` argument in `plot`. The `waterfall` option shows the changes in the prediction score due to each features contribution (their Shapley values): -```r +``` r plot(explanation, plot_type = "waterfall", index_x_explain = c(1, 6)) ``` @@ -475,19 +540,33 @@ Shapley value of a given instance, where the points are colored by the feature value of that instance: -```r +``` r x_explain_many <- data[, ..x_var] explanation_plot <- explain( model = model, x_explain = x_explain_many, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:09 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 111 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3d5f010f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. plot(explanation_plot, plot_type = "beeswarm") ``` @@ -498,7 +577,7 @@ Shapley values on the y-axis, as well as (optionally) a background scatter_hist showing the distribution of the feature data: -```r +``` r plot(explanation_plot, plot_type = "scatter", scatter_hist = TRUE) ``` @@ -508,7 +587,7 @@ We can use mixed (i.e continuous, categorical, ordinal) data with `ctree` or `va Use `ctree` with mixed data in the following manner: -```r +``` r # convert the month variable to a factor data[, Month_factor := as.factor(Month)] @@ -532,10 +611,24 @@ explanation_lm_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:17 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c49d943cf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect @@ -549,7 +642,7 @@ in the following manner. Default values are based on @hothorn2006unbiased. -```r +``` r # Use the conditional inference tree approach # We can specify parameters used to building trees by specifying mincriterion, # minsplit, minbucket @@ -558,13 +651,27 @@ explanation_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0, + phi0 = p0, ctree.mincriterion = 0.80, ctree.minsplit = 20, - ctree.minbucket = 20 + ctree.minbucket = 20, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:18 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4dae3760.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Default parameters (based on (Hothorn, 2006)) are: # mincriterion = 0.95 # minsplit = 20 @@ -575,7 +682,7 @@ If **all** features are categorical, one may use the categorical approach as follows: -```r +``` r # For the sake of illustration, convert ALL features to factors data[, Solar.R_factor := as.factor(cut(Solar.R, 10))] data[, Wind_factor := as.factor(cut(Wind, 3))] @@ -601,10 +708,24 @@ explanation_cat_method <- explain( x_explain = x_explain_all_cat, x_train = x_train_all_cat, approach = "categorical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:19 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: categorical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5dd5485a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Shapley values can be used to explain any predictive model. For @@ -619,7 +740,7 @@ achieved through the `group` attribute. Other optional parameters of time series if necessary). -```r +``` r # Simulate time series data with AR(1)-structure set.seed(1) data_ts <- data.frame(matrix(NA, ncol = 41, nrow = 4)) @@ -664,11 +785,25 @@ explanation_timeseries <- explain( x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, - group = group_ts + phi0 = p0_ts, + group = group_ts, + iterative = FALSE ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 16, +#> and is therefore set to 2^n_groups = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:19 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: timeseries +#> • Iterative estimation: FALSE +#> • Number of group-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c2eab32f5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -752,7 +887,7 @@ each observation, as each combination is a different prediction tasks. Start by explaining the predictions by using different methods and combining them into lists. -```r +``` r # We use more explicands here for more stable confidence intervals ind_x_explain_many <- 1:25 x_train <- data[-ind_x_explain_many, ..x_var] @@ -776,13 +911,27 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:22 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3b3736b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Empirical approach explanation_empirical <- explain( @@ -790,13 +939,27 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:22 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5c83bb13.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Gaussian 1e1 approach explanation_gaussian_1e1 <- explain( @@ -804,13 +967,27 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e1, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e1, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:26 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026cb6ddb92.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Gaussian 1e2 approach explanation_gaussian_1e2 <- explain( @@ -818,13 +995,27 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:26 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5ef4677c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Combined approach explanation_combined <- explain( @@ -832,13 +1023,27 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "empirical", "independence"), - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:27 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian, empirical, and independence +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 25 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c59227d80.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Create a list of explanations with names explanation_list_named <- list( @@ -854,7 +1059,7 @@ explanation_list_named <- list( We can then compare the different approaches by creating plots of the $\operatorname{MSE}_{v}$ evaluation criterion. -```r +``` r # Create the MSEv plots with approximate 95% confidence intervals MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named, plot_type = c("overall", "comb", "explicand"), @@ -863,63 +1068,45 @@ MSEv_plots <- plot_MSEv_eval_crit(explanation_list_named, # 5 plots are made names(MSEv_plots) -#> [1] "MSEv_explicand_bar" "MSEv_explicand_line_point" "MSEv_combination_bar" "MSEv_combination_line_point" "MSEv_bar" +#> [1] "MSEv_explicand_bar" "MSEv_explicand_line_point" "MSEv_coalition_bar" "MSEv_coalition_line_point" +#> [5] "MSEv_bar" ``` The main plot if interest is the `MSEv_bar`, which displays the $\operatorname{MSE}_{v}$ evaluation criterion for each method averaged over both the combinations/coalitions and test observations/explicands. However, we can also look at the other plots where we have only averaged over the observations or the combinations (both as bar and line plots). -```r +``` r # The main plot of the overall MSEv averaged over both the combinations and observations MSEv_plots$MSEv_bar ``` ![](figure_main/unnamed-chunk-12-1.png) -```r +``` r # The MSEv averaged over only the explicands for each combinations MSEv_plots$MSEv_combination_bar -``` - -![](figure_main/unnamed-chunk-12-2.png) - -```r +#> NULL # The MSEv averaged over only the combinations for each observation/explicand MSEv_plots$MSEv_explicand_bar ``` -![](figure_main/unnamed-chunk-12-3.png) +![](figure_main/unnamed-chunk-12-2.png) -```r +``` r # To see which coalition S each of the `id_combination` corresponds to, # i.e., which features that are conditions on. explanation_list_named[[1]]$MSEv$MSEv_combination[, c("id_combination", "features")] -#> id_combination features -#> -#> 1: 2 1 -#> 2: 3 2 -#> 3: 4 3 -#> 4: 5 4 -#> 5: 6 1,2 -#> 6: 7 1,3 -#> 7: 8 1,4 -#> 8: 9 2,3 -#> 9: 10 2,4 -#> 10: 11 3,4 -#> 11: 12 1,2,3 -#> 12: 13 1,2,4 -#> 13: 14 1,3,4 -#> 14: 15 2,3,4 +#> NULL ``` We can specify the `index_x_explain` and `id_combination` parameters in `plot_MSEv_eval_crit()` to only plot certain test observations and combinations, respectively. -```r +``` r # We can specify which test observations or combinations to plot plot_MSEv_eval_crit(explanation_list_named, plot_type = "explicand", @@ -930,21 +1117,20 @@ plot_MSEv_eval_crit(explanation_list_named, ![](figure_main/unnamed-chunk-13-1.png) -```r +``` r plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 )$MSEv_combination_bar +#> NULL ``` -![](figure_main/unnamed-chunk-13-2.png) - We can also alter the plots design-wise as we do in the code below. -```r +``` r bar_text_n_decimals <- 1 plot_MSEv_eval_crit(explanation_list_named) + ggplot2::scale_x_discrete(limits = rev(levels(MSEv_plots$MSEv_bar$data$Method))) + @@ -970,282 +1156,244 @@ plot_MSEv_eval_crit(explanation_list_named) + ![](figure_main/unnamed-chunk-14-1.png) + -## Main arguments in `explain` - -When using `explain`, the default behavior is to use all feature -combinations in the Shapley formula. Kernel SHAP's sampling based -approach may be used by specifying `n_combinations`, which is the number -of unique feature combinations to sample. If not specified, the exact -method is used. The computation time grows approximately exponentially -with the number of features. The training data and the model whose -predictions we wish to explain must be provided through the arguments -`x_train` and `model`. The data whose predicted values we wish to -explain must be given by the argument `x_explain`. Note that both -`x_train` and `x_explain` must be a `data.frame` or a `matrix`, and all -elements must be finite numerical values. Currently we do not support -missing values. The default approach when computing the Shapley values -is the empirical approach (i.e. `approach = "empirical"`). If you'd like -to use a different approach you'll need to set `approach` equal to -either `copula` or `gaussian`, or a vector of them, with length equal to -the number of features. If a vector, a combined approach is used, and -element `i` indicates the approach to use when conditioning on `i` -variables. For more details see [Combined approach](#combined) below. - -When computing the kernel SHAP values by `explain`, the maximum number -of samples to use in the Monte Carlo integration for every conditional -expectation is controlled by the argument `n_samples` (default equals -`1000`). The computation time grows approximately linear with this -number. You will also need to pass a numeric value for the argument -`prediction_zero`, which represents the prediction value when not -conditioning on any features. We recommend setting this equal to the -mean of the response, but other values, like the mean prediction of a -large test data set is also a possibility. If the empirical method is -used, specific settings for that approach, like a vector of fixed -$\sigma$ values can be specified through the argument -`empirical.fixed_sigma`. See `?explain` for more information. If -`approach = "gaussian"`, you may specify the mean vector and covariance -matrix of the data generating distribution by the arguments -`gaussian.mu` and `gaussian.cov_mat`. If not specified, they are -estimated from the training data. - -## Explaining a forecasting model using `explain_forecast` +# iterative estimation -`shapr` provides a specific function, `explain_forecast`, to explain -forecasts from time series models, at one or more steps into the future. -The main difference compared to `explain` is that the data is supplied -as (set of) time series, in addition to index arguments (`train_idx` and -`explain_idx`) specifying which time points that represents the train -and explain parts of the data. See `?explain_forecast` for more -information. +iterative estimation is the default when computing Shapley values with six or more features (or feature groups), and +can always be manually overridden by setting `iterative = FALSE` in the `explain()` function. +The idea behind iterative estimation is to estimate sufficiently accurate Shapley value estimates faster. +First, an initial number of coalitions is sampled, then, bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through `iterative_args` argument. -To demonstrate how to use the function, 500 observations are generated -which follow an AR(1) structure, i.e. -$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of -order (2, 0, 0) is fitted, and we therefore would like to explain the -forecasts in terms of the two previous lags of the time series. This is -is specified through the argument `explain_y_lags = 2`. Note that some -models may also put restrictions on the amount of data required to make -a forecast. The AR(2) model we used there, for instance, requires two -previous time point to make a forecast. +The convergence criterion we use is adopted from @covert2021improving, and slightly modified to work for multiple +observations -In the example, two separate forecasts, each three steps ahead, are -explained. To set the starting points of the two forecasts, -`explain_idx` is set to `499:500`. This means that one forecast of -$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be -explained. In other words, `explain_idx` tells `shapr` at which points -in time data was available up until, when making the forecast to -explain. +\[ \median_i\left(\frac{max_j \hat{\text{sd}}(\hat{\phi}_{ij}){\max_j \hat{\phi}_{ij} - \min_j \hat{\phi}_{ij}}\right), < t \] -In the same way, `train_idx` denotes the points in time used to estimate -the conditional expectations used to explain the different forecasts. -Note that since we want to explain the forecasts in terms of the two -previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` -must also be 2, because at time $t = 1$ there was only a single -observation available. +where $\hat{\phi}_{ij}$ is the Shapley value of feature $j$ for observation $i$, and $\text{sd}(\phi_{ij})$ +is the its (bootstrap) estimated standard deviation. The default value of $t$ is 0.02. +Below we provide some examples of how to use the iterative estimation procedure -Since the data is stationary, the mean of the data is used as value of -`prediction_zero` (i.e. $\phi_0$). This can however be chosen -differently depending on the data and application. -For a multivariate model such as a VAR (Vector AutoRegressive model), it -may be of more interesting to explain the impact of each variable, -rather than each lag of each variable. This can be done by setting -`group_lags = TRUE`. -```r -# Simulate time series data with AR(1)-structure. -set.seed(1) -data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) -data_ts <- data.table::as.data.table(data_ts) +``` r +library(xgboost) +library(data.table) -# Fit an ARIMA(2, 0, 0) model. -arima_model <- arima(data_ts, order = c(2, 0, 0)) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -# Set prediction zero as the mean of the data for each forecast point. -p0_ar <- rep(mean(data_ts$Y), 3) +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Explain forecasts from points t = 499 and t = 500. -explain_idx <- 499:500 +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -explanation_forecast <- explain_forecast( - model = arima_model, - y = data_ts, - train_idx = 2:498, - explain_idx = 499:500, - explain_y_lags = 2, - horizon = 3, - approach = "empirical", - prediction_zero = p0_ar, - group_lags = FALSE +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. -explanation_forecast -#> explain_idx horizon none Y.1 Y.2 -#> -#> 1: 499 1 0.04018 0.5053 -0.07659 -#> 2: 500 1 0.04018 -0.3622 0.02497 -#> 3: 499 2 0.04018 0.5053 -0.07659 -#> 4: 500 2 0.04018 -0.3622 0.02497 -#> 5: 499 3 0.04018 0.5053 -0.07659 -#> 6: 500 3 0.04018 -0.3622 0.02497 -``` -Note that for a multivariate model such as a VAR (Vector AutoRegressive -model), or for models also including several exogenous variables, it may -be of more informative to explain the impact of each variable, rather -than each lag of each variable. This can be done by setting -`group_lags = TRUE`. This does not make sense for this model, however, -as that would result in decomposing the forecast into a single group. +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -We now give a more hands on example of how to use the `explain_forecast` -function. Say that we have an AR(2) model which describes the change -over time of the variable `Temp` in the dataset `airquality`. It seems -reasonable to assume that the temperature today should affect the -temperature tomorrow. To a lesser extent, we may also suggest that the -temperature today should also have an impact on that of the day after -tomorrow. +# Initial explanation computation +ex <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + iterative_args = list(convergence_tol = 0.1) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:30 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: TRUE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c729d00b9.rds' +#> +#> ── iterative computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 5 of 32 coalitions, 5 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 32 coalitions, 4 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 32 coalitions, 2 new. +``` -We start by building our AR(2) model, naming it `model_ar_temp`. This -model is then used to make a forecast of the temperature of the day that -comes after the last day in the data, this forecast starts from index -153. + +# Parallelization -```r -data_ts2 <- data.table::as.data.table(airquality) +The `shapr` package supports parallelization of the Shapley value estimation process through the +`future` package. +The parallelization is conducted over batches of `v(S)`-values. +We therefore start by describing this batch computing. -model_ar_temp <- ar(data_ts2$Temp, order = 2) +## Batch computation -predict(model_ar_temp, n.ahead = 2)$pred -#> Time Series: -#> Start = 154 -#> End = 155 -#> Frequency = 1 -#> [1] 71.081 71.524 -``` +The computational complexity of Shapley value based explanations grows +fast in the number of features, as the number of conditional +expectations one needs to estimate in the Shapley formula grows +exponentially. As outlined [above](#KSHAP), the estimating of each of +these conditional expectations is also computationally expensive, +typically requiring estimation of a conditional probability +distribution, followed by Monte Carlo integration. These computations +are not only heavy for the CPU, they also require a lot of memory (RAM), +which typically is a limited resource. By doing the most resource hungry +computations (the computation of v(S)) in sequential batches with +different feature subsets $S$, the memory usage can be significantly +reduces. +The user can control the number of batches by setting the two arguments +`extra_computation_args$max_batch_size` (defaults to 10) and +`extra_computation_args$min_n_batches` (defaults to 10). -First, we pass the model and the data as `model` and `y`. Since we have -an AR(2) model, we want to explain the forecasts in terms of the two -previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let -`shapr` know which time indices to use as training data through the -argument `train_idx`. We use `2:152`, meaning that we skip the first -index, as we want to explain the two previous lags. Letting the training -indices go up until 152 means that every point in time except the first -and last will be used as training data. +## Parallelized computation -The last index, 153 is passed as the argument `explain_idx`, which means -that we want to explain a forecast made from time point 153 in the data. -The argument `horizon` is set to 2 in order to explain a forecast of -length 2. +In addition to reducing the memory consumption, the batch computing allows the +computations within each batch to be performed in parallel. +The parallelization in `shapr::explain()` is handled by the +`future_apply` which builds on the `future` environment. The `future` +package works on all OS, allows the user to decide the parallelization +backend (mutliple R procesess or forking), works directly with hpc +clusters, and also supports progress updates for the parallelized task +(see [Verbosity and progress updates](#verbose)). -The argument `prediction_zero` is set to the mean of the time series, -and is repeated two times. Each value of `prediction_zero` is the -baseline for each forecast horizon. In our example, we assume that given -no effect from the two lags, the temperature would just be the average -during the observed period. Finally, we opt to not group the lags by -setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be -explained separately. Grouping lags may be more interesting to do in a -model with multiple variables, as it is then possible to explain each -variable separately. +Note that, since it takes some time to duplicate data into different +processes/machines when running in parallel, it is not always +preferrable to run `shapr::explain()` in parallel, at least not with +many parallel sessions (hereby called **workers**). Parallelization also +increases the memory consumption proportionally, so you want to limit +the number of workers for that reason too. +Below is a basic example of a parallelization with two workers. -```r -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_ts2[, "Temp"], - train_idx = 2:152, - explain_idx = 153, - explain_y_lags = 2, - horizon = 2, +``` r +library(future) +future::plan(multisession, workers = 2) + +explanation_par <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:33 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c770548a9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. -print(explanation_forecast) -#> explain_idx horizon none Temp.1 Temp.2 -#> -#> 1: 153 1 77.79 -6.578 -0.134 -#> 2: 153 2 77.79 -5.980 -0.288 +future::plan(sequential) # To return to non-parallel computation ``` -The results are presented per value of `explain_idx` and forecast -horizon. We can see that the mean temperature was around 77.9 degrees. -At horizon 1, the first lag in the model caused it to be 6.6 degrees -lower, and the second lag had just a minor effect. At horizon 2, the -first lag has a slightly smaller negative impact, and the second lag has -a slightly larger impact. + -It is also possible to explain a forecasting model which uses exogenous -regressors. The previous example is expanded to use an ARIMA(2,0,0) -model with `Wind` as an exogenous regressor. Since the exogenous -regressor must be available for the predicted time points, the model is -just fit on the 151 first observations, leaving two observations of -`Wind` to be used as exogenous values during the prediction phase. +# Verbosity and progress updates +The `verbose` argument controls the verbosity of the output while running `explain()`, +and allows one or more of the strings `"basic"`, `"progress"`, `"convergence"`, `"shapley"` and `"vS_details"`. +`"basic"` (default) displays basic information about the computation which is being performed, +`"progress` displays information about where in the calculation process the function currently is, +`"convergence"` displays information on how close to convergence the Shapley value estimates are +(for iterative estimation), +`"shapley"` displays (intermediate) Shapley value estimates and standard deviations + the final estimates, +while `"vS_details"` displays information about the `v(S)` estimates for some of the approaches. +If the user wants no printout, the argument can be set to `NULL`. -```r -data_ts3 <- data.table::as.data.table(airquality) +In additon, progress updates of the computation of the `v(S)` values, values through the R-package `progressr`. +This gives the user full control over the visual appearance of these progress updates. +The main reason for providing this separate progress update feature is that it +integreats seamlessly with the parallelization framework `future` used by `shapr` (see [Parallelization](#para)), +and apparently is the only framework allowing progress updates also for parallelized tasks. +These progress updates can be used in combination with, or independently of, the `verbose` argument. -data_fit <- data_ts3[seq_len(151), ] - -model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) - -newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] - -predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred -#> Time Series: -#> Start = 152 -#> End = 153 -#> Frequency = 1 -#> [1] 77.500 76.381 -``` +These progress updates via `progressr` are enabled for the current R-session by running the +command `progressr::handlers(local=TRUE)`, before calling +`explain()`. To use progress updates for only a single call to +`explain()`, one can wrap the call using +`progressr::with_progress` as follows: +`progressr::with_progress({ shapr::explain() })` The default appearance +of the progress updates is a basic ASCII-based horizontal progress bar. +Other variants can be chosen by passing different strings to +`progressr::handlers()`, some of which require additional packages. If +you are using Rstudio, the progress can be displayed directly in the gui +with `progressr::handlers('rstudio')` (requires the `rstudioapi` +package). If you are running Windows, you may use the pop-up gui +progress bar `progressr::handlers('handler_winprogressbar')`. +A wrapper for progressbar of the flexible `cli` package, is also available +`progressr::handlers('cli')`.. -The `shapr` package can then explain not only the two autoregressive -lags, but also the single lag of the exogenous regressor. In order to do -so, the `Wind` variable is passed as the argument `xreg`, and -`explain_xreg_lags` is set to 1. Notice how only the first 151 -observations are used for `y` and all 153 are used for `xreg`. This -makes it possible for `shapr` to not only explain the effect of the -first lag of the exogenous variable, but also the contemporary effect -during the forecasting period. +For a full list of all progression handlers and the customization +options available with `progressr`, see the `progressr` +[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). +A full code example of using `progressr` with `shapr` is shown below: -```r -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_fit[, "Temp"], - xreg = data_ts3[, "Wind"], - train_idx = 2:150, - explain_idx = 151, - explain_y_lags = 2, - explain_xreg_lags = 1, - horizon = 2, +``` r +library(progressr) +progressr::handlers(global = TRUE) +# If no progression handler is specified, the txtprogressbar is used +# Other progression handlers: +# progressr::handlers('rstudio') # requires the 'rstudioapi' package +# progressr::handlers('handler_winprogressbar') # Window only +# progressr::handlers('cli') # requires the 'cli' package +ex_progress <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data_fit$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -#> Note: Feature names extracted from the model contains NA. -#> Consistency checks between model and data is therefore disabled. -print(explanation_forecast$shapley_values) -#> explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.F1 Wind.F2 -#> -#> 1: 151 1 77.96 -0.67793 -0.67340 -1.2688 0.493408 NA -#> 2: 151 2 77.96 0.39968 -0.50059 -1.4655 0.065913 -0.47422 +handlers("progress") +#| [=================================>----------------------] 60% Estimating v(S) ``` + + +
@@ -1284,43 +1432,99 @@ features, using `"empirical", "copula"` and `"gaussian"` when conditioning on respectively 1, 2 and 3 features. -```r +``` r +library(xgboost) +library(data.table) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + + # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("empirical", "copula", "gaussian"), - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:36 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical, copula, and gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c67f7e50f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect plot(explanation_combined, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-20-1.png) +![](figure_main/unnamed-chunk-18-1.png) As a second example using `"ctree"` to condition on 1 and 2 features, and `"empirical"` when conditioning on 3 features: -```r +``` r # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("ctree", "ctree", "empirical"), - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 10 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:38 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree, ctree, and empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c68a713f2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` ## Explain groups of features @@ -1332,7 +1536,7 @@ intuition and real world examples. Explaining prediction in terms of groups of features is very easy using `shapr`: -```r +``` r # Define the feature groups group_list <- list( A = c("Temp", "Month"), @@ -1345,48 +1549,42 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - group = group_list + phi0 = p0, + group = group_list, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 4, +#> and is therefore set to 2^n_groups = 4. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:39 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of group-wise Shapley values: 2 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4f1e913.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. # Prints the group-wise explanations explanation_group -#> none A B -#> -#> 1: 47.27 -29.588 13.1628 -#> 2: 47.27 -11.834 -15.7011 -#> 3: 47.27 -15.976 -17.5729 -#> 4: 47.27 -25.067 -5.1374 -#> 5: 47.27 -35.848 20.2892 -#> 6: 47.27 -27.257 -8.4830 -#> 7: 47.27 -14.960 -21.3995 -#> 8: 47.27 -18.325 7.3791 -#> 9: 47.27 -23.012 9.6591 -#> 10: 47.27 -16.189 -5.6100 -#> 11: 47.27 -25.607 -10.1334 -#> 12: 47.27 -25.065 -5.1394 -#> 13: 47.27 -25.841 -0.7281 -#> 14: 47.27 -21.518 -13.3293 -#> 15: 47.27 -21.248 -1.3199 -#> 16: 47.27 -13.676 -16.9497 -#> 17: 47.27 -13.899 -14.8890 -#> 18: 47.27 -12.276 -8.2472 -#> 19: 47.27 -13.768 -13.5242 -#> 20: 47.27 -24.866 -10.8744 -#> 21: 47.27 -14.486 -22.7674 -#> 22: 47.27 -4.122 -14.2893 -#> 23: 47.27 -11.218 22.4682 -#> 24: 47.27 -33.002 14.2114 -#> 25: 47.27 -16.251 -8.6796 -#> none A B +#> explain_id none A B +#> +#> 1: 1 43.09 -29.25 16.0731 +#> 2: 2 43.09 -15.17 -7.8373 +#> 3: 3 43.09 -13.07 -10.8778 +#> 4: 4 43.09 -17.47 0.6653 +#> 5: 5 43.09 -28.27 3.5289 +#> 6: 6 43.09 -20.59 -3.3793 # Plots the group-wise explanations plot(explanation_group, bar_plot_phi0 = TRUE, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-22-1.png) +![](figure_main/unnamed-chunk-20-1.png) ## Explain custom models @@ -1440,8 +1638,10 @@ this for the `gbm` model class from the `gbm` package, fitted to the same airquality data set as used above. -```r +``` r library(gbm) +#> Loaded gbm 2.2.2 +#> This version of gbm is no longer under development. Consider transitioning to gbm3, https://github.com/gbm-developers/gbm3 formula_gbm <- as.formula(paste0(y_var, "~", paste0(x_var, collapse = "+"))) # Fitting a gbm model @@ -1486,20 +1686,33 @@ explanation_custom <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_predict_model, get_model_specs = MY_get_model_specs ) -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:41 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c16415c3d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot results plot(explanation_custom, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-23-1.png) +![](figure_main/unnamed-chunk-21-1.png) -```r +``` r #### Minimal version of the three required model functions #### @@ -1517,21 +1730,34 @@ explanation_custom_minimal <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_MINIMAL_predict_model ) #> Note: You passed a model to explain() which is not natively supported, and did not supply a 'get_model_specs' function to explain(). #> Consistency checks between model and data is therefore disabled. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:44 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c75618775.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Plot results plot(explanation_custom_minimal, index_x_explain = c(1, 6)) ``` -![](figure_main/unnamed-chunk-23-2.png) +![](figure_main/unnamed-chunk-21-2.png) -### Tidymodels and workflows {#workflow_example} +## Tidymodels and workflows {#workflow_example} In this section, we demonstrate how to use `shapr` to explain `tidymodels` models fitted using `workflows`. In the example [above](#examples), we directly used the `xgboost` package to fit the `xgboost` model. However, we can also fit the `xgboost` model using the `tidymodels` package. These fits will be identical @@ -1539,7 +1765,7 @@ as `tidymodels` calls `xgboost` internally. which we demonstrate in the example `xgboost` (i.e., `parsnip::boost_tree`) with any other fitted `tidymodels` in the `workflows` procedure outlined below. -```r +``` r # Fitting a basic xgboost model to the training data using tidymodels set.seed(123) # Set the same seed as above all_var <- c(y_var, x_var) @@ -1566,7 +1792,7 @@ model_tidymodels <- parsnip::fit( # See that the output of the two models are identical all.equal(predict(model_tidymodels, x_train)$.pred, predict(model, as.matrix(x_train))) -#> [1] "Mean relative difference: 0.018699" +#> [1] TRUE # Create the Shapley values for the tidymodels version explanation_tidymodels <- explain( @@ -1574,16 +1800,30 @@ explanation_tidymodels <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 -) + phi0 = p0, + iterative = FALSE + ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:48 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c1d933001.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # See that the Shapley value explanations are identical too -all.equal(explanation$shapley_values, explanation_tidymodels$shapley_values) -#> [1] "Different number of rows" +all.equal(explanation$shapley_values_est, explanation_tidymodels$shapley_values_est) +#> [1] TRUE ``` - ## The parameters of the `vaeac` approach The `vaeac` approach is a very flexible method that supports mixed data. The main @@ -1602,28 +1842,42 @@ extra parameters to the `vaeac` approach. We strongly encourage the user to spec -```r +``` r explanation_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 3, - vaeac.n_vaeacs_initialize = 2 + vaeac.n_vaeacs_initialize = 2, + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:45:51 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c4ef15f9c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Can look at the training and validation error for the trained `vaeac` model and see that `vaeac.epochs = 3` is likely to few epochs as it still seems like the `vaeac` model is learning. -```r +``` r # Look at the training and validation errors. vaeac_plot_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac), plot_type = "method") ``` @@ -1645,29 +1899,43 @@ is applied. Furthermore, a value of `2` is too low for real world applications, to make the vignette faster to build. -```r +``` r explanation_vaeac_early_stop <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 1000, # Set it to a large number vaeac.n_vaeacs_initialize = 2, - vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) + vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2), + iterative = FALSE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting parameter 'n_batches' to 2 as a fair trade-off between memory consumption and computation time. -#> Reducing 'n_batches' typically reduces the computation time at the cost of increased memory consumption. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:07 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c1b83b97d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Can compare with the previous version and see that the results are more stable now. -```r +``` r # Look at the training and validation errors. vaeac_plot_eval_crit( list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop), @@ -1680,191 +1948,436 @@ vaeac_plot_eval_crit( Can also compare the $MSE_{v}$ evaluation scores. -```r +``` r plot_MSEv_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop)) ``` ![](figure_main/vaeac-plot-3-1.png) +## Continued computation {#cont_computation} +In this section, we demonstrate how to continue to improve estimation accuracy with additional coalition samples, +from a previous Shapley value computation based on `shapr::explain()` with the iterative estimation procedure. +This can be done either by passing an existing object of class `shapr`, or by passing a string with the path to +the intermediately saved results. +The latter is found at `SHAPR_OBJ$saving_path`, defaults to a temporary folder, +and is updated after each iteration. +This can be particularly handy for long-running computations. +``` r +# First we run the computation with the iterative estimation procedure for a limited number of coalition samples +library(xgboost) +library(data.table) - +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -
+x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Scalability and efficency +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -## Batch computation +# Set seed for reproducibility +set.seed(123) -The computational complexity of Shapley value based explanations grows -fast in the number of features, as the number of conditional -expectations one needs to estimate in the Shapley formula grows -exponentially. As outlined [above](#KSHAP), the estimating of each of -these conditional expectations is also computationally expensive, -typically requiring estimation of a conditional probability -distribution, followed by Monte Carlo integration. These computations -are not only heavy for the CPU, they also require a lot of memory (RAM), -which typically is a limited resource. By doing the most resource hungry -computations (the computation of v(S)) in sequential batches with -different feature subsets $S$, the memory usage can be significantly -reduces. Such batching comes at the cost of an increase in computation -time, which depends on the number of feature subsets (`n_combinations`), -the number of features, the estimation `approach` and so on. When -calling `shapr::explain()`, we allow the user to set the number of -batches with the argument `n_batches`. The default of this argument is -`NULL`, which uses a (hopefully) reasonable trade-off between -computation speed and memory consumption which depends on -`n_combinations` and `approach`. The memory/computation time trade-off -is most apparent for models with more than say 6-7 features. Below we a -basic example where `n_batches=10`: - - -```r -explanation_batch <- explain( +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + +# Initial explanation computation +ex_init <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -``` - -## Parallelized computation - -In addition to reducing the memory consumption, the introduction of the -`n_batch` argument allows computation within each batch to be performed in parallel. -The parallelization in `shapr::explain()` is handled by the -`future_apply` which builds on the `future` environment. The `future` -package works on all OS, allows the user to decide the parallelization -backend (mutliple R procesess or forking), works directly with hpc -clusters, and also supports progress updates for the parallelized task -(see below). - -Note that, since it takes some time to duplicate data into different -processes/machines when running in parallel, it is not always -preferrable to run `shapr::explain()` in parallel, at least not with -many parallel sessions (hereby called **workers**). Parallelization also -increases the memory consumption proportionally, so you want to limit -the number of workers for that reason too. In a future version of -`shapr` we will provide experienced based automatic selection of the -number of workers. In the meanwhile, this is all let to the user, and we -advice that `n_batches` equals some positive integer multiplied by the -number of workers. Below is a basic example of a parallelization with -two workers. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:29 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: TRUE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c5251f86b.rds' +#> +#> ── iterative computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 5 of 32 coalitions, 5 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 32 coalitions, 4 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 32 coalitions, 2 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 16 of 32 coalitions, 4 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 18 of 32 coalitions, 2 new. +# Using the ex_init object to continue the computation with 5 more coalition samples +ex_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 25, + iterative_args = list(convergence_tol = 0.005), # Decrease the convergence threshold + prev_shapr_object = ex_init +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:34 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c45a5f9b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 24 of 32 coalitions. -```r -library(future) -future::plan(multisession, workers = 2) +print(ex_further$saving_path) +#> [1] "/tmp/RtmpGq2OQE/shapr_obj_3026c45a5f9b2.rds" -explanation_par <- explain( +# Using the ex_init object to continue the computation for the remaining coalition samples +# but this time using the path to the saved intermediate estimation object +ex_even_further <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = NULL, + prev_shapr_object = ex_further$saving_path ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:35 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c7433ff88.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. +``` -future::plan(sequential) # To return to non-parallel computation + + +
+ + +# Explaining a forecasting model using `explain_forecast` + +`shapr` provides a specific function, `explain_forecast`, to explain +forecasts from time series models, at one or more steps into the future. +The main difference compared to `explain` is that the data is supplied +as (set of) time series, in addition to index arguments (`train_idx` and +`explain_idx`) specifying which time points that represents the train +and explain parts of the data. See `?explain_forecast` for more +information. + +To demonstrate how to use the function, 500 observations are generated +which follow an AR(1) structure, i.e. +$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of +order (2, 0, 0) is fitted, and we therefore would like to explain the +forecasts in terms of the two previous lags of the time series. This is +is specified through the argument `explain_y_lags = 2`. Note that some +models may also put restrictions on the amount of data required to make +a forecast. The AR(2) model we used there, for instance, requires two +previous time point to make a forecast. + +In the example, two separate forecasts, each three steps ahead, are +explained. To set the starting points of the two forecasts, +`explain_idx` is set to `499:500`. This means that one forecast of +$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be +explained. In other words, `explain_idx` tells `shapr` at which points +in time data was available up until, when making the forecast to +explain. + +In the same way, `train_idx` denotes the points in time used to estimate +the conditional expectations used to explain the different forecasts. +Note that since we want to explain the forecasts in terms of the two +previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` +must also be 2, because at time $t = 1$ there was only a single +observation available. + +Since the data is stationary, the mean of the data is used as value of +`phi0` (i.e. $\phi_0$). This can however be chosen +differently depending on the data and application. + +For a multivariate model such as a VAR (Vector AutoRegressive model), it +may be of more interesting to explain the impact of each variable, +rather than each lag of each variable. This can be done by setting +`group_lags = TRUE`. + + +``` r +# Simulate time series data with AR(1)-structure. +set.seed(1) +data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) +data_ts <- data.table::as.data.table(data_ts) + +# Fit an ARIMA(2, 0, 0) model. +arima_model <- arima(data_ts, order = c(2, 0, 0)) + +# Set prediction zero as the mean of the data for each forecast point. +p0_ar <- rep(mean(data_ts$Y), 3) + +# Explain forecasts from points t = 499 and t = 500. +explain_idx <- 499:500 + +explanation_forecast <- explain_forecast( + model = arima_model, + y = data_ts, + train_idx = 2:498, + explain_idx = 499:500, + explain_y_lags = 2, + horizon = 3, + approach = "empirical", + phi0 = p0_ar, + group_lags = FALSE +) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 4, +#> and is therefore set to 2^n_features = 4. +#> Registered S3 method overwritten by 'quantmod': +#> method from +#> as.zoo.data.frame zoo +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:36 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 2 +#> • Number of observations to explain: 2 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c6949be4.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. +explanation_forecast +#> explain_idx horizon none Y.1 Y.2 +#> +#> 1: 499 1 0.04018 0.5053 -0.07659 +#> 2: 500 1 0.04018 -0.3622 0.02497 +#> 3: 499 2 0.04018 0.5053 -0.07659 +#> 4: 500 2 0.04018 -0.3622 0.02497 +#> 5: 499 3 0.04018 0.5053 -0.07659 +#> 6: 500 3 0.04018 -0.3622 0.02497 ``` -## Progress updates +Note that for a multivariate model such as a VAR (Vector AutoRegressive +model), or for models also including several exogenous variables, it may +be of more informative to explain the impact of each variable, rather +than each lag of each variable. This can be done by setting +`group_lags = TRUE`. This does not make sense for this model, however, +as that would result in decomposing the forecast into a single group. -`shapr` provides progress updates of the computation of the Shapley -values through the R-package `progressr`. This gives the user full -control over the visual appearance of the progress updates, and also -integrates seamlessly with the parallelization framework `future` used -by `shapr` (see above). Note that the progress is updated as the batches -are completed, meaning that if you have chosen `n_batches=1`, you will -not get intermediate updates, while if you set `n_batches=10` you will -get updates on every 10% of the computation. +We now give a more hands on example of how to use the `explain_forecast` +function. Say that we have an AR(2) model which describes the change +over time of the variable `Temp` in the dataset `airquality`. It seems +reasonable to assume that the temperature today should affect the +temperature tomorrow. To a lesser extent, we may also suggest that the +temperature today should also have an impact on that of the day after +tomorrow. -Progress updates are enabled for the current R-session by running the -command `progressr::handlers(local=TRUE)`, before calling -`shapr::explain()`. To use progress updates for only a single call to -`shapr::explain()`, one can wrap the call using -`progressr::with_progress` as follows: -`progressr::with_progress({ shapr::explain() })` The default appearance -of the progress updates is a basic ASCII-based horizontal progress bar. -Other variants can be chosen by passing different strings to -`progressr::handlers()`, some of which require additional packages. If -you are using Rstudio, the progress can be displayed directly in the gui -with `progressr::handlers('rstudio')` (requires the `rstudioapi` -package). If you are running Windows, you may use the pop-up gui -progress bar `progressr::handlers('handler_winprogressbar')`. A wrapper -for progressbar of the flexible `cli` package is also available -`progressr::handlers('cli')` (requires the `cli` package). +We start by building our AR(2) model, naming it `model_ar_temp`. This +model is then used to make a forecast of the temperature of the day that +comes after the last day in the data, this forecast starts from index +153. -For a full list of all progression handlers and the customization -options available with `progressr`, see the `progressr` -[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). -A full code example of using `progressr` with `shapr` is shown below: +``` r +data_ts2 <- data.table::as.data.table(airquality) -```r -library(progressr) -progressr::handlers(global = TRUE) -# If no progression handler is specified, the txtprogressbar is used -# Other progression handlers: -# progressr::handlers('rstudio') # requires the 'rstudioapi' package -# progressr::handlers('handler_winprogressbar') # Window only -# progressr::handlers('cli') # requires the 'cli' package -explanation <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +model_ar_temp <- ar(data_ts2$Temp, order = 2) + +predict(model_ar_temp, n.ahead = 2)$pred +#> Time Series: +#> Start = 154 +#> End = 155 +#> Frequency = 1 +#> [1] 71.081 71.524 +``` + +First, we pass the model and the data as `model` and `y`. Since we have +an AR(2) model, we want to explain the forecasts in terms of the two +previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let +`shapr` know which time indices to use as training data through the +argument `train_idx`. We use `2:152`, meaning that we skip the first +index, as we want to explain the two previous lags. Letting the training +indices go up until 152 means that every point in time except the first +and last will be used as training data. + +The last index, 153 is passed as the argument `explain_idx`, which means +that we want to explain a forecast made from time point 153 in the data. +The argument `horizon` is set to 2 in order to explain a forecast of +length 2. + +The argument `phi0` is set to the mean of the time series, +and is repeated two times. Each value of `phi0` is the +baseline for each forecast horizon. In our example, we assume that given +no effect from the two lags, the temperature would just be the average +during the observed period. Finally, we opt to not group the lags by +setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be +explained separately. Grouping lags may be more interesting to do in a +model with multiple variables, as it is then possible to explain each +variable separately. + + +``` r +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_ts2[, "Temp"], + train_idx = 2:152, + explain_idx = 153, + explain_y_lags = 2, + horizon = 2, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = rep(mean(data$Temp), 2), + group_lags = FALSE ) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 4, +#> and is therefore set to 2^n_features = 4. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:38 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 2 +#> • Number of observations to explain: 1 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c3dcf1900.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 4 of 4 coalitions. -handlers("progress") -#| [=================================>----------------------] 60% Estimating v(S) +print(explanation_forecast) +#> explain_idx horizon none Temp.1 Temp.2 +#> +#> 1: 153 1 77.79 -6.578 -0.134 +#> 2: 153 2 77.79 -5.980 -0.288 ``` +The results are presented per value of `explain_idx` and forecast +horizon. We can see that the mean temperature was around 77.9 degrees. +At horizon 1, the first lag in the model caused it to be 6.6 degrees +lower, and the second lag had just a minor effect. At horizon 2, the +first lag has a slightly smaller negative impact, and the second lag has +a slightly larger impact. +It is also possible to explain a forecasting model which uses exogenous +regressors. The previous example is expanded to use an ARIMA(2,0,0) +model with `Wind` as an exogenous regressor. Since the exogenous +regressor must be available for the predicted time points, the model is +just fit on the 151 first observations, leaving two observations of +`Wind` to be used as exogenous values during the prediction phase. - +``` r +data_ts3 <- data.table::as.data.table(airquality) + +data_fit <- data_ts3[seq_len(151), ] + +model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) + +newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] + +predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred +#> Time Series: +#> Start = 152 +#> End = 153 +#> Frequency = 1 +#> [1] 77.500 76.381 +``` + +The `shapr` package can then explain not only the two autoregressive +lags, but also the single lag of the exogenous regressor. In order to do +so, the `Wind` variable is passed as the argument `xreg`, and +`explain_xreg_lags` is set to 1. Notice how only the first 151 +observations are used for `y` and all 153 are used for `xreg`. This +makes it possible for `shapr` to not only explain the effect of the +first lag of the exogenous variable, but also the contemporary effect +during the forecasting period. + + +``` r +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_fit[, "Temp"], + xreg = data_ts3[, "Wind"], + train_idx = 2:150, + explain_idx = 151, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = rep(mean(data_fit$Temp), 2), + group_lags = FALSE +) +#> Note: Feature names extracted from the model contains NA. +#> Consistency checks between model and data is therefore disabled. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 32, +#> and is therefore set to 2^n_features = 32. +#> +#> ── Starting `shapr::explain()` at 2024-10-22 23:46:39 ─────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • Iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 5 +#> • Number of observations to explain: 1 +#> • Computations (temporary) saved at: '/tmp/RtmpGq2OQE/shapr_obj_3026c621a80f2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 32 of 32 coalitions. + +print(explanation_forecast$shapley_values_est) +#> explain_idx horizon none Temp.1 Temp.2 Wind.1 Wind.F1 Wind.F2 +#> +#> 1: 151 1 77.96 -0.67793 -0.67340 -1.2688 0.493408 NA +#> 2: 151 2 77.96 0.39968 -0.50059 -1.4655 0.065913 -0.47422 +``` + -
-# Comparison to Lundberg & Lee's implementation - -As mentioned above, the original (independence assuming) Kernel SHAP -implementation can be approximated by setting a large $\sigma$ value -using our empirical approach. If we specify that the distances to *all* -training observations should be used (i.e. setting -`approach = "empirical"` and `empirical.eta = 1` when using `explain`, -we can approximate the original method arbitrarily well by increasing -$\sigma$. For completeness of the `shapr` package, we have also -implemented a version of the original method, which samples training -observations independently with respect to their distances to test -observations (i.e. without the large-$\sigma$ approximation). This -method is available by using `approach = "independence"` in `explain`. - -We have compared the results using these two variants with the original -implementation of @lundberg2017unified, available through the Python -library [`shap`](https://github.com/slundberg/shap). As above, we used -the Boston housing data, trained via `xgboost`. We specify that *all* -training observations should be used when explaining all of the 6 test -observations. To run the individual explanation method in the `shap` -Python library we use the `reticulate` `R`-package, allowing Python code -to run within `R`. As this requires installation of Python package, the -comparison code and results is not included in this vignette, but can be -found -[here](https://github.com/NorskRegnesentral/shapr/blob/master/inst/scripts/compare_shap_python.R). -As indicated by the (commented out) results in the file above both -methods in our `R`-package give (up to numerical approximation error) -identical results to the original implementation in the Python `shap` -library.
diff --git a/vignettes/understanding_shapr.Rmd.orig b/vignettes/understanding_shapr.Rmd.orig index 32699e239887a5d890695d413a5f9767cc4a21f2..9e159a00540ea65e58f9aef10c29b4967dc4cba4 100644 --- a/vignettes/understanding_shapr.Rmd.orig +++ b/vignettes/understanding_shapr.Rmd.orig @@ -35,15 +35,19 @@ library(shapr) > [Overview of Package](#overview) -> [The Kernel SHAP Method](#KSHAP) +> [KernelSHAP and dependence-aware estimators](#KSHAP) -> [Examples](#ex) +> [Estimation approaches and plotting functionality](#ex) -> [Advanced usage](#advanced) +> [iterative estimation](#iterative) + +> [Parallelization](#para) -> [Scalability and efficency](#scalability) +> [Verbosity and progress updates](#verbose) + +> [Advanced usage](#advanced) -> [Comparison to Lundberg & Lee's implementation](#compare) +> [Explaining forecasting models](#forecasting) @@ -58,7 +62,7 @@ on interpreting individual predictions, Shapley values is regarded to be the only model-agnostic explanation method with a solid theoretical foundation (@lundberg2017unified). Kernel SHAP is a computationally efficient approximation to Shapley values in higher dimensions, but it -assumes independent features. @aas2019explaining extend the Kernel SHAP +assumes independent features. @aas2019explaining extends the Kernel SHAP method to handle dependent features, resulting in more accurate approximations to the true Shapley values. See the [paper](https://www.sciencedirect.com/sdfe/reader/pii/S0004370221000539/pdf) @@ -70,7 +74,7 @@ approximations to the true Shapley values. See the # Overview of Package -## Functions +## Functionality Here is an overview of the main functions. You can read their documentation and see examples with `?function_name`. @@ -83,11 +87,62 @@ documentation and see examples with `?function_name`. : Main functions in the `shapr` package. +The `shapr` package implements kernelSHAP estimation of dependence-aware Shapley values with +eight different Monte Carlo-based approaches for estimating the conditional distributions of the data, namely +`"empirical"`, `"gaussian"`, `"copula"`, `"ctree"`, `"vaeac"`, `"categorical"`, `"timeseries"`, and `"independence"`. +`shapr` has also implemented two regression-based approaches `"regression_separate"` and `"regression_surrogate"`. +See [Estimation approaches and plotting functionality](#ex) below for examples. +It is also possible to combine the different approaches, see the [combined approach](#combined). + +The package allows for parallelized computation through the `future`package, see [Parallelization](#para) for details. + +The level of detail in the output can be controlled through the `verbose` argument. In addition, progress updates +on the process of estimating the `v(S)`'s (and training the `"vaeac"` model) is available through the +`progressr` package, supporting progress updates also for parallelized computation. +See [Verbosity and progress updates](#verbose) for details. + +Moreover, the default behavior is to estimate the Shapley values iteratively/iteratively, with increasing number of +feature coalitions being added, and to stop estimation as the estimated Shapley values has achieved a certain level of +stability. +More information about this is provided in [iterative estimation](#iterative) +The above, combined with batch computation of the `v(S)` values, enables fast and accurate estimation of the +Shapley values in a memory friendly manner. + +The package also provides functionality for computing Shapley values for groups of features, and custom function explanation, see [Advanced usage](#advanced). +Finally, explanation of multiple output time series forecasting models are discussed in +[Explaining forecasting models](#forecasting). + + +## Default behavior of `explain` + +Below we provide brief descriptions of the most important parts of the default behavior of the `explain` function. + +By default `explain` always compute feature-wise Shapley values. +Groups of features can be explained by providing the feature groups through the `group` argument. + +When there are five or less features (or feature groups), iterative estimation is by default disabled. +The reason for this is that it is usually faster to estimate the Shapley values for all possible coalitions (`v(S)`), +than to estimate the uncertainty of the Shapley values, and potentially stop estimation earlier. +While iterative estimation is the default starting from six features, it is mainly when there are more than ten features, +that it is most beneficial, and can save a lot of computation time. +The reason for this is that the number of possible coalitions grows exponentially. +These defaults can be overridden by setting the `iterative` argument to `TRUE` or `FALSE`. +When using the `iterative` argument, the estimation for an observation is stopped when all Shapley value +standard deviations are below `t` times the range of the Shapley values. +The `t` value controls the convergence tolerance, defaults to 0.02, and can be set through the `iterative_args$convergence_tol` argument, see [iterative estimation](#iterative) for more details. + +Since the iterativeness default changes based on the number of features (or feature groups), the default is also to have +no upper bound on the number of coalitions considered. +This can be controlled through the `max_n_coalitions` argument. + +
-# The Kernel SHAP Method +# KernelSHAP and dependence-aware estimators + +## The Kernel SHAP Method Assume a predictive model $f(\boldsymbol{x})$ for a response value $y$ with features $\boldsymbol{x}\in \mathbb{R}^M$, trained on a training @@ -252,9 +307,7 @@ AIC known as AICc. As calculation of it is computationally intensive, an approximate version of the selection criterion is also suggested. Details on this is found in @aas2019explaining. - - -
+ ## Conditional Inference Tree Approach @@ -334,6 +387,8 @@ the `explain()` function. For example, we can the change the batch size to 32 by `vaeac.extra_parameters = list(vaeac.batch_size = 32)` as a parameter in the call the `explain()` function. See `?shapr::vaeac_get_extra_para_default` for a description of the possible extra parameters to the `vaeac` approach. We strongly encourage the user to specify the main and extra parameters to the `vaeac` approach at the correct place in the call to the `explain()` function. That is, the main parameters are directly entered to the `explain()` function, while the extra parameters are included in a named list called `vaeac.extra_parameters`. However, the `vaeac` approach will try to correct for misplaced and duplicated parameters and give warnings to the user. + + ## Categorical Approach When the features are all categorical, we can estimate the conditional @@ -380,17 +435,11 @@ paradigm into the separate and surrogate regression method classes. In the separate vignette, we briefly introduce the two method classes. For an in-depth explanation, we refer the reader to Sections 3.5 and 3.6 in @olsen2024comparative. -
-# Examples {#examples} - -`shapr` supports computation of Shapley values with any predictive model -which takes a set of numeric features and produces a numeric outcome. -Note that the ctree method takes both numeric and categorical variables. -Check under "Advanced usage" for an example of how this can be done. +# Estimation approaches and plotting functionality {#ex} The following example shows how a simple `xgboost` model is trained using the `airquality` dataset, and how `shapr` can be used to explain @@ -439,12 +488,13 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) # Printing the Shapley values for the test data. # For more information about the interpretation of the values in the table, see ?shapr::explain. -print(explanation$shapley_values) +print(explanation$shapley_values_est) # Plot the resulting explanations for observations 1 and 6 plot(explanation, bar_plot_phi0 = FALSE, index_x_explain = c(1, 6)) @@ -477,7 +527,8 @@ explanation_plot <- explain( x_explain = x_explain_many, x_train = x_train, approach = "empirical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) plot(explanation_plot, plot_type = "beeswarm") ``` @@ -517,7 +568,8 @@ explanation_lm_cat <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) # Plot the resulting explanations for observations 1 and 6, excluding @@ -538,10 +590,11 @@ explanation_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = p0, + phi0 = p0, ctree.mincriterion = 0.80, ctree.minsplit = 20, - ctree.minbucket = 20 + ctree.minbucket = 20, + iterative = FALSE ) # Default parameters (based on (Hothorn, 2006)) are: # mincriterion = 0.95 @@ -578,7 +631,8 @@ explanation_cat_method <- explain( x_explain = x_explain_all_cat, x_train = x_train_all_cat, approach = "categorical", - prediction_zero = p0 + phi0 = p0, + iterative = FALSE ) ``` @@ -638,8 +692,9 @@ explanation_timeseries <- explain( x_explain = x_explain_ts, x_train = x_train_ts, approach = "timeseries", - prediction_zero = p0_ts, - group = group_ts + phi0 = p0_ts, + group = group_ts, + iterative = FALSE ) ``` @@ -747,9 +802,8 @@ explanation_independence <- explain( x_explain = x_explain, x_train = x_train, approach = "independence", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -759,9 +813,8 @@ explanation_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -771,9 +824,8 @@ explanation_gaussian_1e1 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e1, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e1, MSEv_uniform_comb_weights = TRUE ) @@ -783,9 +835,8 @@ explanation_gaussian_1e2 <- explain( x_explain = x_explain, x_train = x_train, approach = "gaussian", - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -795,9 +846,8 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("gaussian", "empirical", "independence"), - prediction_zero = p0, - n_samples = 1e2, - n_batches = 5, + phi0 = p0, + n_MC_samples = 1e2, MSEv_uniform_comb_weights = TRUE ) @@ -854,7 +904,7 @@ plot_MSEv_eval_crit(explanation_list_named, )$MSEv_explicand_bar plot_MSEv_eval_crit(explanation_list_named, plot_type = "comb", - id_combination = c(3, 4, 9, 13:15), + id_coalition = c(3, 4, 9, 13:15), CI_level = 0.95 )$MSEv_combination_bar ``` @@ -886,243 +936,201 @@ plot_MSEv_eval_crit(explanation_list_named) + ) ``` + -## Main arguments in `explain` - -When using `explain`, the default behavior is to use all feature -combinations in the Shapley formula. Kernel SHAP's sampling based -approach may be used by specifying `n_combinations`, which is the number -of unique feature combinations to sample. If not specified, the exact -method is used. The computation time grows approximately exponentially -with the number of features. The training data and the model whose -predictions we wish to explain must be provided through the arguments -`x_train` and `model`. The data whose predicted values we wish to -explain must be given by the argument `x_explain`. Note that both -`x_train` and `x_explain` must be a `data.frame` or a `matrix`, and all -elements must be finite numerical values. Currently we do not support -missing values. The default approach when computing the Shapley values -is the empirical approach (i.e. `approach = "empirical"`). If you'd like -to use a different approach you'll need to set `approach` equal to -either `copula` or `gaussian`, or a vector of them, with length equal to -the number of features. If a vector, a combined approach is used, and -element `i` indicates the approach to use when conditioning on `i` -variables. For more details see [Combined approach](#combined) below. - -When computing the kernel SHAP values by `explain`, the maximum number -of samples to use in the Monte Carlo integration for every conditional -expectation is controlled by the argument `n_samples` (default equals -`1000`). The computation time grows approximately linear with this -number. You will also need to pass a numeric value for the argument -`prediction_zero`, which represents the prediction value when not -conditioning on any features. We recommend setting this equal to the -mean of the response, but other values, like the mean prediction of a -large test data set is also a possibility. If the empirical method is -used, specific settings for that approach, like a vector of fixed -$\sigma$ values can be specified through the argument -`empirical.fixed_sigma`. See `?explain` for more information. If -`approach = "gaussian"`, you may specify the mean vector and covariance -matrix of the data generating distribution by the arguments -`gaussian.mu` and `gaussian.cov_mat`. If not specified, they are -estimated from the training data. - -## Explaining a forecasting model using `explain_forecast` +# iterative estimation -`shapr` provides a specific function, `explain_forecast`, to explain -forecasts from time series models, at one or more steps into the future. -The main difference compared to `explain` is that the data is supplied -as (set of) time series, in addition to index arguments (`train_idx` and -`explain_idx`) specifying which time points that represents the train -and explain parts of the data. See `?explain_forecast` for more -information. +iterative estimation is the default when computing Shapley values with six or more features (or feature groups), and +can always be manually overridden by setting `iterative = FALSE` in the `explain()` function. +The idea behind iterative estimation is to estimate sufficiently accurate Shapley value estimates faster. +First, an initial number of coalitions is sampled, then, bootsrapping is used to estimate the variance of the Shapley +values. +A convergence criterion is used to determine if the variances of the Shapley values are sufficently small. +If the variances are too high, we estimate the number of required samples to reach convergence, and thereby add more +coalitions. +The process is repeated until the variances are below the threshold. +Specifics related to the iterative process and convergence criterion are set through `iterative_args` argument. -To demonstrate how to use the function, 500 observations are generated -which follow an AR(1) structure, i.e. -$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of -order (2, 0, 0) is fitted, and we therefore would like to explain the -forecasts in terms of the two previous lags of the time series. This is -is specified through the argument `explain_y_lags = 2`. Note that some -models may also put restrictions on the amount of data required to make -a forecast. The AR(2) model we used there, for instance, requires two -previous time point to make a forecast. +The convergence criterion we use is adopted from @covert2021improving, and slightly modified to work for multiple +observations -In the example, two separate forecasts, each three steps ahead, are -explained. To set the starting points of the two forecasts, -`explain_idx` is set to `499:500`. This means that one forecast of -$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be -explained. In other words, `explain_idx` tells `shapr` at which points -in time data was available up until, when making the forecast to -explain. +\[ \median_i\left(\frac{max_j \hat{\text{sd}}(\hat{\phi}_{ij}){\max_j \hat{\phi}_{ij} - \min_j \hat{\phi}_{ij}}\right), < t \] -In the same way, `train_idx` denotes the points in time used to estimate -the conditional expectations used to explain the different forecasts. -Note that since we want to explain the forecasts in terms of the two -previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` -must also be 2, because at time $t = 1$ there was only a single -observation available. +where $\hat{\phi}_{ij}$ is the Shapley value of feature $j$ for observation $i$, and $\text{sd}(\phi_{ij})$ +is the its (bootstrap) estimated standard deviation. The default value of $t$ is 0.02. +Below we provide some examples of how to use the iterative estimation procedure -Since the data is stationary, the mean of the data is used as value of -`prediction_zero` (i.e. $\phi_0$). This can however be chosen -differently depending on the data and application. -For a multivariate model such as a VAR (Vector AutoRegressive model), it -may be of more interesting to explain the impact of each variable, -rather than each lag of each variable. This can be done by setting -`group_lags = TRUE`. ```{r} -# Simulate time series data with AR(1)-structure. -set.seed(1) -data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) -data_ts <- data.table::as.data.table(data_ts) +library(xgboost) +library(data.table) -# Fit an ARIMA(2, 0, 0) model. -arima_model <- arima(data_ts, order = c(2, 0, 0)) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] -# Set prediction zero as the mean of the data for each forecast point. -p0_ar <- rep(mean(data_ts$Y), 3) +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -# Explain forecasts from points t = 499 and t = 500. -explain_idx <- 499:500 +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -explanation_forecast <- explain_forecast( - model = arima_model, - y = data_ts, - train_idx = 2:498, - explain_idx = 499:500, - explain_y_lags = 2, - horizon = 3, - approach = "empirical", - prediction_zero = p0_ar, - group_lags = FALSE +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE ) -explanation_forecast -``` -Note that for a multivariate model such as a VAR (Vector AutoRegressive -model), or for models also including several exogenous variables, it may -be of more informative to explain the impact of each variable, rather -than each lag of each variable. This can be done by setting -`group_lags = TRUE`. This does not make sense for this model, however, -as that would result in decomposing the forecast into a single group. +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -We now give a more hands on example of how to use the `explain_forecast` -function. Say that we have an AR(2) model which describes the change -over time of the variable `Temp` in the dataset `airquality`. It seems -reasonable to assume that the temperature today should affect the -temperature tomorrow. To a lesser extent, we may also suggest that the -temperature today should also have an impact on that of the day after -tomorrow. +# Initial explanation computation +ex <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + iterative = TRUE, + iterative_args = list(convergence_tol = 0.1) +) -We start by building our AR(2) model, naming it `model_ar_temp`. This -model is then used to make a forecast of the temperature of the day that -comes after the last day in the data, this forecast starts from index -153. +``` -```{r} -data_ts2 <- data.table::as.data.table(airquality) + -model_ar_temp <- ar(data_ts2$Temp, order = 2) +# Parallelization -predict(model_ar_temp, n.ahead = 2)$pred -``` +The `shapr` package supports parallelization of the Shapley value estimation process through the +`future` package. +The parallelization is conducted over batches of `v(S)`-values. +We therefore start by describing this batch computing. -First, we pass the model and the data as `model` and `y`. Since we have -an AR(2) model, we want to explain the forecasts in terms of the two -previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let -`shapr` know which time indices to use as training data through the -argument `train_idx`. We use `2:152`, meaning that we skip the first -index, as we want to explain the two previous lags. Letting the training -indices go up until 152 means that every point in time except the first -and last will be used as training data. +## Batch computation -The last index, 153 is passed as the argument `explain_idx`, which means -that we want to explain a forecast made from time point 153 in the data. -The argument `horizon` is set to 2 in order to explain a forecast of -length 2. +The computational complexity of Shapley value based explanations grows +fast in the number of features, as the number of conditional +expectations one needs to estimate in the Shapley formula grows +exponentially. As outlined [above](#KSHAP), the estimating of each of +these conditional expectations is also computationally expensive, +typically requiring estimation of a conditional probability +distribution, followed by Monte Carlo integration. These computations +are not only heavy for the CPU, they also require a lot of memory (RAM), +which typically is a limited resource. By doing the most resource hungry +computations (the computation of v(S)) in sequential batches with +different feature subsets $S$, the memory usage can be significantly +reduces. +The user can control the number of batches by setting the two arguments +`extra_computation_args$max_batch_size` (defaults to 10) and +`extra_computation_args$min_n_batches` (defaults to 10). -The argument `prediction_zero` is set to the mean of the time series, -and is repeated two times. Each value of `prediction_zero` is the -baseline for each forecast horizon. In our example, we assume that given -no effect from the two lags, the temperature would just be the average -during the observed period. Finally, we opt to not group the lags by -setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be -explained separately. Grouping lags may be more interesting to do in a -model with multiple variables, as it is then possible to explain each -variable separately. +## Parallelized computation + +In addition to reducing the memory consumption, the batch computing allows the +computations within each batch to be performed in parallel. +The parallelization in `shapr::explain()` is handled by the +`future_apply` which builds on the `future` environment. The `future` +package works on all OS, allows the user to decide the parallelization +backend (mutliple R procesess or forking), works directly with hpc +clusters, and also supports progress updates for the parallelized task +(see [Verbosity and progress updates](#verbose)). + +Note that, since it takes some time to duplicate data into different +processes/machines when running in parallel, it is not always +preferrable to run `shapr::explain()` in parallel, at least not with +many parallel sessions (hereby called **workers**). Parallelization also +increases the memory consumption proportionally, so you want to limit +the number of workers for that reason too. +Below is a basic example of a parallelization with two workers. ```{r} -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_ts2[, "Temp"], - train_idx = 2:152, - explain_idx = 153, - explain_y_lags = 2, - horizon = 2, +library(future) +future::plan(multisession, workers = 2) + +explanation_par <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -print(explanation_forecast) +future::plan(sequential) # To return to non-parallel computation ``` -The results are presented per value of `explain_idx` and forecast -horizon. We can see that the mean temperature was around 77.9 degrees. -At horizon 1, the first lag in the model caused it to be 6.6 degrees -lower, and the second lag had just a minor effect. At horizon 2, the -first lag has a slightly smaller negative impact, and the second lag has -a slightly larger impact. - -It is also possible to explain a forecasting model which uses exogenous -regressors. The previous example is expanded to use an ARIMA(2,0,0) -model with `Wind` as an exogenous regressor. Since the exogenous -regressor must be available for the predicted time points, the model is -just fit on the 151 first observations, leaving two observations of -`Wind` to be used as exogenous values during the prediction phase. - -```{r} -data_ts3 <- data.table::as.data.table(airquality) + -data_fit <- data_ts3[seq_len(151), ] +# Verbosity and progress updates -model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) +The `verbose` argument controls the verbosity of the output while running `explain()`, +and allows one or more of the strings `"basic"`, `"progress"`, `"convergence"`, `"shapley"` and `"vS_details"`. +`"basic"` (default) displays basic information about the computation which is being performed, +`"progress` displays information about where in the calculation process the function currently is, +`"convergence"` displays information on how close to convergence the Shapley value estimates are +(for iterative estimation), +`"shapley"` displays (intermediate) Shapley value estimates and standard deviations + the final estimates, +while `"vS_details"` displays information about the `v(S)` estimates for some of the approaches. +If the user wants no printout, the argument can be set to `NULL`. -newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] +In additon, progress updates of the computation of the `v(S)` values, values through the R-package `progressr`. +This gives the user full control over the visual appearance of these progress updates. +The main reason for providing this separate progress update feature is that it +integreats seamlessly with the parallelization framework `future` used by `shapr` (see [Parallelization](#para)), +and apparently is the only framework allowing progress updates also for parallelized tasks. +These progress updates can be used in combination with, or independently of, the `verbose` argument. -predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred -``` +These progress updates via `progressr` are enabled for the current R-session by running the +command `progressr::handlers(local=TRUE)`, before calling +`explain()`. To use progress updates for only a single call to +`explain()`, one can wrap the call using +`progressr::with_progress` as follows: +`progressr::with_progress({ shapr::explain() })` The default appearance +of the progress updates is a basic ASCII-based horizontal progress bar. +Other variants can be chosen by passing different strings to +`progressr::handlers()`, some of which require additional packages. If +you are using Rstudio, the progress can be displayed directly in the gui +with `progressr::handlers('rstudio')` (requires the `rstudioapi` +package). If you are running Windows, you may use the pop-up gui +progress bar `progressr::handlers('handler_winprogressbar')`. +A wrapper for progressbar of the flexible `cli` package, is also available +`progressr::handlers('cli')`.. -The `shapr` package can then explain not only the two autoregressive -lags, but also the single lag of the exogenous regressor. In order to do -so, the `Wind` variable is passed as the argument `xreg`, and -`explain_xreg_lags` is set to 1. Notice how only the first 151 -observations are used for `y` and all 153 are used for `xreg`. This -makes it possible for `shapr` to not only explain the effect of the -first lag of the exogenous variable, but also the contemporary effect -during the forecasting period. +For a full list of all progression handlers and the customization +options available with `progressr`, see the `progressr` +[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). +A full code example of using `progressr` with `shapr` is shown below: -```{r} -explanation_forecast <- explain_forecast( - model = model_ar_temp, - y = data_fit[, "Temp"], - xreg = data_ts3[, "Wind"], - train_idx = 2:150, - explain_idx = 151, - explain_y_lags = 2, - explain_xreg_lags = 1, - horizon = 2, +```{r,eval = FALSE} +library(progressr) +progressr::handlers(global = TRUE) +# If no progression handler is specified, the txtprogressbar is used +# Other progression handlers: +# progressr::handlers('rstudio') # requires the 'rstudioapi' package +# progressr::handlers('handler_winprogressbar') # Window only +# progressr::handlers('cli') # requires the 'cli' package +ex_progress <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, approach = "empirical", - prediction_zero = rep(mean(data_fit$Temp), 2), - group_lags = FALSE, - n_batches = 1, - timing = FALSE + phi0 = p0 ) -print(explanation_forecast$shapley_values) +handlers("progress") +#| [=================================>----------------------] 60% Estimating v(S) ``` + + +
@@ -1161,13 +1169,43 @@ features, using `"empirical", "copula"` and `"gaussian"` when conditioning on respectively 1, 2 and 3 features. ```{r} +library(xgboost) +library(data.table) + +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] + +x_var <- c("Solar.R", "Wind", "Temp", "Month") +y_var <- "Ozone" + +ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] + +# Set seed for reproducibility +set.seed(123) + +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) + +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) + + # Use the combined approach explanation_combined <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = c("empirical", "copula", "gaussian"), - prediction_zero = p0 + phi0 = p0 ) # Plot the resulting explanations for observations 1 and 6, excluding # the no-covariate effect @@ -1184,7 +1222,7 @@ explanation_combined <- explain( x_explain = x_explain, x_train = x_train, approach = c("ctree", "ctree", "empirical"), - prediction_zero = p0 + phi0 = p0 ) ``` @@ -1209,8 +1247,9 @@ explanation_group <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - group = group_list + phi0 = p0, + group = group_list, + iterative = FALSE ) # Prints the group-wise explanations explanation_group @@ -1315,7 +1354,7 @@ explanation_custom <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_predict_model, get_model_specs = MY_get_model_specs ) @@ -1339,7 +1378,7 @@ explanation_custom_minimal <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, + phi0 = p0, predict_model = MY_MINIMAL_predict_model ) @@ -1347,7 +1386,7 @@ explanation_custom_minimal <- explain( plot(explanation_custom_minimal, index_x_explain = c(1, 6)) ``` -### Tidymodels and workflows {#workflow_example} +## Tidymodels and workflows {#workflow_example} In this section, we demonstrate how to use `shapr` to explain `tidymodels` models fitted using `workflows`. In the example [above](#examples), we directly used the `xgboost` package to fit the `xgboost` model. However, we can also fit the `xgboost` model using the `tidymodels` package. These fits will be identical @@ -1388,15 +1427,14 @@ explanation_tidymodels <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 -) + phi0 = p0, + iterative = FALSE + ) # See that the Shapley value explanations are identical too -all.equal(explanation$shapley_values, explanation_tidymodels$shapley_values) +all.equal(explanation$shapley_values_est, explanation_tidymodels$shapley_values_est) ``` - ## The parameters of the `vaeac` approach The `vaeac` approach is a very flexible method that supports mixed data. The main @@ -1420,12 +1458,13 @@ explanation_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 3, - vaeac.n_vaeacs_initialize = 2 + vaeac.n_vaeacs_initialize = 2, + iterative = FALSE ) ``` @@ -1455,13 +1494,14 @@ explanation_vaeac_early_stop <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = p0, - n_samples = 100, + phi0 = p0, + n_MC_samples = 100, vaeac.width = 16, vaeac.depth = 2, vaeac.epochs = 1000, # Set it to a large number vaeac.n_vaeacs_initialize = 2, - vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) + vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2), + iterative = FALSE ) ``` @@ -1480,178 +1520,287 @@ Can also compare the $MSE_{v}$ evaluation scores. plot_MSEv_eval_crit(list("Vaeac 3 epochs" = explanation_vaeac, "Vaeac early stopping" = explanation_vaeac_early_stop)) ``` +## Continued computation {#cont_computation} +In this section, we demonstrate how to continue to improve estimation accuracy with additional coalition samples, +from a previous Shapley value computation based on `shapr::explain()` with the iterative estimation procedure. +This can be done either by passing an existing object of class `shapr`, or by passing a string with the path to +the intermediately saved results. +The latter is found at `SHAPR_OBJ$saving_path`, defaults to a temporary folder, +and is updated after each iteration. +This can be particularly handy for long-running computations. +```{r} +# First we run the computation with the iterative estimation procedure for a limited number of coalition samples +library(xgboost) +library(data.table) +data("airquality") +data <- data.table::as.data.table(airquality) +data <- data[complete.cases(data), ] - +x_var <- c("Solar.R", "Wind", "Temp", "Month","Day") +y_var <- "Ozone" -
+ind_x_explain <- 1:6 +x_train <- data[-ind_x_explain, ..x_var] +y_train <- data[-ind_x_explain, get(y_var)] +x_explain <- data[ind_x_explain, ..x_var] -# Scalability and efficency +# Set seed for reproducibility +set.seed(123) -## Batch computation +# Fitting a basic xgboost model to the training data +model <- xgboost::xgboost( + data = as.matrix(x_train), + label = y_train, + nround = 20, + verbose = FALSE +) -The computational complexity of Shapley value based explanations grows -fast in the number of features, as the number of conditional -expectations one needs to estimate in the Shapley formula grows -exponentially. As outlined [above](#KSHAP), the estimating of each of -these conditional expectations is also computationally expensive, -typically requiring estimation of a conditional probability -distribution, followed by Monte Carlo integration. These computations -are not only heavy for the CPU, they also require a lot of memory (RAM), -which typically is a limited resource. By doing the most resource hungry -computations (the computation of v(S)) in sequential batches with -different feature subsets $S$, the memory usage can be significantly -reduces. Such batching comes at the cost of an increase in computation -time, which depends on the number of feature subsets (`n_combinations`), -the number of features, the estimation `approach` and so on. When -calling `shapr::explain()`, we allow the user to set the number of -batches with the argument `n_batches`. The default of this argument is -`NULL`, which uses a (hopefully) reasonable trade-off between -computation speed and memory consumption which depends on -`n_combinations` and `approach`. The memory/computation time trade-off -is most apparent for models with more than say 6-7 features. Below we a -basic example where `n_batches=10`: +# Specifying the phi_0, i.e. the expected prediction without any features +p0 <- mean(y_train) -```{r} -explanation_batch <- explain( +# Initial explanation computation +ex_init <- explain( model = model, x_explain = x_explain, x_train = x_train, - approach = "empirical", - prediction_zero = p0, - n_batches = 10 + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 20, + iterative = TRUE +) + +# Using the ex_init object to continue the computation with 5 more coalition samples +ex_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = 25, + iterative_args = list(convergence_tol = 0.005), # Decrease the convergence threshold + prev_shapr_object = ex_init +) + +print(ex_further$saving_path) + +# Using the ex_init object to continue the computation for the remaining coalition samples +# but this time using the path to the saved intermediate estimation object +ex_even_further <- explain( + model = model, + x_explain = x_explain, + x_train = x_train, + approach = "gaussian", + phi0 = p0, + max_n_coalitions = NULL, + prev_shapr_object = ex_further$saving_path ) + + ``` -## Parallelized computation + -In addition to reducing the memory consumption, the introduction of the -`n_batch` argument allows computation within each batch to be performed in parallel. -The parallelization in `shapr::explain()` is handled by the -`future_apply` which builds on the `future` environment. The `future` -package works on all OS, allows the user to decide the parallelization -backend (mutliple R procesess or forking), works directly with hpc -clusters, and also supports progress updates for the parallelized task -(see below). +
-Note that, since it takes some time to duplicate data into different -processes/machines when running in parallel, it is not always -preferrable to run `shapr::explain()` in parallel, at least not with -many parallel sessions (hereby called **workers**). Parallelization also -increases the memory consumption proportionally, so you want to limit -the number of workers for that reason too. In a future version of -`shapr` we will provide experienced based automatic selection of the -number of workers. In the meanwhile, this is all let to the user, and we -advice that `n_batches` equals some positive integer multiplied by the -number of workers. Below is a basic example of a parallelization with -two workers. + +# Explaining a forecasting model using `explain_forecast` + +`shapr` provides a specific function, `explain_forecast`, to explain +forecasts from time series models, at one or more steps into the future. +The main difference compared to `explain` is that the data is supplied +as (set of) time series, in addition to index arguments (`train_idx` and +`explain_idx`) specifying which time points that represents the train +and explain parts of the data. See `?explain_forecast` for more +information. + +To demonstrate how to use the function, 500 observations are generated +which follow an AR(1) structure, i.e. +$y_t = 0.5 y_{t-1} + \varepsilon_t$. To this data an arima model of +order (2, 0, 0) is fitted, and we therefore would like to explain the +forecasts in terms of the two previous lags of the time series. This is +is specified through the argument `explain_y_lags = 2`. Note that some +models may also put restrictions on the amount of data required to make +a forecast. The AR(2) model we used there, for instance, requires two +previous time point to make a forecast. + +In the example, two separate forecasts, each three steps ahead, are +explained. To set the starting points of the two forecasts, +`explain_idx` is set to `499:500`. This means that one forecast of +$t = (500, 501, 502)$ and another of $t = (501, 502, 503)$, will be +explained. In other words, `explain_idx` tells `shapr` at which points +in time data was available up until, when making the forecast to +explain. + +In the same way, `train_idx` denotes the points in time used to estimate +the conditional expectations used to explain the different forecasts. +Note that since we want to explain the forecasts in terms of the two +previous lags (`explain_y_lags = 2`), the smallest value of `train_idx` +must also be 2, because at time $t = 1$ there was only a single +observation available. + +Since the data is stationary, the mean of the data is used as value of +`phi0` (i.e. $\phi_0$). This can however be chosen +differently depending on the data and application. + +For a multivariate model such as a VAR (Vector AutoRegressive model), it +may be of more interesting to explain the impact of each variable, +rather than each lag of each variable. This can be done by setting +`group_lags = TRUE`. ```{r} -library(future) -future::plan(multisession, workers = 2) +# Simulate time series data with AR(1)-structure. +set.seed(1) +data_ts <- data.frame(Y = arima.sim(list(order = c(1, 0, 0), ar = .5), n = 500)) +data_ts <- data.table::as.data.table(data_ts) -explanation_par <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +# Fit an ARIMA(2, 0, 0) model. +arima_model <- arima(data_ts, order = c(2, 0, 0)) + +# Set prediction zero as the mean of the data for each forecast point. +p0_ar <- rep(mean(data_ts$Y), 3) + +# Explain forecasts from points t = 499 and t = 500. +explain_idx <- 499:500 + +explanation_forecast <- explain_forecast( + model = arima_model, + y = data_ts, + train_idx = 2:498, + explain_idx = 499:500, + explain_y_lags = 2, + horizon = 3, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = p0_ar, + group_lags = FALSE ) - -future::plan(sequential) # To return to non-parallel computation +explanation_forecast ``` -## Progress updates +Note that for a multivariate model such as a VAR (Vector AutoRegressive +model), or for models also including several exogenous variables, it may +be of more informative to explain the impact of each variable, rather +than each lag of each variable. This can be done by setting +`group_lags = TRUE`. This does not make sense for this model, however, +as that would result in decomposing the forecast into a single group. -`shapr` provides progress updates of the computation of the Shapley -values through the R-package `progressr`. This gives the user full -control over the visual appearance of the progress updates, and also -integrates seamlessly with the parallelization framework `future` used -by `shapr` (see above). Note that the progress is updated as the batches -are completed, meaning that if you have chosen `n_batches=1`, you will -not get intermediate updates, while if you set `n_batches=10` you will -get updates on every 10% of the computation. +We now give a more hands on example of how to use the `explain_forecast` +function. Say that we have an AR(2) model which describes the change +over time of the variable `Temp` in the dataset `airquality`. It seems +reasonable to assume that the temperature today should affect the +temperature tomorrow. To a lesser extent, we may also suggest that the +temperature today should also have an impact on that of the day after +tomorrow. -Progress updates are enabled for the current R-session by running the -command `progressr::handlers(local=TRUE)`, before calling -`shapr::explain()`. To use progress updates for only a single call to -`shapr::explain()`, one can wrap the call using -`progressr::with_progress` as follows: -`progressr::with_progress({ shapr::explain() })` The default appearance -of the progress updates is a basic ASCII-based horizontal progress bar. -Other variants can be chosen by passing different strings to -`progressr::handlers()`, some of which require additional packages. If -you are using Rstudio, the progress can be displayed directly in the gui -with `progressr::handlers('rstudio')` (requires the `rstudioapi` -package). If you are running Windows, you may use the pop-up gui -progress bar `progressr::handlers('handler_winprogressbar')`. A wrapper -for progressbar of the flexible `cli` package is also available -`progressr::handlers('cli')` (requires the `cli` package). +We start by building our AR(2) model, naming it `model_ar_temp`. This +model is then used to make a forecast of the temperature of the day that +comes after the last day in the data, this forecast starts from index +153. -For a full list of all progression handlers and the customization -options available with `progressr`, see the `progressr` -[vignette](https://cran.r-project.org/web/packages/progressr/vignettes/progressr-intro.html). -A full code example of using `progressr` with `shapr` is shown below: +```{r} +data_ts2 <- data.table::as.data.table(airquality) -```{r,eval = FALSE} -library(progressr) -progressr::handlers(global = TRUE) -# If no progression handler is specified, the txtprogressbar is used -# Other progression handlers: -# progressr::handlers('rstudio') # requires the 'rstudioapi' package -# progressr::handlers('handler_winprogressbar') # Window only -# progressr::handlers('cli') # requires the 'cli' package -explanation <- explain( - model = model, - x_explain = x_explain, - x_train = x_train, +model_ar_temp <- ar(data_ts2$Temp, order = 2) + +predict(model_ar_temp, n.ahead = 2)$pred +``` + +First, we pass the model and the data as `model` and `y`. Since we have +an AR(2) model, we want to explain the forecasts in terms of the two +previous lags, whihc we specify with `explain_y_lags = 2`. Then, we let +`shapr` know which time indices to use as training data through the +argument `train_idx`. We use `2:152`, meaning that we skip the first +index, as we want to explain the two previous lags. Letting the training +indices go up until 152 means that every point in time except the first +and last will be used as training data. + +The last index, 153 is passed as the argument `explain_idx`, which means +that we want to explain a forecast made from time point 153 in the data. +The argument `horizon` is set to 2 in order to explain a forecast of +length 2. + +The argument `phi0` is set to the mean of the time series, +and is repeated two times. Each value of `phi0` is the +baseline for each forecast horizon. In our example, we assume that given +no effect from the two lags, the temperature would just be the average +during the observed period. Finally, we opt to not group the lags by +setting `group_lags` to `FALSE`. This means that lag 1 and 2 will be +explained separately. Grouping lags may be more interesting to do in a +model with multiple variables, as it is then possible to explain each +variable separately. + +```{r} +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_ts2[, "Temp"], + train_idx = 2:152, + explain_idx = 153, + explain_y_lags = 2, + horizon = 2, approach = "empirical", - prediction_zero = p0, - n_batches = 10 + phi0 = rep(mean(data$Temp), 2), + group_lags = FALSE ) -handlers("progress") -#| [=================================>----------------------] 60% Estimating v(S) +print(explanation_forecast) ``` +The results are presented per value of `explain_idx` and forecast +horizon. We can see that the mean temperature was around 77.9 degrees. +At horizon 1, the first lag in the model caused it to be 6.6 degrees +lower, and the second lag had just a minor effect. At horizon 2, the +first lag has a slightly smaller negative impact, and the second lag has +a slightly larger impact. +It is also possible to explain a forecasting model which uses exogenous +regressors. The previous example is expanded to use an ARIMA(2,0,0) +model with `Wind` as an exogenous regressor. Since the exogenous +regressor must be available for the predicted time points, the model is +just fit on the 151 first observations, leaving two observations of +`Wind` to be used as exogenous values during the prediction phase. +```{r} +data_ts3 <- data.table::as.data.table(airquality) + +data_fit <- data_ts3[seq_len(151), ] + +model_arimax_temp <- arima(data_fit$Temp, order = c(2, 0, 0), xreg = data_fit$Wind) + +newxreg <- data_ts3[-seq_len(151), "Wind", drop = FALSE] + +predict(model_arimax_temp, n.ahead = 2, newxreg = newxreg)$pred +``` + +The `shapr` package can then explain not only the two autoregressive +lags, but also the single lag of the exogenous regressor. In order to do +so, the `Wind` variable is passed as the argument `xreg`, and +`explain_xreg_lags` is set to 1. Notice how only the first 151 +observations are used for `y` and all 153 are used for `xreg`. This +makes it possible for `shapr` to not only explain the effect of the +first lag of the exogenous variable, but also the contemporary effect +during the forecasting period. + +```{r} +explanation_forecast <- explain_forecast( + model = model_ar_temp, + y = data_fit[, "Temp"], + xreg = data_ts3[, "Wind"], + train_idx = 2:150, + explain_idx = 151, + explain_y_lags = 2, + explain_xreg_lags = 1, + horizon = 2, + approach = "empirical", + phi0 = rep(mean(data_fit$Temp), 2), + group_lags = FALSE +) + +print(explanation_forecast$shapley_values_est) +``` - -
-# Comparison to Lundberg & Lee's implementation - -As mentioned above, the original (independence assuming) Kernel SHAP -implementation can be approximated by setting a large $\sigma$ value -using our empirical approach. If we specify that the distances to *all* -training observations should be used (i.e. setting -`approach = "empirical"` and `empirical.eta = 1` when using `explain`, -we can approximate the original method arbitrarily well by increasing -$\sigma$. For completeness of the `shapr` package, we have also -implemented a version of the original method, which samples training -observations independently with respect to their distances to test -observations (i.e. without the large-$\sigma$ approximation). This -method is available by using `approach = "independence"` in `explain`. - -We have compared the results using these two variants with the original -implementation of @lundberg2017unified, available through the Python -library [`shap`](https://github.com/slundberg/shap). As above, we used -the Boston housing data, trained via `xgboost`. We specify that *all* -training observations should be used when explaining all of the 6 test -observations. To run the individual explanation method in the `shap` -Python library we use the `reticulate` `R`-package, allowing Python code -to run within `R`. As this requires installation of Python package, the -comparison code and results is not included in this vignette, but can be -found -[here](https://github.com/NorskRegnesentral/shapr/blob/master/inst/scripts/compare_shap_python.R). -As indicated by the (commented out) results in the file above both -methods in our `R`-package give (up to numerical approximation error) -identical results to the original implementation in the Python `shap` -library.
diff --git a/vignettes/understanding_shapr_asymmetric_causal.Rmd b/vignettes/understanding_shapr_asymmetric_causal.Rmd new file mode 100644 index 0000000000000000000000000000000000000000..cd1e3e55c354cbd0480e280140ebcc7f6d6805b7 --- /dev/null +++ b/vignettes/understanding_shapr_asymmetric_causal.Rmd @@ -0,0 +1,2054 @@ +--- +title: "Asymmetric and causal Shapley value explanations" +author: "Lars Henry Berge Olsen" +output: + rmarkdown::html_vignette: + toc: true + fig_caption: yes +bibliography: ../inst/REFERENCES.bib +vignette: > + %\VignetteEncoding{UTF-8} + %\VignetteIndexEntry{Asymmetric and causal Shapley value explanations} + %\VignetteEngine{knitr::rmarkdown} +editor_options: + markdown: + wrap: 72 + toc: true +--- + + + + +# Overview {#Vignette} + +This vignette elaborates and demonstrates the asymmetric and +causal Shapley value frameworks introduced by @frye2020asymmetric +and @heskes2020causal, respectively. We also consider the marginal +and conditional Shapley value frameworks, see @lundberg2017unified +and @aas2019explaining, respectively. We demonstrate the frameworks +on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. The setup is +based on the `CauSHAPley` package, which is the +[code supplement](https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html) +to the @heskes2020causal paper. The `CauSHAPley` package was based +on an old version of `shapr` and was restricted to the `gaussian` approach (see section 6 in @heskes2020causal for more details). + +We have extended the causal Shapley value framework to work for all +Monte Carlo-based approaches (`independence` (not recommended), `empirical`, `gaussian`, `copula`, `ctree`, `vaeac` and `categorical`), while the extension of the asymmetric +Shapley value framework works for both the Monte Carlo and regression-based approaches. +Our generalization is of uttermost importance, as many real-world data sets +are far from the Gaussian distribution, and, compared to `CauSHAPley`, our implementation +can utilize all of `shapr`'s new features, such as batch computation, parallelization and +iterative computation for both feature-wise and group-wise Shapley values. + +The main differences between the marginal, conditional, and casual Shapley value +frameworks is that they sample/generate the Monte Carlo samples from the +marginal distribution, (conventional) observational conditional distribution, +and interventional conditional distribution, respectively. Asymmetric means +that we do not consider all possible coalitions, but rather only the coalitions +that respects a causal ordering. + + + +# Asymmetric conditional Shapley values {#AsymSV} + +Asymmetric (conditional) Shapley values were proposed by @frye2020asymmetric as +a way to incorporate causal knowledge in the real world by computing the Shapley +value explanations using only the feature combinations/coalitions consistent with +a (partial) causal ordering. See the figure below for a schematic overview of the causal ordering we are going to use in the examples in this vignette. In the figure, we see +that our causal ordering consists of three components: $\tau_1 = \{X_1\}$, $\tau_2 = \{X_2, X_3\}$, and $\tau_3 = \{X_4, X_5, X_6, X_7\}$. See the [code section](#Code) for what the features represent. + +To elaborate, instead of considering the $2^M$ possible coalitions, +where $M$ is the number of features, asymmetric Shapley values only +consider the subset of coalitions which respects the causal ordering. +For our causal ordering, this means that the asymmetric Shapley value explanation +framework skips the coalitions where $X_2$ is included but \textit{not} $X_1$, +as $X_1$ is the ancestor of $X_2$. This will skew the explanations towards +distal/root causes, see Section 3.2 in @frye2020asymmetric. + +We can use all approaches in `shapr`, both Monte Carlo-based and +regression based methods, to compute the asymmetric Shapley values. +This is because the asymmetric Shapley value explanation framework does not change +how we compute the contribution functions $v(S)$, but rather which of +the coalitions $S$ that are used to compute the Shapley value explanations. +This means that the number of coalitions are no longer $O(2^M)$, but rather +$O(2^{\tau_0})$, where $\tau_0 = \operatorname{max}_i |\tau_i|$ +is the number of features ($|\tau_i|$) in the largest component of the causal ordering. + +Furthermore, asymmetric Shapley values supports groups of features, but +then the causal ordering must be given on the group level instead of on the +feature level. The asymmetric Shapley value framework also supports +sampling of coalitions where the sampling is done from the +set of coalitions that respects the causal ordering. + +Finally, we want make a remark that asymmetric conditional Shapley values are +equivalent to asymmetric causal Shapley values (see below) when we only +use the coalitions respecting the causal ordering and assuming that all +dependencies within chain components are induced by mutual interactions. + + +
+Schematic overview of the causal ordering used in this vignette. +

Schematic overview of the causal ordering used in this vignette.

+
+ + +# Causal Shapley values {#CausSV} + +Causal Shapley values were proposed by @heskes2020causal as a way +to explain the total effect of features on the prediction by taking +into account their causal relationships and adapting the sampling +procedure in `shapr`. More precisely, they propose to employ Pearl’s +do-calculus to circumvent the independence assumption, made by +@lundberg2017unified, without sacrificing any of the desirable +properties of the Shapley value framework. The causal Shapley value +explanation framework can also separate the contribution of direct +and indirect effects, which makes them principally different from +marginal and conditional Shapley values. The framework also provides +a more direct and robust way to incorporate causal knowledge, compared +to the asymmetric Shapley value explanation framework. + +To compute causal Shapley values, we have to specify a (partial) causal +ordering and make an assumption about the confounding in each component. +Together, they form a causal chain graph which contains directed and undirected +edges. All features that are treated on an equal footing are linked +together with undirected edges and become part of the same chain component. +Edges between chain components are directed and represent causal relationships. +In the figure below, we have the same causal ordering as above, but we +have in addition made the assumption that we have confounding in the +second component, but no confounding in the first and third components. +This allows us to correctly distinguishes between dependencies that are +due to confounding and mutual interactions. That is, in the figure, +the dependencies in chain component $\tau_2$ are assumed to be the result +of a common confounder, and those in $\tau_3$ of mutual interactions, while +we have no mutual interactions in $\tau_1$ as it is a singleton. + +Computing the effect of an intervention depends on how we interpret the +generative process that lead to the feature dependencies within each component. +If they are the result of marginalizing out a common confounder, +then intervention on a particular feature will break the dependency +with the other features, and we denote the set of these chain components +by $\mathcal{T}_{\text{confounding}}$. For the components with mutual +feature interactions, setting the value of a feature effects the +distribution of the variables within the same component. We denote +the set of these components by $\mathcal{T}_{\,\overline{\text{confounding}}}$. + +@heskes2020causal described how any expectation by intervention needed +to compute the causal Shapley values can be translated to an expectation +by observation, by using the interventional formula for causal chain graphs: +\begin{align} +\label{eq:do} +P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) += & +\prod_{\tau \in \mathcal{T}_{\,\text{confounding}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}) \times \tag{1} \\ +& \quad +\prod_{\tau \in \mathcal{T}_{\,\overline{\text{confounding}}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}, x_{\tau \cap \mathcal{S}}). +\end{align} +Here, any of the Monte Carlo-based approaches in `shapr` can be +used to compute the conditional distributions/observational expectations. The marginals +are estimated from the training data for all approaches except +`gaussian`, for which we use the marginals of the Gaussian +distribution instead. + +For specific causal chain graphs, the causal Shapley value framework +simplifies to symmetric conditional, asymmetric conditional, and marginal +Shapley values, see Corollary 1 to 3 in the supplement of @heskes2020causal. + + + +``` +#> Error in knitr::include_graphics("figure_asymmetric_causal/causal_ordering.png"): Cannot find the file(s): "figure_asymmetric_causal/causal_ordering.png" +``` + + +# Marginal Shapley values {#MarginaSV} +Causal Shapley values are equivalent to marginal Shapley values when all $M$ +features are combined into a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by confounding. Then $\text{pa}(\tau) = \emptyset$, and +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation (\ref{eq:do}) +simplifies to $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}})$, +as specified in @lundberg2017unified. + +The Monte Carlo samples for the marginals are generated by sampling from the +training data, except for the `gaussian` approach where we use the marginals +of the estimated multivariate Gaussian distribution. This means that for all +other approaches, this is the same as using the `independence` approach +in the conditional Shapley value explanation framework. + +# Symmetric conditioal Shapley values {#ConditionalSV} +Causal Shapley values are equivalent to symmetric conditional Shapley values when all $M$ +features are combined in a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by mutual interaction. Then $\text{pa}(\tau) = \emptyset$, +and $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation +(\ref{eq:do}) simplifies to +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}} \mid X_\mathcal{S} = x_\mathcal{S})$, +as specified in @aas2019explaining. Symmetric means that we consider all coalitions. + + + + + + +# Code example +## Overview +We demonstrate the frameworks on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. We let the features be the +number of days since January 2011 (`trend`), two cyclical variables representing +the season (`cosyear`, `sinyear`), temperature (`temp`), feeling temperature +(`atemp`), wind speed (`windspeed`), and humidity (`hum`). The first three features +are considered to be a potential cause of the four weather-related features. +The bike rental is strongly seasonal and shows an upward trend, as illustrated in the figure below. +The bike data is split randomly into a training (80%) and test/explicand (20%) set. +We train an `XGBoost` model for 100 rounds with default variables to act as the model +we want to explain. + +In the table below, we highlight the Shapley value explanation frameworks introduced above +and how to access them by changing the arguments `asymmetric`, `ordering`, and `confounding` in `shapr::explain()`. +Note that symmetric conditional Shapley values are the default version, i.e., by default +`asymmetric = FALSE`, `ordering = NULL`, `confounding = NULL`. + +| Framework | Sampling | Approaches | `asymmetric` | `ordering` | `confounding` | +|:-------------------|:-----------------------|:---------------------|:-------------|:------------|:--------------| +| Sym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `FALSE` | `NULL` | `NULL` | +| Asym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `TRUE` | `list(...)` | `NULL` | +| Sym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `FALSE` | `list(...)` | `c(...)` | +| Asym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `TRUE` | `list(...)` | `c(...)` | +| Sym. Marginal | $P(X_{\bar{\mathcal{S}}})$ | `indep.`, `gaussian` | `FALSE` | `NULL` | `TRUE` | + + + +## Code setup +First, we load the needed libraries, set up the training/explicand data, plot the data, and train an `xgboost` model. + +``` r +# Libraries +library(ggplot2) +require(GGally) +library(ggpubr) +library(gridExtra) +library(xgboost) +library(data.table) +library(shapr) + +# Ensure that shapr's functions are prioritzed, otherwise we need to use the `shapr::` +# prefix when calling explain(). The `conflicted` package is imported by `tidymodels`. +conflicted::conflicts_prefer(shapr::explain, shapr::prepare_data) + +# Set up the data +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) +bike <- read.csv("../inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Plot the data +ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) +``` + +![](figure_asymmetric_causal/setup_1-1.png) + + +``` r +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# Load the training-test split. 80% training and 20% test +train_index <- readRDS("../inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Plot pairs plot +GGally::ggpairs(x_train) +``` + +![](figure_asymmetric_causal/setup_2-1.png) + + +``` r +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Get 6 explicands to plot the Shapley values of with a wide spread in their predicted outcome +n_index_x_explain <- 6 +index_x_explain <- order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +y_explain[index_x_explain] +#> [1] -3900.0324 -1872.0324 -377.0324 411.9676 1690.9676 3889.9676 + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) + +# Save the phi0 +phi0 <- mean(y_train) + +# Look at the root mean squared error +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +#> [1] 798.7148 +ggplot( + data.table("response" = y_explain[, 1], "predicted_response" = predict(model, x_explain)), + aes(response, predicted_response) +) + + geom_point() +``` + +![](figure_asymmetric_causal/setup_3-1.png) + + +We are going to use the `causal_ordering` and `confounding` illustrated in the figures above. +For `causal_ordering`, we can either provide the index of feature or the feature names. +Thus, the following two versions of `causal_ordering` will produce equivalent results. +Furthermore, we assume that we have confounding for the second component (i.e., the season has +an effect on the weather) and no confounding for the third component (i.e., we do not +how to model the intricate relations between the weather features). + + +``` r +causal_ordering <- list(1, c(2, 3), c(4:7)) +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +To make the rest of the vignette easier to follow, we create some helper +functions that plot and summarize the results of the explanation methods. +This code block is optional to understand and can be skipped. + + +``` r +# Extract the MSEv criterion scores and elapsed times +print_MSEv_scores_and_time <- function(explanation_list) { + res <- as.data.frame(t(sapply( + explanation_list, + function(explanation) { + round(c( + explanation$MSEv$MSEv$MSEv, + explanation$MSEv$MSEv$MSEv_sd, + difftime(explanation$timing$end_time, explanation$timing$init_time, units = "secs") + ), 2) + } + ))) + colnames(res) <- c("MSEv", "MSEv_sd", "Time (secs)") + return(res) +} + +# Print the full time information +print_time <- function(explanation_list) { + t(sapply(explanation_list, function(explanation) explanation$timing$total_time_secs)) +} + +# Make beeswarm plots +plot_beeswarms <- function(explanation_list, title = "", ...) { + # Make the beeswarm plots + grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm", ...) + + ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) + }) + + # Get the limits + ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) + ylim <- c(min(ylim), max(ylim)) + + # Update the limits + grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + + # Make the combined plot + gridExtra::grid.arrange( + grobs = grobs, ncol = 1, + top = grid::textGrob(title, gp = grid::gpar(fontsize = 18, font = 8)) + ) +} +``` + + + +## Symmetric conditional Shapley values (default) +We start by demonstrating how to compute symmetric conditional Shapley values. +This is the default version in `shapr` and there is no need to specify the arguments below. +However, we have specified them for the sake of clarity. +We use the `gaussian`, `ctree`, and `regression_separate`(`xgboost` with default hyperparameters) +approaches, but any other approach can also be used. + + + +``` r +# list to store the results +explanation_sym_con <- list() + +explanation_sym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:04:06 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e573f83ea.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 38 of 128 coalitions, 2 new. + +explanation_sym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:04:14 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e33ec0fdc.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 54 of 128 coalitions, 18 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 64 of 128 coalitions, 10 new. + +explanation_sym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + asymmetric = FALSE, # Default value (TRUE will give the same as `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:02 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e74d3a17a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +#> +#> ── Iteration 4 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 54 of 128 coalitions, 18 new. +#> +#> ── Iteration 5 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 64 of 128 coalitions, 10 new. +``` +We can then look at the $\operatorname{MSE}_v$ evaluation scores to compare the approaches. +All approaches are comparable, but `xgboost` is clearly the fastest approach. + + +``` r +print_MSEv_scores_and_time(explanation_sym_con) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1098008 77896.33 8.17 +#> ctree 1095957 69223.49 48.53 +#> xgboost 1154565 66463.44 9.82 +``` + +We can then plot the Shapley values for the six explicands chosen above. + + +``` r +plot_SV_several_approaches(explanation_sym_con, index_x_explain) + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/explanation_sym_con_SV-1.png) + + + +We can also make beeswarm plots of the Shapley values to look at the structure +of the Shapley values for all explicands. The figures are quite similar, but +with minor differences. E.g., the `gaussian` approach produces almost no +Shapley values around $500$ for the `trend` feature. + + +``` r +plot_beeswarms(explanation_sym_con, title = "Symmetric conditional Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_con_beeswarm-1.png) + + + + +## Asymmetric conditional Shapley values +Then we look at the asymmetric conditional Shapley values. To obtain these +types of Shapley values, we have to specify that `asymmetric = TRUE` and a +`causal_ordering`. We use `causal_ordering = list(1, c(2, 3), c(4:7))`. + + + +``` r +explanation_asym_con <- list() + +explanation_asym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:14 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1f04f5c1.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +explanation_asym_con[["gaussian_non_iterative"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value + iterative = FALSE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:16 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e265c8e5c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 20 of 20 coalitions. + +explanation_asym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:18 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e5d25406a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +explanation_asym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:26 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e36b9166e.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. +``` + +The asymmetric conditional Shapley value framework is faster as we only +consider $20$ coalitions (including empty and grand coalition) +instead of all $128$ coalitions (see code below). + + +``` r +print_MSEv_scores_and_time(explanation_asym_con) +#> MSEv MSEv_sd Time (secs) +#> gaussian 330603.3 36828.70 1.66 +#> gaussian_non_iterative 306457.7 35411.60 1.52 +#> ctree 260562.1 29428.95 8.75 +#> xgboost 307562.1 39362.81 1.60 + +# Look at the number of coalitions considered. Decreased from 128 to 20. +explanation_sym_con$gaussian$internal$parameters$max_n_coalitions +#> [1] 128 +explanation_asym_con$gaussian$internal$parameters$max_n_coalitions +#> [1] 20 + +# Here we can see the 20 coalitions that respects the causal ordering +explanation_asym_con$gaussian$internal$objects$dt_valid_causal_coalitions[["coalitions"]] +#> [[1]] +#> integer(0) +#> +#> [[2]] +#> [1] 1 +#> +#> [[3]] +#> [1] 1 2 +#> +#> [[4]] +#> [1] 1 3 +#> +#> [[5]] +#> [1] 1 2 3 +#> +#> [[6]] +#> [1] 1 2 3 4 +#> +#> [[7]] +#> [1] 1 2 3 5 +#> +#> [[8]] +#> [1] 1 2 3 6 +#> +#> [[9]] +#> [1] 1 2 3 7 +#> +#> [[10]] +#> [1] 1 2 3 4 5 +#> +#> [[11]] +#> [1] 1 2 3 4 6 +#> +#> [[12]] +#> [1] 1 2 3 4 7 +#> +#> [[13]] +#> [1] 1 2 3 5 6 +#> +#> [[14]] +#> [1] 1 2 3 5 7 +#> +#> [[15]] +#> [1] 1 2 3 6 7 +#> +#> [[16]] +#> [1] 1 2 3 4 5 6 +#> +#> [[17]] +#> [1] 1 2 3 4 5 7 +#> +#> [[18]] +#> [1] 1 2 3 4 6 7 +#> +#> [[19]] +#> [1] 1 2 3 5 6 7 +#> +#> [[20]] +#> [1] 1 2 3 4 5 6 7 +``` + +We can then look at the beeswarm plots of the asymmetric conditional Shapley value. +The `ctree` and `xgboost` approaches produce similar figures, while the `gaussian` +approach both shrinks and groups the Shapley values for the `trend` feature, while +it produces more negative values for the `cosyear` feature. + +When going from symmetric to asymmetric Shapley values, we see that many of the features' +Shapley values are now shrunken closer to zero, especially `temp` and `atemp`. + + +``` r +plot_beeswarms(explanation_asym_con, title = "Asymmetric conditional Shapley values") +``` + +![](figure_asymmetric_causal/explanation_asym_con_beeswarm-1.png) + + + +We can also compare the obtained symmetric and asymmetric conditional Shapley values +for the 6 explicands. We often see that the asymmetric version gives larger Shapley +values to the distal/root causes, i.e., `trend` and `cosyear`, than the symmetric +version. This is in line with Section 3.2 in @frye2020asymmetric. + +``` r +# Order the symmetric and asymmetric conditional explanations into a joint list +explanation_sym_con_tmp <- copy(explanation_sym_con) +names(explanation_sym_con_tmp) <- paste0(names(explanation_sym_con_tmp), "_sym") +explanation_asym_con_tmp <- copy(explanation_asym_con) +names(explanation_asym_con_tmp) <- paste0(names(explanation_asym_con_tmp), "_asym") +explanation_asym_sym_con <- c(explanation_sym_con_tmp, explanation_asym_con_tmp)[c(1, 4, 2, 5, 3, 6)] +plot_SV_several_approaches(explanation_asym_sym_con, index_x_explain, brewer_palette = "Paired") + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/sym_and_asym_Shapley_values-1.png) + + + + +## Symmetric marginal Shapley values +For marginal Shapley values, we can only consider the symmetric version as we must set +`causal_ordering = list(1:7)` (or `NULL`) and `confounding = TRUE`. Setting `asymmetric = TRUE` +will have no effect, as the causal ordering consists of only a single component containing all features, +i.e., all coalitions respect the causal ordering. As stated above, `shapr` generates the +marginal Monte Carlos samples from the Gaussian marginals if `approach = "gaussian"`, +while for all other Monte Carlo approaches the marginals are estimated from the training data, i.e., +assuming feature independence. Thus, it does not matter if we set `approach = "independence"` +or any other of the Monte Carlo-based approaches. We use `approach = "independence"` for clarity. +Furthermore, we also obtain marginal Shapley values by using the +conditional Shapley value framework with the `independence` approach. However, note that there will +be a minuscule difference in the produced Shapley values due to different sampling setups/orders. + + +``` r +explanation_sym_marg <- list() + +# Here we sample from the estimated Gaussian marginals +explanation_sym_marg[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:30 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ea85dbd1.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +# Here we sample from the marginals of the training data +explanation_sym_marg[["independence_marg"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:41 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp, atemp, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e439e741c.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +# Here we use the conditional Shapley value framework with the `independence` approach +explanation_sym_marg[["independence_con"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence" +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:05:48 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1d1af448.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +``` + + +We can look the beeswarm plots + + +``` r +print_MSEv_scores_and_time(explanation_sym_marg) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1383295 111844.8 10.20 +#> independence_marg 1382080 111150.6 7.61 +#> independence_con 1382544 111313.8 10.45 + +plot_beeswarms(explanation_sym_marg, title = "Symmetric marginal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_mar_beeswarm-1.png) + + + +## Causal Shapley values +To compute (symmetric/asymmetric) causal Shapley values, we have to provide +the `causal_ordering` and `confounding` objects. We set them to be +`causal_ordering = list(1, 2:3, 4:7)` and `confounding = c(FALSE, TRUE, FALSE)`, +as explained above. + +The causal framework takes longer than the other frameworks, as generating the +the Monte Carlo samples often consists of a chain of sampling steps. For example, +for $\mathcal{S} = {2}$, we must generate $X_1,X_3,X_4,X_5,X_6,X_7 \mid X_2$. +However, we cannot do this directly due to the `causal_ordering` and `confounding` +specified above. To generate the Monte Carlo samples, we have to follow a chain of +sampling steps. More precisely, we first need to generate $X_1$ from the marginal, +then $X_3 \mid X_1$, and finally $X_4,X_5,X_6,X_7 \mid X_1,X_2,X_3$. The latter two +steps are done by using the provided `approach` to model the conditional distributions. +The `internal$objects$S_causal_steps_strings` object contains the sampling steps +needed for the different feature combinations/coalitions $\mathcal{S}$. + +For causal Shapley values, only the Monte Carlo-based approaches are applicable. + +### Symmetric + +``` r +explanation_sym_cau <- list() + +explanation_sym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + iterative = FALSE, # Set to FALSE to get a single iteration to illustrate sampling steps below + exact = TRUE +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:00 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1488ce0d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 128 of 128 coalitions. + +# Look at the sampling steps for the third coalition (S = {2}) +explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps_strings$id_coalition_3 +#> [1] "1|" "3|1" "4,5,6,7|1,2,3" + +# Use the copula approach +explanation_sym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 128, +#> and is therefore set to 2^n_features = 128. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:30 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ed2cf629.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. +``` + + + +``` r +print_MSEv_scores_and_time(explanation_sym_cau) +#> MSEv MSEv_sd Time (secs) +#> gaussian 1113795 85800.41 29.68 +#> copula 1137608 88376.95 21.05 +plot_beeswarms(explanation_sym_cau, title = "Symmetric causal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_sym_cau_beeswarm-1.png) + + + + +### Asymmetric +We now turn to asymmetric causal Shapley values. That is, we only use the coalitions +that respects the causal ordering. Thus, the computations are faster as the number of +coalitions are reduced. + + +``` r +explanation_asym_cau <- list() + +explanation_asym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:51 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e1f0d44a2.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +# Use the copula approach +explanation_asym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:54 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7ee9098e0.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 14 of 20 coalitions, 1 new. + +# Use the ctree approach (warning: ctree is slow) +explanation_asym_cau[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:06:58 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e65db9137.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 20 coalitions, 13 new. + +# Use the vaeac approach +explanation_asym_cau[["vaeac"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "vaeac", + vaeac.epochs = 20, + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 20, and is therefore set to 20. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:41:21 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • Adaptive estimation: FALSE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e55cf411.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 20 of 20 coalitions. +``` +We can look at the elapsed time. We see that `ctree` is much slower than the other approaches. +See the [implementation details](#Implementation_details) for an explanation. + +``` r +print_time(explanation_asym_cau) +#> gaussian copula ctree vaeac +#> [1,] 2.406503 3.795967 34.39729 8.143235 +``` +We can then plot the beeswarm plots. We see that `ctree` provides more spread out Shapley values for the `trend` feature. + + + +``` r +# Plot the beeswarm plots +plot_beeswarms(explanation_asym_cau, title = "Asymmetric causal Shapley values") +``` + +![](figure_asymmetric_causal/explanation_asym_cau_beeswarm-1.png) + + +``` r +# Plot the Shapley values +plot_SV_several_approaches(explanation_asym_cau, index_x_explain) + + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/explanation_asym_cau_SV-1.png) + +We can also use the other Monte Carlo-based approaches (`independence` and `empirical`), too. + + + +## Comparing the frameworks +Here we plot the obtained Shapley values for the six explicand when using the +`gaussian` approach in the different Shapley value explanation frameworks, and +we see that the different frameworks provide different explanations. +The largest difference are between +whether we use the symmetric or asymmetric version. To summarize, asymmetric +conditional/causal Shapley values focus on the root cause, marginal Shapley +values on the more direct effect, and symmetric conditional/causal Shapley +consider both for a more natural explanation. + + +``` r +explanation_gaussian <- list( + symmetric_marginal = explanation_sym_marg$gaussian, + symmetric_conditional = explanation_sym_con$gaussian, + symmetric_causal = explanation_sym_cau$gaussian, + asymmetric_conditional = explanation_asym_con$gaussian, + asymmetric_causal = explanation_asym_cau$gaussian +) + +plot_SV_several_approaches(explanation_gaussian, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) + + ggtitle("Shapley value prediction explanation (approach = 'gaussian')") + + guides(color = guide_legend(title = "Framework")) +``` + +![](figure_asymmetric_causal/compare_plots-1.png) + +## Scatter plots: marginal vs. causal Shapley values +In this section, we produce scatter plots comparing the symmetric marginal +and symmetric causal Shapley values for the temperature feature `temp` and +the seasonal feature `cosyear` for all explicands. The plots shows that the +marginal Shapley values almost purely explain the predictions based on +temperature, while the causal Shapley values also give credit to season. +We can change the features and frameworks in the code below, but we chose +these values to replicate Figure 3 in @heskes2020causal. + + + +``` r +# The color of the points +color <- "temp" + +# The features we want to compare +feature_1 <- "cosyear" +feature_2 <- "temp" + +# The Shapley value frameworks we want to compare +sv_framework_1 <- explanation_sym_marg[["gaussian"]] +sv_framework_1_str <- "Marginal SV" +sv_framework_2 <- explanation_sym_cau[["gaussian"]] +sv_framework_2_str <- "Causal SV" + +# Set up the data.frame we are going to plot +sv_correlation_df <- data.frame( + color = x_explain[, color], + sv_framework_1_feature_1 = sv_framework_1$shapley_values_est[[feature_1]], + sv_framework_2_feature_1 = sv_framework_2$shapley_values_est[[feature_1]], + sv_framework_1_feature_2 = sv_framework_1$shapley_values_est[[feature_2]], + sv_framework_2_feature_2 = sv_framework_2$shapley_values_est[[feature_2]] +) + +# Make the plots +scatterplot_topleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_blank(), + axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), + axis.title.x = element_blank() + ) + +scatterplot_topright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.title.x = element_blank(), + axis.title.y = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +scatterplot_bottomleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.text.y = element_text(size = 12) + ) + +scatterplot_bottomright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.title.y = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +# Plot of the trend of the data +bike_plot_new <- ggplot(bike, aes(x = trend, y = cnt, color = get(color))) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(color = color) + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +# Combine the plots +ggpubr::ggarrange( + bike_plot_new, + ggpubr::ggarrange( + scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none" + ), + nrow = 2, heights = c(1, 2) +) +``` + +![](figure_asymmetric_causal/scatter_plots-1.png) + +## Investigating two similar days + +We investigate the difference between symmetric/asymmetric conditional, +symmetric/asymmetric causal, and marginal Shapley values for two days: +October 10 and December 3, 2012. They have more or less the same +temperature of 13 and 13.27 degrees Celsius, and predicted bike counts +of 6117 and 6241, respectively. The figure below is an extension of +Figure 4 in @heskes2020causal, as they only included asymmetric +conditional, symmetric causal, and marginal Shapley values. + +We plot the various Shapley values for the `cosyear` and `temp` features +below. We obtain the same results as @heskes2020causal obtained, namely, +that the marginal Shapley value explanation framework provides similar +explanation for both days. I.e., it only considers the direct effect of `temp`. +The asymmetric conditional and causal Shapley values are almost +indistinguishable and put the most weight on the ‘root’ cause `cosyear`. +@heskes2020causal states that the symmetric causal Shapley values provides +a sensible balance between the two extremes and gives credit to both season and temperature, +but still different explanation for the two days. + +However, as we also include symmetric conditional Shapley values, +we see that they are extremely similar to symmetric causal Shapley values. +I.e., the conditional Shapley value explanation framework also provides +a sensible balance between marginal and asymmetric Shapley values. +To summarize: +as concluded by @heskes2020causal in their Figure 4, the +asymmetric conditional/causal Shapley values focus on the +root cause, marginal Shapley values on the more direct effect, and symmetric +conditional/causal Shapley consider both for a more natural explanation. + + +``` r +# Features of interest +features <- c("cosyear", "temp") + +# Get explicands with similar temperature: 2012-10-09 (October) and 2012-12-03 (December) +dates <- c("2012-10-09", "2012-12-03") +dates_idx <- sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) # predicted values for the two points + +# List of the Shapley value explanations +explanations <- list( + "Sym. Mar." = explanation_sym_marg[["gaussian"]], + "Sym. Con." = explanation_sym_con[["gaussian"]], + "Sym. Cau." = explanation_sym_cau[["gaussian"]], + "Asym. Con." = explanation_asym_con[["gaussian"]], + "Asym. Cau." = explanation_asym_cau[["gaussian"]] +) + +# Extract the relevant Shapley values +explanations_extracted <- data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[ + dates_idx, ..features + ][, `:=`(Date = dates, type = names(explanations)[idx])] +})) + +# Set type to be a ordered factor +explanations_extracted[, type := factor(type, levels = names(explanations), ordered = TRUE)] + +# Convert from wide to long data table +dt_all <- data.table::melt(explanations_extracted, + id.vars = c("Date", "type"), + variable.name = "feature" +) + +# Make the plot +ggplot(dt_all, aes( + x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2) +)) + + geom_col(position = "dodge") + + theme_classic() + + ylab("Shapley value") + + facet_wrap(vars(type)) + + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c("indianred4", "ivory4")) + + theme( + legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14) + ) +``` + +![](figure_asymmetric_causal/two_dates_1-1.png) + +We can also make a similar plot using the `plot_SV_several_approaches` function in `shapr`, +but then we get each explicand in a separate facet instead of a facet for each framework. + +``` r +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + only_these_features = features, # Can include more features. + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_n_dodge = 1 +) + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/two_dates_2-1.png) + +Furthermore, instead of doing as @heskes2020causal and only considering the features +`cosyear` and `temp`, we can plot all features, too, to get a more complete overview. + +``` r +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_rotate_angle = 45, + digits = 2 +) + theme(legend.position = "bottom") +``` + +![](figure_asymmetric_causal/two_dates_3-1.png) + + +## Sampling of coalitions + +We can use `max_n_coalitions` to specify/reduce the number of coalitions +to use when computing the Shapley value explanation framework. This applies +to marginal, conditional, and causal Shapley values, both the symmetric and +asymmetric versions. However, recall that the asymmetric versions already +have fewer valid coalitions due to the causal ordering. + +In the example below, we demonstrate the sampling of coalitions for the +asymmetric and symmetric causal Shapley value explanation frameworks. +We half the number of coalitions for both versions +and see that the elapsed times are approximately halved, too. + +``` r +explanation_n_coal <- list() + +explanation_n_coal[["sym_cau_gaussian_64"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + max_n_coalitions = 64 # Instead of 128 +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:37 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e33b2f318.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 128 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 128 coalitions, 12 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 36 of 128 coalitions, 10 new. + +explanation_n_coal[["asym_cau_gaussian_10"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + paired_shap_sampling = FALSE, + verbose = c("basic", "convergence", "shapley"), + max_n_coalitions = 10 # Instead of 20 +) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:49 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of feature-wise Shapley values: 7 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 20 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp, atemp, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e3cd42aab.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 10 of 10 coalitions, 10 new. +#> +#> ── Convergence info +#> ✔ Converged after 10 coalitions: +#> Maximum number of coalitions reached! +#> +#> ── Final estimated Shapley values (sd) +#> none trend cosyear sinyear temp +#> +#> 1: 0.00 (0.00) -2181.910 (374.04) -825.541 (352.22) -236.730 (257.69) -33.813 ( 53.47) +#> 2: 0.00 (0.00) -2174.357 (371.76) -846.615 (359.93) -187.083 (274.61) -44.966 ( 55.78) +#> 3: 0.00 (0.00) -2088.959 (360.14) -793.628 (341.68) -186.335 (247.79) -104.809 ( 41.77) +#> 4: 0.00 (0.00) -2103.364 (368.62) -798.135 (356.43) -110.331 (268.86) 169.736 (102.28) +#> 5: 0.00 (0.00) -2003.877 (349.11) -723.936 (323.40) -231.863 (226.28) 36.505 ( 31.04) +#> --- +#> 140: 0.00 (0.00) 1575.954 (585.55) -1014.078 (542.93) 236.336 (357.62) -68.170 (206.64) +#> 141: 0.00 (0.00) 1588.686 (607.20) -1057.223 (537.25) 33.370 (256.33) 1.919 ( 28.84) +#> 142: 0.00 (0.00) 1466.745 (593.37) -1109.151 (522.54) -96.687 (257.10) -44.555 ( 72.30) +#> 143: 0.00 (0.00) 1003.943 (616.41) -1780.473 (602.42) -101.586 (368.10) 19.062 ( 60.30) +#> 144: 0.00 (0.00) 711.139 (724.53) -2635.898 (777.15) -178.609 (570.02) 36.623 (145.89) +#> atemp windspeed hum +#> +#> 1: -0.059 ( 65.71) 116.495 ( 59.05) 10.180 ( 92.70) +#> 2: 34.569 ( 42.77) 13.436 ( 21.38) 185.698 ( 89.52) +#> 3: -18.460 ( 53.67) 244.081 ( 54.64) -122.150 ( 58.66) +#> 4: 45.240 ( 57.08) -182.944 ( 58.37) -207.757 (103.50) +#> 5: 4.713 ( 45.34) 203.889 ( 37.12) -30.464 ( 60.24) +#> --- +#> 140: 16.388 (172.46) 362.193 (170.55) 627.943 (272.06) +#> 141: 7.102 ( 42.19) 216.846 ( 28.41) -71.698 ( 50.86) +#> 142: 129.756 ( 52.97) 80.036 ( 24.09) 272.476 (108.16) +#> 143: -3.841 ( 66.41) 48.680 ( 20.19) -236.844 ( 96.01) +#> 144: -66.292 (113.43) -469.159 ( 73.57) 562.362 (233.86) + +# Look at the times +explanation_n_coal[["sym_cau_gaussian_all_128"]] <- explanation_sym_cau$gaussian +explanation_n_coal[["asym_cau_gaussian_all_20"]] <- explanation_asym_cau$gaussian +explanation_n_coal <- explanation_n_coal[c(1, 3, 2, 4)] +print_time(explanation_n_coal) +#> sym_cau_gaussian_64 sym_cau_gaussian_all_128 asym_cau_gaussian_10 asym_cau_gaussian_all_20 +#> [1,] 11.35182 29.67625 2.171 2.406503 +``` + +We can then plot the beeswarm plots and the Shapley values for the six selected explicands. +We see that there are only minuscule differences between the Shapley values we obtain when we use +all the coalitions and those we obtain when we use half of the valid coalitions. + + +``` r +plot_beeswarms(explanation_n_coal, title = "Shapley values (gaussian) exact vs. approximation") +``` + +![](figure_asymmetric_causal/n_coalitions_plot_beeswarm-1.png) + + +``` r +plot_SV_several_approaches(explanation_n_coal, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) +``` + +![](figure_asymmetric_causal/n_coalitions_plot_SV-1.png) + + + +## Groups of features +In this section, we demonstrate that we can compute marginal, asymmetric +conditional, and symmetric/asymmetric Shapley values for groups of features, too. +For group Shapley values, we need to specify the causal ordering on the group level +and feature level. We demonstrate with the `gaussian` approach, but other approaches +are applicable, too. + +In the pairs plot above (and below), we see that it can be natural to group the +features `temp` and `atemp` due to their (conceptual) similarity and high correlation. + + +``` r +GGally::ggpairs(x_train[, 4:5]) +``` + +![](figure_asymmetric_causal/group_cor-1.png) + +We set up the groups and update the causal ordering to be on the group level. + +``` r +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum" +) + +causal_ordering_group <- + list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +We can then compute the (group) Shapley values using the different Shapley value frameworks. + +``` r +explanation_group_gaussian <- list() + +explanation_group_gaussian[["symmetric_marginal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = TRUE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:49:54 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend, cosyear, sinyear, temp_group, windspeed, hum} +#> • Components with confounding: {trend, cosyear, sinyear, temp_group, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e6a0055ef.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 24 of 64 coalitions, 4 new. + +explanation_group_gaussian[["symmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = NULL, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:01 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e64cac371.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 64 coalitions, 6 new. + +explanation_group_gaussian[["asymmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 12, and is therefore set to 12. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:06 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 12 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e7bf9af79.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 12 coalitions, 12 new. + +explanation_group_gaussian[["symmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering_group, + confounding = confounding, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 64, +#> and is therefore set to 2^n_groups = 64. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:08 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e5c0d6350.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 13 of 64 coalitions, 13 new. +#> +#> ── Iteration 2 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 20 of 64 coalitions, 6 new. +#> +#> ── Iteration 3 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 26 of 64 coalitions, 6 new. + +explanation_group_gaussian[["asymmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = confounding, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) +#> Note: Feature classes extracted from the model contains NA. +#> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or number of coalitions respecting the causal +#> ordering 12, and is therefore set to 12. +#> +#> ── Starting `shapr::explain()` at 2024-10-11 20:50:16 ─────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • Adaptive estimation: TRUE +#> • Number of group-wise Shapley values: 6 +#> • Number of observations to explain: 144 +#> • Number of asymmetric coalitions: 12 +#> • Causal ordering: {trend}, {cosyear, sinyear}, {temp_group, windspeed, hum} +#> • Components with confounding: {cosyear, sinyear} +#> • Computations (temporary) saved at: '/tmp/Rtmp4fQ9Oe/shapr_obj_15ba7e6d67207a.rds' +#> +#> ── Adaptive computation started ── +#> +#> ── Iteration 1 ────────────────────────────────────────────────────────────────────────────────────────────────── +#> ℹ Using 12 of 12 coalitions, 12 new. + +# Look at the elapsed times (symmetric takes the longest time) +print_time(explanation_group_gaussian) +#> symmetric_marginal symmetric_conditional asymmetric_conditional symmetric_causal asymmetric_causal +#> [1,] 6.838362 5.180942 1.792266 8.226479 2.392054 +``` + +We can then make the beeswarm plots and Shapley values plots for the six selected explicands. +For the beeswarm plots, we set `include_group_feature_means = TRUE` to make the plots. +This means that the plot function use the mean of the `temp` and `atemp` features as the feature +value. This only makes sense due to the high correlation between the two features. + +The main difference between the feature-wise and group-wise Shapley values +is that we now see a much wider spread in the Shapley values for `temp_group` +than we did for `temp` and `atemp`. +For example, for the symmetric causal framework, we saw above that the `temp` and `atemp` +obtained Shapley values between (around) $-500$ to $500$, while the grouped version +`temp_group` obtains Shapley values between $-1000$ to $1000$ + + + +``` r +plot_beeswarms(explanation_group_gaussian, + title = "Group Shapley values (gaussian)", + include_group_feature_means = TRUE +) +``` + +![](figure_asymmetric_causal/group_gaussian_plot_beeswarm-1.png) + + + +``` r +plot_SV_several_approaches(explanation_group_gaussian, index_x_explain) + + ggtitle("Shapley value prediction explanation (gaussian)") + + theme(legend.position = "bottom") + guides(fill = guide_legend(nrow = 2)) +``` + +![](figure_asymmetric_causal/group_gaussian_plot_SV-1.png) + + + + + +## Implementation details + +The `shapr` package is built to estimate conditional Shapley values, thus, +it parallelize over the coalitions. This makes perfect sense for said +framework as each batch of coalitions are independent of other batches, +which means that it is easy to parallelize. Furthermore, by using many +batches we drastically reduce the memory usage as `shapr` does not need +to store the Monte Carlo samples for all coalitions. + +This setup is not optimal for the causal Shapley value framework as the +chains of sampling steps for two coalition $\mathcal{S}$ and $\mathcal{S}^*$ +can contain many of the same steps. Ideally, each unique sampling step +should only be modeled once to save computation time, but, some of the +sampling steps will occur in many of the chains. Thus, we would then have +to store the Monte Carlo samples for all coalitions where this sampling +step is included, and we can therefor run into memory consumption problems. +Thus, in the current implementation, we treat each coalition $\mathcal{S}$ +independent and remodel the needed sampling steps for each coalition. + +Furthermore, in the conditional Shapley value framework, we have that +$\bar{\mathcal{S}} = \mathcal{M} \backslash \mathcal{S}$, thus `shapr` +will by default generate Monte Carlo samples for all features not in +$\mathcal{S}$. For the causal Shapley value framework, this is not the +case, i.e., $\bar{\mathcal{S}} \neq \mathcal{M} \backslash \mathcal{S}$ +in general. To reuse the code, we generate Monte Carlo samples for all +features not in $\mathcal{S}$, but only keep the samples for the features +in $\bar{\mathcal{S}}$. To speed up `shapr` further, one could rewrite +all the approaches to support that $\bar{\mathcal{S}}$ is not +the complement of $\mathcal{S}$. + +In the code below, we see the unique coalitions/set of features to condition +on to generate the Monte Carlo samples for all coalitions and the number of +times that set of conditional features is needed in the symmetric causal Shapley +value framework for the set up above. We see that most of the conditional +distributions will now be remodeled eights times. For the `gaussian` approach, +which is very fast to estimate the conditional distributions, this does not +have a major impact on the time. However, for, e.g., the `ctree` approach which +is much slower, this will take a significant amount of extra time. The `vaeac` +approach trains only on these relevant coalitions. + +``` r +S_causal_steps <- explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps +S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) +S_causal_steps_freq <- S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)] +S_causal_steps_freq <- S_causal_steps_freq[!sapply(S_causal_steps_freq, is.null)] # Remove NULLs +S_causal_steps_freq <- S_causal_steps_freq[sapply(S_causal_steps_freq, length) > 0] # Remove extra integer(0) +table(sapply(S_causal_steps_freq, paste0, collapse = ",")) +#> +#> 1 1,2,3 1,2,3,4 1,2,3,4,5 1,2,3,4,5,6 1,2,3,4,5,7 1,2,3,4,6 1,2,3,4,6,7 1,2,3,4,7 +#> 95 7 8 8 8 8 8 8 8 +#> 1,2,3,5 1,2,3,5,6 1,2,3,5,6,7 1,2,3,5,7 1,2,3,6 1,2,3,6,7 1,2,3,7 +#> 8 8 8 8 8 8 8 +``` + +The `independence`, `empirical`, `ctree`, and `categorical` approaches produce +weighted Monte Carlo samples. That means that they do not necessarily generate +`n_MC_samples`. To ensure `n_MC_samples`, we sample `n_MC_samples` samples using weighted +sampling with replacements where the weights are the weights returned by the approaches. + +The marginal Shapley value explanation framework can be extended to +support modeling the marginal distributions using the `copula` and +`vaeac` approaches as both of these methods support unconditional sampling. + + +# References diff --git a/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig b/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig new file mode 100644 index 0000000000000000000000000000000000000000..4f97fb01099c6fd6abe3d5f4f54cbe81cfb4c97b --- /dev/null +++ b/vignettes/understanding_shapr_asymmetric_causal.Rmd.orig @@ -0,0 +1,1333 @@ +--- +title: "Asymmetric and causal Shapley value explanations" +author: "Lars Henry Berge Olsen" +output: + rmarkdown::html_vignette: + toc: true + fig_caption: yes +bibliography: ../inst/REFERENCES.bib +vignette: > + %\VignetteEncoding{UTF-8} + %\VignetteIndexEntry{Asymmetric and causal Shapley value explanations} + %\VignetteEngine{knitr::rmarkdown} +editor_options: + markdown: + wrap: 72 + toc: true +--- + +```{r, include = FALSE} +knitr::opts_chunk$set( + collapse = TRUE, + comment = "#>", + fig.cap = "", + fig.width = 7.2, + fig.height = 6, + fig.path = "figure_asymmetric_causal/", # Ensure that figures are saved in the right folder (build vignette manually) + cache.path = "cache_asymmetric_causal/", # Ensure that cached objects are saved in the right folder + warning = FALSE, + message = TRUE +) +``` + + +# Overview {#Vignette} + +This vignette elaborates and demonstrates the asymmetric and +causal Shapley value frameworks introduced by @frye2020asymmetric +and @heskes2020causal, respectively. We also consider the marginal +and conditional Shapley value frameworks, see @lundberg2017unified +and @aas2019explaining, respectively. We demonstrate the frameworks +on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. The setup is +based on the `CauSHAPley` package, which is the +[code supplement](https://proceedings.neurips.cc/paper/2020/hash/32e54441e6382a7fbacbbbaf3c450059-Abstract.html) +to the @heskes2020causal paper. The `CauSHAPley` package was based +on an old version of `shapr` and was restricted to the `gaussian` approach (see section 6 in @heskes2020causal for more details). + +We have extended the causal Shapley value framework to work for all +Monte Carlo-based approaches (`independence` (not recommended), `empirical`, `gaussian`, `copula`, `ctree`, `vaeac` and `categorical`), while the extension of the asymmetric +Shapley value framework works for both the Monte Carlo and regression-based approaches. +Our generalization is of uttermost importance, as many real-world data sets +are far from the Gaussian distribution, and, compared to `CauSHAPley`, our implementation +can utilize all of `shapr`'s new features, such as batch computation, parallelization and +iterative computation for both feature-wise and group-wise Shapley values. + +The main differences between the marginal, conditional, and casual Shapley value +frameworks is that they sample/generate the Monte Carlo samples from the +marginal distribution, (conventional) observational conditional distribution, +and interventional conditional distribution, respectively. Asymmetric means +that we do not consider all possible coalitions, but rather only the coalitions +that respects a causal ordering. + + + +# Asymmetric conditional Shapley values {#AsymSV} + +Asymmetric (conditional) Shapley values were proposed by @frye2020asymmetric as +a way to incorporate causal knowledge in the real world by computing the Shapley +value explanations using only the feature combinations/coalitions consistent with +a (partial) causal ordering. See the figure below for a schematic overview of the causal ordering we are going to use in the examples in this vignette. In the figure, we see +that our causal ordering consists of three components: $\tau_1 = \{X_1\}$, $\tau_2 = \{X_2, X_3\}$, and $\tau_3 = \{X_4, X_5, X_6, X_7\}$. See the [code section](#Code) for what the features represent. + +To elaborate, instead of considering the $2^M$ possible coalitions, +where $M$ is the number of features, asymmetric Shapley values only +consider the subset of coalitions which respects the causal ordering. +For our causal ordering, this means that the asymmetric Shapley value explanation +framework skips the coalitions where $X_2$ is included but \textit{not} $X_1$, +as $X_1$ is the ancestor of $X_2$. This will skew the explanations towards +distal/root causes, see Section 3.2 in @frye2020asymmetric. + +We can use all approaches in `shapr`, both Monte Carlo-based and +regression based methods, to compute the asymmetric Shapley values. +This is because the asymmetric Shapley value explanation framework does not change +how we compute the contribution functions $v(S)$, but rather which of +the coalitions $S$ that are used to compute the Shapley value explanations. +This means that the number of coalitions are no longer $O(2^M)$, but rather +$O(2^{\tau_0})$, where $\tau_0 = \operatorname{max}_i |\tau_i|$ +is the number of features ($|\tau_i|$) in the largest component of the causal ordering. + +Furthermore, asymmetric Shapley values supports groups of features, but +then the causal ordering must be given on the group level instead of on the +feature level. The asymmetric Shapley value framework also supports +sampling of coalitions where the sampling is done from the +set of coalitions that respects the causal ordering. + +Finally, we want make a remark that asymmetric conditional Shapley values are +equivalent to asymmetric causal Shapley values (see below) when we only +use the coalitions respecting the causal ordering and assuming that all +dependencies within chain components are induced by mutual interactions. + + +```{r asymmetric_ordering, echo=FALSE, fig.cap="Schematic overview of the causal ordering used in this vignette.", fig.align='center', out.width = '50%'} +knitr::include_graphics("figure_asymmetric_causal/Asymmetric_ordering.png") +``` + + +# Causal Shapley values {#CausSV} + +Causal Shapley values were proposed by @heskes2020causal as a way +to explain the total effect of features on the prediction by taking +into account their causal relationships and adapting the sampling +procedure in `shapr`. More precisely, they propose to employ Pearl’s +do-calculus to circumvent the independence assumption, made by +@lundberg2017unified, without sacrificing any of the desirable +properties of the Shapley value framework. The causal Shapley value +explanation framework can also separate the contribution of direct +and indirect effects, which makes them principally different from +marginal and conditional Shapley values. The framework also provides +a more direct and robust way to incorporate causal knowledge, compared +to the asymmetric Shapley value explanation framework. + +To compute causal Shapley values, we have to specify a (partial) causal +ordering and make an assumption about the confounding in each component. +Together, they form a causal chain graph which contains directed and undirected +edges. All features that are treated on an equal footing are linked +together with undirected edges and become part of the same chain component. +Edges between chain components are directed and represent causal relationships. +In the figure below, we have the same causal ordering as above, but we +have in addition made the assumption that we have confounding in the +second component, but no confounding in the first and third components. +This allows us to correctly distinguishes between dependencies that are +due to confounding and mutual interactions. That is, in the figure, +the dependencies in chain component $\tau_2$ are assumed to be the result +of a common confounder, and those in $\tau_3$ of mutual interactions, while +we have no mutual interactions in $\tau_1$ as it is a singleton. + +Computing the effect of an intervention depends on how we interpret the +generative process that lead to the feature dependencies within each component. +If they are the result of marginalizing out a common confounder, +then intervention on a particular feature will break the dependency +with the other features, and we denote the set of these chain components +by $\mathcal{T}_{\text{confounding}}$. For the components with mutual +feature interactions, setting the value of a feature effects the +distribution of the variables within the same component. We denote +the set of these components by $\mathcal{T}_{\,\overline{\text{confounding}}}$. + +@heskes2020causal described how any expectation by intervention needed +to compute the causal Shapley values can be translated to an expectation +by observation, by using the interventional formula for causal chain graphs: +\begin{align} +\label{eq:do} +P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) += & +\prod_{\tau \in \mathcal{T}_{\,\text{confounding}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}) \times \tag{1} \\ +& \quad +\prod_{\tau \in \mathcal{T}_{\,\overline{\text{confounding}}}} +P(X_{\tau \cap \bar{\mathcal{S}}} \mid X_{\text{pa}(\tau) \cap \bar{\mathcal{S}}}, x_{\text{pa}(\tau) \cap \mathcal{S}}, x_{\tau \cap \mathcal{S}}). +\end{align} +Here, any of the Monte Carlo-based approaches in `shapr` can be +used to compute the conditional distributions/observational expectations. The marginals +are estimated from the training data for all approaches except +`gaussian`, for which we use the marginals of the Gaussian +distribution instead. + +For specific causal chain graphs, the causal Shapley value framework +simplifies to symmetric conditional, asymmetric conditional, and marginal +Shapley values, see Corollary 1 to 3 in the supplement of @heskes2020causal. + + +```{r pressure, echo=FALSE, fig.cap="Schematic overview of the causal chain graph used in this vignette.", out.width = '50%'} +knitr::include_graphics("figure_asymmetric_causal/Causal_ordering.png") +``` + + +# Marginal Shapley values {#MarginaSV} +Causal Shapley values are equivalent to marginal Shapley values when all $M$ +features are combined into a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by confounding. Then $\text{pa}(\tau) = \emptyset$, and +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation (\ref{eq:do}) +simplifies to $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}})$, +as specified in @lundberg2017unified. + +The Monte Carlo samples for the marginals are generated by sampling from the +training data, except for the `gaussian` approach where we use the marginals +of the estimated multivariate Gaussian distribution. This means that for all +other approaches, this is the same as using the `independence` approach +in the conditional Shapley value explanation framework. + +# Symmetric conditioal Shapley values {#ConditionalSV} +Causal Shapley values are equivalent to symmetric conditional Shapley values when all $M$ +features are combined in a single component $\tau = \mathcal{M} = \{1,2,...,M\}$ and +all dependencies are induced by mutual interaction. Then $\text{pa}(\tau) = \emptyset$, +and $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ in Equation +(\ref{eq:do}) simplifies to +$P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S})) = P(X_{\bar{\mathcal{S}}} \mid X_\mathcal{S} = x_\mathcal{S})$, +as specified in @aas2019explaining. Symmetric means that we consider all coalitions. + + + + + + +# Code example +## Overview +We demonstrate the frameworks on the [bike sharing](https://archive.ics.uci.edu/dataset/275/) +dataset from the UCI Machine Learning Repository. We let the features be the +number of days since January 2011 (`trend`), two cyclical variables representing +the season (`cosyear`, `sinyear`), temperature (`temp`), feeling temperature +(`atemp`), wind speed (`windspeed`), and humidity (`hum`). The first three features +are considered to be a potential cause of the four weather-related features. +The bike rental is strongly seasonal and shows an upward trend, as illustrated in the figure below. +The bike data is split randomly into a training (80%) and test/explicand (20%) set. +We train an `XGBoost` model for 100 rounds with default variables to act as the model +we want to explain. + +In the table below, we highlight the Shapley value explanation frameworks introduced above +and how to access them by changing the arguments `asymmetric`, `ordering`, and `confounding` in `shapr::explain()`. +Note that symmetric conditional Shapley values are the default version, i.e., by default +`asymmetric = FALSE`, `ordering = NULL`, `confounding = NULL`. + +| Framework | Sampling | Approaches | `asymmetric` | `ordering` | `confounding` | +|:-------------------|:-----------------------|:---------------------|:-------------|:------------|:--------------| +| Sym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `FALSE` | `NULL` | `NULL` | +| Asym. Conditional | $P(X_{\bar{\mathcal{S}}} \mid (X_\mathcal{S} = x_\mathcal{S})$ | All | `TRUE` | `list(...)` | `NULL` | +| Sym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `FALSE` | `list(...)` | `c(...)` | +| Asym. Causal | $P(X_{\bar{\mathcal{S}}} \mid do(X_\mathcal{S} = x_\mathcal{S}))$ | All MC-based | `TRUE` | `list(...)` | `c(...)` | +| Sym. Marginal | $P(X_{\bar{\mathcal{S}}})$ | `indep.`, `gaussian` | `FALSE` | `NULL` | `TRUE` | + + + +## Code setup +First, we load the needed libraries, set up the training/explicand data, plot the data, and train an `xgboost` model. +```{r setup_1, message = FALSE, fig.height = 4, cache = TRUE} +# Libraries +library(ggplot2) +require(GGally) +library(ggpubr) +library(gridExtra) +library(xgboost) +library(data.table) +library(shapr) + +# Ensure that shapr's functions are prioritzed, otherwise we need to use the `shapr::` +# prefix when calling explain(). The `conflicted` package is imported by `tidymodels`. +conflicted::conflicts_prefer(shapr::explain, shapr::prepare_data) + +# Set up the data +# Can also download the data set from the source https://archive.ics.uci.edu/dataset/275/bike+sharing+dataset +# temp <- tempfile() +# download.file("https://archive.ics.uci.edu/static/public/275/bike+sharing+dataset.zip", temp) +# bike <- read.csv(unz(temp, "day.csv")) +# unlink(temp) +bike <- read.csv("../inst/extdata/day.csv") +# Difference in days, which takes DST into account +bike$trend <- as.numeric(difftime(bike$dteday, bike$dteday[1], units = "days")) +bike$cosyear <- cospi(bike$trend / 365 * 2) +bike$sinyear <- sinpi(bike$trend / 365 * 2) +# Unnormalize variables (see data set information in link above) +bike$temp <- bike$temp * (39 - (-8)) + (-8) +bike$atemp <- bike$atemp * (50 - (-16)) + (-16) +bike$windspeed <- 67 * bike$windspeed +bike$hum <- 100 * bike$hum + +# Plot the data +ggplot(bike, aes(x = trend, y = cnt, color = temp)) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(colour = "temp") + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) +``` + +```{r setup_2, message = FALSE, fig.height = 7, cache = TRUE} +# Define the features and the response variable +x_var <- c("trend", "cosyear", "sinyear", "temp", "atemp", "windspeed", "hum") +y_var <- "cnt" + +# NOTE: Encountered RNG reproducibility issues across different systems, +# Load the training-test split. 80% training and 20% test +train_index <- readRDS("../inst/extdata/train_index.rds") + +# Training data +x_train <- as.matrix(bike[train_index, x_var]) +y_train_nc <- as.matrix(bike[train_index, y_var]) # not centered +y_train <- y_train_nc - mean(y_train_nc) + +# Plot pairs plot +GGally::ggpairs(x_train) +``` + +```{r setup_3, message = FALSE, fig.height = 4, cache = TRUE} +# Test/explicand data +x_explain <- as.matrix(bike[-train_index, x_var]) +y_explain_nc <- as.matrix(bike[-train_index, y_var]) # not centered +y_explain <- y_explain_nc - mean(y_train_nc) + +# Get 6 explicands to plot the Shapley values of with a wide spread in their predicted outcome +n_index_x_explain <- 6 +index_x_explain <- order(y_explain)[seq(1, length(y_explain), length.out = n_index_x_explain)] +y_explain[index_x_explain] + +# Fit an XGBoost model to the training data +model <- xgboost::xgboost( + data = x_train, + label = y_train, + nround = 100, + verbose = FALSE +) + +# Save the phi0 +phi0 <- mean(y_train) + +# Look at the root mean squared error +sqrt(mean((predict(model, x_explain) - y_explain)^2)) +ggplot( + data.table("response" = y_explain[, 1], "predicted_response" = predict(model, x_explain)), + aes(response, predicted_response) +) + + geom_point() +``` + + +We are going to use the `causal_ordering` and `confounding` illustrated in the figures above. +For `causal_ordering`, we can either provide the index of feature or the feature names. +Thus, the following two versions of `causal_ordering` will produce equivalent results. +Furthermore, we assume that we have confounding for the second component (i.e., the season has +an effect on the weather) and no confounding for the third component (i.e., we do not +how to model the intricate relations between the weather features). + +```{r causal_ordering, cache = TRUE} +causal_ordering <- list(1, c(2, 3), c(4:7)) +causal_ordering <- list("trend", c("cosyear", "sinyear"), c("temp", "atemp", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +To make the rest of the vignette easier to follow, we create some helper +functions that plot and summarize the results of the explanation methods. +This code block is optional to understand and can be skipped. + +```{r set_up_functions, cache = TRUE} +# Extract the MSEv criterion scores and elapsed times +print_MSEv_scores_and_time <- function(explanation_list) { + res <- as.data.frame(t(sapply( + explanation_list, + function(explanation) { + round(c( + explanation$MSEv$MSEv$MSEv, + explanation$MSEv$MSEv$MSEv_sd, + difftime(explanation$timing$end_time, explanation$timing$init_time, units = "secs") + ), 2) + } + ))) + colnames(res) <- c("MSEv", "MSEv_sd", "Time (secs)") + return(res) +} + +# Print the full time information +print_time <- function(explanation_list) { + t(sapply(explanation_list, function(explanation) explanation$timing$total_time_secs)) +} + +# Make beeswarm plots +plot_beeswarms <- function(explanation_list, title = "", ...) { + # Make the beeswarm plots + grobs <- lapply(seq(length(explanation_list)), function(explanation_idx) { + gg <- plot(explanation_list[[explanation_idx]], plot_type = "beeswarm", ...) + + ggplot2::ggtitle(tools::toTitleCase(gsub("_", " ", names(explanation_list)[[explanation_idx]]))) + + # Flip the order such that the features comes in the right order + gg <- gg + + ggplot2::scale_x_discrete(limits = rev(levels(gg$data$variable)[levels(gg$data$variable) != "none"])) + }) + + # Get the limits + ylim <- sapply(grobs, function(grob) ggplot2::ggplot_build(grob)$layout$panel_scales_y[[1]]$range$range) + ylim <- c(min(ylim), max(ylim)) + + # Update the limits + grobs <- suppressMessages(lapply(grobs, function(grob) grob + ggplot2::coord_flip(ylim = ylim))) + + # Make the combined plot + gridExtra::grid.arrange( + grobs = grobs, ncol = 1, + top = grid::textGrob(title, gp = grid::gpar(fontsize = 18, font = 8)) + ) +} +``` + + + +## Symmetric conditional Shapley values (default) +We start by demonstrating how to compute symmetric conditional Shapley values. +This is the default version in `shapr` and there is no need to specify the arguments below. +However, we have specified them for the sake of clarity. +We use the `gaussian`, `ctree`, and `regression_separate`(`xgboost` with default hyperparameters) +approaches, but any other approach can also be used. + + +```{r sym_con, cache = TRUE} +# list to store the results +explanation_sym_con <- list() + +explanation_sym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) + +explanation_sym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "ctree", + phi0 = phi0, + n_MC_samples = 1000, + asymmetric = FALSE, # Default value (TRUE will give the same since `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) + +explanation_sym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + asymmetric = FALSE, # Default value (TRUE will give the same as `causal_ordering = NULL`) + causal_ordering = NULL, # Default value + confounding = NULL # Default value +) +``` +We can then look at the $\operatorname{MSE}_v$ evaluation scores to compare the approaches. +All approaches are comparable, but `xgboost` is clearly the fastest approach. + +```{r, cache = TRUE} +print_MSEv_scores_and_time(explanation_sym_con) +``` + +We can then plot the Shapley values for the six explicands chosen above. + +```{r explanation_sym_con_SV, fig.height = 7, cache = TRUE} +plot_SV_several_approaches(explanation_sym_con, index_x_explain) + + theme(legend.position = "bottom") +``` + + + +We can also make beeswarm plots of the Shapley values to look at the structure +of the Shapley values for all explicands. The figures are quite similar, but +with minor differences. E.g., the `gaussian` approach produces almost no +Shapley values around $500$ for the `trend` feature. + +```{r explanation_sym_con_beeswarm, fig.height = 9, cache = TRUE} +plot_beeswarms(explanation_sym_con, title = "Symmetric conditional Shapley values") +``` + + + + +## Asymmetric conditional Shapley values +Then we look at the asymmetric conditional Shapley values. To obtain these +types of Shapley values, we have to specify that `asymmetric = TRUE` and a +`causal_ordering`. We use `causal_ordering = list(1, c(2, 3), c(4:7))`. + + +```{r asym_con_gaussian, cache = TRUE} +explanation_asym_con <- list() + +explanation_asym_con[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) + +explanation_asym_con[["gaussian_non_iterative"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL, # Default value + iterative = FALSE +) + +explanation_asym_con[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) + +explanation_asym_con[["xgboost"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + approach = "regression_separate", + regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = causal_ordering, + confounding = NULL # Default value +) +``` + +The asymmetric conditional Shapley value framework is faster as we only +consider $20$ coalitions (including empty and grand coalition) +instead of all $128$ coalitions (see code below). + +```{r, cache = TRUE} +print_MSEv_scores_and_time(explanation_asym_con) + +# Look at the number of coalitions considered. Decreased from 128 to 20. +explanation_sym_con$gaussian$internal$parameters$max_n_coalitions +explanation_asym_con$gaussian$internal$parameters$max_n_coalitions + +# Here we can see the 20 coalitions that respects the causal ordering +explanation_asym_con$gaussian$internal$objects$dt_valid_causal_coalitions[["coalitions"]] +``` + +We can then look at the beeswarm plots of the asymmetric conditional Shapley value. +The `ctree` and `xgboost` approaches produce similar figures, while the `gaussian` +approach both shrinks and groups the Shapley values for the `trend` feature, while +it produces more negative values for the `cosyear` feature. + +When going from symmetric to asymmetric Shapley values, we see that many of the features' +Shapley values are now shrunken closer to zero, especially `temp` and `atemp`. + +```{r explanation_asym_con_beeswarm, fig.height = 9, cache = TRUE} +plot_beeswarms(explanation_asym_con, title = "Asymmetric conditional Shapley values") +``` + + + +We can also compare the obtained symmetric and asymmetric conditional Shapley values +for the 6 explicands. We often see that the asymmetric version gives larger Shapley +values to the distal/root causes, i.e., `trend` and `cosyear`, than the symmetric +version. This is in line with Section 3.2 in @frye2020asymmetric. +```{r sym_and_asym_Shapley_values, fig.height = 7, cache = TRUE} +# Order the symmetric and asymmetric conditional explanations into a joint list +explanation_sym_con_tmp <- copy(explanation_sym_con) +names(explanation_sym_con_tmp) <- paste0(names(explanation_sym_con_tmp), "_sym") +explanation_asym_con_tmp <- copy(explanation_asym_con) +names(explanation_asym_con_tmp) <- paste0(names(explanation_asym_con_tmp), "_asym") +explanation_asym_sym_con <- c(explanation_sym_con_tmp, explanation_asym_con_tmp)[c(1, 4, 2, 5, 3, 6)] +plot_SV_several_approaches(explanation_asym_sym_con, index_x_explain, brewer_palette = "Paired") + + theme(legend.position = "bottom") +``` + + + + +## Symmetric marginal Shapley values +For marginal Shapley values, we can only consider the symmetric version as we must set +`causal_ordering = list(1:7)` (or `NULL`) and `confounding = TRUE`. Setting `asymmetric = TRUE` +will have no effect, as the causal ordering consists of only a single component containing all features, +i.e., all coalitions respect the causal ordering. As stated above, `shapr` generates the +marginal Monte Carlos samples from the Gaussian marginals if `approach = "gaussian"`, +while for all other Monte Carlo approaches the marginals are estimated from the training data, i.e., +assuming feature independence. Thus, it does not matter if we set `approach = "independence"` +or any other of the Monte Carlo-based approaches. We use `approach = "independence"` for clarity. +Furthermore, we also obtain marginal Shapley values by using the +conditional Shapley value framework with the `independence` approach. However, note that there will +be a minuscule difference in the produced Shapley values due to different sampling setups/orders. + +```{r sym_marg, cache = TRUE} +explanation_sym_marg <- list() + +# Here we sample from the estimated Gaussian marginals +explanation_sym_marg[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Here we sample from the marginals of the training data +explanation_sym_marg[["independence_marg"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence", + asymmetric = FALSE, + causal_ordering = list(1:7), + confounding = TRUE +) + +# Here we use the conditional Shapley value framework with the `independence` approach +explanation_sym_marg[["independence_con"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "independence" +) +``` + + +We can look the beeswarm plots + +```{r explanation_sym_mar_beeswarm, fig.height = 9, cache = TRUE} +print_MSEv_scores_and_time(explanation_sym_marg) + +plot_beeswarms(explanation_sym_marg, title = "Symmetric marginal Shapley values") +``` + + + +## Causal Shapley values +To compute (symmetric/asymmetric) causal Shapley values, we have to provide +the `causal_ordering` and `confounding` objects. We set them to be +`causal_ordering = list(1, 2:3, 4:7)` and `confounding = c(FALSE, TRUE, FALSE)`, +as explained above. + +The causal framework takes longer than the other frameworks, as generating the +the Monte Carlo samples often consists of a chain of sampling steps. For example, +for $\mathcal{S} = {2}$, we must generate $X_1,X_3,X_4,X_5,X_6,X_7 \mid X_2$. +However, we cannot do this directly due to the `causal_ordering` and `confounding` +specified above. To generate the Monte Carlo samples, we have to follow a chain of +sampling steps. More precisely, we first need to generate $X_1$ from the marginal, +then $X_3 \mid X_1$, and finally $X_4,X_5,X_6,X_7 \mid X_1,X_2,X_3$. The latter two +steps are done by using the provided `approach` to model the conditional distributions. +The `internal$objects$S_causal_steps_strings` object contains the sampling steps +needed for the different feature combinations/coalitions $\mathcal{S}$. + +For causal Shapley values, only the Monte Carlo-based approaches are applicable. + +### Symmetric +```{r sym_cau, cache = TRUE} +explanation_sym_cau <- list() + +explanation_sym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + iterative = FALSE, # Set to FALSE to get a single iteration to illustrate sampling steps below + exact = TRUE +) + +# Look at the sampling steps for the third coalition (S = {2}) +explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps_strings$id_coalition_3 + +# Use the copula approach +explanation_sym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) +``` + + +```{r explanation_sym_cau_beeswarm, fig.height = 6} +print_MSEv_scores_and_time(explanation_sym_cau) +plot_beeswarms(explanation_sym_cau, title = "Symmetric causal Shapley values") +``` + + + + +### Asymmetric +We now turn to asymmetric causal Shapley values. That is, we only use the coalitions +that respects the causal ordering. Thus, the computations are faster as the number of +coalitions are reduced. + +```{r asym_cau, cache = TRUE} +explanation_asym_cau <- list() + +explanation_asym_cau[["gaussian"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "gaussian", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the copula approach +explanation_asym_cau[["copula"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "copula", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the ctree approach (warning: ctree is slow) +explanation_asym_cau[["ctree"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "ctree", + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) +) + +# Use the vaeac approach +explanation_asym_cau[["vaeac"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + phi0 = phi0, + n_MC_samples = 1000, + approach = "vaeac", + vaeac.epochs = 20, + paired_shap_sampling = FALSE, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE) + ) +``` +We can look at the elapsed time. We see that `ctree` is much slower than the other approaches. +See the [implementation details](#Implementation_details) for an explanation. +```{r, cache = TRUE} +print_time(explanation_asym_cau) +``` +We can then plot the beeswarm plots. We see that `ctree` provides more spread out Shapley values for the `trend` feature. + + +```{r explanation_asym_cau_beeswarm, fig.height = 9, cache = TRUE} +# Plot the beeswarm plots +plot_beeswarms(explanation_asym_cau, title = "Asymmetric causal Shapley values") +``` + +```{r explanation_asym_cau_SV, fig.height = 8, cache = TRUE} +# Plot the Shapley values +plot_SV_several_approaches(explanation_asym_cau, index_x_explain) + + theme(legend.position = "bottom") +``` + +We can also use the other Monte Carlo-based approaches (`independence` and `empirical`), too. + + + +## Comparing the frameworks +Here we plot the obtained Shapley values for the six explicand when using the +`gaussian` approach in the different Shapley value explanation frameworks, and +we see that the different frameworks provide different explanations. +The largest difference are between +whether we use the symmetric or asymmetric version. To summarize, asymmetric +conditional/causal Shapley values focus on the root cause, marginal Shapley +values on the more direct effect, and symmetric conditional/causal Shapley +consider both for a more natural explanation. + +```{r compare_plots, cache = TRUE, fig.height = 8, cache = TRUE} +explanation_gaussian <- list( + symmetric_marginal = explanation_sym_marg$gaussian, + symmetric_conditional = explanation_sym_con$gaussian, + symmetric_causal = explanation_sym_cau$gaussian, + asymmetric_conditional = explanation_asym_con$gaussian, + asymmetric_causal = explanation_asym_cau$gaussian +) + +plot_SV_several_approaches(explanation_gaussian, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) + + ggtitle("Shapley value prediction explanation (approach = 'gaussian')") + + guides(color = guide_legend(title = "Framework")) +``` + +## Scatter plots: marginal vs. causal Shapley values +In this section, we produce scatter plots comparing the symmetric marginal +and symmetric causal Shapley values for the temperature feature `temp` and +the seasonal feature `cosyear` for all explicands. The plots shows that the +marginal Shapley values almost purely explain the predictions based on +temperature, while the causal Shapley values also give credit to season. +We can change the features and frameworks in the code below, but we chose +these values to replicate Figure 3 in @heskes2020causal. + + +```{r scatter_plots, cache = TRUE, fig.height = 6, cache = TRUE} +# The color of the points +color <- "temp" + +# The features we want to compare +feature_1 <- "cosyear" +feature_2 <- "temp" + +# The Shapley value frameworks we want to compare +sv_framework_1 <- explanation_sym_marg[["gaussian"]] +sv_framework_1_str <- "Marginal SV" +sv_framework_2 <- explanation_sym_cau[["gaussian"]] +sv_framework_2_str <- "Causal SV" + +# Set up the data.frame we are going to plot +sv_correlation_df <- data.frame( + color = x_explain[, color], + sv_framework_1_feature_1 = sv_framework_1$shapley_values_est[[feature_1]], + sv_framework_2_feature_1 = sv_framework_2$shapley_values_est[[feature_1]], + sv_framework_1_feature_2 = sv_framework_1$shapley_values_est[[feature_2]], + sv_framework_2_feature_2 = sv_framework_2$shapley_values_est[[feature_2]] +) + +# Make the plots +scatterplot_topleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_blank(), + axis.text.y = element_text(size = 12), + axis.ticks.x = element_blank(), + axis.title.x = element_blank() + ) + +scatterplot_topright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_1_feature_1, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_1_str, feature_1)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-500, 500), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.title.x = element_blank(), + axis.title.y = element_blank(), + axis.text.x = element_blank(), + axis.ticks.x = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +scatterplot_bottomleft <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_1_feature_2, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + scale_color_gradient(low = "blue", high = "red") + + xlab(paste(sv_framework_1_str, feature_2)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.text.y = element_text(size = 12) + ) + +scatterplot_bottomright <- + ggplot( + sv_correlation_df, + aes(x = sv_framework_2_feature_1, y = sv_framework_2_feature_2, color = color) + ) + + geom_point(size = 1) + + xlab(paste(sv_framework_2_str, feature_1)) + + ylab(paste(sv_framework_2_str, feature_2)) + + scale_x_continuous(limits = c(-1500, 1000), breaks = c(-1000, 0, 1000)) + + scale_y_continuous(limits = c(-1000, 1000), breaks = c(-500, 0, 500)) + + scale_color_gradient(low = "blue", high = "red") + + theme_minimal() + + theme( + text = element_text(size = 12), + axis.text.x = element_text(size = 12), + axis.title.y = element_blank(), + axis.text.y = element_blank(), + axis.ticks.y = element_blank() + ) + +# Plot of the trend of the data +bike_plot_new <- ggplot(bike, aes(x = trend, y = cnt, color = get(color))) + + geom_point(size = 0.75) + + scale_color_gradient(low = "blue", high = "red") + + labs(color = color) + + xlab("Days since 1 January 2011") + + ylab("Number of bikes rented") + + theme_minimal() + + theme(legend.position = "right", legend.title = element_text(size = 10)) + +# Combine the plots +ggpubr::ggarrange( + bike_plot_new, + ggpubr::ggarrange( + scatterplot_topleft, + scatterplot_topright, + scatterplot_bottomleft, + scatterplot_bottomright, + legend = "none" + ), + nrow = 2, heights = c(1, 2) +) +``` + +## Investigating two similar days + +We investigate the difference between symmetric/asymmetric conditional, +symmetric/asymmetric causal, and marginal Shapley values for two days: +October 10 and December 3, 2012. They have more or less the same +temperature of 13 and 13.27 degrees Celsius, and predicted bike counts +of 6117 and 6241, respectively. The figure below is an extension of +Figure 4 in @heskes2020causal, as they only included asymmetric +conditional, symmetric causal, and marginal Shapley values. + +We plot the various Shapley values for the `cosyear` and `temp` features +below. We obtain the same results as @heskes2020causal obtained, namely, +that the marginal Shapley value explanation framework provides similar +explanation for both days. I.e., it only considers the direct effect of `temp`. +The asymmetric conditional and causal Shapley values are almost +indistinguishable and put the most weight on the ‘root’ cause `cosyear`. +@heskes2020causal states that the symmetric causal Shapley values provides +a sensible balance between the two extremes and gives credit to both season and temperature, +but still different explanation for the two days. + +However, as we also include symmetric conditional Shapley values, +we see that they are extremely similar to symmetric causal Shapley values. +I.e., the conditional Shapley value explanation framework also provides +a sensible balance between marginal and asymmetric Shapley values. +To summarize: +as concluded by @heskes2020causal in their Figure 4, the +asymmetric conditional/causal Shapley values focus on the +root cause, marginal Shapley values on the more direct effect, and symmetric +conditional/causal Shapley consider both for a more natural explanation. + +```{r two_dates_1, cache = TRUE, fig.height = 5, cache = TRUE} +# Features of interest +features <- c("cosyear", "temp") + +# Get explicands with similar temperature: 2012-10-09 (October) and 2012-12-03 (December) +dates <- c("2012-10-09", "2012-12-03") +dates_idx <- sapply(dates, function(data) which(as.integer(row.names(x_explain)) == which(bike$dteday == data))) +# predict(model, x_explain)[dates_idx] + mean(y_train_nc) # predicted values for the two points + +# List of the Shapley value explanations +explanations <- list( + "Sym. Mar." = explanation_sym_marg[["gaussian"]], + "Sym. Con." = explanation_sym_con[["gaussian"]], + "Sym. Cau." = explanation_sym_cau[["gaussian"]], + "Asym. Con." = explanation_asym_con[["gaussian"]], + "Asym. Cau." = explanation_asym_cau[["gaussian"]] +) + +# Extract the relevant Shapley values +explanations_extracted <- data.table::rbindlist(lapply(seq_along(explanations), function(idx) { + explanations[[idx]]$shapley_values_est[ + dates_idx, ..features + ][, `:=`(Date = dates, type = names(explanations)[idx])] +})) + +# Set type to be a ordered factor +explanations_extracted[, type := factor(type, levels = names(explanations), ordered = TRUE)] + +# Convert from wide to long data table +dt_all <- data.table::melt(explanations_extracted, + id.vars = c("Date", "type"), + variable.name = "feature" +) + +# Make the plot +ggplot(dt_all, aes( + x = feature, y = value, group = interaction(Date, feature), + fill = Date, label = round(value, 2) +)) + + geom_col(position = "dodge") + + theme_classic() + + ylab("Shapley value") + + facet_wrap(vars(type)) + + theme(axis.title.x = element_blank()) + + scale_fill_manual(values = c("indianred4", "ivory4")) + + theme( + legend.position.inside = c(0.75, 0.25), axis.title = element_text(size = 20), + legend.title = element_text(size = 16), legend.text = element_text(size = 14), + axis.text.x = element_text(size = 12), axis.text.y = element_text(size = 12), + strip.text.x = element_text(size = 14) + ) +``` + +We can also make a similar plot using the `plot_SV_several_approaches` function in `shapr`, +but then we get each explicand in a separate facet instead of a facet for each framework. +```{r two_dates_2, cache = TRUE, fig.height = 4, cache = TRUE} +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + only_these_features = features, # Can include more features. + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_n_dodge = 1 +) + theme(legend.position = "bottom") +``` + +Furthermore, instead of doing as @heskes2020causal and only considering the features +`cosyear` and `temp`, we can plot all features, too, to get a more complete overview. +```{r two_dates_3, cache = TRUE, fig.height = 5, cache = TRUE} +# Here 2012-10-09 is the left facet and 2012-12-03 the right facet +plot_SV_several_approaches(explanations, + index_explicands = dates_idx, + facet_scales = "free_x", + horizontal_bars = FALSE, + axis_labels_rotate_angle = 45, + digits = 2 +) + theme(legend.position = "bottom") +``` + + +## Sampling of coalitions + +We can use `max_n_coalitions` to specify/reduce the number of coalitions +to use when computing the Shapley value explanation framework. This applies +to marginal, conditional, and causal Shapley values, both the symmetric and +asymmetric versions. However, recall that the asymmetric versions already +have fewer valid coalitions due to the causal ordering. + +In the example below, we demonstrate the sampling of coalitions for the +asymmetric and symmetric causal Shapley value explanation frameworks. +We half the number of coalitions for both versions +and see that the elapsed times are approximately halved, too. +```{r n_coalitions, cache = TRUE, cache = TRUE} +explanation_n_coal <- list() + +explanation_n_coal[["sym_cau_gaussian_64"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + max_n_coalitions = 64 # Instead of 128 +) + +explanation_n_coal[["asym_cau_gaussian_10"]] <- explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = list(1, 2:3, 4:7), + confounding = c(FALSE, TRUE, FALSE), + paired_shap_sampling = FALSE, + verbose = c("basic", "convergence", "shapley"), + max_n_coalitions = 10 # Instead of 20 +) + +# Look at the times +explanation_n_coal[["sym_cau_gaussian_all_128"]] <- explanation_sym_cau$gaussian +explanation_n_coal[["asym_cau_gaussian_all_20"]] <- explanation_asym_cau$gaussian +explanation_n_coal <- explanation_n_coal[c(1, 3, 2, 4)] +print_time(explanation_n_coal) +``` + +We can then plot the beeswarm plots and the Shapley values for the six selected explicands. +We see that there are only minuscule differences between the Shapley values we obtain when we use +all the coalitions and those we obtain when we use half of the valid coalitions. + +```{r n_coalitions_plot_beeswarm, cache = TRUE, fig.height = 12, cache = TRUE} +plot_beeswarms(explanation_n_coal, title = "Shapley values (gaussian) exact vs. approximation") +``` + +```{r n_coalitions_plot_SV, cache = TRUE, fig.height = 8, cache = TRUE} +plot_SV_several_approaches(explanation_n_coal, index_x_explain) + + theme(legend.position = "bottom") + + guides(fill = guide_legend(nrow = 2)) +``` + + + +## Groups of features +In this section, we demonstrate that we can compute marginal, asymmetric +conditional, and symmetric/asymmetric Shapley values for groups of features, too. +For group Shapley values, we need to specify the causal ordering on the group level +and feature level. We demonstrate with the `gaussian` approach, but other approaches +are applicable, too. + +In the pairs plot above (and below), we see that it can be natural to group the +features `temp` and `atemp` due to their (conceptual) similarity and high correlation. + +```{r group_cor, cache = TRUE, fig.height = 4, cache = TRUE} +GGally::ggpairs(x_train[, 4:5]) +``` + +We set up the groups and update the causal ordering to be on the group level. +```{r group_group, cache = TRUE, cache = TRUE} +group_list <- list( + trend = "trend", + cosyear = "cosyear", + sinyear = "sinyear", + temp_group = c("temp", "atemp"), + windspeed = "windspeed", + hum = "hum" +) + +causal_ordering_group <- + list("trend", c("cosyear", "sinyear"), c("temp_group", "windspeed", "hum")) +confounding <- c(FALSE, TRUE, FALSE) +``` + + +We can then compute the (group) Shapley values using the different Shapley value frameworks. +```{r group_gaussian, cache = TRUE, cache = TRUE} +explanation_group_gaussian <- list() + +explanation_group_gaussian[["symmetric_marginal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = TRUE, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["symmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = list(seq(length(group_list))), # or `NULL` + confounding = NULL, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["asymmetric_conditional"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = NULL, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["symmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = FALSE, + causal_ordering = causal_ordering_group, + confounding = confounding, + n_MC_samples = 1000, + group = group_list + ) + +explanation_group_gaussian[["asymmetric_causal"]] <- + explain( + model = model, + x_train = x_train, + x_explain = x_explain, + approach = "gaussian", + phi0 = phi0, + asymmetric = TRUE, + causal_ordering = causal_ordering_group, + confounding = confounding, + paired_shap_sampling = FALSE, + n_MC_samples = 1000, + group = group_list + ) + +# Look at the elapsed times (symmetric takes the longest time) +print_time(explanation_group_gaussian) +``` + +We can then make the beeswarm plots and Shapley values plots for the six selected explicands. +For the beeswarm plots, we set `include_group_feature_means = TRUE` to make the plots. +This means that the plot function use the mean of the `temp` and `atemp` features as the feature +value. This only makes sense due to the high correlation between the two features. + +The main difference between the feature-wise and group-wise Shapley values +is that we now see a much wider spread in the Shapley values for `temp_group` +than we did for `temp` and `atemp`. +For example, for the symmetric causal framework, we saw above that the `temp` and `atemp` +obtained Shapley values between (around) $-500$ to $500$, while the grouped version +`temp_group` obtains Shapley values between $-1000$ to $1000$ + + +```{r group_gaussian_plot_beeswarm, cache = TRUE, fig.height = 15, fig.width = 7.2, cache = TRUE} +plot_beeswarms(explanation_group_gaussian, + title = "Group Shapley values (gaussian)", + include_group_feature_means = TRUE +) +``` + + +```{r group_gaussian_plot_SV, cache = TRUE, fig.height = 8, fig.width = 7.2, cache = TRUE} +plot_SV_several_approaches(explanation_group_gaussian, index_x_explain) + + ggtitle("Shapley value prediction explanation (gaussian)") + + theme(legend.position = "bottom") + guides(fill = guide_legend(nrow = 2)) +``` + + + + + +## Implementation details + +The `shapr` package is built to estimate conditional Shapley values, thus, +it parallelize over the coalitions. This makes perfect sense for said +framework as each batch of coalitions are independent of other batches, +which means that it is easy to parallelize. Furthermore, by using many +batches we drastically reduce the memory usage as `shapr` does not need +to store the Monte Carlo samples for all coalitions. + +This setup is not optimal for the causal Shapley value framework as the +chains of sampling steps for two coalition $\mathcal{S}$ and $\mathcal{S}^*$ +can contain many of the same steps. Ideally, each unique sampling step +should only be modeled once to save computation time, but, some of the +sampling steps will occur in many of the chains. Thus, we would then have +to store the Monte Carlo samples for all coalitions where this sampling +step is included, and we can therefor run into memory consumption problems. +Thus, in the current implementation, we treat each coalition $\mathcal{S}$ +independent and remodel the needed sampling steps for each coalition. + +Furthermore, in the conditional Shapley value framework, we have that +$\bar{\mathcal{S}} = \mathcal{M} \backslash \mathcal{S}$, thus `shapr` +will by default generate Monte Carlo samples for all features not in +$\mathcal{S}$. For the causal Shapley value framework, this is not the +case, i.e., $\bar{\mathcal{S}} \neq \mathcal{M} \backslash \mathcal{S}$ +in general. To reuse the code, we generate Monte Carlo samples for all +features not in $\mathcal{S}$, but only keep the samples for the features +in $\bar{\mathcal{S}}$. To speed up `shapr` further, one could rewrite +all the approaches to support that $\bar{\mathcal{S}}$ is not +the complement of $\mathcal{S}$. + +In the code below, we see the unique coalitions/set of features to condition +on to generate the Monte Carlo samples for all coalitions and the number of +times that set of conditional features is needed in the symmetric causal Shapley +value framework for the set up above. We see that most of the conditional +distributions will now be remodeled eights times. For the `gaussian` approach, +which is very fast to estimate the conditional distributions, this does not +have a major impact on the time. However, for, e.g., the `ctree` approach which +is much slower, this will take a significant amount of extra time. The `vaeac` +approach trains only on these relevant coalitions. +```{r implementation_details, cache = TRUE} +S_causal_steps <- explanation_sym_cau$gaussian$internal$iter_list[[1]]$S_causal_steps +S_causal_unlist <- do.call(c, unlist(S_causal_steps, recursive = FALSE)) +S_causal_steps_freq <- S_causal_unlist[grepl("\\.S(?!bar)", names(S_causal_unlist), perl = TRUE)] +S_causal_steps_freq <- S_causal_steps_freq[!sapply(S_causal_steps_freq, is.null)] # Remove NULLs +S_causal_steps_freq <- S_causal_steps_freq[sapply(S_causal_steps_freq, length) > 0] # Remove extra integer(0) +table(sapply(S_causal_steps_freq, paste0, collapse = ",")) +``` + +The `independence`, `empirical`, `ctree`, and `categorical` approaches produce +weighted Monte Carlo samples. That means that they do not necessarily generate +`n_MC_samples`. To ensure `n_MC_samples`, we sample `n_MC_samples` samples using weighted +sampling with replacements where the weights are the weights returned by the approaches. + +The marginal Shapley value explanation framework can be extended to +support modeling the marginal distributions using the `copula` and +`vaeac` approaches as both of these methods support unconditional sampling. + + +# References diff --git a/vignettes/understanding_shapr_regression.Rmd b/vignettes/understanding_shapr_regression.Rmd index 964ae972bf8498c71dc6f0f5dabdee4187ea4943..84cd223d792f0b1647c91d72727c751e7410147a 100644 --- a/vignettes/understanding_shapr_regression.Rmd +++ b/vignettes/understanding_shapr_regression.Rmd @@ -167,7 +167,7 @@ the `tidymodels` framework: `parsnip`, `recipes`, `workflows`, which package the functions originate from in the `tidymodels` framework. -```r +``` r # Either use `library(tidymodels)` or separately specify the libraries indicated above library(tidymodels) @@ -213,7 +213,7 @@ functions that plot and summarize the results of the explanation methods. This code block is optional to understand and can be skipped. -```r +``` r # Plot the MSEv criterion scores as horizontal bars and add dashed line of one method's score plot_MSEv_scores <- function(explanation_list, method_line = NULL) { fig <- plot_MSEv_eval_crit(explanation_list) + @@ -256,18 +256,32 @@ with default hyperparameters. In the last section, we include all Monte Carlo-based methods implemented in `shapr` to make an extensive comparison. -```r +``` r # Compute the Shapley value explanations using the empirical method explanation_list$MC_empirical <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:09:54 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: empirical +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553378f592d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -276,18 +290,32 @@ Then we compute the Shapley value explanations using a linear regression model and the separate regression method class. -```r +``` r explanation_list$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg() ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:00 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533c20e191.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` A linear model is often not flexible enough to properly model the @@ -297,7 +325,7 @@ outperforms the linear regression model approach quite significantly concerning the $\operatorname{MSE}_v$ evaluation criterion. -```r +``` r plot_MSEv_scores(explanation_list) ``` @@ -335,13 +363,12 @@ the feature itself. This regression model is called principal component regression. -```r +``` r explanation_list$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -350,6 +377,21 @@ explanation_list$sep_pcr <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:01 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55318b105b2.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Second, we apply a pre-processing step that computes the basis @@ -357,13 +399,12 @@ expansions of the features using natural splines with two degrees of freedom. This is similar to fitting a generalized additive model. -```r +``` r explanation_list$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -372,6 +413,21 @@ explanation_list$sep_splines <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:02 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553a209912.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` Finally, we provide an example where we include interactions @@ -385,7 +441,7 @@ Furthermore, we stress that the purpose of this example is to highlight the framework's flexibility, NOT that the transformations below are reasonable. -```r +``` r # Example function of how to apply step functions from the recipes package to specific features regression.recipe_func <- function(recipe) { # Get the names of the present features @@ -419,14 +475,28 @@ explanation_list$sep_reicpe_example <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = regression.recipe_func ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:03 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55334f61d01.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can examine the $\operatorname{MSE}_v$ evaluation scores, and we @@ -434,23 +504,23 @@ see that the method using natural splines significantly outperforms the other methods. -```r +``` r # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` ![](figure_regression/preproc-plot-1.png) -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 ``` @@ -472,14 +542,13 @@ we see that the default hyperparameter values for the model are `tree_depth = 30`, `min_n = 2`, and `cost_complexity = 0.01`. -```r +``` r # Decision tree with specified parameters (stumps) explanation_list$sep_tree_stump <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = 1, @@ -491,19 +560,48 @@ explanation_list$sep_tree_stump <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:04 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553108eb1.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list$sep_tree_default <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534d028986.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can also set `regression.model = parsnip::decision_tree(tree_depth = 1, min_n = 2, cost_complexity = 0.01) %>% parsnip::set_engine("rpart") %>% parsnip::set_mode("regression")` @@ -516,24 +614,24 @@ the empirical approach. We obtained a worse method by using stumps, i.e., trees with depth one. -```r +``` r # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` ![](figure_regression/decision-tree-plot-1.png) -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 ``` @@ -573,7 +671,7 @@ Note that `dials` have several other grid functions, e.g., `dials::grid_random() and `dials::grid_latin_hypercube()`. -```r +``` r # Possible ways to define the `regression.tune_values` object. # function(x) dials::grid_regular(dials::tree_depth(), levels = 4) dials::grid_regular(dials::tree_depth(), levels = 4) @@ -594,14 +692,13 @@ both the `tree_depth` and `cost_complexity` parameters, but we will manually specify the possible hyperparameter values this time. -```r +``` r # Decision tree with cross validated depth (default values other parameters) explanation_list$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), engine = "rpart", mode = "regression" @@ -611,14 +708,28 @@ explanation_list$sep_tree_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:06 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553173580dc.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list$sep_tree_cv_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -632,6 +743,21 @@ explanation_list$sep_tree_cv_2 <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:19 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531b0af982.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We also include one example with a random forest model where @@ -640,34 +766,46 @@ Thus, `regression.tune_values` must be a function that returns a data.frame where the hyperparameter values for `mtry` will change based on the coalition size. If we do not let `regression.tune_values` be a function, then `tidymodels` will crash for any `mtry` higher -than 1. Furthermore, by setting `verbose = 2`, we receive messages -about which batch and coalition/combination that `shapr` processes -and the results of the cross-validation procedure. Note that the tested -hyperparameter value combinations change based on the coalition size. +than 1. Furthermore, by setting letting `"vS_details" %in% verbose`, +we receive get messages with the results of the cross-validation procedure ran within `shapr`. +Note that the tested hyperparameter value combinations change based on the coalition size. -```r +``` r # Using random forest with default parameters explanation_list$sep_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:45 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535c02b48f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation explanation_list$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # One batch to get printouts in chronological order - verbose = 2, # To get printouts + phi0 = p0, + verbose = c("basic","vS_details"), # To get printouts approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -680,139 +818,142 @@ explanation_list$sep_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Starting 'setup_approach.regression_separate'. -#> When using `approach = 'regression_separate'` the `explanation$timing$timing_secs` object -#> can be missleading as `setup_computation` does not contain the training times of the -#> regression models as they are trained on the fly in `compute_vS`. This is to reduce memory -#> usage and to improve efficency. -#> Done with 'setup_approach.regression_separate'. -#> Working on batch 1 of 1 in `prepare_data.regression_separate()`. -#> Working on combination with id 2 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 34.85 rmse_std_err = 2.99 -#> #2: mtry = 1 trees = 400 rmse = 34.95 rmse_std_err = 3.05 -#> #3: mtry = 1 trees = 50 rmse = 34.99 rmse_std_err = 2.81 -#> -#> Working on combination with id 3 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 27.48 rmse_std_err = 1.50 -#> #2: mtry = 1 trees = 750 rmse = 27.52 rmse_std_err = 1.29 -#> #3: mtry = 1 trees = 400 rmse = 27.74 rmse_std_err = 1.30 -#> -#> Working on combination with id 4 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 23.60 rmse_std_err = 3.17 -#> #2: mtry = 1 trees = 750 rmse = 23.63 rmse_std_err = 3.17 -#> #3: mtry = 1 trees = 50 rmse = 24.24 rmse_std_err = 3.37 -#> -#> Working on combination with id 5 of 16. -#> Results of the 5-fold cross validation (top 3 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 33.31 rmse_std_err = 2.81 -#> #2: mtry = 1 trees = 750 rmse = 33.34 rmse_std_err = 2.81 -#> #3: mtry = 1 trees = 50 rmse = 33.41 rmse_std_err = 2.87 -#> -#> Working on combination with id 6 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 21.25 rmse_std_err = 2.24 -#> #2: mtry = 1 trees = 400 rmse = 21.69 rmse_std_err = 2.38 -#> #3: mtry = 1 trees = 750 rmse = 21.81 rmse_std_err = 2.40 -#> #4: mtry = 2 trees = 400 rmse = 22.38 rmse_std_err = 2.11 -#> #5: mtry = 2 trees = 750 rmse = 22.68 rmse_std_err = 2.04 -#> #6: mtry = 2 trees = 50 rmse = 22.91 rmse_std_err = 1.97 -#> -#> Working on combination with id 7 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 2 trees = 50 rmse = 22.18 rmse_std_err = 2.93 -#> #2: mtry = 2 trees = 400 rmse = 22.28 rmse_std_err = 2.74 -#> #3: mtry = 1 trees = 750 rmse = 22.31 rmse_std_err = 2.90 -#> #4: mtry = 2 trees = 750 rmse = 22.35 rmse_std_err = 2.76 -#> #5: mtry = 1 trees = 400 rmse = 22.40 rmse_std_err = 2.80 -#> #6: mtry = 1 trees = 50 rmse = 22.62 rmse_std_err = 2.71 -#> -#> Working on combination with id 8 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 29.35 rmse_std_err = 2.17 -#> #2: mtry = 1 trees = 400 rmse = 29.45 rmse_std_err = 2.37 -#> #3: mtry = 1 trees = 750 rmse = 29.57 rmse_std_err = 2.32 -#> #4: mtry = 2 trees = 750 rmse = 30.43 rmse_std_err = 2.21 -#> #5: mtry = 2 trees = 400 rmse = 30.49 rmse_std_err = 2.18 -#> #6: mtry = 2 trees = 50 rmse = 30.51 rmse_std_err = 2.19 -#> -#> Working on combination with id 9 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 18.61 rmse_std_err = 1.56 -#> #2: mtry = 2 trees = 400 rmse = 18.63 rmse_std_err = 1.56 -#> #3: mtry = 1 trees = 400 rmse = 18.80 rmse_std_err = 1.55 -#> #4: mtry = 2 trees = 750 rmse = 19.00 rmse_std_err = 1.70 -#> #5: mtry = 1 trees = 50 rmse = 19.02 rmse_std_err = 1.86 -#> #6: mtry = 2 trees = 50 rmse = 19.50 rmse_std_err = 1.72 -#> -#> Working on combination with id 10 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 23.61 rmse_std_err = 1.61 -#> #2: mtry = 1 trees = 50 rmse = 23.72 rmse_std_err = 1.49 -#> #3: mtry = 1 trees = 750 rmse = 23.79 rmse_std_err = 1.64 -#> #4: mtry = 2 trees = 750 rmse = 23.86 rmse_std_err = 0.83 -#> #5: mtry = 2 trees = 400 rmse = 23.91 rmse_std_err = 0.80 -#> #6: mtry = 2 trees = 50 rmse = 24.74 rmse_std_err = 0.68 -#> -#> Working on combination with id 11 of 16. -#> Results of the 5-fold cross validation (top 6 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 22.99 rmse_std_err = 4.29 -#> #2: mtry = 1 trees = 750 rmse = 23.08 rmse_std_err = 4.33 -#> #3: mtry = 1 trees = 50 rmse = 23.16 rmse_std_err = 4.28 -#> #4: mtry = 2 trees = 50 rmse = 23.80 rmse_std_err = 3.70 -#> #5: mtry = 2 trees = 400 rmse = 23.85 rmse_std_err = 3.72 -#> #6: mtry = 2 trees = 750 rmse = 24.07 rmse_std_err = 3.79 -#> -#> Working on combination with id 12 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 16.86 rmse_std_err = 2.19 -#> #2: mtry = 1 trees = 400 rmse = 16.90 rmse_std_err = 1.83 -#> #3: mtry = 1 trees = 750 rmse = 16.91 rmse_std_err = 1.93 -#> #4: mtry = 2 trees = 50 rmse = 17.47 rmse_std_err = 1.47 -#> #5: mtry = 2 trees = 750 rmse = 17.53 rmse_std_err = 1.77 -#> #6: mtry = 2 trees = 400 rmse = 17.82 rmse_std_err = 1.67 -#> #7: mtry = 3 trees = 50 rmse = 18.03 rmse_std_err = 1.84 -#> #8: mtry = 3 trees = 750 rmse = 18.47 rmse_std_err = 1.91 -#> #9: mtry = 3 trees = 400 rmse = 18.49 rmse_std_err = 1.82 -#> -#> Working on combination with id 13 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 50 rmse = 19.27 rmse_std_err = 2.13 -#> #2: mtry = 2 trees = 750 rmse = 19.80 rmse_std_err = 1.59 -#> #3: mtry = 1 trees = 750 rmse = 20.03 rmse_std_err = 1.95 -#> #4: mtry = 2 trees = 400 rmse = 20.21 rmse_std_err = 1.59 -#> #5: mtry = 3 trees = 50 rmse = 20.42 rmse_std_err = 1.64 -#> #6: mtry = 1 trees = 400 rmse = 20.49 rmse_std_err = 2.13 -#> #7: mtry = 2 trees = 50 rmse = 20.59 rmse_std_err = 1.26 -#> #8: mtry = 3 trees = 400 rmse = 20.61 rmse_std_err = 1.68 -#> #9: mtry = 3 trees = 750 rmse = 20.85 rmse_std_err = 1.74 -#> -#> Working on combination with id 14 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 750 rmse = 21.96 rmse_std_err = 3.12 -#> #2: mtry = 1 trees = 400 rmse = 22.36 rmse_std_err = 2.96 -#> #3: mtry = 1 trees = 50 rmse = 22.53 rmse_std_err = 3.01 -#> #4: mtry = 2 trees = 750 rmse = 22.59 rmse_std_err = 2.53 -#> #5: mtry = 2 trees = 400 rmse = 22.76 rmse_std_err = 2.39 -#> #6: mtry = 2 trees = 50 rmse = 22.80 rmse_std_err = 2.41 -#> #7: mtry = 3 trees = 400 rmse = 23.19 rmse_std_err = 2.26 -#> #8: mtry = 3 trees = 750 rmse = 23.42 rmse_std_err = 2.07 -#> #9: mtry = 3 trees = 50 rmse = 23.69 rmse_std_err = 2.22 -#> -#> Working on combination with id 15 of 16. -#> Results of the 5-fold cross validation (top 9 best configurations): -#> #1: mtry = 1 trees = 400 rmse = 18.33 rmse_std_err = 2.07 -#> #2: mtry = 1 trees = 750 rmse = 18.59 rmse_std_err = 2.25 -#> #3: mtry = 2 trees = 750 rmse = 18.78 rmse_std_err = 1.59 -#> #4: mtry = 2 trees = 400 rmse = 18.81 rmse_std_err = 1.58 -#> #5: mtry = 3 trees = 50 rmse = 18.93 rmse_std_err = 1.53 -#> #6: mtry = 3 trees = 400 rmse = 19.11 rmse_std_err = 1.57 -#> #7: mtry = 3 trees = 750 rmse = 19.17 rmse_std_err = 1.71 -#> #8: mtry = 2 trees = 50 rmse = 19.18 rmse_std_err = 1.33 -#> #9: mtry = 1 trees = 50 rmse = 19.94 rmse_std_err = 2.02 +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:10:46 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534bc71658.rds' +#> +#> ── Additional details about the regression model +#> Random Forest Model Specification (regression) +#> +#> Main Arguments: mtry = hardhat::tune() trees = hardhat::tune() +#> +#> Computational engine: ranger +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the tuning of the regression model ── +#> +#> ── Top 6 best configs for v(1 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 28.43 rmse_std_err = 3.02 +#> #2: mtry = 1 trees = 750 rmse = 28.76 rmse_std_err = 2.57 +#> #3: mtry = 1 trees = 400 rmse = 28.80 rmse_std_err = 2.64 +#> #4: mtry = 2 trees = 50 rmse = 29.27 rmse_std_err = 2.29 +#> #5: mtry = 2 trees = 400 rmse = 29.42 rmse_std_err = 2.40 +#> #6: mtry = 2 trees = 750 rmse = 29.46 rmse_std_err = 2.20 +#> +#> ── Top 6 best configs for v(2 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 21.12 rmse_std_err = 0.73 +#> #2: mtry = 1 trees = 750 rmse = 21.21 rmse_std_err = 0.66 +#> #3: mtry = 2 trees = 400 rmse = 21.27 rmse_std_err = 1.02 +#> #4: mtry = 2 trees = 750 rmse = 21.31 rmse_std_err = 1.01 +#> #5: mtry = 1 trees = 400 rmse = 21.34 rmse_std_err = 0.69 +#> #6: mtry = 2 trees = 50 rmse = 21.65 rmse_std_err = 0.94 +#> +#> ── Top 6 best configs for v(1 3) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 21.34 rmse_std_err = 3.18 +#> #2: mtry = 1 trees = 400 rmse = 21.56 rmse_std_err = 3.13 +#> #3: mtry = 1 trees = 750 rmse = 21.68 rmse_std_err = 3.13 +#> #4: mtry = 2 trees = 50 rmse = 21.79 rmse_std_err = 3.10 +#> #5: mtry = 2 trees = 750 rmse = 21.85 rmse_std_err = 2.98 +#> #6: mtry = 2 trees = 400 rmse = 21.89 rmse_std_err = 2.97 +#> +#> ── Top 6 best configs for v(3 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 22.94 rmse_std_err = 4.33 +#> #2: mtry = 1 trees = 400 rmse = 23.13 rmse_std_err = 4.23 +#> #3: mtry = 1 trees = 50 rmse = 23.43 rmse_std_err = 4.13 +#> #4: mtry = 2 trees = 400 rmse = 23.86 rmse_std_err = 3.77 +#> #5: mtry = 2 trees = 750 rmse = 24.00 rmse_std_err = 3.78 +#> #6: mtry = 2 trees = 50 rmse = 24.57 rmse_std_err = 4.08 +#> +#> ── Top 6 best configs for v(2 3) (using 5-fold CV) +#> #1: mtry = 2 trees = 50 rmse = 17.46 rmse_std_err = 2.26 +#> #2: mtry = 2 trees = 750 rmse = 17.53 rmse_std_err = 2.43 +#> #3: mtry = 2 trees = 400 rmse = 17.64 rmse_std_err = 2.38 +#> #4: mtry = 1 trees = 750 rmse = 17.80 rmse_std_err = 2.09 +#> #5: mtry = 1 trees = 50 rmse = 17.81 rmse_std_err = 1.79 +#> #6: mtry = 1 trees = 400 rmse = 17.89 rmse_std_err = 2.13 +#> +#> ── Top 3 best configs for v(3) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 22.55 rmse_std_err = 4.68 +#> #2: mtry = 1 trees = 400 rmse = 22.59 rmse_std_err = 4.63 +#> #3: mtry = 1 trees = 750 rmse = 22.64 rmse_std_err = 4.65 +#> +#> ── Top 6 best configs for v(1 2) (using 5-fold CV) +#> #1: mtry = 1 trees = 400 rmse = 21.57 rmse_std_err = 2.25 +#> #2: mtry = 1 trees = 750 rmse = 21.59 rmse_std_err = 2.29 +#> #3: mtry = 1 trees = 50 rmse = 22.38 rmse_std_err = 2.10 +#> #4: mtry = 2 trees = 400 rmse = 22.54 rmse_std_err = 2.09 +#> #5: mtry = 2 trees = 750 rmse = 22.65 rmse_std_err = 2.09 +#> #6: mtry = 2 trees = 50 rmse = 23.12 rmse_std_err = 2.23 +#> +#> ── Top 3 best configs for v(4) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 32.14 rmse_std_err = 4.32 +#> #2: mtry = 1 trees = 400 rmse = 32.21 rmse_std_err = 4.31 +#> #3: mtry = 1 trees = 50 rmse = 32.21 rmse_std_err = 4.25 +#> +#> ── Top 3 best configs for v(1) (using 5-fold CV) +#> #1: mtry = 1 trees = 50 rmse = 30.34 rmse_std_err = 3.40 +#> #2: mtry = 1 trees = 750 rmse = 30.53 rmse_std_err = 3.31 +#> #3: mtry = 1 trees = 400 rmse = 30.63 rmse_std_err = 3.32 #> +#> ── Top 3 best configs for v(2) (using 5-fold CV) +#> #1: mtry = 1 trees = 750 rmse = 26.62 rmse_std_err = 2.33 +#> #2: mtry = 1 trees = 400 rmse = 26.72 rmse_std_err = 2.29 +#> #3: mtry = 1 trees = 50 rmse = 26.97 rmse_std_err = 2.24 +#> +#> ── Top 9 best configs for v(1 2 4) (using 5-fold CV) +#> #1: mtry = 2 trees = 750 rmse = 19.81 rmse_std_err = 1.53 +#> #2: mtry = 2 trees = 400 rmse = 19.85 rmse_std_err = 1.64 +#> #3: mtry = 1 trees = 750 rmse = 19.93 rmse_std_err = 1.93 +#> #4: mtry = 1 trees = 400 rmse = 20.18 rmse_std_err = 1.90 +#> #5: mtry = 2 trees = 50 rmse = 20.41 rmse_std_err = 1.56 +#> #6: mtry = 3 trees = 50 rmse = 20.69 rmse_std_err = 1.54 +#> #7: mtry = 3 trees = 750 rmse = 20.74 rmse_std_err = 1.69 +#> #8: mtry = 3 trees = 400 rmse = 20.77 rmse_std_err = 1.76 +#> #9: mtry = 1 trees = 50 rmse = 20.79 rmse_std_err = 1.89 +#> +#> ── Top 9 best configs for v(1 2 3) (using 5-fold CV) +#> #1: mtry = 2 trees = 400 rmse = 16.16 rmse_std_err = 2.75 +#> #2: mtry = 3 trees = 400 rmse = 16.30 rmse_std_err = 2.80 +#> #3: mtry = 2 trees = 750 rmse = 16.41 rmse_std_err = 2.79 +#> #4: mtry = 3 trees = 750 rmse = 16.43 rmse_std_err = 2.82 +#> #5: mtry = 3 trees = 50 rmse = 16.52 rmse_std_err = 2.52 +#> #6: mtry = 1 trees = 750 rmse = 16.69 rmse_std_err = 3.15 +#> #7: mtry = 2 trees = 50 rmse = 16.89 rmse_std_err = 2.76 +#> #8: mtry = 1 trees = 400 rmse = 16.98 rmse_std_err = 2.93 +#> #9: mtry = 1 trees = 50 rmse = 17.69 rmse_std_err = 3.16 +#> +#> ── Top 9 best configs for v(1 3 4) (using 5-fold CV) +#> #1: mtry = 1 trees = 400 rmse = 21.88 rmse_std_err = 4.33 +#> #2: mtry = 1 trees = 750 rmse = 21.96 rmse_std_err = 4.38 +#> #3: mtry = 1 trees = 50 rmse = 22.03 rmse_std_err = 4.07 +#> #4: mtry = 2 trees = 400 rmse = 22.65 rmse_std_err = 4.11 +#> #5: mtry = 2 trees = 750 rmse = 22.72 rmse_std_err = 4.09 +#> #6: mtry = 2 trees = 50 rmse = 22.89 rmse_std_err = 3.97 +#> #7: mtry = 3 trees = 400 rmse = 23.38 rmse_std_err = 3.80 +#> #8: mtry = 3 trees = 750 rmse = 23.50 rmse_std_err = 3.77 +#> #9: mtry = 3 trees = 50 rmse = 23.88 rmse_std_err = 3.64 +#> +#> ── Top 9 best configs for v(2 3 4) (using 5-fold CV) +#> #1: mtry = 3 trees = 50 rmse = 17.96 rmse_std_err = 1.34 +#> #2: mtry = 1 trees = 50 rmse = 17.97 rmse_std_err = 2.40 +#> #3: mtry = 1 trees = 750 rmse = 18.63 rmse_std_err = 1.99 +#> #4: mtry = 2 trees = 400 rmse = 18.76 rmse_std_err = 1.42 +#> #5: mtry = 1 trees = 400 rmse = 18.79 rmse_std_err = 2.14 +#> #6: mtry = 2 trees = 750 rmse = 18.80 rmse_std_err = 1.49 +#> #7: mtry = 3 trees = 750 rmse = 19.12 rmse_std_err = 1.68 +#> #8: mtry = 3 trees = 400 rmse = 19.14 rmse_std_err = 1.65 +#> #9: mtry = 2 trees = 50 rmse = 19.33 rmse_std_err = 1.67 ``` We can look at the $\operatorname{MSE}_v$ evaluation criterion, @@ -827,7 +968,7 @@ we include a vertical line at the $\operatorname{MSE}_v$ score of the `empirical` method for easier comparison. -```r +``` r plot_MSEv_scores(explanation_list, method_line = "MC_empirical") ``` @@ -842,21 +983,21 @@ This result indicates that even though we do hyperparameter tuning, we still overfit the data. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 ``` @@ -878,7 +1019,7 @@ parallel to speed up the computations. The fourth model is run in parallel but also tunes the depth of the trees and not only the number of trees. -A small side note: If we set `verbose = 2`, we can see which +A small side note: If we let `"vS_details" %in% verbose`, we can see which `tree` value `shapr` chooses for each coalition. We would then see that the values 25, 50, 100, and 500 are never chosen. Thus, we can remove these values without influencing the result @@ -886,27 +1027,40 @@ and instead do a finer grid search among the lower values. We do this in the fourth method. -```r +``` r # Regular xgboost with default parameters explanation_list$sep_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:21 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553b9eedb1.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Cross validate the number of trees explanation_list$sep_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -915,6 +1069,21 @@ explanation_list$sep_xgboost_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:22 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5536c6c263f.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Cross validate the number of trees in parallel on two threads future::plan(future::multisession, workers = 2) @@ -922,8 +1091,7 @@ explanation_list$sep_xgboost_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -932,6 +1100,21 @@ explanation_list$sep_xgboost_cv_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:37 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55375979516.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use a finer grid of low values for `trees` and also tune `tree_depth` future::plan(future::multisession, workers = 4) # Change to 4 threads due to more complex CV @@ -939,8 +1122,7 @@ explanation_list$sep_xgboost_cv_2_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -953,6 +1135,21 @@ explanation_list$sep_xgboost_cv_2_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:11:50 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553e0b863c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. future::plan(future::sequential) # To return to non-parallel computation ``` @@ -973,25 +1170,25 @@ note that we obtain the same value whether we run the cross-validation in parallel or sequentially. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 ``` @@ -1026,40 +1223,67 @@ cross-validation), and `xgboost` (with and without (some) cross-validation). -```r +``` r # Compute the Shapley value explanations using a surrogate linear regression model explanation_list$sur_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533f5d53fc.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using xgboost with default parameters as the surrogate model explanation_list$sur_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:05 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553657c3c72.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using xgboost with parameters tuned by cross-validation as the surrogate model explanation_list$sur_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1072,27 +1296,55 @@ explanation_list$sur_xgboost_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:06 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55349e5a38c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with default parameters as the surrogate model explanation_list$sur_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:08 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553eebc9ea.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation as the surrogate model explanation_list$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1106,6 +1358,21 @@ explanation_list$sur_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537965b6b3.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1127,15 +1394,14 @@ can cause it to be slower than running the code sequentially for smaller problems. -```r +``` r # Cross validate the number of trees in parallel on four threads future::plan(future::multisession, workers = 4) explanation_list$sur_rf_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1149,12 +1415,27 @@ explanation_list$sur_rf_cv_par <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:37 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533f7681e8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. future::plan(future::sequential) # To return to non-parallel computation # Check that we get identical Shapley value explanations all.equal( - explanation_list$sur_rf_cv$shapley_values, - explanation_list$sur_rf_cv_par$shapley_values + explanation_list$sur_rf_cv$shapley_values_est, + explanation_list$sur_rf_cv_par$shapley_values_est ) #> [1] TRUE ``` @@ -1170,31 +1451,31 @@ identical and independent of whether they were run sequentially or in parallel. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 # Compare the MSEv criterion of the different explanation methods. # Include vertical line corresponding to the MSEv of the empirical method. @@ -1221,7 +1502,7 @@ on adding new regression models. We refer to that guide for more details and explanations of the code below. -```r +``` r # Step 1: register the model, modes, and arguments parsnip::set_new_model(model = "ppr_reg") parsnip::set_model_mode(model = "ppr_reg", mode = "regression") @@ -1338,27 +1619,40 @@ terms `num_terms` to a specific value or use cross-validation to tune the hyperparameter. We do all four combinations below. -```r +``` r # PPR separate with specified number of terms explanation_list$sep_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = 2) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:58 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553791592c7.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR separate with cross-validated number of terms explanation_list$sep_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 4)), levels = 3), @@ -1366,27 +1660,55 @@ explanation_list$sep_ppr_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:12:58 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531ac3859d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR surrogate with specified number of terms explanation_list$sur_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = 3) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55339bdd72a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # PPR surrogate with cross-validated number of terms explanation_list$sur_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 8)), levels = 4), @@ -1394,6 +1716,21 @@ explanation_list$sur_ppr_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:09 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532987aff5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` We can then compare the $\operatorname{MSE}_v$ and some of the Shapley value explanations. @@ -1401,35 +1738,35 @@ We see that conducting cross-validation improves the evaluation criterion, but also increase the running time. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_empirical 179.43 2.22 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 -#> sep_ppr 327.23 1.41 -#> sep_ppr_cv 269.74 15.46 -#> sur_ppr 395.42 0.29 -#> sur_ppr_cv 415.62 1.86 +#> MC_empirical 179.43 5.43 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 +#> sep_ppr 327.23 0.79 +#> sep_ppr_cv 246.28 10.40 +#> sur_ppr 395.42 0.47 +#> sur_ppr_cv 415.62 1.63 # Compare the MSEv criterion of the different explanation methods plot_MSEv_scores(explanation_list, method_line = "MC_empirical") @@ -1450,7 +1787,7 @@ In the code chunk below, we compute the Shapley value explanations using the different Monte Carlo-based methods. -```r +``` r explanation_list_MC <- list() # Compute the Shapley value explanations using the independence method @@ -1458,12 +1795,26 @@ explanation_list_MC$MC_independence <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "independence", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:11 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: independence +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533cae2265.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Copy the Shapley value explanations for the empirical method explanation_list_MC$MC_empirical <- explanation_list$MC_empirical @@ -1473,49 +1824,105 @@ explanation_list_MC$MC_gaussian <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "gaussian", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:12 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: gaussian +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532491f5ab.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the copula method explanation_list_MC$MC_copula <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "copula", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:13 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: copula +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553451ae5c5.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the ctree method explanation_list_MC$MC_ctree <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:13 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: ctree +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533d628d5e.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Compute the Shapley value explanations using the vaeac method explanation_list_MC$MC_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "vaeac", - prediction_zero = p0, + phi0 = p0, vaeac.epochs = 10 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:13:15 ─────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5534050514.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Combine the two explanations lists explanation_list$MC_empirical <- NULL @@ -1528,40 +1935,40 @@ include a vertical line corresponding to the $\operatorname{MSE}_v$ of the `MC_empirical` method to make the comparison easier. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list) #> MSEv Time -#> MC_independence 206.92 0.50 -#> MC_empirical 179.43 2.22 -#> MC_gaussian 245.19 0.49 -#> MC_copula 247.29 0.46 -#> MC_ctree 191.82 1.72 -#> MC_vaeac 141.88 72.61 -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_reicpe_example 687.45 1.74 -#> sep_tree_stump 218.05 1.03 -#> sep_tree_default 177.68 0.89 -#> sep_tree_cv 169.96 17.31 -#> sep_tree_cv_2 166.17 35.01 -#> sep_rf 210.99 1.58 -#> sep_rf_cv 212.88 38.41 -#> sep_xgboost 197.72 0.99 -#> sep_xgboost_cv 164.69 20.72 -#> sep_xgboost_cv_par 164.69 17.53 -#> sep_xgboost_cv_2_par 146.51 21.94 -#> sur_lm 649.61 0.31 -#> sur_xgboost 169.92 0.26 -#> sur_xgboost_cv 169.87 2.37 -#> sur_rf 195.10 0.52 -#> sur_rf_cv 171.84 30.55 -#> sur_rf_cv_par 171.84 33.24 -#> sep_ppr 327.23 1.41 -#> sep_ppr_cv 269.74 15.46 -#> sur_ppr 395.42 0.29 -#> sur_ppr_cv 415.62 1.86 +#> MC_independence 206.92 0.66 +#> MC_empirical 179.43 5.43 +#> MC_gaussian 235.15 0.52 +#> MC_copula 237.35 0.52 +#> MC_ctree 190.82 1.56 +#> MC_vaeac 145.06 2.09 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_reicpe_example 687.45 1.28 +#> sep_tree_stump 218.05 0.80 +#> sep_tree_default 177.68 0.79 +#> sep_tree_cv 222.71 12.72 +#> sep_tree_cv_2 219.45 25.69 +#> sep_rf 217.00 1.42 +#> sep_rf_cv 212.64 34.73 +#> sep_xgboost 197.72 0.95 +#> sep_xgboost_cv 169.83 14.95 +#> sep_xgboost_cv_par 169.83 12.55 +#> sep_xgboost_cv_2_par 153.13 14.34 +#> sur_lm 649.61 0.48 +#> sur_xgboost 169.92 0.50 +#> sur_xgboost_cv 169.87 2.10 +#> sur_rf 201.23 0.77 +#> sur_rf_cv 172.09 27.69 +#> sur_rf_cv_par 172.09 19.73 +#> sep_ppr 327.23 0.79 +#> sep_ppr_cv 246.28 10.40 +#> sur_ppr 395.42 0.47 +#> sur_ppr_cv 415.62 1.63 # Compare the MSEv criterion of the different explanation methods # Include vertical line corresponding to the MSEv of the MC_empirical method @@ -1580,7 +1987,7 @@ We can also order the methods to more easily look at the order of the methods according to the $\operatorname{MSE}_v$ criterion. -```r +``` r order <- get_k_best_methods(explanation_list, k = length(explanation_list)) plot_MSEv_scores(explanation_list[order], method_line = "MC_empirical") ``` @@ -1595,19 +2002,19 @@ some differences for the less important features. These tendencies/discrepancies are often more visible for the methods with poor/larger $\operatorname{MSE}_v$ values. -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(1, 2), facet_ncol = 1) ``` ![](figure_regression/SV-sum-1.png) -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(3, 4), facet_ncol = 1) ``` ![](figure_regression/SV-sum-2.png) -```r +``` r plot_SV_several_approaches(explanation_list[order], index_explicands = c(5, 6), facet_ncol = 1) ``` @@ -1618,7 +2025,7 @@ easier to analyze the individual Shapley value explanations, and we see a quite strong agreement between the different methods. -```r +``` r # Extract the 5 best methods (and empirical) best_methods <- get_k_best_methods(explanation_list, k = 5) if (!"MC_empirical" %in% best_methods) best_methods <- c(best_methods, "MC_empirical") @@ -1646,7 +2053,7 @@ this below using the `regression.recipe_func` function. First, we copy the setup from the main vignette. -```r +``` r # convert the month variable to a factor data_cat <- copy(data)[, Month_factor := as.factor(Month)] @@ -1677,33 +2084,78 @@ explanation_list_mixed <- list() Second, we compute the explanations using the Monte Carlo-based methods. -```r +``` r explanation_list_mixed$MC_independence <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "independence" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:23 ──────────────────────── +#> • Model class: +#> • Approach: independence +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e55313bdf15c.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_mixed$MC_ctree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "ctree" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:24 ──────────────────────── +#> • Model class: +#> • Approach: ctree +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e55371bbf8d6.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_mixed$MC_vaeac <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "vaeac" ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:15:26 ──────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: +#> '/tmp/RtmpRxPm0I/shapr_obj_10e553641ecd3b.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1714,50 +2166,91 @@ regression methods. We use many of the same regression models as we did above for the continuous data examples. -```r +``` r # Standard linear regression explanation_list_mixed$sep_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg() ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:46 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533131b08d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Linear regression where we have added splines to the numerical features explanation_list_mixed$sep_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { return(step_ns(regression_recipe, all_numeric_predictors(), deg_free = 2)) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:47 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55366111c0d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list_mixed$sep_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:48 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531ad27dab.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list_mixed$sep_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1769,25 +2262,53 @@ explanation_list_mixed$sep_tree_cv <- explain( expand.grid(tree_depth = c(1, 3, 5), cost_complexity = c(0.001, 0.01, 0.1)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:18:49 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55348ac2c82.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with default hyperparameters. Do NOT need to use dummy features. explanation_list_mixed$sep_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:19:18 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55369de7bb8.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with cross validated hyperparameters. explanation_list_mixed$sep_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1798,28 +2319,56 @@ explanation_list_mixed$sep_rf_cv <- explain( }, regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:19:20 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537f540ca9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with default hyperparameters, but we have to dummy encode the factors explanation_list_mixed$sep_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { return(step_dummy(regression_recipe, all_factor_predictors())) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:13 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533286e2bf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with cross validated hyperparameters and we dummy encode the factors explanation_list_mixed$sep_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1833,6 +2382,21 @@ explanation_list_mixed$sep_xgboost_cv <- explain( regression.tune_values = expand.grid(trees = c(5, 15, 25), tree_depth = c(2, 6, 10)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:14 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531fa7a245.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1843,17 +2407,31 @@ regression methods. We use the same regression models as we did above for separate regression method class. -```r +``` r # Standard linear regression explanation_list_mixed$sur_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:33 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55365b26da6.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Linear regression where we have added splines to the numerical features # NOTE, that we remove the augmented mask variables to avoid a rank-deficient fit @@ -1861,33 +2439,60 @@ explanation_list_mixed$sur_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(recipe) { return(step_ns(recipe, all_numeric_predictors(), -starts_with("mask_"), deg_free = 2)) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:34 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5537f7cd475.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Decision tree with default parameters explanation_list_mixed$sur_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:34 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55342bb266a.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Use trees with cross-validation on the depth and cost complexity. Manually set the values. explanation_list_mixed$sur_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1899,25 +2504,53 @@ explanation_list_mixed$sur_tree_cv <- explain( expand.grid(tree_depth = c(1, 3, 5), cost_complexity = c(0.001, 0.01, 0.1)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:35 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553263d5b45.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with default hyperparameters. Do NOT need to use dummy features. explanation_list_mixed$sur_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:37 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5536f402f15.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Random forest with cross validated hyperparameters. explanation_list_mixed$sur_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1925,28 +2558,56 @@ explanation_list_mixed$sur_rf_cv <- explain( regression.tune_values = expand.grid(mtry = c(1, 2, 4), trees = c(50, 250, 500, 750)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:38 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55321ef0397.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with default hyperparameters, but we have to dummy encode the factors explanation_list_mixed$sur_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { return(step_dummy(regression_recipe, all_factor_predictors())) } ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:52 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535b569440.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Xgboost with cross validated hyperparameters and we dummy encode the factors explanation_list_mixed$sur_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1960,6 +2621,21 @@ explanation_list_mixed$sur_xgboost_cv <- explain( regression.tune_values = expand.grid(trees = c(5, 15, 25), tree_depth = c(2, 6, 10)), regression.vfold_cv_para = list(v = 5) ) +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:52 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532e902f01.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. ``` @@ -1973,29 +2649,29 @@ methods. More specifically, three separate regression methods and three surrogate regression methods. -```r +``` r # Print the MSEv scores and the elapsed time (in seconds) for the different methods print_MSEv_scores_and_time(explanation_list_mixed) -#> MSEv Time -#> MC_independence 641.82 0.69 -#> MC_ctree 554.50 2.36 -#> MC_vaeac 629.43 147.26 -#> sep_lm 550.06 1.53 -#> sep_splines 541.36 1.80 -#> sep_tree 753.84 0.84 -#> sep_tree_cv 756.27 41.75 -#> sep_rf 521.79 1.10 -#> sep_rf_cv 609.58 51.42 -#> sep_xgboost 792.17 1.13 -#> sep_xgboost_cv 595.98 26.29 -#> sur_lm 610.61 0.51 -#> sur_splines 596.86 0.55 -#> sur_tree 677.04 0.38 -#> sur_tree_cv 789.37 3.34 -#> sur_rf 414.15 0.55 -#> sur_rf_cv 533.06 15.50 -#> sur_xgboost 606.92 0.40 -#> sur_xgboost_cv 429.06 3.05 +#> MSEv Time +#> MC_independence 641.82 0.80 +#> MC_ctree 555.58 1.99 +#> MC_vaeac 629.56 3.32 +#> sep_lm 550.06 0.78 +#> sep_splines 541.36 1.03 +#> sep_tree 753.84 0.87 +#> sep_tree_cv 756.27 29.41 +#> sep_rf 518.27 1.52 +#> sep_rf_cv 619.81 53.24 +#> sep_xgboost 792.17 1.08 +#> sep_xgboost_cv 595.98 18.29 +#> sur_lm 610.61 0.45 +#> sur_splines 596.86 0.50 +#> sur_tree 677.04 0.48 +#> sur_tree_cv 789.37 2.53 +#> sur_rf 407.76 0.76 +#> sur_rf_cv 520.63 13.70 +#> sur_xgboost 606.92 0.50 +#> sur_xgboost_cv 429.06 2.24 # Compare the MSEv criterion of the different explanation methods # Include vertical line corresponding to the MSEv of the empirical method. @@ -2013,7 +2689,7 @@ We can also order the methods to more easily look at the order of the methods according to the $\operatorname{MSE}_v$ criterion. -```r +``` r order <- get_k_best_methods(explanation_list_mixed, k = length(explanation_list_mixed)) plot_MSEv_scores(explanation_list_mixed[order], method_line = "MC_ctree") ``` @@ -2024,7 +2700,7 @@ We also look at some of the Shapley value explanations and see that many methods produce similar explanations. -```r +``` r plot_SV_several_approaches(explanation_list_mixed[order], index_explicands = c(1, 2), facet_ncol = 1) ``` @@ -2035,7 +2711,7 @@ methods according to the $\operatorname{MSE}_v$ criterion. We also include the `ctree` method, the best-performing Monte Carlo-based method. -```r +``` r best_methods <- get_k_best_methods(explanation_list_mixed, k = 5) if (!"MC_ctree" %in% best_methods) best_methods <- c(best_methods, "MC_ctree") plot_SV_several_approaches(explanation_list_mixed[best_methods], index_explicands = 1:4) @@ -2057,26 +2733,39 @@ that we obtain identical $\operatorname{MSE}_v$ scores for the string and non-string versions. -```r +``` r explanation_list_str <- list() explanation_list_str$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()" ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:57 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5532789a643.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()", regression.recipe_func = "function(regression_recipe) { @@ -2085,13 +2774,27 @@ explanation_list_str$sep_pcr <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:58 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e553c707510.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = "function(regression_recipe) { @@ -2100,13 +2803,27 @@ explanation_list_str$sep_splines <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:20:59 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5535c2a1de9.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. explanation_list_str$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::decision_tree( tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression' @@ -2116,14 +2833,28 @@ explanation_list_str$sep_tree_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:00 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5531d5a6c89.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation explanation_list_str$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # As we used this for the non-string version + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -2136,14 +2867,28 @@ explanation_list_str$sep_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:12 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_separate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e5533d1c027d.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # Using random forest with parameters tuned by cross-validation as the surrogate model explanation_list_str$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -2157,24 +2902,39 @@ explanation_list_str$sur_rf_cv <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-09 16:21:47 ──────────────────────────────────────────── +#> • Model class: +#> • Approach: regression_surrogate +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 20 +#> • Computations (temporary) saved at: '/tmp/RtmpRxPm0I/shapr_obj_10e55364f1a477.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. # See that the evaluation scores match the non-string versions. print_MSEv_scores_and_time(explanation_list_str) #> MSEv Time -#> sep_lm 745.21 1.14 -#> sep_pcr 784.91 1.19 -#> sep_splines 165.13 1.15 -#> sep_tree_cv 169.96 20.65 -#> sep_rf_cv 212.88 39.29 -#> sur_rf_cv 171.84 30.51 +#> sep_lm 745.21 0.74 +#> sep_pcr 784.91 0.95 +#> sep_splines 165.13 0.98 +#> sep_tree_cv 222.71 12.90 +#> sep_rf_cv 212.64 34.89 +#> sur_rf_cv 172.09 27.16 print_MSEv_scores_and_time(explanation_list[names(explanation_list_str)]) #> MSEv Time -#> sep_lm 745.21 0.77 -#> sep_pcr 784.91 1.32 -#> sep_splines 165.13 1.09 -#> sep_tree_cv 169.96 17.31 -#> sep_rf_cv 212.88 38.41 -#> sur_rf_cv 171.84 30.55 +#> sep_lm 745.21 0.73 +#> sep_pcr 784.91 0.90 +#> sep_splines 165.13 0.96 +#> sep_tree_cv 222.71 12.72 +#> sep_rf_cv 212.64 34.73 +#> sur_rf_cv 172.09 27.69 ``` diff --git a/vignettes/understanding_shapr_regression.Rmd.orig b/vignettes/understanding_shapr_regression.Rmd.orig index 8db1271ee400fa5b191c7a4f549ed1f889806789..5c170fcd5a7344b86aa0073eb7df56c37a408628 100644 --- a/vignettes/understanding_shapr_regression.Rmd.orig +++ b/vignettes/understanding_shapr_regression.Rmd.orig @@ -272,8 +272,7 @@ explanation_list$MC_empirical <- explain( x_explain = x_explain, x_train = x_train, approach = "empirical", - prediction_zero = p0, - n_batches = 4 + phi0 = p0 ) ``` @@ -287,8 +286,7 @@ explanation_list$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -340,8 +338,7 @@ explanation_list$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -359,8 +356,7 @@ explanation_list$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -413,8 +409,7 @@ explanation_list$sep_reicpe_example <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = regression.recipe_func @@ -457,8 +452,7 @@ explanation_list$sep_tree_stump <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = 1, @@ -474,8 +468,7 @@ explanation_list$sep_tree_default <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -559,8 +552,7 @@ explanation_list$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), engine = "rpart", mode = "regression" @@ -574,8 +566,7 @@ explanation_list$sep_tree_cv_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -595,10 +586,9 @@ Thus, `regression.tune_values` must be a function that returns a data.frame where the hyperparameter values for `mtry` will change based on the coalition size. If we do not let `regression.tune_values` be a function, then `tidymodels` will crash for any `mtry` higher -than 1. Furthermore, by setting `verbose = 2`, we receive messages -about which batch and coalition/combination that `shapr` processes -and the results of the cross-validation procedure. Note that the tested -hyperparameter value combinations change based on the coalition size. +than 1. Furthermore, by setting letting `"vS_details" %in% verbose`, +we receive get messages with the results of the cross-validation procedure ran within `shapr`. +Note that the tested hyperparameter value combinations change based on the coalition size. ```{r rf-cv, cache=TRUE} # Using random forest with default parameters @@ -606,8 +596,7 @@ explanation_list$sep_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -617,9 +606,8 @@ explanation_list$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # One batch to get printouts in chronological order - verbose = 2, # To get printouts + phi0 = p0, + verbose = c("basic","vS_details"), # To get printouts approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -679,7 +667,7 @@ parallel to speed up the computations. The fourth model is run in parallel but also tunes the depth of the trees and not only the number of trees. -A small side note: If we set `verbose = 2`, we can see which +A small side note: If we let `"vS_details" %in% verbose`, we can see which `tree` value `shapr` chooses for each coalition. We would then see that the values 25, 50, 100, and 500 are never chosen. Thus, we can remove these values without influencing the result @@ -692,8 +680,7 @@ explanation_list$sep_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) @@ -703,8 +690,7 @@ explanation_list$sep_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -718,8 +704,7 @@ explanation_list$sep_xgboost_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree(trees = hardhat::tune(), engine = "xgboost", mode = "regression"), @@ -733,8 +718,7 @@ explanation_list$sep_xgboost_cv_2_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -806,8 +790,7 @@ explanation_list$sur_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) @@ -817,8 +800,7 @@ explanation_list$sur_xgboost <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression") ) @@ -828,8 +810,7 @@ explanation_list$sur_xgboost_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -846,8 +827,7 @@ explanation_list$sur_rf <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -857,8 +837,7 @@ explanation_list$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -897,8 +876,7 @@ explanation_list$sur_rf_cv_par <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -914,8 +892,8 @@ future::plan(future::sequential) # To return to non-parallel computation # Check that we get identical Shapley value explanations all.equal( - explanation_list$sur_rf_cv$shapley_values, - explanation_list$sur_rf_cv_par$shapley_values + explanation_list$sur_rf_cv$shapley_values_est, + explanation_list$sur_rf_cv_par$shapley_values_est ) ``` @@ -1077,8 +1055,7 @@ explanation_list$sep_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = 2) ) @@ -1088,8 +1065,7 @@ explanation_list$sep_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 4)), levels = 3), @@ -1101,8 +1077,7 @@ explanation_list$sur_ppr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = 3) ) @@ -1112,8 +1087,7 @@ explanation_list$sur_ppr_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = ppr_reg(num_terms = hardhat::tune()), regression.tune_values = dials::grid_regular(dials::num_terms(c(1, 8)), levels = 4), @@ -1153,9 +1127,8 @@ explanation_list_MC$MC_independence <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "independence", - prediction_zero = p0 + phi0 = p0 ) # Copy the Shapley value explanations for the empirical method @@ -1166,9 +1139,8 @@ explanation_list_MC$MC_gaussian <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "gaussian", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the copula method @@ -1176,9 +1148,8 @@ explanation_list_MC$MC_copula <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "copula", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the ctree method @@ -1186,9 +1157,8 @@ explanation_list_MC$MC_ctree <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "ctree", - prediction_zero = p0 + phi0 = p0 ) # Compute the Shapley value explanations using the vaeac method @@ -1196,9 +1166,8 @@ explanation_list_MC$MC_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, - n_batches = 4, approach = "vaeac", - prediction_zero = p0, + phi0 = p0, vaeac.epochs = 10 ) @@ -1312,8 +1281,7 @@ explanation_list_mixed$MC_independence <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "independence" ) @@ -1321,8 +1289,7 @@ explanation_list_mixed$MC_ctree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "ctree" ) @@ -1330,8 +1297,7 @@ explanation_list_mixed$MC_vaeac <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "vaeac" ) ``` @@ -1349,8 +1315,7 @@ explanation_list_mixed$sep_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg() ) @@ -1360,8 +1325,7 @@ explanation_list_mixed$sep_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(regression_recipe) { @@ -1374,8 +1338,7 @@ explanation_list_mixed$sep_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -1385,8 +1348,7 @@ explanation_list_mixed$sep_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1404,8 +1366,7 @@ explanation_list_mixed$sep_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -1415,8 +1376,7 @@ explanation_list_mixed$sep_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1433,8 +1393,7 @@ explanation_list_mixed$sep_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { @@ -1447,8 +1406,7 @@ explanation_list_mixed$sep_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_separate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1477,8 +1435,7 @@ explanation_list_mixed$sur_lm <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg() ) @@ -1489,8 +1446,7 @@ explanation_list_mixed$sur_splines <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::linear_reg(), regression.recipe_func = function(recipe) { @@ -1503,8 +1459,7 @@ explanation_list_mixed$sur_tree <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree(engine = "rpart", mode = "regression") ) @@ -1514,8 +1469,7 @@ explanation_list_mixed$sur_tree_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::decision_tree( tree_depth = hardhat::tune(), @@ -1533,8 +1487,7 @@ explanation_list_mixed$sur_rf <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest(engine = "ranger", mode = "regression") ) @@ -1544,8 +1497,7 @@ explanation_list_mixed$sur_rf_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = "ranger", mode = "regression" @@ -1559,8 +1511,7 @@ explanation_list_mixed$sur_xgboost <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree(engine = "xgboost", mode = "regression"), regression.recipe_func = function(regression_recipe) { @@ -1573,8 +1524,7 @@ explanation_list_mixed$sur_xgboost_cv <- explain( model = model_cat, x_explain = x_explain_cat, x_train = x_train_cat, - prediction_zero = p0_cat, - n_batches = 4, + phi0 = p0_cat, approach = "regression_surrogate", regression.model = parsnip::boost_tree( trees = hardhat::tune(), @@ -1658,8 +1608,7 @@ explanation_list_str$sep_lm <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()" ) @@ -1668,8 +1617,7 @@ explanation_list_str$sep_pcr <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::linear_reg()", regression.recipe_func = "function(regression_recipe) { @@ -1681,8 +1629,7 @@ explanation_list_str$sep_splines <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = parsnip::linear_reg(), regression.recipe_func = "function(regression_recipe) { @@ -1694,8 +1641,7 @@ explanation_list_str$sep_tree_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::decision_tree( tree_depth = hardhat::tune(), engine = 'rpart', mode = 'regression' @@ -1709,8 +1655,7 @@ explanation_list_str$sep_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 1, # As we used this for the non-string version + phi0 = p0, approach = "regression_separate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' @@ -1727,8 +1672,7 @@ explanation_list_str$sur_rf_cv <- explain( model = model, x_explain = x_explain, x_train = x_train, - prediction_zero = p0, - n_batches = 4, + phi0 = p0, approach = "regression_surrogate", regression.model = "parsnip::rand_forest( mtry = hardhat::tune(), trees = hardhat::tune(), engine = 'ranger', mode = 'regression' diff --git a/vignettes/understanding_shapr_vaeac.Rmd b/vignettes/understanding_shapr_vaeac.Rmd index 79053b197b3f1f9b7d2c64d1a38dd1a57bc8d2a0..91ca7f772e030a7ba77d631395cb281f33ff90c6 100644 --- a/vignettes/understanding_shapr_vaeac.Rmd +++ b/vignettes/understanding_shapr_vaeac.Rmd @@ -26,7 +26,7 @@ editor_options: > [Pretrained vaeac (path)](#pretrained_vaeac_path) -> [Subset of coalitions](#n_combinations) +> [Subset of coalitions](#n_coalitions) > [Paired sampling](#paired_sampling) @@ -109,9 +109,10 @@ Here we go through how to use the `vaeac` approach on the same data as in the ma First we set up the model we want to explain. -```r +``` r library(xgboost) library(data.table) +#> data.table 1.15.4 using 16 threads (see ?getDTthreads). Latest news: r-datatable.com data("airquality") data <- data.table::as.data.table(airquality) @@ -134,7 +135,7 @@ model <- xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) ``` @@ -144,9 +145,8 @@ prediction_zero <- mean(y_train) We are now going to explain predictions made by the model using the `vaeac` approach. -```r -n_samples <- 25 # Low number of MC samples to make the vignette build faster -n_batches <- 1 # Do all coalitions in one batch +``` r +n_MC_samples <- 25 # Low number of MC samples to make the vignette build faster vaeac.n_vaeacs_initialize <- 2 # Initialize several vaeacs to counteract bad initialization values vaeac.epochs <- 4 # The number of training epochs @@ -155,30 +155,32 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = vaeac.epochs, vaeac.n_vaeacs_initialize = vaeac.n_vaeacs_initialize ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. ``` We can look at the Shapley values. -```r +``` r # Printing and ploting the Shapley values. # See ?shapr::explain for interpretation of the values. -print(explanation$shapley_values) -#> none Solar.R Wind Temp Month -#> -#> 1: 43.086 6.1207 3.1430 -18.6779 -2.88614 -#> 2: 43.086 -2.0779 -2.5548 -20.1182 0.69569 -#> 3: 43.086 3.0385 -5.5121 -18.2575 -2.55871 -#> 4: 43.086 3.0009 -4.7220 -8.9452 -3.92486 -#> 5: 43.086 -1.1022 -4.4319 -13.5459 -5.29567 -#> 6: 43.086 3.9320 -9.8445 -11.9489 -3.56018 +print(explanation$shapley_values_est) +#> explain_id none Solar.R Wind Temp Month +#> +#> 1: 1 43.086 4.35827 -0.49487 -16.7173 0.55352 +#> 2: 2 43.086 -2.06968 -2.76668 -17.3760 -1.84287 +#> 3: 3 43.086 1.24259 -5.05865 -18.7919 -0.68187 +#> 4: 4 43.086 5.20834 -10.03741 -8.4807 -1.28136 +#> 5: 5 43.086 0.22127 -3.05847 -17.9177 -3.62080 +#> 6: 6 43.086 4.25576 -9.58514 -18.7123 2.62017 plot(explanation) ``` @@ -191,42 +193,44 @@ if we want to explain new predictions using the same combinations/coalitions as `x_explain`. Note that the new `x_explain` must have the same features as before. The `vaeac` model is accessible via `explanation$internal$parameters$vaeac`. -Note that if we set `verbose = 2` in `explain()`, then `shapr` will give a message +Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. In this example, we extract the trained `vaeac` model from the previous example and send it to `explain()`. -```r +``` r # Send the pre-trained vaeac model expl_pretrained_vaeac <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac$shapley_values_est) #> [1] TRUE ``` ## Pre-trained vaeac (path) {#pretrained_vaeac_path} We can also just provide a path to the stored `vaeac` model. This is beneficial if we have only stored the `vaeac` model on the computer but not the whole `explanation` object. The possible save paths are stored in -`explanation$internal$parameters$vaeac$model`. Note that if we set `verbose = 2` in `explain()`, then `shapr` will give +`explanation$internal$parameters$vaeac$model`. Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. -```r +``` r # Call `explanation$internal$parameters$vaeac$model` to see possible vaeac models. We use `best` below. # send the pre-trained vaeac path expl_pretrained_vaeac_path <- explain( @@ -234,144 +238,129 @@ expl_pretrained_vaeac_path <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac$models$best ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac_path$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac_path$shapley_values_est) #> [1] TRUE ``` -## Specified n_combinations and more batches {#n_combinations} +## Specified n_coalitions {#n_coalitions} -In this section, we discuss two general `shapr` parameters in the `explain()` function -that are method independent, namely, `n_combinations` and `n_batches`. +In this section, we discuss a general `shapr` parameter in the `explain()` function +that is method independent, namely, `n_coalitions`. The user can limit the Shapley value computations to only a subset of coalitions by setting the -`n_combinations` parameter to a value lower than $2^{n_\text{features}}$. To lower the memory -usage, the user can split the coalitions into several batches by setting `n_batches` to a desired -number. In this example, we set `n_batches = 5` and `n_combinations = 10` which is less than -the maximum of `16`. +`n_coalitions` parameter to a value lower than $2^{n_\text{features}}$. Note that we do not need to train a new `vaeac` model as we can use the one above trained on all `16` coalitions as we are now only using a subset of them. This is not applicable the other way around. -```r +``` r # send the pre-trained vaeac path expl_batches_combinations <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 5, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Gives different Shapley values as the latter one are only based on a subset of coalitions plot_SV_several_approaches(list("Original" = explanation, "Other combi." = expl_batches_combinations)) ``` -![](figure_vaeac/check-n_combinations-and-more-batches-1.png) - -```r -# Here we can see that the samples coalitions are in different batches and have different weights -expl_batches_combinations$internal$objects$X -#> Key: -#> Index: -#> id_combination features n_features N shapley_weight approach batch -#> -#> 1: 1 0 1 1000000 NA -#> 2: 2 3 1 4 1 vaeac 1 -#> 3: 3 4 1 4 1 vaeac 3 -#> 4: 4 2 1 4 1 vaeac 2 -#> 5: 5 2,3 2 6 2 vaeac 5 -#> 6: 6 1,4 2 6 1 vaeac 2 -#> 7: 7 1,3,4 3 4 2 vaeac 5 -#> 8: 8 2,3,4 3 4 1 vaeac 4 -#> 9: 9 1,2,3 3 4 1 vaeac 4 -#> 10: 10 1,2,3,4 4 1 1000000 1 +![](figure_vaeac/check-n_coalitions-1.png) + +``` r # Can compare that to the situation where we have exact computations (i.e., include all coalitions) explanation$internal$objects$X -#> Key: -#> id_combination features n_features N shapley_weight approach batch -#> -#> 1: 1 0 1 1.00e+06 NA -#> 2: 2 1 1 4 2.50e-01 vaeac 1 -#> 3: 3 2 1 4 2.50e-01 vaeac 1 -#> 4: 4 3 1 4 2.50e-01 vaeac 1 -#> 5: 5 4 1 4 2.50e-01 vaeac 1 -#> 6: 6 1,2 2 6 1.25e-01 vaeac 1 -#> 7: 7 1,3 2 6 1.25e-01 vaeac 1 -#> 8: 8 1,4 2 6 1.25e-01 vaeac 1 -#> 9: 9 2,3 2 6 1.25e-01 vaeac 1 -#> 10: 10 2,4 2 6 1.25e-01 vaeac 1 -#> 11: 11 3,4 2 6 1.25e-01 vaeac 1 -#> 12: 12 1,2,3 3 4 2.50e-01 vaeac 1 -#> 13: 13 1,2,4 3 4 2.50e-01 vaeac 1 -#> 14: 14 1,3,4 3 4 2.50e-01 vaeac 1 -#> 15: 15 2,3,4 3 4 2.50e-01 vaeac 1 -#> 16: 16 1,2,3,4 4 1 1.00e+06 1 +#> id_coalition coalitions coalition_size N shapley_weight sample_freq features approach +#> +#> 1: 1 0 1 1.00e+06 NA vaeac +#> 2: 2 1 1 4 2.50e-01 NA 1 vaeac +#> 3: 3 2 1 4 2.50e-01 NA 2 vaeac +#> 4: 4 3 1 4 2.50e-01 NA 3 vaeac +#> 5: 5 4 1 4 2.50e-01 NA 4 vaeac +#> 6: 6 1,2 2 6 1.25e-01 NA 1,2 vaeac +#> 7: 7 1,3 2 6 1.25e-01 NA 1,3 vaeac +#> 8: 8 1,4 2 6 1.25e-01 NA 1,4 vaeac +#> 9: 9 2,3 2 6 1.25e-01 NA 2,3 vaeac +#> 10: 10 2,4 2 6 1.25e-01 NA 2,4 vaeac +#> 11: 11 3,4 2 6 1.25e-01 NA 3,4 vaeac +#> 12: 12 1,2,3 3 4 2.50e-01 NA 1,2,3 vaeac +#> 13: 13 1,2,4 3 4 2.50e-01 NA 1,2,4 vaeac +#> 14: 14 1,3,4 3 4 2.50e-01 NA 1,3,4 vaeac +#> 15: 15 2,3,4 3 4 2.50e-01 NA 2,3,4 vaeac +#> 16: 16 1,2,3,4 4 1 1.00e+06 NA 1,2,3,4 vaeac ``` Note that if we train a `vaeac` model from scratch with the setup above, then the `vaeac` model will not use a missing completely as random (MCAR) mask generator, but rather a mask generator that ensures that the `vaeac` model is only trained on the specified set of coalitions. In this case, it will be the set of the -`n_combinations - 2` sampled coalitions. The minus two is because the `vaeac` model will +`n_coalitions - 2` sampled coalitions. The minus two is because the `vaeac` model will not train on the empty and grand coalitions as they are not needed in the Shapley value computations. -```r +``` r expl_batches_combinations_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 1, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.n_vaeacs_initialize = 1, vaeac.epochs = 3, - verbose = 2 + verbose = "vS_details" ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. -#> Using 'specified_masks_mask_generator' with '8' coalitions. +#> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 1. -#> Best vaeac inititalization was number 1 (of 1) with a training VLB = -6.451 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `last` vaeac model at epoch 3. -#> +#> Initializing vaeac model number 1 of 1. +#> Best vaeac inititalization was number 1 (of 1) with a training VLB = -6.593 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■■■■ 67% | Training vaeac (init. 1 of 1): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 1s #> Results of the `vaeac` training process: -#> Best epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Best running avg epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Last epoch: 3. VLB = -4.824 IWAE = -3.252 IWAE_running = -3.540 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 1. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> Best epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> Best running avg epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> Last epoch: 3. VLB = -4.688 IWAE = -3.124 IWAE_running = -3.465 +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.55.18.126022_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' ``` @@ -382,7 +371,7 @@ The `vaeac` approach can use paired sampling to improve the stability of the tra When using paired sampling, each observation in the training batches will be duplicated, but the first version will be masked by $S$ and the second verion will be masked by the complement $\bar{S}$. The mask are taken from the `explanation$internal$objects$S` matrix. Note that `vaeac` does not check if the complement is also in said matrix. -This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_combinations` +This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_coalitions` is less than $2^{n_\text{features}}$, then the `vaeac` model might be trained on coalitions which are not used when computing the Shapley values. This should not be considered as redundant training as it increases the stablility and performance of the `vaeac` model as a whole, hence, we reccomend to use paried samping (default). Furthermore, the masks @@ -390,43 +379,48 @@ are randomly selected for each observation in the batch. The training time when comparison to random sampling due to more complex implementation. -```r +``` r expl_paired_sampling_TRUE <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. expl_paired_sampling_FALSE <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. ``` We can compare the results by looking at the training and validation errors and by the $MSE_v$ evaluation criterion. We do this by using the `vaeac_plot_eval_crit()` and `plot_MSEv_eval_crit()` functions in the `shapr` package, respectively. -```r +``` r explanation_list <- list("Regular samp." = expl_paired_sampling_FALSE, "Paired samp." = expl_paired_sampling_TRUE) vaeac_plot_eval_crit(explanation_list, plot_type = "criterion") @@ -434,7 +428,7 @@ vaeac_plot_eval_crit(explanation_list, plot_type = "criterion") ![](figure_vaeac/paired-sampling-plotting-1.png) -```r +``` r plot_MSEv_eval_crit(explanation_list) ``` @@ -443,14 +437,14 @@ plot_MSEv_eval_crit(explanation_list) By looking at the time, we see that the paired version takes (a bit) longer time in the `setup_computation` phase, that is, in the training phase. -```r +``` r rbind( - "Paired" = expl_paired_sampling_TRUE$timing$timing_secs, - "Regular" = expl_paired_sampling_FALSE$timing$timing_secs + "Paired" = expl_paired_sampling_TRUE$timing$main_timing_secs, + "Regular" = expl_paired_sampling_FALSE$timing$main_timing_secs ) -#> setup test_prediction setup_computation compute_vS shapley_computation -#> Paired 0.10987 0.055879 7.1928 0.29876 0.0043712 -#> Regular 0.05501 0.037705 6.2180 0.30362 0.0044370 +#> setup test_prediction iterative_estimation finalize_explanation +#> Paired 0.048088 0.036740 11.721 0.0049973 +#> Regular 0.047131 0.036345 11.517 0.0049357 ``` @@ -458,74 +452,70 @@ rbind( ## Progressr {#progress_bar} As discussed in the main vignette, the `shapr` package provides two ways for receiving information about the progress of the approach. First, the `shapr` package provides progress updates of the computation of the Shapley values through -the `progressr` package. Second, the user can also get information by setting `verbose = 2` in `explain()`, which -will print out extra information related to the `vaeac` approach. The `verbose` parameter works independently of the -`progressr` package. Meaning that the user can chose to use none, either, or both options simultaneously. We give -two examples here, and refer the reader to the main vignette for more detailed information. +the `progressr` package. Second, the user can also get various form of information through `verbose` in `explain()`. +By letting `'vS_detail' %in% verbose`, we get extra information related to the `vaeac` approach. +The `verbose` parameter works independently of the `progressr` package. +Meaning that the user can chose to use none, either, or both options simultaneously. +We give two examples here, and refer the reader to the main vignette for more detailed information. -By setting `verbose = 2`, we get messages about the progress of the `vaeac` approach. +By setting `c("basic", vS_details")`, we get both basic messages about the explanation case, and +messages about the estimation of the `vaeac` approach. -```r +``` r expl_with_messages <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = c("basic","vS_details"), vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-04 14:57:22 ───────────────────────────────────────────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpIQRVZ2/shapr_obj_acefb1be76dcf.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `last` vaeac model at epoch 5. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■ 29% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 4s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■ 57% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 2s #> Results of the `vaeac` training process: #> Best epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Best running avg epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Last epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 3 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 4 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 5 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.22.930756_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' ``` - -For more visual information, we can use the `progressr` package. This can help us see the progress of the training -step for the final `vaeac` model. Note that one can set `verbose = 0` to not get any messages from the `vaeac` +For more visual information we can use the `progressr` package. +This can help us see detailed progress of the training step for the final `vaeac` model. +Note that by default `vS_details` is not part of `verbose`, meaning that we do not get any messages from the `vaeac`, approach and only get the progress bars. See the main vignette for examples for how to change the progress bar. -```r +``` r library(progressr) progressr::handlers("cli") # Use `progressr::handlers("void")` to silence all `progressr` updates progressr::with_progress({ @@ -534,56 +524,38 @@ progressr::with_progress({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) }) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `last` vaeac model at epoch 5. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■ 29% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 4s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■ 57% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 2s #> Results of the `vaeac` training process: #> Best epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Best running avg epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 #> Last epoch: 5. VLB = -3.318 IWAE = -3.049 IWAE_running = -3.149 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 3 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 4 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 5 of 5. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.33.088772_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' +all.equal(expl_with_messages$shapley_values_est, expl_with_progressr$shapley_values_est) #> [1] TRUE ``` @@ -591,7 +563,7 @@ all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) In the case the user has set a too low number of training epochs and sees that the network is still learning, then the user can continue to train the network from where it stopped. Thus, a good workflow can therefore -be to call the `explain()` function with a `n_samples = 1` (to not waste to much time to generate MC samples), +be to call the `explain()` function with a `n_MC_samples = 1` (to not waste to much time to generate MC samples), then look at the training and evaluation plots of the `vaeac`. If not satisfied, then train more. If satisfied, then call the `explain()` function again but this time by using the extra parameter `vaeac.pretrained_vaeac_model`, as illustrated above. Note that we have set the number of `vaeac.epochs` to be very low in this example and we @@ -605,15 +577,14 @@ data. However, recall that the `vaeac` model is never trained on the empty coali be taken with a grain of salt. -```r +``` r expl_little_training <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 3, vaeac.n_vaeacs_initialize = 2 ) @@ -624,7 +595,7 @@ vaeac_plot_eval_crit(list("Original" = expl_little_training), plot_type = "metho ![](figure_vaeac/continue-training-1.png) -```r +``` r # Can also see how well vaeac generates data from the full joint distribution. Quite good. vaeac_plot_imputed_ggpairs( explanation = expl_little_training, @@ -635,7 +606,7 @@ vaeac_plot_imputed_ggpairs( ![](figure_vaeac/continue-training-2.png) -```r +``` r # Make a copy of the explanation object and continue to train the vaeac model some more epochs expl_train_more <- expl_little_training expl_train_more$internal$parameters$vaeac <- @@ -651,9 +622,8 @@ expl_train_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_more$internal$parameters$vaeac ) @@ -668,7 +638,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/continue-training-3.png) -```r +``` r # Continue to train the vaeac model some more epochs expl_train_even_more <- expl_train_more expl_train_even_more$internal$parameters$vaeac <- @@ -684,9 +654,8 @@ expl_train_even_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_even_more$internal$parameters$vaeac ) @@ -705,7 +674,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/continue-training-4.png) -```r +``` r # Can also see how well vaeac generates data from the full joint distribution vaeac_plot_imputed_ggpairs( explanation = expl_train_even_more, @@ -719,7 +688,7 @@ vaeac_plot_imputed_ggpairs( We can see that the extra training has decreased the MSEv score. The Shapley value explanations have also changed, but they are often comparable. -```r +``` r plot_MSEv_eval_crit(list( "Few epochs" = expl_little_training, "More epochs" = expl_train_more_vaeac, @@ -729,7 +698,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/continue-training-2-1.png) -```r +``` r # We see that the Shapley values have changed, but they are often comparable plot_SV_several_approaches(list( "Few epochs" = expl_little_training, @@ -749,58 +718,57 @@ If we do not want to specify the number of `epochs`, as we are uncertain how man model will stop the training procedure if there has been no improvement in the validation score for `5` epochs. -```r +``` r # Low value for `vaeac.epochs_early_stopping` here to build the vignette faster expl_early_stopping <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = 1, - verbose = 2, + phi0 = phi0, + n_MC_samples = 250, + verbose = c("basic","vS_details"), vaeac.epochs = 1000, # Set it to a big number vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. +#> +#> ── Starting `shapr::explain()` at 2024-10-04 14:57:44 ───────────────────────────────────────────────────────────────────────────────────────────────────────── +#> • Model class: +#> • Approach: vaeac +#> • iterative estimation: FALSE +#> • Number of feature-wise Shapley values: 4 +#> • Number of observations to explain: 6 +#> • Computations (temporary) saved at: '/tmp/RtmpIQRVZ2/shapr_obj_acefb6c654eee.rds' +#> +#> ── Main computation started ── +#> +#> ℹ Using 16 of 16 coalitions. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. #> Using 'mcar_mask_generator' with 'masking_ratio = 0.5'. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `best` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 5. -#> Saving `best_running` vaeac model at epoch 6. -#> Saving `best` vaeac model at epoch 7. -#> Saving `best_running` vaeac model at epoch 7. -#> Saving `best` vaeac model at epoch 8. -#> Saving `best_running` vaeac model at epoch 8. -#> Saving `best_running` vaeac model at epoch 9. -#> Saving `best` vaeac model at epoch 10. -#> Saving `best_running` vaeac model at epoch 10. -#> Saving `best_running` vaeac model at epoch 11. -#> Saving `best` vaeac model at epoch 12. -#> Saving `best_running` vaeac model at epoch 12. -#> No IWAE improvment in 2 epochs. Apply early stopping at epoch 14. -#> Saving `last` vaeac model at epoch 14. +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■ 0% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.593 | IWAE: -3.321 | ETA: 12m Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.566 after 2 epochs. Continue to train this inititalization. +#> ■ 0% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.566 | IWAE: -3.226 | ETA: 13m No IWAE improvment in 2 epochs. Apply early stopping at epoch 14. #> #> Results of the `vaeac` training process: #> Best epoch: 12. VLB = -2.958 IWAE = -2.930 IWAE_running = -2.991 #> Best running avg epoch: 12. VLB = -2.958 IWAE = -2.930 IWAE_running = -2.991 #> Last epoch: 14. VLB = -2.971 IWAE = -2.955 IWAE_running = -2.996 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 1. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.57.43.853271_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' # Look at the training and validation errors. We are quite happy with it. vaeac_plot_eval_crit( @@ -815,7 +783,7 @@ However, we can train it further for a fixed amount of epochs if desired. This c happy with the IWAE curve or we feel that we set `vaeac.epochs_early_stopping` to a too low value or if the max number of epochs (`vaeac.epochs`) were reached. -```r +``` r # Make a copy of the explanation object which we are to train further. expl_early_stopping_train_more <- expl_early_stopping @@ -825,7 +793,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 15, x_train = x_train, - verbose = 0 + verbose = NULL ) # Can even do it twice if desired @@ -834,7 +802,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 10, x_train = x_train, - verbose = 0 + verbose = NULL ) # Look at the training and validation errors. We see some improvement @@ -852,22 +820,24 @@ vaeac_plot_eval_crit( We can then use the extra trained version to compute the Shapley value explanations and compare it with the previous version that used early stopping. We see a non-significant difference. -```r +``` r # Use extra trained vaeac model to compute Shapley values again. expl_early_stopping_train_more <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_early_stopping_train_more$internal$parameters$vaeac ) ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # We can compare their MSEv scores plot_MSEv_eval_crit(list( @@ -878,7 +848,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/early-stopping-3-1.png) -```r +``` r # We see that the Shapley values have changed, but only slightly plot_SV_several_approaches(list( "Vaeac early stopping" = expl_early_stopping, @@ -900,47 +870,43 @@ The same goes for group B. Note that in this setup, there are only `4` possible `2` coalitions as the empty and grand coalitions as they are not needed in the Shapley value computations. -```r +``` r expl_group <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, + phi0 = phi0, group = list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")), - n_batches = 2, - n_samples = n_samples, - verbose = 2, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. -#> Setting up the `vaeac` approach. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_groups = 4, +#> and is therefore set to 2^n_groups = 4. +#> +#> ── Extra info about the pretrained vaeac model ── +#> #> Training the `vaeac` model with the provided parameters from scratch on CPU. -#> Using 'specified_masks_mask_generator' with '2' coalitions. +#> Using 'specified_masks_mask_generator' with '4' coalitions. #> The vaeac model contains 17032 trainable parameters. -#> Initializing vaeac number 1 of 2. -#> Initializing vaeac number 2 of 2. -#> Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.814 after 2 epochs. Continue to train this inititalization. -#> Saving `best` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 3. -#> Saving `best_running` vaeac model at epoch 4. -#> Saving `last` vaeac model at epoch 4. -#> +#> Initializing vaeac model number 1 of 2. +#> Initializing vaeac model number 2 of 2. +#> ■■■■■■■■■■■ 33% | Training vaeac (init. 1 of 2): Epoch: 2 | VLB: -6.489 | IWAE: -3.322 | ETA: 3s Best vaeac inititalization was number 2 (of 2) with a training VLB = -4.453 after 2 epochs. Continue to train this inititalization. +#> ■■■■■■■■■■■■■■■■■■■■■ 67% | Training vaeac (init. 2 of 2): Epoch: 2 | VLB: -4.453 | IWAE: -3.174 | ETA: 2s #> Results of the `vaeac` training process: -#> Best epoch: 3. VLB = -3.935 IWAE = -3.124 IWAE_running = -3.267 -#> Best running avg epoch: 4. VLB = -3.619 IWAE = -3.138 IWAE_running = -3.235 -#> Last epoch: 4. VLB = -3.619 IWAE = -3.138 IWAE_running = -3.235 -#> Done with setting up the `vaeac` approach. -#> Generating Monte Carlo samples using `vaeac` for batch 1 of 2. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. -#> Generating Monte Carlo samples using `vaeac` for batch 2 of 2. -#> Preprocessing the explicands. -#> Generating the MC samples. -#> Postprocessing the Monte Carlo samples. +#> Best epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> Best running avg epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> Last epoch: 4. VLB = -3.514 IWAE = -3.114 IWAE_running = -3.153 +#> +#> ℹ The trained `vaeac` models are saved to folder '/tmp/RtmpIQRVZ2' at +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_best_running.pt' +#> '/tmp/RtmpIQRVZ2/X2024.10.04.14.58.51.032657_n_features_4_n_train_105_depth_3_width_32_latent_8_lr_0.001_epoch_last.pt' # Plot the resulting explanations plot(expl_group) @@ -954,7 +920,7 @@ plot(expl_group) Here we look at a setup with mixed data, i.e., the data contains both categorical and continuous features. First we set up the data and the model. -```r +``` r library(ranger) data <- data.table::as.data.table(airquality) data <- data[complete.cases(data), ] @@ -977,26 +943,28 @@ model <- ranger(as.formula(paste0(y_var, " ~ ", paste0(x_var_cat, collapse = " + ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(data_train_cat[, get(y_var)]) +phi0 <- mean(data_train_cat[, get(y_var)]) ``` Then we compute explanations using the `ctree` and `vaeac` approaches. For the `vaeac` approach, we consider two setups: the default architecture, and a simpler one without skip connections. We do this to illustrate that the skip connections improve the `vaeac` method. We use `ctree` with default parameters. -```r +``` r # Here we use the ctree approach expl_ctree <- explain( model = model, x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250 + phi0 = phi0, + n_MC_samples = 250 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Then we use the vaeac approach expl_vaeac_with <- explain( @@ -1004,14 +972,17 @@ expl_vaeac_with <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4 ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # Then we use the vaeac approach expl_vaeac_without <- explain( @@ -1019,9 +990,8 @@ expl_vaeac_without <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4, vaeac.extra_parameters = list( @@ -1031,6 +1001,10 @@ expl_vaeac_without <- explain( ) #> Note: Feature classes extracted from the model contains NA. #> Assuming feature classes from the data are correct. +#> +#> Success with message: +#> max_n_coalitions is NULL or larger than or 2^n_features = 16, +#> and is therefore set to 2^n_features = 16. # We see that the `vaeac` model without the skip connections perform worse vaeac_plot_eval_crit( @@ -1044,7 +1018,7 @@ vaeac_plot_eval_crit( ![](figure_vaeac/vaeac-mixed-data-1.png) -```r +``` r # The vaeac model with skip connections have the lowest/best MSE_Frye evaluation criterion score plot_MSEv_eval_crit(list( "Vaeac w.o. skip-con." = expl_vaeac_without, @@ -1055,7 +1029,7 @@ plot_MSEv_eval_crit(list( ![](figure_vaeac/vaeac-mixed-data-2.png) -```r +``` r # Can compare the Shapley values. Ctree and vaeac with skip connections produce similar explanations. plot_SV_several_approaches( list( @@ -1079,7 +1053,7 @@ Finally, note that if the user specifies `vaeac.cuda = TRUE`, but there is no av a warning and falls back to use CPU instead. -```r +``` r # Load necessary library library(mvtnorm) @@ -1117,7 +1091,7 @@ x_explain <- dt_explain[, -1] model <- lm(y ~ ., dt_train) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Fit vaeac model using the CPU time_cpu <- system.time({ @@ -1126,9 +1100,8 @@ time_cpu <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = FALSE) @@ -1142,9 +1115,8 @@ time_cuda <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = TRUE) @@ -1168,7 +1140,7 @@ rbind("Vaeac CPU" = time_cpu, "Vaeac GPU" = time_cuda) It is no possible to set same random state on the CPU and GPU, hence, the results are not equivalent. The difference is due to different initialization values. -```r +``` r vaeac_plot_eval_crit( list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda), plot_type = "criterion" @@ -1179,7 +1151,7 @@ vaeac_plot_eval_crit( We also get almost identical $\text{MSE}_v$ values. -```r +``` r plot_MSEv_eval_crit(list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda)) ``` @@ -1188,7 +1160,7 @@ plot_MSEv_eval_crit(list("Vaeac CPU" = explanation_cpu, We can also compare the Shapley values and see that we get comparable explanations. -```r +``` r plot_SV_several_approaches( list("Vaeac CPU" = explanation_cpu, "Vaeac GPU" = explanation_cuda), index_explicands = 1:3, diff --git a/vignettes/understanding_shapr_vaeac.Rmd.orig b/vignettes/understanding_shapr_vaeac.Rmd.orig index 9d12a64c50a38406a91d282768ad8d0ca06da664..3d621ff48321e66883cfcf4335c0603313503324 100644 --- a/vignettes/understanding_shapr_vaeac.Rmd.orig +++ b/vignettes/understanding_shapr_vaeac.Rmd.orig @@ -40,7 +40,7 @@ library(shapr) > [Pretrained vaeac (path)](#pretrained_vaeac_path) -> [Subset of coalitions](#n_combinations) +> [Subset of coalitions](#n_coalitions) > [Paired sampling](#paired_sampling) @@ -147,7 +147,7 @@ model <- xgboost( ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) ``` @@ -157,8 +157,7 @@ prediction_zero <- mean(y_train) We are now going to explain predictions made by the model using the `vaeac` approach. ```{r first-vaeac, cache = TRUE} -n_samples <- 25 # Low number of MC samples to make the vignette build faster -n_batches <- 1 # Do all coalitions in one batch +n_MC_samples <- 25 # Low number of MC samples to make the vignette build faster vaeac.n_vaeacs_initialize <- 2 # Initialize several vaeacs to counteract bad initialization values vaeac.epochs <- 4 # The number of training epochs @@ -167,9 +166,8 @@ explanation <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = vaeac.epochs, vaeac.n_vaeacs_initialize = vaeac.n_vaeacs_initialize ) @@ -179,7 +177,7 @@ We can look at the Shapley values. ```{r first-vaeac-plots, cache = TRUE} # Printing and ploting the Shapley values. # See ?shapr::explain for interpretation of the values. -print(explanation$shapley_values) +print(explanation$shapley_values_est) plot(explanation) ``` @@ -190,7 +188,7 @@ if we want to explain new predictions using the same combinations/coalitions as `x_explain`. Note that the new `x_explain` must have the same features as before. The `vaeac` model is accessible via `explanation$internal$parameters$vaeac`. -Note that if we set `verbose = 2` in `explain()`, then `shapr` will give a message +Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. In this example, we extract the trained `vaeac` model from the previous example and send it to `explain()`. @@ -202,22 +200,21 @@ expl_pretrained_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) ) # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac$shapley_values_est) ``` ## Pre-trained vaeac (path) {#pretrained_vaeac_path} We can also just provide a path to the stored `vaeac` model. This is beneficial if we have only stored the `vaeac` model on the computer but not the whole `explanation` object. The possible save paths are stored in -`explanation$internal$parameters$vaeac$model`. Note that if we set `verbose = 2` in `explain()`, then `shapr` will give +`explanation$internal$parameters$vaeac$model`. Note that if we let `'vS_detail' %in% verbose` in `explain()`, then `shapr` will give a message that it loads a pretrained `vaeac` model instead of training it from scratch. ```{r pretrained-vaeac-path, cache = TRUE} @@ -228,45 +225,40 @@ expl_pretrained_vaeac_path <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = n_samples, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac$models$best ) ) # Check that this version provides the same Shapley values -all.equal(explanation$shapley_values, expl_pretrained_vaeac_path$shapley_values) +all.equal(explanation$shapley_values_est, expl_pretrained_vaeac_path$shapley_values_est) ``` -## Specified n_combinations and more batches {#n_combinations} +## Specified n_coalitions {#n_coalitions} -In this section, we discuss two general `shapr` parameters in the `explain()` function -that are method independent, namely, `n_combinations` and `n_batches`. +In this section, we discuss a general `shapr` parameter in the `explain()` function +that is method independent, namely, `n_coalitions`. The user can limit the Shapley value computations to only a subset of coalitions by setting the -`n_combinations` parameter to a value lower than $2^{n_\text{features}}$. To lower the memory -usage, the user can split the coalitions into several batches by setting `n_batches` to a desired -number. In this example, we set `n_batches = 5` and `n_combinations = 10` which is less than -the maximum of `16`. +`n_coalitions` parameter to a value lower than $2^{n_\text{features}}$. Note that we do not need to train a new `vaeac` model as we can use the one above trained on all `16` coalitions as we are now only using a subset of them. This is not applicable the other way around. -```{r check-n_combinations-and-more-batches, cache = TRUE} +```{r check-n_coalitions, cache = TRUE} # send the pre-trained vaeac path expl_batches_combinations <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 5, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = explanation$internal$parameters$vaeac ) @@ -274,8 +266,6 @@ expl_batches_combinations <- explain( # Gives different Shapley values as the latter one are only based on a subset of coalitions plot_SV_several_approaches(list("Original" = explanation, "Other combi." = expl_batches_combinations)) -# Here we can see that the samples coalitions are in different batches and have different weights -expl_batches_combinations$internal$objects$X # Can compare that to the situation where we have exact computations (i.e., include all coalitions) explanation$internal$objects$X @@ -284,21 +274,20 @@ explanation$internal$objects$X Note that if we train a `vaeac` model from scratch with the setup above, then the `vaeac` model will not use a missing completely as random (MCAR) mask generator, but rather a mask generator that ensures that the `vaeac` model is only trained on the specified set of coalitions. In this case, it will be the set of the -`n_combinations - 2` sampled coalitions. The minus two is because the `vaeac` model will +`n_coalitions - 2` sampled coalitions. The minus two is because the `vaeac` model will not train on the empty and grand coalitions as they are not needed in the Shapley value computations. -```{r check-n_combinations-and-more-batches-2, cache = TRUE} +```{r check-n_coalitions-2, cache = TRUE} expl_batches_combinations_2 <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_combinations = 10, - n_batches = 1, - n_samples = n_samples, + phi0 = phi0, + n_coalitions = 10, + n_MC_samples = n_MC_samples, vaeac.n_vaeacs_initialize = 1, vaeac.epochs = 3, - verbose = 2 + verbose = "vS_details" ) ``` @@ -310,7 +299,7 @@ The `vaeac` approach can use paired sampling to improve the stability of the tra When using paired sampling, each observation in the training batches will be duplicated, but the first version will be masked by $S$ and the second verion will be masked by the complement $\bar{S}$. The mask are taken from the `explanation$internal$objects$S` matrix. Note that `vaeac` does not check if the complement is also in said matrix. -This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_combinations` +This means that if the Shapley value explanations are computed based on a subset of coalitions, i.e., `n_coalitions` is less than $2^{n_\text{features}}$, then the `vaeac` model might be trained on coalitions which are not used when computing the Shapley values. This should not be considered as redundant training as it increases the stablility and performance of the `vaeac` model as a whole, hence, we reccomend to use paried samping (default). Furthermore, the masks @@ -323,9 +312,8 @@ expl_paired_sampling_TRUE <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = TRUE) @@ -336,9 +324,8 @@ expl_paired_sampling_FALSE <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = n_MC_samples, vaeac.epochs = 10, vaeac.n_vaeacs_initialize = 1, vaeac.extra_parameters = list(vaeac.paired_sampling = FALSE) @@ -359,8 +346,8 @@ By looking at the time, we see that the paired version takes (a bit) longer time phase, that is, in the training phase. ```{r paired-sampling-timing} rbind( - "Paired" = expl_paired_sampling_TRUE$timing$timing_secs, - "Regular" = expl_paired_sampling_FALSE$timing$timing_secs + "Paired" = expl_paired_sampling_TRUE$timing$main_timing_secs, + "Regular" = expl_paired_sampling_FALSE$timing$main_timing_secs ) ``` @@ -369,29 +356,30 @@ rbind( ## Progressr {#progress_bar} As discussed in the main vignette, the `shapr` package provides two ways for receiving information about the progress of the approach. First, the `shapr` package provides progress updates of the computation of the Shapley values through -the `progressr` package. Second, the user can also get information by setting `verbose = 2` in `explain()`, which -will print out extra information related to the `vaeac` approach. The `verbose` parameter works independently of the -`progressr` package. Meaning that the user can chose to use none, either, or both options simultaneously. We give -two examples here, and refer the reader to the main vignette for more detailed information. - -By setting `verbose = 2`, we get messages about the progress of the `vaeac` approach. +the `progressr` package. Second, the user can also get various form of information through `verbose` in `explain()`. +By letting `'vS_detail' %in% verbose`, we get extra information related to the `vaeac` approach. +The `verbose` parameter works independently of the `progressr` package. +Meaning that the user can chose to use none, either, or both options simultaneously. +We give two examples here, and refer the reader to the main vignette for more detailed information. + +By setting `c("basic", vS_details")`, we get both basic messages about the explanation case, and +messages about the estimation of the `vaeac` approach. ```{r progressr-false-verbose-2, cache = TRUE} expl_with_messages <- explain( model = model, x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = c("basic","vS_details"), vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) ``` - -For more visual information, we can use the `progressr` package. This can help us see the progress of the training -step for the final `vaeac` model. Note that one can set `verbose = 0` to not get any messages from the `vaeac` +For more visual information we can use the `progressr` package. +This can help us see detailed progress of the training step for the final `vaeac` model. +Note that by default `vS_details` is not part of `verbose`, meaning that we do not get any messages from the `vaeac`, approach and only get the progress bars. See the main vignette for examples for how to change the progress bar. ```{r progressr-true-verbose-2, cache = TRUE} library(progressr) @@ -402,22 +390,21 @@ progressr::with_progress({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = n_samples, - n_batches = 5, - verbose = 2, + phi0 = phi0, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 5, vaeac.n_vaeacs_initialize = 2 ) }) -all.equal(expl_with_messages$shapley_values, expl_with_progressr$shapley_values) +all.equal(expl_with_messages$shapley_values_est, expl_with_progressr$shapley_values_est) ``` ## Continue the training of the vaeac approach {#continue_training} In the case the user has set a too low number of training epochs and sees that the network is still learning, then the user can continue to train the network from where it stopped. Thus, a good workflow can therefore -be to call the `explain()` function with a `n_samples = 1` (to not waste to much time to generate MC samples), +be to call the `explain()` function with a `n_MC_samples = 1` (to not waste to much time to generate MC samples), then look at the training and evaluation plots of the `vaeac`. If not satisfied, then train more. If satisfied, then call the `explain()` function again but this time by using the extra parameter `vaeac.pretrained_vaeac_model`, as illustrated above. Note that we have set the number of `vaeac.epochs` to be very low in this example and we @@ -436,9 +423,8 @@ expl_little_training <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = n_batches, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 3, vaeac.n_vaeacs_initialize = 2 ) @@ -466,9 +452,8 @@ expl_train_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_more$internal$parameters$vaeac ) @@ -494,9 +479,8 @@ expl_train_even_more_vaeac <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_train_even_more$internal$parameters$vaeac ) @@ -550,10 +534,9 @@ expl_early_stopping <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 250, - n_batches = 1, - verbose = 2, + phi0 = phi0, + n_MC_samples = 250, + verbose = c("basic","vS_details"), vaeac.epochs = 1000, # Set it to a big number vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.epochs_early_stopping = 2) @@ -579,7 +562,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 15, x_train = x_train, - verbose = 0 + verbose = NULL ) # Can even do it twice if desired @@ -588,7 +571,7 @@ expl_early_stopping_train_more$internal$parameters$vaeac <- explanation = expl_early_stopping_train_more, epochs_new = 10, x_train = x_train, - verbose = 0 + verbose = NULL ) # Look at the training and validation errors. We see some improvement @@ -610,9 +593,8 @@ expl_early_stopping_train_more <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = n_batches, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.extra_parameters = list( vaeac.pretrained_vaeac_model = expl_early_stopping_train_more$internal$parameters$vaeac ) @@ -647,11 +629,10 @@ expl_group <- explain( x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, + phi0 = phi0, group = list(A = c("Temp", "Month"), B = c("Wind", "Solar.R")), - n_batches = 2, - n_samples = n_samples, - verbose = 2, + n_MC_samples = n_MC_samples, + verbose = "vS_details", vaeac.epochs = 4, vaeac.n_vaeacs_initialize = 2 ) @@ -688,7 +669,7 @@ model <- ranger(as.formula(paste0(y_var, " ~ ", paste0(x_var_cat, collapse = " + ) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(data_train_cat[, get(y_var)]) +phi0 <- mean(data_train_cat[, get(y_var)]) ``` Then we compute explanations using the `ctree` and `vaeac` approaches. For the `vaeac` approach, we consider two setups: the default architecture, and a simpler one without skip connections. We do this @@ -701,9 +682,8 @@ expl_ctree <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "ctree", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250 + phi0 = phi0, + n_MC_samples = 250 ) # Then we use the vaeac approach @@ -712,9 +692,8 @@ expl_vaeac_with <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4 ) @@ -725,9 +704,8 @@ expl_vaeac_without <- explain( x_explain = x_explain_cat, x_train = x_train_cat, approach = "vaeac", - prediction_zero = prediction_zero, - n_batches = 1, - n_samples = 250, + phi0 = phi0, + n_MC_samples = 250, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 4, vaeac.extra_parameters = list( @@ -808,7 +786,7 @@ x_explain <- dt_explain[, -1] model <- lm(y ~ ., dt_train) # Specifying the phi_0, i.e. the expected prediction without any features -prediction_zero <- mean(y_train) +phi0 <- mean(y_train) # Fit vaeac model using the CPU time_cpu <- system.time({ @@ -817,9 +795,8 @@ time_cpu <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = FALSE) @@ -833,9 +810,8 @@ time_cuda <- system.time({ x_explain = x_explain, x_train = x_train, approach = "vaeac", - prediction_zero = prediction_zero, - n_samples = 100, - n_batches = 5, + phi0 = phi0, + n_MC_samples = 100, vaeac.epochs = 50, vaeac.n_vaeacs_initialize = 2, vaeac.extra_parameters = list(vaeac.cuda = TRUE)