diff --git a/R/csv.R b/R/csv.R index 5fce18f47..a628f9e45 100644 --- a/R/csv.R +++ b/R/csv.R @@ -453,6 +453,7 @@ read_cmdstan_csv <- function(files, } list( metadata = metadata, + time = list(total = NA_integer_, chains = metadata$time), generated_quantities = draws ) } else if (metadata$method == "pathfinder") { @@ -783,6 +784,10 @@ read_csv_metadata <- function(csv_file) { tmp <- gsub("seconds (Total)", "", tmp, fixed = TRUE) tmp <- trimws(gsub(" Elapsed Time: ", "", tmp, fixed = TRUE)) total_time <- as.numeric(tmp) + } else if (grepl("(Generated Quantities)", tmp, fixed = TRUE)) { + tmp <- gsub("seconds (Generated Quantities)", "", tmp, fixed = TRUE) + tmp <- trimws(gsub("Elapsed Time:", "", tmp, fixed = TRUE)) + total_time <- as.numeric(tmp) } if (!is.null(csv_file_info$method) && csv_file_info$method == "diagnose" && @@ -824,6 +829,11 @@ read_csv_metadata <- function(csv_file) { sampling = sampling_time, total = total_time ) + } else if (csv_file_info$method == "generate_quantities") { + csv_file_info$time <- data.frame( + chain_id = csv_file_info$id, + total = total_time + ) } csv_file_info$model <- NULL csv_file_info$engaged <- NULL diff --git a/R/run.R b/R/run.R index eefa6d7ce..08952cbc9 100644 --- a/R/run.R +++ b/R/run.R @@ -1135,6 +1135,8 @@ CmdStanGQProcs <- R6::R6Class( if (self$is_still_working(id) && !self$is_queued(id) && !self$is_alive(id)) { # if the process just finished make sure we process all # input and mark the process finished + self$process_output(id) + self$process_error_output(id) if (self$get_proc(id)$get_exit_status() == 0) { self$set_proc_state(id = id, new_state = 5) # mark_proc_stop will mark this process successful } else { @@ -1156,6 +1158,8 @@ CmdStanGQProcs <- R6::R6Class( if (nzchar(line)) { if (self$proc_state(id) == 1 && grepl("refresh = ", line, perl = TRUE)) { self$set_proc_state(id, new_state = 1.5) + } else if (grepl("Elapsed Time:", line, fixed = TRUE)) { + private$proc_total_time_[[id]] <- as.double(trimws(sub("Elapsed Time:", "", sub("seconds (Generated Quantities)", "", line, fixed = TRUE), fixed = TRUE))) } else if (self$proc_state(id) >= 2 && private$show_stdout_messages_) { cat("Chain", id, line, "\n") } diff --git a/tests/testthat/resources/csv/bernoulli_ppc-1-gq-with-timing.csv b/tests/testthat/resources/csv/bernoulli_ppc-1-gq-with-timing.csv new file mode 100644 index 000000000..715d91184 --- /dev/null +++ b/tests/testthat/resources/csv/bernoulli_ppc-1-gq-with-timing.csv @@ -0,0 +1,26 @@ +# stan_version_major = 2 +# stan_version_minor = 39 +# stan_version_patch = 0 +# model = bernoulli_ppc_model +# method = generate_quantities +# generate_quantities +# fitted_params = /tmp/RtmpCvOIQ1/fittedParams-202006271227-1-b85b52.csv +# id = 1 +# data +# file = /home/rok/.cmdstanr/cmdstan-2.23.0/examples/bernoulli/bernoulli.data.json +# init = 2 (Default) +# random +# seed = 123 +# output +# file = /tmp/RtmpCvOIQ1/bernoulli_ppc-202006271227-1-986540.csv +# diagnostic_file = (Default) +# refresh = 250 +y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_rep.9,y_rep.10 +0,0,0,0,0,0,0,0,0,0 +0,1,1,0,1,0,1,1,0,0 +0,0,0,0,0,0,0,0,0,0 +0,0,0,0,0,1,0,0,0,1 +1,0,0,0,0,0,0,0,0,0 +# +# Elapsed Time: 0.123 seconds (Generated Quantities) +# diff --git a/tests/testthat/resources/csv/bernoulli_ppc-2-gq-with-timing.csv b/tests/testthat/resources/csv/bernoulli_ppc-2-gq-with-timing.csv new file mode 100644 index 000000000..2c1cda221 --- /dev/null +++ b/tests/testthat/resources/csv/bernoulli_ppc-2-gq-with-timing.csv @@ -0,0 +1,26 @@ +# stan_version_major = 2 +# stan_version_minor = 39 +# stan_version_patch = 0 +# model = bernoulli_ppc_model +# method = generate_quantities +# generate_quantities +# fitted_params = /tmp/RtmpCvOIQ1/fittedParams-202006271227-2-b85b52.csv +# id = 2 +# data +# file = /home/rok/.cmdstanr/cmdstan-2.23.0/examples/bernoulli/bernoulli.data.json +# init = 2 (Default) +# random +# seed = 456 +# output +# file = /tmp/RtmpCvOIQ1/bernoulli_ppc-202006271227-2-986540.csv +# diagnostic_file = (Default) +# refresh = 250 +y_rep.1,y_rep.2,y_rep.3,y_rep.4,y_rep.5,y_rep.6,y_rep.7,y_rep.8,y_rep.9,y_rep.10 +1,0,1,0,0,0,1,0,0,1 +0,1,0,0,1,0,0,1,0,0 +1,0,0,0,0,1,0,0,0,0 +0,0,1,0,0,0,0,1,0,0 +0,1,0,0,0,0,0,0,1,0 +# +# Elapsed Time: 0.456 seconds (Generated Quantities) +# diff --git a/tests/testthat/test-csv.R b/tests/testthat/test-csv.R index 89c0faf36..edd4c9d5f 100644 --- a/tests/testthat/test-csv.R +++ b/tests/testthat/test-csv.R @@ -525,6 +525,58 @@ test_that("time from read_cmdstan_csv matches time from fit$time()", { ) }) +test_that("returning time works for gq read_cmdstan_csv from static CSV", { + csv_files <- test_path("resources", "csv", "bernoulli_ppc-1-gq-with-timing.csv") + csv_data <- read_cmdstan_csv(csv_files) + expect_equal(csv_data$time$total, NA_integer_) + expect_equal(csv_data$time$chains, data.frame( + chain_id = 1, + total = 0.123 + )) + + csv_files <- c( + test_path("resources", "csv", "bernoulli_ppc-1-gq-with-timing.csv"), + test_path("resources", "csv", "bernoulli_ppc-2-gq-with-timing.csv") + ) + csv_data <- read_cmdstan_csv(csv_files) + expect_equal(csv_data$time$total, NA_integer_) + expect_equal(csv_data$time$chains, data.frame( + chain_id = c(1, 2), + total = c(0.123, 0.456) + )) +}) + +test_that("returning time works for gq read_cmdstan_csv from fit object", { + gq_csv <- read_cmdstan_csv(fit_gq$output_files()) + expect_equal(gq_csv$time$total, NA_integer_) + checkmate::expect_data_frame( + gq_csv$time$chains, + any.missing = FALSE, + types = c("numeric", "numeric"), + nrows = fit_gq$num_chains(), + ncols = 2 + ) + expect_named(gq_csv$time$chains, c("chain_id", "total")) + expect_true(all(gq_csv$time$chains$total > 0)) +}) + +test_that("gq time from read_cmdstan_csv matches time from fit_gq$time()", { + expect_equivalent( + read_cmdstan_csv(fit_gq$output_files())$time$chains, + fit_gq$time()$chains + ) +}) + +test_that("returning time is NULL for gq CSV without timing", { + csv_files <- test_path("resources", "csv", "bernoulli_ppc-1-gq.csv") + csv_data <- read_cmdstan_csv(csv_files) + expect_equal(csv_data$time$total, NA_integer_) + expect_equal(csv_data$time$chains, data.frame( + chain_id = 1, + total = 0 + )) +}) + test_that("read_cmdstan_csv reads seed correctly", { opt <- read_cmdstan_csv(fit_bernoulli_optimize$output_files()) vi <- read_cmdstan_csv(fit_bernoulli_variational$output_files()) diff --git a/tests/testthat/test-fit-gq.R b/tests/testthat/test-fit-gq.R index 33cd62bed..e9f7d0cf0 100644 --- a/tests/testthat/test-fit-gq.R +++ b/tests/testthat/test-fit-gq.R @@ -118,6 +118,10 @@ test_that("time() works after gq", { nrows = fit_gq$runset$num_procs(), ncols = 2 ) + expect_named(run_times$chains, c("chain_id", "total")) + # per-chain times should be non-zero (parsed from CmdStan timing output) + expect_true(all(run_times$chains$total > 0)) + expect_true(run_times$total > 0) }) test_that("fitted_params_files() works", {