diff --git a/probe.py b/probe.py index 7aa6834..beac743 100644 --- a/probe.py +++ b/probe.py @@ -419,7 +419,12 @@ def _seg_extract_features(model, mean, std, device, images_np): for i in range(0, len(images_np), SEGMENTATION_BATCH_SIZE): batch = torch.from_numpy(np.ascontiguousarray(images_np[i : i + SEGMENTATION_BATCH_SIZE, 16:240, 16:240, :])).permute(0, 3, 1, 2).float().to(device) / 255.0 with autocast: - feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].float().cpu()) + # float16 host-RAM cache: the dense seg features are [N, 1024, 3072]; at full pannuke + # (~5k imgs) a float32 cache is ~65GB and OOM-kills the probe worker on an 85GB single-GPU + # host (the train_1gpu.sbatch target). float16 halves it to ~32GB and is lossless for these + # bf16-autocast features (upcast back to float32 at the head call sites below), so the + # segmentation jaccard is unchanged. + feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].half().cpu()) return torch.cat(feats, dim=0) @@ -441,7 +446,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v for i in range(0, n, SEGMENTATION_BATCH_SIZE): idx = perm[i : i + SEGMENTATION_BATCH_SIZE] labels = train_labels_t[idx].to(device) - logits = F.interpolate(head(train_feats[idx].to(device)), (256, 256), mode="bilinear") + logits = F.interpolate(head(train_feats[idx].to(device).float()), (256, 256), mode="bilinear") loss = multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool)) opt.zero_grad() loss.backward() @@ -451,7 +456,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v with torch.no_grad(): for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE): labels = val_labels_t[i : i + SEGMENTATION_BATCH_SIZE].to(device) - logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear") + logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear") val_loss_sum += multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool)).item() val_batches += 1 val_loss = val_loss_sum / max(1, val_batches) @@ -464,7 +469,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v # Report the Thunder-compatible per-image macro Jaccard with bg-only reweighting. with torch.no_grad(): for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE): - preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy() + preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy() true_chunk = val_labels[i : i + SEGMENTATION_BATCH_SIZE] for k in range(preds.shape[0]): t = true_chunk[k].reshape(-1)