Skip to content

Commit

Permalink
Output RMSE map and time series for decay model fit (#1044)
Browse files Browse the repository at this point in the history
* Draft function to calculate decay model fit.

* Calculate root mean squared error instead.

* Incorporate metrics.

* Output RMSE results.

* Output results in tedana.

* Hopefully fix things.

* Update decay.py

* Try improving performance.

* Update decay.py

* Fix again.

* Use tqdm.

* Update decay.py

* Update decay.py

* Update decay.py

* Update expected outputs.

* Add figures.

* Update outputs.

* Include global signal in confounds file.

* Update fiu_four_echo_outputs.txt

* Rename function.

* Rename function.

* Update tedana.py

* Update tedana/decay.py

Co-authored-by: Dan Handwerker <[email protected]>

* Update decay.py

* Update decay.py

* Whoops.

* Apply suggestions from code review

Co-authored-by: Dan Handwerker <[email protected]>

* Fix things maybe.

* Fix things.

* Update decay.py

* Remove any files that are built through appending.

* Update outputs.

* Add section on plots to docs.

* Fix the description.

* Update docs/outputs.rst

Co-authored-by: Dan Handwerker <[email protected]>

* Update docs/outputs.rst

* Fix docstring.

---------

Co-authored-by: Dan Handwerker <[email protected]>
  • Loading branch information
tsalo and handwerkerd authored Apr 29, 2024
1 parent af5e99a commit 0f6cbe1
Show file tree
Hide file tree
Showing 16 changed files with 324 additions and 18 deletions.
Binary file added docs/_static/rmse_plots.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
31 changes: 29 additions & 2 deletions docs/outputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ report.txt A s
"high kappa ts img": desc-optcomAccepted_bold.nii.gz High-kappa time series. This dataset does not
include thermal noise or low variance components.
Not the recommended dataset for analysis.
"confounds tsv": desc-confounds_timeseries.tsv Summary time series measures, including RMSE measures
of T2*/S0 model fit.
references.bib The BibTeX entries for references cited in
report.txt.

Expand Down Expand Up @@ -167,8 +169,8 @@ If ``gscontrol`` includes 'gsr'
Key: Filename Content
================================================================= =====================================================
"gs img": desc-globalSignal_map.nii.gz Spatial global signal
"global signal time series tsv": desc-globalSignal_timeseries.tsv Time series of global signal from optimally combined
data.
"confounds tsv": desc-confounds_timeseries.tsv Time series of global signal from optimally combined
data will be added to this file.
"has gs combined img": desc-optcomWithGlobalSignal_bold.nii.gz Optimally combined time series with global signal
retained.
"removed gs combined img": desc-optcomNoGlobalSignal_bold.nii.gz Optimally combined time series with global signal
Expand Down Expand Up @@ -563,6 +565,31 @@ It is important to note that the histogram is limited from 0 to the 98th percent
:height: 400px


*********************
Decay Model Fit Plots
*********************

Below the T2* and S0 summary plots are the decay model fit plots.
These plots show residual mean squared error (RMSE) values for the
monoexponential decay model, based on the T2* and S0 maps.

The first plot is the mean RMSE brain plot, which shows the mean RMSE over time for each voxel in the brain.
This plot is limited from the 2nd percentile to the 98th percentile.

The second plot is a time series of RMSE values across the brain, over time.
This plot includes the median RMSE time series,
along with an error band representing the 25th and 75th percentiles,
and dotted lines indicating the 2nd and 98th percentile RMSE values.

The fit quality will vary depending on acquisition parameters and will likely be worse near signal drop-out areas.
For a study with consistent acquisition parameters,
relatively high RMSE values for runs or timepoints might be a marker of an underlying data quality issue.

.. image:: /_static/rmse_plots.png
:align: center
:height: 400px


**************************
Citable workflow summaries
**************************
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ dependencies = [
"scikit-learn>=0.21, <=1.4.2",
"scipy>=1.2.0, <=1.13.0",
"threadpoolctl",
"tqdm",
]
dynamic = ["version"]

Expand Down
109 changes: 107 additions & 2 deletions tedana/decay.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Functions to estimate S0 and T2* from multi-echo data."""

import logging
from typing import List, Literal, Tuple

import numpy as np
import numpy.matlib
import pandas as pd
import scipy
from scipy import stats
from tqdm.auto import tqdm

from tedana import utils

Expand Down Expand Up @@ -112,7 +116,7 @@ def fit_monoexponential(data_cat, echo_times, adaptive_mask, report=True):
"estimate T2* and S0. In cases of model fit failure, T2*/S0 "
"estimates from the log-linear fit were retained instead."
)
n_samp, n_echos, n_vols = data_cat.shape
n_samp, _, n_vols = data_cat.shape

# Currently unused
# fit_data = np.mean(data_cat, axis=2)
Expand Down Expand Up @@ -151,7 +155,7 @@ def fit_monoexponential(data_cat, echo_times, adaptive_mask, report=True):
# perform a monoexponential fit of echo times against MR signal
# using loglin estimates as initial starting points for fit
fail_count = 0
for voxel in voxel_idx:
for voxel in tqdm(voxel_idx, desc=f"{echo_num}-echo monoexponential"):
try:
popt, cov = scipy.optimize.curve_fit(
monoexponential,
Expand Down Expand Up @@ -460,3 +464,104 @@ def fit_decay_ts(data, tes, mask, adaptive_mask, fittype):
report = False

return t2s_limited_ts, s0_limited_ts, t2s_full_ts, s0_full_ts


def rmse_of_fit_decay_ts(
*,
data: np.ndarray,
tes: List[float],
adaptive_mask: np.ndarray,
t2s: np.ndarray,
s0: np.ndarray,
fitmode: Literal["all", "ts"],
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Estimate model fit of voxel- and timepoint-wise monoexponential decay models to ``data``.
Parameters
----------
data : (S x E x T) :obj:`numpy.ndarray`
Multi-echo data array, where `S` is samples, `E` is echos, and `T` is time.
tes : (E,) :obj:`list`
Echo times.
adaptive_mask : (S,) :obj:`numpy.ndarray`
Array where each value indicates the number of echoes with good signal for that voxel.
This mask may be thresholded; for example, with values less than 3 set to 0.
For more information on thresholding, see :func:`~tedana.utils.make_adaptive_mask`.
t2s : (S [x T]) :obj:`numpy.ndarray`
Voxel-wise (and possibly volume-wise) T2* estimates from
:func:`~tedana.decay.fit_decay_ts`.
s0 : (S [x T]) :obj:`numpy.ndarray`
Voxel-wise (and possibly volume-wise) S0 estimates from :func:`~tedana.decay.fit_decay_ts`.
fitmode : {"fit", "all"}
Whether the T2* and S0 estimates are volume-wise ("fit") or not ("all").
Returns
-------
rmse_map : (S,) :obj:`numpy.ndarray`
Mean root mean squared error of the model fit across all volumes at each voxel.
rmse_df : :obj:`pandas.DataFrame`
Each column is the root mean squared error of the model fit at each timepoint.
Columns are mean, standard deviation, and percentiles across voxels. Column labels are
"rmse_mean", "rmse_std", "rmse_min", "rmse_percentile02", "rmse_percentile25",
"rmse_median", "rmse_percentile75", "rmse_percentile98", and "rmse_max"
"""
n_samples, _, n_vols = data.shape
tes = np.array(tes)

rmse = np.full([n_samples, n_vols], np.nan, dtype=np.float32)
# n_good_echoes interates from 2 through the number of echoes
# 0 and 1 are excluded because there aren't T2* and S0 estimates
# for less than 2 good echoes. 2 echoes will have a bad estimate so consider
# how/if we want to distinguish those
for n_good_echoes in range(2, len(tes) + 1):
# a boolean mask for voxels with a specific num of good echoes
use_vox = adaptive_mask == n_good_echoes
data_echo = data[use_vox, :n_good_echoes, :]
if fitmode == "all":
s0_echo = numpy.matlib.repmat(s0[use_vox].T, n_vols, 1).T
t2s_echo = numpy.matlib.repmat(t2s[use_vox], n_vols, 1).T
elif fitmode == "ts":
s0_echo = s0[use_vox, :]
t2s_echo = t2s[use_vox, :]
else:
raise ValueError(f"Unknown fitmode option {fitmode}")

predicted_data = np.full([use_vox.sum(), n_good_echoes, n_vols], np.nan, dtype=np.float32)
# Need to loop by echo since monoexponential can take either single vals for s0 and t2star
# or a single TE value.
# We could expand that func, but this is a functional solution
for echo_num in range(n_good_echoes):
predicted_data[:, echo_num, :] = monoexponential(
tes=tes[echo_num],
s0=s0_echo,
t2star=t2s_echo,
)
rmse[use_vox, :] = np.sqrt(np.mean((data_echo - predicted_data) ** 2, axis=1))

rmse_map = np.nanmean(rmse, axis=1)
rmse_timeseries = np.nanmean(rmse, axis=0)
rmse_sd_timeseries = np.nanstd(rmse, axis=0)
rmse_percentiles_timeseries = np.nanpercentile(rmse, [0, 2, 25, 50, 75, 98, 100], axis=0)

rmse_df = pd.DataFrame(
columns=[
"rmse_mean",
"rmse_std",
"rmse_min",
"rmse_percentile02",
"rmse_percentile25",
"rmse_median",
"rmse_percentile75",
"rmse_percentile98",
"rmse_max",
],
data=np.column_stack(
(
rmse_timeseries,
rmse_sd_timeseries,
rmse_percentiles_timeseries.T,
)
),
)

return rmse_map, rmse_df
2 changes: 1 addition & 1 deletion tedana/gscontrol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def gscontrol_raw(catd, optcom, n_echos, io_generator, dtrank=4):
glsig = stats.zscore(glsig, axis=None)

glsig_df = pd.DataFrame(data=glsig.T, columns=["global_signal"])
io_generator.save_file(glsig_df, "global signal time series tsv")
io_generator.add_df_to_file(glsig_df, "confounds tsv")
glbase = np.hstack([legendre_arr, glsig.T])

# Project global signal out of optimally combined data
Expand Down
33 changes: 33 additions & 0 deletions tedana/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def __init__(
LGR.info(f"Generating figures directory: {self.figures_dir}")
os.mkdir(self.figures_dir)

# Remove files that are appended to instead of overwritten.
if overwrite:
files_to_remove = ["confounds tsv"]
for file_ in files_to_remove:
filepath = self.get_name(file_)
if op.exists(filepath):
os.remove(filepath)

def _determine_extension(self, description, name):
"""Infer the extension for a file based on its description.
Expand Down Expand Up @@ -346,6 +354,31 @@ def save_tsv(self, data, name):
deblanked = data.replace("", np.nan)
deblanked.to_csv(name, sep="\t", lineterminator="\n", na_rep="n/a", index=False)

def add_df_to_file(self, data, description, **kwargs):
"""Add a DataFrame to a tsv file, which may or may not exist.
Parameters
----------
data : dict or img_like or pandas.DataFrame
Data to save to file.
description : str
Description of the data, used to determine the appropriate filename from
``self.config``.
Returns
-------
name : str
The full file path of the saved file.
"""
name = self.get_name(description, **kwargs)
if op.isfile(name):
old_data = pd.read_table(name)
data = pd.concat([old_data, data], axis=1, ignore_index=False)

self.save_tsv(data, name)

return name

def save_self(self):
"""Save the registry to a json file.
Expand Down
7 changes: 7 additions & 0 deletions tedana/reporting/data/html/report_body_template.html
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,13 @@ <h2>S0</h2>
<div class="carpet-plots-image">
<img id="s0Histogram" src="$s0Histogram" style="height: 500px" />
</div>
<h2>T2* and S0 model fit (RMSE). (Scaled between 2nd and 98th percentiles)</h2>
<div class="carpet-plots-image">
<img id="rmseBrainPlot" src="$rmseBrainPlot" style="height:500px;" />
</div>
<div class="carpet-plots-image">
<img id="rmseTimeseries" src="$rmseTimeseries" style="height:500px;" />
</div>
</div>
</div>
<div class="info">
Expand Down
5 changes: 5 additions & 0 deletions tedana/reporting/html_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
t2star_histogram = f"./figures/{prefix}t2star_histogram.svg"
s0_brain = f"./figures/{prefix}s0_brain.svg"
s0_histogram = f"./figures/{prefix}s0_histogram.svg"
rmse_brain = f"./figures/{prefix}rmse_brain.svg"
rmse_timeseries = f"./figures/{prefix}rmse_timeseries.svg"

# Convert bibtex to html
references, bibliography = _bib2html(references)
Expand All @@ -162,6 +164,7 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
body_template_path = resource_path.joinpath(body_template_name)
with open(str(body_template_path)) as body_file:
body_tpl = Template(body_file.read())

body = body_tpl.substitute(
content=bokeh_id,
info=info_table,
Expand All @@ -173,6 +176,8 @@ def _update_template_bokeh(bokeh_id, info_table, about, prefix, references, boke
t2starHistogram=t2star_histogram,
s0BrainPlot=s0_brain,
s0Histogram=s0_histogram,
rmseBrainPlot=rmse_brain,
rmseTimeseries=rmse_timeseries,
references=references,
javascript=bokeh_js,
buttons=buttons,
Expand Down
86 changes: 86 additions & 0 deletions tedana/reporting/static_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,92 @@ def plot_t2star_and_s0(
)


def plot_rmse(
*,
io_generator: io.OutputGenerator,
adaptive_mask: np.ndarray,
):
"""Plot the residual mean squared error map and time series for the monoexponential model fit.
Parameters
----------
io_generator : :obj:`~tedana.io.OutputGenerator`
The output generator for this workflow.
adaptive_mask : (S,) :obj:`numpy.ndarray`
A mask where each value is the number of good echoes.
Since the T2* and S0 estimations require a minimum of 2 good echoes,
the outputted plots will only include mask values of at least 2.
"""
import pandas as pd

rmse_img = io_generator.get_name("rmse img")
confounds_file = io_generator.get_name("confounds tsv")
# Mask that only includes values >=2 (i.e. at least 2 good echoes)
mask_img = io.new_nii_like(io_generator.reference_img, (adaptive_mask >= 2).astype(np.int32))

rmse_data = masking.apply_mask(rmse_img, mask_img)
rmse_p02, rmse_p98 = np.percentile(rmse_data, [2, 98])

# Get repetition time from reference image
tr = io_generator.reference_img.header.get_zooms()[-1]

# Load the confounds file
confounds_df = pd.read_table(confounds_file)

fig, ax = plt.subplots(figsize=(10, 6))
rmse_arr = confounds_df["rmse_median"].values
p25_arr = confounds_df["rmse_percentile25"].values
p75_arr = confounds_df["rmse_percentile75"].values
p02_arr = confounds_df["rmse_percentile02"].values
p98_arr = confounds_df["rmse_percentile98"].values
time_arr = np.arange(confounds_df.shape[0]) * tr
ax.plot(time_arr, rmse_arr, color="black")
ax.fill_between(
time_arr,
p25_arr,
p75_arr,
color="blue",
alpha=0.2,
)
ax.plot(time_arr, p02_arr, color="black", linestyle="dashed")
ax.plot(time_arr, p98_arr, color="black", linestyle="dashed")
ax.set_ylabel("RMSE", fontsize=16)
ax.set_xlabel(
"Time (s)",
fontsize=16,
)
ax.legend(["Median", "25th-75th percentiles", "2nd and 98th percentiles"])
ax.set_title("Root mean squared error of T2* and S0 fit across voxels", fontsize=20)
rmse_ts_plot = os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}rmse_timeseries.svg",
)
ax.set_xlim(0, time_arr[-2])
fig.savefig(rmse_ts_plot)
plt.close(fig)

# Plot RMSE
rmse_brain_plot = os.path.join(
io_generator.out_dir,
"figures",
f"{io_generator.prefix}rmse_brain.svg",
)
plotting.plot_stat_map(
rmse_img,
bg_img=None,
display_mode="mosaic",
cut_coords=4,
symmetric_cbar=False,
black_bg=True,
cmap="Reds",
vmin=rmse_p02,
vmax=rmse_p98,
annotate=False,
output_file=rmse_brain_plot,
)


def plot_adaptive_mask(
*,
optcom: np.ndarray,
Expand Down
Loading

0 comments on commit 0f6cbe1

Please sign in to comment.