diff --git a/src/openknotscore/pipeline/scoring.py b/src/openknotscore/pipeline/scoring.py index 05237d5..84e6833 100644 --- a/src/openknotscore/pipeline/scoring.py +++ b/src/openknotscore/pipeline/scoring.py @@ -1,7 +1,12 @@ import re import pandas import statistics -from arnie.utils import convert_dotbracket_to_bp_list, post_process_struct +import numpy as np +from arnie.utils import ( + get_helices, + convert_dotbracket_to_bp_list, + post_process_struct +) def calculateEternaClassicScore(structure, data, score_start_idx, score_end_idx, filter_singlets=False): """Calculates an Eterna score for a predicted structure and accompanying reactivity dataset @@ -301,4 +306,120 @@ def calculateOpenKnotScore(row, prediction_tags): df["ensemble_structures"] = list(row[top_scoring_index].values) df["ensemble_structures_ecs"] = list(top_scoring_by_ecs.values) - return df \ No newline at end of file + return df + + +def remove_bps_in_blanked_regions(bps, num_res, BLANK_OUT5, BLANK_OUT3): + filtered_bps = [] + for bp in bps: + if (bp[0] < bp[1]): + if (bp[0] <= BLANK_OUT5): continue + if (bp[1] <= BLANK_OUT5): continue + if (bp[0] > num_res - BLANK_OUT3): continue + if (bp[1] > num_res - BLANK_OUT3): continue + filtered_bps.append(bp) + return filtered_bps + + +def calculateOpenKnotScoreBatched( + target_sec_struct, + reactivity_data_arr, + threshold_SHAPE_fixed = 0.5, + threshold_SHAPE_fixed_pair = 0.25, + min_SHAPE_fixed = 0.0, + BLANK_OUT5=0, + BLANK_OUT3=0 + ): + """ + Calculates the OpenKnot score for a DBN structure and multiple accompanying datasets in a batched manner. + + crossed_pair_score = + 100 * (number of residues in crossed pairs with data < 0.25) / + [0.7*(length of region with data - 20)] + + crossed_pair_quality_score = + 100 * (number of residues in crossed pairs with data < 0.25) / + ( number of residues modeled to be crossed pairs in structure ) + + openknot_score = 0.5 * eterna_score + 0.5 * crossed_pair_quality_score + + Args: + target_sec_struct: str, target secondary structure with pseudoknots + reactivity_data_arr: np.array, array of predicted 2A3 chemical modifications + threshold_SHAPE_fixed: float, threshold value for SHAPE data + threshold_SHAPE_fixed_pair: float, threshold value for SHAPE data + min_SHAPE_fixed: float, minimum value threshold of paired or unpaired in SHAPE data + BLANK_OUT5: int, number of bases to blank out at the 5' end + BLANK_OUT3: int, number of bases to blank out at the 3' end + + Returns: + list of tuple: A list of tuples, where each tuple contains: + eterna_score: float, Eterna score + crossed_pair_score: float, crossed pair score + crossed_pair_quality_score: float, crossed pair quality score + openknot_score: float, openknot score + """ + # Convert the input structure to a binary paired/unpaired list + # We only have trustworthy data for the bases in the middle of the sequence + # so we skip the blanked regions at the beginning and end + prediction = np.array([1 if char == "." else 0 for char in target_sec_struct][BLANK_OUT5:(len(target_sec_struct) - BLANK_OUT3)]) + + # Calculate correct hits for unpaired bases + unpaired_hits = (prediction == 1) & (reactivity_data_arr > (0.25 * threshold_SHAPE_fixed + 0.75 * min_SHAPE_fixed)) + + # Calculate correct hits for paired bases + paired_hits = (prediction == 0) & (reactivity_data_arr < threshold_SHAPE_fixed) + + # Sum correct hits + correct_hits = unpaired_hits | paired_hits + + # Calculate Eterna scores + eterna_scores = (np.sum(correct_hits, axis=1) / reactivity_data_arr.shape[1]) * 100 + + padded_data_arr = np.pad( + reactivity_data_arr, + ((0, 0), (BLANK_OUT5, BLANK_OUT3)), + constant_values=np.nan, + ) + + # Some algorithms will produce a string of all x's when they fail on a sequence + # This checks for a structure string that is all x characters and returns if true + failed_structure = [char == "x" for char in target_sec_struct] + if all(failed_structure): + return [[0, 0] for _ in reactivity_data_arr] + + # Remove singlet base pairs + stems = get_helices(target_sec_struct) + stems = [stem for stem in stems if len(stem) > 1] + # Convert the list of helices into a list of base pairs (flatten 3D list to 2D list) + bp_list = [item for sublist in stems for item in sublist] + + # Get indexes for bases in crossed pairs + crossed_res = identify_crossing_bps(bp_list) + + # Filter out base pairs that involve residues in the blanked out flanking regions + bps_filtered = remove_bps_in_blanked_regions(bp_list, len(target_sec_struct), BLANK_OUT5, BLANK_OUT3) + crossed_res_filtered = identify_crossing_bps(bps_filtered) + + crossed_res = np.array(crossed_res) + crossed_res_filtered = np.array(crossed_res_filtered) + + max_count = np.sum((crossed_res >= BLANK_OUT5) & (crossed_res < len(target_sec_struct) - BLANK_OUT3)) + + if len(crossed_res) == 0: + num_crossed_pairs = np.zeros(reactivity_data_arr.shape[0]) + else: + num_crossed_pairs = np.sum((padded_data_arr[:, crossed_res] < threshold_SHAPE_fixed_pair) & + np.isin(crossed_res, crossed_res_filtered), axis=1) + \ + 0.5 * np.sum((padded_data_arr[:, crossed_res] < threshold_SHAPE_fixed_pair) & + ~np.isin(crossed_res, crossed_res_filtered), axis=1) + + data_region_length = padded_data_arr.shape[1] - BLANK_OUT5 - BLANK_OUT3 + max_crossed_pairs = 0.7 * np.maximum(data_region_length - 20, 20) + crossed_pair_score = 100 * np.minimum(num_crossed_pairs / max_crossed_pairs, 1.0) + + crossed_pair_quality_score = 100 * (num_crossed_pairs / max_count) if max_count > 0 else 0 + + openknot_score = 0.5 * eterna_scores + 0.5 * crossed_pair_quality_score + + return eterna_scores, crossed_pair_score, crossed_pair_quality_score, openknot_score \ No newline at end of file