Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,17 @@ Collate:
'tupleSelectors.R'
'GwasFineMappingResult.R'
'H2Estimate.R'
'JointGroup.R'
'LdBlocks.R'
'LdData.R'
'LdStatistic.R'
'LdEigen.R'
'LdScore.R'
'MashPrior.R'
'QtlDataset.R'
'MultiStudyQtlDataset.R'
'QtlFineMappingResult.R'
'SldscData.R'
'TwasWeightsEntry.R'
'causalInferencePipeline.R'
'colocPipeline.R'
Expand All @@ -118,6 +121,7 @@ Collate:
'gwasSumStats.R'
'h2Annotations.R'
'h2EstimationWrappers.R'
'jointEngine.R'
'jointSpecification.R'
'ld.R'
'mashPipeline.R'
Expand Down
26 changes: 24 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ export(GenotypeHandle)
export(GwasFineMappingResult)
export(GwasSumStats)
export(LdData)
export(MashPrior)
export(MultiStudyQtlDataset)
export(QtlDataset)
export(QtlFineMappingResult)
export(QtlSumStats)
export(SldscData)
export(TwasWeights)
export(TwasWeightsEntry)
export(adjustPips)
Expand Down Expand Up @@ -86,6 +88,8 @@ export(fsusieGetCs)
export(fsusieWeights)
export(fsusieWrapper)
export(getAf)
export(getAnnotCols)
export(getAnnotData)
export(getAnnotationMeta)
export(getAnnotations)
export(getBaseline)
Expand All @@ -96,14 +100,16 @@ export(getContexts)
export(getCorrelation)
export(getCs)
export(getCtwasMetaData)
export(getCvPerformance)
export(getCvFits)
export(getCvResult)
export(getDataType)
export(getEigenList)
export(getEnrichment)
export(getFineMappingResult)
export(getFits)
export(getFormat)
export(getFrqData)
export(getFullFit)
export(getGenome)
export(getGenotypeCovariates)
export(getGenotypeHandle)
Expand Down Expand Up @@ -149,6 +155,9 @@ export(getSusieFit)
export(getSusieResult)
export(getTauBlocks)
export(getTopLoci)
export(getTraitNames)
export(getTraitRun)
export(getTraitRuns)
export(getTraits)
export(getTwasWeights)
export(getVarY)
Expand Down Expand Up @@ -198,6 +207,7 @@ export(mcpRssWeights)
export(mcpWeights)
export(mergeCtwasBoundaryRegions)
export(mergeMashData)
export(mergeSusieCs)
export(mergeVariantInfo)
export(metaAnalysisPerCell)
export(metaSldscRandom)
Expand All @@ -224,6 +234,8 @@ export(raiss)
export(readAfreq)
export(readAnnotations)
export(readGenotypes)
export(readSldscAnnot)
export(readSldscFrq)
export(readSldscTrait)
export(regionDataToIndInput)
export(regionDataToRssInput)
Expand Down Expand Up @@ -270,29 +282,35 @@ exportClasses(LdData)
exportClasses(LdEigen)
exportClasses(LdScore)
exportClasses(LdStatistic)
exportClasses(MashPrior)
exportClasses(MultiStudyQtlDataset)
exportClasses(SldscData)
exportClasses(SumStatsBase)
exportMethods(adjustPips)
exportMethods(colocboostPipeline)
exportMethods(computeLdScores)
exportMethods(estimateH2)
exportMethods(fineMappingPipeline)
exportMethods(getAf)
exportMethods(getAnnotCols)
exportMethods(getAnnotData)
exportMethods(getAnnotationMeta)
exportMethods(getAnnotations)
exportMethods(getBlockMetadata)
exportMethods(getBlocks)
exportMethods(getContexts)
exportMethods(getCorrelation)
exportMethods(getCs)
exportMethods(getCvPerformance)
exportMethods(getCvFits)
exportMethods(getCvResult)
exportMethods(getDataType)
exportMethods(getEigenList)
exportMethods(getEnrichment)
exportMethods(getFineMappingResult)
exportMethods(getFits)
exportMethods(getFormat)
exportMethods(getFrqData)
exportMethods(getFullFit)
exportMethods(getGenome)
exportMethods(getGenotypeCovariates)
exportMethods(getGenotypeHandle)
Expand Down Expand Up @@ -336,6 +354,9 @@ exportMethods(getSumstatDf)
exportMethods(getSusieFit)
exportMethods(getTauBlocks)
exportMethods(getTopLoci)
exportMethods(getTraitNames)
exportMethods(getTraitRun)
exportMethods(getTraitRuns)
exportMethods(getTraits)
exportMethods(getTwasWeights)
exportMethods(getVarY)
Expand Down Expand Up @@ -445,6 +466,7 @@ importFrom(tibble,tibble)
importFrom(tictoc,tic)
importFrom(tictoc,toc)
importFrom(tidyr,separate)
importFrom(tidyselect,all_of)
importFrom(tools,file_ext)
importFrom(tools,file_path_sans_ext)
importFrom(utils,combn)
Expand Down
92 changes: 83 additions & 9 deletions R/AllGenerics.R
Original file line number Diff line number Diff line change
Expand Up @@ -495,15 +495,6 @@ setGeneric("getWeights", function(x, ...) standardGeneric("getWeights"))
setGeneric("getStandardized",
function(x, ...) standardGeneric("getStandardized"))

#' @title Get CV Performance
#' @description Extract cross-validation performance metrics.
#' @param x A \code{TwasWeightsEntry} or \code{TwasWeights}.
#' @param ... Class-specific selection arguments.
#' @return Method-specific (typically a list).
#' @export
setGeneric("getCvPerformance",
function(x, ...) standardGeneric("getCvPerformance"))

#' @title Get Model Fits
#' @description Extract fitted model objects.
#' @param x A \code{TwasWeightsEntry} or \code{TwasWeights}.
Expand All @@ -512,6 +503,24 @@ setGeneric("getCvPerformance",
#' @export
setGeneric("getFits", function(x, ...) standardGeneric("getFits"))

#' @title Get the Full-Data Prior from a MashPrior
#' @description Accessor for the \code{fullFit} slot (the full-data data-driven
#' prior payload).
#' @param x A \code{MashPrior} object.
#' @param ... Unused.
#' @return The full-data prior payload, or \code{NULL}.
#' @export
setGeneric("getFullFit", function(x, ...) standardGeneric("getFullFit"))

#' @title Get the Per-Fold Priors from a MashPrior
#' @description Accessor for the \code{cvFits} slot (per-fold priors +
#' \code{samplePartition}).
#' @param x A \code{MashPrior} object.
#' @param ... Unused.
#' @return The \code{cvFits} list, or \code{NULL}.
#' @export
setGeneric("getCvFits", function(x, ...) standardGeneric("getCvFits"))

#' @title Get Method Names
#' @description Extract method names from a collection class.
#' @param x A \code{FineMappingResult} or \code{TwasWeights} object.
Expand Down Expand Up @@ -846,3 +855,68 @@ setGeneric("getTauBlocks", function(x) standardGeneric("getTauBlocks"))
#' @return Numeric (length 1).
#' @export
setGeneric("getH2", function(x) standardGeneric("getH2"))

# Internal generics for the unified joint-analysis engine (see R/JointGroup.R
# and dev/jointSpecification-s4-refactor.md). Not exported: the engine and its
# fitters are package-internal machinery.

# fitJointGroup(group, pipeline, token, args) -- multiple dispatch on
# (JointGroup subclass, JointPipeline subclass). The 4 irreducible joint fits
# (individual/sumstats x fm/twas). Returns one fit entry (FineMappingEntry or
# TwasWeightsEntry).
setGeneric("fitJointGroup",
function(group, pipeline, token, args) standardGeneric("fitJointGroup"))

# construct(pipeline, rows) -- assemble the per-pipeline result collection
# (QtlFineMappingResult vs TwasWeights) from accumulated joint rows. The joint
# row identity (which axes collapse to "joint" + jointStudies/Contexts/Traits)
# is derived from each group's `conditions` by the rows accumulator.
setGeneric("construct",
function(pipeline, rows, ...) standardGeneric("construct"))

# ---- SldscData accessors ----
#' @title Get the annotation table from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @return A \code{data.frame} of annotations (CHR, SNP, annotation columns).
#' @rdname getAnnotData
#' @export
setGeneric("getAnnotData", function(x) standardGeneric("getAnnotData"))

#' @title Get the allele-frequency table from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @return A \code{data.frame} of reference-panel frequencies (SNP, MAF).
#' @rdname getFrqData
#' @export
setGeneric("getFrqData", function(x) standardGeneric("getFrqData"))

#' @title Get the per-trait runs list from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @return The named list of per-trait \code{single}/\code{joint} runs.
#' @rdname getTraitRuns
#' @export
setGeneric("getTraitRuns", function(x) standardGeneric("getTraitRuns"))

#' @title Get the trait names from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @return A character vector of trait names.
#' @rdname getTraitNames
#' @export
setGeneric("getTraitNames", function(x) standardGeneric("getTraitNames"))

#' @title Get the annotation column names from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @return A character vector of annotation column names.
#' @rdname getAnnotCols
#' @export
setGeneric("getAnnotCols", function(x) standardGeneric("getAnnotCols"))

#' @title Get one trait's run from an SldscData
#' @param x An \code{\link{SldscData}} object.
#' @param trait Character. Trait name.
#' @param ... Further arguments: \code{mode} (\code{"single"}/\code{"joint"})
#' and \code{idx} (which single run).
#' @return A single run list, the list of single runs, or \code{NULL}.
#' @rdname getTraitRun
#' @export
setGeneric("getTraitRun",
function(x, trait, ...) standardGeneric("getTraitRun"))
111 changes: 111 additions & 0 deletions R/JointGroup.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# =============================================================================
# JointGroup S4 hierarchy + dispatch scaffolding
# -----------------------------------------------------------------------------
# The intermediate contract for the unified joint-analysis engine (see
# dev/jointSpecification-s4-refactor.md). Every enumerator emits a list of
# `JointGroup`s; every fitter consumes one. The grammar/parsing half of
# jointSpecification.R and the auto-detection paths both funnel through this.
#
# JointGroup (VIRTUAL) the conditions fitted jointly: a data.frame with
# one row per fitted condition (= per Y/Z column),
# carrying its (study, context, trait) identity.
# IndividualJointGroup design = individual-level (X, Y)
# SumStatsJointGroup design = summary-statistic (Z, R, N)
#
# The OUTPUT row identity is DERIVED from `conditions`: an axis that takes one
# value across all conditions is fixed (that value); an axis that varies is
# collapsed to "joint" with the distinct members recorded in jointStudies /
# jointContexts / jointTraits. So cross-context / cross-trait / cross-study are
# the single-varying-axis case and composed is the >1-varying-axis case --
# uniformly, with the actual fitted tuples preserved (composed loses nothing).
#
# JointDispatchCell one row of the wiring table: (pattern, dataForm)
# -> enumerator + minGroup
# JointPipeline (VIRTUAL) pipeline marker carrying per-pipeline config
# FmJointPipeline fine-mapping -> QtlFineMappingResult
# TwasJointPipeline twas weights -> TwasWeights
#
# Construction is validated (new() runs validity), so an enumerator cannot emit
# a malformed group and a mistyped dispatch cell fails at package load.
# =============================================================================

#' @include AllGenerics.R
NULL

# ---- JointGroup virtual base ------------------------------------------------
setClass("JointGroup",
contains = "VIRTUAL",
representation(conditions = "data.frame"), # one row per condition (Y/Z column)
validity = function(object) {
errors <- character()
if (!all(c("study", "context", "trait") %in% names(object@conditions))) {
errors <- c(errors,
"'conditions' must have columns 'study', 'context', 'trait'")
} else if (nrow(object@conditions) < 1L) {
errors <- c(errors, "a group needs >= 1 condition (Y/Z column)")
}
if (length(errors) == 0L) TRUE else errors
})

# ---- IndividualJointGroup ---------------------------------------------------
# `pos` is the per-condition functional position (one per Y column), set only by
# the cross-trait enumerator for fsusie (functional SuSiE over the trait domain);
# empty for every other pattern/method.
setClass("IndividualJointGroup",
contains = "JointGroup",
representation(X = "matrix", Y = "matrix", pos = "numeric"),
validity = function(object) {
errors <- character()
if (nrow(object@X) != nrow(object@Y))
errors <- c(errors, "X and Y must share the sample (row) dimension")
if (ncol(object@Y) != nrow(object@conditions))
errors <- c(errors, "ncol(Y) must equal nrow(conditions)")
if (length(object@pos) > 0L && length(object@pos) != ncol(object@Y))
errors <- c(errors, "when set, 'pos' must have one entry per Y column")
if (length(errors) == 0L) TRUE else errors
})

# ---- SumStatsJointGroup -----------------------------------------------------
setClass("SumStatsJointGroup",
contains = "JointGroup",
representation(Z = "matrix", R = "matrix", N = "numeric"),
validity = function(object) {
errors <- character()
if (nrow(object@R) != ncol(object@R))
errors <- c(errors, "'R' (LD) must be square")
if (nrow(object@Z) != nrow(object@R))
errors <- c(errors, "'Z' rows (variants) must match the 'R' dimension")
if (ncol(object@Z) != nrow(object@conditions))
errors <- c(errors, "ncol(Z) must equal nrow(conditions)")
if (length(errors) == 0L) TRUE else errors
})

# ---- JointDispatchCell ------------------------------------------------------
setClass("JointDispatchCell",
representation(
pattern = "character", # context / trait / study / composed (a label)
dataForm = "character", # individual / sumstats
enumerate = "function", # (data, scope, args) -> list<JointGroup>
minGroup = "integer"), # smallest fittable condition count (joint cells
# use >= 2; the univariate cell uses 1)
validity = function(object) {
errors <- character()
if (length(object@dataForm) != 1L ||
!object@dataForm %in% c("individual", "sumstats"))
errors <- c(errors, "'dataForm' must be 'individual' or 'sumstats'")
if (length(object@minGroup) != 1L || object@minGroup < 1L)
errors <- c(errors, "'minGroup' must be a single integer >= 1")
if (length(errors) == 0L) TRUE else errors
})

# ---- Pipeline markers -------------------------------------------------------
# Not empty: the `config` list carries the per-pipeline parameter tail
# (coverage/cvFolds/samplePartition/fitFullData/retainFit/... for fm;
# retainFit/retainFitDetail/cvFolds/... for twas), and dispatch on the concrete
# class selects the result type via `construct()`.
setClass("JointPipeline",
contains = "VIRTUAL",
representation(config = "list"))

setClass("FmJointPipeline", contains = "JointPipeline")
setClass("TwasJointPipeline", contains = "JointPipeline")
Loading
Loading