diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index fd9685df..2c64c7fb 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -86,6 +86,11 @@ def add_args(parser): action="store_true", help="Add descriptor in hidden and cell state", ) + parser.add_argument( + "--preload_condition", + action="store_true", + help="Add descriptor in hidden and cell state", + ) parser.add_argument( "--heldout_file", type=str, @@ -137,6 +142,7 @@ def sample_molecules_RNN( conditional_dec=False, conditional_dec_l=True, conditional_h=False, + preload_condition=False, heldout_file=None, ): os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) @@ -294,20 +300,29 @@ def sample_molecules_RNN( # Erase file contents if there are any open(output_file, "w").close() - + if preload_condition and heldout_dataset is not None: + preload_descriptors=torch.stack([_[1] for _ in heldout_dataset], 0).to(model.device) + while len(preload_descriptors)