Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions abcfold/abcfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
prediction_argparse_util,
protenix_argparse_util,
raise_argument_errors,
rosettafold_argparse_util,
visuals_argparse_util)
from abcfold.html.html_utils import (PORT, NoCacheHTTPRequestHandler,
get_all_cif_files, get_model_data,
Expand All @@ -30,6 +31,7 @@
from abcfold.output.file_handlers import superpose_models
from abcfold.output.openfold3 import OpenfoldOutput
from abcfold.output.protenix import ProtenixOutput
from abcfold.output.rosettafold3 import RosettafoldOutput
from abcfold.output.utils import (get_gap_indicies, insert_none_by_minus_one,
make_dummy_m8_file, verify_config_file)
from abcfold.scripts.abc_script_utils import (check_input_json, make_dir,
Expand All @@ -44,7 +46,12 @@
PLOTS_DIR = ".plots"

ModelOutput = Union[
AlphafoldOutput, BoltzOutput, ChaiOutput, ProtenixOutput, OpenfoldOutput
AlphafoldOutput,
BoltzOutput,
ChaiOutput,
ProtenixOutput,
OpenfoldOutput,
RosettafoldOutput,
]


Expand Down Expand Up @@ -281,6 +288,27 @@ def run(args, config, defaults, config_file):
outputs.append(oo)
successful_runs.append(openfold_success)

if args.rosettafold3:
from abcfold.rosettafold3.run_rosettafold3 import run_rosettafold

rosettafold_success = run_rosettafold(
input_json=run_json,
output_dir=args.output_dir,
save_input=args.save_input,
number_of_models=args.number_of_models,
config=rt_config,
)

if rosettafold_success:
rosettafold_output_dirs = list(
args.output_dir.glob("rosettafold_results*")
)
ro = RosettafoldOutput(
rosettafold_output_dirs, input_params, name, args.save_input
)
outputs.append(ro)
successful_runs.append(rosettafold_success)

if args.no_visuals:
logger.info("Visuals disabled")
return
Expand Down Expand Up @@ -440,12 +468,41 @@ def run(args, config, defaults, config_file):
)
protenix_models["models"].append(model_data)

rosettafold_models: Dict[str, List[Dict[str, Any]]] = {"models": []}
if args.rosettafold3:
if rosettafold_success:
programs_run.append("RosettaFold3")
for seed in ro.output.keys():
for idx in ro.output[seed].keys():
if idx >= 0:
model = ro.output[seed][idx]["cif"]
model.check_clashes()
score_file = ro.output[seed][idx]["scores"]
plddt = model.residue_plddts
pae = ro.output[seed][idx]["af3_pae"]
if len(indicies) > 0:
plddt = insert_none_by_minus_one(
indicies[index_counter], plddt
)
index_counter += 1
model_data = get_model_data(
model,
plot_dict,
"RosettaFold3",
plddt,
pae,
score_file,
args.output_dir,
)
rosettafold_models["models"].append(model_data)

combined_models = (
alphafold_models["models"]
+ boltz_models["models"]
+ chai_models["models"]
+ openfold_models["models"]
+ protenix_models["models"]
+ rosettafold_models["models"]
)

# Make the output directory for the models
Expand All @@ -463,6 +520,8 @@ def run(args, config, defaults, config_file):
output_name = "openfold_model_" + model["model_id"][-1] + ".cif"
elif model["model_source"] == "Protenix":
output_name = "protenix_model_" + model["model_id"][-1] + ".cif"
elif model["model_source"] == "RosettaFold3":
output_name = "rosettafold_model_" + model["model_id"][-1] + ".cif"
shutil.copy(
cif_file,
args.output_dir.joinpath("output_models").joinpath(output_name),
Expand Down Expand Up @@ -551,7 +610,7 @@ def run(args, config, defaults, config_file):

def main():
"""
Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix
Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix / RosettaFold3
"""
import argparse

Expand Down Expand Up @@ -581,7 +640,9 @@ def main():
defaults.update(dict(config.items(section)))

parser = argparse.ArgumentParser(
description="Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix",
description=(
"Run AlphaFold3 / Boltz / Chai-1 / OpenFold3 / Protenix / RosettaFold3"
),
parents=[config_parser],
)

Expand All @@ -591,6 +652,7 @@ def main():
parser = chai_argparse_util(parser)
parser = openfold_argparse_util(parser)
parser = protenix_argparse_util(parser)
parser = rosettafold_argparse_util(parser)
parser = mmseqs2_argparse_util(parser)
parser = custom_template_argpase_util(parser)
parser = prediction_argparse_util(parser)
Expand Down
15 changes: 13 additions & 2 deletions abcfold/argparse_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ def openfold_argparse_util(parser):
return parser


def rosettafold_argparse_util(parser):
parser.add_argument(
"-r",
"--rosettafold3",
action="store_true",
help="Run RosettaFold 3",
)
return parser


def alphafold_argparse_util(parser):
parser.add_argument(
"--database",
Expand Down Expand Up @@ -336,10 +346,11 @@ def raise_argument_errors(args):
and not args.chai1
and not args.protenix
and not args.openfold3
and not args.rosettafold3
):
logger.info(
dedent("None of AlphaFold3, Boltz, Chai-1, Protenix or OpenFold3 selected. \
Running AlphaFold3 by default")
dedent("None of AlphaFold3, Boltz, Chai-1, Protenix, OpenFold3 or \
RosettaFold3 selected. Running AlphaFold3 by default")
)
args.alphafold3 = True

Expand Down
5 changes: 4 additions & 1 deletion abcfold/data/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@ af3_sif_path = None
model_params = /mnt/ligandpro/soft/protein/alphafold3/models
boltz_weights = None
openfold_weights = None
rosettafold_weights = None

[Environments]
af3_docker_env = romerolabduke/alphafast:latest
boltz_env = abcfold-boltz-py311
chai_env = abcfold-chai-py311
openfold_env = abcfold-openfold-py311
protenix_env = abcfold-protenix-py311
rosettafold_env = abcfold-rosetta-py312

[Versions]
af3_version = 3.0.0
boltz_version = 2.2.1
chai_version = 0.6.1
openfold_version = 0.3.1
openfold_version = 0.4.1
protenix_version = 2.0.0
rosetta_version = 0.1.12

[Models]
protenix_model = protenix-v2
Expand Down
2 changes: 2 additions & 0 deletions abcfold/html/abcfold_vue.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ Vue.component('abc-table', {
return 'btn-source4';
case 'OpenFold3':
return 'btn-source5';
case 'RosettaFold3':
return 'btn-source6';
default:
return 'btn-default';
}
Expand Down
6 changes: 6 additions & 0 deletions abcfold/html/html_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from abcfold.output.file_handlers import ConfidenceJsonFile, NpzFile
from abcfold.output.openfold3 import OpenfoldOutput
from abcfold.output.protenix import ProtenixOutput
from abcfold.output.rosettafold3 import RosettafoldOutput
from abcfold.plots.pae_plot import create_pae_plots
from abcfold.plots.plddt_plot import plot_plddt
from abcfold.scripts.ipsae import Ipsae
Expand Down Expand Up @@ -293,6 +294,11 @@ def get_all_cif_files(outputs) -> Dict[str, list]:
if "Protenix" not in method_cif_objs:
method_cif_objs["Protenix"] = []
method_cif_objs["Protenix"].extend(output.cif_files[seed])
elif isinstance(output, RosettafoldOutput):
for seed in output.seeds:
if "RosettaFold3" not in method_cif_objs:
method_cif_objs["RosettaFold3"] = []
method_cif_objs["RosettaFold3"].extend(output.cif_files[seed])

return method_cif_objs

Expand Down
12 changes: 11 additions & 1 deletion abcfold/html/style.css
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,16 @@ rect{
transition: background-color 0.3s ease;
}

.btn-source6 {
background-color: #FFA533;
color: #3b3b3d;
border: none;
border-radius: 12px;
padding: 7px 15px;
cursor: pointer;
transition: background-color 0.3s ease;
}

.btn-default {
background-color: rgb(199, 195, 195);
color: white;
Expand All @@ -506,6 +516,6 @@ rect{
transition: background-color 0.3s ease;
}

.btn-source1:hover, .btn-source2:hover, .btn-source3:hover, .btn-default:hover {
.btn-source1:hover, .btn-source2:hover, .btn-source3:hover, .btn-source4:hover, .btn-source5:hover, .btn-source6:hover, .btn-default:hover {
opacity: 0.8;
}
20 changes: 20 additions & 0 deletions abcfold/output/file_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,26 @@ def from_openfold(cls,
cls._fix_openfold_mmcif(str(openfold_path), str(tmp_path))
return cls(str(tmp_path), input_params)

@classmethod
def from_rosettafold(cls,
rosettafold_cif: Union[str, Path],
input_params: Optional[dict] = None) -> "CifFile":
"""
Create a CifFile from a RosettaFold3 mmCIF and convert pLDDT scores
stored in the B-factor field from 0-1 scale to 0-100 scale.
"""
rosettafold_path = Path(rosettafold_cif)
tmp_path = rosettafold_path.parent / f"{rosettafold_path.stem}_fixed.cif"
parser = MMCIFParser(QUIET=True)
model = parser.get_structure(rosettafold_path.stem, rosettafold_path)
for atom in model.get_atoms():
atom.bfactor = atom.bfactor * 100.0
io = MMCIFIO()
io.set_structure(model)
io.save(str(tmp_path))

return cls(str(tmp_path), input_params)


class ConfidenceJsonFile(FileBase):
def __init__(self, json_file: Union[str, Path]):
Expand Down
Loading
Loading