Compressed Sensing 2D pMRI
This example illustrates how to perform 2D compressed sensing image reconstruction from Cartesian sampled MRI data for parallel MRI (sensitivity encoding) with 1-norm regularization of orthogonal wavelet coefficients, using the Julia language.
This page comes from a single Julia file: 4-cs-sense-2d.jl
.
You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 4-cs-sense-2d.ipynb
, or open it in binder here: 4-cs-sense-2d.ipynb
.
This demo is somewhat similar to Fig. 3 in the survey paper "Optimization methods for MR image reconstruction," in Jan 2020 IEEE Signal Processing Magazine, except
- the sampling is 1D phase encoding instead of 2D,
- there are multiple coils,
- we use units (WIP - todo)
- the simulation avoids inverse crimes.
Packages used in this demo (run Pkg.add
as needed):
using ImagePhantoms: ellipse_parameters, SheppLoganBrainWeb, ellipse, phantom
using ImagePhantoms: mri_smap_fit, mri_spectra
#using ImageFiltering: imfilter, centered
using ImageMorphology: dilate #, label_components # imfill
using LazyGrids: ndgrid
using ImageGeoms: embed, embed!
using MIRT: Aodwt, Asense
using MIRTjim: jim, prompt
using MIRT: ir_mri_sensemap_sim
using MIRT: pogm_restart
using LinearAlgebra: norm, dot
using LinearMapsAA: LinearMapAA
using Plots; default(markerstrokecolor=:auto, label="")
using FFTW: fft!, bfft!, fftshift!
using Random: seed!
using InteractiveUtils: versioninfo
The following line is helpful when running this jl-file as a script; this way it will prompt user to hit a key after each image is displayed.
isinteractive() && jim(:prompt, true);
Create (synthetic) data
Image geometry:
uu = 1
fovs = (256, 256) .* uu
nx, ny = (192, 256)
dx, dy = fovs ./ (nx,ny)
x = (-(nx÷2):(nx÷2-1)) * dx
y = (-(ny÷2):(ny÷2-1)) * dy
-128.0:1.0:127.0
Modified Shepp-Logan phantom with random complex phases per ellipse
object = ellipse_parameters(SheppLoganBrainWeb() ; disjoint=true, fovs)
seed!(0)
object = vcat( (object[1][1:end-1]..., 1), # random phases
[(ob[1:end-1]..., randn(ComplexF32)) for ob in object[2:end]]...)
object = ellipse(object)
oversample = 3
Xtrue = phantom(x, y, object, oversample)
cfun = z -> cat(dims = ndims(z)+1, real(z), imag(z))
jim(:aspect_ratio, :equal)
jim(x, y, cfun(Xtrue), "True image\n real | imag"; ncol=2)
Mask (support for image reconstruction)
mask = abs.(Xtrue) .> 0
mask = dilate(dilate(dilate(mask))) # repeated dilate with 3×3 square
@assert mask .* Xtrue == Xtrue
jim(x, y, mask + abs.(Xtrue), "Mask + |Xtrue|")
Make sensitivity maps, normalized so SSoS = 1:
ncoil = 4
smap_raw = ir_mri_sensemap_sim(dims=(nx,ny); ncoil, orbit_start=[45])
jif(args...; kwargs...) = jim(args...; prompt=false, kwargs...)
p1 = jif(x, y, smap_raw, "Sensitivity maps raw");
sum_last = (f, x) -> selectdim(sum(f, x; dims=ndims(x)), ndims(x), 1)
ssos_fun = smap -> sqrt.(sum_last(abs2, smap)) # SSoS
ssos_raw = ssos_fun(smap_raw) # SSoS raw
p2 = jif(x, y, ssos_raw, "SSoS raw, ncoil=$ncoil");
smap = @. smap_raw / ssos_raw * mask # normalize and mask
ssos = ssos_fun(smap) # SSoS
@assert all(≈(1), @view ssos[mask]) # verify ≈ 1
smaps = [eachslice(smap; dims = ndims(smap))...] # stack
p3 = jif(x, y, smaps, "|Sensitivity maps normalized|")
p4 = jif(x, y, map(x -> angle.(x), smaps), "∠Sensitivity maps"; color=:hsv)
jim(p1, p2, p3, p4)
Frequency sample vectors; crucial to match mri_smap_basis
internals!
fx = (-(nx÷2):(nx÷2-1)) / (nx*dx)
fy = (-(ny÷2):(ny÷2-1)) / (ny*dy)
gx, gy = ndgrid(fx, fy);
Somewhat random 1D phase-encode sampling:
seed!(0); sampfrac = 0.3; samp = rand(ny÷2) .< sampfrac
tmp = rand(ny÷2) .< 0.5; samp = [samp .* tmp; reverse(samp .* .!tmp)] # symmetry
samp .|= (abs.(fy*dy) .< 1/8) # fully sampled center ±1/8 phase-encodes
ny_count = count(samp)
samp = trues(nx) * samp'
samp_frac = round(100*count(samp) / (nx*ny), digits=2)
jim(fx, fy, samp, title="k-space sampling ($ny_count / $ny = $samp_frac%)")
To avoid an inverse crime, here we use the 2012 method of Guerquin-Kern et al. and use the analytical k-space values of the phantom combined with an analytical model for the sensitivity maps.
kmax = 7
fit = mri_smap_fit(smaps, embed; mask, kmax, deltas=(dx,dy))
jim(
jif(x, y, cfun(smaps), "Original maps"; clim=(-1,1), nrow=4),
jif(x, y, cfun(fit.smaps), "Fit maps"; clim=(-1,1), nrow=4),
jif(x, y, cfun(100 * (fit.smaps - smaps)), "error * 100"; nrow=4),
layout = (1,3),
)
Analytical spectra computation for complex phantom using all smaps. (No inverse crime here.)
ytrue = mri_spectra(gx[samp], gy[samp], object, fit)
ytrue = hcat(ytrue...)
ytrue = ComplexF32.(ytrue) # save memory
size(ytrue)
(17088, 4)
Noisy under-sampled k-space data:
sig = 1
ydata = ytrue + oneunit(eltype(ytrue)) *
sig * √(2f0) * randn(ComplexF32, size(ytrue)) # complex noise with units!
ysnr = 20 * log10(norm(ytrue) / norm(ydata - ytrue)) # data SNR in dB
46.200874f0
Display zero-filled data:
logger = (x; min=-6, up=maximum(abs,x)) -> log10.(max.(abs.(x) / up, (10.)^min))
jim(:abswarn, false) # suppress warnings about showing magnitude
tmp = embed(ytrue[:,1],samp)
jim(
jif(fx, fy, logger(tmp),
title="k-space |data|\n(zero-filled, coil 1)",
xlabel="νx", ylabel="νy"),
jif(fx, fy, angle.(tmp),
title="∠data, coil 1"; color=:hsv)
)
Prepare to reconstruct
Create a system matrix (encoding matrix) and an initial image.
The system matrix is a LinearMapAO
object, akin to a fatrix
in Matlab MIRT.
This system model ("encoding matrix") is for a 2D image x
being mapped to an array of size count(samp) × ncoil
k-space data.
Here we construct it from FFT and coil map components. So this is like "seeing the sausage being made."
The dx * dy
scale factor here is required because the true k-space data ytrue
comes from an analytical Fourier transform, but the reconstruction uses a discrete Fourier transform. This factor is also needed from a unit-balance perspective.
Ascale = Float32(dx * dy)
A = Ascale * Asense(samp, smaps) # operator!
(size(A), A._odim, A._idim)
((68352, 49152), (17088, 4), (192, 256))
Compare the analytical k-space data with the discrete model k-space data:
y0 = embed(ytrue, samp) # analytical
y1 = embed(A * Xtrue, samp) # discrete
jim(
jif(logger(y0; up=maximum(abs,y0)), "analytical"; clim=(-6,0)),
jif(logger(y1; up=maximum(abs,y0)), "discrete"; clim=(-6,0)),
jif(logger(y1 - y0; up=maximum(abs,y0)), "difference"),
)
norm(y1) / norm(y0) # scale factor is ≈1
0.9995024f0
Initial image based on zero-filled adjoint reconstruction. Note the (dx*dy)²
scale factor here!
nrmse = (x) -> round(norm(x - Xtrue) / norm(Xtrue) * 100, digits=1)
X0 = 1.0f0/(nx*ny) * (A' * ydata) / (dx*dy)^2
jim(x, y, X0, "|X0|: initial image; NRMSE $(nrmse(X0))%")
Wavelet sparsity in synthesis form
The image reconstruction optimization problem here is
\[\arg \min_{x} \frac{1}{2} \| A x - y \|_2^2 + \beta \; \| W x \|_1\]
where
- $y$ is the k-space data,
- $A$ is the system model (simply Fourier encoding
F
here), - $W$ is an orthogonal discrete (Haar) wavelet transform, again implemented as a
LinearMapAA
object.
Because $W$ is unitary, we make the change of variables $z = W x$ and solve for $z$ and then let $x = W' z$ at the end. In fact we use a weighted 1-norm where only the detail wavelet coefficients are regularized, not the approximation coefficients.
Orthogonal discrete wavelet transform operator (LinearMapAO
):
W, scales, _ = Aodwt((nx,ny) ; T = eltype(A))
isdetail = scales .> 0
jim(
jif(scales, "wavelet scales"),
jif(real(W * Xtrue) .* isdetail, "wavelet detail coefficients\nreal for Xtrue"),
)
Inputs needed for proximal gradient methods. The trickiest part of this is determining a bound on the Lipschitz constant, i.e., the spectral norm of $(AW')'(AW')$, which is the same as the spectral norm of $A'A$ because $W$ is unitary. Here we have SSOS=1 for the coil maps, so we just need to account for the number of voxels (because fft
& bfft
are not the unitary DFT) and for the scale factor. (If SSOS was nonuniform, we would use eqn. (6) of Matt Muckley's BARISTA paper. Submit an issue if you need an example using that.)
f_Lz = Ascale^2 * nx*ny # Lipschitz constant
Az = A * W' # another operator!
Fnullz = (z) -> 0 # cost function in `z` not needed
f_gradz = (z) -> Az' * (Az * z - ydata)
regz = 0.03 * nx * ny # oracle from Xtrue wavelet coefficients!
costz = (z) -> 1/2 * norm(Az * z - ydata)^2 + regz * norm(z,1) # 1-norm regularizer
soft = (z,c) -> sign(z) * max(abs(z) - c, 0) # soft thresholding
g_prox = (z,c) -> soft.(z, isdetail .* (regz * c)) # proximal operator (shrink details only)
z0 = W * X0
jim(
jif(z0, "|wavelet coefficients|"),
jif(z0 .* isdetail, "|detail coefficients|"),
; plot_title = "Initial",
)
Iterate
Run ISTA=PGM and FISTA=FPGM and POGM, the latter two with adaptive restart. See Kim & Fessler, 2018 for adaptive restart algorithm details.
Functions for tracking progress:
function fun_ista(iter, xk_z, yk, is_restart)
xh = W' * xk_z
return (costz(xk_z), nrmse(xh), is_restart) # , psnr(xh) # time()
end
function fun_fista(iter, xk, yk_z, is_restart)
xh = W' * yk_z
return (costz(yk_z), nrmse(xh), is_restart) # , psnr(xh) # time()
end;
Run and compare three proximal gradient methods:
niter = 20
z_ista, out_ista = pogm_restart(z0, Fnullz, f_gradz, f_Lz;
mom=:pgm, niter,
restart=:none, restart_cutoff=0., g_prox, fun=fun_ista)
Xista = W'*z_ista
@show nrmse(Xista)
z_fista, out_fista = pogm_restart(z0, Fnullz, f_gradz, f_Lz;
mom=:fpgm, niter,
restart=:gr, restart_cutoff=0., g_prox, fun=fun_fista)
Xfista = W'*z_fista
@show nrmse(Xfista)
z_pogm, out_pogm = pogm_restart(z0, Fnullz, f_gradz, f_Lz;
mom=:pogm, niter,
restart=:gr, restart_cutoff=0., g_prox, fun=fun_fista)
Xpogm = W'*z_pogm
@show nrmse(Xpogm)
jim(
jif(x, y, Xfista, "FISTA/FPGM"),
jif(x, y, Xpogm, "POGM with ODWT"),
)
Convergence rate: POGM is fastest
Plot cost function vs iteration:
cost_ista = [out_ista[k][1] for k in 1:niter+1]
cost_fista = [out_fista[k][1] for k in 1:niter+1]
cost_pogm = [out_pogm[k][1] for k in 1:niter+1]
cost_min = min(minimum(cost_ista), minimum(cost_pogm))
pc = plot(xlabel="iteration k", ylabel="Relative cost")
scatter!(0:niter, cost_ista .- cost_min, label="Cost ISTA")
scatter!(0:niter, cost_fista .- cost_min, markershape=:square, label="Cost FISTA")
scatter!(0:niter, cost_pogm .- cost_min, markershape=:utriangle, label="Cost POGM")
isinteractive() && prompt();
Plot NRMSE vs iteration:
nrmse_ista = [out_ista[k][2] for k in 1:niter+1]
nrmse_fista = [out_fista[k][2] for k in 1:niter+1]
nrmse_pogm = [out_pogm[k][2] for k in 1:niter+1]
pn = plot(xlabel="iteration k", ylabel="NRMSE %")#, ylims=(3,6.5))
scatter!(0:niter, nrmse_ista, label="NRMSE ISTA")
scatter!(0:niter, nrmse_fista, markershape=:square, label="NRMSE FISTA")
scatter!(0:niter, nrmse_pogm, markershape=:utriangle, label="NRMSE POGM")
isinteractive() && prompt();
Show error images:
snrfun = (x) -> round(-20log10(nrmse(x)/100); digits=1)
p1 = jif(x, y, Xtrue, "|true|"; clim=(0,2.5))
p2 = jif(x, y, X0, "|X0|: initial"; clim=(0,2.5))
p3 = jif(x, y, Xpogm, "|POGM recon|"; clim=(0,2.5))
p5 = jif(x, y, X0 - Xtrue, "|X0 error|"; clim=(0,1), color=:cividis,
xlabel = "NRMSE = $(nrmse(X0))%\n SNR = $(snrfun(X0)) dB")
p6 = jif(x, y, Xpogm - Xtrue, "|Xpogm error|"; clim=(0,1), color=:cividis,
xlabel = "NRMSE = $(nrmse(Xpogm))%\n SNR = $(snrfun(Xpogm)) dB")
pe = jim(p2, p3, p5, p6)
Discussion
As reported in the optimization survey paper cited above, POGM converges faster than ISTA and FISTA.
The final images serve as a reminder that NRMSE (and PSNR) and dubious image quality metrics. The NRMSE after 20 iterations may seem only a bit lower than the NRMSE of the initial image, but aliasing artifacts (ripples) were greatly reduced by the CS-SENSE reconstruction method.
Reproducibility
This page was generated with the following version of Julia:
using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
"Julia Version 1.10.3"
"Commit 0b4590a5507 (2024-04-30 10:59 UTC)"
"Build Info:"
" Official https://julialang.org/ release"
"Platform Info:"
" OS: Linux (x86_64-linux-gnu)"
" CPU: 4 × AMD EPYC 7763 64-Core Processor"
" WORD_SIZE: 64"
" LIBM: libopenlibm"
" LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
"Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
""
And with the following package versions
import Pkg; Pkg.status()
Status `~/work/Examples/Examples/docs/Project.toml`
[e30172f5] Documenter v1.4.1
[940e8692] Examples v0.0.1 `~/work/Examples/Examples`
[7a1cc6ca] FFTW v1.8.0
[9ee76f2b] ImageGeoms v0.11.1
[787d08f9] ImageMorphology v0.4.5
[71a99df6] ImagePhantoms v0.8.1
[b964fa9f] LaTeXStrings v1.3.1
[7031d0ef] LazyGrids v1.0.0
[599c1a8e] LinearMapsAA v0.12.0
[98b081ad] Literate v2.18.0
[23992714] MAT v0.10.7
[7035ae7a] MIRT v0.18.1
[170b2178] MIRTjim v0.25.0
[efe261a4] NFFT v0.13.3
[91a5bcdd] Plots v1.40.4
[2913bbd2] StatsBase v0.34.3
[1986cc42] Unitful v1.20.0
[b77e0a4c] InteractiveUtils
[37e2e46d] LinearAlgebra
[44cfe95a] Pkg v1.10.0
[9a3f8284] Random
This page was generated using Literate.jl.