From f4b171aba27d10985917fbe1b1f586533329231f Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Sat, 20 Jun 2026 17:46:11 +0300 Subject: [PATCH 1/3] Sync OpenFold PAE and iPAE fixes --- abcfold/data/config.ini | 2 +- abcfold/output/utils.py | 2 +- abcfold/scripts/ipsae.py | 21 ++++++++++++--------- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/abcfold/data/config.ini b/abcfold/data/config.ini index edde52b..456f1eb 100644 --- a/abcfold/data/config.ini +++ b/abcfold/data/config.ini @@ -20,7 +20,7 @@ protenix_env = abcfold-protenix-py311 af3_version = 3.0.0 boltz_version = 2.2.1 chai_version = 0.6.1 -openfold_version = 0.3.1 +openfold_version = 0.4.1 protenix_version = 2.0.0 [Models] diff --git a/abcfold/output/utils.py b/abcfold/output/utils.py index 335ebc3..cc4e688 100644 --- a/abcfold/output/utils.py +++ b/abcfold/output/utils.py @@ -328,7 +328,7 @@ def from_openfold3(cls, scores: dict, cif_file: CifFile): ] ) - pae_matrix = np.asarray(scores["pde"]) + pae_matrix = np.asarray(scores["pae"] if "pae" in scores else scores["pde"]) af3_scores["pae"] = pae_matrix.tolist() af3_scores["atom_chain_ids"] = atom_chain_ids af3_scores["atom_plddts"] = atom_plddts diff --git a/abcfold/scripts/ipsae.py b/abcfold/scripts/ipsae.py index 6d900db..086ef5c 100644 --- a/abcfold/scripts/ipsae.py +++ b/abcfold/scripts/ipsae.py @@ -127,17 +127,20 @@ def process_input(self): else: raise ValueError(f"Unsupported PAE file type: {suffix}") - # Construct input_params obj dict with protein seqs to avoid error msg self.struct.input_params = {"sequences": []} for chain in self.struct.get_chains(): - seq_data = {} - seq_data["protein"] = {} - if all([is_aa(res.resname) for res in chain.get_residues()]): - seq_data["protein"]["ID"] = [chain.id] - seq_data["protein"]["sequence"] = "".join( - [seq1(residue.get_resname()) for residue in chain] - ) - self.struct.input_params["sequences"].append(seq_data) + if all(is_aa(res.resname) for res in chain.get_residues()): + seq_data = { + "protein": { + "id": [chain.id], + "sequence": "".join( + seq1(residue.get_resname()) for residue in chain + ), + } + } + else: + seq_data = {"ligand": {"id": [chain.id]}} + self.struct.input_params["sequences"].append(seq_data) # Get PAE data for different formats if self.pae_format == "alphafold2": From f4e82054d2e8ea7875f787a4d0edce079bb37bd6 Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Sat, 20 Jun 2026 17:55:18 +0300 Subject: [PATCH 2/3] Add RosettaFold3 backend support --- abcfold/abcfold.py | 68 ++- abcfold/argparse_utils.py | 15 +- abcfold/data/config.ini | 3 + abcfold/html/abcfold_vue.js | 2 + abcfold/html/html_utils.py | 6 + abcfold/html/style.css | 12 +- abcfold/output/file_handlers.py | 20 + abcfold/output/rosettafold3.py | 210 +++++++ abcfold/output/utils.py | 33 ++ .../css/paeViewerStandaloneLayoutRosetta.css | 519 ++++++++++++++++++ abcfold/plots/pae_plot.py | 28 + abcfold/plots/plddt_plot.py | 1 + abcfold/rosettafold3/af3_to_rosettafold3.py | 170 ++++++ abcfold/rosettafold3/check_install.py | 73 +++ abcfold/rosettafold3/run_rosettafold3.py | 162 ++++++ pyproject.toml | 1 + tests/test_af3_to_rosettafold.py | 161 ++++++ tests/test_rosettafold_output.py | 48 ++ tests/test_run_rosettafold.py | 46 ++ 19 files changed, 1572 insertions(+), 6 deletions(-) create mode 100644 abcfold/output/rosettafold3.py create mode 100644 abcfold/plots/pae-viewer-main/standalone/css/paeViewerStandaloneLayoutRosetta.css create mode 100644 abcfold/rosettafold3/af3_to_rosettafold3.py create mode 100644 abcfold/rosettafold3/check_install.py create mode 100644 abcfold/rosettafold3/run_rosettafold3.py create mode 100644 tests/test_af3_to_rosettafold.py create mode 100644 tests/test_rosettafold_output.py create mode 100644 tests/test_run_rosettafold.py diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 77e5794..22d690b 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -17,6 +17,7 @@ openfold_argparse_util, prediction_argparse_util, protenix_argparse_util, + rosettafold_argparse_util, raise_argument_errors, visuals_argparse_util) from abcfold.html.html_utils import (PORT, NoCacheHTTPRequestHandler, @@ -30,6 +31,7 @@ from abcfold.output.file_handlers import superpose_models from abcfold.output.openfold3 import OpenfoldOutput from abcfold.output.protenix import ProtenixOutput +from abcfold.output.rosettafold3 import RosettafoldOutput from abcfold.output.utils import (get_gap_indicies, insert_none_by_minus_one, make_dummy_m8_file, verify_config_file) from abcfold.scripts.abc_script_utils import (check_input_json, make_dir, @@ -44,7 +46,12 @@ PLOTS_DIR = ".plots" ModelOutput = Union[ - AlphafoldOutput, BoltzOutput, ChaiOutput, ProtenixOutput, OpenfoldOutput + AlphafoldOutput, + BoltzOutput, + ChaiOutput, + ProtenixOutput, + OpenfoldOutput, + RosettafoldOutput, ] @@ -281,6 +288,27 @@ def run(args, config, defaults, config_file): outputs.append(oo) successful_runs.append(openfold_success) + if args.rosettafold3: + from abcfold.rosettafold3.run_rosettafold3 import run_rosettafold + + rosettafold_success = run_rosettafold( + input_json=run_json, + output_dir=args.output_dir, + save_input=args.save_input, + number_of_models=args.number_of_models, + config=rt_config, + ) + + if rosettafold_success: + rosettafold_output_dirs = list( + args.output_dir.glob("rosettafold_results*") + ) + ro = RosettafoldOutput( + rosettafold_output_dirs, input_params, name, args.save_input + ) + outputs.append(ro) + successful_runs.append(rosettafold_success) + if args.no_visuals: logger.info("Visuals disabled") return @@ -440,12 +468,41 @@ def run(args, config, defaults, config_file): ) protenix_models["models"].append(model_data) + rosettafold_models: Dict[str, List[Dict[str, Any]]] = {"models": []} + if args.rosettafold3: + if rosettafold_success: + programs_run.append("RosettaFold3") + for seed in ro.output.keys(): + for idx in ro.output[seed].keys(): + if idx >= 0: + model = ro.output[seed][idx]["cif"] + model.check_clashes() + score_file = ro.output[seed][idx]["scores"] + plddt = model.residue_plddts + pae = ro.output[seed][idx]["af3_pae"] + if len(indicies) > 0: + plddt = insert_none_by_minus_one( + indicies[index_counter], plddt + ) + index_counter += 1 + model_data = get_model_data( + model, + plot_dict, + "RosettaFold3", + plddt, + pae, + score_file, + args.output_dir, + ) + rosettafold_models["models"].append(model_data) + combined_models = ( alphafold_models["models"] + boltz_models["models"] + chai_models["models"] + openfold_models["models"] + protenix_models["models"] + + rosettafold_models["models"] ) # Make the output directory for the models @@ -463,6 +520,8 @@ def run(args, config, defaults, config_file): output_name = "openfold_model_" + model["model_id"][-1] + ".cif" elif model["model_source"] == "Protenix": output_name = "protenix_model_" + model["model_id"][-1] + ".cif" + elif model["model_source"] == "RosettaFold3": + output_name = "rosettafold_model_" + model["model_id"][-1] + ".cif" shutil.copy( cif_file, args.output_dir.joinpath("output_models").joinpath(output_name), @@ -551,7 +610,7 @@ def run(args, config, defaults, config_file): def main(): """ - Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix + Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix / RosettaFold3 """ import argparse @@ -581,7 +640,9 @@ def main(): defaults.update(dict(config.items(section))) parser = argparse.ArgumentParser( - description="Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix", + description=( + "Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix / RosettaFold3" + ), parents=[config_parser], ) @@ -591,6 +652,7 @@ def main(): parser = chai_argparse_util(parser) parser = openfold_argparse_util(parser) parser = protenix_argparse_util(parser) + parser = rosettafold_argparse_util(parser) parser = mmseqs2_argparse_util(parser) parser = custom_template_argpase_util(parser) parser = prediction_argparse_util(parser) diff --git a/abcfold/argparse_utils.py b/abcfold/argparse_utils.py index f4306bd..78fe50d 100644 --- a/abcfold/argparse_utils.py +++ b/abcfold/argparse_utils.py @@ -267,6 +267,16 @@ def openfold_argparse_util(parser): return parser +def rosettafold_argparse_util(parser): + parser.add_argument( + "-r", + "--rosettafold3", + action="store_true", + help="Run RosettaFold 3", + ) + return parser + + def alphafold_argparse_util(parser): parser.add_argument( "--database", @@ -336,10 +346,11 @@ def raise_argument_errors(args): and not args.chai1 and not args.protenix and not args.openfold3 + and not args.rosettafold3 ): logger.info( - dedent("None of AlphaFold3, Boltz, Chai-1, Protenix or OpenFold3 selected. \ - Running AlphaFold3 by default") + dedent("None of AlphaFold3, Boltz, Chai-1, Protenix, OpenFold3 or \ + RosettaFold3 selected. Running AlphaFold3 by default") ) args.alphafold3 = True diff --git a/abcfold/data/config.ini b/abcfold/data/config.ini index 456f1eb..fb84a4e 100644 --- a/abcfold/data/config.ini +++ b/abcfold/data/config.ini @@ -8,6 +8,7 @@ af3_sif_path = None model_params = /mnt/ligandpro/soft/protein/alphafold3/models boltz_weights = None openfold_weights = None +rosettafold_weights = None [Environments] af3_docker_env = romerolabduke/alphafast:latest @@ -15,6 +16,7 @@ boltz_env = abcfold-boltz-py311 chai_env = abcfold-chai-py311 openfold_env = abcfold-openfold-py311 protenix_env = abcfold-protenix-py311 +rosettafold_env = abcfold-rosetta-py312 [Versions] af3_version = 3.0.0 @@ -22,6 +24,7 @@ boltz_version = 2.2.1 chai_version = 0.6.1 openfold_version = 0.4.1 protenix_version = 2.0.0 +rosetta_version = 0.1.12 [Models] protenix_model = protenix-v2 diff --git a/abcfold/html/abcfold_vue.js b/abcfold/html/abcfold_vue.js index 21ca19c..b76c9bc 100644 --- a/abcfold/html/abcfold_vue.js +++ b/abcfold/html/abcfold_vue.js @@ -62,6 +62,8 @@ Vue.component('abc-table', { return 'btn-source4'; case 'OpenFold3': return 'btn-source5'; + case 'RosettaFold3': + return 'btn-source6'; default: return 'btn-default'; } diff --git a/abcfold/html/html_utils.py b/abcfold/html/html_utils.py index 5a8bd53..db220c5 100644 --- a/abcfold/html/html_utils.py +++ b/abcfold/html/html_utils.py @@ -16,6 +16,7 @@ from abcfold.output.file_handlers import ConfidenceJsonFile, NpzFile from abcfold.output.openfold3 import OpenfoldOutput from abcfold.output.protenix import ProtenixOutput +from abcfold.output.rosettafold3 import RosettafoldOutput from abcfold.plots.pae_plot import create_pae_plots from abcfold.plots.plddt_plot import plot_plddt from abcfold.scripts.ipsae import Ipsae @@ -293,6 +294,11 @@ def get_all_cif_files(outputs) -> Dict[str, list]: if "Protenix" not in method_cif_objs: method_cif_objs["Protenix"] = [] method_cif_objs["Protenix"].extend(output.cif_files[seed]) + elif isinstance(output, RosettafoldOutput): + for seed in output.seeds: + if "RosettaFold3" not in method_cif_objs: + method_cif_objs["RosettaFold3"] = [] + method_cif_objs["RosettaFold3"].extend(output.cif_files[seed]) return method_cif_objs diff --git a/abcfold/html/style.css b/abcfold/html/style.css index c46d947..7c20fc1 100644 --- a/abcfold/html/style.css +++ b/abcfold/html/style.css @@ -496,6 +496,16 @@ rect{ transition: background-color 0.3s ease; } +.btn-source6 { + background-color: #FFA533; + color: #3b3b3d; + border: none; + border-radius: 12px; + padding: 7px 15px; + cursor: pointer; + transition: background-color 0.3s ease; +} + .btn-default { background-color: rgb(199, 195, 195); color: white; @@ -506,6 +516,6 @@ rect{ transition: background-color 0.3s ease; } -.btn-source1:hover, .btn-source2:hover, .btn-source3:hover, .btn-default:hover { +.btn-source1:hover, .btn-source2:hover, .btn-source3:hover, .btn-source4:hover, .btn-source5:hover, .btn-source6:hover, .btn-default:hover { opacity: 0.8; } diff --git a/abcfold/output/file_handlers.py b/abcfold/output/file_handlers.py index 59d141d..c67abee 100644 --- a/abcfold/output/file_handlers.py +++ b/abcfold/output/file_handlers.py @@ -848,6 +848,26 @@ def from_openfold(cls, cls._fix_openfold_mmcif(str(openfold_path), str(tmp_path)) return cls(str(tmp_path), input_params) + @classmethod + def from_rosettafold(cls, + rosettafold_cif: Union[str, Path], + input_params: Optional[dict] = None) -> "CifFile": + """ + Create a CifFile from a RosettaFold3 mmCIF and convert pLDDT scores + stored in the B-factor field from 0-1 scale to 0-100 scale. + """ + rosettafold_path = Path(rosettafold_cif) + tmp_path = rosettafold_path.parent / f"{rosettafold_path.stem}_fixed.cif" + parser = MMCIFParser(QUIET=True) + model = parser.get_structure(rosettafold_path.stem, rosettafold_path) + for atom in model.get_atoms(): + atom.bfactor = atom.bfactor * 100.0 + io = MMCIFIO() + io.set_structure(model) + io.save(str(tmp_path)) + + return cls(str(tmp_path), input_params) + class ConfidenceJsonFile(FileBase): def __init__(self, json_file: Union[str, Path]): diff --git a/abcfold/output/rosettafold3.py b/abcfold/output/rosettafold3.py new file mode 100644 index 0000000..a128b4f --- /dev/null +++ b/abcfold/output/rosettafold3.py @@ -0,0 +1,210 @@ +import logging +from pathlib import Path +from typing import Union + +from abcfold.output.file_handlers import CifFile, ConfidenceJsonFile, FileTypes +from abcfold.output.utils import Af3Pae + +logger = logging.getLogger("logger") + + +class RosettafoldOutput: + def __init__( + self, + rosettafold_output_dirs: list[Union[str, Path]], + input_params: dict, + name: str, + save_input: bool = False, + ): + """ + Object to process the output of an RosettaFold 3 run + + Args: + rosettafold_output_dirs (list[Union[str, Path]]): Path to the RosettaFold 3 + output directory + input_params (dict): Dictionary containing the input parameters used for the + RosettaFold 3 run + name (str): Name given to the RosettaFold 3 run + save_input (bool): If True, RosettaFold 3 was run with the save_input flag + + Attributes: + output_dirs (list): List of paths to the RosettaFold 3 output directory(s) + input_params (dict): Dictionary containing the input parameters used for the + RosettaFold 3 run + name (str): Name given to the RosettaFold 3 run + output (dict): Dictionary containing the processed output the contents + of the RosettaFold 3 output directory(s). The dictionary is structured as + follows: + + { + "seed-1": { + 1: { + "cif": CifFile, + "scores": ConfidenceJsonFile, + "af3_pae": ConfidenceJsonFile, + }, + 2: { + "cif": CifFile, + "scores": ConfidenceJsonFile, + "af3_pae": ConfidenceJsonFile, + }, + }, + etc... + } + pae_files (list): Ordered list of ConfidenceJsonFile objects containing the + PAE data + cif_files (list): Ordered list of CifFile objects containing the model data + scores_files (list): Ordered list of ConfidenceJsonFile objects containing + the model scores + """ + self.output_dirs = [Path(x) for x in rosettafold_output_dirs] + self.input_params = input_params + self.name = name + self.save_input = save_input + + parent_dir = self.output_dirs[0].parent + new_parent = parent_dir / f"rosettafold_{self.name}" + new_parent.mkdir(parents=True, exist_ok=True) + + if self.save_input: + rosettafold_json = list(parent_dir.glob("*.json"))[0] + if rosettafold_json.exists(): + rosettafold_json.rename(new_parent / "rosettafold_input.json") + + rosettafold_msas = list(parent_dir.glob("*/*.a3m")) + if rosettafold_msas: + for rosettafold_msa in rosettafold_msas: + if rosettafold_msa.exists(): + rosettafold_msa.rename(new_parent / rosettafold_msa.name) + + new_output_dirs = [] + for output_dir in self.output_dirs: + if output_dir.name.startswith("rosettafold_results_"): + new_path = new_parent / output_dir.name + output_dir.rename(new_path) + new_output_dirs.append(new_path) + else: + new_output_dirs.append(output_dir) + self.output_dirs = new_output_dirs + + self.output = self.process_rosettafold_output() + + self.seeds = list(self.output.keys()) + self.pae_files = { + seed: [value["pae"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.cif_files = { + seed: [value["cif"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.scores_files = { + seed: [value["score"] for value in self.output[seed].values()] + for seed in self.seeds + } + self.pae_to_af3() + self.af3_pae_files = { + seed: [value["af3_pae"] for value in self.output[seed].values()] + for seed in self.seeds + } + + def process_rosettafold_output(self): + """ + Function to process the output of a RosettaFold 3 run + """ + + file_groups: dict[str, dict[int, list]] = {} + for pathway in self.output_dirs: + seed = pathway.name.split("_")[-1] + if seed not in file_groups: + file_groups[seed] = {} + + for output in pathway.rglob("*"): + number = None + number_str = output.stem.split("_sample-")[-1].split('_')[0] + if not number_str.isdigit(): + continue + number = int(number_str) + + file_type = output.suffix[1:] + + file_: Union[CifFile, ConfidenceJsonFile] + if file_type == FileTypes.CIF.value: + file_ = CifFile.from_rosettafold(output, self.input_params) + elif file_type == FileTypes.JSON.value: + file_ = ConfidenceJsonFile(str(output)) + else: + continue + if number not in file_groups[seed]: + file_groups[seed][number] = [file_] + else: + file_groups[seed][number].append(file_) + + seed_dict = {} + for seed, models in file_groups.items(): + model_number_file_type_file = {} + for model_number, files in models.items(): + intermediate_dict: dict[ + str, Union[CifFile, ConfidenceJsonFile] + ] = {} + for file_ in sorted(files, key=lambda x: x.suffix): + if ( + "confidences" in file_.pathway.stem + and "summary" not in file_.pathway.stem + ) and isinstance(file_, ConfidenceJsonFile): + intermediate_dict["pae"] = file_ + elif ( + "summary_confidences" in file_.pathway.stem + ) and isinstance(file_, ConfidenceJsonFile): + intermediate_dict["score"] = file_ + elif isinstance(file_, CifFile): + if file_.pathway.suffix == ".cif": + file_.name = f"rosettafold_{seed}_{model_number}" + intermediate_dict["cif"] = file_ + else: + continue + + model_number_file_type_file[model_number] = intermediate_dict + + model_number_file_type_file = { + key: model_number_file_type_file[key] + for key in sorted(model_number_file_type_file) + } + seed_dict[seed] = model_number_file_type_file + + return seed_dict + + def pae_to_af3(self): + """ + Convert the PAE data from OpenFold 3 to the format used by Alphafold3 + + Returns: + None + """ + new_pae_files: dict[str, list[ConfidenceJsonFile]] = {} + for seed in self.seeds: + for (pae_file, cif_file) in zip(self.pae_files[seed], self.cif_files[seed]): + pae = Af3Pae.from_rosettafold3( + pae_file.data, + cif_file, + ) + + out_name = pae_file.pathway + + pae.to_file(out_name) + + if seed not in new_pae_files: + new_pae_files[seed] = [] + new_pae_files[seed].append(ConfidenceJsonFile(out_name)) + + self.output = { + seed: { + i: { + "cif": cif_file, + "af3_pae": new_pae_files[seed][i], + "scores": self.output[seed][i]["score"], + } + for i, cif_file in enumerate(self.cif_files[seed]) + } + for seed in self.seeds + } diff --git a/abcfold/output/utils.py b/abcfold/output/utils.py index cc4e688..9c12575 100644 --- a/abcfold/output/utils.py +++ b/abcfold/output/utils.py @@ -338,6 +338,39 @@ def from_openfold3(cls, scores: dict, cif_file: CifFile): return cls(af3_scores) + @classmethod + def from_rosettafold3(cls, scores: dict, cif_file: CifFile): + af3_scores = AF3TEMPLATE.copy() + + chain_lengths = cif_file.chain_lengths(mode="residues", ligand_atoms=True) + residue_lengths = cif_file.chain_lengths(mode="all", ligand_atoms=True) + + atom_chain_ids = flatten( + [[key] * value for key, value in residue_lengths.items()] + ) + + atom_plddts = cif_file.plddts + token_chain_ids = flatten( + [[key] * value for key, value in chain_lengths.items()] + ) + + token_res_ids = flatten( + [ + [value for value in values] + for _, values in cif_file.token_residue_ids().items() + ] + ) + + pae_matrix = np.asarray(scores["pae"]) + af3_scores["pae"] = pae_matrix.tolist() + af3_scores["atom_chain_ids"] = atom_chain_ids + af3_scores["atom_plddts"] = atom_plddts + af3_scores["contact_probs"] = np.zeros(shape=pae_matrix.shape).tolist() + af3_scores["token_chain_ids"] = token_chain_ids + af3_scores["token_res_ids"] = token_res_ids + + return cls(af3_scores) + def __init__(self, af3_scores: dict): self.scores = af3_scores diff --git a/abcfold/plots/pae-viewer-main/standalone/css/paeViewerStandaloneLayoutRosetta.css b/abcfold/plots/pae-viewer-main/standalone/css/paeViewerStandaloneLayoutRosetta.css new file mode 100644 index 0000000..66ad3dd --- /dev/null +++ b/abcfold/plots/pae-viewer-main/standalone/css/paeViewerStandaloneLayoutRosetta.css @@ -0,0 +1,519 @@ +:root { + --primary-color: #49aaa1eb; + --secondary-color: #2196F3; + --accent-color: gray; + --font-color: #564b47; +} + +/* css released under Creative Commons License - http://creativecommons.org/licenses/by/2.0/deed.en */ +/* html5 + CSS 3 Template created by miss monorom http://intensivstation.ch 2013 */ + +* { + box-sizing: border-box; +} + +/* renders html5 elements as block */ + + +header, footer, main, aside, nav, article { + display: block; +} + +body { + background-color: rgba(170, 170, 170, 0.4); +} + +a { + text-decoration: none; /*#999; */ + color: #1565c0; +} + +a:hover { + text-decoration: underline; +} + +h1 { + font-size: 1.5em; +} + +h2 { + font-size: 1.25em; +} + +img.download { + vertical-align: middle; +} + +img { + border: none; +} + +aside { + font-size: small; +} + +li { + line-height: 1.5em; + list-style-position: inside; + text-indent: -1.3em; + padding-left: 1.3em; +} + +p { + line-height: 1.5em; +} + +ul { + padding-left: 0; +} + +/* ----------container centers the layout -------------- */ + +#mini { + width: 50px; +} + +#container { + max-width: 85em; + min-width: 20em; + margin: 0 auto; + background-color: #eee; + border-left: 1px solid #dedede; + -moz-box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); + box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); + position: relative; +} + + +/* form */ + +label, .label { + font-weight: bold; + /*color: var(--secondary-color);*/ + margin-right: 10px; +} + +input[type=submit] { + margin-top: 1em; +} + +input[type=radio] { + margin-right: 0.2em; +} + +.searchGroup > button { + background-color: #33b5e5; +} + +button:not(:first-child), +a.button:not(:first-child), +.button:not(:first-child), +input[type=submit]:not(:first-child) { + margin-left: 5px; +} + +a.button:hover { + text-decoration: none; + color: white; +} + +a.button:visited { + color: white; +} + +button:focus, .button:focus { + outline: 0; +} + +button:active, .button:active { + transform: translate(1px, 1px); +} + +textarea, .textarea { + font-family: monospace; + line-height: 1.2em; + resize: vertical; +} + + +/* ----------header for logo-------------- */ +header { + background-color: var(--primary-color); + box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); + display: flex; + flex-direction: row; + flex-wrap: wrap; + justify-content: left; + align-items: baseline; + color: whitesmoke; + padding: 0.5em 2em 0.5em 2em; + column-gap: 2em; + white-space:nowrap; + overflow: hidden; +} + +.title { + font-size: xxx-large; + font-variant: small-caps; +} + +.subtitle { + font-size: xx-large; +} + +/* new nav */ +.top-flex { + display: flex; + flex-direction: column; + justify-content: space-between; +} + +header nav { + align-self: flex-end; + padding: 0 1em 0 1em; +} + +header nav ul { + display: flex; + flex-flow: row wrap; + justify-content: flex-end; + list-style-type: none; + margin: 0; +} + +header nav li { + margin: 0 0.1em 0.2em 0.1em; + line-height: 1.5em; + list-style: none; + text-indent: 0; + padding-left: 0; +} + +header nav a, nav a:visited { + background-color: #E3F2FD; + color: #333; + display: block; + padding: 5px 15px; + text-decoration: none; + border-radius: 0 0 2px 2px; + text-overflow: ellipsis; + overflow: hidden; + white-space: nowrap; +} + +header nav a:hover { + background-color: white; + color: #333; +} + +header nav a:active { + color: black; +} + +header nav a.selected { + color: white; + background-color: #f67; +} + +#searchWrapper { + align-self: flex-end; + margin: 1em; +} + +.searchGroup { + display: flex; + align-items: center; + flex-wrap: wrap; +} + +.searchGroup > * { + min-width: 0; + margin: 2px; +} + +#searchBox { + width: 20em; + max-width: 20em; + flex-grow: 1; + flex-shrink: 1; + min-width: 5em; +} + +#floatTop { + padding: 10px 20px; + background-color: var(--primary-color); + position: relative; + display: none; + top: 0; + width: 100%; + overflow: hidden; + z-index: 20; + color: white; + box-shadow: 1px 2px 2px 1px rgba(0, 0, 0, 0.07); +} + +#floatTop span { + font-size: 2em; +} + +#floatTop nav { + float: right; + text-align: right +} + +#floatTop nav ul { + margin: 0; + padding: 10px 0; +} + +#floatTop nav li { + display: inline-block; +} + +#floatTop nav li a { + color: white; + background: transparent; + padding: 7.5px 10px; + border: 1px solid white; + margin: 2.5px; + border-radius: 3px; +} + +#floatTop nav li a:hover { + color: var(--primary-color); + background: #E8F5E9; + text-decoration: none; +} + +#floatTop nav li a:visited { + color: white; +} + +#floatTop nav li a:visited:hover { + color: var(--primary-color); +} + +.footnote { + font-size: smaller; + background: #eee; + padding: 10px; +} + +.footnote p { + margin: 0; +} + +/* ----------------- content--------------------- */ +main:before { + height: 0; + content: "."; + display: block; + clear: both; + visibility: hidden; +} + +main { + padding: 20px; + background-color: white; + min-height: 600px; +} + +/* -------------- side infos ------------- */ + +aside { + padding: 20px; + float: right; + width: 24%; +} + +aside img { + border: 1px solid #bbb; + box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); + -moz-box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); +} + +/* -----------footer--------------------------- */ + +footer { + padding: 20px; + clear: both; + background-color: #424242; + color: #fff; + display: flex; + flex-direction: row-reverse; +} + +footer a { + color: #aaa; + text-decoration: none; +} + +footer a:visited { + color: #aaa; + text-decoration: none; +} + +.footer-segment { + display: inline-block; + vertical-align: top; + width: 25%; +} + +.footer-segment ul { + padding: 0; + list-style: none; +} + +/*The following line are formatting the browser!*/ + +.box { + padding: 15px; + margin-bottom: 10px; + border: 1px solid #bbb; + border-radius: 3px; + box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); + -moz-box-shadow: 1px 1px 2px 1px rgba(0, 0, 0, 0.07); +} + +#float_container { + position: fixed; + right: 5px; + bottom: 10%; + z-index: 100; +} + +#float_container img { + display: block; + width: 50px; + height: 50px; + border-radius: 4px; + padding: 10px; + background-color: #000; + opacity: 0.5; + filter: alpha(opacity=50); + font-size: 100%; + font-weight: bold; + margin: 10px; + background-position: center; + background-size: 60% auto; + background-repeat: no-repeat; +} + +#float_container img:hover { + opacity: 0.7; + filter: alpha(opacity=70); +} + +/* -------------------- Media Queries -------------------- */ +@media only screen and (max-width: 860px) { + form input { + width: 90%; + } + + header nav { + padding: 0; + align-self: stretch; + } + + header nav ul { + justify-content: space-between; + } + + header nav li { + flex-grow: 1; + } + + #floatTop { + height: 0; + width: 0; + padding: 0; + margin: 0; + display: none; + overflow: hidden; + } +} + +@media only screen and (max-width: 768px) { + header { + flex-direction: column; + align-items: center; + } + + .subtitle { + display: none; + } + + main#content { + float: none; + width: 100%; + padding: 10px; + } + + aside { + padding: 20px; + float: none; + width: 100%; + } + + .footer-segment { + display: inline-block; + vertical-align: top; + width: 49%; + padding: 0 0 20px; + } + + .footer-segment:last-child { + border-top: 1px solid #999; + display: block; + width: 100%; + padding: 20px 0 0; + } +} + +@media only screen and (max-width: 480px) { + header nav ul { + padding-left: 0; + } + + header nav li { + float: none; + margin: 0; + width: 100%; + display: block; + } + + header nav a { + width: 100%; + position: relative; + } + + header nav a:not(.selected):after { + content: '>'; + position: absolute; + right: 10px; + } + + footer.footer { + padding: 0; + } + + .footer-segment, .footer-segment:last-child { + display: block; + vertical-align: top; + width: 100%; + border-bottom: 1px solid #999; + padding: 20px; + } +} + +.card { + --bs-card-cap-color: #F8F9FA; +} + +.card-header { + padding:var(--bs-card-cap-padding-y) var(--bs-card-cap-padding-x); + margin-bottom:0; + color:var(--bs-card-cap-color); + background-color:var(--primary-color); + border-bottom:var(--bs-card-border-width) solid var(--bs-card-border-color) + } diff --git a/abcfold/plots/pae_plot.py b/abcfold/plots/pae_plot.py index 02a56b8..0b5e494 100644 --- a/abcfold/plots/pae_plot.py +++ b/abcfold/plots/pae_plot.py @@ -15,6 +15,7 @@ from abcfold.output.file_handlers import CifFile from abcfold.output.openfold3 import OpenfoldOutput from abcfold.output.protenix import ProtenixOutput +from abcfold.output.rosettafold3 import RosettafoldOutput logger = logging.getLogger(__name__) @@ -24,6 +25,7 @@ "C": "paeViewerStandaloneLayoutChai.css", "O": "paeViewerStandaloneLayoutOpenfold.css", "P": "paeViewerStandaloneLayoutProtenix.css", + "R": "paeViewerStandaloneLayoutRosetta.css", } PAEVIEWER = Path(__file__).parent.joinpath( @@ -193,6 +195,32 @@ def create_pae_plots( continue + elif isinstance(output, RosettafoldOutput): + css_path = CSSPATHS["R"] + template_file = plots_dir.joinpath("rosettafold_template.html") + template_files.append(template_file) + cmd = get_template_run_script( + "ABCFold - RosettaFold 3 Output", + css_path, + template_file, + output_dir.joinpath(".pae_viewer"), + ) + run_script(cmd) + + for seed in output.seeds: + run_scripts.extend( + prepare_scripts( + output.cif_files[seed], + output.af3_pae_files[seed], + plots_dir, + pathway_plot, + template_file, + True, + ) + ) + + continue + elif isinstance(output, AlphafoldOutput): css_path = CSSPATHS["A"] template_file = plots_dir.joinpath("af3_template.html") diff --git a/abcfold/plots/plddt_plot.py b/abcfold/plots/plddt_plot.py index 2ff5fd0..39b12bf 100644 --- a/abcfold/plots/plddt_plot.py +++ b/abcfold/plots/plddt_plot.py @@ -56,6 +56,7 @@ def plot_plddt( "Chai-1": px.colors.qualitative.Prism, "Protenix": px.colors.qualitative.Set3, "OpenFold3": px.colors.qualitative.Alphabet, + "RosettaFold3": px.colors.qualitative.Dark24, } line_ranges: dict = {} diff --git a/abcfold/rosettafold3/af3_to_rosettafold3.py b/abcfold/rosettafold3/af3_to_rosettafold3.py new file mode 100644 index 0000000..10041fe --- /dev/null +++ b/abcfold/rosettafold3/af3_to_rosettafold3.py @@ -0,0 +1,170 @@ +import json +import logging +import random +import string +from pathlib import Path +from typing import Any, Dict, Union + +logger = logging.getLogger("logger") + + +class Rosettafoldjson: + """ + Object to convert an AlphaFold3 json file to a RosettaFold3 JSON file. + """ + + def __init__(self, working_dir: Union[str, Path], + create_files: bool = True): + self.working_dir = working_dir + self.seeds: list = [42] + self.__ids: Dict = {} + self.__create_files = create_files + self.name = "" + self.rosettafold_dict: Dict = {} + + @property + def chain_ids(self) -> Dict: + return self.__ids + + def msa_to_file(self, msa: str, file_path: Union[str, Path]): + """ + Takes a msa string and writes it to a file + + Args: + msa (str): msa string + file_path (Union[str, Path]): file path to write the msa to + + Returns: + None + """ + + with open(file_path, "w") as f: + f.write(msa) + + def json_to_json( + self, + json_file_or_dict: Union[dict, str, Path], + ): + """ + Main function to convert an AF3 json file or dict to a RosettaFold3 json string + + Args: + json_file_or_dict (Union[dict, str, Path]): json file or dict + + Returns: + Dict: RosettaFold3 dictionary + """ + logger.info("Converting input json to a RosettaFold3 compatible json file") + if isinstance(json_file_or_dict, str) or isinstance(json_file_or_dict, Path): + with open(json_file_or_dict, "r") as f: + json_dict = json.load(f) + else: + json_dict = json_file_or_dict + + rosettafold_sequences = [] + for key, value in json_dict.items(): + if key == "name": + self.name = value + if key == "modelSeeds": + if isinstance(value, list): + self.seeds = value + elif isinstance(value, int): + self.seeds = [value] + if key == "sequences": + for entry in value: + if "protein" in entry: + for chain_id in entry["protein"].get("id", []): + chain_entry = self.convert_component(entry["protein"], + chain_id) + rosettafold_sequences.append(chain_entry) + elif "rna" in entry: + for chain_id in entry["rna"].get("id", []): + chain_entry = self.convert_component(entry["rna"], + chain_id) + rosettafold_sequences.append(chain_entry) + elif "dna" in entry: + for chain_id in entry["dna"].get("id", []): + chain_entry = self.convert_component(entry["dna"], + chain_id) + rosettafold_sequences.append(chain_entry) + elif "ligand" in entry: + for chain_id in entry["ligand"].get("id", []): + chain_entry = self.convert_ligand(entry["ligand"]) + rosettafold_sequences.append(chain_entry) + + self.rosettafold_dict = { + "name": self.name, + "components": rosettafold_sequences + } + + return self.rosettafold_dict + + def convert_component(self, seq_dict, chain_id) -> Dict[str, Any]: + sequence = seq_dict["sequence"] + modifications = seq_dict.get("modifications", []) + unpaired_msa = seq_dict.get("unpairedMsa") + + sequence_list = list(sequence) + if modifications: + for mod in modifications: + if "ptmType" in mod and "ptmPosition" in mod: + ptm_type = mod['ptmType'] + position = int(mod['ptmPosition']) - 1 + sequence_list[position] = f"({ptm_type})" + if unpaired_msa is not None: + msa_lines = unpaired_msa.splitlines() + input_seq = msa_lines[1] + idx = int(mod['ptmPosition']) - 1 + input_seq = input_seq[:idx] + 'X' + input_seq[idx+1:] + msa_lines[1] = input_seq + seq_dict['unpairedMsa'] = "\n".join(msa_lines) + elif "modificationType" in mod and "basePosition" in mod: + mod_type = mod['modificationType'] + position = int(mod['basePosition']) - 1 + sequence_list[position] = f"({mod_type})" + sequence = ''.join(sequence_list) + + chain = { + "seq": sequence, + "chain_id": chain_id, + } + + random_string = ''.join(random.choices(string.ascii_letters, k=5)) + msa_dir = Path(self.working_dir) / random_string + if unpaired_msa and self.__create_files: + msa_out = msa_dir / "colabfold_main.a3m" + if not msa_dir.exists(): + msa_dir.mkdir(parents=True, exist_ok=True) + self.msa_to_file( + unpaired_msa, + msa_out + ) + chain["msa_path"] = msa_out.resolve().as_posix() + + return chain + + def convert_ligand(self, seq_dict) -> Dict[str, Any]: + ligand_chain = {} + + if "ccdCodes" in seq_dict: + ligand_id = seq_dict["ccdCodes"][0] + ligand_chain["ccd_code"] = ligand_id + else: + ligand_id = seq_dict["smiles"] + ligand_chain["smiles"] = ligand_id + + return ligand_chain + + def write_json(self, out_file: Union[str, Path]): + """ + Write the RosettaFold3 json to a file + + Args: + out_file (Union[str, Path]): output file path + + Returns: + None + """ + + with open(out_file, "w") as f: + json.dump([self.rosettafold_dict], f, indent=4) diff --git a/abcfold/rosettafold3/check_install.py b/abcfold/rosettafold3/check_install.py new file mode 100644 index 0000000..a0b8efc --- /dev/null +++ b/abcfold/rosettafold3/check_install.py @@ -0,0 +1,73 @@ +import logging +import urllib.request +from pathlib import Path + +from abcfold.backend_envs import MicromambaEnv + +logger = logging.getLogger("logger") + +RF3_BASE_URL = "http://files.ipd.uw.edu/pub/rf3" +CHECKPOINT_NAME = "rf3_foundry_01_24_latest_remapped.ckpt" +RF3_URL = f"{RF3_BASE_URL}/{CHECKPOINT_NAME}" + + +def ensure_rosettafold_checkpoint(target_path: Path) -> Path: + if target_path.exists(): + return target_path + + target_path.parent.mkdir(parents=True, exist_ok=True) + + try: + logger.info( + "Downloading RoseTTAFold3 checkpoint via HTTPS " + "(this may take a while)..." + ) + urllib.request.urlretrieve(RF3_URL, target_path) + except Exception as e: + raise RuntimeError( + "Failed to download RoseTTAFold3 checkpoint.\n" + f"Target: {target_path}\n" + f"Error: {e}" + ) + + if not target_path.exists(): + raise RuntimeError("Checkpoint download completed but file not found") + + return target_path + + +def ensure_rosettafold_env(config: dict) -> MicromambaEnv: + ROSETTAFOLD_ENV = config['rosettafold_env'] + ROSETTAFOLD_VERSION = config['rosetta_version'] + + env = MicromambaEnv(ROSETTAFOLD_ENV) + # 1. Ensure env exists + env.create(python_version="3.12") + + # 2. Check installed rosettafold version + installed = env.get_installed_version("rc-foundry") + + if installed != ROSETTAFOLD_VERSION: + if installed is None: + logger.info("RosettaFold3 not found. Installing rc-foundry version: %s", + ROSETTAFOLD_VERSION) + else: + logger.info( + "RosettaFold3 version mismatch (found %s). " + "Installing correct version: %s", + installed, + ROSETTAFOLD_VERSION, + ) + + env.pip_install([ + f"rc-foundry[rf3]=={ROSETTAFOLD_VERSION}", + ]) + else: + logger.info("RosettaFold3 is already up-to-date (%s)", ROSETTAFOLD_ENV) + + # 3. Ensure runtime deps you *actually* need + env.ensure_package("numpy") + env.ensure_package("typer") + env.ensure_package("matplotlib") + + return env diff --git a/abcfold/rosettafold3/run_rosettafold3.py b/abcfold/rosettafold3/run_rosettafold3.py new file mode 100644 index 0000000..b0a7fd7 --- /dev/null +++ b/abcfold/rosettafold3/run_rosettafold3.py @@ -0,0 +1,162 @@ +import logging +import subprocess +import tempfile +from pathlib import Path +from typing import Union + +from abcfold.rosettafold3.af3_to_rosettafold3 import Rosettafoldjson +from abcfold.rosettafold3.check_install import (CHECKPOINT_NAME, + ensure_rosettafold_checkpoint, + ensure_rosettafold_env) + +logger = logging.getLogger("logger") + + +def run_rosettafold( + input_json: Union[str, Path], + output_dir: Union[str, Path], + config: dict, + save_input: bool = False, + test: bool = False, + number_of_models: int = 5, +) -> bool: + """ + Run RosettaFold3 using the input JSON file + + Args: + input_json (Union[str, Path]): Path to the input JSON file + output_dir (Union[str, Path]): Path to the output directory + config (dict): Configuration dictionary + save_input (bool): If True, save the input JSON file and MSA to the output + directory + test (bool): If True, run the test command + number_of_models (int): Number of models to generate + + Returns: + Bool: True if the RosettaFold3 run was successful, False otherwise + + Raises: + subprocess.CalledProcessError: If the RosettaFold3 command returns an error + + """ + input_json = Path(input_json) + output_dir = Path(output_dir) + + env = None + rosettafold_ckpt = Path("test.ckpt") + if not test: + logger.debug("Checking if RosettaFold3 is installed") + env = ensure_rosettafold_env(config=config) + + rosettafold_weight_dir = config["rosettafold_weights"] + if rosettafold_weight_dir is not None and rosettafold_weight_dir != "None": + cache_path = Path(rosettafold_weight_dir) + else: + cache_path = Path.home().joinpath(".rosettafold3") + + default_ckpt = cache_path.joinpath(CHECKPOINT_NAME) + if not default_ckpt.exists(): + logger.info( + "No Checkpoint file found. " + f"Downloading RosettaFold3 checkpoint to {default_ckpt}" + ) + rosettafold_ckpt = ensure_rosettafold_checkpoint(default_ckpt) + else: + rosettafold_ckpt = default_ckpt + + with tempfile.TemporaryDirectory() as temp_dir: + working_dir = Path(temp_dir) + if save_input: + logger.info("Saving msa to the output directory") + working_dir = output_dir + + rosettafold_json = Rosettafoldjson(working_dir) + rosettafold_json.json_to_json(input_json) + + for seed in rosettafold_json.seeds: + out_file = working_dir.joinpath(f"{input_json.stem}_seed-{seed}.json") + + rosettafold_json.write_json(out_file) + logger.info("Running RosettaFold3 using seed: %s", seed) + rosettafold_out_dir = output_dir / f"rosettafold_results_seed-{seed}" + cmd = generate_rosettafold_command( + out_file, + rosettafold_out_dir, + rosettafold_ckpt, + number_of_models, + seed, + ) + + if test: + continue + + try: + assert env is not None + env.run(cmd) + except subprocess.CalledProcessError as e: + stderr = e.stderr or "" + if stderr: + if working_dir.exists(): + output_err_file = working_dir / "rosettafold_error.log" + else: + output_err_file = working_dir.parent / "rosettafold_error.log" + output_err_file.write_text(stderr) + logger.error( + "RosettaFold3 run failed. Error log is in %s", output_err_file + ) + else: + logger.error("RosettaFold3 run failed") + return False + + logger.info("RosettaFold3 run complete") + logger.info("Output files are in %s", output_dir) + return True + + +def generate_rosettafold_command( + input_json: Union[str, Path], + output_dir: Union[str, Path], + ckpt_path: Union[str, Path], + number_of_models: int = 5, + seed: int = 42, +) -> list: + """ + Generate the RosettaFold3 command + + Args: + input_json (Union[str, Path]): Path to the input JSON file + output_dir (Union[str, Path]): Path to the output directory + ckpt_path (Union[str, Path]): Path to the inference CheckPoint + number_of_models (int): Number of models to generate + seed (int): Random seed to use for the RosettaFold3 run + + Returns: + list: The RosettaFold3 command + """ + return [ + "rf3", + "fold", + f"inputs='{input_json}'", + f"out_dir='{output_dir}'", + f"diffusion_batch_size={str(number_of_models)}", + f"seed={str(seed)}", + f"ckpt_path='{ckpt_path}'", + ] + + +def generate_rosettafold_test_command() -> list: + """ + Generate the test command for RosettaFold3 + + Args: + None + + Returns: + list: The OpenFold 3 test command + """ + + return [ + "rf3", + "fold", + "--help", + ] diff --git a/pyproject.toml b/pyproject.toml index c249c9d..7482e64 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ packages = [ "abcfold.chai1", "abcfold.protenix", "abcfold.openfold3", + "abcfold.rosettafold3", "abcfold.html", "abcfold.output", "abcfold.plots", diff --git a/tests/test_af3_to_rosettafold.py b/tests/test_af3_to_rosettafold.py new file mode 100644 index 0000000..96d0bd6 --- /dev/null +++ b/tests/test_af3_to_rosettafold.py @@ -0,0 +1,161 @@ +import json +import tempfile +from pathlib import Path + +from abcfold.rosettafold3.af3_to_rosettafold3 import Rosettafoldjson + +# flake8: noqa + + +def test_af3_to_rosettafold(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputAB_json) + + reference = { + 'name': '2PV7', + 'components': + [ + {'seq': 'GMRES', 'chain_id': 'A'}, + {'seq': 'GMRES', 'chain_id': 'B'}, + {'seq': 'YANEN', 'chain_id': 'C'}, + {'ccd_code': 'ATP'}, + {'ccd_code': 'ATP'}, + {'smiles': 'CC(=O)OC1C[NH+]2CCC1CC2'} + ] + } + + + assert data == reference + + +def test_af3_to_rosettafold_rna(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputRNA_json) + reference = { + 'name': 'RNA_example', + 'components': + [ + {'seq': 'AGCU', 'chain_id': 'A'}, + ] + } + + assert data == reference + + +def test_af3_to_rosettafold_dna(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputDNA_json) + reference = { + 'name': 'DNA_example', + 'components': + [ + {'seq': 'AGCT', 'chain_id': 'A'}, + {'seq': 'AGCT', 'chain_id': 'B'} + ] + } + + assert data == reference + + +def test_af3_to_rosettafold_ligand(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputLIG_json) + reference = { + 'name': '2PV7', + 'components': + [ + {'seq': 'GMRESYANENQFGFKTINSDIHKIVIVGGYGKLGGLFARYLRASGYPISILDREDWAVAESILANADVVIVSVPINLTLETIERLKPYLTENMLLADLTSVKREPLAKMLEVHTGAVLGLHPMFGADIASMAKQVVVRCDGRFPERYEWLLEQIQIWGAKIYQTNATEHDHNMTYIQALRHFSTFANGLHLSKQPINLANLLALSSPIYRLELAMIGRLFAQDAELYADIIMDKSENLAVIETLKQTYDEALTFFENNDRQGFIDAFHKVRDWFGDYSEQFLKESRQLLQQANDLKQG', 'chain_id': 'A'}, + {'seq': 'GMRESYANENQFGFKTINSDIHKIVIVGGYGKLGGLFARYLRASGYPISILDREDWAVAESILANADVVIVSVPINLTLETIERLKPYLTENMLLADLTSVKREPLAKMLEVHTGAVLGLHPMFGADIASMAKQVVVRCDGRFPERYEWLLEQIQIWGAKIYQTNATEHDHNMTYIQALRHFSTFANGLHLSKQPINLANLLALSSPIYRLELAMIGRLFAQDAELYADIIMDKSENLAVIETLKQTYDEALTFFENNDRQGFIDAFHKVRDWFGDYSEQFLKESRQLLQQANDLKQG', 'chain_id': 'B'}, + {'ccd_code': 'ATP'}, + {'ccd_code': 'ATP'}, + {'smiles': 'CC(=O)OC1C[NH+]2CCC1CC2'}, + {'smiles': 'CCCCCCCCCCCC(O)=O'}, + {'smiles': 'CCCCCCCCCCCC(O)=O'}, + {'ccd_code': 'MG'} + ] + } + + assert data == reference + + +def test_af3_to_rosettafold_ptm(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputPTM_json) + reference = { + 'name': 'PTM example', + 'components': + [ + {'seq': '(HY3)VLS(P1L)GEWQL', 'chain_id': 'A'}, + {'seq': '(2MG)GC(5MC)', 'chain_id': 'B'} + ] + } + + assert data == reference + +def test_rosettafold_output_msa(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + data = rosettafold_json.json_to_json(test_data.test_inputAmsa_json) + msa_path = ( + data["components"][0].get("msa_path") + ) + # MSA directory has a random path, so just check that it exists then give + # it a placeholder value for comparison + assert msa_path is not None + assert Path(msa_path).exists() + data["components"][0]["msa_path"] = ( + "PRECOMPUTED_MSA" + ) + + reference = { + 'name': '2PV7', + 'components': + [ + {'seq': 'GMRESYANENQFGFKTINSDIHKIVIVGGYGKLGGLFARYLRASGYPISILDREDWAVAESILANADVVIVSVPINLTLETIERLKPYLTENMLLADLTSVKREPLAKMLEVHTGAVLGLHPMFGADIASMAKQVVVRCDGRFPERYEWLLEQIQIWGAKIYQTNATEHDHNMTYIQALRHFSTFANGLHLSKQPINLANLLALSSPIYRLELAMIGRLFAQDAELYADIIMDKSENLAVIETLKQTYDEALTFFENNDRQGFIDAFHKVRDWFGDYSEQFLKESRQLLQQANDLKQG', + 'chain_id': 'A', + 'msa_path': 'PRECOMPUTED_MSA'}, + ] + } + + assert data == reference + +def test_rosettafold_write_json(test_data): + with tempfile.TemporaryDirectory() as temp_dir: + rosettafold_json = Rosettafoldjson(temp_dir) + + rosettafold_json.json_to_json(test_data.test_inputAB_json) + out_file = Path(temp_dir) / "rosettafold_output.json" + rosettafold_json.write_json(out_file) + + reference = [ + { + 'name': '2PV7', + 'components': + [ + {'seq': 'GMRES', 'chain_id': 'A'}, + {'seq': 'GMRES', 'chain_id': 'B'}, + {'seq': 'YANEN', 'chain_id': 'C'}, + {'ccd_code': 'ATP'}, + {'ccd_code': 'ATP'}, + {'smiles': 'CC(=O)OC1C[NH+]2CCC1CC2'} + ] + } + ] + + with open(out_file, "r") as f: + written_data = f.read() + written_data = json.loads(written_data) + + + assert written_data == reference diff --git a/tests/test_rosettafold_output.py b/tests/test_rosettafold_output.py new file mode 100644 index 0000000..fbede68 --- /dev/null +++ b/tests/test_rosettafold_output.py @@ -0,0 +1,48 @@ +import json +import shutil + +from abcfold.output.file_handlers import CifFile, ConfidenceJsonFile +from abcfold.output.rosettafold3 import RosettafoldOutput +from abcfold.output.utils import Af3Pae, flatten + + +def test_process_rosettafold_output(test_data, tmp_path): + with open("tests/test_data/alphafold3_6BJ9/6bj9_data.json", "r") as f: + input_params = json.load(f) + + output_dir = tmp_path / "rosettafold_results_seed-1" + output_dir.mkdir() + cif_path = output_dir / "6BJ9_seed-1_sample-0_model.cif" + shutil.copyfile("tests/test_data/alphafold3_6BJ9/6bj9_model.cif", cif_path) + + cif_file = CifFile(cif_path, input_params) + token_count = len(flatten(cif_file.token_residue_ids().values())) + pae = [[0.0 for _ in range(token_count)] for _ in range(token_count)] + + confidence_path = output_dir / "6BJ9_seed-1_sample-0_confidences.json" + confidence_path.write_text(json.dumps({"pae": pae})) + summary_path = output_dir / "6BJ9_seed-1_sample-0_summary_confidences.json" + summary_path.write_text(json.dumps({"ptm": 0.2, "iptm": 0.3})) + + rosettafold_output = RosettafoldOutput([output_dir], input_params, "6BJ9") + + assert "seed-1" in rosettafold_output.output + assert 0 in rosettafold_output.output["seed-1"] + assert isinstance(rosettafold_output.cif_files["seed-1"][0], CifFile) + assert isinstance(rosettafold_output.pae_files["seed-1"][0], ConfidenceJsonFile) + assert isinstance(rosettafold_output.scores_files["seed-1"][0], ConfidenceJsonFile) + + +def test_rosettafold_pae_to_af3_pae(test_data): + with open("tests/test_data/alphafold3_6BJ9/6bj9_data.json", "r") as f: + input_params = json.load(f) + + cif_file = CifFile("tests/test_data/alphafold3_6BJ9/6bj9_model.cif", input_params) + token_count = len(flatten(cif_file.token_residue_ids().values())) + pae_matrix = [[0.0 for _ in range(token_count)] for _ in range(token_count)] + pae = Af3Pae.from_rosettafold3({"pae": pae_matrix}, cif_file) + + assert len(pae.scores["pae"]) == token_count + assert len(pae.scores["contact_probs"]) == token_count + assert len(pae.scores["token_chain_ids"]) == token_count + assert len(pae.scores["token_res_ids"]) == token_count diff --git a/tests/test_run_rosettafold.py b/tests/test_run_rosettafold.py new file mode 100644 index 0000000..505525d --- /dev/null +++ b/tests/test_run_rosettafold.py @@ -0,0 +1,46 @@ +import os +import tempfile + +import pytest + +from abcfold.rosettafold3.run_rosettafold3 import ( + generate_rosettafold_command, run_rosettafold) + + +@pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping test in CI environment") +def test_run_rosettafold(test_data): + + with tempfile.TemporaryDirectory() as temp_dir: + try: + run_rosettafold( + test_data.test_inputA_json, + temp_dir, + config=test_data.config_dict, + save_input=True, + test=True, + ) + except Exception as e: + print(e) + assert False + + +def test_generate_rosettafold_command(test_data): + input_json = "/road/to/nowhere.json" + output_dir = "/road/to/nowhere" + ckpt_path = "/road/to/nowhere.ckpt" + + cmd = generate_rosettafold_command( + input_json=input_json, + output_dir=output_dir, + ckpt_path=ckpt_path, + number_of_models=5, + seed=42 + ) + + assert "rf3" in cmd + assert "fold" in cmd + assert f"inputs='{input_json}'" in cmd + assert f"out_dir='{output_dir}'" in cmd + assert f"ckpt_path='{ckpt_path}'" in cmd + assert f"diffusion_batch_size={str(5)}" in cmd + assert f"seed={str(42)}" in cmd From fa1a845c485bba687221b60e4b92f5822fc416f4 Mon Sep 17 00:00:00 2001 From: "Nikolenko.Sergei" Date: Sat, 20 Jun 2026 19:23:11 +0300 Subject: [PATCH 3/3] Fix RosettaFold import ordering --- abcfold/abcfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/abcfold/abcfold.py b/abcfold/abcfold.py index 22d690b..19ab535 100644 --- a/abcfold/abcfold.py +++ b/abcfold/abcfold.py @@ -17,8 +17,8 @@ openfold_argparse_util, prediction_argparse_util, protenix_argparse_util, - rosettafold_argparse_util, raise_argument_errors, + rosettafold_argparse_util, visuals_argparse_util) from abcfold.html.html_utils import (PORT, NoCacheHTTPRequestHandler, get_all_cif_files, get_model_data,