diff --git a/gffquant/__init__.py b/gffquant/__init__.py index f3a023d6..31f4177f 100644 --- a/gffquant/__init__.py +++ b/gffquant/__init__.py @@ -5,7 +5,7 @@ from enum import Enum, auto, unique -__version__ = "2.18.5" +__version__ = "2.19.0" __tool__ = "gffquant" diff --git a/gffquant/__main__.py b/gffquant/__main__.py index c5804587..4133d1a1 100644 --- a/gffquant/__main__.py +++ b/gffquant/__main__.py @@ -128,26 +128,34 @@ def main(): **kwargs, ) - if args.input_type == "fastq": + if args.input_type in ("fastq", "bam", "sam"): - stream_alignments(args, profiler) + if args.input_type == "fastq": - else: + stream_alignments(args, profiler) - input_file = args.bam if args.input_type == "bam" else args.sam - debug_samfile = None - if profiler.debug: - debug_samfile = f"{profiler.out_prefix}.{args.input_type}.filtered.sam" - - profiler.count_alignments( - sys.stdin if input_file == "-" else input_file, - aln_format=args.input_type, - min_identity=args.min_identity, - min_seqlen=args.min_seqlen, - external_readcounts=args.import_readcounts, - unmarked_orphans=args.unmarked_orphans, - debug_samfile=debug_samfile, - ) + else: + + input_file = args.bam if args.input_type == "bam" else args.sam + debug_samfile = None + if profiler.debug: + debug_samfile = f"{profiler.out_prefix}.{args.input_type}.filtered.sam" + + profiler.count_alignments( + sys.stdin if input_file == "-" else input_file, + aln_format=args.input_type, + min_identity=args.min_identity, + min_seqlen=args.min_seqlen, + external_readcounts=args.import_readcounts, + unmarked_orphans=args.unmarked_orphans, + debug_samfile=debug_samfile, + ) + + profiler.report_alignments() + + else: + + ... profiler.finalise( restrict_reports=args.restrict_metrics, @@ -156,6 +164,7 @@ def main(): dump_counters=args.debug, in_memory=args.db_in_memory, gene_group_db=args.gene_group_db, + external_gene_counts=args.gene_counts, ) diff --git a/gffquant/annotation/count_annotator.py b/gffquant/annotation/count_annotator.py index feb96b92..16d5425b 100644 --- a/gffquant/annotation/count_annotator.py +++ b/gffquant/annotation/count_annotator.py @@ -2,6 +2,7 @@ """ This module contains code for transforming gene counts to feature counts. """ +import csv import logging from itertools import chain @@ -197,6 +198,9 @@ class RegionCountAnnotator(CountAnnotator): def __init__(self, strand_specific, report_scaling_factors=True): CountAnnotator.__init__(self, strand_specific, report_scaling_factors=report_scaling_factors) + def annotate_external(self, fn, db, gene_group_db=False): + raise NotImplementedError() + # pylint: disable=R0914,W0613 def annotate(self, refmgr, db, count_manager, gene_group_db=False): """ @@ -326,3 +330,34 @@ def annotate(self, refmgr, db, count_manager, gene_group_db=False): self.unannotated_counts += counts[:4] self.calculate_scaling_factors() + + def annotate_external(self, fn, db, gene_group_db=False): # refmgr, db, count_manager, gene_group_db=False): + + with open(fn, "rt", encoding="UTF-8") as _in: + for row in csv.DictReader(_in, delimiter="\t"): + # gene uniq_raw uniq_lnorm uniq_scaled combined_raw combined_lnorm combined_scaled + cols = row["uniq_raw"], row["uniq_lnorm"], row["combined_raw"], row["combined_lnorm"] + counts = tuple(map(float, cols)) + ref = row["gene"] + + if gene_group_db: + # ref_tokens = ref.split(".") + p = ref.rfind(".") + # gene_id, ggroup_id = ".".join(ref_tokens[:-1]), ref_tokens[-1] + gene_id, ggroup_id = ref[:p], ref[p + 1:] + else: + ggroup_id, gene_id = ref, ref + + gcounts = self.gene_counts.setdefault(gene_id, np.zeros(self.bins)) + gcounts += counts + self.total_gene_counts += counts[:4] + + region_annotation = db.query_sequence(ggroup_id) + if region_annotation is not None: + _, _, region_annotation = region_annotation + self.distribute_feature_counts(counts, region_annotation) + + else: + self.unannotated_counts += counts[:4] + + self.calculate_scaling_factors() diff --git a/gffquant/counters/count_manager.py b/gffquant/counters/count_manager.py index 40ae72a6..0e89b61c 100644 --- a/gffquant/counters/count_manager.py +++ b/gffquant/counters/count_manager.py @@ -155,6 +155,24 @@ def get_counts(self, seqid, region_counts=False, strand_specific=False): uniq_counts, ambig_counts = [uniq_counter[seqid]], [ambig_counter[seqid]] return uniq_counts, ambig_counts + + # def set_counts(self, seqid, value, which_counter): + # counter = (self.uniq_seqcounts, self.ambig_seqcounts)[which_counter == "ambig"] + # counter[seqid] = value + + # def load_data(self, fn): + # # gene uniq_raw uniq_lnorm uniq_scaled combined_raw combined_lnorm combined_scaled + # with open(fn, "rt", encoding="UTF-8") as _in: + # try: + # header = next(_in) + # except StopIteration: + # header = "" + # if header: + # for row in csv.reader(_in, delimiter="\t"): + # gene, counts = row[0], tuple(map(float, row[1:])) + # self.set_counts() + + def get_regions(self, rid): return set(self.uniq_regioncounts.get(rid, set())).union( diff --git a/gffquant/handle_args.py b/gffquant/handle_args.py index 8be6a387..fe64d9b8 100644 --- a/gffquant/handle_args.py +++ b/gffquant/handle_args.py @@ -27,9 +27,7 @@ def validate_args(args): if not all(os.path.isfile(f) for f in db_files): raise ValueError(f"Cannot find annotation db at `{args.annotation_db}`.") - if (args.aligner == "bwa" and not check_bwa_index(args.reference)) or (args.aligner == "minimap" and not check_minimap2_index(args.reference)): - raise ValueError(f"Cannot find reference index at `{args.reference}`.") - + has_fastq = any( map( lambda x: x is not None, @@ -39,22 +37,39 @@ def validate_args(args): ) ) - if tuple(map(bool, (has_fastq, args.bam, args.sam))).count(True) != 1: - raise ValueError(f"Need exactly one type of input: bam={bool(args.bam)} sam={bool(args.sam)} fastq={bool(has_fastq)}.") + if tuple(map(bool, (has_fastq, args.bam, args.sam, args.gene_counts))).count(True) != 1: + raise ValueError( + "Need exactly one type of input: " + f"bam={bool(args.bam)} sam={bool(args.sam)} fastq={bool(has_fastq)} " + f"gene_counts={bool(args.gene_counts)}." + ) + + args.input_type = "fastq" if has_fastq else ("bam" if args.bam else ("sam" if args.sam else "gene_counts")) - args.input_type = "fastq" if has_fastq else ("bam" if args.bam else "sam") + if has_fastq: + if not bool(args.reference and args.aligner): + raise ValueError("--fastq- input requires --reference and --aligner to be set.") - if (args.reference or args.aligner) and not has_fastq: - raise ValueError("--reference/--aligner are not needed with alignment input (bam, sam).") - if bool(args.reference and args.aligner) != has_fastq: - raise ValueError("--fastq requires --reference and --aligner to be set.") + if (args.aligner == "bwa" and not check_bwa_index(args.reference)) or (args.aligner == "minimap" and not check_minimap2_index(args.reference)): + raise ValueError(f"Cannot find `${args.aligner}` reference index at `{args.reference}`.") - if args.input_type == "fastq": args.input_data = check_input_reads( fwd_reads=args.reads1, rev_reads=args.reads2, single_reads=args.singles, orphan_reads=args.orphans, ) - + else: + if bool(args.reference or args.aligner): + raise ValueError("--reference/--aligner parameters are only required for --fastq- input.") + + if (args.strand_specific and args.gene_counts): + raise NotImplementedError("External gene count input is not implemented for strand-specific counts.") + + # if (args.reference or args.aligner) and not has_fastq: + # raise ValueError("--reference/--aligner parameters are only required for --fastq- input.") + + # if bool(args.reference and args.aligner) != has_fastq: + # raise ValueError("--fastq- input requires --reference and --aligner to be set.") + if args.restrict_metrics: restrict_metrics = set(args.restrict_metrics.split(",")) invalid = restrict_metrics.difference(('raw', 'lnorm', 'scaled', 'rpkm')) @@ -189,6 +204,12 @@ def handle_args(args): # Input from STDOUT can be used with '-'.""" # ), # ) + ap.add_argument( + "--gene_counts", + type=str, + help="Path to a file containing a gene_counts matrix from a previous gffquant run." + ) + ap.add_argument( "--bam", type=str, diff --git a/gffquant/profilers/feature_quantifier.py b/gffquant/profilers/feature_quantifier.py index c66fb190..b83c5ba5 100644 --- a/gffquant/profilers/feature_quantifier.py +++ b/gffquant/profilers/feature_quantifier.py @@ -13,6 +13,7 @@ from dataclasses import dataclass, asdict from .panda_coverage_profiler import PandaCoverageProfiler +from .reference_hit import ReferenceHit from ..alignment import AlignmentGroup, AlignmentProcessor, SamFlags from ..annotation import GeneCountAnnotator, RegionCountAnnotator, CountWriter from ..counters import CountManager @@ -24,39 +25,6 @@ logger = logging.getLogger(__name__) -@dataclass(slots=True) -class ReferenceHit: - rid: int = None - start: int = None - end: int = None - rev_strand: bool = None - cov_start: int = None - cov_end: int = None - has_annotation: bool = None - n_aln: int = None - is_ambiguous: bool = None - library_mod: int = None - mate_id: int = None - - def __hash__(self): - return hash(tuple(asdict(self).values())) - - def __eq__(self, other): - return all( - item[0][1] == item[1][1] - for item in zip( - sorted(asdict(self).items()), - sorted(asdict(other).items()) - ) - ) - - def __str__(self): - return "\t".join(map(str, asdict(self).values())) - - def __repr__(self): - return str(self) - - class FeatureQuantifier(ABC): """ Three groups of alignments: @@ -153,31 +121,44 @@ def process_counters( dump_counters=True, in_memory=True, gene_group_db=False, + external_gene_counts=None, ): if self.adm is None: self.adm = AnnotationDatabaseManager.from_db(self.db, in_memory=in_memory) - if dump_counters: + if dump_counters and not external_gene_counts: self.count_manager.dump_raw_counters(self.out_prefix, self.reference_manager) report_scaling_factors = restrict_reports is None or "scaled" in restrict_reports - Annotator = (GeneCountAnnotator, RegionCountAnnotator)[self.run_mode.overlap_required] + Annotator = (GeneCountAnnotator, RegionCountAnnotator)[self.run_mode.overlap_required and not external_gene_counts] count_annotator = Annotator(self.strand_specific, report_scaling_factors=report_scaling_factors) - count_annotator.annotate(self.reference_manager, self.adm, self.count_manager, gene_group_db=gene_group_db,) + + if external_gene_counts: + count_annotator.annotate_external(external_gene_counts, self.adm, gene_group_db=gene_group_db,) + total_readcount = 1 + filtered_readcount = 1 + has_ambig_counts = True + else: + count_annotator.annotate(self.reference_manager, self.adm, self.count_manager, gene_group_db=gene_group_db,) + total_readcount = self.aln_counter["read_count"] + filtered_readcount = self.aln_counter["filtered_read_count"] + has_ambig_counts = self.count_manager.has_ambig_counts() count_writer = CountWriter( self.out_prefix, - has_ambig_counts=self.count_manager.has_ambig_counts(), + has_ambig_counts=has_ambig_counts, strand_specific=self.strand_specific, restrict_reports=restrict_reports, report_category=report_category, - total_readcount=self.aln_counter["read_count"], - filtered_readcount=self.aln_counter["filtered_read_count"], + total_readcount=total_readcount, + filtered_readcount=filtered_readcount, ) - unannotated_reads = self.count_manager.get_unannotated_reads() - unannotated_reads += self.aln_counter["unannotated_ambig"] + unannotated_reads = 0 + if not external_gene_counts: + unannotated_reads += self.count_manager.get_unannotated_reads() + unannotated_reads += self.aln_counter["unannotated_ambig"] count_writer.write_feature_counts( self.adm, @@ -185,11 +166,12 @@ def process_counters( (None, unannotated_reads)[report_unannotated], ) - count_writer.write_gene_counts( - count_annotator.gene_counts, - count_annotator.scaling_factors["total_gene_uniq"], - count_annotator.scaling_factors["total_gene_ambi"] - ) + if not external_gene_counts: + count_writer.write_gene_counts( + count_annotator.gene_counts, + count_annotator.scaling_factors["total_gene_uniq"], + count_annotator.scaling_factors["total_gene_ambi"] + ) self.adm.clear_caches() @@ -330,6 +312,50 @@ def count_alignments( self.aln_counter.update(aln_reader.get_alignment_stats_dict()) + + def report_alignments(self): + with open(f"{self.out_prefix}.aln_stats.json", "wt") as aln_stats_out: + json.dump(self.aln_counter, aln_stats_out) + # print( + # AlignmentProcessor.get_alignment_stats_str( + # [ + # v + # for k, v in self.aln_counter.items() + # if k.startswith("pysam_") and not k.endswith("total") + # ], + # table=True, + # ), + # file=aln_stats_out + # ) + + for metric, value in ( + ("Input reads", "full_read_count"), + ("Aligned reads", "read_count"), + ("Alignments", "pysam_total"), + ("Reads passing filters", "filtered_read_count"), + ("Alignments passing filters", "pysam_passed"), + (" - Discarded due to seqid", "pysam_seqid_filt"), + (" - Discarded due to length", "pysam_len_filt"), + # ("Unannotated multimappers", "unannotated_ambig"), + ): + logger.info("%s: %s", metric, self.aln_counter.get(value)) + + if self.aln_counter["full_read_count"]: + alignment_rate = round(self.aln_counter["read_count"] / self.aln_counter["full_read_count"], 3) * 100, + filter_pass_rate = round(self.aln_counter["filtered_read_count"] / self.aln_counter["full_read_count"], 3) * 100, + else: + alignment_rate, filter_pass_rate = None, None + + logger.info( + "Alignment rate: %s%%, Filter pass rate: %s%%" % ( + alignment_rate, filter_pass_rate, + ) + ) + + # def load_gene_counts(self, gene_count_matrix): + # self.aln_counter["aln_count"] = 1 + # self.count_manager.load_data(gene_count_matrix) + def finalise( self, restrict_reports=None, @@ -338,33 +364,13 @@ def finalise( dump_counters=False, in_memory=True, gene_group_db=False, + external_gene_counts=None, ): - with gzip.open(f"{self.out_prefix}.aln_stats.txt.gz", "wt") as aln_stats_out: - print( - AlignmentProcessor.get_alignment_stats_str( - [ - v - for k, v in self.aln_counter.items() - if k.startswith("pysam_") and not k.endswith("total") - ], - table=True, - ), - file=aln_stats_out - ) - - if self.aln_counter.get("aln_count"): + if self.aln_counter.get("aln_count") or external_gene_counts: if self.adm is None: self.adm = AnnotationDatabaseManager.from_db(self.db, in_memory=in_memory) - report_args = { - "restrict_reports": restrict_reports, - "report_category": report_category, - "report_unannotated": report_unannotated, - } - - # self.write_coverage() - self.process_counters( restrict_reports=restrict_reports, report_category=report_category, @@ -372,25 +378,8 @@ def finalise( dump_counters=dump_counters, in_memory=in_memory, gene_group_db=gene_group_db, - ) - - for metric, value in ( - ("Input reads", "full_read_count"), - ("Aligned reads", "read_count"), - ("Alignments", "pysam_total"), - ("Reads passing filters", "filtered_read_count"), - ("Alignments passing filters", "pysam_passed"), - (" - Discarded due to seqid", "pysam_seqid_filt"), - (" - Discarded due to length", "pysam_len_filt"), - # ("Unannotated multimappers", "unannotated_ambig"), - ): - logger.info("%s: %s", metric, self.aln_counter.get(value)) - - logger.info( - "Alignment rate: %s%%, Filter pass rate: %s%%", - round(self.aln_counter["read_count"] / self.aln_counter["full_read_count"], 3) * 100, - round(self.aln_counter["filtered_read_count"] / self.aln_counter["full_read_count"], 3) * 100, - ) + external_gene_counts=external_gene_counts, + ) self.adm.clear_caches() diff --git a/gffquant/profilers/reference_hit.py b/gffquant/profilers/reference_hit.py new file mode 100644 index 00000000..4cb30f45 --- /dev/null +++ b/gffquant/profilers/reference_hit.py @@ -0,0 +1,32 @@ +""" module docstring """ + +from dataclasses import asdict, dataclass + + +@dataclass(slots=True) +class ReferenceHit: + rid: int = None + start: int = None + end: int = None + rev_strand: bool = None + cov_start: int = None + cov_end: int = None + has_annotation: bool = None + n_aln: int = None + is_ambiguous: bool = None + library_mod: int = None + mate_id: int = None + + def __hash__(self): + return hash(tuple(asdict(self).values())) + + def __eq__(self, other): + if not isinstance(other, ReferenceHit): + raise NotImplementedError(f"{other} is not a valid ReferenceHit object.") + return sorted(asdict(self).items()) == sorted(asdict(other).items()) + + def __str__(self): + return "\t".join(map(str, asdict(self).values())) + + def __repr__(self): + return str(self) \ No newline at end of file