Compressed Sensing 2D

This example illustrates how to perform 2D compressed sensing image reconstruction from Cartesian sampled MRI data with 1-norm regularization of orthogonal wavelet coefficients, using the Julia language.

This page comes from a single Julia file: 2-cs-wl-l1-2d.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 2-cs-wl-l1-2d.ipynb, or open it in binder here: 2-cs-wl-l1-2d.ipynb.

This demo is somewhat similar to Fig. 3 in the survey paper "Optimization methods for MR image reconstruction," in Jan 2020 IEEE Signal Processing Magazine, except that the sampling is 1D phase encoding instead of 2D.

Packages used in this demo (run Pkg.add as needed):

using ImagePhantoms: shepp_logan, SheppLoganEmis, spectrum, phantom
using MIRT: embed, Afft, Aodwt
using MIRTjim: jim, prompt
using MIRT: pogm_restart
using LinearAlgebra: norm
using Plots; default(markerstrokecolor=:auto, label="")
using FFTW: fft
using Random: seed!
using InteractiveUtils: versioninfo

The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.

isinteractive() && jim(:prompt, true);

Create (synthetic) data

Shepp-Logan phantom (unrealistic because real-valued):

nx,ny = 192,256
object = shepp_logan(SheppLoganEmis(); fovs=(ny,ny))
Xtrue = phantom(-(nx÷2):(nx÷2-1), -(ny÷2):(ny÷2-1), object, 2)
Xtrue = reverse(Xtrue, dims=2)
clim = (0,9)
jim(Xtrue, "true image"; clim)
Example block output

Somewhat random 1D phase-encode sampling:

seed!(0); sampfrac = 0.2; samp = rand(ny) .< sampfrac; sig = 1
mod2 = (N) -> mod.((0:N-1) .+ Int(N/2), N) .- Int(N/2)
samp .|= (abs.(mod2(ny)) .< Int(ny/8)) # fully sampled center rows
samp = trues(nx) * samp'
jim(samp, fft0=true, title="k-space sampling ($(100count(samp)/(nx*ny))%)")
Example block output

Generate noisy, under-sampled k-space data (inverse-crime simulation):

ytrue = fft(Xtrue)[samp]
y = ytrue + sig * √(2) * randn(ComplexF32, size(ytrue)) # complex noise!
y = ComplexF32.(y) # save memory
ysnr = 20 * log10(norm(ytrue) / norm(y-ytrue)) # data SNR in dB
54.77991363373502

Display zero-filled data:

logger = (x; min=-6) -> log10.(max.(abs.(x) / maximum(abs, x), (10.)^min))
jim(:abswarn, false) # suppress warnings about showing magnitude
jim(logger(embed(ytrue,samp)), fft0=true, title="k-space |data| (zero-filled)",
    xlabel="kx", ylabel="ky")
Example block output

Prepare to reconstruct

Creating a system matrix (encoding matrix) and an initial image The system matrix is a LinearMapAA object, akin to a fatrix in Matlab MIRT.

System model ("encoding matrix") from MIRT:

F = Afft(samp) # operator!
LinearMapAO: 18432 × 49152 odim=(18432,) idim=(192, 256) T=ComplexF32 Do=1 Di=2
prop = (name = "fft", dims = 1:2, fft_forward = true)
map = 18432×49152 LinearMaps.FunctionMap{ComplexF32,true}(#5, #7; issymmetric=false, ishermitian=false, isposdef=false)

Initial image based on zero-filled reconstruction:

nrmse = (x) -> round(norm(x - Xtrue) / norm(Xtrue) * 100, digits=1)
X0 = 1.0f0/(nx*ny) * (F' * y)
jim(X0, "|X0|: initial image; NRMSE $(nrmse(X0))%")
Example block output

Wavelet sparsity in synthesis form

The image reconstruction optimization problem here is

\[\arg \min_{x} \frac{1}{2} \| A x - y \|_2^2 + \beta \; \| W x \|_1\]

where $y$ is the k-space data, $A$ is the system model (simply Fourier encoding F here), $W$ is an orthogonal discrete (Haar) wavelet transform, again implemented as a LinearMapAA object. Because $W$ is unitary, we make the change of variables $z = W x$ and solve for $z$ and then let $x = W' z$ at the end. In fact we use a weighted 1-norm where only the detail wavelet coefficients are regularized, not the approximation coefficients.

Orthogonal discrete wavelet transform operator (LinearMapAO):

W, scales, _ = Aodwt((nx,ny) ; T = eltype(F))
isdetail = scales .> 0
jim(
    jim(scales, "wavelet scales"),
    jim(real(W * Xtrue) .* isdetail, "wavelet detail coefficients"),
)
Example block output

Inputs needed for proximal gradient methods:

Az = F * W' # another operator!
Fnullz = (z) -> 0 # cost function in `z` not needed
f_gradz = (z) -> Az' * (Az * z - y)
f_Lz = nx*ny # Lipschitz constant for single coil Cartesian DFT
regz = 0.03 * nx * ny # oracle from Xtrue wavelet coefficients!
costz = (z) -> 1/2 * norm(Az * z - y)^2 + regz * norm(z,1) # 1-norm regularizer
soft = (z,c) -> sign(z) * max(abs(z) - c, 0) # soft thresholding
g_prox = (z,c) -> soft.(z, isdetail .* (regz * c)) # proximal operator (shrink details only)
z0 = W * X0
jim(z0, "Initial wavelet coefficients")
Example block output

Iterate

Run ISTA=PGM and FISTA=FPGM and POGM, the latter two with adaptive restart See Kim & Fessler, 2018 for adaptive restart algorithm details.

Functions for tracking progress:

function fun_ista(iter, xk_z, yk, is_restart)
    xh = W' * xk_z
    return (costz(xk_z), nrmse(xh), is_restart) # , psnr(xh)) # time()
end

function fun_fista(iter, xk, yk_z, is_restart)
    xh = W' * yk_z
    return (costz(yk_z), nrmse(xh), is_restart) # , psnr(xh)) # time()
end;

Run and compare three proximal gradient methods:

niter = 50
z_ista, out_ista = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:pgm, niter=niter,
    restart=:none, restart_cutoff=0., g_prox=g_prox, fun=fun_ista)
Xista = W'*z_ista
@show nrmse(Xista)

z_fista, out_fista = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:fpgm, niter=niter,
    restart=:gr, restart_cutoff=0., g_prox=g_prox, fun=fun_fista)
Xfista = W'*z_fista
@show nrmse(Xfista)

z_pogm, out_pogm = pogm_restart(z0, Fnullz, f_gradz, f_Lz; mom=:pogm, niter=niter,
    restart=:gr, restart_cutoff=0., g_prox=g_prox, fun=fun_fista)
Xpogm = W'*z_pogm
@show nrmse(Xpogm)

jim(
    jim(Xfista, "FISTA/FPGM"),
    jim(Xpogm, "POGM with ODWT"),
)
Example block output

POGM is fastest

Plot cost function vs iteration:

cost_ista = [out_ista[k][1] for k=1:niter+1]
cost_fista = [out_fista[k][1] for k=1:niter+1]
cost_pogm = [out_pogm[k][1] for k=1:niter+1]
cost_min = min(minimum(cost_ista), minimum(cost_pogm))
plot(xlabel="iteration k", ylabel="Relative cost")
scatter!(0:niter, cost_ista  .- cost_min, label="Cost ISTA")
scatter!(0:niter, cost_fista .- cost_min, markershape=:square, label="Cost FISTA")
scatter!(0:niter, cost_pogm  .- cost_min, markershape=:utriangle, label="Cost POGM")
Example block output
isinteractive() && prompt();

Plot NRMSE vs iteration:

nrmse_ista = [out_ista[k][2] for k=1:niter+1]
nrmse_fista = [out_fista[k][2] for k=1:niter+1]
nrmse_pogm = [out_pogm[k][2] for k=1:niter+1]
pn = plot(xlabel="iteration k", ylabel="NRMSE %", ylims=(3,6.5))
scatter!(0:niter, nrmse_ista, label="NRMSE ISTA")
scatter!(0:niter, nrmse_fista, markershape=:square, label="NRMSE FISTA")
scatter!(0:niter, nrmse_pogm, markershape=:utriangle, label="NRMSE POGM")
Example block output

Show error images:

p1 = jim(Xtrue, "true")
p2 = jim(X0, "X0: initial")
p3 = jim(Xpogm, "POGM recon")
p5 = jim(X0 - Xtrue, "X0 error", clim=(0,2))
p6 = jim(Xpogm - Xtrue, "Xpogm error", clim=(0,2))
pe = jim(p2, p3, p5, p6)
Example block output

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
 "Julia Version 1.10.3"
 "Commit 0b4590a5507 (2024-04-30 10:59 UTC)"
 "Build Info:"
 "  Official https://julialang.org/ release"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 4 × AMD EPYC 7763 64-Core Processor"
 "  WORD_SIZE: 64"
 "  LIBM: libopenlibm"
 "  LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
 "Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/Examples/Examples/docs/Project.toml`
  [e30172f5] Documenter v1.4.1
  [940e8692] Examples v0.0.1 `~/work/Examples/Examples`
  [7a1cc6ca] FFTW v1.8.0
  [9ee76f2b] ImageGeoms v0.11.1
  [787d08f9] ImageMorphology v0.4.5
  [71a99df6] ImagePhantoms v0.8.1
  [b964fa9f] LaTeXStrings v1.3.1
  [7031d0ef] LazyGrids v1.0.0
  [599c1a8e] LinearMapsAA v0.12.0
  [98b081ad] Literate v2.18.0
  [23992714] MAT v0.10.7
  [7035ae7a] MIRT v0.18.1
  [170b2178] MIRTjim v0.25.0
  [efe261a4] NFFT v0.13.3
  [91a5bcdd] Plots v1.40.4
  [2913bbd2] StatsBase v0.34.3
  [1986cc42] Unitful v1.20.0
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.10.0
  [9a3f8284] Random

This page was generated using Literate.jl.