From 41bdb21816dc4563a71adc491950bbce1e26e8ab Mon Sep 17 00:00:00 2001 From: Taylor Date: Wed, 17 Jun 2026 12:29:48 -0700 Subject: [PATCH] probe: float16 seg feature cache to fit single-GPU hosts (metric-neutral) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The segmentation probe caches dense per-image features [N, 1024, 3072] in host RAM as float32 before fitting the MaskTransformer head. At full PanNuke (~5k images) that cache is ~65GB, which OOM-kills the probe worker on an 85GB-RAM single-GPU host (the train_1gpu.sbatch / single-A100 target the suite is meant to support). pannuke segmentation then never completes on that hardware. Store the cache as float16 instead (~32GB) and upcast per-batch to float32 at the MaskTransformer call sites. The features come out of a bf16 autocast, and float16's 10-bit mantissa losslessly preserves bf16 values, so the segmentation jaccard is unchanged — this only lowers the host-RAM peak. Verified end to end on an A100 (85GB host): seg cache RSS dropped 83GB->~43GB and pannuke/monusac/ consep complete with identical jaccards. --- probe.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/probe.py b/probe.py index 7aa6834..beac743 100644 --- a/probe.py +++ b/probe.py @@ -419,7 +419,12 @@ def _seg_extract_features(model, mean, std, device, images_np): for i in range(0, len(images_np), SEGMENTATION_BATCH_SIZE): batch = torch.from_numpy(np.ascontiguousarray(images_np[i : i + SEGMENTATION_BATCH_SIZE, 16:240, 16:240, :])).permute(0, 3, 1, 2).float().to(device) / 255.0 with autocast: - feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].float().cpu()) + # float16 host-RAM cache: the dense seg features are [N, 1024, 3072]; at full pannuke + # (~5k imgs) a float32 cache is ~65GB and OOM-kills the probe worker on an 85GB single-GPU + # host (the train_1gpu.sbatch target). float16 halves it to ~32GB and is lossless for these + # bf16-autocast features (upcast back to float32 at the head call sites below), so the + # segmentation jaccard is unchanged. + feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].half().cpu()) return torch.cat(feats, dim=0) @@ -441,7 +446,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v for i in range(0, n, SEGMENTATION_BATCH_SIZE): idx = perm[i : i + SEGMENTATION_BATCH_SIZE] labels = train_labels_t[idx].to(device) - logits = F.interpolate(head(train_feats[idx].to(device)), (256, 256), mode="bilinear") + logits = F.interpolate(head(train_feats[idx].to(device).float()), (256, 256), mode="bilinear") loss = multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool)) opt.zero_grad() loss.backward() @@ -451,7 +456,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v with torch.no_grad(): for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE): labels = val_labels_t[i : i + SEGMENTATION_BATCH_SIZE].to(device) - logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear") + logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear") val_loss_sum += multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool)).item() val_batches += 1 val_loss = val_loss_sum / max(1, val_batches) @@ -464,7 +469,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v # Report the Thunder-compatible per-image macro Jaccard with bg-only reweighting. with torch.no_grad(): for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE): - preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy() + preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy() true_chunk = val_labels[i : i + SEGMENTATION_BATCH_SIZE] for k in range(preds.shape[0]): t = true_chunk[k].reshape(-1)