Skip to content

Fix MPS tensor layout for WiLoR inference#22

Open
lyonsno wants to merge 3 commits into
warmshao:mainfrom
lyonsno:codex/mps-layout-smoke-0522
Open

Fix MPS tensor layout for WiLoR inference#22
lyonsno wants to merge 3 commits into
warmshao:mainfrom
lyonsno:codex/mps-layout-smoke-0522

Conversation

@lyonsno
Copy link
Copy Markdown

@lyonsno lyonsno commented May 22, 2026

Summary

This fixes two non-contiguous tensor boundaries that break or threaten PyTorch MPS inference on Apple Silicon:

  • make the cropped ViT/backbone input contiguous after x[:, :, :, 32:-32]
  • make refinement image features contiguous before the first convolution in both deconv variants

The patch is intentionally narrow: it does not change model weights, detector behavior, output schema, or CPU semantics.

Verification

Base checked on 2026-05-22: warmshao/WiLoR-mini@ebec42f94c389070cdd7dda6fd1bf0b4a659c960. Open PR #21 was also checked (alex-bene:deps-fix-focal-input-det-conf@cf8f582223142a57f2f486cd7b2fc0545663473d); that PR covers dependency/API work and does not include this MPS tensor-layout repair.

Local test command:

.venv/bin/python -m pytest tests/test_mps_layout.py -q

Current result on PR head e6f1ceee111b3792303c32db8168f4485eb807fa:

6 passed in 0.55s

The layout tests now run on CPU in ordinary CI and additionally on MPS when available. A fail-first check against base with the current test file failed on the intended non-contiguous tensor assertions for the backbone crop and both refinement deconv variants.

Saved-image smoke on Mac MPS, using assets/img.png, float32, predict_with_bboxes, and cached WiLoR-mini model artifacts:

{"torch":"2.5.0","mps_built":true,"mps_available":true}
{"init_seconds":1.655}
{"elapsed_seconds":6.144,"num_outputs":1,"pred_vertices_shape":[1,778,3],"pred_keypoints_3d_shape":[1,21,3],"pred_cam_t_full_shape":[1,3],"device_route":"mps","dtype":"float32"}

CPU smoke with the same saved image/cache also passed:

{"torch":"2.5.0","init_seconds":1.405,"elapsed_seconds":0.298,"num_outputs":1,"pred_vertices_shape":[1,778,3],"pred_keypoints_3d_shape":[1,21,3],"pred_cam_t_full_shape":[1,3],"device_route":"cpu","dtype":"float32"}

Dependency facts from the smoke venv:

{"python":"3.10.19","torch":"2.5.0","torchvision":"0.20.0","ultralytics":"8.1.34","chumpy":"0.71","dill":"0.4.1","mps_built":true,"mps_available":true}

One dependency note: the detector checkpoint load required dill; it was not installed by the current requirements, and Ultralytics' auto-install fallback failed in a uv-created venv because pip was not present. I installed dill locally to continue the MPS smoke, but did not add that dependency here to keep this PR scoped to the tensor-layout repair.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant