VarPro: Two exponentials
Illustrate fitting bi-exponential model to data.
See:
- VarPro blog
- VP4Optim.jl has biexponential fits
- Varpro.jl
This page comes from a single Julia file: 99-varpro2.jl.
You can access the source code for such Julia documentation using the 'Edit on GitHub' link in the top right. You can view the corresponding notebook in nbviewer here: 99-varpro2.ipynb, or open it in binder here: 99-varpro2.ipynb.
Setup
Packages needed here.
import ForwardDiff
using ImagePhantoms: ellipse, ellipse_parameters, phantom, SheppLoganBrainWeb
using Statistics: mean, std
using Plots: default, gui, histogram, plot, plot!, scatter, scatter!
using Plots: cgrad, RGB
default(markerstrokecolor=:auto, label="", widen = true)
using LaTeXStrings
using LinearAlgebra: norm, Diagonal, diag, diagm, qr
using MIRTjim: jim
using Random: seed!; seed!(0)
#using Unitful: @u_str, ms, s, mm
ms = 0.001; s = 1; mm = 1 # avoid Unitful due to ForwardDiff
using InteractiveUtils: versioninfoDouble exponential (bi-exponential)
We explore a simple case: fitting a bi-exponential to some noisy data:
\[y_m = c_a e^{- r_a t_m} + c_b e^{- r_b t_m} + ϵ_m ,\quad m = 1,…,M.\]
The four unknown parameters here are:
- the decay rates $r_a, r_b ≥ 0$
- the amplitudes (aka coefficients) $c_a, c_b$ (that could be complex in some MRI settings)
Tf = Float32
Tc = Complex{Tf}
M = 18 # how many samples (more than in single exponential demo!)
Δte = 10ms # echo spacing
te1 = 5ms # time of first echo
tm = Tf.(te1 .+ (0:(M-1)) * Δte) # echo times
Lin = (:ca, :cb) # linear model parameter names
Non = (:ra, :rb) # nonlinear model parameter names
Tlin = NamedTuple{Lin}
Tnon = NamedTuple{Non}
Tall = NamedTuple{(Lin..., Non...)}
c_true = Tlin(Tf.([60, 40])) # AU
r_true = Tnon(Tf.([100/s, 20/s]))
x_true = (; c_true..., r_true...) # all parameters(ca = 60.0f0, cb = 40.0f0, ra = 100.0f0, rb = 20.0f0)Signal model
This next function is the signal basis function(s) from physics. This is the only function that is model-specific.
signal_bases(ra::Number, rb::Number, t::Number) =
[exp(-t * ra); exp(-t * rb)]; # bi-exponential modelThe following signal helper functions apply to many models having a mix of linear and nonlinear signal parameters.
Here there is just one scan parameter (echo time).
- These functions would need to be generalized
to handle multiple scan parameters (e.g., echo time, phase cycling factor, flip angle).
- They would also need to be generalized
to handle models with "known" parameters (e.g., B0 and B1+).
signal_bases(non::Tnon, t::Number) =
signal_bases(non..., t)
signal_bases(non::Tnon, tv::AbstractVector) =
stack(t -> signal_bases(non, t), tv; dims=1);Signal model that combines nonlinear and linear effects:
signal(lin::Tlin, non::Tnon, tv::AbstractVector) =
signal_bases(non, tv) * collect(lin);Signal model helpers:
signal(lin::AbstractVector, non::Tnon, tv::AbstractVector) =
signal(Tlin(lin), non, tv)
signal(lin, non::AbstractVector, tv::AbstractVector) =
signal(lin, Tnon(non), tv)
function signal(x::Tall, tv::AbstractVector)
fun = name -> getfield(x, name)
signal(Tlin(fun.(Lin)), Tnon(fun.(Non)), tv)
end
signal(x::AbstractVector, tv::AbstractVector) =
signal(Tall(x), tv)signal (generic function with 5 methods)Simulate data:
y_true = signal(c_true, r_true, tm)
@assert y_true == signal(x_true, tm)
tf = Tf.(range(0, M, 201) * Δte) # fine sampling for plots
yf = signal(c_true, r_true, tf)
xaxis_t = ("t [ms]", (0,200), (0:4:M)*Δte/ms) # no units
py = plot( xaxis = xaxis_t, yaxis = ("y", (0,100)) )
plot!(py, tf/ms, yf, color=:black)
scatter!(py, tm/ms, y_true, label = "Noiseless data, M=$M samples")Random phase and noise
Actual MRI data has some phase and noise.
phase_true = rand() * 2π + 0π
y_true_phased = Tc.(cis(phase_true) * y_true)
snr = 25 # dB
snr2sigma(db, y) = 10^(-db/20) * norm(y) / sqrt(length(y))
σ = Tf(snr2sigma(snr, y_true_phased))
yc = y_true_phased + σ * randn(Tc, M)
@show 20 * log10(norm(yc) / norm(yc - y_true_phased)) # check σ26.45752f0The phase of the noisy data becomes unreliable for low signal values:
pp = scatter(tm/ms, angle.(y_true_phased), label = "True data",
xaxis = xaxis_t,
yaxis = ("Phase", (-π, π), ((-1:1)*π, ["-π", "0", "π"])),
)
scatter!(tm, angle.(yc), label="Noisy data")pc = scatter(tm/ms, real(yc),
label = "Noisy data - real part",
xaxis = xaxis_t,
ylim = (-100, 100),
)
scatter!(pc, tm/ms, imag(yc),
label = "Noisy data - imag part",
)Phase correction
Phase correct signal using phase of first (noisy) data point
yr = conj(sign(yc[1])) .* yc
pr = deepcopy(py)
scatter!(pr, tm/ms, real(yr),
label = "Phase corrected data - real part",
xaxis = xaxis_t,
ylim = (-5, 105),
marker = :square,
)
scatter!(pr, tm/ms, imag(yr),
label = "Phase corrected data - imag part",
)Examine the distribution of real part after phase correction
function make1_phase_corrected_signal()
phase_true = rand() * 2π
y_true_phased = Tc.(cis(phase_true) * y_true)
yc = y_true_phased + σ * randn(Tc, M)
yr = conj(sign(yc[1])) .* yc
end
N = 2000
ysim = stack(_ -> make1_phase_corrected_signal(), 1:N)
tmp = ysim[end,:];
pe = scatter(real(tmp), imag(tmp), aspect_ratio=1,
xaxis = ("real(y_$M)", (-4,8), -3:8),
yaxis = ("imag(y_$M)", (-6,6), -5:5),
)
plot!(pe, real(y_true[end]) * [1,1], [-5, 5])
plot!(pe, [-5, 5] .+ 3, imag(y_true[end]) * [1,1])Histogram of the real part looks reasonably Gaussian
ph = histogram((real(tmp) .- real(y_true[end])) / (σ/√2), bins=-4:0.1:4,
xlabel = "Real part of phase-corrected signal y_$M")
roundr(rate) = round(rate, digits=2);CRB
Compute CRB for precision of unbiased estimator. This requires inverting the Fisher information matrix. If the Fisher information matrix has units, then Julia's built-in inverse inv does not work. See 2.4.5.2 of Fessler 2024 book for tips.
"""
Matrix inverse for matrix whose units are suitable for inversion,
meaning `X = Diag(left) * Z * Diag(right)`
where `Z` is unitless and `left` and `right` are vectors with units.
(Fisher information matrices always have this structure.)
[irrelevant when units are excluded]
"""
function inv_unitful(X::Matrix{<:Number})
right = oneunit.(X[1,:]) # units for "right side" of matrix
left = oneunit.(X[:,1] / right[1]) # units for "left side" of matrix
N = size(X,1)
Z = [X[i,j] / (left[i] * right[j]) for i in 1:N, j in 1:N]
Zinv = inv(Z) # Z should be unitless if X has inverse-appropriate units
Xinv = [Zinv[i,j] / (right[i] * left[j]) for i in 1:N, j in 1:N]
return Xinv
end;
signal(x) = signal(x, tm)
@assert y_true == signal(collect(x_true))Jacobian of signal w.r.t. both linear and nonlinear parameters
jac = ForwardDiff.jacobian(signal, collect(x_true));Fisher information
fish = jac' * jac / σ^2;Compute CRB from Fisher information via matrix inverse
crb = inv_unitful(fish)
round3(x) = round(x; digits=3)
crb_std = Tall(round3.(sqrt.(diag(crb)))) # relabel CRB std. deviations(ca = 5.03f0, cb = 6.177f0, ra = 19.046f0, rb = 2.103f0)Dictionary matching
This approach is essentially a quantized maximum-likelihood estimator. Here the quantization interval of 0.5/s turns out to be much smaller than the estimator standard deviation, so the quantization error seems negligible.
Simple dot products seem inapplicable to a 2-pool model, so we use VarPro.
The VarPro cost function (for complex coefficients) becomes
\[f(r) = -(A'y)' (A'A)^{-1} A'y\]
where $A = A(r)$ is a $M × 2$ matrix for each r.
By applying the QR decomposition of A, the cost function simplifies to $-‖Q'y‖₂$, which is a natural extension of the dot product used in dictionary matching.
ra_list = Tf.(range(50/s, 160/s, 221)) # linear spacing?
rb_list = Tf.(range(0/s, 40/s, 81)) # linear spacing?
bases_unnormalized(ra, rb) = signal_bases((;ra, rb), tm)
dict = bases_unnormalized.(ra_list, rb_list');Plot the fast and slow dictionary components
tmp = stack(first ∘ eachcol, dict[:,1])
pd1 = plot(tm/ms, tmp[:,1:5:end]; xaxis=xaxis_t, marker=:o)
tmp = stack(last ∘ eachcol, dict[1,:])
pd2 = plot(tm/ms, tmp[:,1:5:end]; xaxis=xaxis_t, marker=:o)
pd12 = plot(pd1, pd2, plot_title = "Dictionary")dict_q = map(A -> Matrix(qr(A).Q), dict)
dict_q = map(A -> sign(A[1]) * A, dict_q); # preserve sign of 1st basistmp = stack(first ∘ eachcol, dict_q[:,1])
pq1 = plot(tm/ms, tmp[:,1:5:end]; xaxis=xaxis_t, marker=:o)
tmp = stack(last ∘ eachcol, dict_q[1,:])
pq2 = plot(tm/ms, tmp[:,1:5:end]; xaxis=xaxis_t, marker=:o)
pq12 = plot(pq1, pq2, plot_title = "Orthogonalized Dictionary")varpro_cost(Q::Matrix, y::AbstractVector) = -norm(Q'*y)
varpro_best(y) = findmin(Q -> varpro_cost(Q, y), dict_q)[2]
if !@isdefined(i_vp) # perform dictionary matching via VarPro
i_vp = map(varpro_best, eachcol(ysim));
end
ra_dm = map(i -> ra_list[i[1]], i_vp) # dictionary matching estimates
rb_dm = map(i -> rb_list[i[2]], i_vp)
ph_ra_dm = histogram(ra_dm, bins=60:5:160,
label = "Mean=$(roundr(mean(ra_dm))), σ=$(roundr(std(ra_dm)))",
xaxis = ("Ra estimate via dictionary matching", (60, 160)./s),
)
plot!(r_true.ra*[1,1], [0, 3e2])
plot!(annotation = (140, 200, "CRB(Ra) = $(roundr(crb_std.ra))", :red))ph_rb_dm = histogram(rb_dm, bins=14:0.5:26,
label = "Mean=$(roundr(mean(rb_dm))), σ=$(roundr(std(rb_dm)))",
xaxis = ("Rb estimate via dictionary matching", (14, 26)./s),
)
plot!(r_true.rb*[1,1], [0, 3e2])
plot!(annotation = (24, 200, "CRB = $(roundr(crb_std.rb))", :red))Future work
- Compare to ML via VarPro
- Compare to ML via NLLS
- Cost contours, before and after eliminating x
- MM approach?
- GD?
- Newton's method?
- Units?
Reproducibility
This page was generated with the following version of Julia:
using InteractiveUtils: versioninfo
io = IOBuffer(); versioninfo(io); split(String(take!(io)), '\n')12-element Vector{SubString{String}}:
"Julia Version 1.12.4"
"Commit 01a2eadb047 (2026-01-06 16:56 UTC)"
"Build Info:"
" Official https://julialang.org release"
"Platform Info:"
" OS: Linux (x86_64-linux-gnu)"
" CPU: 4 × AMD EPYC 7763 64-Core Processor"
" WORD_SIZE: 64"
" LLVM: libLLVM-18.1.7 (ORCJIT, znver3)"
" GC: Built with stock GC"
"Threads: 1 default, 1 interactive, 1 GC (on 4 virtual cores)"
""And with the following package versions
import Pkg; Pkg.status()Status `~/work/Examples/Examples/docs/Project.toml`
[e30172f5] Documenter v1.16.1
[940e8692] Examples v0.0.1 `~/work/Examples/Examples`
[7a1cc6ca] FFTW v1.10.0
[f6369f11] ForwardDiff v1.3.2
[9ee76f2b] ImageGeoms v0.11.2
[787d08f9] ImageMorphology v0.4.7
[71a99df6] ImagePhantoms v0.8.1
[b964fa9f] LaTeXStrings v1.4.0
[7031d0ef] LazyGrids v1.1.0
[599c1a8e] LinearMapsAA v0.12.0
[98b081ad] Literate v2.21.0
[23992714] MAT v0.11.4
[7035ae7a] MIRT v0.18.3
[170b2178] MIRTjim v0.26.0
[efe261a4] NFFT v0.14.3
[91a5bcdd] Plots v1.41.4
[10745b16] Statistics v1.11.1
[2913bbd2] StatsBase v0.34.10
[1986cc42] Unitful v1.28.0
[b77e0a4c] InteractiveUtils v1.11.0
[37e2e46d] LinearAlgebra v1.12.0
[44cfe95a] Pkg v1.12.1
[9a3f8284] Random v1.11.0This page was generated using Literate.jl.