From 39425e71f335ed8309a75493c8ed7e1c8f75cdcc Mon Sep 17 00:00:00 2001 From: Jakob Nybo Andersen Date: Sat, 10 Jan 2026 10:15:15 +0100 Subject: [PATCH] Instantiate tensors on correct devices Vamb used to have the following pattern a lot of places: ```python my_tensor = _torch.zeros(n) if self.usecuda: my_tensor.cuda() ``` This is wasteful, as the tensor is instantiated on CPU, then moved to GPU, instead of creating it directly on the GPU. In this commit, where possible, we instantiate on the correct device from the beginning. --- vamb/aamb_encode.py | 30 ++++++--------- vamb/cluster.py | 19 ++++------ vamb/encode.py | 25 ++++++------- vamb/semisupervised_encode.py | 70 ++++++++++++++++------------------- vamb/taxvamb_encode.py | 34 ++++++++--------- 5 files changed, 81 insertions(+), 97 deletions(-) diff --git a/vamb/aamb_encode.py b/vamb/aamb_encode.py index d7ce663d..e6867a16 100644 --- a/vamb/aamb_encode.py +++ b/vamb/aamb_encode.py @@ -122,9 +122,6 @@ def _reparameterization(self, mu, logvar): std = torch.exp(logvar / 2) sampled_z = Variable(Tensor(self.rng.normal(0, 1, (mu.size(0), self.ld)))) - - if self.usecuda: - sampled_z = sampled_z.cuda() z = sampled_z * std + mu return z @@ -248,7 +245,7 @@ def trainmodel( # Define adversarial loss for the discriminators adversarial_loss = torch.nn.BCELoss() if self.usecuda: - adversarial_loss.cuda() + adversarial_loss = adversarial_loss.cuda() #### Optimizers optimizer_E = torch.optim.Adam(enc_params, lr=lr) @@ -290,16 +287,15 @@ def trainmodel( ) # Sample noise as discriminator Z,Y ground truth + device = "cuda" if self.usecuda else "cpu" if self.usecuda: z_prior = torch.cuda.FloatTensor(nrows, self.ld).normal_() - z_prior.cuda() ohc = RelaxedOneHotCategorical( - torch.tensor([T], device="cuda"), - torch.ones([nrows, self.y_len], device="cuda"), + torch.tensor([T], device=device), + torch.ones([nrows, self.y_len], device=device), ) y_prior = ohc.sample() - y_prior = y_prior.cuda() else: z_prior = Variable(Tensor(self.rng.normal(0, 1, (nrows, self.ld)))) @@ -308,9 +304,8 @@ def trainmodel( del ohc - if self.usecuda: - depths_in = depths_in.cuda() - tnfs_in = tnfs_in.cuda() + depths_in = depths_in.to(device) + tnfs_in = tnfs_in.to(device) # ----------------- # Train Generator @@ -461,15 +456,15 @@ def get_latents( with torch.no_grad(): for depths_in, tnfs_in, _, _ in new_data_loader: nrows, _ = depths_in.shape + device = "cuda" if self.usecuda else "cpu" + if self.usecuda: z_prior = torch.cuda.FloatTensor(nrows, self.ld).normal_() - z_prior.cuda() ohc = RelaxedOneHotCategorical( - torch.tensor([0.15], device="cuda"), - torch.ones([nrows, self.y_len], device="cuda"), + torch.tensor([0.15], device=device), + torch.ones([nrows, self.y_len], device=device), ) y_prior = ohc.sample() - y_prior = y_prior.cuda() else: z_prior = Variable(Tensor(self.rng.normal(0, 1, (nrows, self.ld)))) @@ -478,9 +473,8 @@ def get_latents( ) y_prior = ohc.sample() - if self.usecuda: - depths_in = depths_in.cuda() - tnfs_in = tnfs_in.cuda() + depths_in = depths_in.to(device) + tnfs_in = tnfs_in.to(device) mu, _, _, _, y_sample = self(depths_in, tnfs_in)[0:5] diff --git a/vamb/cluster.py b/vamb/cluster.py index c6bf9b84..8db575f4 100644 --- a/vamb/cluster.py +++ b/vamb/cluster.py @@ -224,9 +224,9 @@ def _check_params( def _init_histogram_kept_mask(self, N: int) -> tuple[_Tensor, _Tensor]: "N is number of contigs" - kept_mask = _torch.ones(N, dtype=_torch.bool) - if self.cuda: - kept_mask = kept_mask.cuda() + kept_mask = _torch.ones( + N, dtype=_torch.bool, device="cuda" if self.cuda else "cpu" + ) histogram = _torch.empty(_ceil(_XMAX / _DELTA_X)) return histogram, kept_mask @@ -258,10 +258,9 @@ def __init__( _normalize(torch_matrix, inplace=True) # Move to GPU - torch_lengths = _torch.Tensor(lengths) - if cuda: - torch_matrix = torch_matrix.cuda() - torch_lengths = torch_lengths.cuda() + device = "cuda" if cuda else "cpu" + torch_matrix = torch_matrix.to(device) + torch_lengths = _torch.tensor(lengths, dtype=_torch.float32, device=device) self.maxsteps: int = maxsteps self.minsuccesses: int = minsuccesses @@ -274,9 +273,7 @@ def __init__( self.indices = _torch.arange(len(matrix)) self.order = _np.argsort(lengths)[::-1] self.order_index = 0 - self.lengths = _torch.Tensor(lengths) - if cuda: - self.lengths = self.lengths.cuda() + self.lengths = torch_lengths self.n_emitted_clusters = 0 self.n_remaining_points = len(torch_matrix) self.peak_valley_ratio = 0.1 @@ -321,7 +318,7 @@ def pack(self): cpu_kept_mask = self.kept_mask.cpu() self.matrix = _vambtools.torch_inplace_maskarray( self.matrix.cpu(), cpu_kept_mask - ).cuda() + ).to("cuda") self.indices = self.indices[cpu_kept_mask] else: diff --git a/vamb/encode.py b/vamb/encode.py index 67894d33..2260d52d 100644 --- a/vamb/encode.py +++ b/vamb/encode.py @@ -274,10 +274,9 @@ def _encode(self, tensor: Tensor) -> Tensor: # sample with gaussian noise def reparameterize(self, mu: Tensor) -> Tensor: - epsilon = _torch.randn(mu.size(0), mu.size(1)) - - if self.usecuda: - epsilon = epsilon.cuda() + epsilon = _torch.randn( + mu.size(0), mu.size(1), device="cuda" if self.usecuda else "cpu" + ) epsilon.requires_grad = True @@ -392,11 +391,11 @@ def trainepoch( tnf_in.requires_grad = True abundance_in.requires_grad = True - if self.usecuda: - depths_in = depths_in.cuda() - tnf_in = tnf_in.cuda() - abundance_in = abundance_in.cuda() - weights = weights.cuda() + device = "cuda" if self.usecuda else "cpu" + depths_in = depths_in.to(device) + tnf_in = tnf_in.to(device) + abundance_in = abundance_in.to(device) + weights = weights.to(device) optimizer.zero_grad() @@ -465,10 +464,10 @@ def encode(self, data_loader) -> _np.ndarray: with _torch.no_grad(): for depths, tnf, ab, _ in new_data_loader: # Move input to GPU if requested - if self.usecuda: - depths = depths.cuda() - tnf = tnf.cuda() - ab = ab.cuda() + device = "cuda" if self.usecuda else "cpu" + depths = depths.to(device) + tnf = tnf.to(device) + ab = ab.to(device) # Evaluate _, _, _, mu = self(depths, tnf, ab) diff --git a/vamb/semisupervised_encode.py b/vamb/semisupervised_encode.py index e9c3c5d9..ce092b79 100644 --- a/vamb/semisupervised_encode.py +++ b/vamb/semisupervised_encode.py @@ -238,9 +238,7 @@ def _decode(self, tensor): def forward(self, labels): mu = self._encode(labels) - logsigma = _torch.zeros(mu.size()) - if self.usecuda: - logsigma = logsigma.cuda() + logsigma = _torch.zeros(mu.size(), device="cuda" if self.usecuda else "cpu") latent = self.reparameterize(mu) labels_out = self._decode(latent) return labels_out, mu, logsigma @@ -281,8 +279,7 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps): labels_in = labels_in[0] labels_in.requires_grad = True - if self.usecuda: - labels_in = labels_in.cuda() + labels_in = labels_in.to("cuda" if self.usecuda else "cpu") optimizer.zero_grad() @@ -343,8 +340,7 @@ def encode(self, data_loader): for labels in new_data_loader: labels = labels[0] # Move input to GPU if requested - if self.usecuda: - labels = labels.cuda() + labels = labels.to("cuda" if self.usecuda else "cpu") # Evaluate out_labels, mu, logsigma = self(labels) @@ -504,9 +500,7 @@ def _decode(self, tensor): def forward(self, depths, tnf, abundance, labels): tensor = _torch.cat((depths, tnf, abundance, labels), 1) mu = self._encode(tensor) - logsigma = _torch.zeros(mu.size()) - if self.usecuda: - logsigma = logsigma.cuda() + logsigma = _torch.zeros(mu.size(), device="cuda" if self.usecuda else "cpu") latent = self.reparameterize(mu) depths_out, tnf_out, abundance_out, labels_out = self._decode(latent) @@ -596,12 +590,12 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps): abundance_in.requires_grad = True labels_in.requires_grad = True - if self.usecuda: - depths_in = depths_in.cuda() - tnf_in = tnf_in.cuda() - abundance_in = abundance_in.cuda() - weights = weights.cuda() - labels_in = labels_in.cuda() + device = "cuda" if self.usecuda else "cpu" + depths_in = depths_in.to(device) + tnf_in = tnf_in.to(device) + abundance_in = abundance_in.to(device) + weights = weights.to(device) + labels_in = labels_in.to(device) optimizer.zero_grad() @@ -677,11 +671,11 @@ def encode(self, data_loader): with _torch.no_grad(): for depths, tnf, ab, weights, labels in new_data_loader: # Move input to GPU if requested - if self.usecuda: - depths = depths.cuda() - tnf = tnf.cuda() - ab = ab.cuda() - labels = labels.cuda() + device = "cuda" if self.usecuda else "cpu" + depths = depths.to(device) + tnf = tnf.to(device) + ab = ab.to(device) + labels = labels.to(device) # Evaluate _, _, _, _, mu, _ = self(depths, tnf, ab, labels) @@ -882,17 +876,17 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps): abundance_in_unsup.requires_grad = True labels_in_unsup.requires_grad = True - if self.VAEVamb.usecuda: - depths_in_sup = depths_in_sup.cuda() - tnf_in_sup = tnf_in_sup.cuda() - abundance_in_sup = abundance_in_sup.cuda() - weights_in_sup = weights_in_sup.cuda() - labels_in_sup = labels_in_sup.cuda() - depths_in_unsup = depths_in_unsup.cuda() - tnf_in_unsup = tnf_in_unsup.cuda() - abundance_in_unsup = abundance_in_unsup.cuda() - weights_in_unsup = weights_in_unsup.cuda() - labels_in_unsup = labels_in_unsup.cuda() + device = "cuda" if self.VAEVamb.usecuda else "cpu" + depths_in_sup = depths_in_sup.to(device) + tnf_in_sup = tnf_in_sup.to(device) + abundance_in_sup = abundance_in_sup.to(device) + weights_in_sup = weights_in_sup.to(device) + labels_in_sup = labels_in_sup.to(device) + depths_in_unsup = depths_in_unsup.to(device) + tnf_in_unsup = tnf_in_unsup.to(device) + abundance_in_unsup = abundance_in_unsup.to(device) + weights_in_unsup = weights_in_unsup.to(device) + labels_in_unsup = labels_in_unsup.to(device) optimizer.zero_grad() @@ -913,15 +907,15 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps): abundance_out_unsup, mu_vamb_unsup, ) = self.VAEVamb(depths_in_unsup, tnf_in_unsup, abundance_in_unsup) - logsigma_vamb_unsup = _torch.zeros(mu_vamb_unsup.size()) - if self.usecuda: - logsigma_vamb_unsup = logsigma_vamb_unsup.cuda() + logsigma_vamb_unsup = _torch.zeros( + mu_vamb_unsup.size(), device="cuda" if self.VAEVamb.usecuda else "cpu" + ) _, _, _, mu_vamb_sup_s = self.VAEVamb( depths_in_sup, tnf_in_sup, abundance_in_sup ) - logsigma_vamb_sup_s = _torch.zeros(mu_vamb_sup_s.size()) - if self.usecuda: - logsigma_vamb_sup_s = logsigma_vamb_sup_s.cuda() + logsigma_vamb_sup_s = _torch.zeros( + mu_vamb_sup_s.size(), device="cuda" if self.VAEVamb.usecuda else "cpu" + ) labels_out_unsup, mu_labels_unsup, logsigma_labels_unsup = self.VAELabels( labels_in_unsup ) diff --git a/vamb/taxvamb_encode.py b/vamb/taxvamb_encode.py index 924ca6c0..5ea9d4f8 100644 --- a/vamb/taxvamb_encode.py +++ b/vamb/taxvamb_encode.py @@ -900,11 +900,11 @@ def predict(self, data_loader) -> Iterable[tuple[_np.ndarray, _np.ndarray]]: with _torch.no_grad(): for depths, tnf, abundances, weights in new_data_loader: # Move input to GPU if requested - if self.usecuda: - depths = depths.cuda() - tnf = tnf.cuda() - abundances = abundances.cuda() - weights = weights.cuda() + device = "cuda" if self.usecuda else "cpu" + depths = depths.to(device) + tnf = tnf.to(device) + abundances = abundances.to(device) + weights = weights.to(device) # Evaluate labels = self(depths, tnf, abundances, weights) @@ -931,12 +931,12 @@ def predict_with_ground_truth( with _torch.no_grad(): for depths_in, tnf_in, abundances_in, weights, labels_in in new_data_loader: - if self.usecuda: - depths_in = depths_in.cuda() - tnf_in = tnf_in.cuda() - abundances_in = abundances_in.cuda() - weights = weights.cuda() - labels_in = labels_in.cuda() + device = "cuda" if self.usecuda else "cpu" + depths_in = depths_in.to(device) + tnf_in = tnf_in.to(device) + abundances_in = abundances_in.to(device) + weights = weights.to(device) + labels_in = labels_in.to(device) labels_out = self(depths_in, tnf_in, abundances_in, weights) loss, correct_labels = self.calc_loss(labels_in, labels_out) @@ -1004,12 +1004,12 @@ def trainepoch( abundances_in.requires_grad = True labels_in.requires_grad = True - if self.usecuda: - depths_in = depths_in.cuda() - tnf_in = tnf_in.cuda() - abundances_in = abundances_in.cuda() - weights = weights.cuda() - labels_in = labels_in.cuda() + device = "cuda" if self.usecuda else "cpu" + depths_in = depths_in.to(device) + tnf_in = tnf_in.to(device) + abundances_in = abundances_in.to(device) + weights = weights.to(device) + labels_in = labels_in.to(device) optimizer.zero_grad()