L+S 2D dynamic recon

This page illustrates dynamic parallel MRI image reconstruction using a low-rank plus sparse (L+S) model optimized by a fast algorithm described in the paper by Claire Lin and Jeff Fessler Efficient Dynamic Parallel MRI Reconstruction for the Low-Rank Plus Sparse Model, IEEE Trans. on Computational Imaging, 5(1):17-26, 2019, by Claire Lin and Jeff Fessler, EECS Department, University of Michigan.

The Julia code here is a translation of part of the Matlab code used in the original paper.

If you use this code, please cite that paper.

This page comes from a single Julia file: 5-l-plus-s.jl.

You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 5-l-plus-s.ipynb, or open it in binder here: 5-l-plus-s.ipynb.

Setup

Packages needed here.

# using Unitful: s
using Plots; default(markerstrokecolor=:auto, label="")
using MIRT: Afft, Asense, embed
using MIRT: pogm_restart, poweriter
using MIRTjim: jim, prompt
using FFTW: fft!, bfft!, fftshift!
using LinearMapsAA: LinearMapAA, block_diag, redim, undim
using MAT: matread
import Downloads # todo: use Fetch or DataDeps?
using LinearAlgebra: dot, norm, svd, svdvals, Diagonal, I
using Random: seed!
using StatsBase: mean
using LaTeXStrings

The following line is helpful when running this file as a script; this way it will prompt user to hit a key after each figure is displayed.

jif(args...; kwargs...) = jim(args...; prompt=false, kwargs...)
isinteractive() ? jim(:prompt, true) : prompt(:draw);

Overview

Dynamic image reconstruction using a "low-rank plus sparse" or "L+S" approach was proposed by Otazo et al. and uses the following cost function:

\[X = \hat{L} + \hat{S} ,\qquad (\hat{L}, \hat{S}) = \arg \min_{L,S} \frac{1}{2} \| E (L + S) - d \|_2^2 + λ_L \| L \|_* + λ_S \| vec(T S) \|_1\]

where $T$ is a temporal unitary FFT, $E$ is an encoding operator (system matrix), and $d$ is Cartesian undersampled multicoil k-space data.

The Otazo paper used an iterative soft thresholding algorithm (ISTA) to solve this optimization problem. Using FISTA is faster, but using the proximal optimized gradient method (POGM) with adaptive restart is even faster.

This example reproduces part of Figures 1 & 2 in Claire Lin's paper, based on the cardiac perfusion example.

Read data

if !@isdefined(data)
    url = "https://web.eecs.umich.edu/~fessler/irt/reproduce/19/lin-19-edp/data/"
    dataurl = url * "cardiac_perf_R8.mat"
    data = matread(Downloads.download(dataurl))
    xinfurl = url * "Xinf.mat"
    Xinf = matread(Downloads.download(xinfurl))["Xinf"]["perf"] # (128,128,40)
end;

Show converged image as a preview:

jim(Xinf, L"\mathrm{Converged\ image\ } X_∞")

Organize k-space data:

if !@isdefined(ydata0)
    ydata0 = data["kdata"] # k-space data full of zeros
    ydata0 = permutedims(ydata0, [1, 2, 4, 3]) # (nx,ny,nc,nt)
    ydata0 = ComplexF32.(ydata0)
end
(nx, ny, nc, nt) = size(ydata0)
(128, 128, 12, 40)

Extract sampling pattern from zeros of k-space data:

if !@isdefined(samp)
    samp = ydata0[:,:,1,:] .!= 0
    for ic in 2:nc # verify it is same for all coils
        @assert samp == (ydata0[:,:,ic,:] .!= 0)
    end
    kx = -(nx÷2):(nx÷2-1)
    ky = -(ny÷2):(ny÷2-1)
    jim(kx, ky, samp, "Sampling patterns for $nt frames"; xlabel=L"k_x", ylabel=L"k_y")
end

Prepare coil sensitivity maps

if !@isdefined(smaps)
    smaps_raw = data["b1"] # raw coil sensitivity maps
    jim(smaps_raw, "Raw |coil maps| for $nc coils")
    sum_last = (f, x) -> selectdim(sum(f, x; dims=ndims(x)), ndims(x), 1)
    ssos_fun = smap -> sqrt.(sum_last(abs2, smap)) # SSoS
    ssos_raw = ssos_fun(smaps_raw)
    smaps = smaps_raw ./ ssos_raw
    ssos = ssos_fun(smaps)
    @assert all(≈(1), ssos)
    jim(smaps, "Normalized |coil maps| for $nc coils")
end

Temporal unitary FFT sparsifying transform for image sequence of size (nx, ny, nt):

TF = Afft((nx,ny,nt), 3; unitary=true) # unitary FFT along 3rd (time) dimension
if false # verify adjoint
    tmp1 = randn(ComplexF32, nx, ny, nt)
    tmp2 = randn(ComplexF32, nx, ny, nt)
    @assert dot(tmp2, TF * tmp1) ≈ dot(TF' * tmp2, tmp1)
    @assert TF' * (TF * tmp1) ≈ tmp1
    (size(TF), TF._odim, TF._idim)
end

Examine temporal Fourier sparsity of Xinf. The low temporal frequencies dominate, as expected, because Xinf was reconstructed using this regularizer!

tmp = TF * Xinf
jim(tmp, "|Temporal FFT of Xinf|")

System matrix

Construct dynamic parallel MRI system model. It is block diagonal where each frame has its own sampling pattern. The input (image) here has size (nx=128, ny=128, nt=40). The output (data) has size (nsamp=2048, nc=12, nt=40) because every frame has 16 phase-encode lines of 128 samples.

todo: precompute (i)fft along readout direction to save time

The code in the original Otazo et al. paper used an ifft in the forward model and an fft in the adjoint, so we must use a flag here to match that model.

Aotazo = (samp, smaps) -> Asense(samp, smaps; unitary=true, fft_forward=false) # Otazo style
A = block_diag([Aotazo(s, smaps) for s in eachslice(samp, dims=3)]...)
#A = ComplexF32(1/sqrt(nx*ny)) * A # match Otazo's scaling
(size(A), A._odim, A._idim)
((983040, 655360), (2048, 12, 40), (128, 128, 40))

Reshape data to match the system model

if !@isdefined(ydata)
    tmp = reshape(ydata0, :, nc, nt)
    tmp = [tmp[vec(s),:,it] for (it,s) in enumerate(eachslice(samp, dims=3))]
    ydata = cat(tmp..., dims=3) # (nsamp,nc,nt) = (2048,12,40) no "zeros"
end
size(ydata)
(2048, 12, 40)

Final encoding operator E for L+S because we stack X = [L;S]

tmp = LinearMapAA(I(nx*ny*nt);
    odim=(nx,ny,nt), idim=(nx,ny,nt), T=ComplexF32, prop=(;name="I"))
tmp = kron([1 1], tmp)
AII = redim(tmp; odim=(nx,ny,nt), idim=(nx,ny,nt,2)) # "squeeze" odim
E = A * AII;

Run power iteration to verify that opnorm(E) = √2

if false
    (_, σ1E) = poweriter(undim(E)) # 1.413 ≈ √2
else
    σ1E = √2
end
1.4142135623730951

Check scale factor of Xinf. (It should be ≈1.)

tmp = A * Xinf
scale = dot(tmp, ydata) / norm(tmp)^2 # 1.009 ≈ 1
1.0090379f0 + 8.4778144f-8im

Crude initial image

L0 = A' * ydata # adjoint (zero-filled)
S0 = zeros(ComplexF32, nx, ny, nt)
X0 = cat(L0, S0, dims=ndims(L0)+1) # (nx, ny, nt, 2) = (128, 128, 40, 2)
M0 = AII * X0 # L0 + S0
jim(M0, "|Initial L+S via zero-filled recon|")

L+S reconstruction

Prepare for proximal gradient methods

Scalars to match Otazo's results

scaleL = 130 / 1.2775 # Otazo's stopping St(1) / b1 constant squared
scaleS = 1 / 1.2775; # 1 / b1 constant squared

L+S regularizer

lambda_L = 0.01 # regularization parameter
lambda_S = 0.01 * scaleS
Lpart = X -> selectdim(X, ndims(X), 1) # extract "L" from X
Spart = X -> selectdim(X, ndims(X), 2) # extract "S" from X
nucnorm(L::AbstractMatrix) = sum(svdvals(L)) # nuclear norm
nucnorm(L::AbstractArray) = nucnorm(reshape(L, :, nt)); # (nx*ny, nt) for L

Optimization overall composite cost function

Fcost = X -> 0.5 * norm(E * X - ydata)^2 +
    lambda_L * scaleL * nucnorm(Lpart(X)) + # note scaleL !
    lambda_S * norm(TF * Spart(X), 1);

f_grad = X -> E' * (E * X - ydata); # gradient of data-fit term
#18 (generic function with 1 method)

Lipschitz constant of data-fit term is 2 because A is unitary and AII is like ones(2,2).

f_L = 2; # σ1E^2

Proximal operator for scaled nuclear norm $β | X |_*$: singular value soft thresholding (SVST).

function SVST(X::AbstractArray, β)
    dims = size(X)
    X = reshape(X, :, dims[end]) # assume time frame is the last dimension
    U,s,V = svd(X)
    sthresh = @. max(s - β, 0)
    keep = findall(>(0), sthresh)
    X = U[:,keep] * Diagonal(sthresh[keep]) * V[:,keep]'
    X = reshape(X, dims)
    return X
end;

Combine proximal operators for L and S parts to make overall prox for X

soft = (v,c) -> sign(v) * max(abs(v) - c, 0) # soft threshold function
S_prox = (S, β) -> TF' * soft.(TF * S, β) # 1-norm proximal mapping for unitary TF
g_prox = (X, c) -> cat(dims=ndims(X),
    SVST(Lpart(X), c * lambda_L * scaleL),
    S_prox(Spart(X), c * lambda_S),
);

if false # check functions
    @assert Fcost(X0) isa Real
    tmp = f_grad(X0)
    @assert size(tmp) == size(X0)
    tmp = SVST(Lpart(X0), 1)
    @assert size(tmp) == size(L0)
    tmp = S_prox(S0, 1)
    @assert size(tmp) == size(S0)
    tmp = g_prox(X0, 1)
    @assert size(tmp) == size(X0)
end


niter = 10
fun = (iter, xk, yk, is_restart) -> (Fcost(xk), xk); # logger

Run PGM

if !@isdefined(Mpgm)
    f_mu = 2/0.99 - f_L # trick to match 0.99 step size in Lin 1999
    f_mu = 0
    xpgm, out_pgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
        f_mu, mom = :pgm, niter, g_prox, fun)
    Mpgm = AII * xpgm
end;

Run FPGM (FISTA)

if !@isdefined(Mfpgm)
    xfpgm, out_fpgm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
        mom = :fpgm, niter, g_prox, fun)
    Mfpgm = AII * xfpgm
end;

Run POGM

if !@isdefined(Mpogm)
    xpogm, out_pogm = pogm_restart(X0, (x) -> 0, f_grad, f_L ;
        mom = :pogm, niter, g_prox, fun)
    Mpogm = AII * xpogm
end;

Look at final POGM image components

px = jim(
 jif(Lpart(xpogm), "L"),
 jif(Spart(xpogm), "S"),
 jif(Mpogm, "M=L+S"),
 jif(Xinf, "Minf"),
)

Plot cost function

costs = out -> [o[1] for o in out]
nrmsd = out -> [norm(AII*o[2]-Xinf)/norm(Xinf) for o in out]
cost_pgm = costs(out_pgm)
cost_fpgm = costs(out_fpgm)
cost_pogm = costs(out_pogm)
pc = plot(xlabel = "iteration", ylabel = "cost")
plot!(0:niter, cost_pgm, marker=:circle, label="PGM (ISTA)")
plot!(0:niter, cost_fpgm, marker=:square, label="FPGM (FISTA)")
plot!(0:niter, cost_pogm, marker=:star, label="POGM")

Plot NRMSD vs Matlab Xinf

nrmsd_pgm = nrmsd(out_pgm)
nrmsd_fpgm = nrmsd(out_fpgm)
nrmsd_pogm = nrmsd(out_pogm)
pd = plot(xlabel = "iteration", ylabel = "NRMSD vs Matlab Xinf")
plot!(0:niter, nrmsd_pgm, marker=:circle, label="PGM (ISTA)")
plot!(0:niter, nrmsd_fpgm, marker=:square, label="FPGM (FISTA)")
plot!(0:niter, nrmsd_pogm, marker=:star, label="POGM")

Discussion

todo

Reproducibility

This page was generated with the following version of Julia:

using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')
10-element Vector{SubString{String}}:
 "Julia Version 1.9.1"
 "Commit 147bdf428cd (2023-06-07 08:27 UTC)"
 "Platform Info:"
 "  OS: Linux (x86_64-linux-gnu)"
 "  CPU: 2 × Intel(R) Xeon(R) CPU E5-2673 v4 @ 2.30GHz"
 "  WORD_SIZE: 64"
 "  LIBM: libopenlibm"
 "  LLVM: libLLVM-14.0.6 (ORCJIT, broadwell)"
 "  Threads: 1 on 2 virtual cores"
 ""

And with the following package versions

import Pkg; Pkg.status()
Status `~/work/Examples/Examples/docs/Project.toml`
  [e30172f5] Documenter v0.27.24
  [940e8692] Examples v0.0.1 `~/work/Examples/Examples`
  [7a1cc6ca] FFTW v1.7.0
  [9ee76f2b] ImageGeoms v0.10.0
  [787d08f9] ImageMorphology v0.4.4
  [71a99df6] ImagePhantoms v0.7.2
  [b964fa9f] LaTeXStrings v1.3.0
  [7031d0ef] LazyGrids v0.5.0
  [599c1a8e] LinearMapsAA v0.11.0
  [98b081ad] Literate v2.14.0
  [23992714] MAT v0.10.5
  [7035ae7a] MIRT v0.17.0
  [170b2178] MIRTjim v0.22.0
  [efe261a4] NFFT v0.13.3
  [91a5bcdd] Plots v1.38.16
  [2913bbd2] StatsBase v0.34.0
  [1986cc42] Unitful v1.14.0
  [b77e0a4c] InteractiveUtils
  [37e2e46d] LinearAlgebra
  [44cfe95a] Pkg v1.9.0
  [9a3f8284] Random

This page was generated using Literate.jl.