Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions src/clm/commands/sample_molecules_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ def add_args(parser):
action="store_true",
help="Add descriptor in hidden and cell state",
)
parser.add_argument(
"--preload_condition",
action="store_true",
help="Add descriptor in hidden and cell state",
)
parser.add_argument(
"--heldout_file",
type=str,
Expand Down Expand Up @@ -137,6 +142,7 @@ def sample_molecules_RNN(
conditional_dec=False,
conditional_dec_l=True,
conditional_h=False,
preload_condition=False,
heldout_file=None,
):
os.makedirs(os.path.dirname(os.path.abspath(output_file)), exist_ok=True)
Expand Down Expand Up @@ -294,20 +300,29 @@ def sample_molecules_RNN(

# Erase file contents if there are any
open(output_file, "w").close()

if preload_condition and heldout_dataset is not None:
preload_descriptors=torch.stack([_[1] for _ in heldout_dataset], 0).to(model.device)
while len(preload_descriptors)<batch_size:
preload_descriptors=torch.cat([preload_descriptors, preload_descriptors], 0)
preload_descriptors=torch.stack([preload_descriptors, preload_descriptors], 0)

with tqdm(total=sample_mols) as pbar:
for i in range(0, sample_mols, batch_size):
n_sequences = min(batch_size, sample_mols - i)
descriptors = None
if heldout_dataset is not None:
# Use modulo to cycle through heldout_dataset
descriptor_indices = [
(i + j) % len(heldout_dataset) for j in range(n_sequences)
]
descriptors = torch.stack(
[heldout_dataset[idx][1] for idx in descriptor_indices]
)
descriptors = descriptors.to(model.device)
if preload_condition:
s=i%len(heldout_dataset)
descriptors=preload_descriptors[s:s+n_sequences]
else:
descriptor_indices = [
(i + j) % len(heldout_dataset) for j in range(n_sequences)
]
descriptors = torch.stack(
[heldout_dataset[idx][1] for idx in descriptor_indices]
)
descriptors = descriptors.to(model.device)
sampled_smiles, losses = model.sample(
descriptors=descriptors,
n_sequences=n_sequences,
Expand Down Expand Up @@ -349,4 +364,5 @@ def main(args):
conditional_dec_l=args.conditional_dec_l,
conditional_h=args.conditional_h,
heldout_file=args.heldout_file,
preload_condition=args.preload_condition
)
7 changes: 6 additions & 1 deletion src/clm/commands/train_models_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ def add_args(parser):
parser.add_argument(
"--learning_rate", type=float, help="Learning rate for the optimizer"
)
parser.add_argument(
"--num_workers", type=int, help="number of workers for loading training data", default=3
)

parser.add_argument(
"--max_epochs", type=int, help="Maximum number of epochs for training"
Expand Down Expand Up @@ -220,6 +223,7 @@ def train_models_RNN(
smiles_file,
model_file,
loss_file,
num_workers,
conditional=False,
conditional_emb=False,
conditional_emb_l=True,
Expand Down Expand Up @@ -321,7 +325,7 @@ def train_models_RNN(
logger.info(dataset.vocabulary.dictionary)

loader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate
dataset, batch_size=batch_size, shuffle=True, collate_fn=dataset.collate, num_workers=num_workers
)
optim = Adam(
model.parameters(), betas=(0.9, 0.999), eps=1e-08, lr=learning_rate
Expand Down Expand Up @@ -411,4 +415,5 @@ def main(args):
conditional_dec=args.conditional_dec,
conditional_dec_l=args.conditional_dec_l,
conditional_h=args.conditional_h,
num_workers=args.num_workers
)
Loading