Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 148 additions & 0 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
name: Sphinx Docs

on:
push:
branches:
- main
pull_request:

permissions:
contents: write

defaults:
run:
shell: bash

jobs:
get_notebooks:
name: Get list of notebooks
runs-on: ubuntu-latest
steps:
- name: Checkout LION repo
uses: actions/checkout@v6

- id: set-matrix
run: |
echo "notebook_paths=$(find examples/notebooks -type f -name '*.ipynb' | jq -R -s -c 'split("\n")[:-1]')" >> $GITHUB_OUTPUT

- name: Notebook overview
run: |
echo "jupyter-notebooks: ${{ steps.set-matrix.outputs.notebook_paths }}"

outputs:
notebook_paths: ${{ steps.set-matrix.outputs.notebook_paths }}

run_notebook:
name: Run notebook
needs: get_notebooks
runs-on: ubuntu-latest
permissions:
pull-requests: write
contents: write
strategy:
fail-fast: false
matrix:
notebook_path: ${{ fromJson(needs.get_notebooks.outputs.notebook_paths) }}
steps:
- name: Checkout repo
uses: actions/checkout@v6

- name: Install LION and dependencies
run: pip install -e .[notebooks]

- name: Notebook name
run: |
echo "current jupyter-notebook: ${{ matrix.notebook_path }}"

- name: Add nb-myst download badge
run: |
notebook=${{ matrix.notebook_path }}
notebook_name=$(basename $notebook)
download_badge_md="[![Download notebook](https://img.shields.io/badge/Download-notebook-blue?logo=jupyter)](path:$notebook_name)"
python_command="import nbformat as nbf\n\
nb = nbf.read(open('$notebook'), as_version=4)\n\
# if the 1st cell is md and has colab text => add space after\n\
if nb['cells'][0]['cell_type'] == 'markdown' and 'colab' in nb['cells'][0]['source'].lower():\n\
nb['cells'][0]['source'] += ' '\n\
# if there is no md cell with colab => create empty md cell on top\n\
else:\n\
nb['cells'].insert(0, nbf.v4.new_markdown_cell())\n\
nb['cells'][0]['source'] += '$download_badge_md'\n\
nbf.write(nb, open('$notebook', 'w'))"

python -c "exec (\"$python_command\")"

- name: Run notebook
uses: fzimmermann89/run-notebook@v3
env:
RUNNER: ${{ toJson(runner) }}
with:
notebook: ${{ matrix.notebook_path }}

- name: Get artifact names
id: artifact_names
run: |
notebook=${{ matrix.notebook_path }}
echo "ARTIFACT_NAME=$(basename ${notebook/.ipynb})" >> $GITHUB_OUTPUT
echo "IPYNB_EXECUTED=$(basename $notebook)" >> $GITHUB_OUTPUT

- name: Upload notebook
uses: actions/upload-artifact@v5
if: always()
with:
name: ${{ steps.artifact_names.outputs.ARTIFACT_NAME }}
path: ${{ github.workspace }}/nb-runner.out/${{ steps.artifact_names.outputs.IPYNB_EXECUTED }}
env:
RUNNER: ${{ toJson(runner) }}

create_documentation:
name: Build and deploy documentation
needs: run_notebook
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0 # fetch history for github links
fetch-tags: true

- name: Install LION and dependencies
run: pip install -e .[docs]

- name: Download executed notebook ipynb files
id: download
uses: actions/download-artifact@v6
with:
path: ./docs/source/_notebooks/
merge-multiple: true

- name: Build docs
run: |
sphinx-build -b html ./docs/source ./docs/build/html
rm -rf ./docs/build/html/.doctrees

- name: Upload documentation artifact
id: upload_docs
uses: actions/upload-artifact@v5
with:
name: Documentation
path: docs/build/html/

# if the one of above steps fails the "artifact-url" will be an empty string
- name: Dump documentation info
if: always()
run: |
echo "${{ steps.upload_docs.outputs.artifact-url }}" > artifact_url

- name: Upload docs-metadata artifact
if: always()
uses: actions/upload-artifact@v5
with:
name: artifact_url
path: artifact_url

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}

# Cancel in-progress runs when a new workflow with the same group name is triggered
cancel-in-progress: true
6 changes: 5 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ dist
example.py
/slurm

wandb
wandb

# Sphinx documentation
docs/build/
docs/source/_notebooks/
11 changes: 5 additions & 6 deletions LION/CTtools/ct_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,19 @@

# math/science imports
import numpy as np
import torch
import tomosipo as ts
import torch

# AItomotools imports
from LION.CTtools.ct_geometry import Geometry
from LION.operators import CTProjectionOp


def from_HU_to_normal(img):
"""
Converts image in Hounsfield Units (air-> -1000, bone->500) into a [0-1] image.
Comercial scanners use a piecewise linear function. Check STIR for real values. (https://raw.githubusercontent.com/UCL/STIR/85cc1940c297b1749cf44a9fba937d7cefdccd47/src/utilities/share/ct_slopes.json)
"""

if isinstance(img, np.ndarray):
return np.minimum(np.maximum((img.astype(np.float32) + 1000) / 3000, 0), 1)
elif isinstance(img, torch.Tensor):
Expand All @@ -37,7 +37,6 @@ def from_HU_to_mu(img):
bone->1.52 g/cm^3). Approximate.
Comercial scanners use a piecewise linear function. Check STIR for real values. (https://raw.githubusercontent.com/UCL/STIR/85cc1940c297b1749cf44a9fba937d7cefdccd47/src/utilities/share/ct_slopes.json)
"""

if isinstance(img, np.ndarray):
return np.maximum(
((1.52 - 0.0012) / (500 + 1000)) * (img.astype(np.float32) + 1000) + 0.0012,
Expand Down Expand Up @@ -154,7 +153,7 @@ def from_HU_to_material_id(img):
return materials


def make_operator(geometry: Geometry):
def make_operator(geometry: Geometry) -> CTProjectionOp:
if not isinstance(geometry, Geometry):
raise ValueError(
"Input geometry is not of class LION.CTtools.ct_geometry.Geometry"
Expand Down Expand Up @@ -182,6 +181,7 @@ def make_operator(geometry: Geometry):
else:
raise ValueError("Geometry mode not understood, has to be 'fan' or 'parallel'")
A = ts.operator(vg, pg)
A = CTProjectionOp(A)
return A


Expand All @@ -195,7 +195,6 @@ def forward_projection(
distances from source to detector DSD and distance from source to object DSO.
May support other backends than tomosipo
"""

if backend != "tomosipo":
raise ValueError("Only tomosipo backend for CT supported")
# You can add other backends here
Expand All @@ -209,7 +208,7 @@ def forward_projection(
if image.shape[0] > 1: # there is no reason to have this constraint
raise ValueError("Image must be 2D")
elif len(image.shape) == 2:
image = torch.unsqueeze(image, axis=0)
image = torch.unsqueeze(image, 0)
else:
raise ValueError("Image must be 2D")

Expand Down
4 changes: 3 additions & 1 deletion LION/classical_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

from LION.classical_algorithms.conjugate_gradient import conjugate_gradient
from LION.classical_algorithms.fdk import fdk
from LION.classical_algorithms.fista import fista_l1
from LION.classical_algorithms.sirt import sirt
from LION.classical_algorithms.spgl1_torch import spgl1_torch
from LION.classical_algorithms.tv_min import tv_min

__all__ = ["conjugate_gradient", "fdk", "sirt", "tv_min"]
__all__ = ["conjugate_gradient", "fdk", "fista_l1", "sirt", "spgl1_torch", "tv_min"]
141 changes: 141 additions & 0 deletions LION/classical_algorithms/fista.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
"""FISTA algorithm for l1-regularized problems."""

import math

import torch
from tqdm import tqdm

from LION.operators import Operator
from LION.utils.math import power_method


def soft_threshold(v: torch.Tensor, tau: float) -> torch.Tensor:
r"""Soft thresholding operator.

It is defined as:

.. math::

S_{\tau}(v) = \mathrm{sign}(v) \cdot \\max(|v| - \tau, 0)

Parameters
----------
v : torch.Tensor
Input tensor.
tau : float
Threshold parameter.

Returns
-------
torch.Tensor
Result after applying soft thresholding.
"""
return torch.sign(v) * torch.clamp(torch.abs(v) - tau, min=0.0)


def fista_l1(
op: Operator,
y: torch.Tensor,
lam: float,
max_iter: int = 200,
tol: float = 1e-4,
L: float | None = None,
verbose: bool = False,
progress_bar: bool = False,
) -> torch.Tensor:
r"""Solve :math:`\min_w \tfrac12\lVert A w - y\rVert_2^2 + \lambda \lVert w\rVert_1`
by FISTA.

Implements the Fast Iterative Shrinkage-Thresholding Algorithm (FISTA) for
:math:`\ell_1`-regularised least squares [BeckTeboulle2009]_. FISTA is an
accelerated proximal-gradient method for composite objectives
:math:`f(w) + \lambda \lVert w\rVert_1` with smooth data-fidelity term
:math:`f(w) = \tfrac12\lVert A w - y\rVert_2^2`; see
[DaubechiesDefriseDeMol2004]_ for the original ISTA scheme and
[ParikhBoyd2014]_ for a general overview of proximal-gradient methods.

Parameters
----------
op : Operator
Linear operator implementing the forward map and its adjoint. It is
called as ``op(w)`` and ``op.adjoint(r)``.
y : torch.Tensor
Measurements, shape ``(M,)``.
lam : float
:math:`\ell_1` regularisation parameter.
max_iter : int
Maximum number of iterations.
tol : float
Relative stopping threshold on :math:`w`. The iteration stops once
``norm(w_next - w) / (norm(w) + 1e-8) < tol``.
L : float or None
Lipschitz constant of :math:`A^\top A`. If ``None``, estimated by a
power method on the normal operator :math:`A^\top A`, following the
standard practice in FISTA-type schemes [BeckTeboulle2009]_.
verbose : bool
If True, prints basic progress such as objective value and relative
change.
progress_bar : bool
If True, wraps the iteration in a ``tqdm`` progress bar.

Returns
-------
w : torch.Tensor
Estimated coefficient vector, shape ``(Nw,)``.

References
----------
.. [DaubechiesDefriseDeMol2004] I. Daubechies, M. Defrise, and C. De Mol,
"An iterative thresholding algorithm for linear inverse problems with a
sparsity constraint", Communications on Pure and Applied Mathematics,
57(11):1413-1457, 2004.
.. [BeckTeboulle2009] A. Beck and M. Teboulle, "A fast iterative
shrinkage-thresholding algorithm for linear inverse problems", SIAM
Journal on Imaging Sciences, 2(1):183-202, 2009.
.. [ParikhBoyd2014] N. Parikh and S. Boyd, "Proximal Algorithms",
Foundations and Trends in Optimization, 1(3):127-239, 2014.
"""
y = y.detach()
device = y.device

# Dimension inferred from one adjoint call
w0: torch.Tensor = op.adjoint(torch.zeros_like(y))
n: int = w0.numel()

if L is None:
# Power method estimates ||A||_2; Lipschitz constant is ||A||_2^2
L = power_method(op, device=device).item() ** 2
step = 1.0 / (L + 1e-12)

w = torch.zeros(n, dtype=torch.float32, device=device)
z = w.clone()
t = 1.0

iterator = range(max_iter)
if progress_bar:
iterator = tqdm(iterator, desc="FISTA l1")
for k in iterator:
Az: torch.Tensor = op(z)
grad = op.adjoint(Az - y) # gradient of data term, shape (n,)

w_next = soft_threshold(z - step * grad, lam * step)
t_next = 0.5 * (1.0 + math.sqrt(1.0 + 4.0 * t * t))
z = w_next + (t - 1.0) / t_next * (w_next - w)

rel_change = torch.norm(w_next - w) / (torch.norm(w) + 1e-8)
w = w_next
t = t_next

if verbose:
data_term = 0.5 * torch.norm((op(w)) - y).pow(2).item()
l1_term = lam * torch.norm(w, p=1).item()
print(
f"Iter {k:4d} f = {data_term + l1_term:.4e} "
f"rel_change = {rel_change.item():.2e} tol = {tol:.2e} "
f"rel_change < tol: {rel_change.item() < tol}"
)

if rel_change.item() < tol:
break

return w
Loading
Loading