"""Decode module."""
import os.path as op
from abc import ABCMeta, abstractmethod
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from nilearn import masking
from nimare import decode
from nimare.annotate.gclda import GCLDAModel
from nimare.decode.utils import weight_priors
from nimare.extract import download_abstracts, fetch_neuroquery, fetch_neurosynth
from nimare.io import convert_neurosynth_to_dataset
from nimare.meta.cbma.mkda import MKDAChi2
from tqdm.auto import tqdm
from gradec.fetcher import (
_fetch_features,
_fetch_metamaps,
_fetch_model,
_fetch_spinsamples,
)
from gradec.model import _get_counts, annotate_lda
from gradec.stats import _permtest_pearson
from gradec.transform import _vol_to_surf
from gradec.utils import _check_ncores, _conform_features, get_data_dir
def _get_dataset(dset_nm, data_dir):
data_dir = get_data_dir(op.join(data_dir, "data", dset_nm))
if dset_nm == "neurosynth":
files = fetch_neurosynth(
data_dir=data_dir,
version="7",
overwrite=False,
source="abstract",
vocab="terms",
)
elif dset_nm == "neuroquery":
files = fetch_neuroquery(
data_dir=data_dir,
version="1",
overwrite=False,
source="combined",
vocab="neuroquery6308",
type="tfidf",
)
dataset_db = files[0]
dset = convert_neurosynth_to_dataset(
coordinates_file=dataset_db["coordinates"],
metadata_file=dataset_db["metadata"],
annotations_files=dataset_db["features"],
)
dset.update_path(data_dir)
class Decoder(metaclass=ABCMeta):
"""Base class for decoders in :mod:`~gradec.decode`."""
def __init__(
self,
space="fsLR",
density="32k",
feature_group=None,
use_fetchers=True,
calc_pvals=True,
data_dir=None,
n_cores=1,
):
self.space = space
self.density = density
self.feature_group = feature_group
self.use_fetchers = use_fetchers
self.calc_pvals = calc_pvals
self.data_dir = op.abspath(data_dir) if data_dir else op.abspath(".")
self.neuromaps_dir = get_data_dir(op.join(self.data_dir, "neuromaps"))
self.n_cores = _check_ncores(n_cores)
self.n_top_words = None # Set in child classes
@abstractmethod
def transform(self, dset_nm):
"""Apply decoding to dataset and output a DF."""
@abstractmethod
def _train_decoder(self, dset_nm):
"""Train decoder on dataset."""
@abstractmethod
def _get_features(self, dset_nm):
"""Get features from dataset."""
def fit(self, dset_nm, dset=None):
"""Fit decoder to dataset."""
self.dset_nm = dset_nm
if self.use_fetchers:
metamaps_surf = _fetch_metamaps(
self.dset_nm,
self.model_nm,
space=self.space,
density=self.density,
data_dir=self.data_dir,
)
features = _fetch_features(self.dset_nm, self.model_nm, data_dir=self.data_dir)
else:
self.dset = dset or _get_dataset(dset_nm, self.data_dir)
metamaps = self._train_decoder()
metamaps_surf = _vol_to_surf(
metamaps,
space=self.space,
density=self.density,
neuromaps_dir=self.neuromaps_dir,
)
features = self._get_features()
# This feature is desabled for now because it takes too long.
"""
if self.calc_pvals:
spinsamples_surf = _gen_spinsamples(
self.space,
self.density,
self.neuromaps_dir,
self.n_samples,
self.n_cores,
)
"""
if self.calc_pvals:
spinsamples_surf = _fetch_spinsamples(
n_samples=self.n_samples,
space=self.space,
density=self.density,
data_dir=self.data_dir,
)
self.spinsamples_ = spinsamples_surf
else:
self.spinsamples_ = None
self.maps_ = metamaps_surf
self.features_ = _conform_features(features, self.model_nm, self.n_top_words)
class CorrelationDecoder(Decoder):
"""Decode an unthresholded image by correlating the image with meta-analytic maps."""
def __init__(self, n_samples=1000, **kwargs):
super().__init__(**kwargs)
self.n_samples = n_samples
def transform(self, grad_maps):
"""Correlate target image with each feature-specific meta-analytic map.
Parameters
----------
grad_maps : :obj:`~nibabel.nifti1.Nifti1Image`
Image to decode. Must be in same space as ``dataset``.
Returns
-------
out_df : :obj:`pandas.DataFrame`
DataFrame with one row for each feature, an index named "feature", and one column: "r".
"""
results_lst = tqdm(
zip(
*Parallel(return_as="generator", n_jobs=self.n_cores)(
delayed(_permtest_pearson)(grad_map, self.maps_, self.spinsamples_)
for grad_map in grad_maps
),
),
total=len(grad_maps),
)
data_df_lst = []
for results in results_lst:
data = np.array(results).T
data_df = pd.DataFrame(index=self.features_, data=data)
data_df.index.name = "feature"
if self.calc_pvals:
data_df_lst.append(data_df)
else:
return data_df
return data_df_lst
[docs]
class TermDecoder(CorrelationDecoder):
"""Decode an unthresholded image by correlating the image with meta-analytic maps."""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.model_nm = "term"
def _get_features(self):
decoder_features = list(self.decoder.results_.maps.keys())
return [f.split("__")[-1] for f in decoder_features]
def _train_decoder(self):
frequency_threshold = 0.001
decoder = decode.CorrelationDecoder(
frequency_threshold=frequency_threshold,
meta_estimator=MKDAChi2,
feature_group=self.feature_group,
target_image="z_desc-association",
n_cores=self.n_cores,
)
decoder.fit(self.dset)
self.decoder = decoder
metamaps_arr = np.array(list(decoder.results_.maps.values()))
return decoder.results_.masker.inverse_transform(metamaps_arr)
[docs]
class LDADecoder(CorrelationDecoder):
"""Decode an unthresholded image by correlating the image with meta-analytic maps."""
def __init__(self, n_topics=200, n_top_words=3, **kwargs):
super().__init__(**kwargs)
self.n_topics = n_topics
self.n_top_words = n_top_words
self.model_nm = "lda"
def _get_features(self):
topic_word_weights = self.model.distributions_["p_topic_g_word"]
sorted_weights_idxs = np.argsort(-topic_word_weights, axis=1)
vocabulary = self.model.distributions_["p_topic_g_word_df"].columns.to_numpy()
n_topics = topic_word_weights.shape[0]
features = [vocabulary[sorted_weights_idxs[topic_i, :]] for topic_i in range(n_topics)]
# Get only the topics that were used to trained the decoder (those that
# have at least one study over the 0.05 threshold)
decoder_features = list(self.decoder.results_.maps.keys())
feat_used_ids = [int(f.split("__")[-1].split("_")[0]) - 1 for f in decoder_features]
return [features[i] for i in feat_used_ids]
def _train_decoder(self):
if ("abstract" not in self.dset.texts) and (self.dset_nm == "neurosynth"):
# Download abstracts only for Neurosynth
self.dset = download_abstracts(self.dset, "jpera054@fiu.edu")
self.dset, self.model = annotate_lda(
self.dset,
self.dset_nm,
self.feature_group,
n_topics=self.n_topics,
n_cores=self.n_cores,
)
frequency_threshold = 0.05
new_feature_group = f"LDA{self.n_topics}"
decoder = decode.CorrelationDecoder(
frequency_threshold=frequency_threshold,
meta_estimator=MKDAChi2,
feature_group=new_feature_group,
target_image="z_desc-association",
n_cores=self.n_cores,
)
decoder.fit(self.dset)
self.decoder = decoder
metamaps_arr = np.array(list(decoder.results_.maps.values()))
return decoder.results_.masker.inverse_transform(metamaps_arr)
[docs]
class GCLDADecoder(CorrelationDecoder):
"""Decode an unthresholded image by correlating the image with meta-analytic maps."""
def __init__(self, n_iters=1000, n_topics=200, n_top_words=3, **kwargs):
super().__init__(**kwargs)
self.n_iters = n_iters
self.n_topics = n_topics
self.n_top_words = n_top_words
self.model_nm = "gclda"
def _get_features(self):
topic_word_weights = self.model.p_word_g_topic_.T
sorted_weights_idxs = np.argsort(-topic_word_weights, axis=1)
vocabulary = np.array(self.model.vocabulary)
n_topics = topic_word_weights.shape[0]
return [vocabulary[sorted_weights_idxs[topic_i, :]] for topic_i in range(n_topics)]
def _train_decoder(self):
counts_df = _get_counts(self.dset, self.dset_nm, self.feature_group)
gclda_model = GCLDAModel(
counts_df,
self.dset.coordinates,
mask=self.dset.masker.mask_img,
n_topics=self.n_topics,
n_regions=4,
symmetric=True,
)
gclda_model.fit(n_iters=self.n_iters, loglikely_freq=10)
self.model = gclda_model
metamaps_arr = gclda_model.p_topic_g_voxel_.T
return masking.unmask(metamaps_arr, gclda_model.mask)