L+S 2D dynamic recon
This page illustrates dynamic parallel MRI image reconstruction using a low-rank plus sparse (L+S) model optimized by a fast algorithm described in the paper by Claire Lin and Jeff Fessler Efficient Dynamic Parallel MRI Reconstruction for the Low-Rank Plus Sparse Model, IEEE Trans. on Computational Imaging, 5(1):17-26, 2019, by Claire Lin and Jeff Fessler, EECS Department, University of Michigan.
The Julia code here is a translation of part of the Matlab code used in the original paper.
If you use this code, please cite that paper.
This page comes from a single Julia file: 5-l-plus-s.jl
.
You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 5-l-plus-s.ipynb
, or open it in binder here: 5-l-plus-s.ipynb
.
Setup
Packages needed here.
# using Unitful: s
using Plots; cgrad, default(markerstrokecolor=:auto, label="")
using MIRT: Afft, Asense, embed
using MIRT: pogm_restart, poweriter
using MIRTjim: jim, prompt
using FFTW: fft!, bfft!, fftshift!
using LinearMapsAA: LinearMapAA, block_diag, redim, undim
using MAT: matread
import Downloads # todo: use Fetch or DataDeps?
using LinearAlgebra: dot, norm, svd, svdvals, Diagonal, I
using Random: seed!
using StatsBase: mean
using LaTeXStrings
The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.
jif(args...; kwargs...) = jim(args...; prompt=false, kwargs...)
isinteractive() ? jim(:prompt, true) : prompt(:draw);
Overview
Dynamic image reconstruction using a "low-rank plus sparse" or "L+S" approach was proposed by Otazo et al. and uses the following cost function:
\[X = \hat{L} + \hat{S} ,\qquad (\hat{L}, \hat{S}) = \arg \min_{L,S} \frac{1}{2} \| E (L + S) - d \|_2^2 + λ_L \| L \|_* + λ_S \| vec(T S) \|_1\]
where $T$ is a temporal unitary FFT, $E$ is an encoding operator (system matrix), and $d$ is Cartesian undersampled multicoil k-space data.
The Otazo paper used an iterative soft thresholding algorithm (ISTA) to solve this optimization problem. Using FISTA is faster, but using the proximal optimized gradient method (POGM) with adaptive restart is even faster.
This example reproduces part of Figures 1 & 2 in Claire Lin's paper, based on the cardiac perfusion example.
Read data
if !@isdefined(data)
url = "https://github.com/JeffFessler/MIRTdata/raw/main/mri/lin-19-edp/"
dataurl = url * "cardiac_perf_R8.mat"
data = matread(Downloads.download(dataurl))
xinfurl = url * "Xinf.mat"
Xinf = matread(Downloads.download(xinfurl))["Xinf"]["perf"] # (128,128,40)
end;
Show converged image as a preview:
pinf = jim(Xinf, L"\mathrm{Converged\ image\ sequence } X_∞")
Organize k-space data:
if !@isdefined(ydata0)
ydata0 = data["kdata"] # k-space data full of zeros
ydata0 = permutedims(ydata0, [1, 2, 4, 3]) # (nx,ny,nc,nt)
ydata0 = ComplexF32.(ydata0)
end
(nx, ny, nc, nt) = size(ydata0)
(128, 128, 12, 40)
Extract sampling pattern from zeros of k-space data:
if !@isdefined(samp)
samp = ydata0[:,:,1,:] .!= 0
for ic in 2:nc # verify it is same for all coils
@assert samp == (ydata0[:,:,ic,:] .!= 0)
end
kx = -(nx÷2):(nx÷2-1)
ky = -(ny÷2):(ny÷2-1)
psamp = jim(kx, ky, samp, "Sampling patterns for $nt frames";
xlabel=L"k_x", ylabel=L"k_y")
end
Are all k-space rows are sampled in one of the 40 frames? Sadly no. The 10 blue rows shown below are never sampled. A better sampling pattern design could have avoided this issue.
samp_sum = sum(samp, dims=3)
color = cgrad([:blue, :black, :white], [0, 1/2nt, 1])
pssum = jim(kx, ky, samp_sum; xlabel="kx", ylabel="ky",
color, clim=(0,nt), title="Number of sampled frames out of $nt")
Prepare coil sensitivity maps
if !@isdefined(smaps)
smaps_raw = data["b1"] # raw coil sensitivity maps
jim(smaps_raw, "Raw |coil maps| for $nc coils")
sum_last = (f, x) -> selectdim(sum(f, x; dims=ndims(x)), ndims(x), 1)
ssos_fun = smap -> sqrt.(sum_last(abs2, smap)) # SSoS
ssos_raw = ssos_fun(smaps_raw)
smaps = smaps_raw ./ ssos_raw
ssos = ssos_fun(smaps)
@assert all(≈(1), ssos)
pmap = jim(smaps, "Normalized |coil maps| for $nc coils")
end
Temporal unitary FFT sparsifying transform for image sequence of size (nx, ny, nt)
:
TF = Afft((nx,ny,nt), 3; unitary=true) # unitary FFT along 3rd (time) dimension
if false # verify adjoint
tmp1 = randn(ComplexF32, nx, ny, nt)
tmp2 = randn(ComplexF32, nx, ny, nt)
@assert dot(tmp2, TF * tmp1) ≈ dot(TF' * tmp2, tmp1)
@assert TF' * (TF * tmp1) ≈ tmp1
(size(TF), TF._odim, TF._idim)
end
Examine temporal Fourier sparsity of Xinf. The low temporal frequencies dominate, as expected, because Xinf was reconstructed using this regularizer!
tmp = TF * Xinf
ptfft = jim(tmp, "|Temporal FFT of Xinf|")
System matrix
Construct dynamic parallel MRI system model. It is block diagonal where each frame has its own sampling pattern. The input (image) here has size (nx=128, ny=128, nt=40)
. The output (data) has size (nsamp=2048, nc=12, nt=40)
because every frame has 16 phase-encode lines of 128 samples.
todo: precompute (i)fft along readout direction to save time
The code in the original Otazo et al. paper used an ifft
in the forward model and an fft
in the adjoint, so we must use a flag here to match that model.
Aotazo = (samp, smaps) -> Asense(samp, smaps; unitary=true, fft_forward=false) # Otazo style
A = block_diag([Aotazo(s, smaps) for s in eachslice(samp, dims=3)]...)
#A = ComplexF32(1/sqrt(nx*ny)) * A # match Otazo's scaling
(size(A), A._odim, A._idim)
((983040, 655360), (2048, 12, 40), (128, 128, 40))
Reshape data to match the system model
if !@isdefined(ydata)
tmp = reshape(ydata0, :, nc, nt)
tmp = [tmp[vec(s),:,it] for (it,s) in enumerate(eachslice(samp, dims=3))]
ydata = cat(tmp..., dims=3) # (nsamp,nc,nt) = (2048,12,40) no "zeros"
end
size(ydata)
(2048, 12, 40)
Final encoding operator E
for L+S because we stack X = [L;S]
tmp = LinearMapAA(I(nx*ny*nt);
odim=(nx,ny,nt), idim=(nx,ny,nt), T=ComplexF32, prop=(;name="I"))
tmp = kron([1 1], tmp)
AII = redim(tmp; odim=(nx,ny,nt), idim=(nx,ny,nt,2)) # "squeeze" odim
E = A * AII;
Run power iteration to verify that opnorm(E) = √2
if false
(_, σ1E) = poweriter(undim(E)) # 1.413 ≈ √2
else
σ1E = √2
end
1.4142135623730951
Check scale factor of Xinf. (It should be ≈1.)
tmp = A * Xinf
scale0 = dot(tmp, ydata) / norm(tmp)^2 # 1.009 ≈ 1
1.0090379f0 + 8.4778144f-8im
Crude initial image
L0 = A' * ydata # adjoint (zero-filled)
S0 = zeros(ComplexF32, nx, ny, nt)
X0 = cat(L0, S0, dims=ndims(L0)+1) # (nx, ny, nt, 2) = (128, 128, 40, 2)
M0 = AII * X0 # L0 + S0
pm0 = jim(M0, "|Initial L+S via zero-filled recon|")
L+S reconstruction
Prepare for proximal gradient methods
Scalars to match Otazo's results
scaleL = 130 / 1.2775 # Otazo's stopping St(1) / b1 constant squared
scaleS = 1 / 1.2775; # 1 / b1 constant squared
L+S regularizer
lambda_L = 0.01 # regularization parameter
lambda_S = 0.01 * scaleS
Lpart = X -> selectdim(X, ndims(X), 1) # extract "L" from X
Spart = X -> selectdim(X, ndims(X), 2) # extract "S" from X
nucnorm(L::AbstractMatrix) = sum(svdvals(L)) # nuclear norm
nucnorm(L::AbstractArray) = nucnorm(reshape(L, :, nt)); # (nx*ny, nt) for L
Optimization overall composite cost function
Fcost = X -> 0.5 * norm(E * X - ydata)^2 +
lambda_L * scaleL * nucnorm(Lpart(X)) + # note scaleL !
lambda_S * norm(TF * Spart(X), 1);
f_grad = X -> E' * (E * X - ydata); # gradient of data-fit term
#18 (generic function with 1 method)
Lipschitz constant of data-fit term is 2 because A is unitary and AII is like ones(2,2).
f_L = 2; # σ1E^2
Proximal operator for scaled nuclear norm $β | X |_*$: singular value soft thresholding (SVST).
function SVST(X::AbstractArray, β)
dims = size(X)
X = reshape(X, :, dims[end]) # assume time frame is the last dimension
U,s,V = svd(X)
sthresh = @. max(s - β, 0)
keep = findall(>(0), sthresh)
X = U[:,keep] * Diagonal(sthresh[keep]) * V[:,keep]'
X = reshape(X, dims)
return X
end;
Combine proximal operators for L and S parts to make overall prox for X
soft = (v,c) -> sign(v) * max(abs(v) - c, 0) # soft threshold function
S_prox = (S, β) -> TF' * soft.(TF * S, β) # 1-norm proximal mapping for unitary TF
g_prox = (X, c) -> cat(dims=ndims(X),
SVST(Lpart(X), c * lambda_L * scaleL),
S_prox(Spart(X), c * lambda_S),
);
if false # check functions
@assert Fcost(X0) isa Real
tmp = f_grad(X0)
@assert size(tmp) == size(X0)
tmp = SVST(Lpart(X0), 1)
@assert size(tmp) == size(L0)
tmp = S_prox(S0, 1)
@assert size(tmp) == size(S0)
tmp = g_prox(X0, 1)
@assert size(tmp) == size(X0)
end
niter = 10
fun = (iter, xk, yk, is_restart) -> (Fcost(xk), xk); # logger
Run PGM
if !@isdefined(Mpgm)
f_mu = 2/0.99 - f_L # trick to match 0.99 step size in Lin 1999
f_mu = 0
xpgm, out_pgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
f_mu, mom = :pgm, niter, g_prox, fun)
Mpgm = AII * xpgm
end;
Run FPGM (FISTA)
if !@isdefined(Mfpgm)
xfpgm, out_fpgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
mom = :fpgm, niter, g_prox, fun)
Mfpgm = AII * xfpgm
end;
Run POGM
if !@isdefined(Mpogm)
xpogm, out_pogm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
mom = :pogm, niter, g_prox, fun)
Mpogm = AII * xpogm
end;
Look at final POGM image components
px = jim(
jif(Lpart(xpogm), "L"),
jif(Spart(xpogm), "S"),
jif(Mpogm, "M=L+S"),
jif(Xinf, "Minf"),
)
Plot cost function
costs = out -> [o[1] for o in out]
nrmsd = out -> [norm(AII*o[2]-Xinf)/norm(Xinf) for o in out]
cost_pgm = costs(out_pgm)
cost_fpgm = costs(out_fpgm)
cost_pogm = costs(out_pogm)
pc = plot(xlabel = "iteration", ylabel = "cost")
plot!(0:niter, cost_pgm, marker=:circle, label="PGM (ISTA)")
plot!(0:niter, cost_fpgm, marker=:square, label="FPGM (FISTA)")
plot!(0:niter, cost_pogm, marker=:star, label="POGM")
Plot NRMSD vs Matlab Xinf
nrmsd_pgm = nrmsd(out_pgm)
nrmsd_fpgm = nrmsd(out_fpgm)
nrmsd_pogm = nrmsd(out_pogm)
pd = plot(xlabel = "iteration", ylabel = "NRMSD vs Matlab Xinf")
plot!(0:niter, nrmsd_pgm, marker=:circle, label="PGM (ISTA)")
plot!(0:niter, nrmsd_fpgm, marker=:square, label="FPGM (FISTA)")
plot!(0:niter, nrmsd_pogm, marker=:star, label="POGM")
Discussion
todo
Reproducibility
This page was generated with the following version of Julia:
using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
12-element Vector{SubString{String}}:
"Julia Version 1.10.3"
"Commit 0b4590a5507 (2024-04-30 10:59 UTC)"
"Build Info:"
" Official https://julialang.org/ release"
"Platform Info:"
" OS: Linux (x86_64-linux-gnu)"
" CPU: 4 × AMD EPYC 7763 64-Core Processor"
" WORD_SIZE: 64"
" LIBM: libopenlibm"
" LLVM: libLLVM-15.0.7 (ORCJIT, znver3)"
"Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)"
""
And with the following package versions
import Pkg; Pkg.status()
Status `~/work/Examples/Examples/docs/Project.toml`
[e30172f5] Documenter v1.4.1
[940e8692] Examples v0.0.1 `~/work/Examples/Examples`
[7a1cc6ca] FFTW v1.8.0
[9ee76f2b] ImageGeoms v0.11.1
[787d08f9] ImageMorphology v0.4.5
[71a99df6] ImagePhantoms v0.8.1
[b964fa9f] LaTeXStrings v1.3.1
[7031d0ef] LazyGrids v1.0.0
[599c1a8e] LinearMapsAA v0.12.0
[98b081ad] Literate v2.18.0
[23992714] MAT v0.10.7
[7035ae7a] MIRT v0.18.1
[170b2178] MIRTjim v0.25.0
[efe261a4] NFFT v0.13.3
[91a5bcdd] Plots v1.40.4
[2913bbd2] StatsBase v0.34.3
[1986cc42] Unitful v1.20.0
[b77e0a4c] InteractiveUtils
[37e2e46d] LinearAlgebra
[44cfe95a] Pkg v1.10.0
[9a3f8284] Random
This page was generated using Literate.jl.