diff --git a/R/csv.R b/R/csv.R index 5fce18f4..ffe3b99c 100644 --- a/R/csv.R +++ b/R/csv.R @@ -704,7 +704,13 @@ read_csv_metadata <- function(csv_file) { for (line in metadata[[1]]) { if (!startsWith(line, "#") && is.null(csv_file_info[["variables"]])) { # if no # at the start of line, the line is the CSV header - all_names <- strsplit(line, ",")[[1]] + header_dt <- data.table::fread( + text = line, + header = FALSE, + check.names = FALSE, + data.table = FALSE + ) + all_names <- as.character(header_dt[1, ]) if (all(csv_file_info$algorithm != "fixed_param")) { csv_file_info[["sampler_diagnostics"]] <- all_names[endsWith(all_names, "__")] csv_file_info[["sampler_diagnostics"]] <- csv_file_info[["sampler_diagnostics"]][!(csv_file_info[["sampler_diagnostics"]] %in% c("lp__", "log_p__", "log_g__", "log_q__"))] @@ -922,8 +928,10 @@ check_csv_metadata_matches <- function(csv_metadata) { NULL } -# convert names like beta.1.1 to beta[1,1] +# convert names like beta.1.1 or beta(1,1) to beta[1,1] repair_variable_names <- function(names) { + names <- sub("\\(", "[", names) + names <- gsub("\\)", "", names) names <- sub("\\.", "[", names) names <- gsub("\\.", ",", names) names[grep("\\[", names)] <- @@ -990,7 +998,9 @@ variable_dims <- function(variable_names = NULL) { uniq_variable_names <- unique(gsub("\\[.*\\]", "", variable_names)) var_names <- gsub("\\]", "", variable_names) for (var in uniq_variable_names) { - pattern <- paste0("^", var, "\\[") + # escape regex symbols + esc_var <- gsub("([][{}()+*?.^$|\\\\])", "\\\\\\1", var) + pattern <- paste0("^", esc_var, "\\[") var_indices <- var_names[grep(pattern, var_names)] var_indices <- gsub(pattern, "", var_indices) if (length(var_indices)) { diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 89c0faf3..299e1080 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -100,6 +100,19 @@ test_that("read_cmdstan_csv() fails with the no params listed", { "Supplied CSV file does not contain any variable names or data!") }) +test_that("variable_dims works for standard Stan names", { + vars <- c("beta[1]", "beta[2]", "beta[3]", "sigma") + dims <- variable_dims(vars) + expect_equal(dims$beta, 3) + expect_equal(dims$sigma, 1) +}) + +test_that("variable_dims handles names with regex metacharacters", { + vars <- c('SIGMA(1,1)[1,2]', 'SIGMA(1,1)[2,2]') + dims <- variable_dims(vars) + expect_equal(dims[["SIGMA(1,1)"]], c(2, 2)) +}) + test_that("read_cmdstan_csv() matches utils::read.csv", { csv_files <- c(test_path("resources", "csv", "model1-1-warmup.csv"), test_path("resources", "csv", "model1-2-warmup.csv")) @@ -844,97 +857,55 @@ test_that("read_cmdstan_csv() works with tilde expansion", { expect_no_error(read_cmdstan_csv(tildified_path)) }) +test_that("as_cmdstan_fit handles parameter names with parentheses and indices", { + skip_on_cran() -test_that("as_cmdstan_fit creates fitted model objects from csv", { - fits <- list( - mle = as_cmdstan_fit(fit_logistic_optimize$output_files()), - vb = as_cmdstan_fit(fit_logistic_variational$output_files()), - laplace = as_cmdstan_fit(fit_logistic_laplace$output_files()), - pathfinder = as_cmdstan_fit(fit_logistic_pathfinder$output_files()), - mcmc = as_cmdstan_fit(fit_logistic_thin_1$output_files()) + csv_file <- withr::local_tempfile(fileext = ".csv") + lines <- c( + "# model = norm_model", + "# method = sample (Default)", + "# id = 1", + "# thin = 1", + "# save_warmup = 0", + 'lp__,accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__,"Sigma(1,1)","Sigma(1,2)","Sigma(2,1)","Sigma(2,2)"', + "-65.951579,0.92571393,0.77752825,3,7,0,67.391073,0.2808549,-0.95718644,0.080662461,0.58814086", + "-66.417297,0.89632515,0.77752825,2,3,0,68.026905,0.3014893,-0.97834703,0.069719538,0.89573157" ) + writeLines(lines, csv_file) - for (class in names(fits)) { - fit <- fits[[class]] - if (class == "laplace") { - class_name <- "Laplace" - } else if (class == "pathfinder") { - class_name <- "Pathfinder" - } else { - class_name <- toupper(class) - } - checkmate::expect_r6(fit, classes = paste0("CmdStan", class_name, "_CSV")) - expect_s3_class(fit$draws(), "draws") - checkmate::expect_numeric(fit$lp()) - expect_output(fit$print(), "variable") - expect_length(fit$output_files(), if (class == "mcmc") fit$num_chains() else 1) - expect_s3_class(fit$summary(), "draws_summary") - - if (class == "mcmc") { - expect_s3_class(fit$sampler_diagnostics(), "draws_array") - expect_type(fit$inv_metric(), "list") - expect_equal(fit$time()$total, NA_integer_) - expect_s3_class(fit$time()$chains, "data.frame") - } - if (class == "mle") { - checkmate::expect_numeric(fit$mle()) - } - if (class %in% c("vb", "laplace", "pathfinder")) { - checkmate::expect_numeric(fit$lp_approx()) - } - for (method in unavailable_methods_CmdStanFit_CSV) { - if (!(method == "time" && class == "mcmc")) { - expect_error(fit[[method]](), "This method is not available", info = class) - } - } - } -}) + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE) -test_that("as_cmdstan_fit can check MCMC diagnostics", { - fit_schools <- suppressMessages( - testing_fit("schools", chains = 2, - adapt_delta = 0.5, max_treedepth = 4, - show_messages = FALSE) - ) - expect_message( - as_cmdstan_fit(fit_schools$output_files()), - "transitions ended with a divergence" - ) - expect_message( - as_cmdstan_fit(fit_schools$output_files()), - "transitions hit the maximum treedepth" - ) - expect_silent(as_cmdstan_fit(fit_schools$output_files(), check_diagnostics = FALSE)) -}) - -test_that("as_cmdstan_fit filters variables across methods", { - mcmc_vars <- c("alpha", "beta[2]") - mcmc <- as_cmdstan_fit(fit_logistic_thin_1$output_files(), variables = mcmc_vars) - expect_equal(posterior::variables(mcmc$draws()), mcmc_vars) - expect_equal(mcmc$summary()$variable, mcmc_vars) - expect_equal(mcmc$metadata()$variables, mcmc_vars) - - mle_vars <- c("beta[1]", "beta[3]") - mle <- as_cmdstan_fit(fit_logistic_optimize$output_files(), variables = mle_vars) - expect_equal(posterior::variables(mle$draws()), mle_vars) - expect_equal(mle$summary()$variable, mle_vars) - expect_equal(mle$metadata()$variables, mle_vars) - - vb_vars <- "beta" - vb <- as_cmdstan_fit(fit_logistic_variational$output_files(), variables = vb_vars) - expect_equal(posterior::variables(vb$draws()), c("beta[1]", "beta[2]", "beta[3]")) - expect_equal(vb$summary()$variable, c("beta[1]", "beta[2]", "beta[3]")) - expect_equal(vb$metadata()$variables, c("beta[1]", "beta[2]", "beta[3]")) - - laplace_vars <- "alpha" - laplace <- as_cmdstan_fit(fit_logistic_laplace$output_files(), variables = laplace_vars) - expect_equal(posterior::variables(laplace$draws()), laplace_vars) - expect_equal(laplace$summary()$variable, laplace_vars) - expect_equal(laplace$metadata()$variables, laplace_vars) - - pathfinder_vars <- c("alpha", "beta[1]", "beta[3]") - pathfinder <- as_cmdstan_fit(fit_logistic_pathfinder$output_files(), variables = pathfinder_vars) - expect_equal(posterior::variables(pathfinder$draws()), pathfinder_vars) - expect_equal(pathfinder$summary()$variable, pathfinder_vars) - expect_equal(pathfinder$metadata()$variables, pathfinder_vars) + vars <- posterior::variables(fit$draws()) + expect_true(all( + c("Sigma[1,1]", "Sigma[1,2]", "Sigma[2,1]", "Sigma[2,2]") %in% vars + )) + + dims <- fit$metadata()$stan_variable_sizes + expect_equal(dims[["Sigma"]], c(2, 2)) +}) + +test_that("as_cmdstan_fit handles variable names with parentheses", { + skip_on_cran() + csv_file <- withr::local_tempfile(fileext = ".csv") + writeLines(c( + "# model = norm_model", + "# method = sample (Default)", + "# id = 1", + "# thin = 1", + "# save_warmup = 0", + "THETA4,SIGMA(1,1)", + "2.00000E+00,2.00000E+00", + "2.00000E+00,2.00000E+00" + ), con = csv_file) + + expect_no_error({ + fit <- as_cmdstan_fit(csv_file, check_diagnostics = FALSE, format = "draws_matrix") + }) + + draws <- fit$draws() + vars <- posterior::variables(draws) + + expect_equal(posterior::ndraws(draws), 2L) + expect_true(any(grepl("THETA4", vars))) + expect_true(any(grepl("SIGMA", vars))) })