From 45a874ed770147e10a1afef75faa47eb32b69642 Mon Sep 17 00:00:00 2001 From: Chencheng Xu Date: Mon, 2 Mar 2026 12:17:14 -0500 Subject: [PATCH 1/3] add workers to dataloader --- src/clm/commands/train_models_RNN.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/clm/commands/train_models_RNN.py b/src/clm/commands/train_models_RNN.py index 84ace720..b932a824 100644 --- a/src/clm/commands/train_models_RNN.py +++ b/src/clm/commands/train_models_RNN.py @@ -81,6 +81,9 @@ def add_args(parser): parser.add_argument( "--learning_rate", type=float, help="Learning rate for the optimizer" ) + parser.add_argument( + "--num_workers", type=int, help="number of workers for loading training data", default=3 + ) parser.add_argument( "--max_epochs", type=int, help="Maximum number of epochs for training" @@ -220,6 +223,7 @@ def train_models_RNN( smiles_file, model_file, loss_file, + num_workers, conditional=False, conditional_emb=False, conditional_emb_l=True, @@ -321,7 +325,7 @@ def train_models_RNN( logger.info(dataset.vocabulary.dictionary) loader = DataLoader( - dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate + dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate, num_workers=num_workers ) optim = Adam( model.parameters(), betas=(0.9, 0.999), eps=1e-08, lr=learning_rate @@ -411,4 +415,5 @@ def main(args): conditional_dec=args.conditional_dec, conditional_dec_l=args.conditional_dec_l, conditional_h=args.conditional_h, + num_workers=args.num_workers ) From 733fee21afe2ccd2e2a39931a039512c71f57c6a Mon Sep 17 00:00:00 2001 From: Chencheng Xu Date: Mon, 2 Mar 2026 12:21:36 -0500 Subject: [PATCH 2/3] preload conditions in sampling --- src/clm/commands/sample_molecules_RNN.py | 30 +++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index fd9685df..9f517258 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -86,6 +86,11 @@ def add_args(parser): action="store_true", help="Add descriptor in hidden and cell state", ) + parser.add_argument( + "--preload_condition", + action="store_true", + help="Add descriptor in hidden and cell state", + ) parser.add_argument( "--heldout_file", type=str, @@ -137,6 +142,7 @@ def sample_molecules_RNN( conditional_dec=False, conditional_dec_l=True, conditional_h=False, + preload_condition=False, heldout_file=None, ): os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True) @@ -294,20 +300,27 @@ def sample_molecules_RNN( # Erase file contents if there are any open(output_file, "w").close() - + if preload_condition and heldout_dataset is not None: + preload_descriptors=torch.stack([_[1] for _ in heldout_dataset], 0).to(model.device) + while len(preload_descriptors)<4*batch_size: + preload_descriptors=torch.cat([preload_descriptors, preload_descriptors], 0) with tqdm(total=sample_mols) as pbar: for i in range(0, sample_mols, batch_size): n_sequences = min(batch_size, sample_mols - i) descriptors = None if heldout_dataset is not None: # Use modulo to cycle through heldout_dataset - descriptor_indices = [ - (i + j) % len(heldout_dataset) for j in range(n_sequences) - ] - descriptors = torch.stack( - [heldout_dataset[idx][1] for idx in descriptor_indices] - ) - descriptors = descriptors.to(model.device) + if preload_condition: + s=i%len(heldout_dataset) + descriptors=preload_descriptors[s:s+n_sequences] + else: + descriptor_indices = [ + (i + j) % len(heldout_dataset) for j in range(n_sequences) + ] + descriptors = torch.stack( + [heldout_dataset[idx][1] for idx in descriptor_indices] + ) + descriptors = descriptors.to(model.device) sampled_smiles, losses = model.sample( descriptors=descriptors, n_sequences=n_sequences, @@ -349,4 +362,5 @@ def main(args): conditional_dec_l=args.conditional_dec_l, conditional_h=args.conditional_h, heldout_file=args.heldout_file, + preload_condition=args.preload_condition ) From fd43e8e00a6d561c4569754a32f698780893b098 Mon Sep 17 00:00:00 2001 From: Chencheng Xu Date: Mon, 2 Mar 2026 15:01:52 -0500 Subject: [PATCH 3/3] should concate descriptor twice --- src/clm/commands/sample_molecules_RNN.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/clm/commands/sample_molecules_RNN.py b/src/clm/commands/sample_molecules_RNN.py index 9f517258..2c64c7fb 100644 --- a/src/clm/commands/sample_molecules_RNN.py +++ b/src/clm/commands/sample_molecules_RNN.py @@ -302,8 +302,10 @@ def sample_molecules_RNN( open(output_file, "w").close() if preload_condition and heldout_dataset is not None: preload_descriptors=torch.stack([_[1] for _ in heldout_dataset], 0).to(model.device) - while len(preload_descriptors)<4*batch_size: + while len(preload_descriptors)