Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 196 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ cmdstan_model <- function(stan_file = NULL, exe_file = NULL, compile = TRUE, ...
#' [`$hpp_file()`][model-method-compile] | Return the file path to the `.hpp` file containing the generated C++ code. |
#' [`$save_hpp_file()`][model-method-compile] | Save the `.hpp` file containing the generated C++ code. |
#' [`$expose_functions()`][model-method-expose_functions] | Expose Stan functions for use in R. |
#' [`$get_cmdstan_args()`][model-method-get_cmdstan_args] | Get CmdStan default argument values for a method. |
#'
#' ## Diagnostics
#'
Expand Down Expand Up @@ -2209,6 +2210,201 @@ expose_functions = function(global = FALSE, verbose = FALSE) {
CmdStanModel$set("public", name = "expose_functions", value = expose_functions)


#' Get CmdStan default argument values
#'
#' @name model-method-get_cmdstan_args
#' @aliases get_cmdstan_args
#' @family CmdStanModel methods
#'
#' @description The `$get_cmdstan_args()` method of a [`CmdStanModel`]
#' object queries the compiled model binary for the default argument
#' values used by a given inference method. The returned list uses
#' cmdstanr-style argument names (e.g., `iter_sampling` instead of
#' CmdStan's `num_samples`).
#'
#' The model must be compiled before calling this method.
#'
#' @param method (string) The inference method whose defaults to
#' retrieve. One of `"sample"`, `"optimize"`, `"variational"`,
#' `"pathfinder"`, or `"laplace"`.
#' @return A named list of default argument values for the specified
#' method, with cmdstanr-style argument names.
#'
#' @template seealso-docs
#'
#' @examples
#' \dontrun{
#' mod <- cmdstan_model(file.path(cmdstan_path(),
#' "examples/bernoulli/bernoulli.stan"))
#' mod$get_cmdstan_args("sample")
#' mod$get_cmdstan_args("optimize")
#' }
#'
get_cmdstan_args <- function(method = c("sample", "optimize", "variational",
"pathfinder", "laplace")) {
method <- match.arg(method)
if (length(self$exe_file()) == 0 || !file.exists(self$exe_file())) {
stop(
"'$get_cmdstan_args()' requires a compiled model. ",
"Please compile the model first with '$compile()'.",
call. = FALSE
)
}
parse_cmdstan_args(self$exe_file(), method)
}
CmdStanModel$set("public", name = "get_cmdstan_args", value = get_cmdstan_args)


# get_cmdstan_args helpers ------------------------------------------------

#' Parse CmdStan default argument values from model binary
#'
#' Runs a CmdStan model binary with `help-all` to extract valid arguments
#' and their default values for a given inference method, returning them
#' with cmdstanr argument names.
#'
#' @noRd
#' @param model_binary Path to the CmdStan model binary.
#' @param method Inference method: `"sample"`, `"optimize"`,
#' `"variational"`, `"pathfinder"`, or `"laplace"`.
#' @return A named list with cmdstanr-style argument names and default
#' values.
parse_cmdstan_args <- function(model_binary, method) {
ret <- wsl_compatible_run(
command = wsl_safe_path(model_binary),
args = c(method, "help-all"),
error_on_status = FALSE
)
output <- strsplit(ret$stdout, "\n")[[1]]

arguments <- map_cmdstan_to_cmdstanr(method)
target_args <- vapply(arguments, function(p) {
parts <- strsplit(p, "\\.")[[1]]
parts[length(parts)]
}, FUN.VALUE = character(1), USE.NAMES = TRUE)

result <- list()
n <- length(output)

for (i in seq_len(n)) {
line <- output[i]
content <- trimws(line)

# Match argument lines like "num_samples=<int>" or "t0=<double>"
arg_match <- regmatches(content, regexec("^([a-z_][a-z0-9_]*)=", content))[[1]]

if (length(arg_match) >= 2) {
arg_name <- arg_match[2]

# Check if this is one of our target arguments
matches <- which(target_args == arg_name)

if (length(matches) > 0) {
# Look ahead for "Defaults to" line
default_value <- NULL
for (j in (i + 1):min(i + 5, n)) {
next_content <- trimws(output[j])
if (grepl("^Defaults to", next_content)) {
default_value <- parse_default_value(next_content)
break
}
# Stop if we hit another argument
if (grepl("^[a-z_][a-z0-9_]*=", next_content)) break
}

# Add to result for each matching cmdstanr argument name
for (m in matches) {
cmdstanr_name <- names(target_args)[m]
result[[cmdstanr_name]] <- default_value
}
}
}
}

result
}

#' Parse default value from "Defaults to ..." line
#' @noRd
parse_default_value <- function(line) {
val_str <- sub("^Defaults to\\s*", "", line)
if (val_str %in% c("true", "false")) return(val_str == "true")
if (grepl("^-?[0-9]+$", val_str)) return(as.integer(val_str))
if (grepl("^-?[0-9]*\\.?[0-9]+([eE][+-]?[0-9]+)?$", val_str)) return(as.numeric(val_str))
val_str
}

#' Map CmdStan argument names to CmdStanR argument names
#' @noRd
map_cmdstan_to_cmdstanr <- function(method) {
switch(method,
sample = c(
iter_sampling = "sample.num_samples",
iter_warmup = "sample.num_warmup",
save_warmup = "sample.save_warmup",
thin = "sample.thin",
adapt_engaged = "sample.adapt.engaged",
adapt_delta = "sample.adapt.delta",
init_buffer = "sample.adapt.init_buffer",
term_buffer = "sample.adapt.term_buffer",
window = "sample.adapt.window",
save_metric = "sample.adapt.save_metric",
max_treedepth = "sample.algorithm.hmc.engine.nuts.max_depth",
metric = "sample.algorithm.hmc.metric",
metric_file = "sample.algorithm.hmc.metric_file",
step_size = "sample.algorithm.hmc.stepsize",
num_chains = "sample.num_chains"
),
optimize = c(
algorithm = "optimize.algorithm",
jacobian = "optimize.jacobian",
iter = "optimize.iter",
save_iterations = "optimize.save_iterations",
init_alpha = "optimize.algorithm.lbfgs.init_alpha",
tol_obj = "optimize.algorithm.lbfgs.tol_obj",
tol_rel_obj = "optimize.algorithm.lbfgs.tol_rel_obj",
tol_grad = "optimize.algorithm.lbfgs.tol_grad",
tol_rel_grad = "optimize.algorithm.lbfgs.tol_rel_grad",
tol_param = "optimize.algorithm.lbfgs.tol_param",
history_size = "optimize.algorithm.lbfgs.history_size"
),
variational = c(
algorithm = "variational.algorithm",
iter = "variational.iter",
grad_samples = "variational.grad_samples",
elbo_samples = "variational.elbo_samples",
eta = "variational.eta",
adapt_engaged = "variational.adapt.engaged",
adapt_iter = "variational.adapt.iter",
tol_rel_obj = "variational.tol_rel_obj",
eval_elbo = "variational.eval_elbo",
output_samples = "variational.output_samples"
),
pathfinder = c(
init_alpha = "pathfinder.init_alpha",
tol_obj = "pathfinder.tol_obj",
tol_rel_obj = "pathfinder.tol_rel_obj",
tol_grad = "pathfinder.tol_grad",
tol_rel_grad = "pathfinder.tol_rel_grad",
tol_param = "pathfinder.tol_param",
history_size = "pathfinder.history_size",
draws = "pathfinder.num_psis_draws",
num_paths = "pathfinder.num_paths",
save_single_paths = "pathfinder.save_single_paths",
psis_resample = "pathfinder.psis_resample",
calculate_lp = "pathfinder.calculate_lp",
max_lbfgs_iters = "pathfinder.max_lbfgs_iters",
single_path_draws = "pathfinder.num_draws",
num_elbo_draws = "pathfinder.num_elbo_draws"
),
laplace = c(
jacobian = "laplace.jacobian",
draws = "laplace.draws"
),
character(0)
)
}


# internal ----------------------------------------------------------------
assert_valid_stanc_options <- function(stanc_options) {
Expand Down
1 change: 1 addition & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ reference:
- read_cmdstan_csv
- write_stan_json
- write_stan_file
- print_stan_file
- draws_to_csv
- as_mcmc.list
- as_draws.CmdStanMCMC
Expand Down
1 change: 1 addition & 0 deletions man/CmdStanModel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions man/cmdstanr-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-check_syntax.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-compile.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-diagnose.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-expose_functions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-format.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-generate-quantities.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions man/model-method-get_cmdstan_args.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-laplace.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-optimize.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/model-method-pathfinder.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading