Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,12 @@ def _seg_extract_features(model, mean, std, device, images_np):
for i in range(0, len(images_np), SEGMENTATION_BATCH_SIZE):
batch = torch.from_numpy(np.ascontiguousarray(images_np[i : i + SEGMENTATION_BATCH_SIZE, 16:240, 16:240, :])).permute(0, 3, 1, 2).float().to(device) / 255.0
with autocast:
feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].float().cpu())
# float16 host-RAM cache: the dense seg features are [N, 1024, 3072]; at full pannuke
# (~5k imgs) a float32 cache is ~65GB and OOM-kills the probe worker on an 85GB single-GPU
# host (the train_1gpu.sbatch target). float16 halves it to ~32GB and is lossless for these
# bf16-autocast features (upcast back to float32 at the head call sites below), so the
# segmentation jaccard is unchanged.
feats.append(model.encode_image((batch - mean) / std)[:, model.registers :].half().cpu())
return torch.cat(feats, dim=0)


Expand All @@ -441,7 +446,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v
for i in range(0, n, SEGMENTATION_BATCH_SIZE):
idx = perm[i : i + SEGMENTATION_BATCH_SIZE]
labels = train_labels_t[idx].to(device)
logits = F.interpolate(head(train_feats[idx].to(device)), (256, 256), mode="bilinear")
logits = F.interpolate(head(train_feats[idx].to(device).float()), (256, 256), mode="bilinear")
loss = multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool))
opt.zero_grad()
loss.backward()
Expand All @@ -451,7 +456,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v
with torch.no_grad():
for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE):
labels = val_labels_t[i : i + SEGMENTATION_BATCH_SIZE].to(device)
logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear")
logits = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear")
val_loss_sum += multiclass_dice_loss(logits, labels, torch.ones_like(labels, dtype=torch.bool)).item()
val_batches += 1
val_loss = val_loss_sum / max(1, val_batches)
Expand All @@ -464,7 +469,7 @@ def _seg_head_jaccard_from_feats(device, train_feats, train_labels, val_feats, v
# Report the Thunder-compatible per-image macro Jaccard with bg-only reweighting.
with torch.no_grad():
for i in range(0, len(val_feats), SEGMENTATION_BATCH_SIZE):
preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device)), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy()
preds = F.interpolate(head(val_feats[i : i + SEGMENTATION_BATCH_SIZE].to(device).float()), (256, 256), mode="bilinear").argmax(dim=1).cpu().numpy()
true_chunk = val_labels[i : i + SEGMENTATION_BATCH_SIZE]
for k in range(preds.shape[0]):
t = true_chunk[k].reshape(-1)
Expand Down