Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 123 additions & 2 deletions src/openknotscore/pipeline/scoring.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import re
import pandas
import statistics
from arnie.utils import convert_dotbracket_to_bp_list, post_process_struct
import numpy as np
from arnie.utils import (
get_helices,
convert_dotbracket_to_bp_list,
post_process_struct
)

def calculateEternaClassicScore(structure, data, score_start_idx, score_end_idx, filter_singlets=False):
"""Calculates an Eterna score for a predicted structure and accompanying reactivity dataset
Expand Down Expand Up @@ -301,4 +306,120 @@ def calculateOpenKnotScore(row, prediction_tags):
df["ensemble_structures"] = list(row[top_scoring_index].values)
df["ensemble_structures_ecs"] = list(top_scoring_by_ecs.values)

return df
return df


def remove_bps_in_blanked_regions(bps, num_res, BLANK_OUT5, BLANK_OUT3):
filtered_bps = []
for bp in bps:
if (bp[0] < bp[1]):
if (bp[0] <= BLANK_OUT5): continue
if (bp[1] <= BLANK_OUT5): continue
if (bp[0] > num_res - BLANK_OUT3): continue
if (bp[1] > num_res - BLANK_OUT3): continue
filtered_bps.append(bp)
return filtered_bps


def calculateOpenKnotScoreBatched(
target_sec_struct,
reactivity_data_arr,
threshold_SHAPE_fixed = 0.5,
threshold_SHAPE_fixed_pair = 0.25,
min_SHAPE_fixed = 0.0,
BLANK_OUT5=0,
BLANK_OUT3=0
):
"""
Calculates the OpenKnot score for a DBN structure and multiple accompanying datasets in a batched manner.

crossed_pair_score =
100 * (number of residues in crossed pairs with data < 0.25) /
[0.7*(length of region with data - 20)]

crossed_pair_quality_score =
100 * (number of residues in crossed pairs with data < 0.25) /
( number of residues modeled to be crossed pairs in structure )

openknot_score = 0.5 * eterna_score + 0.5 * crossed_pair_quality_score

Args:
target_sec_struct: str, target secondary structure with pseudoknots
reactivity_data_arr: np.array, array of predicted 2A3 chemical modifications
threshold_SHAPE_fixed: float, threshold value for SHAPE data
threshold_SHAPE_fixed_pair: float, threshold value for SHAPE data
min_SHAPE_fixed: float, minimum value threshold of paired or unpaired in SHAPE data
BLANK_OUT5: int, number of bases to blank out at the 5' end
BLANK_OUT3: int, number of bases to blank out at the 3' end

Returns:
list of tuple: A list of tuples, where each tuple contains:
eterna_score: float, Eterna score
crossed_pair_score: float, crossed pair score
crossed_pair_quality_score: float, crossed pair quality score
openknot_score: float, openknot score
"""
# Convert the input structure to a binary paired/unpaired list
# We only have trustworthy data for the bases in the middle of the sequence
# so we skip the blanked regions at the beginning and end
prediction = np.array([1 if char == "." else 0 for char in target_sec_struct][BLANK_OUT5:(len(target_sec_struct) - BLANK_OUT3)])

# Calculate correct hits for unpaired bases
unpaired_hits = (prediction == 1) & (reactivity_data_arr > (0.25 * threshold_SHAPE_fixed + 0.75 * min_SHAPE_fixed))

# Calculate correct hits for paired bases
paired_hits = (prediction == 0) & (reactivity_data_arr < threshold_SHAPE_fixed)

# Sum correct hits
correct_hits = unpaired_hits | paired_hits

# Calculate Eterna scores
eterna_scores = (np.sum(correct_hits, axis=1) / reactivity_data_arr.shape[1]) * 100

padded_data_arr = np.pad(
reactivity_data_arr,
((0, 0), (BLANK_OUT5, BLANK_OUT3)),
constant_values=np.nan,
)

# Some algorithms will produce a string of all x's when they fail on a sequence
# This checks for a structure string that is all x characters and returns if true
failed_structure = [char == "x" for char in target_sec_struct]
if all(failed_structure):
return [[0, 0] for _ in reactivity_data_arr]

# Remove singlet base pairs
stems = get_helices(target_sec_struct)
stems = [stem for stem in stems if len(stem) > 1]
# Convert the list of helices into a list of base pairs (flatten 3D list to 2D list)
bp_list = [item for sublist in stems for item in sublist]

# Get indexes for bases in crossed pairs
crossed_res = identify_crossing_bps(bp_list)

# Filter out base pairs that involve residues in the blanked out flanking regions
bps_filtered = remove_bps_in_blanked_regions(bp_list, len(target_sec_struct), BLANK_OUT5, BLANK_OUT3)
crossed_res_filtered = identify_crossing_bps(bps_filtered)

crossed_res = np.array(crossed_res)
crossed_res_filtered = np.array(crossed_res_filtered)

max_count = np.sum((crossed_res >= BLANK_OUT5) & (crossed_res < len(target_sec_struct) - BLANK_OUT3))

if len(crossed_res) == 0:
num_crossed_pairs = np.zeros(reactivity_data_arr.shape[0])
else:
num_crossed_pairs = np.sum((padded_data_arr[:, crossed_res] < threshold_SHAPE_fixed_pair) &
np.isin(crossed_res, crossed_res_filtered), axis=1) + \
0.5 * np.sum((padded_data_arr[:, crossed_res] < threshold_SHAPE_fixed_pair) &
~np.isin(crossed_res, crossed_res_filtered), axis=1)

data_region_length = padded_data_arr.shape[1] - BLANK_OUT5 - BLANK_OUT3
max_crossed_pairs = 0.7 * np.maximum(data_region_length - 20, 20)
crossed_pair_score = 100 * np.minimum(num_crossed_pairs / max_crossed_pairs, 1.0)

crossed_pair_quality_score = 100 * (num_crossed_pairs / max_count) if max_count > 0 else 0

openknot_score = 0.5 * eterna_scores + 0.5 * crossed_pair_quality_score

return eterna_scores, crossed_pair_score, crossed_pair_quality_score, openknot_score