diff --git a/frustratometer/__init__.py b/frustratometer/__init__.py index c381858d..cb1d16d2 100644 --- a/frustratometer/__init__.py +++ b/frustratometer/__init__.py @@ -13,6 +13,7 @@ from . import align from . import frustration from . import optimization +from . import numba_util # Handle versioneer from ._version import get_versions diff --git a/frustratometer/classes/AWSEM.py b/frustratometer/classes/AWSEM.py index 770f4ad0..ea0f86ca 100644 --- a/frustratometer/classes/AWSEM.py +++ b/frustratometer/classes/AWSEM.py @@ -1,18 +1,32 @@ +import warnings import numpy as np from ..utils import _path from .. import frustration +from .. import numba_util +from ..numba_util import ham, algos from .Frustratometer import Frustratometer from .Gamma import Gamma from pydantic import BaseModel, Field, ConfigDict from pydantic.types import Path -from typing import List,Optional,Union +from typing import List,Optional,Union,Generator +import os -__all__ = ['AWSEM'] +__all__ = ['AWSEM','AWSEMIndicators','DecoyEnsemble', 'AWSEMVariancePotts'] + +class ParametersAWSEM(BaseModel): + # due to the presence of the pydantic BaseModel, + # these variables that look like class attributes + # are actually instance attributes -class AWSEMParameters(BaseModel): model_config = ConfigDict(extra='ignore', arbitrary_types_allowed=True) """Default parameters for AWSEM energy calculations.""" - k_contact: float = Field(4.184, description="Coefficient for contact potential. (kJ/mol)") + k_contact: float = Field(4.184, description=""" + Scale factor for contact potential. + Many parameters used to be given in kcal/mol, + but we want our results in kJ/mol, so this is + set to the appropriate conversion factor by default. + Note that the electrostatic parameter is not scaled + by k_contact.""") #Density eta: float = Field(5.0, description="Sharpness of the distance-based switching function (Angstrom^-1).") @@ -37,7 +51,6 @@ class AWSEMParameters(BaseModel): r_maxII: float = Field(9.5, description="Maximum distance for mediated contact potential. (Angstrom)") eta_sigma: float = Field(7.0, description="Sharpness of the density-based switching function between protein-mediated and water-mediated contacts.") - #Membrane membrane_gamma: Union[Path,Gamma] = Field(_path/'data'/'AWSEM_membrane_2015.json', description="File or Gamma object containing the membrane Gamma values (for membrane proteins)") eta_switching: int = Field(10, description="Switching distance for the membrane switching function") @@ -45,215 +58,756 @@ class AWSEMParameters(BaseModel): #Electrostatics min_sequence_separation_electrostatics: Optional[int] = Field(1, description="Minimum sequence separation for electrostatics calculation.") k_electrostatics: float = Field(17.3636, description="Coefficient for electrostatic interactions. (kJ/mol)") - electrostatics_screening_length: float = Field(10, description="Screening length for electrostatic interactions. (Angstrom)") + electrostatics_screening_length: float = Field(10.0, description="Screening length for electrostatic interactions. (Angstrom)") + + # We might not know the order of amino acids in our alphabet at the time of initialization + # (this happens the above gammas are Paths), so we'll have to build the electrostatic "gamma" + # later (see self.model_post_init) + charge_dict : dict = Field({'A':0.0,'C':0.0,'D':-1.0,'E':-1.0, + 'F':0.0,'G':0.0,'H':0.0,'I':0.0, + 'K':1.0,'L':0.0,'M':0.0,'N':0.0,'P':0.0, + 'Q':0.0,'R':1.0,'S':0.0,'T':0.0, + 'V':0.0,'W':0.0,'Y':0.0}, + description='charge of each amino acid type that may be used') + + def model_post_init(self, __context__): + """Pydantic v2 hook called after model initialization. + + The signature must accept a single positional argument named + ``__context__`` (per pydantic v2). If ``gamma`` was provided + as a path, convert it to a ``Gamma`` instance here. + """ + if isinstance(self.gamma, Path): + self.gamma = Gamma(self.gamma) -class AWSEM(Frustratometer): - #Mapping to DCA - q = 20 - aa_map_awsem_list = [0, 0, 4, 3, 6, 13, 7, 8, 9, 11, 10, 12, 2, 14, 5, 1, 15, 16, 19, 17, 18] #A gap has no energy - aa_map_awsem_x, aa_map_awsem_y = np.meshgrid(aa_map_awsem_list, aa_map_awsem_list, indexing='ij') + # make properties that reshape gammas stored inside the Gamma object (self.gamma) + @property + def burial_gamma(self): + # check shape of gamma + gb = self.gamma['Burial'] + if gb.shape == (3,self.gamma.q): + return gb.T + elif gb.shape == (self.gamma.q,3): + return gb + else: + raise ValueError(f""" + Don't know how to parse burial gamma with shape {gb.shape}. + Expected ({self.gamma.q},3) or (3,{self.gamma.q}).""") + @burial_gamma.setter + def burial_gamma(self, _): + raise AttributeError(""" + Modifying burial_gamma directly is not allowed. + Instead, modify the underlying Gamma object, accessible at self.gamma.""") + @property + def direct_gamma(self): + gd = np.squeeze(self.gamma['Direct']) # gammas commonly formatted as (1,q,q) + # check shape of gamma + if gd.shape != (self.gamma.q,self.gamma.q): + raise ValueError(f""" + Don't know how to parse direct gamma with shape {gd.shape}. + Expected ({self.gamma.q}, {self.gamma.q})""") + else: + return gd + @direct_gamma.setter + def direct_gamma(self, _): + raise AttributeError(""" + Modifying direct_gamma directly is not allowed. + Instead, modify the underlying Gamma object, accessible at self.gamma.""") + @property + def protein_gamma(self): + gp = np.squeeze(self.gamma['Protein']) # gammas commonly formatted as (1,q,q) + # check shape of gamma + if gp.shape != (self.gamma.q,self.gamma.q): + raise ValueError(f""" + Don't know how to parse protein gamma with shape {gp.shape}. + Expected ({self.gamma.q}, {self.gamma.q})""") + else: + return gp + @protein_gamma.setter + def protein_gamma(self, _): + raise AttributeError(""" + Modifying protein_gamma directly is not allowed. + Instead, modify the underlying Gamma object, accessible at self.gamma.""") + @property + def water_gamma(self): + gw = np.squeeze(self.gamma['Water']) # gammas commonly formatted as (1,q,q) + # check shape of gamma + if gw.shape != (self.gamma.q,self.gamma.q): + raise ValueError(f""" + Don't know how to parse water gamma with shape {gw.shape}. + Expected ({self.gamma.q}, {self.gamma.q})""") + else: + return gw + @direct_gamma.setter + def direct_gamma(self, _): + raise AttributeError(""" + Modifying direct_gamma directly is not allowed. + Instead, modify the underlying Gamma object, accessible at self.gamma.""") + @property + def electrostatic_gamma(self): + charges = [] + for oneletter in self.gamma.alphabet: + if oneletter not in self.charge_dict.keys(): + raise ValueError(f""" + One letter code {oneletter} found in Gamma.alphabet + is not known to the electrostatic potential. + Provide your ParametersAWSEM object with a complete + charge_dict specifying the electric charge, in fundamental + units, of {oneletter} and all other amino acids + in your alphabet so that an electrostatic "gamma" array + of the same shape and amino acid order as your direct, + protein, and water gammas can be created.""") + else: + charges.append(self.charge_dict[oneletter]) + assert len(charges) == self.gamma.q + return np.outer(charges, charges) # our electrostatic "gamma" + @electrostatic_gamma.setter + def electrostatic_gamma(self, _): + raise AttributeError(""" + Electrostatic_gamma is a property computed from the + ParametersAWSEM.charge_dict and ParametersAWSEM.Gamma.alphabet. + To set the electrostatic gamma, change the charge_dict + of your parameters object.""") + @property + def electrostatics_gamma(self): + return self.electrostatic_gamma + @electrostatics_gamma.setter + def electrostatics_gamma(self, _): + self.electrostatic_gamma = _ + # if we're going to make gamma arrays from self.gamma available as + # properties of this class, we should also make the alphabet and q from + # self.gamma available as a property of this class + @property + def alphabet(self): + return self.gamma.alphabet + @alphabet.setter + def alphabet(self, new_alphabet): + # we might need new_alphabet to be either a list or str, but not the other + self.gamma.reorder(alphabet=new_alphabet) + #raise AttributeError(""" + #Resetting the alphabet must be done using self.gamma.reorder() + #(self.gamma is an instance of the Gamma class). + #Changes made to the underlying Gamma object will then + #propagate upward.""") + @property + def q(self): + return self.gamma.q + # gamma.q is itself a property that calculates seq len and returns it + +class _AWSEMBase(Frustratometer): + """ + Base class for potts model and frustration calculations + with the AWSEM Hamiltonian. + """ + def __init__(self, - pdb_structure: object, - sequence: str =None, + sequence: str, expose_indicator_functions: bool=False, + potts_option: bool=False, **parameters)->object: """ - Generate AWSEM object + Set attributes that do not depend on the implementations of + the indicator function and potts model setup calculations. Parameters ---------- - pdb_structure : object - Structure object generated by Structure class - sequence : str - The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. + sequence: str + The amino acid sequence expose_indicator_functions: bool If set to True, indicator functions of the contact and burial energy terms can be accessed by user. - + potts_option: bool + Whether to set up the potts model (can be RAM-intensive and time-intensive), + which is unnecessary if all you want to get is the indicator functions. + **parameters: + Used to initialize an ParametersAWSEM, which becomes an attribute + of this class and helps us organize the parameters of our AWSEM Hamiltonian + Returns ------- - AWSEM object + _AWSEMBase object """ - - #Set attributes - p = AWSEMParameters(**parameters) - if p.min_sequence_separation_contact is None: - p.min_sequence_separation_contact = 1 - if p.min_sequence_separation_rho is None: - p.min_sequence_separation_rho = 1 - if p.min_sequence_separation_electrostatics is None: - p.min_sequence_separation_electrostatics = 1 - - for field, value in p: - setattr(self, field, value) - - #Gamma parameters - if isinstance(p.gamma, Gamma): - gamma = p.gamma - elif isinstance(p.gamma, Path): - gamma = Gamma(p.gamma) - else: - raise ValueError("Gamma parameter must be a path or a Gamma object.") + + # set sequence based on arguments + self._sequence = sequence + + # set indicator function exposure based on argument + # (not exposing them saves a tiny bit of RAM but it's useful to Ezequiel) + self.expose_indicator_functions = expose_indicator_functions + + # whether to store the potts model as an object attribute, + # which requires a lot of ram + self.potts_option = potts_option + + # check consistency of potts_option and expose_indicator_functions arguments + if self.potts_option and not self.expose_indicator_functions: + warnings.warn(f""" + You requested storing the potts model as an object attribute by using potts_option=True + but requested NOT storing the indicator functions as object attributes by using + expose_indicator_functions=False. Since the potts model requires far more RAM than + the indicator functions, we will override your indicator function request + and store them anyway. This will have no effect on the accuracy of any calculations. - self.gamma=gamma - self.burial_gamma = gamma['Burial'].T - self.direct_gamma = gamma['Direct'][0] - self.protein_gamma = gamma['Protein'][0] - self.water_gamma = gamma['Water'][0] - self.burial_in_context=p.burial_in_context - - #Structure details - self.full_to_aligned_index_dict=pdb_structure.full_to_aligned_index_dict - if sequence is None: - self.sequence=pdb_structure.sequence + Setting {self.__class__}.expose_indicator_functions = True""") + self.expose_indicator_functions = True + + # parse other arguments + self.p = ParametersAWSEM(**parameters) + + # i don't know why these aren't the defaults in ParametersAWSEM + # if we're going to override them anyway + if self.p.min_sequence_separation_contact is None: + self.p.min_sequence_separation_contact = 1 + if self.p.min_sequence_separation_rho is None: + self.p.min_sequence_separation_rho = 1 + if self.p.min_sequence_separation_electrostatics is None: + self.p.min_sequence_separation_electrostatics = 1 + + # set other attributes + self.burial_in_context = self.p.burial_in_context #i'd prefer to move this out of ParametersAWSEM completely + self._decoy_fluctuation = {} # used for non-configurational frustration calculations + self._minimally_frustrated_threshold=.78 # this should be a class variable or an argument to __init__ + self._native_energy = None + self._potts_model = None + + # although the alphabet is really an attribute of the AWSEM + # Hamiltonian, and therefore belongs in the ParametersAWSEM instance, + # at least one method in the Frustratometer class requires + # it to be accessible in this namespace, so we make it a property + @property + def alphabet(self): + return self.p.alphabet + @alphabet.setter + def alphabet(self, new_alphabet): + # we might need new_alphabet to be either a list or str, but not the other + self.p.alphabet = new_alphabet + + # we make these attributes into properties for protection; + # may write setters at some point to update everything appropriately + # and allow modification of initialized objects + @property + def sequence(self): + return self._sequence + @sequence.setter + def sequence(self, _): + raise NotImplementedError("Modifying the sequence is not permitted. May add support at some point.") + @property + def minimally_frustrated_threshold(self): + return self._minimally_frustrated_threshold + @minimally_frustrated_threshold.setter + def minimally_frustrated_threshold(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.minimally_frustrated_threshold. May add support in the future") + + # these attributes are computed from other attributes, + # so we make them into properties + @property + def sequence_cutoff(self): + if self.p.k_electrostatics == 0: + return self.p.min_sequence_separation_contact + else: + return min(self.p.min_sequence_separation_contact, + self.p.min_sequence_separation_electrostatics) + @property + def distance_cutoff(self): + # the distance cutoff might not exist in all subclasses, + # but they just won't use it. there's no harm in defining it + if self.p.k_electrostatics == 0: + return self.p.distance_cutoff_contact + else: + return None + @property + def N(self): + return len(self.sequence) + @property # emphasizes that seq_index is computed from the alphabet + def seq_index(self): + self._seq_index = np.array([self.p.alphabet.index(aa) for aa in self.sequence]) + return self._seq_index + @property + def aa_freq(self): + return frustration.compute_aa_freq(self.sequence, self.p.alphabet) + @property + def contact_freq(self): + return frustration.compute_contact_freq(self.sequence, self.p.alphabet) + @property + def potts_model(self): + if self._potts_model is None: + if self.potts_option: + #raise AssertionError(f""" + #The user requested potts model calculation but apparently + #it wasn't done upon initialization of {self.__class__}. + #This is likely an issue with the _AWSEMBase class.""") + # + # the user may have changed their mind after initializing + # the object (see the else block of this conditional), + # so the above assertion is inappropriate + self.calculate_energy_and_potts() + return self._potts_model + else: + warnings.warn(f""" + Attempting to access (N,N,q,q)-shaped numpy array of potts model, + but {self.__class__}.potts_option evaluated to False. + Will not return potts model. + To get the potts model, set self.potts_option=True and try again.""") + return None else: - self.sequence=sequence - self.structure=pdb_structure.structure - self.chain=pdb_structure.chain - self.pdb_file=pdb_structure.pdb_file - self.init_index_shift=pdb_structure.init_index_shift - self.distance_matrix=pdb_structure.distance_matrix - self.full_pdb_distance_matrix=pdb_structure.full_pdb_distance_matrix - selection_CB = self.structure.select('name CB or (resname GLY IGL and name CA)') - - resid = selection_CB.getResindices() - self.resid=resid - self.N=len(self.resid) - assert self.N == len(self.sequence), "The pdb is incomplete. Try setting 'repair_pdb=True' when constructing the Structure object." + return self._potts_model + @potts_model.setter + def potts_model(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.potts_model") + # this is like derived properties (see above), but certain functions + # require native_energy to be a callable method instead of a property + def native_energy(self): + if self._potts_model is not None: + energy = super().native_energy() # method to compute native energy given potts model + else: + l_D = float(self.p.electrostatics_screening_length) + min_seq_sep_rho = self.p.min_sequence_separation_rho + min_seq_sep_contact = self.p.min_sequence_separation_contact + min_seq_sep_electrostatic = self.p.min_sequence_separation_electrostatics + energy = ham.compute_potential_total( + l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, + chain_starts, chain_ends, + dist_mat=self.distance_matrix, + lambda_direct=self.p.k_contact, direct_gamma=self.p.direct_gamma, + lambda_protein=self.p.k_contact, protein_gamma=self.p.protein_gamma, + lambda_water=self.p.k_contact, water_gamma=self.p.water_gamma, + lambda_burial=self.p.k_contact, burial_gamma=self.p.burial_gamma, + lambda_electrostatic=self.p.k_electrostatics, electrostatic_gamma=self.p.electrostatic_gamma, + seq_index=self.seq_index, parallel=True) # can set to False if having numba issues + #self._native_energy = energy # maybe _native_energy is needed for compatibility with certain things? + return energy + + # this format is a little bit unusual but is useful for the optimization code + @property + def coefficient_lambda_gamma_array(self): + _coefficient_lambda_gamma_array = [] + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.burial_gamma[:,0]) + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.burial_gamma[:,1]) + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.burial_gamma[:,2]) + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.direct_gamma) + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.protein_gamma) + _coefficient_lambda_gamma_array.append(-0.5 * self.p.k_contact * self.p.water_gamma) + _coefficient_lambda_gamma_array.append(0.5 * self.p.k_electrostatics * self.p.electrostatic_gamma) + # not a typo, supposed to be positive ^^^ + # charges2 is our electrostatic "gamma" + return _coefficient_lambda_gamma_array + @coefficient_lambda_gamma_array.setter # clarifies that this is derived from more fundamental quantities + def coefficient_lambda_gamma_array(self, _): + raise AttributeError(f"""Setting {self.__class__}.coefficient_lambda_gamma_array + directly is not allowed. Initialize a new instance with a different + {self.__class__}.p.k_contact, {self.__class__}.burial_gamma, {self.__class__}.direct_gamma, + {self.__class__}.p.gamma.protein_gamma, or {self.__class__}.water_gamma instead.""") + + + + + + + + + @property + def sequence_mask_rho(self): + return self._sequence_mask_rho + @sequence_mask_rho.setter + def sequence_mask_rho(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.sequence_mask_rho") + + @property + def sequence_mask_contact(self): + return self._sequence_mask_contact + @sequence_mask_contact.setter + def sequence_mask_contact(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.sequence_mask_contact") + + @property + def electrostatics_mask(self): + return self._electrostatics_mask + @electrostatics_mask.setter + def electrostatics_mask(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.electrostatics_mask") + + @property + def mask(self): + return self._mask + @mask.setter + def mask(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.mask") + + @property + def selected_matrix(self): + return self._selected_matrix + @selected_matrix.setter + def selected_matrix(self, value): + raise NotImplementedError(f"Cannot directly set {self.__class__}.selected_matrix") + + + ################################################################################## + + # methods for subclass initialization + def subclass_setup_helper(self): + """ + This method calls methods to calculate native indicator functions (optional), + masks (based on the native distance matrix), native energy (optional), + and potts model (optional). + + This method is intended to be called as the last step of __init__ + in each subclass of _AWSEMBase. The subclasses may differ in how + they load in the structural information (the part of __init__ + preceding the call to this method) and how they implement + the calculate_indicators and calculate_masks methods called + by subclass_setup_helper + """ + self.calculate_masks() # subclasses should (re)define this method as needed + self.calculate_indicators() # subclasses should (re)define this method as needed + self.calculate_energy_and_potts() + + def calculate_indicators(self): + raise NotImplementedError("Subclasses must implement this method") + + def calculate_masks(self): + # calculate masks if self.burial_in_context==True: selected_matrix=self.full_pdb_distance_matrix else: selected_matrix=self.distance_matrix - sequence_mask_rho = frustration.compute_mask(selected_matrix, + self._sequence_mask_rho = frustration.compute_mask(selected_matrix, maximum_contact_distance=None, - minimum_sequence_separation = p.min_sequence_separation_rho) - sequence_mask_contact = frustration.compute_mask(self.distance_matrix, - maximum_contact_distance=p.distance_cutoff_contact, - minimum_sequence_separation = p.min_sequence_separation_contact) - - self._decoy_fluctuation = {} - self.minimally_frustrated_threshold=.78 - - # Calculate rho - rho = 0.25 - rho *= (1 + np.tanh(p.eta * (selected_matrix- p.r_min))) - rho *= (1 + np.tanh(p.eta * (p.r_max - selected_matrix))) - rho *= sequence_mask_rho - self.rho=rho - - #Calculate sigma water - rho_r = (rho).sum(axis=1) - if self.full_pdb_distance_matrix.shape!=self.distance_matrix.shape: - if self.burial_in_context==True: - self.init_index_shift=pdb_structure.init_index_shift - self.fin_index_shift=pdb_structure.fin_index_shift - rho_r=rho_r[self.init_index_shift:self.fin_index_shift] - self.rho_r=rho_r - rho_b = np.expand_dims(rho_r, 1) - rho1 = np.expand_dims(rho_r, 0) - rho2 = np.expand_dims(rho_r, 1) - sigma_water = 0.25 * (1 - np.tanh(p.eta_sigma * (rho1 - p.rho_0))) * (1 - np.tanh(p.eta_sigma * (rho2 - p.rho_0))) - sigma_protein = 1 - sigma_water + minimum_sequence_separation = self.p.min_sequence_separation_rho) + self._sequence_mask_contact = frustration.compute_mask(self.distance_matrix, + maximum_contact_distance=self.p.distance_cutoff_contact, + minimum_sequence_separation = self.p.min_sequence_separation_contact) + self._electrostatics_mask = frustration.compute_mask(self.distance_matrix, + maximum_contact_distance=None, + minimum_sequence_separation=self.p.min_sequence_separation_electrostatics) + #with open('my_data.txt','w') as f: + # f.write(f"self.distance_cutoff: {self.distance_cutoff}\n") + # f.write(f"self.sequence_cutoff: {self.sequence_cutoff}\n") + #np.save('my_distance_matrix.npy',self.distance_matrix) + self._mask = frustration.compute_mask(self.distance_matrix, + maximum_contact_distance=self.distance_cutoff, + minimum_sequence_separation = self.sequence_cutoff) + #np.save('my_mask_new.npy',self.mask) + self._selected_matrix = selected_matrix # we'll need this in the calculate_indicators function - #Calculate theta and indicators - theta = 0.25 * (1 + np.tanh(p.eta * (self.distance_matrix - p.r_min))) * (1 + np.tanh(p.eta * (p.r_max - self.distance_matrix))) - thetaII = 0.25 * (1 + np.tanh(p.eta * (self.distance_matrix - p.r_minII))) * (1 + np.tanh(p.eta * (p.r_maxII - self.distance_matrix))) - burial_indicator = np.tanh(p.burial_kappa * (rho_b - p.burial_ro_min)) + np.tanh(p.burial_kappa * (p.burial_ro_max - rho_b)) - direct_indicator = theta[:, :, np.newaxis, np.newaxis] - water_indicator = thetaII[:, :, np.newaxis, np.newaxis] * sigma_water[:, :, np.newaxis, np.newaxis] - protein_indicator = thetaII[:, :, np.newaxis, np.newaxis] * sigma_protein[:, :, np.newaxis, np.newaxis] - - if expose_indicator_functions: - self.indicators=[] - self.indicators.append(burial_indicator[:,0]) - self.indicators.append(burial_indicator[:,1]) - self.indicators.append(burial_indicator[:,2]) + def calculate_energy_and_potts(self, chain_starts=None, chain_ends=None): + # chain_starts and chain_ends should be calculated based on object attributes + # (maybe Structure.chain?) but i don't know how to do that, so we'll do this for now + if self.potts_option: + warnings.warn(""" + Constructing full potts model in RAM. + If you don't want to do this, set potts_option=False + """) + self._potts_model = {'h':None, 'J':None} + if chain_starts is None: + chain_starts = np.array([0]) + if chain_ends is None: + chain_ends = np.array([len(self.seq_index)-1]) + if self.p.distance_cutoff_contact is None: + contact_max_dist = 12.5 + else: + contact_max_dist = self.p.distance_cutoff_contact + self._potts_model['h'] = ham.compute_potts_model_h_parallel( + self.p.min_sequence_separation_rho, + chain_starts, chain_ends, + self.distance_matrix, + self.p.k_contact, self.p.burial_gamma) + self._potts_model['J'] = ham.compute_potts_model_J_parallel( + self.p.electrostatics_screening_length, self.p.min_sequence_separation_rho, + self.p.min_sequence_separation_contact, self.p.min_sequence_separation_electrostatics, + chain_starts, chain_ends, + contact_max_dist, 10*self.p.electrostatics_screening_length, # maximum distance for contact potential, maximum for electrostatics + self.distance_matrix, + self.p.k_contact, self.p.direct_gamma, + self.p.k_contact, self.p.protein_gamma, + self.p.k_contact, self.p.water_gamma, + self.p.k_electrostatics, self.p.electrostatic_gamma) + #breakpoint() + #self.potts_model['J'] = ham.compute_potts_model_J( + # self.distance_matrix, ) + J_index = np.meshgrid(range(self.N), range(self.N), range(self.p.q), range(self.p.q), indexing='ij', sparse=False) + h_index = np.meshgrid(range(self.N), range(self.p.q), indexing='ij', sparse=False) - self.indicators.append(direct_indicator[:,:,0,0]*sequence_mask_contact) - self.indicators.append(protein_indicator[:,:,0,0]*sequence_mask_contact) - self.indicators.append(water_indicator[:,:,0,0]*sequence_mask_contact) - - self.gamma_array=[] - temp_burial_gamma=self.burial_gamma[self.aa_map_awsem_list] - temp_burial_gamma[0]=0 - temp_burial_gamma *= -0.5 * p.k_contact - self.gamma_array.append(temp_burial_gamma[:,0]) - self.gamma_array.append(temp_burial_gamma[:,1]) - self.gamma_array.append(temp_burial_gamma[:,2]) - - for contact_gamma in [self.direct_gamma, self.protein_gamma, self.water_gamma]: - temp_gamma = contact_gamma[self.aa_map_awsem_x, self.aa_map_awsem_y].copy() - temp_gamma[0, :] = 0 - temp_gamma[:, 0] = 0 - temp_gamma *= -0.5 * self.k_contact - self.gamma_array.append(temp_gamma) - - self.burial_indicator = burial_indicator - self.direct_indicator = direct_indicator - self.water_indicator = water_indicator - self.protein_indicator = protein_indicator + # compute burial and contact energies + old_burial_energy = 0.5 * self.p.k_contact * self.p.burial_gamma[h_index[1]] * self.burial_indicator[:, np.newaxis, :] + direct = self.direct_indicator * self.p.direct_gamma[J_index[2], J_index[3]] + water_mediated = self.water_indicator * self.p.water_gamma[J_index[2], J_index[3]] + protein_mediated = self.protein_indicator * self.p.protein_gamma[J_index[2], J_index[3]] + contact_energy = self.p.k_contact * np.array([direct, water_mediated, protein_mediated]) * self.sequence_mask_contact[np.newaxis, :, :, np.newaxis, np.newaxis] + electrostatics_energy = -self.p.k_electrostatics * self.p.electrostatic_gamma[np.newaxis,np.newaxis,:,:] * self.electrostatics_indicator[:,:,np.newaxis,np.newaxis]\ + * self.electrostatics_mask[:,:,np.newaxis,np.newaxis] + contact_energy = np.append(contact_energy, electrostatics_energy[np.newaxis,:,:,:,:], axis=0) + old_contact_energy = contact_energy + # Compute potts model + old_potts_model = {} + old_potts_model['h'] = old_burial_energy.sum(axis=-1)[:, :] + old_potts_model['J'] = old_contact_energy.sum(axis=0)[:, :, :, :] + diff_h = np.max(np.abs(old_potts_model['h'] - self.potts_model['h'])) + assert self.distance_matrix.shape == (self.N, self.N) + assert self.distance_matrix.shape == (len(self.sequence), len(self.sequence)) + assert diff_h < 3E-4, diff_h + diff_J = np.max(np.abs(old_potts_model['J'] - self.potts_model['J'])) + assert diff_J < 3E-4, diff_J + else: + warnings.warn(""" + potts_option was False; will not calculate and store potts model. + Energies will be computed on the fly as needed for frustration calculations and then discarded. + If you want to get the energies for your own purposes, set self.potts_option=True + and then call calculate_energy_and_potts.""") - J_index = np.meshgrid(range(self.N), range(self.N), range(self.q), range(self.q), indexing='ij', sparse=False) - h_index = np.meshgrid(range(self.N), range(self.q), indexing='ij', sparse=False) - #Burial energy - burial_energy = 0.5 * p.k_contact * self.burial_gamma[h_index[1]] * burial_indicator[:, np.newaxis, :] - self.burial_energy = burial_energy - #Contact energy - direct = direct_indicator * self.direct_gamma[J_index[2], J_index[3]] - water_mediated = water_indicator * self.water_gamma[J_index[2], J_index[3]] - protein_mediated = protein_indicator * self.protein_gamma[J_index[2], J_index[3]] - contact_energy = p.k_contact * np.array([direct, water_mediated, protein_mediated]) * sequence_mask_contact[np.newaxis, :, :, np.newaxis, np.newaxis] + # methods to calculate different kinds of frustration + def compute_configurational_decoy_statistics(self): + raise NotImplementedError("Subclasses must define this method") - # Compute electrostatics - if p.k_electrostatics!=0: - self.sequence_cutoff=min(p.min_sequence_separation_electrostatics, p.min_sequence_separation_contact) - self.distance_cutoff=None - - - electrostatics_mask = frustration.compute_mask(self.distance_matrix, maximum_contact_distance=None, minimum_sequence_separation=p.min_sequence_separation_electrostatics) - # ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] - charges = np.array([0, 1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) - charges2 = charges[:,np.newaxis]*charges[np.newaxis,:] + def compute_configurational_energies(self): + raise NotImplementedError("Subclasses must define this method") + + def configurational_frustration(self,aa_freq=None, correction=0, n_decoys=4000): + mean_decoy_energy, std_decoy_energy = self.compute_configurational_decoy_statistics(n_decoys=n_decoys,aa_freq=aa_freq) + return -(self.compute_configurational_energies()-mean_decoy_energy)/(std_decoy_energy+correction) - electrostatics_indicator = 1 / (self.distance_matrix + 1E-6) * np.exp(-self.distance_matrix / p.electrostatics_screening_length) * electrostatics_mask - electrostatics_energy = -p.k_electrostatics * (charges2[np.newaxis,np.newaxis,:,:]*electrostatics_indicator[:,:,np.newaxis,np.newaxis]) + def mutational_frustration(self): + # This algorithm is defined in the Frustratometer class + # because it applies to both AWSEM and DCA frustratometry, + # and both the _AWSEMBase and DCA classes inherit from Frustratometer. + # Our goal here is just to provide an interface that matches that used + # for configurational frustration, which has no DCA analog and therefore + # is not defined in Frustratometer (although Frustratometer.frustration + # calls the configurational_frustration method of this class if passed + # the kind='configurational' argument) + return super().frustration(kind='mutational') - contact_energy = np.append(contact_energy, electrostatics_energy[np.newaxis,:,:,:,:], axis=0) - if expose_indicator_functions: - self.indicators.append(electrostatics_indicator) - temp_gamma=0.5 * p.k_electrostatics * charges2[self.aa_map_awsem_x, self.aa_map_awsem_y] - temp_gamma[0,:]=0 - temp_gamma[:,0]=0 - self.gamma_array.append(temp_gamma) + def singleresidue_frustration(self): + # see note for mutational_frustration method + return super().frustration(kind='singleresidue') + + +class AWSEM(_AWSEMBase): + """ + The main class that the user will invoke + for potts model and frustration calculations + with the AWSEM Hamiltonian. + + However, users may also be interested in AWSEMIndicators. + """ + def __init__(self, + pdb_structure: object | tuple, # tuple is an object, but this clarifies what we expect + sequence: str =None, + expose_indicator_functions: bool=False, + potts_option: bool=False, + alt_sigma_wat: bool=False, + **parameters)->object: + """ + Pass parameters to _AWSEMBase to + set attributes that DO NOT depend on the implementations of + the indicator function and potts model setup calculations, + then set attributes that DO depend on the implementations + of the indicator function and potts model setup calculations. + + Parameters + ---------- + pdb_structure: object | tuple + A Structure object or tuple of distance matrices characterizing the conformer + to be used (see self.setup_structure) + sequence: str + The amino acid sequence + expose_indicator_functions: bool + If set to True, indicator functions of the contact and burial energy terms can be accessed by user. + potts_option: bool + Whether to set up the potts model (can be RAM-intensive and therefore time-intensive), + which is unnecessary if all you want to get is the indicator functions or + perform certain frustration calculations, for which energies can be computed + on the fly instead of saved in memory. + alt_sigma_wat : bool=False + Whether to use alternative functional form for sigma_wat (experimental feature) + **parameters: + Used to initialize an ParametersAWSEM, which becomes an attribute + of this class and helps us organize the parameters of our AWSEM Hamiltonian + + Returns + ------- + AWSEM object + """ + + # assume the user wanted the sequence from the pdb structure if not given + if not sequence: + try: + sequence = pdb_structure.sequence + except: + if isinstance(pdb_structure,tuple): + raise ValueError("""It seems that you are trying to use + the tuple pdb_structure format, which + specifies a conformation but not a sequence. + In this case, you must provide the sequence + as a separate argument to this class.""") + else: + raise + + # load structure-independent parameters and methods + super().__init__(sequence, expose_indicator_functions, potts_option, **parameters) + self.alt_sigma_wat = alt_sigma_wat + + # set up strucure + self.setup_structure(pdb_structure) + self.subclass_setup_helper() + + # methods for structure-dependent stuff + def setup_structure(self, pdb_structure): + if not isinstance(pdb_structure, tuple): # alt_conf should be our custom Structure object + # maybe our type check here should be more restrictive, + # but the __init__ only requires pdb_structure to be an object, + # so I'll take my cue from that + # check that the sequence of our Structure is consistent with the current sequence + if self.sequence != pdb_structure.sequence: + raise NotImplementedError(f""" + You are attempting to modify the sequence of your {self.__class__} + by passing in a Structure with a different sequence than self.sequence. + This is currently not supported but may be in the future.""") + # check that the length of the sequence of our Structure is consistent + if self.N != len(pdb_structure.sequence): + raise NotImplementedError(f""" + You are attempting to modify the length of the sequence of your + {self.__class__} by passing in a Structure with a different sequence than + self.sequence. This is currently not supported but may be in the future.""") + # check structure + selection_CB = pdb_structure.structure.select('name CB or (resname GLY IGL and name CA)') + resid = selection_CB.getResindices() + self.resid = resid + # set structure-dependent properties + self._pdb_structure = pdb_structure + self.structure=pdb_structure.structure + self.chain=pdb_structure.chain + self.pdb_file=pdb_structure.pdb_file + self.init_index_shift=pdb_structure.init_index_shift + self.full_to_aligned_index_dict=pdb_structure.full_to_aligned_index_dict + self.distance_matrix=pdb_structure.distance_matrix + self.full_pdb_distance_matrix=pdb_structure.full_pdb_distance_matrix + self.midpoint_matrix = pdb_structure.midpoint_matrix + # midpoint matrix is used to map interacting pairs to a single point in space + elif isinstance(pdb_structure, tuple): # pdb_structure is defined by a few distance matrices + if len(pdb_structure)==3\ + and isinstance(pdb_structure[0],np.ndarray)\ + and isinstance(pdb_structure[1],np.ndarray)\ + and isinstance(pdb_structure[2],np.ndarray) or pdb_structure[2] is None: + # pdb_structure is a full_pdb_distance_matrix + # followed by a distance_matrix + # followed by a midpoint matrix (or None) + self._pdb_structure = None # we're getting our conformer from within python, not a pdb file + self.structure = None # we're getting our conformer from within python, not a pdb file + self.full_pdb_distance_matrix = pdb_structure[0] + self.distance_matrix = pdb_structure[1] + self.midpoint_matrix = pdb_structure[2] + # midpoint matrix is used to map interacting pairs to a single point in space; + # usually not necessary, so it will usually be None + # + # the rest of the attributes that are set in the case that pdb_structure is a Structure + # either remain the same (if this method has been previously called with a Structure) + # or go undefined (if we are passing a list of arrays the first time that we are calling + # this method) + else: + raise TypeError(""" + Could not parse pdb_structure tuple. + Check the source code or make pdb_structure into a Structure object.""") else: - self.sequence_cutoff=p.min_sequence_separation_contact - self.distance_cutoff=p.distance_cutoff_contact - self.mask = frustration.compute_mask(self.distance_matrix, maximum_contact_distance=self.distance_cutoff, minimum_sequence_separation = self.sequence_cutoff) + raise AssertionError("unexpected else block") - self.contact_energy = contact_energy + def calculate_indicators(self): + if self.expose_indicator_functions: + # Calculate rho + rho = 0.25 + rho *= (1 + np.tanh(self.p.eta * (self.selected_matrix - self.p.r_min))) + rho *= (1 + np.tanh(self.p.eta * (self.p.r_max - self.selected_matrix))) + rho *= self.sequence_mask_rho + self.rho=rho + #Calculate sigma water + rho_r = (rho).sum(axis=1) + if self.full_pdb_distance_matrix.shape!=self.distance_matrix.shape: + if self.burial_in_context==True: + self.init_index_shift=self.pdb_structure.init_index_shift + self.fin_index_shift=self.pdb_structure.fin_index_shift + rho_r=rho_r[self.init_index_shift:self.fin_index_shift] + self.rho_r=rho_r + rho_b = np.expand_dims(rho_r, 1) + rho1 = np.expand_dims(rho_r, 0) + rho2 = np.expand_dims(rho_r, 1) + sigma_water = 0.25 * (1 - np.tanh(self.p.eta_sigma * (rho1 - self.p.rho_0))) * (1 - np.tanh(self.p.eta_sigma * (rho2 - self.p.rho_0))) + if self.alt_sigma_wat: + sigma_water = -sigma_water + 0.5*( (1 - np.tanh(self.p.eta_sigma * (rho1 - self.p.rho_0))) + (1 - np.tanh(self.p.eta_sigma * (rho2 - self.p.rho_0)))) + sigma_protein = 1 - sigma_water + #Calculate theta and indicators + theta = 0.25 * (1 + np.tanh(self.p.eta * (self.distance_matrix - self.p.r_min))) * (1 + np.tanh(self.p.eta * (self.p.r_max - self.distance_matrix))) + thetaII = 0.25 * (1 + np.tanh(self.p.eta * (self.distance_matrix - self.p.r_minII))) * (1 + np.tanh(self.p.eta * (self.p.r_maxII - self.distance_matrix))) + burial_indicator = np.tanh(self.p.burial_kappa * (rho_b - self.p.burial_ro_min)) + np.tanh(self.p.burial_kappa * (self.p.burial_ro_max - rho_b)) + direct_indicator = theta[:, :, np.newaxis, np.newaxis] + water_indicator = thetaII[:, :, np.newaxis, np.newaxis] * sigma_water[:, :, np.newaxis, np.newaxis] + protein_indicator = thetaII[:, :, np.newaxis, np.newaxis] * sigma_protein[:, :, np.newaxis, np.newaxis] + self.burial_indicator = burial_indicator + self.direct_indicator = direct_indicator + self.water_indicator = water_indicator + self.protein_indicator = protein_indicator + electrostatics_indicator = 1 / (self.distance_matrix + 1E-6) * np.exp(-self.distance_matrix / self.p.electrostatics_screening_length) + self.electrostatics_indicator = electrostatics_indicator + else: + warnings.warn(""" + self.expose_indicator_functions was False; + will not calculate and store indicator functions. + Indicator functions will be computed on the fly as needed + for energy calculations and then discarded. + If you want to get the indicator functions for your own purposes, + set expose_indicator_functions=True + and then call calculate_indicators().""") - # Compute fast properties - self.aa_freq = frustration.compute_aa_freq(self.sequence) - self.contact_freq = frustration.compute_contact_freq(self.sequence) - self.potts_model = {} - self.potts_model['h'] = burial_energy.sum(axis=-1)[:, self.aa_map_awsem_list] - self.potts_model['J'] = contact_energy.sum(axis=0)[:, :, self.aa_map_awsem_x, self.aa_map_awsem_y] - - # Set the gap energy to zero - self.potts_model['h'][:, 0] = 0 - self.potts_model['J'][:, :, 0, :] = 0 - self.potts_model['J'][:, :, :, 0] = 0 - self._native_energy=None + # make self.pdb_structure into a property so that structure-dependent + # stuff is recalculated automatically when we change the conformation + @property + def pdb_structure(self): + return self._pdb_structure + @pdb_structure.setter + def pdb_structure(self,pdb_structure): + # reset structural attributes + self.setup_structure(pdb_structure) + # check that our new structure is compatible with our old one + if self.N != len(self.sequence): + breakpoint() + raise ValueError("The pdb is incomplete. Try setting 'repair_pdb=True' when constructing the Structure object.") + self.subclass_setup_helper() + def change_conformation(self,alt_conf): + # this method is an alias for the setter + # Keep this method if the setter is too slow + self.pdb_structure = alt_conf + # self.masked_indicators is calculated from other attributes, + # so it should be made into a property + @property + def masked_indicators(self): + # store indicators and gammas for our particular sequence as attributes + _masked_indicators=[] + _masked_indicators.append(self.burial_indicator[:,0]) + _masked_indicators.append(self.burial_indicator[:,1]) + _masked_indicators.append(self.burial_indicator[:,2]) + _masked_indicators.append(self.direct_indicator[:,:,0,0]*self.sequence_mask_contact) + _masked_indicators.append(self.protein_indicator[:,:,0,0]*self.sequence_mask_contact) + _masked_indicators.append(self.water_indicator[:,:,0,0]*self.sequence_mask_contact) + _masked_indicators.append(self.electrostatics_indicator*self.electrostatics_mask) + return _masked_indicators + @masked_indicators.setter + def masked_indicators(self): + raise AttributeError(f"""Setting {self.__class__}.indicators directly is not allowed. + Modify {self.__class__}.burial_indicator, {self.__class__}.direct_indicator, + {self.__class__}.protein_indicator, {self.__class__}.water_indicator, + {self.__class__}.electrostatic_indicator, + {self.__class__}.sequence_mask_contact, + or {self.__class__}.electrostatics_mask instead.""") + + # implementations of frustration algorithms def compute_configurational_decoy_statistics(self, n_decoys=4000,aa_freq=None): # ['A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'] - _AA='ARNDCQEGHILKMFPSTWYV' + _AA = self.p.alphabet #'ARNDCQEGHILKMFPSTWYV' if aa_freq is None: - seq_index = np.array([_AA.find(aa) for aa in self.sequence]) + seq_index = self.seq_index N=self.N else: N=self.N*10 @@ -262,22 +816,22 @@ def compute_configurational_decoy_statistics(self, n_decoys=4000,aa_freq=None): seq_index = np.random.choice(a=len(aa_freq), size=N, p=probabilities) distances = np.triu(self.distance_matrix) - distances = distances[(distances0)] + distances = distances[(distances0)] rho_b = np.expand_dims(self.rho_r, 1) #(n,1) rho1 = np.expand_dims(self.rho_r, 0) #(1,n) rho2 = np.expand_dims(self.rho_r, 1) #(n,1) - sigma_water = 0.25 * (1 - np.tanh(self.eta_sigma * (rho1 - self.rho_0))) * (1 - np.tanh(self.eta_sigma * (rho2 - self.rho_0))) #(n,n) + sigma_water = 0.25 * (1 - np.tanh(self.p.eta_sigma * (rho1 - self.p.rho_0))) * (1 - np.tanh(self.p.eta_sigma * (rho2 - self.p.rho_0))) #(n,n) sigma_protein = 1 - sigma_water #(n,n) #Calculate theta and indicators - theta = 0.25 * (1 + np.tanh(self.eta * (distances - self.r_min))) * (1 + np.tanh(self.eta * (self.r_max - distances))) # (c,) - thetaII = 0.25 * (1 + np.tanh(self.eta * (distances - self.r_minII))) * (1 + np.tanh(self.eta * (self.r_maxII - distances))) #(c,) - burial_indicator = np.tanh(self.burial_kappa * (rho_b - self.burial_ro_min)) + np.tanh(self.burial_kappa * (self.burial_ro_max - rho_b)) #(n,3) + theta = 0.25 * (1 + np.tanh(self.p.eta * (distances - self.p.r_min))) * (1 + np.tanh(self.p.eta * (self.p.r_max - distances))) # (c,) + thetaII = 0.25 * (1 + np.tanh(self.p.eta * (distances - self.p.r_minII))) * (1 + np.tanh(self.p.eta * (self.p.r_maxII - distances))) #(c,) + burial_indicator = np.tanh(self.p.burial_kappa * (rho_b - self.p.burial_ro_min)) + np.tanh(self.p.burial_kappa * (self.p.burial_ro_max - rho_b)) #(n,3) charges = np.array([0, 1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) - electrostatics_indicator = np.exp(-distances / self.electrostatics_screening_length) / distances + electrostatics_indicator = np.exp(-distances / self.p.electrostatics_screening_length) / distances decoy_energies=np.zeros(n_decoys) #decoy_data=[None]*n_decoys @@ -292,14 +846,14 @@ def compute_configurational_decoy_statistics(self, n_decoys=4000,aa_freq=None): q2=seq_index[qi2] - burial_energy1 = (-0.5 * self.k_contact * self.burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) - burial_energy2 = (-0.5 * self.k_contact * self.burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) + burial_energy1 = (-0.5 * self.p.k_contact * self.p.burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) + burial_energy2 = (-0.5 * self.p.k_contact * self.p.burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) - direct = theta[c] * self.direct_gamma[q1, q2] - water_mediated = sigma_water[n1,n2] * thetaII[c] * self.water_gamma[q1,q2] - protein_mediated = sigma_protein[n1,n2] * thetaII[c] * self.protein_gamma[q1,q2] - contact_energy = -self.k_contact * (direct+water_mediated+protein_mediated) - electrostatics_energy = self.k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] + direct = theta[c] * self.p.direct_gamma[q1, q2] + water_mediated = sigma_water[n1,n2] * thetaII[c] * self.p.water_gamma[q1,q2] + protein_mediated = sigma_protein[n1,n2] * thetaII[c] * self.p.protein_gamma[q1,q2] + contact_energy = -self.p.k_contact * (direct+water_mediated+protein_mediated) + electrostatics_energy = self.p.k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] decoy_energies[i]=(burial_energy1+burial_energy2+contact_energy+electrostatics_energy) #decoy_data[i]=[i, qi1, qi2, q1, q2, n1, n2, distances[c], self.rho_r[n1], self.rho_r[n2], contact_energy/4.184, burial_energy1/4.184, burial_energy2/4.184, electrostatics_energy/4.184, decoy_energies[i]] @@ -309,15 +863,15 @@ def compute_configurational_decoy_statistics(self, n_decoys=4000,aa_freq=None): return mean_decoy_energy, std_decoy_energy def compute_configurational_energies(self): - _AA='ARNDCQEGHILKMFPSTWYV' - seq_index = np.array([_AA.find(aa) for aa in self.sequence]) + _AA= self.p.alphabet #'ARNDCQEGHILKMFPSTWYV' + seq_index = self.seq_index distances = np.triu(self.distance_matrix) - distances = distances[(distances0)] + distances = distances[(distances0)] n_contacts=len(distances) n = self.distance_matrix.shape[0] # Assuming self.distance_matrix is defined and square tri_upper_indices = np.triu_indices(n, k=1) # k=1 excludes the diagonal - valid_pairs = (self.distance_matrix[tri_upper_indices] < self.distance_cutoff_contact) & \ + valid_pairs = (self.distance_matrix[tri_upper_indices] < self.p.distance_cutoff_contact) & \ (self.distance_matrix[tri_upper_indices] > 0) indices1,indices2 = (tri_upper_indices[0][valid_pairs], tri_upper_indices[1][valid_pairs]) @@ -328,16 +882,16 @@ def compute_configurational_energies(self): rho1 = np.expand_dims(self.rho_r, 0) #(1,n) rho2 = np.expand_dims(self.rho_r, 1) #(n,1) - sigma_water = 0.25 * (1 - np.tanh(self.eta_sigma * (rho1 - self.rho_0))) * (1 - np.tanh(self.eta_sigma * (rho2 - self.rho_0))) #(n,n) + sigma_water = 0.25 * (1 - np.tanh(self.p.eta_sigma * (rho1 - self.p.rho_0))) * (1 - np.tanh(self.p.eta_sigma * (rho2 - self.p.rho_0))) #(n,n) sigma_protein = 1 - sigma_water #(n,n) #Calculate theta and indicators - theta = 0.25 * (1 + np.tanh(self.eta * (distances - self.r_min))) * (1 + np.tanh(self.eta * (self.r_max - distances))) # (c,) - thetaII = 0.25 * (1 + np.tanh(self.eta * (distances - self.r_minII))) * (1 + np.tanh(self.eta * (self.r_maxII - distances))) #(c,) - burial_indicator = np.tanh(self.burial_kappa * (rho_b - self.burial_ro_min)) + np.tanh(self.burial_kappa * (self.burial_ro_max - rho_b)) #(n,3) + theta = 0.25 * (1 + np.tanh(self.p.eta * (distances - self.p.r_min))) * (1 + np.tanh(self.p.eta * (self.p.r_max - distances))) # (c,) + thetaII = 0.25 * (1 + np.tanh(self.p.eta * (distances - self.p.r_minII))) * (1 + np.tanh(self.p.eta * (self.p.r_maxII - distances))) #(c,) + burial_indicator = np.tanh(self.p.burial_kappa * (rho_b - self.p.burial_ro_min)) + np.tanh(self.p.burial_kappa * (self.p.burial_ro_max - rho_b)) #(n,3) charges = np.array([0, 1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]) - electrostatics_indicator = np.exp(-distances / self.electrostatics_screening_length) / distances + electrostatics_indicator = np.exp(-distances / self.p.electrostatics_screening_length) / distances # decoy_data_columns=['decoy_i','i_resno','j_resno','ires_type','jres_type','aa1','aa2','rij','rho_i','rho_j','water_energy','burial_energy_i','burial_energy_j','electrostatic_energy','total_energies'] # decoy_data=[] @@ -348,14 +902,14 @@ def compute_configurational_energies(self): q1=seq_index[n1] q2=seq_index[n2] - burial_energy1 = (-0.5 * self.k_contact * self.burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) - burial_energy2 = (-0.5 * self.k_contact * self.burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) + burial_energy1 = (-0.5 * self.p.k_contact * self.p.burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) + burial_energy2 = (-0.5 * self.p.k_contact * self.p.burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) - direct = theta[c] * self.direct_gamma[q1, q2] - water_mediated = sigma_water[n1,n2] * thetaII[c] * self.water_gamma[q1,q2] - protein_mediated = sigma_protein[n1,n2] * thetaII[c] * self.protein_gamma[q1,q2] - contact_energy = -self.k_contact * (direct+water_mediated+protein_mediated) - electrostatics_energy = self.k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] + direct = theta[c] * self.p.direct_gamma[q1, q2] + water_mediated = sigma_water[n1,n2] * thetaII[c] * self.p.water_gamma[q1,q2] + protein_mediated = sigma_protein[n1,n2] * thetaII[c] * self.p.protein_gamma[q1,q2] + contact_energy = -self.p.k_contact * (direct+water_mediated+protein_mediated) + electrostatics_energy = self.p.k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] energy=(burial_energy1+burial_energy2+contact_energy+electrostatics_energy) configurational_energies[n1,n2]=energy @@ -364,6 +918,586 @@ def compute_configurational_energies(self): # import pandas as pd return configurational_energies #, pd.DataFrame(decoy_data, columns=decoy_data_columns) - def configurational_frustration(self,aa_freq=None, correction=0, n_decoys=4000): - mean_decoy_energy, std_decoy_energy = self.compute_configurational_decoy_statistics(n_decoys=n_decoys,aa_freq=aa_freq) - return -(self.compute_configurational_energies()-mean_decoy_energy)/(std_decoy_energy+correction) \ No newline at end of file + +class AWSEMIndicators(_AWSEMBase): # PottsEvaluatorFromIndicators or PottsEnergyEvaluatorFromIndicators? + """ + This class is intended to be equivalent to AWSEM + but allows initialization from numpy arrays of + indicator functions, rather than calculating + those indicators from a Structure or set of + distance matrices. + """ + def __init__(self, + burial_indicator: np.ndarray, + direct_indicator: np.ndarray, + protein_indicator: np.ndarray, + water_indicator: np.ndarray, + electrostatics_indicator: Union[np.ndarray, None], + sequence: str, # sequence is optional if we initialize from a Structure but not here + expose_indicator_functions: bool=False, + potts_option : bool=False, + absolute_value_gamma: bool=False, + **parameters)->object: + """ + Pass parameters to _AWSEMBase to + set attributes that DO NOT depend on the implementations of + the indicator function and potts model setup calculations, + then set attributes that DO depend on the implementations + of the indicator function and potts model setup calculations. + + Parameters + ---------- + burial_indicator : np.ndarray + Burial indicator array, most likely accessed using the burial_indicator attribute of an AWSEM + direct_indicator : np.ndarray + Direct indicator array, most likely accessed using the direct_indicator attribute of an AWSEM + protein_indicator : np.ndarray + Protein indicator array, most likely accessed using the protein_indicator attribute of an AWSEM + water_indicator : np.ndarray + Water indicator array, most likely accessed using the water_indicator attribute of an AWSEM + electrostatics_indicator : Union[np.ndarray, None] + Electrostatics indicator array, most likely accessed using the electrostatics_indicator attribute of an AWSEM. + May be None is electrostatics were turned off (k_electrostatics=0). + sequence : str + The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. + expose_indicator_functions: bool + If set to True, indicator functions of the contact and burial energy terms can be accessed by user. + potts_option : bool + Whether to set up the potts model (can be RAM-intensive and therefore time-intensive), + which is unnecessary if all you want is to perform certain frustration calculations, + for which energies can be computed on the fly instead of saved in memory. + absolute_value_gamma: bool + If True, replace gammas with their absolute values. This is helpful for the standard deviation approximation + + Returns + ------- + AWSEMIndicators object + + """ + super().__init__(sequence, expose_indicator_functions, potts_option, **parameters) + self.burial_indicator = burial_indicator + self.direct_indicator = direct_indicator + self.protein_indicator = protein_indicator + self.water_indicator = water_indicator + self.electrostatics_indicator = electrostatics_indicator + # we don't have a distance matrix to a apply a minimum sequence separation to-- + # we have to assume that this consideration was already made when computing the indicators. + # So we just set the "distance" matrix to zeros and set no maximum cutoff, so that nothing changes + # however, we can apply a minimum sequence separation-based mask to the matrix + self.sequence_mask_contact = frustration.compute_mask(np.zeros((self.N,self.N)), + maximum_contact_distance=None, + minimum_sequence_separation = self.p.min_sequence_separation_contact) + self.electrostatics_mask = frustration.compute_mask(np.zeros((self.N,self.N)), + maximum_contact_distance=None, + minimum_sequence_separation=self.p.min_sequence_separation_electrostatics) + self.mask = frustration.compute_mask(np.zeros((self.N,self.N)), + maximum_contact_distance=self.distance_cutoff, + minimum_sequence_separation = self.sequence_cutoff) + #if absolute_value_gamma: + # self.burial_gamma = np.abs(self.burial_gamma) + # self.direct_gamma = np.abs(self.direct_gamma) + # self.protein_gamma = np.abs(self.protein_gamma) + # self.water_gamma = np.abs(self.water_gamma) + # self.electrostatics_gamma = np.abs(self.electrostatics_gamma) + #self.absolute_value_gamma = absolute_value_gamma + #np.save('absolute_value_gamma_1.npy',absolute_value_gamma) + #np.save('burial_indicator_1.npy',burial_indicator) + #np.save('direct_indicator_1.npy', direct_indicator) + #np.save('protein_indicator_1.npy', protein_indicator) + #np.save('water_indicator_1.npy', water_indicator) + #np.save('electrostatics_indicator_1.npy', electrostatics_indicator) + self.subclass_setup_helper() + + def calculate_indicators(self): + pass # the function was initialized with indicators, so there's nothing to do + +class AWSEMVariancePotts(_AWSEMBase): + """ + EXPERIMENTAL CLASS THAT TRIES TO REPURPOSE OUR CODE TO CREATE + A SPECIAL KIND OF "POTTS MODEL" WHERE THE "ENERGY" IS ACTUALLY + THE VARIANCE OF THE ENERGIES OF A PREDEFINED SET OF DECOY CONFORMERS. + THIS CLASS IS STILL UNDER DEVELOPMENT. + """ + def __init__(self, + covariance_matrix: np.ndarray, + sequence: str, # sequence is optional if we initialize from a Structure but not here + expose_indicator_functions: bool=False, + absolute_value_gamma: bool=False, + **parameters)->object: + """ + EXPERIMENTAL CLASS THAT TRIES TO REPURPOSE OUR CODE TO CREATE + A SPECIAL KIND OF "POTTS MODEL" WHERE THE "ENERGY" IS ACTUALLY + THE VARIANCE OF THE ENERGIES OF A PREDEFINED SET OF DECOY CONFORMERS. + THIS CLASS IS STILL UNDER DEVELOPMENT. + + Parameters + ---------- + covariance_matrix: np.ndarray + Covariance matrix of all __indicator functions___ (not residues) over a decoy set + sequence : str + The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. + expose_indicator_functions: bool + If set to True, indicator functions of the contact and burial energy terms can be accessed by user. + absolute_value_gamma: bool + If True, replace gammas with their absolute values. This is helpful for the standard deviation approximation + + Returns + ------- + AWSEMVariancePotts object + + """ + # if we already have our indicator functions, + # our goal is probably to compute the potts model, + # so we'll just hard code a value of True for that argument VVVV + super().__init__(sequence, expose_indicator_functions, potts_option=True, **parameters) + self.covariance_matrix = covariance_matrix + self.num_indicators = 3*self.N + 4*(self.N**2-self.N)/2 # low, med, high burial for each N, 4 classes of pair interactions + self.subclass_setup_helper() + + @staticmethod # trying to avoid loading down memory with too many permanent attributes + def pairwise_mask(l): # l for length + # Helps us figure out where a 1D array's elements were in an upper triangular matrix, + # assuming the matrix was flattened row-major style and the main diagonal was excluded. + # Each index i of this list gives us the indices of the 1D array that were in row i of the matrix + #NbyN_matrix_rows = [[range(n*l-int(((n**2)+n)/2),(n+1)*l-int((((n+1)**2)+(n+1))/2)), + # ] for n in range(l)] + mask = np.zeros((l,l)) + for i in range(l): + temp = np.zeros((l,l)) + # set elements involving i equal to 1 + temp[:,i] = 1 + temp[i,:] = 1 + temp = temp[np.triu_indices(l,k=1)] # this flattens the array, keeping only the upper triangle + found = np.where(temp==1)[0] + try: + mask[i, :] = [1 if index in found else 0 for index in range(l)]#[1 if index in NbyN_matrix_rows[i] else 0 for index in range(l)] + except: + import pdb; pdb.set_trace() + return mask + + def calculate_indicators(self): + print("start calculate indicators") + assert len(self.covariance_matrix.shape)==2, self.covariance_matrix.shape + assert self.covariance_matrix.shape[0] == self.covariance_matrix.shape[1] + print("assertions complete") + # each "indicator function" is actually a covariance of two indicator functions. + # There are a few different kinds of pairs of indicator functions. + # The first kind is self-covariances, AKA variances, which we further break down + # into burial (dependent on AA identity at a single position) + # and pairwise (dependent on two positions) + self.burial_variances = np.diag(self.covariance_matrix[:3*self.N,:3*self.N]) # shape (3N,) + self.pairwise_variances = np.diag(self.covariance_matrix[3*self.N:,3*self.N:]) # shape (4N,) + np.save("burial_variances.npy", self.burial_variances) + np.save("pairwise_variances.npy", self.pairwise_variances) + print("variances calculated") + # The other kind of covariance is a covarience between indicator functions. + # We break these down into burial-burial covariances (dependent on two identities), + # burial-pairwise indicator covariances (some dependent on 2, others dependent on 3 identities), + # and pairwise indicator-pairwise indicator covariances (dependent on 3 or 4 identities) + self.burial_burial_covariances = self.covariance_matrix[np.triu_indices(3*self.N,k=1)] # shape (((3N)**2-3N)/2,) + print("first covariances calculated") + num_upper = int((self.N**2-self.N)/2) + burial_pairwise_covariances_2 = np.zeros((3*self.N, 4*num_upper)) # shape (3N, 4((N**2-N)/2)) + #self.burial_pairwise_covariances_2 = np.concatenate([ # can be represented by 2-body term in Potts model + # pairwise_mask tells us which pairwise indicator functions in a given row + # involve the residue whose burial covariances are evaluated in that row + # (the covariance matrix has more elements than there are energy terms-- + # some elements represent relationships between residues that don't interact directly, + # so there are many indicators in each row i involving residues j and k but not i) + # + # We repeat pairwise_mask 3 times because each residue is repeated 3 times (low, med, high density) + # self.covariance_matrix[:3*self.N, 3*self.N+i*num_upper:3*self.N+(i+1)*num_upper]\ + # *self.pairwise_mask(num_upper)[:self.N,:].repeat(3,axis=0)\ + # for i in range(4)], axis=1) # shape (3N,4((N**2-N)/2)) + for counter in range(4): + burial_pairwise_covariances_2[:,counter*num_upper:(counter+1)*num_upper] =\ + (self.covariance_matrix[:3*self.N, 3*self.N+counter*num_upper:3*self.N+(counter+1)*num_upper]+1E-10)\ + *self.pairwise_mask(num_upper)[:self.N,:].repeat(3,axis=0) # tiny shift of 1E-10 ensures that the only contacts at exactly 0 are those that fail the mask + self.burial_pairwise_covariances_2 = burial_pairwise_covariances_2 + print("second covariances calculated") + # these last three components contain many more covariances than the others + # and are likely to be sparse + self.burial_pairwise_covariances_3 = None # set to every burial-pairwise covariance not in the previous one + self.pairwise_pairwise_covariances_3 = None # set to every pairwise where one residue is common between the two + self.pairwise_pairwise_covariances_4 = None # everything not in pairwise_pairwise_covariances_3 + + def calculate_energy_and_potts(self): + + J_index = np.meshgrid(range(self.N), range(self.N), range(self.q), range(self.q), indexing='ij', sparse=False) + h_index = np.meshgrid(range(self.N), range(self.q), indexing='ij', sparse=False) + + # compute burial and contact energies + # the "energy" of our potts model representing the covariance, not a physical energy + # this "burial energy" is the sum of variances of the burial indicators (the one-body part of the model) + self.burial_energy = (0.5*self.p.k_contact*self.p.burial_gamma[h_index[1]])**2 * self.burial_variances.reshape((self.N,1,3)) + # the "contact energy" is ordinarily the sum of all two-body components of the model + # (direct, protein, water, electrostatics), so we do the analogous thing here + template = np.zeros((self.N,self.N)) + num_upper = int((self.N**2-self.N)/2) + triu_indices = np.triu_indices(self.N,k=1) + template[triu_indices] = self.pairwise_variances[:num_upper] + direct = (template+template.T)[:,:,np.newaxis,np.newaxis] * self.p.direct_gamma[J_index[2], J_index[3]]**2 + template[triu_indices] = self.pairwise_variances[num_upper:2*num_upper] + protein_mediated = (template+template.T)[:,:,np.newaxis,np.newaxis] * self.p.protein_gamma[J_index[2], J_index[3]]**2 + template[triu_indices] = self.pairwise_variances[2*num_upper:3*num_upper] + water_mediated = (template+template.T)[:,:,np.newaxis,np.newaxis] * self.p.water_gamma[J_index[2], J_index[3]]**2 + contact_energy = self.p.k_contact * np.array([direct, protein_mediated, water_mediated]) + if self.p.k_electrostatics!=0: + template[triu_indices] = self.pairwise_variances[3*num_upper:] + electrostatics_energy = -self.p.k_electrostatics * self.p.electrostatics_gamma[np.newaxis,np.newaxis,:,:] * (template+template.T)[:,:,np.newaxis,np.newaxis]**2 + contact_energy = np.append(contact_energy, electrostatics_energy[np.newaxis,:,:,:,:], axis=0) + # for the variance potts model, there is one more kind of two-body interaction: + # burial-pairwise covariance when the pairwise energy term involves the residue in the burial term + # self.burial_pairwise_covariances_2 has shape (3N, 4(N^2-N)/2) + # we first multiply each row by the appropriate burial energy + temp = self.burial_pairwise_covariances_2 + low = temp[::3,:,np.newaxis]*0.5*self.p.k_contact*self.p.burial_gamma[h_index[1],0] + med = temp[1::3,:,:]*0.5*self.p.k_contact*self.p.burial_gamma[h_index[1],1] + high = temp[2::3,:,:]*0.5*self.p.k_contact*self.p.burial_gamma[h_index[1],2] + # we can now collapse our 3 burial indicator types + temp = np.sum(np.concatenate((low[None,...], med[None,...], high[None,...]), axis=0), axis=0) + assert temp.shape == (self.N, 4*((self.N**2-self.N)/2)), temp.shape + # now we split into our 4 pairwise contact types, keeping only the elements of each row + # that represent a pairwise interaction involving the residue whose burial covariances are found in that row + direct, prot, wat, elec = np.split(temp[temp!=0], 4, axis=1) + # now we need to go from shape (N, (N^2-N)/2) to (N,N) + # (each residue burial indicator covaries with (N^2-N)/2 pairwise indicators, + # but only N of them include the same residue from the burial indicator; + # others have a value of 0, which we can easily eliminate) + # we also need to multiply by our pairwise gammas + direct = direct[direct != 0].reshape((self.N,self.N,self.q))[...,np.newaxis]*self.p.direct_gamma[J_index[3]]*self.p.k_contact + prot = prot[prot != 0].reshape((self.N,self.N,self.q))[...,np.newaxis]*self.p.protein_gamma[J_index[3]]*self.p.k_contact + wat = wat[wat != 0].reshape((self.N,self.N,self.q))[...,np.newaxis]*self.p.water_gamma[J_index[3]]*self.p.k_contact + ############################################################################################################################ + elec = elec[elec != 0].reshape((self.N,self.N,self.q))[...,np.newaxis]*self.p.electrostatics_gamma[np.newaxis,np.newaxis,:,:][J_index[3]]*self.p.k_contact + # ???????????????? why are we multiplying electrostatics by k_contact? + # electrostatics_gamma already had the electrostatics weight k_electrostatics multiplied in and k_electrostatics + # isn't necessarily equal to k_contact. Anyway, i'm now going to factor k_electrostatics out of electrostatics_gamma + # in the _AWSEMBase class + # probably should be + # elec = elec[elec != 0].reshape((self.N,self.N,self.q))[...,np.newaxis]*self.electrostatics_gamma[np.newaxis,np.newaxis,:,:][J_index[3]]*(-self.k_electrostatics) + ############################################################################################################################# + + contact_energy = np.append(contact_energy, direct[np.newaxis,...], axis=0) + contact_energy = np.append(contact_energy, prot[np.newaxis,...], axis=0) + contact_energy = np.append(contact_energy, wat[np.newaxis,...], axis=0) + contact_energy = np.append(contact_energy, elec[np.newaxis,...], axis=0) + + """ + direct = np.zeros((self.N,self.N)) + direct[triu_indices] = + direct + + num_upper = int((self.N**2-self.N)/2) + triu_indices = np.triu_indices(self.N,k=1) + template = np.zeros((self.N,self.N)) + + assert direct.shape==(self.N,self.N), direct.shape + assert prot.shape==(self.N,self.N), prot.shape + assert wat.shape==(self.N,self.N), wat.shape + assert elec.shape==(self.N,self.N), elec.shape + + + for counter,row in enumerate(self.burial_pairwise_covariances_2): + direct_indicators = row[:len(row)//4] + direct_energy = direct_indicators[direct_indicators>0].reshape((-1,1))\ + *0.5*self.p.k_contact*self.p.burial_gamma[:,counter%3]\ + *self.p.k_contact*self.direct_gamma[] + direct = row[row!=0] * 0.5*self.p.k_contact*self.p.burial_gamma[] + template[counter] = row[row==1] # where the residue corresponding to the burial row is involved in the pairwise indicator + burial_pairwise_2 = self.burial_pairwise_covariances_2[] + """ + #################################################################### + # the potts model that we're using (AWSEMEnergy) multiplies each of the J terms by 1/2, + # so we should multiply them by 2 to cancel that out + contact_energy[np.diag_indices(contact_energy.shape[0])] *= 1/2 + contact_energy *= 2 + #################################################################### + + self.contact_energy = contact_energy + + # Compute potts model + self._potts_model = {} + self._potts_model['h'] = self.burial_energy.sum(axis=-1)[:, :]#self.aa_map_awsem_list] + self._potts_model['J'] = self.contact_energy.sum(axis=0)[:, :, :, :]#self.aa_map_awsem_x, self.aa_map_awsem_y] + # Set the gap energy to zero + #self.potts_model['h'][:, 0] = 0 + #self.potts_model['J'][:, :, 0, :] = 0 + #self.potts_model['J'][:, :, :, 0] = 0 + self._native_energy=None # don't know what this does + +class DecoyEnsemble(): + """ + EXPERIMENTAL CLASS THAT ITERATIVELY COMPUTES INDICATOR FUNCTIONS + AND STATISTICS FOR A SET OF CONFORMERS, SPECIFIED AS A PYTHON + Generator OF Structure OBJECTS. + THIS CLASS IS STILL UNDER DEVELOPMENT AND MAYBE SHOULD BE + MOVED OUT OF THIS MODULE. + """ + + def __init__(self, + pdb_structures: Generator[object,None,None], + **parameters)->object: + """ + Generate DecoyEnsemble object + + Parameters + ---------- + pdb_structures : Generator[object,None,None] + yields Structure objects representing decoy structures + other parameters: + masks and cutoffs affecting the AWSEM class's indicator function calculations; + they are applied to all structures in the ensemble; burial_in_context also available, but use at your own risk + + Returns + ------- + DecoyEnsemble object + """ + # the AWSEM class takes care of the indicator calculation (including masking) for us + # AWSEM normally accepts an amino acid sequence argument, but we don't need that here + # However, we do need to pass through parameters used to generate the indicator functions + awsem_obj = AWSEM(next(pdb_structures), expose_indicator_functions=True, repair_pdb=True, **parameters) + self.N = awsem_obj.N # number of residues + with open('burial_indicators/burial_indicator_0.npy','ab') as f: + np.save(f,awsem_obj.burial_indicator) + with open('direct_indicators/direct_indicator_0.npy','ab') as f: + np.save(f,awsem_obj.direct_indicator) + with open('protein_indicators/protein_indicator_0.npy','ab') as f: + np.save(f,awsem_obj.protein_indicator) + with open('water_indicators/water_indicator_0.npy','ab') as f: + np.save(f,awsem_obj.water_indicator) + with open('electrostatics_indicators/electrostatics_indicator_0.npy','ab') as f: + if hasattr(awsem_obj, 'electrostatics_indicator'): + np.save(f,awsem_obj.electrostatics_indicator) + else: + np.save(f,None) + for counter, pdb_structure in enumerate(pdb_structures): # iterate over the rest of the structures without re-initializing the entire AWSEM class + awsem_obj.pdb_structure = pdb_structure # we can use the pdb_structure setter to update structural + # stuff without fully re-initializing the object + with open(f'burial_indicators/burial_indicator_{counter+1}.npy','ab') as f: + np.save(f,awsem_obj.burial_indicator) + with open(f'direct_indicators/direct_indicator_{counter+1}.npy','ab') as f: + np.save(f,awsem_obj.direct_indicator) + with open(f'protein_indicators/protein_indicator_{counter+1}.npy','ab') as f: + np.save(f,awsem_obj.protein_indicator) + with open(f'water_indicators/water_indicator_{counter+1}.npy','ab') as f: + np.save(f,awsem_obj.water_indicator) + with open(f'electrostatics_indicators/electrostatics_indicator_{counter+1}.npy','ab') as f: + if hasattr(awsem_obj, 'electrostatics_indicator'): + np.save(f,awsem_obj.electrostatics_indicator) + else: + np.save(f,None) + # averages are needed to compute standard deviation + # having these be attributes allows them to be easily passed + # from the avg method to the std method + self.avg_burial = None + self.avg_direct = None + self.avg_prot = None + self.avg_wat = None + self.avg_elec = None + # and standard deviations can help us check our work + # on the covariance matrix calculation + self.std_burial = None + self.std_direct = None + self.std_prot = None + self.std_wat = None + self.std_elec = None + ################################################ + ## Attach gamma parameters from the AWSEM object + ## Kind of off-topic from my current use of this class + #self.burial_gamma = awsem_obj.burial_gamma + #self.direct_gamma = awsem_obj.direct_gamma + #self.protein_gamma = awsem_obj.protein_gamma + #self.water_gamma = awsem_obj.water_gamma + #self.electrostatics_gamma = getattr(awsem_obj, 'electrostatics_gamma', None) + ################################################ + # this might help with memory + del awsem_obj + + # To manage memory, we need the indicator attributes to be generators (see self.get_indicators). + # But to ensure that we can iterate over them more than once, we need to + # be able to reinitialize the generators. We accomplish this with properties + @property + def burial_indicators(self): + return self.get_indicators("burial_indicators") + @property + def direct_indicators(self): + return self.get_indicators("direct_indicators") + @property + def protein_indicators(self): + return self.get_indicators("protein_indicators") + @property + def water_indicators(self): + return self.get_indicators("water_indicators") + @property + def electrostatics_indicators(self): + return self.get_indicators("electrostatics_indicators") + + # allows us to process indicators without holding them all in memory + # this requires that every method that acts on the indicators iterates over them + def get_indicators(self, directory): + # expecting a directory containing numpy files + for filename in sorted(os.listdir(directory)): + yield np.load(f"{directory}/{filename}") + #with open(filename, 'rb') as f: + # while True: + # try: + # yield np.load(f, allow_pickle=True) # allow_pickle=True needed to load None if not electrostatics + # except EOFError: + # break + + # average indicator functions over all decoys + # these averages can then be averaged to get the average of all indicator functions over all decoys + def avg(self): + # average burial computation from generator + avg_burial = 0 + counter = 0 + for array in self.burial_indicators: + counter += 1 + avg_burial += array + avg_burial /= counter + self.avg_burial = avg_burial + # average direct computation from generator + avg_direct = 0 + counter = 0 + for array in self.direct_indicators: + counter += 1 + avg_direct += array + avg_direct /= counter + self.avg_direct = avg_direct + # average prot computation from generator + avg_prot = 0 + counter = 0 + for array in self.protein_indicators: + counter += 1 + avg_prot += array + avg_prot /= counter + self.avg_prot = avg_prot + # average wat computation from generator + avg_wat = 0 + counter = 0 + for array in self.water_indicators: + counter += 1 + avg_wat += array + avg_wat /= counter + self.avg_wat = avg_wat + # average elec computation from generator + # if not defined, set to zero, which will have no impact + if self.electrostatics_indicators == None: + avg_elec = 0 + else: + avg_elec = 0 + counter = 0 + for array in self.electrostatics_indicators: + counter += 1 + avg_elec += array + avg_elec /= counter + self.avg_elec = avg_elec + return self.avg_burial, self.avg_direct, self.avg_prot, self.avg_wat, self.avg_elec + + # standard deviation of each indicator function over all decoys + # ** averaging these averages + # DOES NOT equal + # the variance over all structures of the sum of the indicator functions of each structure ** + def std(self): + if self.avg_burial is None or self.avg_direct is None or \ + self.avg_prot is None or self.avg_wat is None or \ + self.avg_elec is None: + self.avg() # compute averages if not already done + # std burial computation from generator and previously computed average + std_burial = 0 + counter = 0 + for array in self.burial_indicators: + counter += 1 + std_burial += (array - self.avg_burial) ** 2 + std_burial = np.sqrt(std_burial / counter) + self.std_burial = std_burial + # std direct computation from generator and previously computed average + std_direct = 0 + counter = 0 + for array in self.direct_indicators: + counter += 1 + std_direct += (array - self.avg_direct) ** 2 + std_direct = np.sqrt(std_direct / counter) + self.std_direct = std_direct + # std prot computation from generator and previously computed average + std_prot = 0 + counter = 0 + for array in self.protein_indicators: + counter += 1 + std_prot += (array - self.avg_prot) ** 2 + std_prot = np.sqrt(std_prot / counter) + self.std_prot = std_prot + # std wat computation from generator and previously computed average + std_wat = 0 + counter = 0 + for array in self.water_indicators: + counter += 1 + std_wat += (array - self.avg_wat) ** 2 + std_wat = np.sqrt(std_wat / counter) + self.std_wat = std_wat + # std elec computation from generator and previously computed average + std_elec = 0 + counter = 0 + for array in self.electrostatics_indicators: + counter += 1 + std_elec += (array - self.avg_elec) ** 2 + std_elec = np.sqrt(std_elec / counter) + self.std_elec = std_elec + return std_burial, std_direct, std_prot, std_wat, std_elec + + def covariance_matrix(self): + # + # compute averages + if self.avg_burial is None or self.avg_direct is None or \ + self.avg_prot is None or self.avg_wat is None or \ + self.avg_elec is None: + self.avg() + triu_indices = np.triu_indices(self.N, k=1) + all_avg = np.concatenate([self.avg_burial.flatten(), + self.avg_direct[triu_indices].squeeze(), self.avg_prot[triu_indices].squeeze(), + self.avg_wat[triu_indices].squeeze(), self.avg_elec[triu_indices].squeeze()]) + ex_ey = np.outer(all_avg, all_avg) + # compute covariances + number_indicators = 3*self.N + 4*int((self.N**2 - self.N)/2) + assert ex_ey.shape == (number_indicators, number_indicators), f"ex_ey.shape: {ex_ey.shape}, number_indicators: {number_indicators}" + exy = np.zeros((number_indicators, number_indicators)) + num_decoys = 0 + # we want all the burial indicators, + # but only the unique pairwise indicators (no need to double count) + for b, d, p, w, e in zip(self.burial_indicators, self.direct_indicators, + self.protein_indicators, self.water_indicators, + self.electrostatics_indicators): + all_decoy = np.concatenate([b.flatten(), d[triu_indices].squeeze(), + p[triu_indices].squeeze(), w[triu_indices].squeeze(), e[triu_indices].squeeze()]) + exy += np.outer(all_decoy, all_decoy) + num_decoys += 1 + exy /= num_decoys + covariance_matrix = exy - ex_ey + # check our work + #variances = np.concatenate([np.triu(self.std_burial).flatten(), np.triu(self.std_direct).flatten(), + # np.triu(self.std_prot).flatten(), np.triu(self.std_wat).flatten(), + # np.triu(self.std_elec).flatten()])**2 + #variances = variances[variances!=0] + #assert np.allclose(variances, np.diag(covariance_matrix)) + assert np.all(covariance_matrix==covariance_matrix.T) + assert covariance_matrix.shape == exy.shape == ex_ey.shape + self.covariance_matrix = covariance_matrix + return covariance_matrix + + def all_decoy_indicators(self): + # returns lists of indicator functions for each decoy + # memory scales with the size of the structure and decoy set + all_burial = [] + all_direct = [] + all_prot = [] + all_wat = [] + all_elec = [] + for burial, direct, prot, wat, elec in zip(self.burial_indicators,self.direct_indicators, + self.protein_indicators, self.water_indicators, self.electrostatics_indicators): + all_burial.append(burial) + all_direct.append(direct) + all_prot.append(prot) + all_wat.append(wat) + all_elec.append(elec) + return all_burial, all_direct, all_prot, all_wat, all_elec diff --git a/frustratometer/classes/DCA.py b/frustratometer/classes/DCA.py index 7c9623c5..c385ad83 100644 --- a/frustratometer/classes/DCA.py +++ b/frustratometer/classes/DCA.py @@ -78,6 +78,8 @@ class DCA(Frustratometer): # self._decoy_fluctuation = {} # return self + alphabet = '-ACDEFGHIKLMNPQRSTVWY' + @classmethod def from_potts_model_file(cls,pdb_structure: object, potts_model_file: Union[Path,str] = None, @@ -146,8 +148,8 @@ def from_potts_model_file(cls,pdb_structure: object, self.potts_model["J"]= self.potts_model["familycouplings"].reshape(int(len(self.filtered_aligned_sequence)),21,int(len(self.filtered_aligned_sequence)),21).transpose(0,2,1,3) if self.filtered_aligned_sequence is not None: - self.aa_freq = frustration.compute_aa_freq(self.sequence) - self.contact_freq = frustration.compute_contact_freq(self.sequence) + self.aa_freq = frustration.compute_aa_freq(self.sequence, self.alphabet) + self.contact_freq = frustration.compute_contact_freq(self.sequence, self.alphabet) else: self.aa_freq = None self.contact_freq = None @@ -222,8 +224,8 @@ def from_pottsmodel(cls,pdb_structure : object, self.potts_model["J"]= self.potts_model["familycouplings"].reshape(int(len(self.filtered_aligned_sequence)),21,int(len(self.filtered_aligned_sequence)),21).transpose(0,2,1,3) if self.filtered_aligned_sequence is not None: - self.aa_freq = frustration.compute_aa_freq(self.sequence) - self.contact_freq = frustration.compute_contact_freq(self.sequence) + self.aa_freq = frustration.compute_aa_freq(self.sequence, self.alphabet) + self.contact_freq = frustration.compute_contact_freq(self.sequence, self.alphabet) else: self.aa_freq = None self.contact_freq = None diff --git a/frustratometer/classes/Frustratometer.py b/frustratometer/classes/Frustratometer.py index d8cceec1..ee608706 100644 --- a/frustratometer/classes/Frustratometer.py +++ b/frustratometer/classes/Frustratometer.py @@ -61,9 +61,27 @@ def native_energy(self,sequence:str = None,ignore_couplings_of_gaps:bool=False,i if sequence is None: sequence=self.sequence else: - return frustration.compute_native_energy(sequence, self.potts_model, self.mask,ignore_couplings_of_gaps,ignore_fields_of_gaps) + return frustration.compute_native_energy(sequence, self.potts_model, self.mask, self.alphabet, + ignore_couplings_of_gaps, ignore_fields_of_gaps) if not self._native_energy: - self._native_energy=frustration.compute_native_energy(sequence, self.potts_model, self.mask,ignore_couplings_of_gaps,ignore_fields_of_gaps) + self._native_energy=frustration.compute_native_energy(sequence, self.potts_model, self.mask, self.alphabet, + ignore_couplings_of_gaps, ignore_fields_of_gaps) + else: + new = frustration.compute_native_energy( + sequence, self.potts_model, self.mask, self.alphabet, + ignore_couplings_of_gaps, ignore_fields_of_gaps) + if not (self._native_energy == new): + raise AssertionError(f""" + It seems that you have changed parameters of an object such that + the native energy of your system is now different from what it was + originally computed to be. Our code probably should prevent this + from happening, but you can prevent it too by not changing the alphabet + or any other parameters after initializing your DCA or AWSEM-family + class (anything that inherits from _AWSEMBase). + + Previous value of {self.__class__}._native_energy: {self._native_energy} + New value of {self.__class__}._native_energy: {new}""") + energy_value=self._native_energy return energy_value @@ -89,7 +107,7 @@ def sequences_energies(self, sequences:np.array, split_couplings_and_fields:bool output (if split_couplings_and_fields==True): np.array Array containing computed fields and couplings energies of the protein sequences. """ - output=frustration.compute_sequences_energy(sequences, self.potts_model, self.mask, split_couplings_and_fields) + output=frustration.compute_sequences_energy(sequences, self.potts_model, self.mask, self.alphabet, split_couplings_and_fields) return output def fields_energy(self, sequence:str = None, ignore_fields_of_gaps:bool = False) -> float: @@ -114,7 +132,7 @@ def fields_energy(self, sequence:str = None, ignore_fields_of_gaps:bool = False) """ if sequence is None: sequence=self.sequence - fields_energy=frustration.compute_fields_energy(sequence, self.potts_model,ignore_fields_of_gaps) + fields_energy=frustration.compute_fields_energy(sequence, self.potts_model, self.alphabet, ignore_fields_of_gaps) return fields_energy def couplings_energy(self, sequence:str = None,ignore_couplings_of_gaps:bool = False) -> float: @@ -139,7 +157,7 @@ def couplings_energy(self, sequence:str = None,ignore_couplings_of_gaps:bool = F """ if sequence is None: sequence=self.sequence - couplings_energy=frustration.compute_couplings_energy(sequence, self.potts_model, self.mask,ignore_couplings_of_gaps) + couplings_energy=frustration.compute_couplings_energy(sequence, self.potts_model, self.mask, self.alphabet, ignore_couplings_of_gaps) return couplings_energy def decoy_fluctuation(self, sequence:str = None,kind:str = 'singleresidue',mask:np.array = None) -> np.array: @@ -167,13 +185,13 @@ def decoy_fluctuation(self, sequence:str = None,kind:str = 'singleresidue',mask: if not isinstance(mask, np.ndarray): mask=self.mask if kind == 'singleresidue': - fluctuation = frustration.compute_singleresidue_decoy_energy_fluctuation(sequence, self.potts_model, mask) + fluctuation = frustration.compute_singleresidue_decoy_energy_fluctuation(sequence, self.potts_model, mask, self.alphabet) elif kind == 'mutational': - fluctuation = frustration.compute_mutational_decoy_energy_fluctuation(sequence, self.potts_model, mask) + fluctuation = frustration.compute_mutational_decoy_energy_fluctuation(sequence, self.potts_model, mask, self.alphabet) elif kind == 'configurational': - fluctuation = frustration.compute_configurational_decoy_energy_fluctuation(sequence, self.potts_model, mask) + fluctuation = frustration.compute_configurational_decoy_energy_fluctuation(sequence, self.potts_model, mask, self.alphabet) elif kind == 'contact': - fluctuation = frustration.compute_contact_decoy_energy_fluctuation(sequence, self.potts_model, mask) + fluctuation = frustration.compute_contact_decoy_energy_fluctuation(sequence, self.potts_model, mask, self.alphabet) else: raise Exception("Wrong kind of decoy generation selected") self._decoy_fluctuation[kind] = fluctuation @@ -211,7 +229,8 @@ def scores(self): """ return frustration.compute_scores(self.potts_model) - def frustration(self, sequence:str = None, kind:str = 'singleresidue', mask:np.array = None, aa_freq:np.array = None, correction:int = 0) -> np.array: + def frustration(self, sequence:str = None, kind:str = 'singleresidue', mask:np.array = None, aa_freq:np.array = None, + correction:int = 0, n_decoys:int = 4000) -> np.array: """ Calculates frustration index values. @@ -242,9 +261,11 @@ def frustration(self, sequence:str = None, kind:str = 'singleresidue', mask:np.a frustration_values=frustration.compute_single_frustration(decoy_fluctuation, aa_freq, correction) return frustration_values elif kind in ['mutational', 'configurational', 'contact']: - if kind == 'configurational' and 'configurational_frustration' in dir(self): - #TODO: Correct this function for different aa_freq than WT - return self.configurational_frustration(None, correction) + if kind == 'configurational': + if 'configurational_frustration' in dir(self): + return self.configurational_frustration(aa_freq=aa_freq, correction=correction, n_decoys=n_decoys) + else: + raise ValueError("kind='configurational' may only be used on objects implementing self.configurational_frustration") if aa_freq is None: aa_freq = self.contact_freq frustration_values=frustration.compute_pair_frustration(decoy_fluctuation, aa_freq, correction) @@ -268,7 +289,7 @@ def plot_decoy_energy(self, sequence:str = None, kind:str = 'singleresidue', met native_energy = self.native_energy(sequence=sequence) decoy_energy = self.decoy_energy(kind=kind,sequence=sequence) if kind == 'singleresidue': - g = frustration.plot_singleresidue_decoy_energy(decoy_energy, native_energy, method) + g = frustration.plot_singleresidue_decoy_energy(decoy_energy, native_energy, method, self.alphabet) return g def roc(self): @@ -292,6 +313,7 @@ def auc(self): return frustration.compute_auc(self.roc()) def vmd(self, sequence: str = None, single:Union[str,np.array] = 'singleresidue', pair:Union[str,np.array] = 'mutational', + tcl_script:str = 'frustration.tcl', call_vmd:bool=True, aa_freq:np.array = None, correction:int = 0, max_connections:Union[int,None] = None, movie_name=None, still_image_name=None): """ Calculates frustration indices and superimposes frustration patterns onto PDB structure using the VMD software. @@ -317,12 +339,14 @@ def vmd(self, sequence: str = None, single:Union[str,np.array] = 'singleresidue' from the sequence that was passed to this vmd function. Proceeding further may not\n\ perform the computation that you intend to perform.") - + #breakpoint() tcl_script = frustration.write_tcl_script(self.pdb_file, self.chain, self.mask, self.distance_matrix, self.distance_cutoff, -self.frustration(kind=single, sequence=sequence, aa_freq=aa_freq), -self.frustration(kind=pair, sequence=sequence, aa_freq=aa_freq), - max_connections=max_connections, movie_name=movie_name, still_image_name=still_image_name) - frustration.call_vmd(self.pdb_file, tcl_script) + max_connections=max_connections, movie_name=movie_name, still_image_name=still_image_name, + tcl_script=tcl_script,) + if call_vmd: + frustration.call_vmd(self.pdb_file, tcl_script) def view_pair_frustration(self, sequence:str = None, pair:str = 'mutational', aa_freq:np.array = None): """ diff --git a/frustratometer/classes/Gamma.py b/frustratometer/classes/Gamma.py index 821bffea..d7103f04 100644 --- a/frustratometer/classes/Gamma.py +++ b/frustratometer/classes/Gamma.py @@ -34,6 +34,10 @@ def __init__(self, data, segment_definition=None, description=None, alphabet=Non self._validate_segments() + @property + def q(self): + return len(self.alphabet) + def _init_from_array(self, gamma_array): self.gamma_array = gamma_array @@ -399,7 +403,7 @@ def correlate_segments(self, other): return correlations # Plotting - def plot_gamma(self, new_order=None): + def plot_gamma(self, new_order=None, scale=[-5,5]): import matplotlib.pyplot as plt import seaborn as sns if new_order: @@ -408,16 +412,21 @@ def plot_gamma(self, new_order=None): # Plot setup f, axes = plt.subplots(2, 2, figsize=(18, 16)) - titles = ['Burial Gammas', 'Direct Gammas', 'Water Gammas', 'Protein Gammas'] - + f.subplots_adjust(hspace=50) # fix overlap between axis ticks of upper subplots and titles of lower subplots + titles = ['Burial Gammas', 'Direct Gammas', 'Protein Gammas', 'Water Gammas'] for i, (title, name) in enumerate(zip(titles, segments)): ax = axes[i // 2, i % 2] - sns.heatmap(segments[name].reshape(-1, 20), ax=ax, cmap='RdBu_r', center=0) + foo = sns.heatmap(segments[name].reshape(-1, 20), ax=ax, cmap='RdBu_r', center=0, vmin=scale[0], vmax=scale[1]) + foo.collections[0].colorbar.ax.tick_params(labelsize=16) ax.set_title(title) ax.set_xticks(np.arange(len(self.alphabet)) + 0.5) - ax.set_xticklabels(self.alphabet) - ax.set_yticks(np.arange(segments[name].shape[0] // 20) + 0.5) - ax.set_yticklabels(range(segments[name].shape[0] // 20)) + ax.set_xticklabels(self.alphabet, size=16) + if i==0: # burial + ax.set_yticks([0.5,1.5,2.5]) + ax.set_yticklabels(['low','medium','high'], rotation=45, size=16) + else: # direct, prot, or wat + ax.set_yticks(np.arange(len(self.alphabet)) + 0.5) + ax.set_yticklabels(self.alphabet, rotation=0, fontsize=16) plt.tight_layout() plt.show() @@ -648,4 +657,4 @@ class O(): self.gamma1 = Gamma(np.arange(0,1260,1)) self.gamma2 = Gamma(np.arange(0,1260,1)*5+10) - self.gamma3 = Gamma(np.arange(1260,0,-1)*2-4) \ No newline at end of file + self.gamma3 = Gamma(np.arange(1260,0,-1)*2-4) diff --git a/frustratometer/classes/Structure.py b/frustratometer/classes/Structure.py index cf71c2fe..1021e4ea 100644 --- a/frustratometer/classes/Structure.py +++ b/frustratometer/classes/Structure.py @@ -14,7 +14,7 @@ class Structure: def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_selection: str = None, aligned_sequence: str = None, filtered_aligned_sequence: str = None, - distance_matrix_method:str = 'CB', pdb_directory: Path = Path.cwd(), repair_pdb:bool = True)->object: + distance_matrix_method:str = 'CB', pdb_directory: Path = Path.cwd(), repair_pdb:bool = True, return_distance_midpoints:bool = False)->object: """ Generates structure object. Both PDB and CIF format files are accepted as input. @@ -55,6 +55,11 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s If True, provided pdb file will be repaired with missing residues inserted and heteroatoms removed. Note that a pdb file will be produced, regardless of input file format. + return_distance_midpoints: bool + Whether to return a matrix of the same shape as distance_matrix representing the same contacts as distance_matrix + that indicates the absolute coordinates of the midpoint between the pair of atoms. This helps us compute the pair distribution + functions of the different classes of contacts. So this matrix isn't really a matrix because each "element" has 3 channels: x, y, and z + Returns ------- Structure object @@ -77,7 +82,7 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s self.pdbID=pdb_file.stem self.pdb_file=pdb_file - self.chain=chain + self.chain=chain # will be None if no chain supplied self.distance_matrix_method=distance_matrix_method self.filtered_aligned_sequence=filtered_aligned_sequence self.aligned_sequence=aligned_sequence @@ -88,18 +93,18 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s self.init_index_shift=0 if repair_pdb: - fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) + fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) # for this function, chain can be str or list (or None) self.pdb_file=str(pdb_directory/f"{self.pdbID}_cleaned.pdb") if ".pdb" in str(pdb_file) or repair_pdb==True: - self.structure = prody.parsePDB(str(self.pdb_file), chain=self.chain).select(f"protein") + self.structure = prody.parsePDB(str(self.pdb_file), chain=self.chain).select(f"protein") # for this function, chain should be a string containing the chain ids like "AB" or "A B" else: - self.structure=prody.parseMMCIF(str(self.pdb_file),chain=self.chain).select(f"protein") + self.structure=prody.parseMMCIF(str(self.pdb_file),chain=self.chain).select(f"protein") # for this function, chain should be a string containing the chain ids like "AB" or "A B" else: assert len(self.seq_selection.replace("to"," to ").replace(":"," : ").split())>=4, "Please correctly input your residue selection" if self.chain==None: - raise ValueError("Please provide a chain name") + raise ValueError("self.chain==None. Please provide chain name(s)") self.init_index=int(self.seq_selection.replace("to"," to ").replace(":"," : ").split()[1].replace("`","")) self.fin_index=int(self.seq_selection.replace("to"," to ").replace(":"," : ").split()[3].replace("`","")) @@ -116,7 +121,7 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s with open(pdb_file,"r") as f: for line in f: - if line.split()[0]=="ATOM" and line.split()[4+shift]==self.chain: + if line.split()[0]=="ATOM" and (line.split()[4+shift] in self.chain): try: res_index=''.join(i for i in line.split()[5+index_shift] if i.isdigit()) next_res_index=''.join(i for i in next(f).split()[5+index_shift] if i.isdigit()) @@ -133,27 +138,35 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s self.init_index_shift=self.init_index-self.pdb_init_index self.fin_index_shift=self.fin_index-self.pdb_init_index+1 if repair_pdb: - fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) + fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) # for this function, chain can be str or list (or None) self.pdb_file=f"{pdb_directory}/{self.pdbID}_cleaned.pdb" self.select_gap_indices=[i for i in gap_indices if self.init_index<=i<=self.fin_index] self.fin_index_shift-=len(self.select_gap_indices) - self.seq_selection=f"resnum `{self.init_index_shift+1}to{self.fin_index_shift}`" + #self.seq_selection=f"resnum `{self.init_index_shift+1}to{self.fin_index_shift}`" # WE'RE KEEPING IDs NOW SO DON'T WANT TO DO THIS!!!! elif "resindex" in self.seq_selection: self.init_index_shift=self.init_index self.fin_index_shift=self.fin_index+1 if repair_pdb: - fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) + fixer=pdb.repair_pdb(pdb_file, chain, pdb_directory) # for this function, chain can be str or list (or None) self.pdb_file=f"{pdb_directory}/{self.pdbID}_cleaned.pdb" - self.chain="A" + # self.chain="A" I don't know why we would want to change the chain ID here if ".pdb" in str(pdb_file) or repair_pdb==True: self.structure = prody.parsePDB(str(self.pdb_file), chain=self.chain).select(f"protein and {self.seq_selection}") else: self.structure=prody.parseMMCIF(str(self.pdb_file),chain=self.chain).select(f"protein and {self.seq_selection}") - self.sequence=pdb.get_sequence(self.pdb_file,self.chain) - self.distance_matrix=pdb.get_distance_matrix(pdb_file=self.pdb_file,chain=self.chain, - method=self.distance_matrix_method) + self.sequence, self.start_mask = pdb.get_sequence(self.pdb_file,self.chain,return_start_mask=True) # this function can now accept chain as list or string + if return_distance_midpoints: + self.distance_matrix, self.midpoint_matrix = pdb.get_distance_matrix(pdb_file=self.pdb_file,chain=self.chain, # for this function, chain should be a string containing the chain ids + method=self.distance_matrix_method, # separated by a space, like "A B" + return_distance_midpoints=True) + else: + self.distance_matrix=pdb.get_distance_matrix(pdb_file=self.pdb_file,chain=self.chain, # for this function, chain should be a string containing the chain ids + method=self.distance_matrix_method, # separated by a space, like "A B" + return_distance_midpoints=False) + self.midpoint_matrix = None + self.full_pdb_distance_matrix=self.distance_matrix self.z_coordinates=self.structure.select('((name CB) or (resname GLY and name CA))').getCoords() @@ -180,7 +193,7 @@ def __init__(self, pdb_file: Union[Path,str], chain: Union[str,None]=None, seq_s else: self.full_to_aligned_index_dict=dict(zip(range(self.init_index_shift,self.fin_index_shift+1), range(len(self.sequence)))) self.mapped_distance_matrix=self.distance_matrix - + @classmethod def full_pdb(cls,pdb_file: Union[Path,str], chain: Union[str,None]=None, aligned_sequence: str = None, filtered_aligned_sequence: str = None, distance_matrix_method:str = 'CB', pdb_directory: Path = Path.cwd(), repair_pdb:bool = True): diff --git a/frustratometer/classes/__init__.py b/frustratometer/classes/__init__.py index 6621727b..d469d8c0 100644 --- a/frustratometer/classes/__init__.py +++ b/frustratometer/classes/__init__.py @@ -7,7 +7,7 @@ """ from .DCA import DCA -from .AWSEM import AWSEM +from .AWSEM import AWSEM, AWSEMIndicators, DecoyEnsemble, AWSEMVariancePotts from .Structure import Structure from .Map import Map from .Gamma import Gamma diff --git a/frustratometer/frustration/__init__.py b/frustratometer/frustration/__init__.py index 7d6d920c..fed7a071 100644 --- a/frustratometer/frustration/__init__.py +++ b/frustratometer/frustration/__init__.py @@ -7,6 +7,8 @@ """ from .frustration import * +#from .numba_hamiltonian import * + __all__ = ['compute_mask', 'compute_native_energy', 'compute_fields_energy', 'compute_couplings_energy', 'compute_sequences_energy', 'compute_singleresidue_decoy_energy_fluctuation', @@ -14,4 +16,4 @@ 'compute_contact_decoy_energy_fluctuation', 'compute_decoy_energy', 'compute_aa_freq', 'compute_contact_freq', 'compute_single_frustration', 'compute_pair_frustration', 'compute_scores', 'compute_roc', 'compute_auc', 'plot_roc', 'plot_singleresidue_decoy_energy', 'write_tcl_script', - 'call_vmd', 'canvas'] + 'call_vmd', 'canvas', 'ham'] diff --git a/frustratometer/frustration/frustration.py b/frustratometer/frustration/frustration.py index 6f2b9baa..93374813 100644 --- a/frustratometer/frustration/frustration.py +++ b/frustratometer/frustration/frustration.py @@ -4,8 +4,6 @@ from typing import Union from pathlib import Path -_AA = '-ACDEFGHIKLMNPQRSTVWY' - def compute_mask(distance_matrix: np.array, maximum_contact_distance: Union[float, None] = None, minimum_sequence_separation: Union[int, None] = None) -> np.array: @@ -58,6 +56,7 @@ def compute_mask(distance_matrix: np.array, def compute_native_energy(seq: str, potts_model: dict, mask: np.array, + AA : str, ignore_gap_couplings: bool = False, ignore_gap_fields: bool = False) -> float: @@ -107,7 +106,7 @@ def compute_native_energy(seq: str, .. todo:: Optimize the computation. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) pos1, pos2 = np.meshgrid(np.arange(seq_len), np.arange(seq_len), indexing='ij', sparse=True) @@ -133,6 +132,7 @@ def compute_native_energy(seq: str, def compute_fields_energy(seq: str, potts_model: dict, + AA : str, ignore_fields_of_gaps: bool = False) -> float: """ Computes the fields energy of a protein sequence based on a given Potts model. @@ -165,7 +165,7 @@ def compute_fields_energy(seq: str, >>> fields_energy = compute_fields_energy(seq, potts_model) >>> print(f"Computed fields energy: {fields_energy:.2f}") """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) h = -potts_model['h'][range(seq_len), seq_index] @@ -180,6 +180,7 @@ def compute_fields_energy(seq: str, def compute_couplings_energy(seq: str, potts_model: dict, mask: np.array, + AA : str, ignore_couplings_of_gaps: bool = False) -> float: """ Computes the couplings energy of a protein sequence based on a given Potts model and an interaction mask. @@ -223,7 +224,7 @@ def compute_couplings_energy(seq: str, .. todo:: Optimize the computation. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) pos1, pos2 = np.meshgrid(np.arange(seq_len), np.arange(seq_len), indexing='ij', sparse=True) aa1, aa2 = np.meshgrid(seq_index, seq_index, indexing='ij', sparse=True) @@ -241,6 +242,7 @@ def compute_couplings_energy(seq: str, def compute_sequences_energy(seqs: list, potts_model: dict, mask: np.array, + AA : str, split_couplings_and_fields = False) -> np.array: """ Computes the energy of multiple protein sequences based on a given Potts model and an interaction mask. @@ -288,7 +290,7 @@ def compute_sequences_energy(seqs: list, .. todo:: Optimize the computation. """ - seq_index = np.array([[_AA.find(aa) for aa in seq] for seq in seqs]) + seq_index = np.array([[AA.index(aa) for aa in seq] for seq in seqs]) N_seqs, seq_len = seq_index.shape pos_index=np.repeat([np.arange(seq_len)], N_seqs,axis=0) @@ -312,7 +314,8 @@ def compute_sequences_energy(seqs: list, def compute_singleresidue_decoy_energy_fluctuation(seq: str, potts_model: dict, - mask: np.array) -> np.array: + mask: np.array, + AA : str) -> np.array: """ Computes a (Lx21) matrix for a sequence of length L. Row i contains all possible changes in energy upon mutating residue i. @@ -325,14 +328,14 @@ def compute_singleresidue_decoy_energy_fluctuation(seq: str, seq : str The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. Gaps are represented as '-'. The length of the sequence (L) should match the dimensions of the Potts model. potts_model : dict - A dictionary containing the Potts model parameters 'h' (fields) and 'J' (couplings). The fields are a 2D array of shape (L, 20), where L is the length of the sequence and 20 is the number of amino acids. The couplings are a 4D array of shape (L, L, 20, 20). The fields and couplings are assumed to be in units of energy. + A dictionary containing the Potts model parameters 'h' (fields) and 'J' (couplings). The fields are a 2D array of shape (L, q), where L is the length of the sequence and q is the number of amino acids. The couplings are a 4D array of shape (L, L, q, q). The fields and couplings are assumed to be in units of energy. mask : np.array A 2D Boolean array that determines which residue pairs should be considered in the energy computation. The mask should have dimensions (L, L), where L is the length of the sequence. Returns ------- decoy_energy: np.array - (Lx21) matrix describing the energetic changes upon mutating a single residue. + (Lxq) matrix describing the energetic changes upon mutating a single residue. Examples -------- @@ -354,16 +357,18 @@ def compute_singleresidue_decoy_energy_fluctuation(seq: str, .. todo:: Optimize the computation. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + q = len(AA) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) # Create decoys - pos1, aa1 = np.meshgrid(np.arange(seq_len), np.arange(21), indexing='ij', sparse=True) + pos1, aa1 = np.meshgrid(np.arange(seq_len), np.arange(q), indexing='ij', sparse=True) - decoy_energy = np.zeros([seq_len, 21]) + decoy_energy = np.zeros([seq_len, q]) + # potts_model['h'][pos1, aa1] == potts_model['h'] decoy_energy -= (potts_model['h'][pos1, aa1] - potts_model['h'][pos1, seq_index[pos1]]) # h correction aa1 - j_correction = np.zeros([seq_len, seq_len, 21]) + j_correction = np.zeros([seq_len, seq_len, q]) # J correction interactions with other aminoacids reduced_j = potts_model['J'][range(seq_len), :, seq_index, :].astype(np.float32) j_correction += reduced_j[:, pos1, seq_index[pos1]] * mask[:, pos1] @@ -377,9 +382,11 @@ def compute_singleresidue_decoy_energy_fluctuation(seq: str, def compute_mutational_decoy_energy_fluctuation(seq: str, potts_model: dict, - mask: np.array, ) -> np.array: + mask: np.array, + AA : str) -> np.array: """ - Computes a (LxLx21x21) matrix for a sequence of length L. Matrix[i,j] describes all possible changes in energy upon mutating residue i and j simultaneously. + Computes a (LxLxqxq) matrix for a sequence of length L and AA of length q. + Matrix[i,j] describes all possible changes in energy upon mutating residue i and j simultaneously. .. math:: \Delta H_{ij} = H_i - H_{i'} + H_{j}-H_{j'} + J_{ij} -J_{ij'} + J_{i'j'} - J_{i'j} + \\sum_k {J_{ik} - J_{i'k} + J_{jk} -J_{j'k}} @@ -389,14 +396,14 @@ def compute_mutational_decoy_energy_fluctuation(seq: str, seq : str The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. Gaps are represented as '-'. The length of the sequence (L) should match the dimensions of the Potts model. potts_model : dict - A dictionary containing the Potts model parameters 'h' (fields) and 'J' (couplings). The fields are a 2D array of shape (L, 20), where L is the length of the sequence and 20 is the number of amino acids. The couplings are a 4D array of shape (L, L, 20, 20). The fields and couplings are assumed to be in units of energy. + A dictionary containing the Potts model parameters 'h' (fields) and 'J' (couplings). The fields are a 2D array of shape (L, q), where L is the length of the sequence and q is the number of amino acids. The couplings are a 4D array of shape (L, L, q, q). The fields and couplings are assumed to be in units of energy. mask : np.array A 2D Boolean array that determines which residue pairs should be considered in the energy computation. The mask should have dimensions (L, L), where L is the length of the sequence. Returns ------- decoy_energy2: np.array - (LxLx21x21) matrix describing the energetic changes upon mutating two residues simultaneously. + (LxLxqxq) matrix describing the energetic changes upon mutating two residues simultaneously. Examples -------- @@ -418,23 +425,23 @@ def compute_mutational_decoy_energy_fluctuation(seq: str, .. todo:: Optimize the computation. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + q = len(AA) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) - # Create masked decoys - pos1,pos2=np.where(mask>0) + # get indices and amino acid types for just the unmasked contacts + pos1,pos2=np.where(mask>0) contacts_len=len(pos1) - - pos1,aa1,aa2=np.meshgrid(pos1, np.arange(21), np.arange(21), indexing='ij', sparse=True) - pos2,aa1,aa2=np.meshgrid(pos2, np.arange(21), np.arange(21), indexing='ij', sparse=True) + pos1,aa1,aa2=np.meshgrid(pos1, np.arange(q), np.arange(q), indexing='ij', sparse=True) + pos2,aa1,aa2=np.meshgrid(pos2, np.arange(q), np.arange(q), indexing='ij', sparse=True) #Compute fields - decoy_energy = np.zeros([contacts_len, 21, 21]) + decoy_energy = np.zeros([contacts_len, q, q]) decoy_energy -= (potts_model['h'][pos1, aa1] - potts_model['h'][pos1, seq_index[pos1]]) # h correction aa1 decoy_energy -= (potts_model['h'][pos2, aa2] - potts_model['h'][pos2, seq_index[pos2]]) # h correction aa2 #Compute couplings - j_correction = np.zeros([contacts_len, 21, 21]) + j_correction = np.zeros([contacts_len, q, q]) for pos, aa in enumerate(seq_index): # J correction interactions with other aminoacids reduced_j = potts_model['J'][pos, :, aa, :].astype(np.float32) @@ -449,14 +456,15 @@ def compute_mutational_decoy_energy_fluctuation(seq: str, j_correction -= potts_model['J'][pos1, pos2, aa1, aa2] * mask[pos1, pos2] # Correct combination decoy_energy += j_correction - decoy_energy2=np.zeros([seq_len,seq_len,21,21]) + decoy_energy2=np.zeros([seq_len,seq_len,q,q]) decoy_energy2[mask]=decoy_energy return decoy_energy2 def compute_configurational_decoy_energy_fluctuation(seq: str, potts_model: dict, - mask: np.array, ) -> np.array: + mask: np.array, + AA : str) -> np.array: """ Computes a (LxLx21x21) matrix for a sequence of length L. Matrix[i,j] describes all possible changes in energy upon mutating and altering the local densities of residue i and j simultaneously. @@ -498,7 +506,7 @@ def compute_configurational_decoy_energy_fluctuation(seq: str, .. todo:: Optimize the computation. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) # Create masked decoys @@ -536,7 +544,8 @@ def compute_configurational_decoy_energy_fluctuation(seq: str, def compute_contact_decoy_energy_fluctuation(seq: str, potts_model: dict, - mask: np.array) -> np.array: + mask: np.array, + AA : str) -> np.array: r""" $$ \Delta DCA_{ij} = \Delta j_{ij} $$ :param seq: @@ -545,7 +554,7 @@ def compute_contact_decoy_energy_fluctuation(seq: str, :return: """ - seq_index = np.array([_AA.find(aa) for aa in seq]) + seq_index = np.array([AA.index(aa) for aa in seq]) seq_len = len(seq_index) # Create decoys @@ -559,7 +568,7 @@ def compute_contact_decoy_energy_fluctuation(seq: str, return decoy_energy -def compute_decoy_energy(seq: str, potts_model: dict, mask: np.array, kind='singleresidue') -> np.array: +def compute_decoy_energy(seq: str, potts_model: dict, mask: np.array, AA : str, kind='singleresidue') -> np.array: """ Computes all possible decoy energies. @@ -600,18 +609,18 @@ def compute_decoy_energy(seq: str, potts_model: dict, mask: np.array, kind='sing .. todo:: Optimize the computation. """ - native_energy = compute_native_energy(seq, potts_model, mask) + native_energy = compute_native_energy(seq, potts_model, mask, AA) if kind == 'singleresidue': - decoy_energy=native_energy + compute_singleresidue_decoy_energy_fluctuation(seq, potts_model, mask) + decoy_energy=native_energy + compute_singleresidue_decoy_energy_fluctuation(seq, potts_model, mask, AA) elif kind == 'mutational': - decoy_energy=native_energy + compute_mutational_decoy_energy_fluctuation(seq, potts_model, mask) + decoy_energy=native_energy + compute_mutational_decoy_energy_fluctuation(seq, potts_model, mask, AA) elif kind == 'configurational': - decoy_energy=native_energy + compute_configurational_decoy_energy_fluctuation(seq, potts_model, mask) + decoy_energy=native_energy + compute_configurational_decoy_energy_fluctuation(seq, potts_model, mask, AA) elif kind == 'contact': - decoy_energy=native_energy + compute_contact_decoy_energy_fluctuation(seq, potts_model, mask) + decoy_energy=native_energy + compute_contact_decoy_energy_fluctuation(seq, potts_model, mask, AA) return decoy_energy -def compute_aa_freq(seq, include_gaps=True): +def compute_aa_freq(seq, AA, include_gaps=True,): """ Calculates amino acid frequencies in given sequence @@ -619,6 +628,8 @@ def compute_aa_freq(seq, include_gaps=True): ---------- seq : str The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. Gaps are represented as '-'. + AA : str + The alphabet of allowed residues include_gaps: bool If True, frequencies of gaps ('-') in the sequence are set to 0. Default is True. @@ -627,16 +638,17 @@ def compute_aa_freq(seq, include_gaps=True): Returns ------- aa_freq: np.array - Array of frequencies of all 21 possible amino acids within sequence + Array of frequencies of all q possible amino acids within sequence """ - seq_index = np.array([_AA.find(aa) for aa in seq]) - aa_freq = np.array([(seq_index == i).sum() for i in range(21)]) + q = len(AA) + seq_index = np.array([AA.index(aa) for aa in seq]) + aa_freq = np.array([(seq_index == i).sum() for i in range(q)]) if not include_gaps: aa_freq[0] = 0 return aa_freq -def compute_contact_freq(seq): +def compute_contact_freq(seq, AA): """ Calculates contact frequencies in given sequence @@ -644,14 +656,17 @@ def compute_contact_freq(seq): ---------- seq : str The amino acid sequence of the protein. The sequence is assumed to be in one-letter code. Gaps are represented as '-'. - + AA : str + The alphabet of allowed residues + Returns ------- contact_freq: np.array - 21x21 array of frequencies of all possible contacts within sequence. + qxq array of frequencies of all possible contacts within sequence. """ - seq_index = np.array([_AA.find(aa) for aa in seq]) - aa_freq = np.array([(seq_index == i).sum() for i in range(21)], dtype=np.float64) + q = len(AA) + seq_index = np.array([AA.index(aa) for aa in seq]) + aa_freq = np.array([(seq_index == i).sum() for i in range(q)], dtype=np.float64) aa_freq /= aa_freq.sum() contact_freq = (aa_freq[:, np.newaxis] * aa_freq[np.newaxis, :]) return contact_freq @@ -666,17 +681,18 @@ def compute_single_frustration(decoy_fluctuation, Parameters ---------- decoy_fluctuation: np.array - (Lx21) matrix for a sequence of length L, describing the energetic changes upon mutating a single residue. + (Lxq) matrix for a sequence of length L, describing the energetic changes upon mutating a single residue. aa_freq: np.array - Array of frequencies of all 21 possible amino acids within sequence + Array of frequencies of all q possible amino acids within sequence Returns ------- frustration: np.array Array of length L featuring single residue frustration indices. """ + q = decoy_fluctuation.shape[1] if aa_freq is None: - aa_freq = np.ones(21) + aa_freq = np.ones(q) mean_energy = (aa_freq * decoy_fluctuation).sum(axis=1) / aa_freq.sum() std_energy = np.sqrt( ((aa_freq * (decoy_fluctuation - mean_energy[:, np.newaxis]) ** 2) / aa_freq.sum()).sum(axis=1)) @@ -694,9 +710,9 @@ def compute_pair_frustration(decoy_fluctuation, Parameters ---------- decoy_fluctuation: np.array - (LxLx21x21) matrix for a sequence of length L, describing the energetic changes upon mutating two residues simultaneously. + (LxLxqxq) matrix for a sequence of length L, describing the energetic changes upon mutating two residues simultaneously. contact_freq: np.array - 21x21 array of frequencies of all possible contacts within sequence. + qxq array of frequencies of all possible contacts within sequence. Returns ------- @@ -704,12 +720,17 @@ def compute_pair_frustration(decoy_fluctuation, LxL array featuring pair frustration indices (mutational or configurational frustration, depending on decoy_fluctuation matrix provided) """ + q = decoy_fluctuation.shape[2] # also could have chosen decoy_fluctuation.shape[3] if contact_freq is None: - contact_freq = np.ones([21, 21]) + contact_freq = np.ones([q, q]) decoy_energy = decoy_fluctuation seq_len = decoy_fluctuation.shape[0] - average = np.average(decoy_energy.reshape(seq_len * seq_len, 21 * 21), weights=contact_freq.flatten(), axis=-1) - variance = np.average((decoy_energy.reshape(seq_len * seq_len, 21 * 21) - average[:, np.newaxis]) ** 2, + try: + average = np.average(decoy_energy.reshape(seq_len * seq_len, q * q), weights=contact_freq.flatten(), axis=-1) + except: + raise Exception(f'contact_freq.shape: {contact_freq.shape}, decoy_flucuation.shape: {decoy_fluctuation.shape}') + + variance = np.average((decoy_energy.reshape(seq_len * seq_len, q * q) - average[:, np.newaxis]) ** 2, weights=contact_freq.flatten(), axis=-1) mean_energy = average.reshape(seq_len, seq_len) std_energy = np.sqrt(variance).reshape(seq_len, seq_len) @@ -820,14 +841,14 @@ def plot_roc(roc_score): plt.plot([0, 1], [0, 1], '--') -def plot_singleresidue_decoy_energy(decoy_energy, native_energy, method='clustermap'): +def plot_singleresidue_decoy_energy(decoy_energy, native_energy, AA, method='clustermap'): """ Plot comparison of single residue decoy energies, relative to the native energy Parameters ---------- decoy_energy : np.array - Lx21 array of decoy energies + Lxq array of decoy energies native_energy : float Native energy value method : str @@ -841,7 +862,7 @@ def plot_singleresidue_decoy_energy(decoy_energy, native_energy, method='cluster g = f(decoy_energy, cmap='RdBu_r', vmin=native_energy - decoy_energy.std() * 3, vmax=native_energy + decoy_energy.std() * 3) - AA_dict = {str(i): _AA[i] for i in range(len(_AA))} + AA_dict = {str(i): AA[i] for i in range(len(AA))} new_ticklabels = [] if method == 'clustermap': ax_heatmap = g.ax_heatmap @@ -890,20 +911,28 @@ def write_tcl_script(pdb_file: Union[Path,str], chain: str, mask: np.array, dist tcl_script : Path or str tcl script file """ + fo = open(tcl_script, 'w+') single_frustration = np.nan_to_num(single_frustration,nan=0,posinf=0,neginf=0) pair_frustration = np.nan_to_num(pair_frustration,nan=0,posinf=0,neginf=0) structure = prody.parsePDB(str(pdb_file)) - selection = structure.select('protein', chain=chain) + if chain is not None: + selection = structure.select('protein', chain=chain) + else: + selection = structure.select('protein', chain='_') # select all chains residues = np.unique(selection.getResnums()) fo.write(f'[atomselect top all] set beta 0\n') # Single residue frustration for r, f in zip(residues, single_frustration): # print(f) - fo.write(f'[atomselect top "chain {chain} and residue {int(r)}"] set beta {f}\n') + if chain is not None: + fo.write(f'[atomselect top "chain {chain} and residue {int(r)}"] set beta {f}\n') + else: + fo.write(f'[atomselect top "residue {int(r)}"] set beta {f}\n') # 'residue' corresponds to unique residue id in vmd, + # so this is okay if there are multiple chains # Mutational frustration: r1, r2 = np.meshgrid(residues, residues, indexing='ij') @@ -929,13 +958,21 @@ def write_tcl_script(pdb_file: Union[Path,str], chain: str, mask: np.array, dist r2=int(r2) if abs(r1-r2) == 1: # don't draw interactions between residues adjacent in sequence continue - pos1 = selection.select(f'resid {r1} and chain {chain} and (name CB or (resname GLY and name CA))').getCoords()[0] - pos2 = selection.select(f'resid {r2} and chain {chain} and (name CB or (resname GLY and name CA))').getCoords()[0] + if chain is not None: + pos1 = selection.select(f'resid {r1} and chain {chain} and (name CB or (resname GLY and name CA))').getCoords()[0] + pos2 = selection.select(f'resid {r2} and chain {chain} and (name CB or (resname GLY and name CA))').getCoords()[0] + else: + pos1 = selection.select(f'resid {r1} and (name CB or (resname GLY and name CA))').getCoords()[0] + pos2 = selection.select(f'resid {r2} and (name CB or (resname GLY and name CA))').getCoords()[0] distance = np.linalg.norm(pos1 - pos2) if d > 9.5 or d < 3.5: continue - fo.write(f'lassign [[atomselect top "resid {r1} and name CA and chain {chain}"] get {{x y z}}] pos1\n') - fo.write(f'lassign [[atomselect top "resid {r2} and name CA and chain {chain}"] get {{x y z}}] pos2\n') + if chain is not None: + fo.write(f'lassign [[atomselect top "resid {r1} and name CA and chain {chain}"] get {{x y z}}] pos1\n') + fo.write(f'lassign [[atomselect top "resid {r2} and name CA and chain {chain}"] get {{x y z}}] pos2\n') + else: + fo.write(f'lassign [[atomselect top "resid {r1} and name CA"] get {{x y z}}] pos1\n') + fo.write(f'lassign [[atomselect top "resid {r2} and name CA"] get {{x y z}}] pos2\n') if 3.5 <= distance <= 6.5: fo.write(f'draw line $pos1 $pos2 style solid width 2\n') else: @@ -953,8 +990,12 @@ def write_tcl_script(pdb_file: Union[Path,str], chain: str, mask: np.array, dist r2=int(r2) if d > 9.5 or d < 3.5: continue - fo.write(f'lassign [[atomselect top "resid {r1} and name CA and chain {chain}"] get {{x y z}}] pos1\n') - fo.write(f'lassign [[atomselect top "resid {r2} and name CA and chain {chain}"] get {{x y z}}] pos2\n') + if chain is not None: + fo.write(f'lassign [[atomselect top "resid {r1} and name CA and chain {chain}"] get {{x y z}}] pos1\n') + fo.write(f'lassign [[atomselect top "resid {r2} and name CA and chain {chain}"] get {{x y z}}] pos2\n') + else: + fo.write(f'lassign [[atomselect top "resid {r1} and name CA"] get {{x y z}}] pos1\n') + fo.write(f'lassign [[atomselect top "resid {r2} and name CA"] get {{x y z}}] pos2\n') if 3.5 <= d <= 6.5: fo.write(f'draw line $pos1 $pos2 style solid width 2\n') else: diff --git a/frustratometer/numba_util/frustration_algorithms.py b/frustratometer/numba_util/frustration_algorithms.py new file mode 100644 index 00000000..3409ae32 --- /dev/null +++ b/frustratometer/numba_util/frustration_algorithms.py @@ -0,0 +1,396 @@ +""" +Functions for frustration calculations with numba. +Relies upon numba_hamiltonian module to evaluate the potential. + +Sometimes, the Potts model for a system requires more RAM than we have available. +One solution to this challenge is to calculate energies on the fly instead of +storing them in a massive array. To speed up evaluation of the many loops needed +to calculate quantities on the fly, we would like to use numba. Unfortunately, +numba struggles to jit-compile most python objects, like Structure and +AWSEM. Our solution is to define functions that take attributes from our +python objects as parameters, which we can then jit without issue. + +The object-oriented interfaces found elsewhere in this repository should +offer an option called something like "use_numba" or "ram_limited" that, +when set to True, results in these numba utilities being called. +""" + +import numpy as np +import numba +from numba import njit, prange, int64, float64, boolean + +from . import hamiltonian as ham + +signature = numba.types.UniTuple(float64,2)( + float64[:,:], + float64[:], float64[:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64, int64[:], int64[:]) +def pair_decoy_stats( + allowed_thetaIthetaIIelectrostatic, + allowed_rho_i, allowed_rho_j, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + n_decoys, seq_index_i, seq_index_j): + """ + Generate distribution of pair energies by randomly sampling + indicators and gammas, then compute the mean and + and standard deviation of the distribution. + + The sampling is performed in the following way: + - Randomly select a row from thetaIthetaIIelectrostatic_array, + representing the pairwise distance-based indicator functions + for a particular pair of residues (i,j) + - Randomly select a rho value for residue i from allowed_rho_i + - Randomly select a rho value for residue j from allowed_rho_j + - Randomly select an amino acid type for residue i + from seq_index_i + - Randomly select an amino acid type for residue j + from seq_index_j + - Randomly select an amino acid type for residue j + - Get the appropriate gammas for the pair (i,j) + - Compute the pair energy given the indicators and gammas + + When writing this function, I'm thinking of seq_index_i and + seq_index_j as the seq_index of the protein (list equal to + the length of the protein where each element represents the + amino acid type at its position). But you can get aa_freq + behavior by replacing seq_index with a different array having + different amino acid types in your desired proportions. + + Similarly, I'm thinking of allowed_thetaIthetaIIelectrostatic as + including one set of {thetaI, thetaII, electrostatic_indicator} + for each pair of residues in the protein meeting some mask + condition (applied by the user before calling this function). + + Note that this function uses the deprecated np.random.choice() + function to take a uniform random sample of our arrays. Getting + random number generators to work with numba is tricky, so it's + probably best to stick with this way of doing things. + + Parameters + ---------- + - allowed_thetaIthetaIIelectrostatic : np.array(C_1, 3) + thetaI, thetaII, and electrostatic indicator values for + all C_1 allowed contacts. + Each set {thetaI_i, thetaII_i, electrostatic_indicator_i} + should be repeated multiple times in proportion to the + desired probability. + - allowed_rho_i : np.array(C_2,) + All C_2 choices of rho allowed for residue "i". + Each unique value should be repeated multiple times + in proportion to the desired probability. + - allowed_rho_j : np.array(C_3,) + All C_3 choices of rho allowed for residue "j". + Each unique value should be repeated multiple times + in proportion to the desired probability. + - lambda_direct : float + Scale factor for direct interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_protein and lambda_water. + - direct_gamma : np.array(20,20) + Array formatted in the same way as self.direct_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + - lambda_protein : float + Scale factor for protein-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_water. + - protein_gamma : np.array(20,20) + Array formatted in the same way as self.protein_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + - lambda_water : float + Scale factor for water-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_protein. + - water_gamma : np.array(20,20) + Array formatted in the same way as self.water_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + - lambda_burial : float + Scale factor for burial interaction energies. + Should be 1 kcal/mol (4.184 kJ/mol). + - burial_gamma : np.array(20,3) + Array formatted in the same way as self.burial_gamma from the AWSEM class. + Order along axis 0 may vary (ACDE vs. ARND), but should always be ordered as + [low density, medium density, high density] along axis 1. + - lambda_electrostatic : float + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "lambda" is the conversion factor from fundamental + charge units to kJ/mol, adjusted for the (uniform component of the) + solvent dielectric screening. (Heterogeneities in the solvation structures of + ions are accounted for in the electrostatics indicator function). + - electrostatic_gamma : np.array(20,20) + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "gamma" is the product of the expected fundamental charges + of the side chain -- usually +/-1, but we could do -2 for phosphorylation + - n_decoys : int + Number of samples draw to construct the distribution of pair energies. + Ideally, n_decoys = infinity. + - seq_index_i : np.array(C_4,) + All C_4 choices of amino acid type allowed for residue "i". + Necessarily repeats amino acid types for C_4 > 20. + Each unique value should be repeated multiple times + in proportion to the desired probability. + - seq_index_j : np.array(C_5,) + All C_5 choices of amino acid type allowed for residue "j". + Necessarily repeats amino acid types for C_5 > 20. + Each unique value should be repeated multiple times + in proportion to the desired probability. + + Returns + ------- + mean : float + Average energy of the decoys + stdev : float + Standard deviation of the energies of the decoys + """ + # randomly choose (with replacement) indices to sample, + # then generate arrays containing the randomly sampled values + thetaIthetaIIelectrostatic_array = allowed_thetaIthetaIIelectrostatic\ + [np.random.choice(allowed_thetaIthetaIIelectrostatic.shape[0],size=n_decoys),:] + rho_i_array = allowed_rho_i[np.random.choice(allowed_rho_i.shape[0],size=n_decoys)] + rho_j_array = allowed_rho_j[np.random.choice(allowed_rho_j.shape[0],size=n_decoys)] + aa_i_array = seq_index_i[np.random.choice(seq_index_i.shape[0],size=n_decoys)] + aa_j_array = seq_index_j[np.random.choice(seq_index_j.shape[0],size=n_decoys)] + # calculate pair energies and fill array + pair_energies = np.zeros(n_decoys) + for counter in prange(n_decoys): + thetaI = thetaIthetaIIelectrostatic_array[counter,0] + thetaII = thetaIthetaIIelectrostatic_array[counter,1] + electrostatic_indicator = thetaIthetaIIelectrostatic_array[counter,2] + rho_i = rho_i_array[counter] + rho_j = rho_j_array[counter] + aa_i = seq_index_i[aa_i_array[counter]] + aa_j = seq_index_j[aa_j_array[counter]] + gamma_bi = burial_gamma[aa_i,:] + gamma_bj = burial_gamma[aa_j,:] + gamma_d = direct_gamma[aa_i, aa_j] + gamma_p = protein_gamma[aa_i, aa_j] + gamma_w = water_gamma[aa_i, aa_j] + gamma_e = electrostatic_gamma[aa_i, aa_j] + pair_energy = ham.compute_pair_energy_ij_useful( + rho_i, rho_j, thetaI, thetaII, electrostatic_indicator, + lambda_direct, gamma_d, lambda_protein, gamma_p, lambda_water, gamma_w, + lambda_burial, gamma_bi, gamma_bj, lambda_electrostatic, gamma_e) + #burial_energy_i = ham.compute_burial_potential_i_from_rho_gamma(rho_i, lambda_burial, gamma_bi) + #burial_energy_j = ham.compute_burial_potential_i_from_rho_gamma(rho_j, lambda_burial, gamma_bj) + #direct_energy = ham.compute_direct_potential_ij_from_thetaI_gamma(thetaI, lambda_direct, gamma_d) + #protein_energy, water_energy = ham.compute_long_potentials_ij_from_rho_thetaII_gamma( + # rho_i, rho_j, thetaII, lambda_protein, gamma_p, lambda_water, gamma_w) + #electrostatic_energy = ham.compute_electrostatic_potential_ij_from_indicator_gamma( + # lambda_electrostatic, gamma_e, electrostatic_indicator) + #pair_energy = burial_energy_i + burial_energy_j + direct_energy +\ + # protein_energy + water_energy + electrostatic_energy + pair_energies[counter] = pair_energy + mean = np.average(pair_energies) + stdev = np.std(pair_energies) + return mean, stdev +pair_decoy_stats_parallel = njit(signature_or_function=signature, parallel=True)(pair_decoy_stats) +pair_decoy_stats = njit(signature_or_function=signature)(pair_decoy_stats) +# +@njit(signature_or_function=numba.types.UniTuple(float64,2)( + float64, int64, int64[:], int64[:], float64[:,:], + float64, float64, + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64, int64[:]), + parallel=True) # we definitely want to parallelize this function +def standard_config_decoy_stats( + l_D, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, + min_dist_decoy_gen, max_dist_decoy_gen, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + n_decoys, seq_index): + """ + Get mean and standard deviation of decoy energies + following the standard configurational frustration algorithm. + + Parameters + ---------- + l_D : float + Screening length for Debye-Huckel electrostatics, in units of Angstroms + min_seq_sep_rho : int + The minimum distance in sequence for two residues to contribute to each others' + rho. Include i,j (i.e., set mask bit bool to True) if |i-j| >= min_seq_sep_rho. + chain_starts : np.array(N_c) + List of 0-indexed residue indices marking the start of each chain, + for example, array([0]) for the case of a single chain (N_c==1). + chain_ends : np.array(N_c) + List of 0-indexed residue indices marking the end of each chain, + for example, array([L-1]) for the case of a single chain (N_c==1). + dist_mat : np.array(L,L) + Pairwise distance matrix for the entire protein system + min_dist_decoy_gen : float + Discard distances lower than this value from the distribution + max_dist_decoy_gen : float + Discard distances greater than this value from the distribution + lambda_direct : float + Scale factor for direct interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_protein and lambda_water. + direct_gamma : np.array(20,20) + Array formatted in the same way as self.direct_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + lambda_protein : float + Scale factor for protein-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_water. + protein_gamma : np.array(20,20) + Array formatted in the same way as self.protein_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + lambda_water : float + Scale factor for water-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_protein. + water_gamma : np.array(20,20) + Array formatted in the same way as self.water_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). + lambda_burial : float + Scale factor for burial interaction energies. + Should be 1 kcal/mol (4.184 kJ/mol). + burial_gamma : np.array(20,3) + Array formatted in the same way as self.burial_gamma from the AWSEM class. + Order along axis 0 may vary (ACDE vs. ARND), but should always be ordered as + [low density, medium density, high density] along axis 1. + lambda_electrostatic : float + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "lambda" is the conversion factor from fundamental + charge units to kJ/mol, adjusted for the (uniform component of the) + solvent dielectric screening. (Heterogeneities in the solvation structures of + ions are accounted for in the electrostatics indicator function). + electrostatic_gamma : np.array(20,20) + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "gamma" is the product of the expected fundamental charges + of the side chain -- usually +/-1, but we could do -2 for phosphorylation + n_decoys : int + Number of samples draw to construct the distribution of pair energies. + Ideally, n_decoys = infinity. + seq_index : np.array(L,) + Array equal in length to the number of amino acids in the protein, + where each element is the numerical code for the amino acid + at that position. Numerical codes are determined by the position + of the one-letter code of the amino acid in the string of all + one-letter amino acid codes, and so should range from 0 to 19. + The string of all one-letter amino acid codes is probably + "ARND..." or "ACDE...", alphabetical by 3-letter code or 1-letter code. + + Returns + ------- + mean : float + Average energy of the decoys + stdev : float + Standard deviation of the energies of the decoys + """ + # calculate rho + C_2 = dist_mat.shape[0] + allowed_rho_i = np.zeros(C_2) + for counter in prange(C_2): + allowed_rho_i[counter] = ham.compute_rho_i(counter, + min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + allowed_rho_j = allowed_rho_i + # calculate distance-based indicators + #triu_indices = np.triu_indices(C_2,k=1) + #distances = dist_mat[triu_indices[0], triu_indices[1]] + #distances = distances[(distances<=max_dist_decoy_gen)&(distances>=min_dist_decoy_gen)] + distances = np.zeros(((C_2**2)-C_2)//2) # maximum possible number of distances + num_distances = 0 + for i in range(C_2): + for j in range(i+1, C_2): + dist_ij = dist_mat[i,j] + if min_dist_decoy_gen <= dist_ij <= max_dist_decoy_gen: + distances[num_distances] = dist_ij + num_distances += 1 + distances = distances[:num_distances+1] + C_1 = distances.shape[0] + allowed_thetaIthetaIIelectrostatic = np.zeros((C_1, 3)) + for counter in prange(C_1): + dist_ij = distances[counter] + allowed_thetaIthetaIIelectrostatic[counter,0] = ham.compute_thetaI(dist_ij) + allowed_thetaIthetaIIelectrostatic[counter,1] = ham.compute_thetaII(dist_ij) + allowed_thetaIthetaIIelectrostatic[counter,2] = ham.compute_electrostatic_indicator(l_D, dist_ij) + # assign pools of aa types to draw from + seq_index_i = seq_index + seq_index_j = seq_index + # send our formatted data to numba function for rapid sampling + mean, stdev = pair_decoy_stats(allowed_thetaIthetaIIelectrostatic, + allowed_rho_i, allowed_rho_j, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + n_decoys, seq_index_i, seq_index_j) + return mean, stdev +# + + +## no numba for this function +#def compute_frustration_matrix(dist_mat, +# min_seq_sep_rho, min_seq_sep_frust_index, +# chain_starts, chain_ends, +# seq_index, +# lambda_direct, direct_gamma, +# lambda_protein, protein_gamma, +# lambda_water, water_gamma, +# lambda_burial, burial_gamma, +# lambda_electrostatic, electrostatic_gamma, l_D, +# decoy_stats_method): +# """ +# Calculate matrix of frustration indices +# +# Parameters +# ---------- +# decoy_stats_method : callable +# function that returns decoy mean and standard deviation +# (recommend numba_util.pair_decoy_stats_config) +# others : +# See module-level docstring +# +# Returns +# ------- +# frustration_matrix: +# Matrix of the same shape as dist_mat, where each element (i,j) +# is, if unmasked, the frustration index of the pair (i,j), or, +# if masked, np.nan. +# """ +# pair_energy_matrix = compute_pair_energy_matrix( +# dist_mat, +# min_seq_sep_rho, min_seq_sep_frust_index, +# chain_starts, chain_ends, +# seq_index, +# lambda_direct, direct_gamma, +# lambda_protein, protein_gamma, +# lambda_water, water_gamma, +# lambda_burial, burial_gamma, +# lambda_electrostatic, electrostatic_gamma, l_D) +# mean, stdev = decoy_stats_method(dist_mat, +# min_dist_decoy_gen, max_dist_decoy_gen, +# min_seq_sep_rho, +# lambda_direct, direct_gamma, +# lambda_protein, protein_gamma, +# lambda_water, water_gamma, +# lambda_burial, burial_gamma, +# lambda_electrostatic, electrostatic_gamma, l_D) +# # will generate warnings about np.nan +# frustration_matrix = (pair_energy_matrix - mean) / stdev +# return frustration_matrix \ No newline at end of file diff --git a/frustratometer/numba_util/hamiltonian.py b/frustratometer/numba_util/hamiltonian.py new file mode 100644 index 00000000..10fb58d2 --- /dev/null +++ b/frustratometer/numba_util/hamiltonian.py @@ -0,0 +1,1587 @@ +""" +Hierarchy of functions for AMW/tertiary/frustratometer/potts +Hamiltonian calculations with numba. + +Sometimes, the Potts model for a system requires more RAM than we have available. +One solution to this challenge is to calculate energies on the fly instead of +storing them in a massive array. To speed up evaluation of the many loops needed +to calculate quantities on the fly, we would like to use numba. Unfortunately, +numba struggles to jit-compile most python objects, like Structure and +AWSEM. Our solution is to define functions that take attributes from our +python objects as parameters, which we can then jit without issue. + +The object-oriented interfaces found elsewhere in this repository should +offer an option called something like "use_numba" or "ram_limited" that, +when set to True, results in these numba utilities being called. + +Conventions +----------- + +This is the complete list of parameters that may be used by any function: +i, j, l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, +min_seq_sep_frust_index, min_seq_sep, seq_sep, chain_starts, chain_ends, +same_chain, min_dist, max_dist, dist_mat, dist_ij, rho_i, rho_j, +thetaI, thetaII, sigma_water, lambda_direct, direct_gamma, lambda_protein, +protein_gamma, gamma_p, lambda_water, water_gamma, gamma_w, lambda_burial, burial_gamma, +lambda_electrostatic, electrostatic_gamma, gamma, seq_index, parallel + +No function uses all these parameters, but all functions use a subset of +these parameters. The subset of parameters is always ordered the same +as it is in the above list. The meanings of the parameters are given +below, in order. + +Parameters to select the residue(s) for a computation +- i : int + 0-indexed position of residue "i" in the complete system +- j : int + 0-indexed position of residue "j" in the complete system + +Mathematical parameters of the indicator functions +- l_D: float + Screening length for Debye-Huckel electrostatics, in units of Angstroms + +Parameters for evaluating mask conditions +- min_seq_sep_rho : int + The minimum distance in sequence for two residues to contribute to each others' + rho. Include i,j (i.e., set mask bit bool to True) if |i-j| >= min_seq_sep_rho. +- min_seq_sep_contact : int + The minimum distance in sequence for a contact to be considered "real" and unmasked. + Include i,j (i.e., set mask bit bool to True) if |i-j| >= min_seq_sep_contact. +- min_seq_sep_electrostatic : int + The minimum distance in sequence for a charged pair to be considered "real" and unmasked. + Include i,j (i.e., set mask bit bool to True) if |i-j| >= min_seq_sep_electrostatic. +- min_seq_sep_frust_index : int + The minimum distance in sequence for a pair's frustration index to be + calculated (frustration index is set to np.nan if not satisfied) +- min_seq_sep : int + Sequence separation used to determine whether two residues "see" each other; + what it means to two residues to "see" each other depends on the context + (see min_seq_sep_contact and min_seq_sep_rho) +- seq_sep : int + Actual distance in sequence between two residues, |i-j| +- chain_starts : np.array(N_c) + List of 0-indexed residue indices marking the start of each chain, + for example, array([0]) for the case of a single chain (N_c==1). +- chain_ends : np.array(N_c) + List of 0-indexed residue indices marking the end of each chain, + for example, array([L-1]) for the case of a single chain (N_c==1). +- same_chain : bool + Whether the two residues i and j are part of the same chain +- min_dist : float + Residues closer in space than this distance are masked +- max_dist : float + Residues further in space than this distance are masked +- max_dist_contact : float + Like the plain max_dist argument (see above) +- max_dist_electrostatic : float + Like the plain max_dist argument (see above) + +Parameters holding the values of indicator functions or +quantities needed to compute indicator functions +- dist_mat : np.array (L,L) + Distance matrix for all residue pairs +- dist_ij : float + Distance between two residues, in angstroms +- rho_i : float + Rho value of residue "i" +- rho_j : float + Rho value of residue "j" +- burial_indicator : np.array(3,) + Low, medium, and high burial components for a burial indicator + function for a particular residue +- thetaI : float + Value of the short-range indicator function for a pair of residues. + This indicator function is used to compute the direct interaction + and as an input to the rho computation +- thetaII : float + Value of the long-range indicator function for a pair of residues. + This indicator function is used to compute the protein-mediated + and water-mediated interactions +- sigma_water : float + Used to determine whether a pair of residues is in a solvent- + exposed or buried environment. +- electrostatic_indicator : float + Effective interaction strength of two charged residues, + based on their distance and debye-huckel screening length + +Parameters needed to compute energies but not indicator functions +- lambda_direct : float + Scale factor for direct interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_protein and lambda_water. +- direct_gamma : np.array(20,20) + Array formatted in the same way as self.direct_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). +- gamma_d : float + Like the plain gamma argument (see below) +- lambda_protein : float + Scale factor for protein-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_water. +- protein_gamma : np.array(20,20) + Array formatted in the same way as self.protein_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). +- gamma_p : float + Like the plain gamma argument (see below), but differentiates the protein gamma from + the water gamma in the long-range potential calculation function +- lambda_water : float + Scale factor for water-mediated interaction energies. + Should probably be 1 kcal/mol (4.184 kJ/mol), + but has sometimes been set to 0.75 kcal/mol along with lambda_direct and lambda_protein. +- water_gamma : np.array(20,20) + Array formatted in the same way as self.water_gamma from the AWSEM class. + Order may vary (ACDE vs. ARND). +- gamma_w : float + Like the plain gamma argument (see below), but differentiates the water gamma from + the protein gamma in the long-range potential calculation function +- lambda_burial : float + Scale factor for burial interaction energies. + Should be 1 kcal/mol (4.184 kJ/mol). +- burial_gamma : np.array(20,3) + Array formatted in the same way as self.burial_gamma from the AWSEM class. + Order along axis 0 may vary (ACDE vs. ARND), but should always be ordered as + [low density, medium density, high density] along axis 1. +- gamma_bi : float + Like the plain gamma argument (see below), but is np.array(3,) instead of float +- gamma_bj : float + Like the plain gamma argument (see below), but is np.array(3,) instead of float +- lambda_electrostatic : float + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "lambda" is the conversion factor from fundamental + charge units to kJ/mol, adjusted for the (uniform component of the) + solvent dielectric screening. (Heterogeneities in the solvation structures of + ions are accounted for in the electrostatics indicator function). +- electrostatic_gamma : np.array(20,20) + Our electrostatic "lambda" and "gamma" are different from those for our other + terms in that they seek to represent fundamental from the bottom up, + rather than the top-down optimization followed for the other gammas. + Specifically, the "gamma" is the product of the expected fundamental charges + of the side chain -- usually +/-1, but we could do -2 for phosphorylation +- gamma_e : float + Like the plain gamma argument (see below) +- gamma : float + Scalar gamma that has been selected from a gamma array based on the + amino acid types of residues i and j +- seq_index : np.array(L,) + Array equal in length to the number of amino acids in the protein, + where each element is the numerical code for the amino acid + at that position. Numerical codes are determined by the position + of the one-letter code of the amino acid in the string of all + one-letter amino acid codes, and so should range from 0 to 19. + The string of all one-letter amino acid codes is probably + "ARND..." or "ACDE...", alphabetical by 3-letter code or 1-letter code. + +Parameters to optimize computation efficiency +- parallel : bool + Whether to call numba parallelized or not + +Notes +----- +What we call the "(AWSEM) tertiary Hamiltonian" +or the "frustratometer Hamiltonian" or the "AMW Hamiltonian" +without electrostatics was defined in its modern form in + +Papoian, Ulander, Eastwood, Luthey-Schulten, and Wolynes, +PNAS 2004 (https://www.pnas.org/doi/10.1073/pnas.0307851100) + +This paper also gave us the gammas for the contact and burial interactions. + +Electrostatics were introduced in + +Tsai, Zheng, Balamurugan, Schafer, Kim, Cheung, and Wolynes, +Prot. Sci. 2016 (https://doi.org/10.1002%2Fpro.2751) +""" + +import numpy as np +import numba +from numba import njit, prange, int64, float64, boolean + +################################################################################ +# FUNCTIONS TO CALCULATE masks, thetaI, thetaII, rho, sigma_wat, sigma_prot, +# burial_indicator, AND THE ELECTROSTATICS INDICATOR, +# GIVEN A SINGLE RESIDUE i OR A PAIR OF RESIDUES (i,j), AS APPROPRIATE +# THESE FUNCTIONS **DON'T** CHECK MASK CONDITIONS! +# +@njit(signature_or_function=boolean(int64, int64, int64[:], int64[:])) +def check_same_chain(i, j, chain_starts, chain_ends): + """ + Checks whether two zero-indexed residue indices, i and j, belong to the same chain + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + same_chain : bool + Whether i and j are in the same chain + """ + same_chain = False + for counter in range(len(chain_starts)): # should be same length as chain_ends + if (chain_starts[counter] <= i <= chain_ends[counter]) and (chain_starts[counter] <= j <= chain_ends[counter]): + same_chain = True + break # this could save us a couple iterations, probably doesn't matter + return same_chain +# +@njit(signature_or_function=boolean(int64, int64, float64, float64, boolean, float64)) +def mask_of_pair(min_seq_sep, seq_sep, min_dist, max_dist, same_chain, dist_ij): + """ + Get a bool representing whether a pair of residues having + sequence separation seq_sep and distance dist_ij should be + considered (True) or ignored (False), given the supplied parameters. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + mask_bit : bool + Whether pair should be considered unmasked (True) or masked (False). + """ + if (min_dist<=dist_ij) and (dist_ij<=max_dist) and ((min_seq_sep<=seq_sep) or (not same_chain)): + mask_bit = True + else: + mask_bit = False + return mask_bit +# +@njit(signature_or_function=float64(float64, float64, float64)) +def _compute_theta(dist_ij, r_min, r_max): + # This function may be called to evaluate either thetaI or thetaII. + # Since thetaI is used to compute both contact indicators and rho, + # we have to worry about min_seq_sep_contact vs min_seq_sep_rho. + # So we do not check the mask here, but instead check it before thetaI + # or thetaII is called. + # 5 (Angstrom^-1) is "eta" + theta = 0.25 * (1 + np.tanh(5*(dist_ij-r_min))) * (1 + np.tanh(5*(r_max-dist_ij))) + return theta +@njit(signature_or_function=float64(float64)) +def compute_thetaI(dist_ij): + """ + Computes thetaI, the short-range indicator function + that tells us whether two residues are close but not overlapping. + This function does not check whether the ij interaction should + be blocked by a mask; this should be done in the calling scope. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + thetaI : float + The short-range indicator function. + """ + return _compute_theta(dist_ij, 4.5, 6.5) +@njit(signature_or_function=float64(float64)) +def compute_thetaII(dist_ij): + """ + Computes thetaII, the long-range switching function + that tells us whether two residues are close but not in direct contact. + This function does not check whether the ij interaction should + be blocked by a mask; this should be done in the calling scope. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + thetaII : float + The long-range indicator function + """ + return _compute_theta(dist_ij, 6.5, 9.5) +# +signature = float64(int64, int64, int64[:], int64[:], float64[:,:]) +def compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat): + """ + Compute the "local density," rho, of a given 0-indexed + residue index, i. The quantity rho_i may be loosely thought of + as the number of neighbors (coordination number) of residue i. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + rho_i : float + The local density of residue i + """ + rho_i = 0.0 + for j in prange(dist_mat.shape[1]): + # check mask + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + # 2.5 and 8.5 cutoffs: effectively, + # we're truncating the potential where the indicators are almost zero + # (thetaI(2.5)==thetaI(8.5)==2.0611536367E-9) + if mask_of_pair(min_seq_sep_rho, abs(i-j), 2.5, 8.5, + same_chain, dist_mat[i,j]): + # only let the residue contribute if it isn't caught by the mask + rho_i += compute_thetaI(dist_mat[i,j]) + return rho_i +compute_rho_i_parallel = njit(signature_or_function=signature, parallel=True)(compute_rho_i) +compute_rho_i = njit(signature_or_function=signature)(compute_rho_i) +# +@njit(signature_or_function=float64[:](float64)) +def compute_burial_indicator_i(rho_i): + """ + Compute the vector-valued burial indicator function (one element + each for low, medium, and high density). + + Parameters + ---------- + See module-level docstring + + Returns + ------- + burial_indicator : np.array(3,) + The burial indicator for residue i in each well + Remember that the burial indicator is defined as ranging from 0 to 2 + """ + + burial_indicator = np.zeros(3) + # 4.0 is "burial_kappa" + burial_indicator[0] = (np.tanh(4.0*(rho_i-0.0)) + np.tanh(4.0*(3.0-rho_i))) + burial_indicator[1] = (np.tanh(4.0*(rho_i-3.0)) + np.tanh(4.0*(6.0-rho_i))) + burial_indicator[2] = (np.tanh(4.0*(rho_i-6.0)) + np.tanh(4.0*(9.0-rho_i))) + return burial_indicator +# +@njit(signature_or_function=float64(float64,float64)) +def compute_sigma_water(rho_i, rho_j): + """ + Compute sigma_water based on local densities of the two residues in the pair. + + If both residues are exposed ((rho_i < rho_0) && (rho_j < rho_0)), + then water-mediated interactions dominate (sigma_water ~ 1). + If either is buried ((rho_i > rho_0) || (rho_j > rho_0)), + then water-mediated interactions are small (sigma_water ~ 0). + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + sigma_water : float + Fraction of water-mediated interactions (0 to 1) + """ + #sigma_water = 0.25 * (1 - np.tanh(eta_sigma * (rho_i - rho_0))) * (1 - np.tanh(eta_sigma * (rho_j - rho_0))) + sigma_water = 0.25*(1-np.tanh(7*(rho_i-2.6)))*(1-np.tanh(7*(rho_j-2.6))) + return sigma_water +# +@njit(signature_or_function=float64(float64, float64)) +def compute_electrostatic_indicator(l_D, dist_ij): + """ + Computes electrostatics indicator function, which gives an + effective proximity (higher <==> closer) of two residues, + capturing not only the 1/r decay of the coulomb energy, + but also the screening effects of counterions; l_D + should be negatively correlated with the ionic strength. + + Parameters + ---------- + See module-level docstring + + Returns + ------- + electrostatics_indicator(i,j,dist_mat[i,j]), parameterized by l_D + """ + if dist_ij >= 1: + safe_dist = dist_ij + else: + raise ValueError("Distance between two residue was less than 1 angstrom!") + electrostatics_indicator = np.exp(-safe_dist / l_D) / safe_dist + return electrostatics_indicator +# +########################################################################### +# FUNCTIONS TO CALCULATE ENERGIES, GIVEN A SINGLE RESIDUE i OR A PAIR (i,j) +# THESE FUNCTIONS **DON'T** CHECK MASK CONDITIONS! +# +# BURIAL POTENTIAL +""" +@njit(signature_or_function=float64(float64[:], float64, float64[:])) +def compute_burial_potential_i_from_indicator_gamma(burial_indicator, lambda_burial, gamma): + """""" + Compute the burial energy for residue i based on its local density. + Note that this function computes and sums the 3 types of burial energies: + low-density, medium-density, and high-density. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + burial_energy : float + Total burial energy for residue i, sum across all three burial wells. + """""" + # Caution: the burial indicator functions range from 0 to 2, + # not 0 to 1, like the other indicator functions. + # This is why we have a coefficient of 0.5 in the energy expression. + low_indicator = burial_indicator[0] + low_gamma = gamma[0] + medium_indicator = burial_indicator[1] + medium_gamma = gamma[1] + high_indicator = burial_indicator[2] + high_gamma = gamma[2] + burial_energy = -0.5*lambda_burial *\ + (low_indicator*low_gamma+medium_indicator*medium_gamma+high_indicator*high_gamma) + return burial_energy +""" +@njit(signature_or_function=float64(float64, float64, float64[:])) +def compute_burial_potential_i_from_rho_gamma(rho_i, lambda_burial, gamma): + """ + Compute the burial energy for residue i based on its local density. + Note that this function computes and sums the 3 types of burial energies: + low-density, medium-density, and high-density. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + burial_energy : float + Total burial energy for residue i, sum across all three burial wells. + """ + burial_indicator = compute_burial_indicator_i(rho_i) + #burial_energy = compute_burial_potential_i_from_indicator_gamma(burial_indicator, lambda_burial, gamma) + low_indicator = burial_indicator[0] + low_gamma = gamma[0] + medium_indicator = burial_indicator[1] + medium_gamma = gamma[1] + high_indicator = burial_indicator[2] + high_gamma = gamma[2] + burial_energy = -0.5*lambda_burial *\ + (low_indicator*low_gamma+medium_indicator*medium_gamma+high_indicator*high_gamma) + return burial_energy +""" +@njit(signature_or_function=float64(int64, float64, float64, float64[:,:], int64[:])) +def compute_burial_potential_i_from_rho(i, rho_i, lambda_burial, burial_gamma, seq_index): + """""" + Compute the burial energy for residue i based on its local density. + Note that this function computes and sums the 3 types of burial energies: + low-density, medium-density, and high-density. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + burial_energy : float + Total burial energy for residue i, sum across all three burial wells. + """""" + gamma = burial_gamma[seq_index[i]] + burial_energy = compute_burial_potential_i_from_rho_gamma(rho_i, lambda_burial, gamma) + return burial_energy +""" +@njit(signature_or_function=float64(int64, int64, int64[:], int64[:], float64[:,:], float64, float64[:])) +def compute_burial_potential_i_from_gamma(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, lambda_burial, gamma): + """ + Compute the burial energy for residue i based on its local density. + Note that this function computes and sums the 3 types of burial energies: + low-density, medium-density, and high-density. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + burial_energy : float + Total burial energy for residue i, sum across all three burial wells. + """ + rho_i = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + burial_energy = compute_burial_potential_i_from_rho_gamma(rho_i, lambda_burial, gamma) + return burial_energy +@njit(signature_or_function=float64(int64, int64, int64[:], int64[:], float64[:,:], float64, float64[:,:], int64[:])) +def compute_burial_potential_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, lambda_burial, burial_gamma, seq_index): + """ + Compute the burial energy for residue i based on its local density. + Note that this function computes and sums the 3 types of burial energies: + low-density, medium-density, and high-density. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + burial_energy : float + Total burial energy for residue i, sum across all three burial wells. + """ + gamma = burial_gamma[seq_index[i], :] + burial_energy = compute_burial_potential_i_from_gamma(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, lambda_burial, gamma) + return burial_energy +# feel free to add more functions with different signatures for greater flexibility of use +# +# DIRECT POTENTIAL +#@njit(signature_or_function=float64(float64, float64, float64)) +#def compute_direct_potential_ij_from_thetaI_gamma(thetaI, lambda_direct, gamma): +# """ +# Compute the direct interaction potential for a pair of residues. +# +# Parameters +# ---------- +# See module-level docstring. +# +# Returns +# ------- +# direct_energy : float +# Energy of the direct contact term for the pair (i,j), +# set to 0 if the pair is masked. +# """ +# return -lambda_direct * thetaI * gamma +@njit(signature_or_function=float64(float64, float64, float64)) +def compute_direct_potential_ij_from_distij_gamma(dist_ij, lambda_direct, gamma): + """ + Compute the direct interaction potential for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + direct_energy : float + Energy of the direct contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + # get indicator + thetaI = compute_thetaI(dist_ij) + # put it all together + direct_energy = -lambda_direct * thetaI * gamma + return direct_energy +@njit(signature_or_function=float64(int64, int64, float64[:,:], float64, float64)) +def compute_direct_potential_ij_from_gamma(i, j, dist_mat, lambda_direct, gamma): + """ + Compute the direct interaction potential for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + direct_energy : float + Energy of the direct contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + dist_ij = dist_mat[i,j] + return compute_direct_potential_ij_from_distij_gamma(dist_ij, lambda_direct, gamma) +@njit(signature_or_function=float64(int64, int64, float64[:,:], float64, float64[:,:], int64[:])) +def compute_direct_potential_ij(i, j, dist_mat, lambda_direct, direct_gamma, seq_index): + """ + Compute the direct interaction potential for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + direct_energy : float + Energy of the direct contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + gamma = direct_gamma[seq_index[i], seq_index[j]] + return compute_direct_potential_ij_from_gamma(i, j, dist_mat, lambda_direct, gamma) +# +# LONG RANGE (protein-mediated and water-mediated) CONTACT POTENTIALS +#@njit(signature_or_function=numba.types.UniTuple(float64,2)( +# float64, float64, float64, float64, float64, float64)) +#def compute_long_potentials_ij_from_sigmawater_thetaII_gamma(thetaII, sigma_water, +# lambda_protein, gamma_p, lambda_water, gamma_w): +# """ +# Compute the protein-mediated and water-mediated (long-range) potentials +# for a pair of residues. +# +# Parameters +# ---------- +# See module-level docstring. +# +# Returns +# ------- +# protein_energy : float +# Energy of the protein-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# water_energy : float +# Energy of the water-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# """ +# # this function is defined so that we have the details of the +# # calculation in one place and don't have to type the equation +# # in several different places. probably not a big deal, +# # but just trying to follow best practices +# sigma_protein = 1.0 - sigma_water +# protein_energy = -lambda_protein * thetaII * sigma_protein * gamma_p +# water_energy = -lambda_water * thetaII * sigma_water * gamma_w +# return protein_energy, water_energy +#@njit(signature_or_function=numba.types.UniTuple(float64,2)(float64, float64, float64, float64, float64, float64)) +#def compute_long_potentials_ij_from_sigmawater_distij_gamma(dist_ij, sigma_water, +# lambda_protein, gamma_p, lambda_water, gamma_w): +# """ +# Compute the protein-mediated and water-mediated (long-range) potentials +# for a pair of residues. +# +# Parameters +# ---------- +# See module-level docstring. +# +# Returns +# ------- +# protein_energy : float +# Energy of the protein-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# water_energy : float +# Energy of the water-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# """ +# # get indicators and sigma values +# thetaII = compute_thetaII(dist_ij) +# # compute energies +# sigma_protein = 1.0 - sigma_water +# protein_energy = -lambda_protein * thetaII * sigma_protein * gamma_p +# water_energy = -lambda_water * thetaII * sigma_water * gamma_w +# return protein_energy, water_energy +@njit(signature_or_function=numba.types.UniTuple(float64,2)(int64, int64, float64, float64, + float64, float64[:,:], float64, float64[:,:], int64[:])) +def compute_long_potentials_ij_from_sigmawater_distij(i, j, dist_ij, sigma_water, + lambda_protein, protein_gamma, lambda_water, water_gamma, seq_index): + """ + Compute the protein-mediated and water-mediated (long-range) potentials + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + protein_energy : float + Energy of the protein-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + water_energy : float + Energy of the water-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + gamma_p = protein_gamma[seq_index[i], seq_index[j]] + gamma_w = water_gamma[seq_index[i], seq_index[j]] + # compute energies + protein_energy, water_energy = compute_long_potentials_ij_from_sigmawater_distij_gamma( + dist_ij, sigma_water, lambda_protein, gamma_p, lambda_water, gamma_w) + return protein_energy, water_energy +@njit(signature_or_function=numba.types.UniTuple(float64,2)(int64, int64, float64[:,:], float64, + float64, float64[:,:], float64, float64[:,:], int64[:])) +def compute_long_potentials_ij_from_sigmawater(i, j, dist_mat, sigma_water, + lambda_protein, protein_gamma, lambda_water, water_gamma, seq_index): + """ + Compute the protein-mediated and water-mediated (long-range) potentials + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + protein_energy : float + Energy of the protein-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + water_energy : float + Energy of the water-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + dist_ij = dist_mat[i,j] + # compute energies + protein_energy, water_energy = compute_long_potentials_ij_from_sigmawater_distij(i, j, + dist_ij, sigma_water, lambda_protein, protein_gamma, lambda_water, water_gamma, seq_index) + return protein_energy, water_energy +#@njit(signature_or_function=numba.types.UniTuple(float64,2)( +# float64, float64, float64, float64, float64, float64, float64)) +#def compute_long_potentials_ij_from_rho_thetaII_gamma(rho_i, rho_j, thetaII, +# lambda_protein, gamma_p, lambda_water, gamma_w): +# """ +# Compute the protein-mediated and water-mediated (long-range) potentials +# for a pair of residues. +# +# Parameters +# ---------- +# See module-level docstring. +# +# Returns +# ------- +# protein_energy : float +# Energy of the protein-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# water_energy : float +# Energy of the water-mediated contact term for the pair (i,j), +# set to 0 if the pair is masked. +# """ +# sigma_water = compute_sigma_water(rho_i, rho_j) +# protein_energy, water_energy = compute_long_potentials_ij_from_sigmawater_thetaII_gamma( +# thetaII, sigma_water, lambda_protein, gamma_p, lambda_water, gamma_w) +# return protein_energy, water_energy +@njit(signature_or_function=numba.types.UniTuple(float64,2)( + float64, float64, float64, float64, float64, float64, float64)) +def compute_long_potentials_ij_from_rho_distij_gamma(dist_ij, rho_i, rho_j, + lambda_protein, gamma_p, lambda_water, gamma_w): + """ + Compute the protein-mediated and water-mediated (long-range) potentials + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + protein_energy : float + Energy of the protein-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + water_energy : float + Energy of the water-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + sigma_water = compute_sigma_water(rho_i, rho_j) + #assert 0 < sigma_water < 1, f'rho_i: {repr(rho_i)}, rho_j: {repr(rho_j)}, sigma_water: {repr(sigma_water)}' + #protein_energy, water_energy = compute_long_potentials_ij_from_sigmawater_distij_gamma( + # dist_ij, sigma_water, lambda_protein, gamma_p, lambda_water, gamma_w) + + thetaII = compute_thetaII(dist_ij) + # compute energies + sigma_protein = 1.0 - sigma_water + protein_energy = -lambda_protein * thetaII * sigma_protein * gamma_p + water_energy = -lambda_water * thetaII * sigma_water * gamma_w + return protein_energy, water_energy +@njit(signature_or_function=numba.types.UniTuple(float64,2)( + int64, int64, int64, int64[:], int64[:], float64[:,:], float64, float64, float64, float64)) +def compute_long_potentials_ij_from_gamma(i, j, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, + lambda_protein, gamma_p, lambda_water, gamma_w): + """ + Compute the protein-mediated and water-mediated (long-range) potentials + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + protein_energy : float + Energy of the protein-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + water_energy : float + Energy of the water-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + rho_i = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + rho_j = compute_rho_i(j, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + protein_energy, water_energy = compute_long_potentials_ij_from_rho_distij_gamma( + dist_mat[i,j], rho_i, rho_j, lambda_protein, gamma_p, lambda_water, gamma_w) + return protein_energy, water_energy +@njit(numba.types.UniTuple(float64,2)(int64, int64, int64, int64[:], int64[:], float64[:,:], + float64, float64[:,:], float64, float64[:,:], int64[:])) +def compute_long_potentials_ij(i, j, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, + lambda_protein, protein_gamma, lambda_water, water_gamma, seq_index): + """ + Compute the protein-mediated and water-mediated (long-range) potentials + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + protein_energy : float + Energy of the protein-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + water_energy : float + Energy of the water-mediated contact term for the pair (i,j), + set to 0 if the pair is masked. + """ + gamma_p = protein_gamma[seq_index[i], seq_index[j]] + gamma_w = water_gamma[seq_index[i], seq_index[j]] + return compute_long_potentials_ij_from_gamma(i, j, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, + lambda_protein, lambda_water, gamma_p, gamma_w) +# feel free to add more functions with different signatures for greater flexibility of use +# +#@njit(signature_or_function=float64(float64, float64, float64)) +#def compute_electrostatic_potential_ij_from_indicator_gamma(electrostatic_indicator, lambda_electrostatic, gamma): +# """ +# Compute the solvation-averaged electrostatic potential +# for a pair of residues. +# +# Parameters +# ---------- +# See module-level docstring. +# +# Returns +# ------- +# electrostatic_energy : float +# Energy of the electrostatic interaction between residues i and j +# """ +# # gamma is negative if interaction is favorable and positive if +# # unfavorable, and our lambdas and indicators are all positive by convention, +# # so we don't precede this equation with a negative sign +# #return -lambda_electrostatic * electrostatic_indicator * gamma +# return lambda_electrostatic * electrostatic_indicator * gamma +@njit(signature_or_function=float64(float64, float64, float64, float64)) +def compute_electrostatic_potential_ij_from_distij_gamma(l_D, dist_ij, lambda_electrostatic, gamma): + """ + Compute the solvation-averaged electrostatic potential + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + electrostatic_energy : float + Energy of the electrostatic interaction between residues i and j + """ + indicator = compute_electrostatic_indicator(l_D, dist_ij) + #electrostatic_energy = compute_electrostatic_potential_ij_from_indicator_gamma( + # indicator, lambda_electrostatic, gamma) + + # gamma is negative if interaction is favorable and positive if + # unfavorable, and our lambdas and indicators are all positive by convention, + # so we don't precede this equation with a negative sign + #return -lambda_electrostatic * electrostatic_indicator * gamma + return lambda_electrostatic * electrostatic_indicator * gamma + return electrostatic_energy +""" +@njit(signature_or_function=float64(int64, int64, float64, float64[:,:], float64, float64)) +def compute_electrostatic_potential_ij_from_gamma(i, j, l_D, dist_mat, lambda_electrostatic, gamma): + """""" + Compute the solvation-averaged electrostatic potential + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + electrostatic_energy : float + Energy of the electrostatic interaction between residues i and j + """""" + dist_ij = dist_mat[i,j] + electrostatic_energy = compute_electrostatic_potential_ij_from_distij_gamma(l_D, dist_ij, lambda_electrostatic, gamma) + return electrostatic_energy +""" +@njit(signature_or_function=float64(int64, int64, float64, float64[:,:], float64, float64[:,:], int64[:])) +def compute_electrostatic_potential_ij(i, j, l_D, dist_mat, lambda_electrostatic, electrostatic_gamma, seq_index): + """ + Compute the solvation-averaged electrostatic potential + for a pair of residues. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + electrostatic_energy : float + Energy of the electrostatic interaction between residues i and j + """ + gamma = electrostatic_gamma[seq_index[i], seq_index[j]] + electrostatic_energy = compute_electrostatic_potential_ij_from_distij_gamma( + l_D, dist_mat[i,j], lambda_electrostatic, gamma) + return electrostatic_energy +# feel free to add more functions with different signatures for greater flexibility of use +# +########################################################################## +# FUNCTIONS TO SUM DIFFERENT ENERGY TYPES OVER AN ENTIRE PROTEIN SYSTEM. +# THESE FUNCTIONS **DO** CHECK MASK CONDITIONS! +# +signature = float64(int64, + int64[:], int64[:], + float64[:,:], + float64, float64[:,:], int64[:]) +def compute_burial_potential_total(min_seq_sep_rho, + chain_starts, chain_ends, + dist_mat, + lambda_burial, burial_gamma, seq_index): + """ + Compute the total burial potential for all residues in the protein system. + Iterates over all residues and sums burial energies. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + total_burial_energy : float + Sum of burial energies for all residues + """ + num_res = dist_mat.shape[0] + total_burial_energy = 0.0 + rho_array = np.zeros(num_res) + burial_indicators = np.zeros((num_res,3)) # axis 1 ordered (low, medium, high) + for i in prange(num_res): + rho_array[i] = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + burial_indicators[i] = compute_burial_indicator_i(rho_array[i]) + energy = compute_burial_potential_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, lambda_burial, burial_gamma, seq_index) + total_burial_energy += energy + return total_burial_energy +compute_burial_potential_total_parallel = njit(signature_or_function=signature, parallel=True)(compute_burial_potential_total) +compute_burial_potential_total = njit(signature_or_function=signature)(compute_burial_potential_total) +# +signature = float64(int64, + int64[:], int64[:], + float64[:,:], + float64, float64[:,:], int64[:],) +def compute_direct_potential_total(min_seq_sep_contact, + chain_starts, chain_ends, + dist_mat, + lambda_direct, direct_gamma, seq_index,): + """ + Compute the total direct contact potential for the entire protein system. + Iterates over all residue pairs and sums direct interaction energies. + + Parameters + ---------- + See module-level docstring + + Returns + ------- + total_direct_energy : float + Sum of all direct contact energies + """ + num_res = dist_mat.shape[0] + total_direct_energy = 0.0 + # loop over all pairs of residues + for i in prange(num_res): + # parallelizing inner loop doesn't make much of a difference + #for j in prange(i+1, num_res): + for j in range(i+1, num_res): + # check mask + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + # 2.5 and 8.5 cutoffs: effectively, + # we're truncating the potential where the indicators are almost zero + # (thetaI(2.5)==thetaI(8.5)==2.0611536367E-9) + if not mask_of_pair(min_seq_sep_contact, abs(i-j), 2.5, 8.5, + same_chain, dist_mat[i,j]): + continue # just call it 0 energy if the pair is masked + energy = compute_direct_potential_ij(i, j, dist_mat, lambda_direct, direct_gamma, seq_index) + total_direct_energy += energy + return total_direct_energy +compute_direct_potential_total_parallel = njit(signature_or_function=signature, parallel=True)(compute_direct_potential_total) +compute_direct_potential_total = njit(signature_or_function=signature)(compute_direct_potential_total) +# +signature = numba.types.UniTuple(float64,2)(int64, int64, + int64[:], int64[:], + float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64[:]) +def compute_long_potentials_total(min_seq_sep_rho, min_seq_sep_contact, + chain_starts, chain_ends, + dist_mat, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + seq_index): + """ + Compute the total protein-mediated and water-mediated contact potentials + for the entire protein structure. Iterates over all residue pairs and sums + long-range interaction energies, considering local densities + for the sigma (protein vs. water mediated) weighting. + This function also applies the mask as appropriate. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + total_protein_energy : float + Sum of all protein-mediated contact energies + total_water_energy : float + Sum of all water-mediated contact energies + """ + num_res = dist_mat.shape[0] + total_protein_energy = 0.0 + total_water_energy = 0.0 + # Pre-compute rho for all residues + rho_array = np.zeros(num_res) + for i in prange(num_res): + rho_array[i] = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + # compute pairwise energies and add to the total of each type + for i in prange(num_res): + # parallelizing inner loop doesn't make much of a difference + #for j in prange(i+1, num_res): + for j in range(i+1, num_res): + # check contact mask + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + # 4.5 and 11.5 cutoffs: effectively, + # we're truncating the potential where the indicators are almost zero + # (thetaII(4.5)==thetaII(11.5)==2.0611536367E-9 + if not mask_of_pair(min_seq_sep_contact, abs(i-j), + 4.5, 11.5, same_chain, dist_mat[i,j]): + continue # just call it 0 energy if the pair is masked + # compute sigma for this pair from precomputed rhos, then call long potentials + sigma_water = compute_sigma_water(rho_array[i], rho_array[j]) + protein_energy, water_energy = compute_long_potentials_ij_from_rho_distij_gamma( + dist_mat[i,j], rho_array[i], rho_array[j], + lambda_protein, protein_gamma[seq_index[i], seq_index[j]], + lambda_water, water_gamma[seq_index[i], seq_index[j]]) + total_protein_energy += protein_energy + total_water_energy += water_energy + return total_protein_energy, total_water_energy +compute_long_potentials_total_parallel = njit(signature_or_function=signature, parallel=True)(compute_long_potentials_total) +compute_long_potential_total = njit(signature_or_function=signature)(compute_long_potentials_total) +# +signature = float64(float64, int64, + int64[:], int64[:], + float64[:,:], + float64, float64[:,:], + int64[:],) +def compute_electrostatic_potential_total(l_D, min_seq_sep_electrostatic, + chain_starts, chain_ends, + dist_mat, + lambda_electrostatic, electrostatic_gamma, + seq_index): + """ + Compute the total Debye-Huckel electrostatic potential + for the entire protein structure. Iterates over all residue pairs and sums + electrostatic interaction energies. + This function also applies the mask as appropriate. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + total_electrostatic_energy : float + Sum of all electrostatic energies, masked as appropriate + """ + # will move this check and/or get rid of it + #if lambda_electrostatic == 0: + # return 0.0 # save some time if we're going to set everything to 0 anyway + num_res = dist_mat.shape[0] + total_electrostatic_energy = 0.0 + # loop over all pairs of residues + for i in prange(num_res): + # parallelizing inner loop doesn't make much of a difference + #for j in prange(i+1, num_res): + for j in range(i+1, num_res): + # check mask + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + # unlike the other potentials, the electrostatic potential doesn't + # decay below some minimum distance, so the lower bound is 0; + # the upper bound varies with the debye length + if not mask_of_pair(min_seq_sep_electrostatic, abs(i-j), 0, 10*l_D, + same_chain, dist_mat[i,j]): + continue # just call it 0 energy if the pair is masked + energy = compute_electrostatic_potential_ij(i, j, l_D, dist_mat, lambda_electrostatic, electrostatic_gamma, seq_index) + total_electrostatic_energy += energy + return total_electrostatic_energy +compute_electrostatic_potential_total_parallel = njit(signature_or_function=signature, parallel=True)(compute_electrostatic_potential_total) +compute_electrostatic_potential_total = njit(signature_or_function=signature)(compute_electrostatic_potential_total) +# +# no numba for this function, since it doesn't have any loops or do any intensive computation +def compute_potential_total(l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, + chain_starts, chain_ends, + dist_mat, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index, parallel): + """ + Compute the total AWSEM energy for the entire protein system. + + CAUTION: this is NOT the sum over all i and j of compute_pair_energy_ij + (compute_pair_energy_ij is found below with the frustration utilities). + Taking the sum of compute_pair_energy_ij over all i and j would overcount + each residue's burial energy, since it is included in the pair energy + of all contacts in which that residue participates, but the burial energy + should only be counted once for each residue. + + Aggregates direct , protein-mediated, water-mediated, + burial, and electrostatic terms. + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + total_energy : float + Total AWSEM energy for the protein system + """ + direct_args = (min_seq_sep_contact, + chain_starts, chain_ends, + dist_mat, + lambda_direct, direct_gamma, + seq_index) + long_args = (min_seq_sep_rho, min_seq_sep_contact, + chain_starts, chain_ends, + dist_mat, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + seq_index) + burial_args = (min_seq_sep_rho, + chain_starts, chain_ends, + dist_mat, + lambda_burial, burial_gamma, + seq_index) + electrostatic_args = (l_D, min_seq_sep_electrostatic, + chain_starts, chain_ends, + dist_mat, + lambda_electrostatic, electrostatic_gamma, + seq_index) + if parallel: + direct_e = compute_direct_potential_total_parallel(*direct_args) + protein_e, water_e = compute_long_potentials_total_parallel(*long_args) + burial_e = compute_burial_potential_total_parallel(*burial_args) + electrostatic_e = compute_electrostatic_potential_total_parallel(*electrostatic_args) + else: + direct_e = compute_direct_potential_total(*direct_args) + protein_e, water_e = compute_long_potentials_total(*long_args) + burial_e = compute_burial_potential_total(*burial_args) + electrostatic_e = compute_electrostatic_potential_total(*electrostatic_args) + total_energy = direct_e + protein_e + water_e + burial_e + electrostatic_e + return total_energy +# +######################################################################################### +# PAIR ENERGY: burial(i)+burial(j)+direct(i,j)+protein(i,j)+water(i,j)+electrostatic(i,j) +# important: total energy is NOT sum over all pairs ij of pair_energy(i,j) +# these functions DO NOT check mask conditions +""" +@njit(signature_or_function=float64( + float64, float64, float64, float64, float64, + float64, float64, float64, float64, float64, float64, + float64, float64[:], float64[:], float64, float64)) +def compute_pair_energy_ij_useful( + rho_i, rho_j, thetaI, thetaII, electrostatic_indicator, + lambda_direct, gamma_d, lambda_protein, gamma_p, lambda_water, gamma_w, + lambda_burial, gamma_bi, gamma_bj, lambda_electrostatic, gamma_e): + # useful parameter set for frustration calculations + burial_energy_i = compute_burial_potential_i_from_rho_gamma(rho_i, lambda_burial, gamma_bi) + burial_energy_j = compute_burial_potential_i_from_rho_gamma(rho_j, lambda_burial, gamma_bj) + direct_energy = compute_direct_potential_ij_from_thetaI_gamma(thetaI, lambda_direct, gamma_d) + protein_energy, water_energy = compute_long_potentials_ij_from_rho_thetaII_gamma( + rho_i, rho_j, thetaII, lambda_protein, gamma_p, lambda_water, gamma_w) + electrostatic_energy = compute_electrostatic_potential_ij_from_indicator_gamma( + electrostatic_indicator, lambda_electrostatic, gamma_e) + pair_energy = burial_energy_i + burial_energy_j + direct_energy +\ + protein_energy + water_energy + electrostatic_energy + return pair_energy +""" +@njit(signature_or_function=float64( + int64, int64, float64, int64, int64[:], int64[:], + float64[:,:], + float64, float64, + float64, float64, + float64, float64, + float64, float64[:], float64[:], + float64, float64)) +def compute_pair_energy_ij_from_gamma( + i, j, l_D, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_direct, gamma_d, + lambda_protein, gamma_p, + lambda_water, gamma_w, + lambda_burial, gamma_bi, gamma_bj, + lambda_electrostatic, gamma_e): + direct_energy = compute_direct_potential_ij_from_gamma(i, j, + dist_mat, + lambda_direct, gamma_d) + protein_energy, water_energy = compute_long_potentials_ij_from_gamma( + i, j, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_protein, gamma_p, lambda_water, gamma_w) + burial_energy_i = compute_burial_potential_i_from_gamma(i, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_burial, gamma_bi) + burial_energy_j = compute_burial_potential_i_from_gamma(j, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_burial, gamma_bj) + electrostatic_energy = compute_electrostatic_potential_ij_from_gamma(i, j, l_D, + dist_mat, + lambda_electrostatic, gamma_e) + pair_energy = burial_energy_i + burial_energy_j + direct_energy +\ + protein_energy + water_energy + electrostatic_energy + return pair_energy +# +""" +@njit(signature_or_function=float64( + int64, int64, float64, + float64[:,:], + float64, float64, float64, + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64[:])) +def compute_pair_energy_ij_from_rho_sigmawater( + i, j, l_D, + dist_mat, + rho_i, rho_j, sigma_water, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index): + direct_energy = compute_direct_potential_ij(i, j, dist_mat, + lambda_direct, direct_gamma, + seq_index) + protein_energy, water_energy = compute_long_potentials_ij_from_sigmawater( + i, j, + dist_mat, sigma_water, + lambda_protein, protein_gamma, lambda_water, water_gamma, + seq_index) + burial_energy_i = compute_burial_potential_i_from_rho(i, rho_i, + lambda_burial, burial_gamma, + seq_index) + burial_energy_j = compute_burial_potential_i_from_rho(j, rho_j, + lambda_burial, burial_gamma, + seq_index) + electrostatic_energy = compute_electrostatic_potential_ij(i, j, l_D, + dist_mat, + lambda_electrostatic, electrostatic_gamma, + seq_index) + pair_energy = burial_energy_i + burial_energy_j + direct_energy +\ + protein_energy + water_energy + electrostatic_energy + return pair_energy +""" +# +""" +@njit(signature_or_function=float64(int64, int64, float64, + float64[:,:], + float64, float64, + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64[:])) +def compute_pair_energy_ij_from_rho(i, j, l_D, + dist_mat, + rho_i, rho_j, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index): + """""" + Compute the "pair energy" for residues i and j, defined as the sum of: + - Direct contact energy + - Protein-mediated contact energy + - Water-mediated contact energy + - Burial energies for both residues + - Electrostatic interaction energy, if requested + + This quantity is used in the calculation of the frustration index: + Frustration Index = -1 * (pair energy - DECOY_AVERAGE) / DECOY_STDEV + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + pair_energy : float + Total "pair energy" of residues i and j + """""" + sigma_water = compute_sigma_water(rho_i, rho_j) + pair_energy = compute_pair_energy_ij_from_rho_sigmawater( + i, j, l_D, + dist_mat, + rho_i, rho_j, sigma_water, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index) + return pair_energy +""" +# +@njit(signature_or_function=float64(int64, int64, float64, int64, int64[:], int64[:], + float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64[:])) +def compute_pair_energy_ij(i, j, l_D, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index): + """ + Compute the "pair energy" for residues i and j, defined as the sum of: + - Direct contact energy + - Protein-mediated contact energy + - Water-mediated contact energy + - Burial energies for both residues + - Electrostatic interaction energy, if requested + + This quantity is used in the calculation of the frustration index: + Frustration Index = -1 * (pair energy - DECOY_AVERAGE) / DECOY_STDEV + + Parameters + ---------- + See module-level docstring. + + Returns + ------- + pair_energy : float + Total "pair energy" of residues i and j + """ + direct_energy = compute_direct_potential_ij(i, j, dist_mat, + lambda_direct, direct_gamma, + seq_index) + protein_energy, water_energy = compute_long_potentials_ij( + i, j, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_protein, protein_gamma, lambda_water, water_gamma, + seq_index) + burial_energy_i = compute_burial_potential_i(i, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_burial, burial_gamma, + seq_index) + burial_energy_j = compute_burial_potential_i(j, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_burial, burial_gamma, + seq_index) + electrostatic_energy = compute_electrostatic_potential_ij(i, j, l_D, + dist_mat, + lambda_electrostatic, electrostatic_gamma, + seq_index) + pair_energy = burial_energy_i + burial_energy_j + direct_energy +\ + protein_energy + water_energy + electrostatic_energy + """ + alternatively, the body of this function could look like this: + aa_i = seq_index[i] + aa_j = seq_index[j] + gamma_d = direct_gamma[aa_i, aa_j] + gamma_p = protein_gamma[aa_i, aa_j] + gamma_w = water_gamma[aa_i, aa_j] + gamma_bi = burial_gamma[aa_i,:] + gamma_bj = burial_gamma[aa_j,:] + gamma_e = electrostatic_gamma[aa_i, aa_j] + pair_energy = compute_pair_energy_ij_from_gamma( + i, j, l_D, min_seq_sep_rho, chain_starts, chain_ends, + dist_mat, + lambda_direct, gamma_d, + lambda_protein, gamma_p, + lambda_water, gamma_w, + lambda_burial, gamma_bi, gamma_bj, + lambda_electrostatic, gamma_e) + """ + return pair_energy +# +######################################################################################### +# POTTS MODEL: (N,N,q,q) for (N,N) dist_mat and (q,q) gammas +signature = float64[:,:]( + int64, + int64[:], int64[:], + float64[:,:], + float64, float64[:,:]) +def compute_potts_model_h( + min_seq_sep_rho, + chain_starts, chain_ends, + dist_mat, + lambda_burial, burial_gamma): + assert dist_mat.shape[0] == dist_mat.shape[1] + num_aa = dist_mat.shape[0] + num_aa_types = burial_gamma.shape[0] + assert burial_gamma.shape[1] == 3 + h = np.zeros((num_aa, num_aa_types)) + for i in prange(num_aa): + for q in range(num_aa_types): + gamma = burial_gamma[q] + h[i,q] = compute_burial_potential_i_from_gamma( + i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat, lambda_burial, gamma) + h = -h # i guess we define it as the negative of the actual potential? + return h +compute_potts_model_h_parallel = njit(signature_or_function=signature, parallel=True)(compute_potts_model_h) +compute_potts_model_h = njit(signature_or_function=signature)(compute_potts_model_h) + + +def potts_model_functions( l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, + chain_starts, chain_ends, max_dist_contact, max_dist_electrostatic): + + signature = float64[:,:,:,:]( + float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:]) + def compute_potts_model_J(dist_mat, lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_electrostatic, electrostatic_gamma): + return J + return compute_potts_model_J + +compute_potts_model_J_function = potts_model_functions(l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, + chain_starts, chain_ends, max_dist_contact, max_dist_electrostatic) + +compute_potts_model_J_function(dist_mat, lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_electrostatic, electrostatic_gamma) + + + +# +signature = float64[:,:,:,:]( + float64, int64, int64, int64, + int64[:], int64[:], float64, float64, + float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:]) +def compute_potts_model_J( + l_D, min_seq_sep_rho, min_seq_sep_contact, min_seq_sep_electrostatic, + chain_starts, chain_ends, max_dist_contact, max_dist_electrostatic, + dist_mat, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_electrostatic, electrostatic_gamma): + # check input + assert dist_mat.shape[0] == dist_mat.shape[1] + assert direct_gamma.shape[0] == direct_gamma.shape[1] + assert direct_gamma.shape == protein_gamma.shape == water_gamma.shape == electrostatic_gamma.shape + num_aa = dist_mat.shape[0] + num_aa_types = direct_gamma.shape[0] + # precompute rho + rho_array = np.zeros(num_aa) + for i in prange(num_aa): + rho_array[i] = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + J = np.zeros((num_aa, num_aa, num_aa_types, num_aa_types)) + for i in prange(num_aa): + for j in range(num_aa): + if i==j: + J[i,j,:,:] = 0.0 + continue + dist_ij = dist_mat[i,j] + rho_i = rho_array[i] + rho_j = rho_array[j] + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + contact_mask_ij = mask_of_pair(min_seq_sep=min_seq_sep_contact, seq_sep=abs(j-i), + min_dist=0.0, max_dist=max_dist_contact, + same_chain=same_chain, dist_ij=dist_ij) + electrostatic_mask_ij = mask_of_pair(min_seq_sep_electrostatic, abs(j-i), 0, max_dist_electrostatic, same_chain, dist_ij) + for qi in range(num_aa_types): + for qj in range(qi, num_aa_types): + gamma_dij = direct_gamma[qi,qj] + gamma_pij = protein_gamma[qi,qj] + gamma_wij = water_gamma[qi,qj] + gamma_eij = electrostatic_gamma[qi,qj] + direct_energy = compute_direct_potential_ij_from_distij_gamma(dist_ij, lambda_direct, gamma_dij) + protein_energy, water_energy = compute_long_potentials_ij_from_rho_distij_gamma( + dist_ij, rho_i, rho_j, lambda_protein, gamma_pij, lambda_water, gamma_wij) + contact_energy = contact_mask_ij * (direct_energy + protein_energy + water_energy) + electrostatic_energy = electrostatic_mask_ij * compute_electrostatic_potential_ij_from_distij_gamma( + l_D, dist_ij, lambda_electrostatic, gamma_eij) + # + energy = contact_energy + electrostatic_energy + J[i,j,qi,qj] = energy + J[i,j,qj,qi] = energy + J = -J # i guess we define it as the negative of the actual potential? + return J +compute_potts_model_J_parallel = njit(signature_or_function=signature, parallel=True)(compute_potts_model_J) +compute_potts_model_J = njit(signature_or_function=signature)(compute_potts_model_J) +# +######################################################################################## +# PAIR ENERGY MATRIX FOR FRUSTRATION CALCULATIONS -- NOT SURE WHAT TO DO WITH THIS. MIGHT DELETE +""" +signature = float64[:,:](float64, + int64, int64, + int64[:], int64[:], float64, + float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + float64, float64[:,:], + int64[:]) +def compute_pair_energy_matrix(l_D, + min_seq_sep_rho, min_seq_sep_frust_index, + chain_starts, chain_ends, max_dist, + dist_mat, + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index): + """""" + Make matrix of the same shape as the distance matrix, + where each element is the pair energy, or np.nan if masked. + + Parameters + ---------- + See module-level docstring + + Returns + ------- + pair_energy_matrix : np.array(dist_mat.shape) + matrix where the element (i,j) is the pair energy of (i,j) + (if unmasked) or np.nan (if masked) + """""" + # Pre-compute rho for all residues + num_res = dist_mat.shape[0] + rho_array = np.zeros(num_res) + for i in prange(num_res): + rho_array[i] = compute_rho_i(i, min_seq_sep_rho, chain_starts, chain_ends, dist_mat) + # fill in the matrix + num_res = dist_mat.shape[0] + pair_energy_matrix = np.empty((num_res, num_res)) + for i in prange(num_res): + for j in range(i,num_res): + # check mask + same_chain = check_same_chain(i, j, chain_starts, chain_ends) + # the idea is that this is the matrix we'll use to calculate frustration indices, + # so we set the minimum distance to 0 (as it always is for frustration calculations) + # and let the maximum distance be a variable + unmasked = mask_of_pair(min_seq_sep_frust_index, abs(i-j), 0.0, max_dist, + same_chain, dist_mat[i,j],) + if unmasked: + pair_energy_matrix[i,j] = compute_pair_energy_ij_from_rho( + i, j, l_D, + dist_mat, + rho_array[i], rho_array[j], + lambda_direct, direct_gamma, + lambda_protein, protein_gamma, + lambda_water, water_gamma, + lambda_burial, burial_gamma, + lambda_electrostatic, electrostatic_gamma, + seq_index) + else: + pair_energy_matrix[i,j] = np.nan + pair_energy_matrix[j,i] = pair_energy_matrix[i,j] + return pair_energy_matrix +compute_pair_energy_matrix_parallel = njit(signature_or_function=signature, parallel=True)(compute_pair_energy_matrix) +compute_pair_energy_matrix = njit(signature_or_function=signature)(compute_pair_energy_matrix) +""" + +#def get_args_for_numba(self, target): +# name_dict = {'l_D' : self.electrostatics_screening_length, +# # etc. } +# needed_args = inspect.get_args(target) +# return name_dict[needed_args] \ No newline at end of file diff --git a/frustratometer/optimization/EnergyTerm.py b/frustratometer/optimization/EnergyTerm.py index 6cd87bc6..db6f558a 100644 --- a/frustratometer/optimization/EnergyTerm.py +++ b/frustratometer/optimization/EnergyTerm.py @@ -216,8 +216,13 @@ def __sub__(self, other): def __truediv__(self, other): new_energy_term = EnergyTerm() + if isinstance(other, EnergyTerm): new_energy_term.use_numba = self.use_numba and other.use_numba + #new_energy_term.total_energies = other.total_energies + #new_energy_term.consider_total_energies = other.consider_total_energies + #new_energy_term.stds = other.stds + #new_energy_term.consider_stds = other.consider_stds e1=self.energy_function; e2=other.energy_function m1=self.denergy_mutation_function; m2=other.denergy_mutation_function s1=self.denergy_swap_function; s2=other.denergy_swap_function diff --git a/frustratometer/optimization/inner_product.py b/frustratometer/optimization/inner_product.py index f8421890..a619bd23 100644 --- a/frustratometer/optimization/inner_product.py +++ b/frustratometer/optimization/inner_product.py @@ -156,19 +156,34 @@ def compute_region_means_1_by_2(indicator_0, indicator_1): @jit(types.Array(types.float64, 1, 'C')(types.Array(types.float64, 1, 'A', readonly=True), types.Array(types.float64, 1, 'A', readonly=True)), nopython=True, cache=True) def compute_region_means_1_by_1(indicator_0, indicator_1): + # indicator_0: an element of an indicator1D list + # (so either low, med, or high density burial, + # with length equal to the number of residues in the protein) + # indicator_1: also an element of an indicator1D list + # (may be identical to indicator_0) + # in other words, these are 1D numpy arrays with axis length equal to the + # number of residues in the protein + # + # the calculation of region_mean n = indicator_0.shape[0] region_sum = np.zeros(2, dtype=np.float64) region_count = np.zeros(2, dtype=np.int64) - (ij, ii) = range(2) + (ij, ii) = range(2) # so ij=0, ii=1 + # ii: covariance between burial indicators (varying the density well) + # for a single residue + # ij: covariance between burial indicators (varying the density well) + # for a pair of residues for i in range(n): region_sum[ii] += indicator_0[i] * indicator_1[i] region_count[ii]=n region_mean = np.zeros(2, dtype=np.float64) if region_count[ii] > 0: - region_mean[ii] = region_sum[ii] / region_count[ii] + region_mean[ii] = region_sum[ii] / region_count[ii] # inner product of indicators / number of residues if n>1: + # it looks like region_sum[0] is always zero, so region_sum.sum()==region_sum[1]==region_sum[ii] region_mean[ij]=indicator_0.mean()*indicator_1.mean()*(n/(n - 1))-region_sum.sum()/(n*(n-1)) - return region_mean + return region_mean # (product of means - normalized dot product of indicators, + # normalized dot product of indicators) @jit(types.Array(types.float64, 2, 'C')(types.Array(types.int64, 1, 'A', readonly=True), types.Array(types.float64, 1, 'A', readonly=True)),nopython=True, cache=True) def mean_inner_product_2_by_2(repetitions,region_mean): @@ -295,10 +310,42 @@ def mean_inner_product_1_by_2(repetitions,region_mean): @jit(types.Array(types.float64, 2, 'C')(types.Array(types.int64, 1, 'A', readonly=True), types.Array(types.float64, 1, 'A', readonly=True)),nopython=True, cache=True) def mean_inner_product_1_by_1(repetitions,region_mean): - ij, ii = range(2) + # This function computes a 20x20 block of the matrix , + # which represents the covariances between different classes + # of burial indicator functions (the 1-body terms). + # The 20x20 block may or may not be centered on the main diagonal of the full matrix. + # In the case that the 20x20 block is centered on the main diagonal, + # it represents the variances of each amino acid types within + # a burial indicator class (low, medium, or high), + # AND the covariances between each combination of amino acid types + # within this same indicator class + # In the case that the 20x20 block is not centered on the main diagonal, + # it represents the covariances of each amino acid type across + # two burial indicator classes (low-med, low-high, or med-high), + # AND the covariances between all combinations of amino acid types + # between those two indicator classes. + # repetitions: number of amino acids of each type (so probably shape (20,)) + # This is a parameter of build_mean_inner_product_matrix, + # which is the only important function that calls this function. + # The repetitions argument is not modified by build_mean_inner_product_matrix + # before or after this function is called; in other words, it is passed straight through + # region_mean: list[] + # The build_mean_inner_product_matrix function has a parameter called region_means + # that is indexed during the call to this function. So it must be that region_mean + # is specific to this particular combination of burial indicator classes + # (low-low, low-med, low-high, med-med, med-high, or high-high). + # Digging deeper, we find that the region_means passed to build_mean_inner_product_matrix + # comes from the output of compute_all_region_means, which repeatedly + # calls compute_region_means_1_by_1, compute_region_means_1_by_2, and + # compute_region_means_2_by_2 to populate a 2D array. So, to understand + # the region_mean parameter of this function, we should look at + # compute_region_means_1_by_1 for different burial indicator class arguments + # (low-low, low-med, low-high, med-med, med-high, or high-high) + + ij, ii = range(2) # so ij=0, ii=1 n=repetitions - n_elements= len(repetitions) + n_elements= len(repetitions) # number of amino acid types mean_inner_product = np.zeros(n_elements**2) @@ -308,33 +355,40 @@ def mean_inner_product_1_by_1(repetitions,region_mean): if i==j: #ii mean_inner_product[id]=n[i]*region_mean[ii]+n[i]*(n[i]-1)*region_mean[ij] else: #ij + # for different amino acid types, we scale the average value by the number of each type mean_inner_product[id]=n[i]*n[j]*region_mean[ij] + # this return value has to be the outer product of the indicator function vector + # (weighted by number of contacts of each type), averaged element-wise return mean_inner_product.reshape(n_elements, n_elements) -@jit(types.Array(types.float64, 2, 'C')( - types.Array(types.int64, 1, 'A', readonly=True), - types.Array(types.float64, 2, 'A', readonly=True), - types.Array(types.float64, 3, 'A', readonly=True), - types.Array(types.float64, 3, 'A', readonly=True)), - nopython=True, parallel=False, cache=True) +#@jit(types.Array(types.float64, 2, 'C')( +# types.Array(types.int64, 1, 'A', readonly=True), +# types.Array(types.float64, 2, 'A', readonly=True), +# types.Array(types.float64, 3, 'A', readonly=True), +# types.Array(types.float64, 3, 'A', readonly=True)), +# nopython=True, parallel=False, cache=True) def build_mean_inner_product_matrix(repetitions, indicators1d, indicators2d, region_means): + # repetitions: number of amino acids of each type (so probably shape (20,)) + # indicators1D: list of 3 elements (low density, medium density, high density) + # indicators2D: list of 3 or 4 elements (dir, prot, wat, possibly elec) + num_matrices1d = len(indicators1d) num_matrices2d = len(indicators2d) n_elements = len(repetitions) - num_matrices = num_matrices1d + num_matrices2d + num_matrices = num_matrices1d + num_matrices2d # probably equal to 6 or 7 # Compute the size of each block and the total size - block_sizes = np.empty(num_matrices, dtype=np.int64) + block_sizes = np.empty(num_matrices, dtype=np.int64) # creates an array without setting elements block_sizes[:num_matrices1d] = n_elements block_sizes[num_matrices1d:] = n_elements**2 + # at this point, block_sizes looks something like [20,20,20,400,400,400] total_size = np.sum(block_sizes) - # Create the resulting matrix filled with zeros + # Create the resulting matrix (which is returned by this function) filled with zeros R = np.zeros((total_size, total_size)) # Compute the starting indices for each matrix - #start_indices = np.cumsum([0] + block_sizes[:-1]) start_indices=np.zeros(len(block_sizes),dtype=np.int64) start=0 for i in range(1,len(block_sizes)): @@ -365,21 +419,33 @@ def build_mean_inner_product_matrix(repetitions, indicators1d, indicators2d, reg if i != j: R[sj:ej, si:ei] = R[si:ei, sj:ej].T + # if we have i==j, then the transposed region is the original region and there's nothing to fill in - return R + return R # The average (over sequence shuffles) of the outer product + # of the vector formed from the set of all indicator types. + # The shuffling average was performed by multiplying by the + # proportion of amino acids in the sequence by each indicator type -@jit(types.Array(types.float64, 3, 'C')( - types.Array(types.float64, 2, 'A', readonly=True), - types.Array(types.float64, 3, 'A', readonly=True)), - nopython=True, cache=True) +#@jit(types.Array(types.float64, 3, 'C')( +# types.Array(types.float64, 2, 'A', readonly=True), +# types.Array(types.float64, 3, 'A', readonly=True)), +# nopython=True, cache=True) def compute_all_region_means(indicators1d, indicators2d): - num_matrices1d = len(indicators1d) - num_matrices2d = len(indicators2d) - num_matrices = num_matrices1d + num_matrices2d + """indicators1d: burial indicators, in the order of low, medium, high + indicators2d: contact indicators, in the order of direct, protein, water + Each array has axis length(s) equal to the number of residues in the protein""" + num_matrices1d = len(indicators1d) # 3 (low density, med density, high density) + num_matrices2d = len(indicators2d) # 3 or 4 (dir, prot, wat, possibly elec) + num_matrices = num_matrices1d + num_matrices2d # Create the resulting matrix filled with zeros R = np.zeros((num_matrices,num_matrices,15),dtype=np.float64) + # 15 deep because we need to unpack 15 return values in the cases + # where i and j correspond to 2d matrices; in the cases that + # i and j represent 1 or 0 2d matrices, we won't need to unpack + # as many return values, so we'll just fill in the first few elements + # of the third axis and the rest will remain as 0 for ij in prange(num_matrices**2): i=ij//num_matrices diff --git a/frustratometer/optimization/optimization.py b/frustratometer/optimization/optimization.py index 0bdc5e59..d4daa6ca 100644 --- a/frustratometer/optimization/optimization.py +++ b/frustratometer/optimization/optimization.py @@ -4,10 +4,12 @@ import csv from functools import wraps from datetime import datetime +import copy +from frustratometer import frustration from frustratometer.classes import Frustratometer from frustratometer.classes import Structure -from frustratometer.classes import AWSEM +from frustratometer.classes import AWSEM, AWSEMIndicators, DecoyEnsemble, AWSEMVariancePotts from frustratometer.optimization.EnergyTerm import EnergyTerm from frustratometer.optimization.inner_product import compute_all_region_means from frustratometer.optimization.inner_product import build_mean_inner_product_matrix @@ -303,6 +305,8 @@ def compute_energy(seq_index: np.array) -> float: energy_J -= model_J[i, j, aa_i, aa_j] * mask[i, j] total_energy = energy_h + energy_J / 2 + #with open('energies.txt','a') as f: + # f.write(f"{total_energy}\n") return total_energy def compute_denergy_mutation(seq_index: np.ndarray, pos: int, aa_new: int) -> float: @@ -365,6 +369,172 @@ def regression_test(self): energy=self.compute_energy(seq_index) assert np.isclose(energy,expected_energy), f"Expected energy {expected_energy} but got {energy}" +#@numba.experimental.jitclass([('_use_numba',numba.float32),('std',numba.float32),('total_energies',numba.float32[:])]) +class AwsemStdSlow(EnergyTerm): + """ Computes the standard deviation of the AWSEM energies of a set of decoy structures + by computing the energy of each decoy structure and then computing the std of the energies + """ + def __init__(self, all_burial, all_direct, all_prot, all_wat, all_elec, sequence, + alphabet=_AA, use_numba=True, **parameters): + + self._use_numba=use_numba + self.alphabet=alphabet + + self.models_h = [] + self.models_J = [] + for burial, direct, prot, wat, elec in zip(all_burial, all_direct, all_prot, all_wat, all_elec): + model = AWSEMIndicators(burial, direct, prot, wat, elec, sequence, **parameters) + self.models_h.append(model.potts_model['h']) + self.models_J.append(model.potts_model['J']) + self.mask = model.mask # should be the same for all, so we put this outside the loop + + if alphabet!=_AA: + raise NotImplementedError("Reindex your potts models according to your alphabet") + self.reindex_dca=[_AA.index(aa) for aa in alphabet] + self.model_h=self.model_h[:,self.reindex_dca] + self.model_J=self.model_J[:,:,self.reindex_dca][:,:,:,self.reindex_dca] + + self.stds = [] + self.total_energies = [] + self.consider_stds = [] + self.consider_total_energies = [] + + self.initialize_functions() + + def initialize_functions(self): + mask=self.mask.copy() + models_h=self.models_h.copy() + models_J=self.models_J.copy() + + def compute_energy(seq_index: np.array) -> float: + seq_len = len(seq_index) + to_append = np.zeros(len(models_h)) # a new array for each seq index, with a length equal to the number of decoys + for counter, models in enumerate(zip(models_h, models_J)): + model_h = models[0] + model_J = models[1] + energy_h = 0.0 + energy_J = 0.0 + for i in range(seq_len): + energy_h -= model_h[i, seq_index[i]] + for i in range(seq_len): + for j in range(seq_len): + aa_i = seq_index[i] + aa_j = seq_index[j] + energy_J -= model_J[i, j, aa_i, aa_j] * mask[i, j] + decoy_energy = energy_h + energy_J / 2 + to_append[counter] = decoy_energy + + if len(self.total_energies) == 0: + self.total_energies.append(to_append) # this may result in the total_energies list being repeated a few times because compute_energy is called a few times + self.consider_total_energies.append(copy.deepcopy(to_append)) + else: + self.total_energies[0] = to_append + self.consider_total_energies[0] = copy.deepcopy(to_append) + #breakpoint() + + std = to_append.std() + if len(self.stds) == 0: + self.stds.append(std) # this may result in the list being too long because compute_energy is called a few times + self.consider_stds.append(std) + else: + self.stds[0] = std + self.consider_stds[0] = std + #std = np.array([1,2]).var()#total_energies.var() # doing variance for now because variances are additive + #self.stds.append(std) + return std + + def compute_denergy_mutation(seq_index: np.ndarray, pos: int, aa_new: int) -> float: + aa_old=seq_index[pos] + + for counter, models in enumerate(zip(models_h, models_J)): + model_h = models[0] + model_J = models[1] + #import pdb; pdb.set_trace() + energy_difference = -model_h[pos,aa_new] + model_h[pos,aa_old] + # Initialize j_correction to 0 + j_correction = 0.0 + # Manually iterate over the sequence indices + for idx in range(len(seq_index)): + aa_idx = seq_index[idx] # The amino acid at the current position + # Accumulate corrections for positions other than the mutated one + j_correction += model_J[idx, pos, aa_idx, aa_old] * mask[idx, pos] + j_correction -= model_J[idx, pos, aa_idx, aa_new] * mask[idx, pos] + # For self-interaction, subtract the old interaction and add the new one + j_correction -= model_J[pos, pos, aa_old, aa_old] * mask[pos, pos] + j_correction += model_J[pos, pos, aa_new, aa_new] * mask[pos, pos] + energy_difference += j_correction + # our mutation might be rejected, so we don't want to overwrite self.total_energies + self.consider_total_energies[0][counter] = self.total_energies[0][counter] + energy_difference + #print(f"energy difference: {energy_difference}") + #assert not np.all(np.array(self.total_energies[0])==np.array(self.consider_total_energies[0])) + #import pdb; pdb.set_trace() + new_std = self.consider_total_energies[0].std() + self.consider_stds[0] = new_std # our mutation might be rejected, so we don't want to overwrite self.stds + delta_std = new_std - self.stds[0] + #print(f"mutation: {delta_std}") + return delta_std + + def compute_denergy_swap(seq_index, pos1, pos2): + aa2 , aa1 = seq_index[pos1],seq_index[pos2] + + for counter, models in enumerate(zip(models_h, models_J)): + model_h = models[0] + model_J = models[1] + #Compute fields + energy_difference = 0 + energy_difference -= (model_h[pos1, aa1] - model_h[pos1, seq_index[pos1]]) # h correction aa1 + energy_difference -= (model_h[pos2, aa2] - model_h[pos2, seq_index[pos2]]) # h correction aa2 + + #Compute couplings + j_correction = 0.0 + for pos in range(len(seq_index)): + aa = seq_index[pos] + # Corrections for interactions with pos1 and pos2 + j_correction += model_J[pos, pos1, aa, seq_index[pos1]] * mask[pos, pos1] + j_correction -= model_J[pos, pos1, aa, aa1] * mask[pos, pos1] + j_correction += model_J[pos, pos2, aa, seq_index[pos2]] * mask[pos, pos2] + j_correction -= model_J[pos, pos2, aa, aa2] * mask[pos, pos2] + + # J correction, interaction with self aminoacids + j_correction -= model_J[pos1, pos2, seq_index[pos1], seq_index[pos2]] * mask[pos1, pos2] # Taken two times + j_correction += model_J[pos1, pos2, aa1, seq_index[pos2]] * mask[pos1, pos2] # Correction for incorrect addition in the for loop + j_correction += model_J[pos1, pos2, seq_index[pos1], aa2] * mask[pos1, pos2] # Correction for incorrect addition in the for loop + j_correction -= model_J[pos1, pos2, aa1, aa2] * mask[pos1, pos2] # Correct combination + energy_difference += j_correction + #import pdb; pdb.set_trace() + #self.total_energies[counter] += energy_difference + self.consider_total_energies[0][counter] = self.total_energies[0][counter] + energy_difference + #print(f"energy difference: {energy_difference}") + #breakpoint() + #import pdb; pdb.set_trace() + #assert not np.all(np.array(self.total_energies[0])==np.array(self.consider_total_energies[0])) + new_std = self.consider_total_energies[0].std() + self.consider_stds[0] = new_std + delta_std = new_std - self.stds[0] + #print(f"swap: {delta_std}") + return delta_std + + self.compute_energy=compute_energy + self.compute_denergy_mutation=compute_denergy_mutation + self.compute_denergy_swap=compute_denergy_swap + +class FourBodyPottsModel(EnergyTerm): + """ Potts model with 3-body and 4-body terms. + This is mainly the same as the AWSEMEnergy class, but 2 important differences: + 1. We don't recognize any mask that may be associated with the "model" input + that's because this class is intended to evaluate changes in a covariance matrix, + which should not be masked + 2. By definition, all coefficients are 1: "Energy" = h + J + K + L + unlike real interactions, the sum over the entire covariance matrix + isn't double counting, so we don't need to multiply by 1/2 + """ + def __init__(self, model:Frustratometer, alphabet=_AA, use_numba=True): + self._use_numba=use_numba + self.model=model + self.alphabet=alphabet + self.model_h = model.potts_model['h'] + self.model_J = model.potts_model['J'] + class AwsemEnergyAverage(EnergyTerm): def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): self._use_numba=use_numba @@ -373,7 +543,7 @@ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): self.reindex_dca=[_AA.index(aa) for aa in alphabet] assert "indicators" in model.__dict__.keys(), "Indicator functions were not exposed. Initialize AWSEM function with `expose_indicator_functions=True` first." - self.indicators = model.indicators + self.indicators = model.masked_indicators self.alphabet_size=len(alphabet) self.model=model self.model_h = model.potts_model['h'][:,self.reindex_dca] @@ -382,7 +552,7 @@ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): self.indicators1D=np.array([ind for ind in self.indicators if len(ind.shape)==1]) self.indicators2D=np.array([ind for ind in self.indicators if len(ind.shape)==2]) #TODO: Fix the gamma matrix to account for elecrostatics - self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.gamma_array]) + self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.coefficient_lambda_gamma_array]) self.initialize_functions() @@ -512,15 +682,287 @@ def regression_test(self, seq_index): energy=self.compute_energy(seq_index) assert np.isclose(energy,expected_energy), f"Expected energy {expected_energy} but got {energy}" -class AwsemEnergyVariance(EnergyTerm): +class PairEnergyAverage(EnergyTerm): + """ + Computes the average pairwise energy for a given sequence. + This class is designed to compute the average pairwise energy of a sequence + using the AWSEM model. It calculates the energy contributions from pairwise interactions + between amino acids in the sequence, averaged over all possible pairs. + """ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): self._use_numba=use_numba self.model=model self.alphabet=alphabet self.reindex_dca=[_AA.index(aa) for aa in alphabet] + assert "indicators" in model.__dict__.keys(), "Indicator functions were not exposed. Initialize AWSEM function with `expose_indicator_functions=True` first." + self.indicators = model.masked_indicators + self.alphabet_size=len(alphabet) + self.model_h = model.potts_model['h'][:,self.reindex_dca] + self.model_J = model.potts_model['J'][:,:,self.reindex_dca][:,:,:,self.reindex_dca] + self.mask = model.mask + self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.coefficient_lambda_gamma_array]) + self.initialize_functions() + + def initialize_functions(self): + len_alphabet=self.alphabet_size + + distances = np.triu(self.model.distance_matrix) + ########################################################################################### + distances = distances[(distances0)] # USE THIS NORMALLY + #distances = distances[distances>0] # USE THIS FOR TESTING THING WHERE WE NEED ALL PAIRS + ########################################################################################### + len_distances = len(distances) + + rho_b = np.expand_dims(self.model.rho_r, 1) #(n,1) + rho1 = np.expand_dims(self.model.rho_r, 0) #(1,n) + rho2 = np.expand_dims(self.model.rho_r, 1) #(n,1) + + sigma_water = 0.25 * (1 - np.tanh(self.model.eta_sigma * (rho1 - self.model.rho_0))) * (1 - np.tanh(self.model.eta_sigma * (rho2 - self.model.rho_0))) #(n,n) + sigma_protein = 1 - sigma_water #(n,n) + + #Calculate theta and indicators + theta = 0.25 * (1 + np.tanh(self.model.eta * (distances - self.model.r_min))) * (1 + np.tanh(self.model.eta * (self.model.r_max - distances))) # (c,) + thetaII = 0.25 * (1 + np.tanh(self.model.eta * (distances - self.model.r_minII))) * (1 + np.tanh(self.model.eta * (self.model.r_maxII - distances))) #(c,) + burial_indicator = np.tanh(self.model.burial_kappa * (rho_b - self.model.burial_ro_min)) + np.tanh(self.model.burial_kappa * (self.model.burial_ro_max - rho_b)) #(n,3) + # gap has 0 charge + # gap, A,C,D, E, F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y + charges = np.array([0, 0,0,-1,-1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0]) + charges = charges[self.reindex_dca] # remove unused gap, C, and P + electrostatics_indicator = np.exp(-distances / self.model.electrostatics_screening_length) / distances + + N = self.model.N + k_contact = self.model.k_contact + burial_gamma = self.model.burial_gamma[self.model.aa_map_awsem_list][self.reindex_dca] + direct_gamma = self.model.direct_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + water_gamma = self.model.water_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + protein_gamma = self.model.protein_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + k_contact = self.model.k_contact + k_electrostatics = self.model.k_electrostatics + mask = self.mask + sequence_mask_contact = self.model.sequence_mask_contact + assert np.all(mask==mask.T), "Mask should be symmetric" + assert mask.shape == (N, N), f"Mask shape {mask.shape} does not match expected shape {(N, N)}" + + n_decoys=9000 + + # these lines used for the random sampling and analytic calculation + foo = np.triu(self.model.distance_matrix) + sigma_water = sigma_water[(foo0)] + sigma_protein = sigma_protein[(foo0)] + + def compute_energy(seq_index): + # adapted from AWSEM.compute_configurational_decoy_statistics, + # modified to be numba-friendly + + # analytic calculation + aa_freq = np.array([(seq_index == i).sum() for i in range(len_alphabet)]) # frequency of each amino acid in the sequence + + temp = burial_gamma * aa_freq[:, np.newaxis] # (20,3) * (20,1) -> (20,3) + scaled_burial_gamma = np.zeros((3,)) # burial gamma is a 3D vector, one for each burial indicator (low, med, high) + for counter in range(temp.shape[0]): + scaled_burial_gamma += temp[counter] + scaled_burial_gamma /= temp.shape[0] # average burial gammas, weighted by amino acid frequencies + #burial_energy = np.average(-1 * k_contact * scaled_burial_gamma * burial_indicator) + avg_burial_indicator = np.zeros((3,)) + for counter in range(burial_indicator.shape[0]): + avg_burial_indicator += burial_indicator[counter] + avg_burial_indicator /= burial_indicator.shape[0] # average burial indicator + burial_energy = -1* k_contact * avg_burial_indicator * scaled_burial_gamma + burial_energy = (burial_energy[0] + burial_energy[1] + burial_energy[2])/(N-1)#*N #/(N-1) # sum over the three burial indicators + #breakpoint() + #assert type(burial_energy) == float + # direct, water-mediated, and protein-mediated contact energies + direct = np.average(theta) * np.average(direct_gamma * np.outer(aa_freq, aa_freq)) + #assert type(direct) == float + water_mediated = np.average(thetaII*sigma_water) * np.average(water_gamma * np.outer(aa_freq, aa_freq)) + #assert type(water_mediated) == float + protein_mediated = np.average(thetaII*sigma_protein) * np.average(protein_gamma * np.outer(aa_freq, aa_freq)) + #assert type(protein_mediated) == float + #contact_energy = -k_contact * (direct*len(theta) + (water_mediated+protein_mediated)*len(thetaII)) # multiply by number of contacts + contact_energy = -k_contact * (direct + (water_mediated+protein_mediated)) # multiply by number of contacts + #electrostatics_energy = k_electrostatics * np.average(electrostatics_indicator) * np.average(np.outer(aa_freq, aa_freq)*charges[:, np.newaxis]*charges[np.newaxis, :]) * len(electrostatics_indicator) # multiply by number of contacts + electrostatics_energy = k_electrostatics * np.average(electrostatics_indicator) * np.average(np.outer(aa_freq, aa_freq)*charges[:, np.newaxis]*charges[np.newaxis, :]) + #assert type(electrostatics_energy) == float + mean_decoy_energy = burial_energy + contact_energy + electrostatics_energy + #import pdb; pdb.set_trace() + + """# constructing the distribution by sampling, then computing the average + decoy_energies=np.zeros(n_decoys) + for i in range(n_decoys): + c=np.random.randint(0,len_distances) + n1=np.random.randint(0,N) + n2=np.random.randint(0,N) + qi1=np.random.randint(0,N) + qi2=np.random.randint(0,N) + q1=seq_index[qi1] + q2=seq_index[qi2] + + burial_energy1 = (-0.5 * k_contact * burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) + burial_energy2 = (-0.5 * k_contact * burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) + burial_energy = (burial_energy1+burial_energy2)/(N-1) # normalize because double counting carlos thing + + direct = theta[c] * direct_gamma[q1, q2] + water_mediated = sigma_water[n1,n2] * thetaII[c] * water_gamma[q1,q2] + protein_mediated = sigma_protein[n1,n2] * thetaII[c] * protein_gamma[q1,q2] + contact_energy = -k_contact * (direct+water_mediated+protein_mediated) + electrostatics_energy = k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] + + decoy_energies[i]=(burial_energy+contact_energy+electrostatics_energy) + mean_decoy_energy = np.mean(decoy_energies) + #std_decoy_energy = np.std(decoy_energies) + """ + """# checking that these energy functions are able to compute the total energy of the sequence + assert len(distances) == (N**2-N)/2, f"len(distances): {len(distances)} != (N**2-N)/2: {(N**2-N)/2}" + decoy_energies = np.zeros((len(seq_index)**2 - len(seq_index))//2) + index = 0 + #breakpoint() + for i in range(len(seq_index)): + aa1 = seq_index[i] + for j in range(i+1, len(seq_index)): + aa2 = seq_index[j] + + burial_energy1 = (-0.5 * k_contact * burial_gamma[aa1] * burial_indicator[i]).sum(axis=0) + burial_energy2 = (-0.5 * k_contact * burial_gamma[aa2] * burial_indicator[j]).sum(axis=0) + burial_energy = (burial_energy1 + burial_energy2) / ((N - 1)) #/ 2) + + direct = theta[index] * direct_gamma[aa1, aa2] + water_mediated = sigma_water[i, j] * thetaII[index] * water_gamma[aa1, aa2] + protein_mediated = sigma_protein[i, j] * thetaII[index] * protein_gamma[aa1, aa2] + contact_energy = -k_contact * (direct+water_mediated+protein_mediated)*mask[i, j]*sequence_mask_contact[i, j] + electrostatics_energy = k_electrostatics * electrostatics_indicator[index]*charges[aa1]*charges[aa2]*mask[i,j] + decoy_energies[index] = contact_energy+burial_energy+electrostatics_energy#(contact_energy+electrostatics_energy)#(burial_energy + contact_energy + electrostatics_energy) + index += 1 + mean_decoy_energy = np.sum(decoy_energies) # for testing, we return the total energy, not the average + """ + + #import pdb; pdb.set_trace() + return mean_decoy_energy#, std_decoy_energy + + compute_energy_numba = self.numbify(compute_energy) + + def denergy_mutation(seq_index, pos, aa): + seq_index_new = seq_index.copy() + seq_index_new[pos] = aa + return compute_energy_numba(seq_index_new) - compute_energy_numba(seq_index) + + self.compute_energy = compute_energy + self.compute_denergy_mutation = denergy_mutation + + def regression_test(self, seq_index): + raise NotImplementedError("sorry") + +class PairEnergyStd(EnergyTerm): + def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): + self._use_numba=use_numba + self.model=model + self.alphabet=alphabet + self.reindex_dca=[_AA.index(aa) for aa in alphabet] + assert "indicators" in model.__dict__.keys(), "Indicator functions were not exposed. Initialize AWSEM function with `expose_indicator_functions=True` first." + self.indicators = model.masked_indicators + self.alphabet_size=len(alphabet) + self.model_h = model.potts_model['h'][:,self.reindex_dca] + self.model_J = model.potts_model['J'][:,:,self.reindex_dca][:,:,:,self.reindex_dca] + self.mask = model.mask + self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.coefficient_lambda_gamma_array]) + self.initialize_functions() + + def initialize_functions(self): + len_alphabet=self.alphabet_size + + distances = np.triu(self.model.distance_matrix) + ########################################################################################### + distances = distances[(distances0)] # USE THIS NORMALLY + #distances = distances[distances>0] # USE THIS FOR TESTING THING WHERE WE NEED ALL PAIRS + ########################################################################################### + len_distances = len(distances) + + rho_b = np.expand_dims(self.model.rho_r, 1) #(n,1) + rho1 = np.expand_dims(self.model.rho_r, 0) #(1,n) + rho2 = np.expand_dims(self.model.rho_r, 1) #(n,1) + + sigma_water = 0.25 * (1 - np.tanh(self.model.eta_sigma * (rho1 - self.model.rho_0))) * (1 - np.tanh(self.model.eta_sigma * (rho2 - self.model.rho_0))) #(n,n) + sigma_protein = 1 - sigma_water #(n,n) + + #Calculate theta and indicators + theta = 0.25 * (1 + np.tanh(self.model.eta * (distances - self.model.r_min))) * (1 + np.tanh(self.model.eta * (self.model.r_max - distances))) # (c,) + thetaII = 0.25 * (1 + np.tanh(self.model.eta * (distances - self.model.r_minII))) * (1 + np.tanh(self.model.eta * (self.model.r_maxII - distances))) #(c,) + burial_indicator = np.tanh(self.model.burial_kappa * (rho_b - self.model.burial_ro_min)) + np.tanh(self.model.burial_kappa * (self.model.burial_ro_max - rho_b)) #(n,3) + # gap has 0 charge + # gap, A,C,D, E, F,G,H,I,K,L,M,N,P,Q,R,S,T,V,W,Y + charges = np.array([0, 0,0,-1,-1,0,0,0,0,1,0,0,0,0,0,1,0,0,0,0,0]) + charges = charges[self.reindex_dca] # remove unused gap, C, and P + electrostatics_indicator = np.exp(-distances / self.model.electrostatics_screening_length) / distances + + N = self.model.N + k_contact = self.model.k_contact + burial_gamma = self.model.burial_gamma[self.model.aa_map_awsem_list][self.reindex_dca] + direct_gamma = self.model.direct_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + water_gamma = self.model.water_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + protein_gamma = self.model.protein_gamma[self.model.aa_map_awsem_x, self.model.aa_map_awsem_y][self.reindex_dca][:,self.reindex_dca] + k_contact = self.model.k_contact + k_electrostatics = self.model.k_electrostatics + mask = self.mask + sequence_mask_contact = self.model.sequence_mask_contact + assert np.all(mask==mask.T), "Mask should be symmetric" + assert mask.shape == (N, N), f"Mask shape {mask.shape} does not match expected shape {(N, N)}" + + n_decoys=9000 + #foo = np.triu(self.model.distance_matrix) + #sigma_water = sigma_water[(foo0)] + #sigma_protein = sigma_protein[(foo0)] + + def compute_energy(seq_index): + # constructing the distribution by sampling, then computing the average + decoy_energies=np.zeros(n_decoys) + for i in range(n_decoys): + c=np.random.randint(0,len_distances) + n1=np.random.randint(0,N) + n2=np.random.randint(0,N) + qi1=np.random.randint(0,N) + qi2=np.random.randint(0,N) + q1=seq_index[qi1] + q2=seq_index[qi2] + + burial_energy1 = (-0.5 * k_contact * burial_gamma[q1] * burial_indicator[n1]).sum(axis=0) + burial_energy2 = (-0.5 * k_contact * burial_gamma[q2] * burial_indicator[n2]).sum(axis=0) + burial_energy = (burial_energy1+burial_energy2)/(N-1) # normalize because double counting carlos thing + + direct = theta[c] * direct_gamma[q1, q2] + water_mediated = sigma_water[n1,n2] * thetaII[c] * water_gamma[q1,q2] + protein_mediated = sigma_protein[n1,n2] * thetaII[c] * protein_gamma[q1,q2] + contact_energy = -k_contact * (direct+water_mediated+protein_mediated) + electrostatics_energy = k_electrostatics * electrostatics_indicator[c]*charges[q1]*charges[q2] + + decoy_energies[i]=(burial_energy+contact_energy+electrostatics_energy) + + #mean_decoy_energy = np.mean(decoy_energies) + std_decoy_energy = np.std(decoy_energies) + return 1#std_decoy_energy + + compute_energy_numba = self.numbify(compute_energy) + + def denergy_mutation(seq_index, pos, aa): + seq_index_new = seq_index.copy() + seq_index_new[pos] = aa + return compute_energy_numba(seq_index_new) - compute_energy_numba(seq_index) + + self.compute_energy = compute_energy + self.compute_denergy_mutation = denergy_mutation + + def regression_test(self, seq_index): + raise NotImplementedError("sorry") + + +class AwsemEnergyVariance(EnergyTerm): + def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): + self._use_numba=use_numba + self.model=model + self.alphabet=alphabet + self.reindex_dca=[_AA.index(aa) for aa in alphabet] + assert "indicators" in model.__dict__.keys(), "Indicator functions were not exposed. Initialize AWSEM function with `expose_indicator_functions=True` first." - self.indicators = model.indicators + self.indicators = model.masked_indicators self.alphabet_size=len(alphabet) self.model=model self.model_h = model.potts_model['h'][:,self.reindex_dca] @@ -529,7 +971,7 @@ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA): self.indicators1D=np.array([ind for ind in self.indicators if len(ind.shape)==1]) self.indicators2D=np.array([ind for ind in self.indicators if len(ind.shape)==2]) #TODO: Fix the gamma matrix to account for elecrostatics - self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.gamma_array]) + self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.coefficient_lambda_gamma_array]) self.initialize_functions() @@ -697,6 +1139,7 @@ def regression_test(self, seq_index): energy=self.compute_energy(seq_index) assert np.isclose(energy,expected_energy), f"Expected energy {expected_energy} but got {energy}" + class AwsemEnergyStd(EnergyTerm): def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA, n_decoys=None): self._use_numba=use_numba @@ -707,7 +1150,7 @@ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA, n_decoys= self.n_decoys=n_decoys assert "indicators" in model.__dict__.keys(), "Indicator functions were not exposed. Initialize AWSEM function with `expose_indicator_functions=True` first." - self.indicators = model.indicators + self.indicators = model.masked_indicators self.alphabet_size=len(alphabet) self.model=model self.model_h = model.potts_model['h'][:,self.reindex_dca] @@ -716,8 +1159,8 @@ def __init__(self, model:Frustratometer, use_numba=True, alphabet=_AA, n_decoys= self.indicators1D=np.array([ind for ind in self.indicators if len(ind.shape)==1]) self.indicators2D=np.array([ind for ind in self.indicators if len(ind.shape)==2]) #TODO: Fix the gamma matrix to account for elecrostatics - self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.gamma_array]) - + self.gamma = np.concatenate([(a[self.reindex_dca].ravel() if len(a.shape)==1 else a[self.reindex_dca][:,self.reindex_dca].ravel()) for a in model.coefficient_lambda_gamma_array]) + self.initialize_functions() def initialize_functions(self): @@ -726,7 +1169,8 @@ def initialize_functions(self): len_alphabet=self.alphabet_size phi_len= indicators1D.shape[0]*len_alphabet + indicators2D.shape[0]*len_alphabet**2 gamma=self.gamma - + rng = np.random.default_rng() + # Precompute the mean of the indicators indicator_means=np.zeros(len(indicators1D)+len(indicators2D)) c=0 @@ -750,10 +1194,13 @@ def compute_energy(seq_index): """ Function to compute the variance of the energy of permutations of a sequence using random shuffling. This function is much faster than compute_energy_permutation but is an approximation""" energies=np.zeros(n_decoys) - shuffled_index=seq_index.copy() + shuffled_index=seq_index.copy() + for _ in numba.prange(20): + to_replace = rng.integers(low=0,high=len(seq_index)) + shuffled_index[to_replace] = rng.integers(low=0,high=len_alphabet) for i in numba.prange(n_decoys): energies[i]=awsem_energy(shuffled_index[np.random.permutation(len(shuffled_index))]) - return np.var(energies) + return np.std(energies) else: def compute_energy(seq_index): counts = np.zeros(len_alphabet, dtype=np.int64) @@ -774,11 +1221,23 @@ def compute_energy(seq_index): for i in range(len_indicators2D): for j in range(len_alphabet): for k in range(len_alphabet): - t=1 if j==k else 0 - phi_mean[c] = indicator_means[i+ len_indicators1D] * counts[j] * (counts[k] - t) + t=1 if j==k else 0 # I don't know why we do this + phi_mean[c] = indicator_means[i+ len_indicators1D] * counts[j] * (counts[k] - t) c += 1 B = build_mean_inner_product_matrix(counts,indicators1D,indicators2D,region_means) + # B[i,j] - phi_mean[i]*phi_mean[j] is the covariance of some avg-indicator/gamma + # product i with some other avg-indicator/gamma product j + # + # we can think of computing the total variance (summing this matrix) + # as evaluating a potts model (covariances playing the role of "energies") that has + # 3 fields (the burial indicator function for each density bin) and + # 4 couplings (direct, prot, wat, and elec pairwise indicators) + # + # because we averaged over all indicators and they're playing the role of + # couplings and fields in our model, this is a "mean field" approach + # + energy=0 for i in range(phi_len): for j in range(phi_len): @@ -791,6 +1250,9 @@ def denergy_mutation(seq_index, pos, aa): seq_index_new = seq_index.copy() seq_index_new[pos] = aa return compute_energy_numba(seq_index_new) - compute_energy_numba(seq_index) + + def denergy_swap(seq_index, pos1, pos2): + return 0 self.compute_energy = compute_energy self.compute_denergy_mutation = denergy_mutation @@ -812,6 +1274,202 @@ def regression_test(self, seq_index): energy=self.compute_energy(seq_index) assert np.isclose(energy,expected_energy), f"Expected energy {expected_energy} but got {energy}" + +class AwsemEnergyStdFromCovMatrix(EnergyTerm): + def __init__(self, covariance_matrix: np.ndarray, + burial_gamma: np.ndarray, + direct_gamma: np.ndarray, + protein_gamma: np.ndarray, + water_gamma: np.ndarray, + electrostatics_gamma: np.ndarray, + use_numba = True, alphabet = _AA): + """ + covariance_matrix: np.ndarray + Covariance matrix of all __indicator functions___ (not residues) over a decoy set. + Should have the following structure: + ___________________________________________________________________________________________ + burial pairwise + _____________________________________________ + position 1 low | . | + ... | . | + position N low | . | + position 1 med | burial-burial . burial-pairwise | + burial ... | . | + position N med | . | + position 1 high | covariances . covariances | + ... | . | + position N high | . | + ------------------------------------------------------------------------------------------- + direct interaction 1 | . | + ... | . | + direct interaction (N**2-N)/2 | . | + prot interaction 1 | . | + ... | . | + pairwise prot interaction (N**2-N)/2 | burial-pairwise . pairwise-pairwise | + wat interaction 1 | covariances . covariances | + ... | . | + wat interaction (N**2-N)/2 | . | + elec interaction 1 | . | + ... | . | + elec interaction (N**2-N)/2 | . | + ___________________________________________________________________________________________| + + This matrix should have the same form as the covariance matrix that would be passed into + AwsemVariancePotts, if we were doing things that way. + + gamma arrays: INPUTS MUST BE REINDEXED according to the alphabet used! + + """ + # check input + if not len(covariance_matrix.shape) == 2: + raise ValueError(f"covariance_matrix must have dimension 2 but was {len(covariance_matrix.shape)}") + if not covariance_matrix.shape[0] == covariance_matrix.shape[1]: + raise ValueError(f"covariance_matrix dimensions were not equal. covariance_matrix.shape: {covariance_matrix.shape}") + if not burial_gamma.shape[1] == 3: + raise ValueError(f"burial_gamma.shape[1] should be 3 but was {burial_gamma.shape[1]}") + if not direct_gamma.shape[0]==direct_gamma.shape[1]\ + or not protein_gamma.shape[0]==protein_gamma.shape[1]\ + or not water_gamma.shape[0]==water_gamma.shape[1]\ + or not electrostatics_gamma.shape[0]==electrostatics_gamma.shape[1]: + raise ValueError("check gamma shapes") + if not burial_gamma.shape[0] == len(alphabet): + raise ValueError(f"alphabet {alphabet} and burial_gamma shape {burial_gamma.shape} are inconsistent") + if not direct_gamma.shape[0] == len(alphabet): + raise ValueError(f"alphabet {alphabet} and direct_gamma shape {direct_gamma.shape} are inconsistent") + if not protein_gamma.shape[0] == len(alphabet): + raise ValueError(f"alphabet {alphabet} and protein_gamma shape {protein_gamma.shape} are inconsistent") + if not water_gamma.shape[0] == len(alphabet): + raise ValueError(f"alphabet {alphabet} and water_gamma shape {water_gamma.shape} are inconsistent") + if not electrostatics_gamma.shape[0] == len(alphabet): + raise ValueError(f"alphabet {alphabet} and electrostatics_gamma shape {electrostatics_gamma.shape} are inconsistent") + # set attributes + self.covariance_matrix = covariance_matrix + self._use_numba = use_numba + self.alphabet = alphabet + self.alphabet_size = len(alphabet) + N = 0 + while 3*N + 4*((N**2-N)/2) < self.covariance_matrix.shape[0]: + N += 1 + if not 3*N + 4*((N**2-N)/2) == self.covariance_matrix.shape[0]: + raise ValueError(f"the covariance matrix seems to have been constructed incorrectly. covariance_matrix.shape: {covariance_matrix.shape}") + self.N = N # number of amino acids + # compute products of gamma parameters for each indicator class + # (burial low density, burial med, burial high, direct, prot, wat, elec) + # for each combination of amino acids (in general, 4 total because we have + # to evaluate the covariance of two pairwise indicators each depending on + # the amino acid identity at 2 different positions in the sequence) + gamma = np.zeros((7,7,len(alphabet),len(alphabet),len(alphabet),len(alphabet))) + # we don't need the third and fourth axes for burial-burial covariances, so we copy the + # 2D outer product along both new axes so that the gamma array is not ragged + # (the third and fourth axes are needed for other terms) + #gamma[0,0] = np.repeat(np.outer(burial_gamma[:,0],burial_gamma[:,0])[:,None,:,None], len(alphabet), axis=1)# low burial- low burial + gamma[0,0] = np.outer(burial_gamma[:,0],burial_gamma[:,0])[:,None,:,None]# low burial- low burial + gamma[0,1] = np.outer(burial_gamma[:,0],burial_gamma[:,1])[:,None,:,None]# low burial- med burial + gamma[0,2] = np.outer(burial_gamma[:,0],burial_gamma[:,2])[:,None,:,None]# low burial- high burial + gamma[0,3] = np.einsum('i,jk->ijk', burial_gamma[:,0], direct_gamma)[:,None,:,:] # low burial- direct + gamma[0,4] = np.einsum('i,jk->ijk', burial_gamma[:,0], protein_gamma)[:,None,:,:] # low burial- prot + gamma[0,5] = np.einsum('i,jk->ijk', burial_gamma[:,0], water_gamma)[:,None,:,:] # low burial- wat + gamma[0,6] = np.einsum('i,jk->ijk', burial_gamma[:,0], electrostatics_gamma)[:,None,:,:] # low burial- elec + gamma[1,1] = np.outer(burial_gamma[:,1],burial_gamma[:,1])[:,None,:,None]# med burial- med burial + gamma[1,2] = np.outer(burial_gamma[:,1],burial_gamma[:,2])[:,None,:,None]# med burial- high burial + gamma[1,3] = np.einsum('i,jk->ijk', burial_gamma[:,1], direct_gamma)[:,None,:,:] # med burial- direct + gamma[1,4] = np.einsum('i,jk->ijk', burial_gamma[:,1], protein_gamma)[:,None,:,:] # med burial- prot + gamma[1,5] = np.einsum('i,jk->ijk', burial_gamma[:,1], water_gamma)[:,None,:,:] # med burial- wat + gamma[1,6] = np.einsum('i,jk->ijk', burial_gamma[:,1], electrostatics_gamma)[:,None,:,:] # med burial- elec + gamma[2,2] = np.outer(burial_gamma[:,2],burial_gamma[:,2])[:,None,:,None]# high burial- high burial + gamma[2,3] = np.einsum('i,jk->ijk', burial_gamma[:,2], direct_gamma)[:,None,:,:] # high burial- direct + gamma[2,4] = np.einsum('i,jk->ijk', burial_gamma[:,2], protein_gamma)[:,None,:,:] # high burial- prot + gamma[2,5] = np.einsum('i,jk->ijk', burial_gamma[:,2], water_gamma)[:,None,:,:] # high burial- wat + gamma[2,6] = np.einsum('i,jk->ijk', burial_gamma[:,2], electrostatics_gamma)[:,None,:,:] # high burial- elec + gamma[3,3] = np.einsum('ij,kl->ijkl', direct_gamma, direct_gamma) # direct- direct + gamma[3,4] = np.einsum('ij,kl->ijkl', direct_gamma, protein_gamma) # direct- prot + gamma[3,5] = np.einsum('ij,kl->ijkl', direct_gamma, water_gamma) # direct- wat + gamma[3,6] = np.einsum('ij,kl->ijkl', direct_gamma, electrostatics_gamma) # direct- elec + gamma[4,4] = np.einsum('ij,kl->ijkl', protein_gamma, protein_gamma) # prot- prot + gamma[4,5] = np.einsum('ij,kl->ijkl', protein_gamma, water_gamma) # prot- wat + gamma[4,6] = np.einsum('ij,kl->ijkl', protein_gamma, electrostatics_gamma) # prot- elec + gamma[5,5] = np.einsum('ij,kl->ijkl', water_gamma, water_gamma) # wat- wat + gamma[5,6] = np.einsum('ij,kl->ijkl', water_gamma, electrostatics_gamma) # wat- elec + gamma[6,6] = np.einsum('ij,kl->ijkl', electrostatics_gamma, electrostatics_gamma) # elec- elec + self.gamma = gamma + gamma.transpose((0,1,5,4,3,2)) # keep the indicator class axes the same + # but transpose the gamma values + # define energy evaluation functions + self.initialize_functions() + + @staticmethod + def covariance_type(N, i, j): + # N: total number of residues + # i: first position in the covariance matrix + # j: second position in the covariance matrix + if i < 3*N: + type_i = i//N + else: + type_i = (i-3*N)//((N**2-N)//2) + if j < 3*N: + type_j = j//N + else: + type_j = (j-3*N)//((N**2-N)//2) + return (type_i, type_j) + + @staticmethod + def residue_identities(N, i, j, seq_index, indexing_helper_rowflatten, indexing_helper_columnflatten): + # indexing helper rowflatten looks like [0, ..., 0, 1, ..., 1, ..., N] + # ^ repeated N times + # ^ repeated N-1 times + # ^ repeated once + # indexing helper columnflatten looks like [0, ..., N, 1, ..., N, ..., N] + if i < 3*N: + i_pos = i%3 + i_aa = (seq_index[i_pos], seq_index[i_pos]) + else: + i_pos1 = indexing_helper_rowflatten[(i-3*N)%((N**2-N)//2)] + i_pos2 = indexing_helper_columnflatten[(i-3*N)%((N**2-N)//2)] + i_aa = (seq_index[i_pos1], seq_index[i_pos2]) + if j < 3*N: + j_pos = j%3 + j_aa = (seq_index[j_pos], seq_index[j_pos]) + else: + j_pos1 = indexing_helper_rowflatten[(j-3*N)%((N**2-N)//2)] + j_pos2 = indexing_helper_columnflatten[(j-3*N)%((N**2-N)//2)] + j_aa = (seq_index[j_pos1], seq_index[j_pos2]) + return i_aa + j_aa # concatenating tuples + + def initialize_functions(self): + covariance_matrix = self.covariance_matrix + gamma = self.gamma + N = self.N # number of amino acids + indexing_helper_rowflatten = np.repeat(np.arange(N).reshape((N,1)),N,axis=1)[np.triu_indices(N)] + indexing_helper_columnflatten = np.transpose(np.repeat(np.arange(N).reshape((N,1)),N,axis=1))[np.triu_indices(N)] + covariance_type = self.covariance_type + residue_identities = self.residue_identities + + def compute_energy(seq_index): + energy = 0 + for i in range(covariance_matrix.shape[0]): + for j in range(i,covariance_matrix.shape[1]): + try: + energy += covariance_matrix[i,j] * gamma[covariance_type(N,i,j)+residue_identities(N, i,j,seq_index, indexing_helper_rowflatten, indexing_helper_columnflatten )] + except: + breakpoint() + return energy**0.5 + compute_energy_numba=self.numbify(compute_energy) + + def compute_denergy_mutation(seq_index, pos, aa): + seq_index_new = seq_index.copy() + seq_index_new[pos] = aa + return compute_energy_numba(seq_index_new) - compute_energy_numba(seq_index) + + def compute_denergy_swap(seq_index, pos1, pos2): + seq_index_new = seq_index.copy() + aa2 , aa1 = seq_index[pos1],seq_index[pos2] + seq_index_new[pos1] = aa1 + seq_index_new[pos2] = aa2 + return compute_energy_numba(seq_index_new) - compute_energy_numba(seq_index) + + self.compute_energy = compute_energy + self.compute_denergy_mutation = compute_denergy_mutation + self.compute_denergy_swap = compute_denergy_swap + class Similarity(EnergyTerm): """ Computes the energy of a sequence based on the similarity to a target sequence. The similarity is calculated as the number of positions that are the same in the two sequences. @@ -843,8 +1501,6 @@ def denergy_swap(seq_index, pos1, pos2): self.compute_denergy_mutation = denergy_mutation self.compute_denergy_swap = denergy_swap - - class MonteCarlo: def __init__(self, sequence: str, energy: EnergyTerm, alphabet:str=_AA, use_numba:bool=True, evaluation_energies:dict={}): self.seq_len=len(sequence) @@ -908,9 +1564,17 @@ def montecarlo_steps(temperature, seq_index, n_steps = 1000, kb = 0.008314) -> n for _ in range(n_steps): new_sequence, energy_difference = sequence_swap(seq_index) if np.random.random() > 0.5 else sequence_mutation(seq_index) exponent= (-energy_difference) / (kb * temperature + 1E-10) + #breakpoint() acceptance_probability = np.exp(min(0, exponent)) + #assert ((acceptance_probability == 0) or (acceptance_probability ==1)), acceptance_probability + #print(acceptance_probability) if np.random.random() < acceptance_probability: seq_index = new_sequence + #print(f"before reassignment: {self.energy.stds}") + #self.energy.stds = copy.deepcopy(self.energy.consider_stds) + #print(f"after reassignment: {self.energy.stds}") + #self.energy.total_energies = copy.deepcopy(self.energy.consider_total_energies) + #print(f"energy_difference: {energy_difference}") return seq_index montecarlo_steps=self.numbify(montecarlo_steps) @@ -987,7 +1651,7 @@ def parallel_tempering(self, seq_indices=None, temperatures=np.logspace(0,6,25), # Run the simulation and append data periodically for s, updated_seq_indices, total_energy in self.parallel_tempering_steps(seq_indices, temperatures, n_steps, n_steps_per_cycle): # Prepare data for this chunk - energies={key:energy_term.energies(seq_indices) for key,energy_term in self.evaluation_energies.items()} + energies={key:energy_term.energies(updated_seq_indices) for key,energy_term in self.evaluation_energies.items()} for i, temp in enumerate(temperatures): sequence_str = index_to_sequence(updated_seq_indices[i],alphabet=self.alphabet) # Convert sequence index back to string step_data=({'Step': (s+1) * n_steps_per_cycle, 'Temperature': temp, 'Sequence': sequence_str, 'Total Energy': total_energy[i]}) @@ -1000,7 +1664,7 @@ def annealing(self, seq_index=None, temperatures=np.arange(500,0,-1), n_steps=in seq_index = self.generate_random_sequences(1)[0] done_steps=0 - total_energy = self.energy.energy(seq_index) + total_energy = self.energy.energy(seq_index, ) #Write data to file step_data={'Step': done_steps, 'Temperature': temperatures[0], 'Sequence': index_to_sequence(seq_index,alphabet=self.alphabet), 'TotalEnergy': total_energy} @@ -1009,6 +1673,7 @@ def annealing(self, seq_index=None, temperatures=np.arange(500,0,-1), n_steps=in for t,temp in enumerate(temperatures): steps=(n_steps-done_steps)//(len(temperatures)-t) + assert steps >= 1, f"steps: {steps}" seq_index= self.montecarlo_steps(temp, seq_index, n_steps=steps) total_energy = self.energy.energy(seq_index) done_steps+=steps @@ -1109,111 +1774,33 @@ def find_optimal_replicas(self, max_replicas=32, n_repeats=5, n_steps=10000): if __name__ == '__main__': - native_pdb = "tests/data/1r69.pdb" + reduced_alphabet = 'ADEFGHIKLMNQRSTVWY' + pdb = "tests/data/1r69.pdb" + s = Structure(pdb, chain=None) + model = AWSEM(s, expose_indicator_functions=True, + distance_cutoff_contact=10, min_sequence_separation_contact=2,) + variance = AwsemEnergyVariance(model, alphabet=reduced_alphabet) + monte_carlo = MonteCarlo(sequence="SISSRVKSKRIQLGLNQAELAQKVGTTQQSIEQLENGKTKRPRFLPELASALGVSVDWLLNGT", + energy=variance, alphabet=reduced_alphabet) + monte_carlo.annealing(n_steps=10) + exit() + + pdb_list = ["tests/data/1r69.pdb","tests/data/1r69.pdb","tests/data/1r69.pdb"] + pdb_structures = (Structure(pdb, chain=None) for pdb in pdb_list) + ensemble = DecoyEnsemble(pdb_structures, distance_cutoff_contact=10, min_sequence_separation_contact=10) + burial_indicators, direct_indicators, protein_indicators, water_indicators, electrostatics_indicators = ensemble.average() + average = AWSEMIndicators(burial_indicators, direct_indicators, protein_indicators, water_indicators, electrostatics_indicators, + "SISSRVKSKRIQLGLNQAELAQKVGTTQQSIEQLENGKTKRPRFLPELASALGVSVDWLLNGT") - structure_bound = Structure.full_pdb(native_pdb, chain=None) - structure_free = Structure.full_pdb(native_pdb, "A") - - model_bound = AWSEM(structure_bound, distance_cutoff_contact=10, min_sequence_separation_contact=2, expose_indicator_functions=True) - model_free = AWSEM(structure_free, distance_cutoff_contact=10, min_sequence_separation_contact=2, expose_indicator_functions=True) - reduced_alphabet = 'ADEFHIKLMNQRSTVWY' - - print(model_bound.sequence) - print(model_free.sequence) - - # binding_region=np.array([1, 2, 3, 4, 26, 27, 28, 29, 30, 31, 32, 33, 49, 50, 51, 52, 53, 54, 55, 56, 57, 68, 69, 70, 90, 91, 92, 93, 94, 95, 96, 97, 109, 110, 111, 112, 113, 114, 115, 116, 117, 127, 128, 129, 130, 131, 132, 133, 134, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 190, 191, 192, 193, 194, 195, 196, 197])-1 - energy_bound = AwsemEnergySelected(model_bound, alphabet=reduced_alphabet, selection=np.array(range(len(model_free.sequence)))) - energy_free = AwsemEnergySelected(model_free, alphabet=reduced_alphabet) - energy_average = AwsemEnergyAverage(model_free, alphabet=reduced_alphabet) - energy_std = AwsemEnergyStd(model_free, alphabet=reduced_alphabet) - energy_std_100 = AwsemEnergyStd(model_free, alphabet=reduced_alphabet, n_decoys=100) - energy_std_1000 = AwsemEnergyStd(model_free, alphabet=reduced_alphabet, n_decoys=1000) - energy_std_10000 = AwsemEnergyStd(model_free, alphabet=reduced_alphabet, n_decoys=10000) - energy_variance = AwsemEnergyVariance(model_free, alphabet=reduced_alphabet) - heterogeneity = Heterogeneity(exact=False, use_numba=True) - similarity = Similarity(model_free.sequence, use_numba=True) - - # energy_mix = energy_free - 20 * heterogeneity - energy_mix = (energy_free - energy_average) / energy_std - # energy_mix = energy_bound - energy_free - - energy_mixes = {"EnergyFree": energy_free, - "EnergyBound": energy_bound, - "Heterogeneity": heterogeneity, - "EnergyAverage": energy_average, - "EnergyStd_ndecoys100": energy_std_100, - "EnergyStd_ndecoys1000": energy_std_1000, - "EnergyStd_ndecoys10000": energy_std_10000, - "EnergyStd": energy_std, - "Zscore_ndecoys10000":(energy_free - energy_average) / energy_std_10000, - "Zscore":(energy_free - energy_average) / energy_std, - "EnergyVariance": energy_variance, - "Binding": (energy_bound - energy_free), - "Similarity": similarity, - "Ivan":energy_bound - 40 * heterogeneity, - "Takada": (energy_bound - energy_average) / energy_std, - "Ivan_binding":(energy_bound - energy_free) - 40 * heterogeneity, - "Takada_binding":(energy_free - energy_average) / energy_std + (energy_bound - energy_free), - "Ivan_Takada_binding": (energy_free - energy_average) / energy_std + (energy_bound - energy_free) - 40 * heterogeneity, - "Corrected_Takada": (energy_bound - energy_average) / (energy_std+5), - "Corrected_Takada_binding":(energy_free - energy_average) / (energy_std+5) + (energy_bound - energy_free), - "Ivan_Corrected_Takada_binding": (energy_free - energy_average) / (energy_std+5) + (energy_bound - energy_free) - 40 * heterogeneity, - "Ivan_bindidng similarity": (energy_bound - energy_free) - 40 * heterogeneity - 100*similarity, - "Corrected_Takada_binding_similarity":(energy_free - energy_average) / (energy_std+5) + (energy_bound - energy_free) - 100*similarity, - "Ivan_bindidng_similarityv2": (energy_bound - energy_free) - 40 * heterogeneity} - - for energy_name,energy_term in energy_mixes.items(): - print (f"Energy term: {energy_name}") - energy_term.benchmark(seq_indices=np.random.randint(0, len(reduced_alphabet), size=(100,len(structure_free.sequence)))) - if "ndecoys" not in energy_name: - energy_term.test(seq_index=np.random.randint(0, len(reduced_alphabet), size=len(structure_free.sequence))) - - monte_carlo = MonteCarlo(sequence = structure_free.sequence, energy=energy_term, alphabet=reduced_alphabet) - monte_carlo.benchmark_montecarlo_steps(n_repeats=3,n_steps=10000) - monte_carlo.benchmark_parallel_montecarlo_steps(n_repeats=3, n_steps=10000, n_replicas=8) - - + reduced_alphabet = 'ADEFGHIKLMNQRSTVWY' - # Profiling of the parallel tempering - import cProfile - import pstats - import io - - monte_carlo = MonteCarlo(sequence=model_free.sequence, energy=energy_mix, alphabet=reduced_alphabet) - # evaluation_energies={"EnergyFree": energy_free, "Heterogeneity": heterogeneity, - # "EnergyAverage": energy_average, "EnergyStd": energy_std, - # "Similarity": similarity, "Zscore":(energy_free - energy_average) / energy_std}) - - monte_carlo.benchmark_montecarlo_steps(n_repeats=3, n_steps=100) - for n_replicas in [1, 2, 4, 8, 16]: - print(f"Running parallel tempering with {n_replicas} replicas") - monte_carlo.benchmark_parallel_montecarlo_steps(n_repeats=3, n_steps=100, n_replicas=n_replicas) - - for n_replicas in [1, 2, 4, 8, 16]: - print(f"Running parallel tempering with {n_replicas} replicas") - monte_carlo.benchmark_parallel_montecarlo_steps(n_repeats=3, n_steps=1000, n_replicas=n_replicas) - - monte_carlo.find_optimal_replicas(max_replicas=32, n_repeats=5, n_steps=1000) - monte_carlo.find_optimal_replicas(max_replicas=8, n_repeats=5, n_steps=10000) - monte_carlo.find_optimal_replicas(max_replicas=8, n_repeats=5, n_steps=100000) - - # # Run the profiler - # profiler = cProfile.Profile() - # profiler.enable() + awsem_energy = AwsemEnergy(average, alphabet=reduced_alphabet) + heterogeneity = Heterogeneity(exact=False, use_numba=True) - # monte_carlo.parallel_tempering(temperatures=np.logspace(3,-4,8), n_steps=1E4, n_steps_per_cycle=1E2) - # profiler.disable() - - # # Print the stats - # s = io.StringIO() - # ps = pstats.Stats(profiler, stream=s).sort_stats('cumulative') - # ps.print_stats() - # ps.dump_stats('parallel_temperingv2.prof') - - - - + energy_mix = awsem_energy - 10*heterogeneity + monte_carlo = MonteCarlo(sequence = "SISSRVKSKRIQLGLNQAELAQKVGTTQQSIEQLENGKTKRPRFLPELASALGVSVDWLLNGT", energy=energy_mix, alphabet=reduced_alphabet) + monte_carlo.annealing(n_steps=1000) diff --git a/frustratometer/pdb/fix.py b/frustratometer/pdb/fix.py index 142656cb..e8c0185f 100644 --- a/frustratometer/pdb/fix.py +++ b/frustratometer/pdb/fix.py @@ -13,7 +13,7 @@ def repair_pdb(pdb_file: str, chain: str, pdb_directory: Path= Path.cwd()) -> PD pdb_file: str, PDB file location. chain: str, - Chain ID + Chain ID -- can be formatted as str or list (or None) pdb_directory: str, PDB file location @@ -51,5 +51,19 @@ def repair_pdb(pdb_file: str, chain: str, pdb_directory: Path= Path.cwd()) -> PD print("Unable to add missing atoms") fixer.addMissingHydrogens(7.0) - PDBFile.writeFile(fixer.topology, fixer.positions, open(f"{pdb_directory}/{pdbID}_cleaned.pdb", 'w')) + + # renumber residues so that each chain starts at 1 + new_top = type(fixer.topology)() # an openmm.app.Topology accesed without needing to import it separately + for old_chain in fixer.topology.chains(): + new_chain = new_top.addChain(id=old_chain.id) + for old_residue in old_chain.residues(): + new_residue = new_top.addResidue(old_residue.name,new_chain,id=None) # allow the class to choose residue id + for old_atom in old_residue.atoms(): + new_atom = new_top.addAtom(old_atom.name,old_atom.element,new_residue,id=old_atom.id) + + # use keepIds=True when writing to preserve chain IDs + # changing it to False causes test_multichain_density to fail for structure_file1-density_file1 + # because chains a, b, c, ... get renamed to A, B, C, ... and end up getting confused + # with the real chains A, B, C, ... + PDBFile.writeFile(new_top, fixer.positions, open(f"{pdb_directory}/{pdbID}_cleaned.pdb", 'w'),keepIds=True) return fixer \ No newline at end of file diff --git a/frustratometer/pdb/pdb.py b/frustratometer/pdb/pdb.py index 853e6e62..b5d5f0bb 100644 --- a/frustratometer/pdb/pdb.py +++ b/frustratometer/pdb/pdb.py @@ -37,7 +37,8 @@ def download(pdbID: str,directory: Union[Path,str]=Path.cwd()) -> Path: return pdb_file def get_sequence(pdb_file: str, - chain: str + chain: str, + return_start_mask: bool=False ) -> str: """ Get a protein sequence from a pdb file @@ -46,8 +47,10 @@ def get_sequence(pdb_file: str, ---------- pdb_file : str, PDB file location. - chain: str, - Chain ID of the selected protein. + chain: str or list, + Chain ID(s) of the selected protein. + return_start_mask: bool, + Return binary mask list indicating whether each position is the start of a chain Returns ------- @@ -58,37 +61,66 @@ def get_sequence(pdb_file: str, Get a protein sequence from a PDB file :param pdb: PDB file location - :param chain: chain name of PDB file to get sequence + :param chain: chain name(s) of PDB file to get sequence :return: protein sequence """ - if ".cif" in str(pdb_file): - parser = MMCIFParser() - else: - parser = PDBParser() - structure = parser.get_structure('name', pdb_file) + if ".cif" in str(pdb_file): # BIOPYTHON + parser = MMCIFParser() # BIOPYTHON + else: # BIOPYTHON + parser = PDBParser() # BIOPYTHON + structure = parser.get_structure('name', pdb_file) #BIOPYTHON + #structure = prody.parsePDB(str(pdb_file)) # PRODY + #hv = structure.getHierView() # PRODY if chain==None: - all_chains=[i.get_id() for i in structure.get_chains()] + all_chains=[i.get_id() for i in structure.get_chains()] # BIOPYTHON + #all_chains = [structure_chain.getChid() for structure_chain in hv] # PRODY else: - all_chains=[chain] + if type(chain) == list: + all_chains = chain + elif type(chain) == str: + all_chains = [id for id in chain if id != " "] # remove spaces if present in string + else: + raise TypeError(f"chain must be list or str but was {type(chain)}") sequence = "" - for chain in all_chains: - c = structure[0][chain] + start_mask = [] + for single_chain in all_chains: + c = structure[0][single_chain] # BIOPYTHON + #c = hv[single_chain] # PRODY chain_seq = "" for residue in c: - is_regular_res = residue.has_id('CA') and residue.has_id('O') - res_id = residue.get_id()[0] - if (res_id==' ' or res_id=='H_MSE' or res_id=='H_M3L' or res_id=='H_CAS') and is_regular_res: - residue_name = residue.get_resname() + is_regular_res = residue.has_id('CA') and residue.has_id('O') # BIOPYTHON + #atom_names = [atom.getName() for atom in residue] # PRODY + #is_regular_res = ("CA" in atom_names and "O" in atom_names) # PRODY + res_id = residue.get_id()[0] #BIOPYTHON + okay_resids = [' ', 'H_MSE', 'H_M3L', 'H_CAS', 'H_ALA', 'H_CYS', 'H_ASP', + 'H_GLU', 'H_PHE', 'H_GLY', 'H_HIS', 'H_ILE', 'H_LYS', 'H_LEU', 'H_MET', + 'H_ASN', 'H_PRO', 'H_GLN', 'H_ARG', 'H_SER', 'H_THR', 'H_VAL', 'H_TRP', 'H_TYR'] + if res_id in okay_resids and is_regular_res: # BIOPYTHON + # i don't know what H_HSE, H_M3L, and H_CAS are doing + # because they aren't in three_to_one, so those should throw an error + # long story short, I don't think we have to worry about them when switching from biopython to prody + #if is_regular_res: # PRODY + residue_name = residue.get_resname() # BIOPYTHON + #residue_name = residue.getResname() # PRODY chain_seq += three_to_one[residue_name] + if chain_seq == "": # empty chain, like a nucleic acid chain (see 8ZWK) + continue # FYI, currently, a non-empty chain with certain invalid residues will throw an error at the three_to_one[residue_name] above sequence += chain_seq - return sequence + start_mask.append(1) + for _ in range(1,len(chain_seq)): + start_mask.append(0) + if return_start_mask: + return (sequence,start_mask) + else: + return sequence def get_distance_matrix(pdb_file: Union[Path,str], chain: str, - method: str = 'CB' + method: str = 'CB', + return_distance_midpoints: bool = False, ) -> np.array: """ Calculate the distance matrix of the specified atoms in a PDB file. @@ -106,6 +138,10 @@ def get_distance_matrix(pdb_file: Union[Path,str], 'CA' for using only the CA atom, 'minimum' for using the minimum distance between all atoms in each residue, 'CB_force' computes a new coordinate for the CB atom based on the CA, C, and N atoms and uses CB distance even for glycine. + return_distance_midpoints: bool + Whether to return a matrix of the same shape as distance_matrix representing the same contacts as distance_matrix + that indicates the absolute coordinates of the midpoint between the pair of atoms. This helps us compute the pair distribution + functions of the different classes of contacts. So this matrix isn't really a matrix because each "element" has 3 channels: x, y, and z Returns: np.array: The distance matrix of the selected atoms. @@ -121,7 +157,7 @@ def get_distance_matrix(pdb_file: Union[Path,str], if method == 'CA': coords = structure.select('protein and name CA' + chain_selection).getCoords() elif method == 'CB': - coords = structure.select('(protein and (name CB) or (resname GLY and name CA))' + chain_selection).getCoords() + coords = structure.select('(protein and (name CB) or (resname GLY IGL and name CA))' + chain_selection).getCoords() elif method == 'minimum': selection = structure.select('protein' + chain_selection) coords = selection.getCoords() @@ -163,8 +199,20 @@ def get_distance_matrix(pdb_file: Union[Path,str], if len(coords) == 0: raise IndexError('Empty selection for distance map') + # coords should be a numpy array of shape (N,3) distance_matrix = sdist.squareform(sdist.pdist(coords)) - return distance_matrix + assert distance_matrix.shape[0] == distance_matrix.shape[1] + if return_distance_midpoints: + midpoint_matrix = np.zeros((distance_matrix.shape[0],distance_matrix.shape[1],3)) + for i in range(distance_matrix.shape[0]): + for j in range(distance_matrix.shape[1]): + midpoint_matrix[i,j,:] = (coords[None,i,:] + coords[None,j,:])/2 + # check that indexing is consistent with distance_matrix + assert np.allclose(np.linalg.norm(coords[i,:]-coords[j,:]),distance_matrix[i,j]) + assert np.allclose(midpoint_matrix,midpoint_matrix.transpose(1,0,2)) # check symmetry + return distance_matrix, midpoint_matrix + else: + return distance_matrix def full_to_filtered_aligned_mapping(aligned_sequence: str, diff --git a/tests/test_awsem_frustratometer.py b/tests/test_awsem_frustratometer.py index b1ff257e..f186ccd9 100644 --- a/tests/test_awsem_frustratometer.py +++ b/tests/test_awsem_frustratometer.py @@ -2,6 +2,7 @@ import pandas as pd import numpy as np import frustratometer +from frustratometer.numba_util import hamiltonian as ham from pathlib import Path @@ -28,7 +29,9 @@ def test_prody_expected_error(): def test_density_residues(test_data): structure = frustratometer.Structure(test_data_path/f"{test_data['pdb']}.pdb") sequence_separation = 2 if test_data['seqsep'] == 3 else 13 - model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, min_sequence_separation_rho=sequence_separation, k_electrostatics=0) + model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, + min_sequence_separation_rho=sequence_separation, k_electrostatics=0, + expose_indicator_functions=True, potts_option=True) data = pd.read_csv(test_data['singleresidue'], delim_whitespace=True) data['Calculated_density'] = model.rho_r data['Expected_density'] = data['DensityRes'] @@ -45,7 +48,9 @@ def test_density_residues(test_data): def test_single_residue_frustration(test_data): structure = frustratometer.Structure(test_data_path/f"{test_data['pdb']}.pdb") sequence_separation = 2 if test_data['seqsep'] == 3 else 13 - model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, min_sequence_separation_rho=sequence_separation, min_sequence_separation_contact=2, k_electrostatics=test_data['k_electrostatics'] * 4.184, min_sequence_separation_electrostatics=1) + model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, min_sequence_separation_rho=sequence_separation, + min_sequence_separation_contact=2, k_electrostatics=test_data['k_electrostatics'] * 4.184, + min_sequence_separation_electrostatics=1, expose_indicator_functions=True, potts_option=True) data = pd.read_csv(test_data['singleresidue'], delim_whitespace=True) data['Calculated_frustration'] = model.frustration(kind='singleresidue') data['Expected_frustration'] = data['FrstIndex'] @@ -63,7 +68,9 @@ def test_mutational_frustration(test_data): if test_data['k_electrostatics']==1000: assert True return - model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, min_sequence_separation_rho=sequence_separation, min_sequence_separation_contact=0, k_electrostatics=test_data['k_electrostatics'] * 4.184, min_sequence_separation_electrostatics=1) + model = frustratometer.AWSEM(structure, distance_cutoff_contact=9.5, min_sequence_separation_rho=sequence_separation, + min_sequence_separation_contact=0, k_electrostatics=test_data['k_electrostatics'] * 4.184, + min_sequence_separation_electrostatics=1, expose_indicator_functions=True, potts_option=True) data = pd.read_csv(test_data['mutational'], delim_whitespace=True) if test_data['pdb']!="ijge": @@ -104,7 +111,8 @@ def test_configurational_frustration(test_data): min_sequence_separation_rho=sequence_separation, min_sequence_separation_contact=0, k_electrostatics=test_data['k_electrostatics'] * 4.184, - min_sequence_separation_electrostatics=1) + min_sequence_separation_electrostatics=1, + expose_indicator_functions=True, potts_option=True) data = pd.read_csv(test_data['configurational'], delim_whitespace=True) @@ -142,47 +150,51 @@ def test_residue_density_calculation(): structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, - min_sequence_separation_contact=2) + min_sequence_separation_contact=2, + expose_indicator_functions=True, potts_option=True) assert np.round(model.rho_r,2).all()==np.round(expected_rho_values,2).all() def test_AWSEM_native_energy(): structure=frustratometer.Structure(test_data_path/f'1l63.pdb',"A") - model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, distance_cutoff_contact = None) + model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, + distance_cutoff_contact = None, expose_indicator_functions=True, potts_option=True) e = model.native_energy() print(e) assert np.round(e, 0) == -915 def test_AWSEM_fields_energy(): structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") - model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, distance_cutoff_contact = None) + model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, + distance_cutoff_contact = None, expose_indicator_functions=True, potts_option=True) e = model.fields_energy() print(e) assert np.round(e, 0) == -555 def test_AWSEM_couplings_energy(): structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") - model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, distance_cutoff_contact = None) + model=frustratometer.AWSEM(structure,k_electrostatics=0, min_sequence_separation_contact = 10, distance_cutoff_contact = None, + expose_indicator_functions=True, potts_option=True) e = model.couplings_energy() print(e) assert np.round(e, 0) == -362 def test_fields_couplings_AWSEM_energy(): structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") - model = frustratometer.AWSEM(structure) + model = frustratometer.AWSEM(structure, expose_indicator_functions=True, potts_option=True) assert model.fields_energy() + model.couplings_energy() - model.native_energy() < 1E-6 def test_single_residue_AWSEM_energy(): - _AA = '-ACDEFGHIKLMNPQRSTVWY' #Import Lammps AWSEM Frustratometer single residue frustration values lammps_single_frustration_dataframe=pd.read_csv(test_data_path/f"6U5E_A_tertiary_frustration_singleresidue_1E8decoys_AWSEM_Frustratometer_LAMMPS_Carlos.dat",header=0,sep="\s+") ### structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, min_sequence_separation_contact=2, - k_electrostatics=0) + k_electrostatics=0, + expose_indicator_functions=True, potts_option=True) #Calculate fields - seq_index = np.array([_AA.find(aa) for aa in structure.sequence]) + seq_index = np.array([model.alphabet.index(aa) for aa in structure.sequence]) seq_len = len(seq_index) h = -model.potts_model['h'][range(seq_len), seq_index] @@ -196,8 +208,77 @@ def test_single_residue_AWSEM_energy(): assert (abs(np.array(lammps_single_frustration_dataframe["native_energy"])-test_residue_total_energy) < 1E-1).all() + + +def test_numba_potts_construction(): + """Check that the potts models constructed in the old (numpy vectorized) + and new (numba) ways are identical. By doing this for a sufficiently + diverse set of structures, we also implicitly verify that the numba + functions (well, the ones invoked by the potts model setup) + compute the potential accurately.""" + structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") + model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, + min_sequence_separation_contact=2, + k_electrostatics=0, + expose_indicator_functions=True, potts=False) + ######################################## + # old way -- note that these are all negatives of the actual energy: + # burial, direct, protein, water don't have a factor of -1 when the should, + # and electrostatics has a factor of -1 when it shouldn't. + # For some reason, this is just how we do the potts model. + J_index = np.meshgrid(range(model.N), range(model.N), range(model.p.q), range(model.p.q), indexing='ij', sparse=False) + h_index = np.meshgrid(range(model.N), range(model.p.q), indexing='ij', sparse=False) + + # compute burial and contact energies + old_burial_energy = 0.5 * model.p.k_contact * model.p.burial_gamma[h_index[1]] * model.burial_indicator[:, np.newaxis, :] + direct = model.direct_indicator * model.p.direct_gamma[J_index[2], J_index[3]] + water_mediated = model.water_indicator * model.p.water_gamma[J_index[2], J_index[3]] + protein_mediated = model.protein_indicator * model.p.protein_gamma[J_index[2], J_index[3]] + contact_energy = model.p.k_contact * np.array([direct, water_mediated, protein_mediated]) * model.sequence_mask_contact[np.newaxis, :, :, np.newaxis, np.newaxis] + + electrostatics_energy = -model.p.k_electrostatics * model.p.electrostatics_gamma[np.newaxis,np.newaxis,:,:] * model.electrostatics_indicator[:,:,np.newaxis,np.newaxis]\ + * model.electrostatics_mask[:,:,np.newaxis,np.newaxis] + contact_energy = np.append(contact_energy, electrostatics_energy[np.newaxis,:,:,:,:], axis=0) + old_contact_energy = contact_energy + # Compute potts model + old_potts_model = {} + old_potts_model['h'] = old_burial_energy.sum(axis=-1)[:, :] + old_potts_model['J'] = old_contact_energy.sum(axis=0)[:, :, :, :] + np.save('old_way_h.npy',old_potts_model['h']) + np.save('old_way_J.npy',old_potts_model['J']) + ############################################### + # new way + new_potts_model = {'h':None, 'J':None} + chain_starts = np.array([0]) + chain_ends = np.array([len(model.seq_index)-1]) + if model.p.distance_cutoff_contact is None: + contact_max_dist = 12.5 + else: + contact_max_dist = model.p.distance_cutoff_contact + new_potts_model['h'] = ham.compute_potts_model_h_parallel( + model.p.min_sequence_separation_rho, + chain_starts, chain_ends, + model.distance_matrix, + model.p.k_contact, model.p.burial_gamma) + new_potts_model['J'] = ham.compute_potts_model_J_parallel( + model.p.electrostatics_screening_length, model.p.min_sequence_separation_rho, + model.p.min_sequence_separation_contact, model.p.min_sequence_separation_electrostatics, + chain_starts, chain_ends, + contact_max_dist, 10*model.p.electrostatics_screening_length, # maximum distance for contact potential, maximum for electrostatics + model.distance_matrix, + model.p.k_contact, model.p.direct_gamma, + model.p.k_contact, model.p.protein_gamma, + model.p.k_contact, model.p.water_gamma, + model.p.k_electrostatics, model.p.electrostatics_gamma) + #np.save('new_way_h.npy',new_potts_model['h']) + #np.save('new_way_J.npy',new_potts_model['J']) + assert np.max(np.abs(old_potts_model['h'] - new_potts_model['h'])) < 1E-5 # 10^-5 kJ/mol error is acceptable + assert np.max(np.abs(old_potts_model['J'] - new_potts_model['J'])) < 1E-5 # 10^-5 kJ/mol error is acceptable + + + + def test_contact_pair_AWSEM_energy(): - _AA = '-ACDEFGHIKLMNPQRSTVWY' #Import Lammps AWSEM Frustratometer mutational frustration values lammps_mutational_frustration_dataframe=pd.read_csv(test_data_path/f"6U5E_A_tertiary_frustration_mutational_1E6decoys_AWSEM_Frustratometer_LAMMPS_Carlos.dat",header=0,sep="\s+") lammps_mutational_frustration_dataframe["i"]=lammps_mutational_frustration_dataframe["i"]-1 @@ -206,9 +287,9 @@ def test_contact_pair_AWSEM_energy(): structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, min_sequence_separation_contact=0, - k_electrostatics=0) + k_electrostatics=0, expose_indicator_functions=True, potts_option=True) #Calculate fields - seq_index = np.array([_AA.find(aa) for aa in structure.sequence]) + seq_index = np.array([model.alphabet.index(aa) for aa in structure.sequence]) seq_len = len(seq_index) h = -model.potts_model['h'][range(seq_len), seq_index] @@ -226,13 +307,15 @@ def test_contact_pair_AWSEM_energy(): def test_selected_subsequence_AWSEM_contact_energy_matrix(): structure=frustratometer.Structure(test_data_path/f'4wnc.pdb',"A",seq_selection="resnum 3to26") - model=frustratometer.AWSEM(structure) - assert model.potts_model['h'].shape==(24,21) + model=frustratometer.AWSEM(structure, expose_indicator_functions=True, potts_option=True) + q = len(model.gamma.alphabet) + assert model.potts_model['h'].shape==(24,q) def test_selected_subsequence_AWSEM_burial_energy_matrix(): structure=frustratometer.Structure(test_data_path/f'4wnc.pdb',"A",seq_selection="resnum 150to315") - model=frustratometer.AWSEM(structure) - assert model.potts_model['J'].shape==(166,166,21,21) + model=frustratometer.AWSEM(structure, expose_indicator_functions=True, potts_option=True) + q = len(model.gamma.alphabet) + assert model.potts_model['J'].shape==(166,166,q,q) ##### #Test Protein Segment Native AWSEM Energy Calculation @@ -241,12 +324,14 @@ def test_selected_subsequence_AWSEM_burial_energy_matrix(): def test_selected_subsequence_AWSEM_rho_calculations(): #Substructure object substructure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A",seq_selection="resnum 39to146") - model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) model_1_init_index=model_1.init_index_shift; model_1_fin_index=model_1.fin_index_shift #Full structure object structure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A") - model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) #Check if shape and entries of rho matrices are identical assert model_1.rho_r.shape==model_2.rho_r[model_1_init_index:model_1_fin_index].shape @@ -255,12 +340,14 @@ def test_selected_subsequence_AWSEM_rho_calculations(): def test_selected_subsequence_AWSEM_burial_energy(): #Substructure object substructure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A",seq_selection="resnum 39to146") - model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) model_1_init_index=model_1.init_index_shift; model_1_fin_index=model_1.fin_index_shift #Full structure object structure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A") - model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) #Check if burial energies are identical assert model_1.burial_energy.shape==model_2.burial_energy[model_1_init_index:model_1_fin_index].shape @@ -269,12 +356,14 @@ def test_selected_subsequence_AWSEM_burial_energy(): def test_selected_subsequence_AWSEM_contact_energy(): #Substructure object substructure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A",seq_selection="resnum 39to146") - model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_1=frustratometer.AWSEM(substructure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) model_1_init_index=model_1.init_index_shift; model_1_fin_index=model_1.fin_index_shift #Full structure object structure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A") - model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0) + model_2=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0, expose_indicator_functions=True, potts_option=True) #Check if contact energies are identical assert model_1.contact_energy.shape==model_2.contact_energy[:,model_1_init_index:model_1_fin_index,model_1_init_index:model_1_fin_index,:,:].shape @@ -282,27 +371,29 @@ def test_selected_subsequence_AWSEM_contact_energy(): def test_selected_subsequence_AWSEM_burial_energy_without_protein_context(): structure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A",seq_selection="resnum 39to146") - model=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0,burial_in_context=False) + model=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0,burial_in_context=False, expose_indicator_functions=True, potts_option=True) selected_region_burial=model.fields_energy() # Energy units are in kJ/mol assert np.round(selected_region_burial, 2) == -377.95 def test_selected_subsequence_AWSEM_contact_energy_without_protein_context(): structure=frustratometer.Structure(test_data_path/f'1MBA_A.pdb',"A",seq_selection="resnum 39to146") - model=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10,distance_cutoff_contact=10.0,burial_in_context=False) + model=frustratometer.AWSEM(structure, k_electrostatics=0.0,min_sequence_separation_contact=10, + distance_cutoff_contact=10.0,burial_in_context=False, expose_indicator_functions=True, potts_option=True) selected_region_contact=model.couplings_energy() # Energy units are in kJ/mol assert np.round(selected_region_contact, 2) == -148.92 def test_single_residue_decoy_AWSEM_energy_statistics(): - _AA = '-ACDEFGHIKLMNPQRSTVWY' #Import Lammps AWSEM Frustratometer single residue frustration values lammps_single_frustration_dataframe=pd.read_csv(test_data_path/f"6U5E_A_tertiary_frustration_singleresidue_1E8decoys_AWSEM_Frustratometer_LAMMPS_Carlos.dat",header=0,sep="\s+") ### structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") - model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, min_sequence_separation_contact=2, k_electrostatics=0) + model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.499, min_sequence_separation_contact=2, k_electrostatics=0, + expose_indicator_functions=True, potts_option=True) #Calculate fields - seq_index = np.array([_AA.find(aa) for aa in structure.sequence]) + seq_index = np.array([model.alphabet.index(aa) for aa in structure.sequence]) seq_len = len(seq_index) h = -model.potts_model['h'][range(seq_len), seq_index] @@ -324,16 +415,18 @@ def test_single_residue_decoy_AWSEM_energy_statistics(): assert (abs(np.array(lammps_single_frustration_dataframe["std(decoy_energies)"])-(expected_std_decoy_energy)) < 1.2E-1).all() def test_contact_pair_decoy_AWSEM_energy_statistics(): - _AA = '-ACDEFGHIKLMNPQRSTVWY' #Import Lammps AWSEM Frustratometer mutational frustration values lammps_mutational_frustration_dataframe=pd.read_csv(test_data_path/f"6U5E_A_tertiary_frustration_mutational_1E6decoys_AWSEM_Frustratometer_LAMMPS_Carlos.dat",header=0,sep="\s+") lammps_mutational_frustration_dataframe["i"]=lammps_mutational_frustration_dataframe["i"]-1 lammps_mutational_frustration_dataframe["j"]=lammps_mutational_frustration_dataframe["j"]-1 ### structure=frustratometer.Structure(test_data_path/f'6u5e.pdb',"A") - model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.5, min_sequence_separation_contact=None, k_electrostatics=0) + model=frustratometer.AWSEM(structure,distance_cutoff_contact=9.5, min_sequence_separation_contact=None, k_electrostatics=0, + expose_indicator_functions=True, potts_option=True) + q = len(model.alphabet) + #Calculate fields - seq_index = np.array([_AA.find(aa) for aa in structure.sequence]) + seq_index = np.array([model.alphabet.index(aa) for aa in structure.sequence]) seq_len = len(seq_index) h = -model.potts_model['h'][range(seq_len), seq_index] @@ -351,10 +444,10 @@ def test_contact_pair_decoy_AWSEM_energy_statistics(): calculated_mutational_frustration_dataframe["j"]=j.ravel() ### decoy_fluctuations=(model.decoy_fluctuation(kind='mutational'))/4.184 - weighted_decoy_fluctations=np.average(decoy_fluctuations.reshape(seq_len * seq_len, 21 * 21), weights=model.contact_freq.flatten(), axis=-1) + weighted_decoy_fluctations=np.average(decoy_fluctuations.reshape(seq_len * seq_len, q * q), weights=model.contact_freq.flatten(), axis=-1) calculated_mutational_frustration_dataframe["Weighted_Decoy_Fluctuations"]=weighted_decoy_fluctations.ravel() calculated_mutational_frustration_dataframe["Test_Mean_Decoy_Energy"]=calculated_mutational_frustration_dataframe["Test_Native_Energy"]+calculated_mutational_frustration_dataframe["Weighted_Decoy_Fluctuations"] - calculated_mutational_frustration_dataframe["STD_Decoy_Energy"]=np.average((decoy_fluctuations.reshape(seq_len * seq_len, 21 * 21)-calculated_mutational_frustration_dataframe["Weighted_Decoy_Fluctuations"].astype(float).values[:,np.newaxis]) ** 2,weights=model.contact_freq.flatten(), axis=-1) + calculated_mutational_frustration_dataframe["STD_Decoy_Energy"]=np.average((decoy_fluctuations.reshape(seq_len * seq_len, q * q)-calculated_mutational_frustration_dataframe["Weighted_Decoy_Fluctuations"].astype(float).values[:,np.newaxis]) ** 2,weights=model.contact_freq.flatten(), axis=-1) calculated_mutational_frustration_dataframe["STD_Decoy_Energy"]=np.sqrt(calculated_mutational_frustration_dataframe["STD_Decoy_Energy"]) merged_dataframe=calculated_mutational_frustration_dataframe.merge(lammps_mutational_frustration_dataframe,on=["i","j"]) @@ -372,21 +465,90 @@ def structure(): @pytest.mark.parametrize("distance_cutoff_contact", [None, 10]) def test_expose_indicators(structure, k_electrostatics, min_sequence_separation_contact, distance_cutoff_contact): """ Check that the AWSEM indicators exposed can reproduce the native energy, where E_native = -sum_{i} h_i - sum_{i,j} J_ij = sum_{i} gamma_i * I_i """ - _AA = '-ACDEFGHIKLMNPQRSTVWY' - model=frustratometer.AWSEM(structure,k_electrostatics=k_electrostatics, min_sequence_separation_contact = min_sequence_separation_contact, distance_cutoff_contact = distance_cutoff_contact, expose_indicator_functions=True) - model_seq_index=np.array([_AA.find(aa) for aa in model.sequence]) - indicators1D=np.array(model.indicators[0:3]) - indicators2D=np.array(model.indicators[3:]) - true_indicator1D=np.array([indicators1D[:,model_seq_index==i].sum(axis=1) for i in range(21)]).T - true_indicator2D=np.array([indicators2D[:,model_seq_index==i][:,:, model_seq_index==j].sum(axis=(1,2)) for i in range(21) for j in range(21)]).reshape(21,21,-1).T - burial_gamma=np.concatenate(model.gamma_array[:3]) + model=frustratometer.AWSEM(structure,k_electrostatics=k_electrostatics, + min_sequence_separation_contact = min_sequence_separation_contact, + distance_cutoff_contact = distance_cutoff_contact, expose_indicator_functions=True, potts_option=True) + q = len(model.alphabet) + model_seq_index=np.array([model.alphabet.index(aa) for aa in model.sequence]) + indicators1D=np.array(model.masked_indicators[0:3]) + indicators2D=np.array(model.masked_indicators[3:]) + true_indicator1D=np.array([indicators1D[:,model_seq_index==i].sum(axis=1) for i in range(q)]).T + true_indicator2D=np.array([indicators2D[:,model_seq_index==i][:,:, model_seq_index==j].sum(axis=(1,2)) for i in range(q) for j in range(q)]).reshape(q,q,-1).T + burial_gamma=np.concatenate(model.coefficient_lambda_gamma_array[:3]) burial_energy_predicted = (burial_gamma * np.concatenate(true_indicator1D)).sum() burial_energy_expected = -model.potts_model['h'][range(len(model_seq_index)), model_seq_index].sum() assert np.isclose(burial_energy_predicted,burial_energy_expected), f"Expected energy {burial_energy_expected} but got {burial_energy_predicted}" - contact_gamma=np.concatenate([a.ravel() for a in model.gamma_array[3:]]) + contact_gamma=np.concatenate([a.ravel() for a in model.coefficient_lambda_gamma_array[3:]]) + #assert indicators2D.shape == "foo", indicators2D.shape + #assert true_indicators2D.shape == "foo", true_indicators2D.shape + #assert contact_gamma.shape == "foo", contact_gamma.shape contact_energy_predicted = (contact_gamma * np.concatenate([a.ravel() for a in true_indicator2D])).sum() contact_energy_expected = model.couplings_energy() assert np.isclose(contact_energy_predicted,contact_energy_expected), f"Expected energy {contact_energy_expected} but got {contact_energy_predicted}" +#@pytest.mark.parametrize("k_electrostatics", [0, 4]) +#@pytest.mark.parametrize("min_sequence_separation_contact", [2, 10]) +#@pytest.mark.parametrize("distance_cutoff_contact", [None, 10]) +#def test_numba_potts_construction(structure, k_electrostatics, min_sequence_separation_contact, distance_cutoff_contact): +# """Check that the potts models constructed in the old (numpy vectorized) +# and new (numba) ways are identical. By doing this for a sufficiently +# diverse set of structures, we also implicitly verify that the numba +# functions (well, the ones invoked by the potts model setup) +# compute the potential accurately.""" +# model=frustratometer.AWSEM(structure,k_electrostatics=k_electrostatics, +# min_sequence_separation_contact = min_sequence_separation_contact, +# distance_cutoff_contact = distance_cutoff_contact, expose_indicator_functions=True, potts_option=True) +# ######################################## +# # old way -- note that these are all negatives of the actual energy: +# # burial, direct, protein, water don't have a factor of -1 when the should, +# # and electrostatics has a factor of -1 when it shouldn't. +# # For some reason, this is just how we do the potts model. +# J_index = np.meshgrid(range(model.N), range(model.N), range(model.q), range(model.q), indexing='ij', sparse=False) +# h_index = np.meshgrid(range(model.N), range(model.q), indexing='ij', sparse=False) +# +# # compute burial and contact energies +# old_burial_energy = 0.5 * model.p.k_contact * model.burial_gamma[h_index[1]] * model.burial_indicator[:, np.newaxis, :] +# direct = model.direct_indicator * model.direct_gamma[J_index[2], J_index[3]] +# water_mediated = model.water_indicator * model.water_gamma[J_index[2], J_index[3]] +# protein_mediated = model.protein_indicator * model.protein_gamma[J_index[2], J_index[3]] +# contact_energy = model.p.k_contact * np.array([direct, water_mediated, protein_mediated]) * model.sequence_mask_contact[np.newaxis, :, :, np.newaxis, np.newaxis] +# +# electrostatics_energy = -model.k_electrostatics * model.electrostatics_gamma[np.newaxis,np.newaxis,:,:] * model.electrostatics_indicator[:,:,np.newaxis,np.newaxis]\ +# * model.electrostatics_mask[:,:,np.newaxis,np.newaxis] +# contact_energy = np.append(contact_energy, electrostatics_energy[np.newaxis,:,:,:,:], axis=0) +# old_contact_energy = contact_energy +# # Compute potts model +# old_potts_model = {} +# old_potts_model['h'] = old_burial_energy.sum(axis=-1)[:, :] +# old_potts_model['J'] = old_contact_energy.sum(axis=0)[:, :, :, :] +# ############################################### +# # new way +# new_potts_model = {'h':None, 'J':None} +# chain_starts = np.array([0]) +# chain_ends = np.array([len(model.seq_index)-1]) +# if model.distance_cutoff_contact is None: +# contact_max_dist = 12.5 +# else: +# contact_max_dist = model.distance_cutoff_contact +# new_potts_model['h'] = ham.compute_potts_model_h_parallel( +# model.min_sequence_separation_rho, +# chain_starts, chain_ends, +# model.distance_matrix, +# model.k_contact, model.burial_gamma) +# new_potts_model['J'] = ham.compute_potts_model_J_parallel( +# model.electrostatics_screening_length, model.min_sequence_separation_rho, +# model.min_sequence_separation_contact, model.min_sequence_separation_electrostatics, +# chain_starts, chain_ends, +# contact_max_dist, 10*model.electrostatics_screening_length, # maximum distance for contact potential, maximum for electrostatics +# model.distance_matrix, +# model.k_contact, model.direct_gamma, +# model.k_contact, model.protein_gamma, +# model.k_contact, model.water_gamma, +# model.k_electrostatics, model.electrostatics_gamma) +# #np.save('new_way_h.npy',new_potts_model['h']) +# #np.save('new_way_J.npy',new_potts_model['J']) +# assert np.max(np.abs(old_potts_model['h'] - new_potts_model['h'])) < 1E-5 # 10^-5 kJ/mol error is acceptable +# assert np.max(np.abs(old_potts_model['J'] - new_potts_model['J'])) < 1E-5 # 10^-5 kJ/mol error is acceptable + if __name__ == "__main__": pytest.main() diff --git a/tests/test_dca_frustratometer.py b/tests/test_dca_frustratometer.py index e02df7bc..c7a97b25 100644 --- a/tests/test_dca_frustratometer.py +++ b/tests/test_dca_frustratometer.py @@ -225,7 +225,7 @@ def test_functional_compute_DCA_native_energy(): distance_matrix = frustratometer.pdb.get_distance_matrix(pdb_path, chain_id, method='minimum') potts_model = frustratometer.dca.matlab.load_potts_model(potts_model_path) mask = frustratometer.frustration.compute_mask(distance_matrix, maximum_contact_distance=4, minimum_sequence_separation=0) - energy = frustratometer.frustration.compute_native_energy(sequence, potts_model, mask) + energy = frustratometer.frustration.compute_native_energy(sequence, potts_model, mask, '-ACDEFGHIKLMNPQRSTVWY') assert np.round(energy, 4) == expected_energy @@ -406,8 +406,8 @@ def test_compute_singleresidue_DCA_decoy_energy(): seq = [aa for aa in seq] seq[pos_x] = AA[aa_x] seq = ''.join(seq) - test_energy = frustratometer.frustration.compute_native_energy(seq, potts_model, mask) - decoy_energy = frustratometer.frustration.compute_decoy_energy(seq, potts_model, mask, 'singleresidue') + test_energy = frustratometer.frustration.compute_native_energy(seq, potts_model, mask, AA) + decoy_energy = frustratometer.frustration.compute_decoy_energy(seq, potts_model, mask, '-ACDEFGHIKLMNPQRSTVWY', 'singleresidue') assert (decoy_energy[pos_x, aa_x] - test_energy) ** 2 < 1E-16 @@ -427,8 +427,8 @@ def test_compute_mutational_DCA_decoy_energy(): seq[pos_x] = AA[aa_x] seq[pos_y] = AA[aa_y] seq = ''.join(seq) - test_energy = frustratometer.frustration.compute_native_energy(seq, potts_model, mask) - decoy_energy = frustratometer.frustration.compute_decoy_energy(seq, potts_model, mask, 'mutational') + test_energy = frustratometer.frustration.compute_native_energy(seq, potts_model, mask, AA) + decoy_energy = frustratometer.frustration.compute_decoy_energy(seq, potts_model, mask, '-ACDEFGHIKLMNPQRSTVWY', 'mutational') assert (decoy_energy[pos_x, pos_y, aa_x, aa_y] - test_energy) ** 2 < 1E-16 diff --git a/tests/test_optimization.py b/tests/test_optimization.py index 7b939a17..afb777af 100644 --- a/tests/test_optimization.py +++ b/tests/test_optimization.py @@ -347,7 +347,7 @@ def test_diff_mean_inner_product_1_by_1(n_elements = 10): def model(request): native_pdb = "tests/data/1bfz.pdb" distance_cutoff_contact, min_sequence_separation_contact, k_electrostatics = request.param - structure = Structure.full_pdb(native_pdb, "A") + structure = Structure(native_pdb, "A") model = AWSEM(structure, distance_cutoff_contact=distance_cutoff_contact, min_sequence_separation_contact=min_sequence_separation_contact, expose_indicator_functions=True, k_electrostatics=k_electrostatics) return model @@ -428,8 +428,8 @@ def test_awsem_energy_variance(model, reduced_alphabet, use_numba): # from itertools import permutations # decoy_sequences = np.array(list(permutations(seq_index))) -# indicators1D=np.array(model.indicators[:3]) -# indicators2D=np.array(model.indicators[3:]) +# indicators1D=np.array(model.masked_indicators[:3]) +# indicators2D=np.array(model.masked_indicators[3:]) # indicator_arrays=[] # energies=[] # for decoy_index in decoy_sequences: @@ -443,36 +443,36 @@ def test_awsem_energy_variance(model, reduced_alphabet, use_numba): # ind2D[i] =np.bincount(decoy_index2D.ravel(), weights=indicators2D[i].ravel(), minlength=21*21) # indicator_array = np.concatenate([ind1D.ravel(),ind2D.ravel()]) -# gamma_array = np.concatenate([a.ravel() for a in model.gamma_array]) +# gamma_array = np.concatenate([a.ravel() for a in model.coefficient_lambda_gamma_array]) -# energy_i = gamma_array @ indicator_array +# energy_i = coefficient_lambda_gamma_array @ indicator_array # assert np.isclose(model.native_energy(index_to_sequence(decoy_index,alphabet=_AA)),energy_i), f"Expected energy {model.native_energy(index_to_sequence(decoy_index,alphabet=_AA))} but got {energy_i}" # energies.append(energy_i) # indicator_arrays.append(indicator_array) # indicator_arrays = np.array(indicator_arrays) # energies = np.array(energies) -# assert np.isclose(gamma_array@indicator_arrays.mean(axis=0),energies.mean()), f"Expected mean energy {gamma_array@indicator_arrays.mean(axis=0)} but got {np.mean(energies)}" +# assert np.isclose(coefficient_lambda_gamma_array@indicator_arrays.mean(axis=0),energies.mean()), f"Expected mean energy {coefficient_lambda_gamma_array@indicator_arrays.mean(axis=0)} but got {np.mean(energies)}" # # I will code something like this using numpy einsums: # # np.array([[np.outer(indicator_arrays[:,i],indicator_arrays[:,j]).mean() - indicator_arrays[:,i].mean()*indicator_arrays[:,i].mean() for i in range(indicator_arrays.shape[1])] for j in range(indicator_arrays.shape[1])]) # outer_product = np.einsum('ij,ik->ijk', indicator_arrays, indicator_arrays) # mean_outer_product = outer_product.mean(axis=0) # mean_outer_product -= np.outer(indicator_arrays.mean(axis=0), indicator_arrays.mean(axis=0)) -# assert np.allclose(gamma_array @ mean_outer_product @ gamma_array, energies.var()), "Covariance matrix is not correct" +# assert np.allclose(coefficient_lambda_gamma_array @ mean_outer_product @ coefficient_lambda_gamma_array, energies.var()), "Covariance matrix is not correct" # # Indicator tests -# indicators1D=np.array(model.indicators[0:3]) -# indicators2D=np.array(model.indicators[3:]) -# gamma=model.gamma_array +# indicators1D=np.array(model.masked_indicators[0:3]) +# indicators2D=np.array(model.masked_indicators[3:]) +# gamma=model.coefficient_lambda_gamma_array # true_indicator1D=np.array([indicators1D[:,model_seq_index==i].sum(axis=1) for i in range(21)]).T # true_indicator2D=np.array([indicators2D[:,model_seq_index==i][:,:, model_seq_index==j].sum(axis=(1,2)) for i in range(21) for j in range(21)]).reshape(21,21,3).T # true_indicator=np.concatenate([true_indicator1D.ravel(),true_indicator2D.ravel()]) -# burial_gamma=np.concatenate(model.gamma_array[:3]) +# burial_gamma=np.concatenate(model.coefficient_lambda_gamma_array[:3]) # burial_energy_predicted = (burial_gamma * np.concatenate(true_indicator1D)).sum() # burial_energy_expected = -model.potts_model['h'][range(len(model_seq_index)), model_seq_index].sum() # assert np.isclose(burial_energy_predicted,burial_energy_expected), f"Expected energy {burial_energy_expected} but got {burial_energy_predicted}" -# contact_gamma=np.concatenate([a.ravel() for a in model.gamma_array[3:]]) +# contact_gamma=np.concatenate([a.ravel() for a in model.coefficient_lambda_gamma_array[3:]]) # contact_energy_predicted = (contact_gamma * np.concatenate([a.ravel() for a in true_indicator2D])).sum() # contact_energy_expected = model.couplings_energy() # assert np.isclose(contact_energy_predicted,contact_energy_expected), f"Expected energy {contact_energy_expected} but got {contact_energy_predicted}" \ No newline at end of file