Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 53 additions & 9 deletions smaca/sma.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,59 @@ def __init__(self, bam_list, ref, n_jobs=1):
self.dup_id[:] = dup_id_memmap[:]

self.r_ij = self.D1_ij + self.D2_ij
#TODO: consider using only "the bests" HK
self.z_ik = self.c_ix.sum(axis=1).reshape((self.n_bam, 1)) / self.H_ik
self.std_k = np.std(self.H_ik, axis=0)
self.std_i = np.std(self.H_ik, axis=1)
self.zmean_k = self.z_ik.sum(axis=0) / self.n_bam
self.theta_i = (self.z_ik / self.zmean_k).sum(axis=1) / len(
C.POSITIONS[ref]["GENES"])
self.pi_ij = self.theta_i.reshape(
(self.n_bam, 1)) * (self.D1_ij / self.r_ij)

# Detect samples with zero-coverage control genes and warn
gene_names = list(C.POSITIONS[ref]["GENES"].keys())
zero_cov_samples = np.any(self.H_ik == 0, axis=1)

for i in range(self.n_bam):
if zero_cov_samples[i]:
zero_genes = [gene_names[k] for k in range(len(gene_names))
if self.H_ik[i, k] == 0]
click.echo(
f"Warning: Sample '{self.bam_list[i]}' has zero coverage"
f" in control gene(s): {', '.join(zero_genes)}. "
f"Results for this sample will be set to NaN.",
err=True
)

# Compute statistics with safe division
with np.errstate(divide='ignore', invalid='ignore'):
self.z_ik = np.where(
self.H_ik > 0,
self.c_ix.sum(axis=1).reshape((self.n_bam, 1)) / self.H_ik,
0.0
)
self.std_k = np.std(self.H_ik, axis=0)
self.std_i = np.std(self.H_ik, axis=1)
# Exclude zero-coverage samples from zmean_k so they don't
# corrupt statistics for healthy samples
valid_mask = ~zero_cov_samples
n_valid = valid_mask.sum()
if n_valid > 0:
self.zmean_k = (
self.z_ik[valid_mask].sum(axis=0) / n_valid
)
else:
self.zmean_k = np.zeros(self.z_ik.shape[1])
self.theta_i = np.where(
self.zmean_k > 0,
self.z_ik / self.zmean_k,
0.0
).sum(axis=1) / len(C.POSITIONS[ref]["GENES"])
self.pi_ij = np.where(
self.r_ij > 0,
self.theta_i.reshape((self.n_bam, 1))
* (self.D1_ij / np.where(self.r_ij > 0, self.r_ij, 1.0)),
np.nan
)

# Overwrite all computed stats to NaN for zero-coverage samples
for i in range(self.n_bam):
if zero_cov_samples[i]:
self.pi_ij[i, :] = np.nan
self.theta_i[i] = np.nan
self.z_ik[i, :] = np.nan

def write_stats(self, output_file):
"""
Expand Down
22 changes: 20 additions & 2 deletions smaca/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,17 @@ def test_bamclass(self):
def test_sma_stats(self):
s = SmaCalculator(BAM_LIST, ref=C.REF_HG19)

# Sample 0 (HG002) has full coverage; sample 1 (HG007) has zero
# coverage in 6 control genes and is excluded from zmean_k
np.testing.assert_array_almost_equal(
s.pi_ij[0], [0.8027337510925755, 0.7296692915157783, 0.6580751624125392])
self.assertAlmostEqual(s.zmean_k[0], 0.6135924580475719)
s.pi_ij[0], [0.65682657, 0.59704251, 0.53846154])
self.assertAlmostEqual(s.zmean_k[0], 0.9887272198987905)
np.testing.assert_array_almost_equal(
s.std_i, [21.18266345, 5.07109509])
self.assertAlmostEqual(s.std_k[0], 15.516854393351194)
self.assertEqual(s.dup_id[0][0], b'T [[0], [0], [0], [106]]')
# Sample 1 should have NaN due to zero-coverage genes
self.assertTrue(np.all(np.isnan(s.pi_ij[1])))

def test_get_chr_prefix(self):
sam_file = pysam.AlignmentFile(BAM_hg19, "rb")
Expand Down Expand Up @@ -110,5 +114,19 @@ def test_hg19_hg38_coverages(self):
err_msg=";".join(C.POSITIONS[C.REF_HG19][ranges]))


def test_zero_coverage_gene_warning(self):
"""Test that zero-coverage control genes produce warnings and NaN."""
s = SmaCalculator(BAM_LIST, ref=C.REF_HG19)

# Sample 0 (HG002) has full coverage — finite results
self.assertTrue(np.all(np.isfinite(s.pi_ij[0])))
self.assertTrue(np.isfinite(s.theta_i[0]))

# Sample 1 (HG007) has zero coverage in some control genes — NaN
self.assertTrue(np.all(np.isnan(s.pi_ij[1])))
self.assertTrue(np.isnan(s.theta_i[1]))
self.assertTrue(np.all(np.isnan(s.z_ik[1])))


if __name__ == '__main__':
unittest.main()