256 lines
10 KiB
Python
256 lines
10 KiB
Python
import math
|
|
import mne
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from mne.decoding import SlidingEstimator, cross_val_multiscore, Vectorizer, Scaler
|
|
from mne.time_frequency import tfr_morlet
|
|
|
|
from sklearn.linear_model import LogisticRegression
|
|
from sklearn.pipeline import make_pipeline
|
|
|
|
from utils.file_utils import load_preprocessed_data, get_epochs
|
|
from utils.plot_utils import plot_tf_cluster, plot_oscillation_bands
|
|
|
|
VERBOSE_LEVEL = 'CRITICAL'
|
|
|
|
|
|
def events_to_labels(evts, events_dict, mask=None): # TODO Test schreiben
|
|
"""
|
|
Converts the event labels of epochs to class labels for classification
|
|
:param evts: the event labels to be converted
|
|
:param events_dict: a dictionary of event keys
|
|
:param mask: an optional label mask with 4-entries, where:
|
|
1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled'
|
|
If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1
|
|
:return: The list of class labels
|
|
"""
|
|
events = evts.copy()
|
|
if mask is None:
|
|
mask = [0, 1, 0, 1]
|
|
for i in range(len(events)):
|
|
key = list(events_dict.keys())[list(events_dict.values()).index(events[i])]
|
|
k = int(key.split(':')[1])
|
|
if k < 41:
|
|
events[i] = mask[0] # Face intact
|
|
elif 40 < k < 81:
|
|
events[i] = mask[1] # Car intact
|
|
elif 100 < k < 141:
|
|
events[i] = mask[2] # Face scrambled
|
|
elif 140 < k < 181:
|
|
events[i] = mask[3] # Car scrambled
|
|
|
|
return events
|
|
|
|
|
|
def permutation_test(baseline, score, n_iter):
|
|
"""
|
|
An implementation of a permutation test for classification scores.
|
|
|
|
:param baseline: The classification scores of the baseline, i.e. selection by chance
|
|
:param score: The classification scores which are tested for significance
|
|
:param n_iter: number of permutations
|
|
:return: p-value
|
|
"""
|
|
all_data = np.concatenate((baseline, score))
|
|
|
|
# Base statistic. The statistic used here is the difference of means
|
|
given_diff = np.mean(score) - np.mean(baseline)
|
|
all_diffs = [given_diff]
|
|
|
|
# Permutation iterations
|
|
for i in range(n_iter):
|
|
# Create a permutation of indices and then use indices from index 0 to len(baseline) to get data for baseline.
|
|
# Analogously for scores
|
|
perm_indices = np.random.permutation(list(range(len(all_data))))
|
|
mean_diff = np.mean(all_data[perm_indices[len(baseline):]]) - np.mean(all_data[perm_indices[:len(baseline)]])
|
|
all_diffs.append(mean_diff)
|
|
|
|
p_val = len(np.where(np.asarray(all_diffs) >= given_diff)[0]) / (n_iter + 1)
|
|
|
|
return p_val
|
|
|
|
|
|
def decoding(dataset, filename, compute_metric=True, mask=None):
|
|
"""
|
|
Runs decoding over time for all subjects
|
|
|
|
:param dataset: The dataset for which the decoding is done
|
|
:param filename: filename of either the file from which the classifier scores will be loaded
|
|
or to which they will be saved
|
|
:param compute_metric: If True the classifier will be run, else the result will be loaded from a precomputed file
|
|
:param mask: an optional label mask with 4-entries, where:
|
|
1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled'
|
|
If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1
|
|
"""
|
|
if mask is None:
|
|
mask = [0, 1, 0, 1]
|
|
times = None
|
|
time_scale = 1100
|
|
metric = []
|
|
p_values = []
|
|
|
|
if compute_metric:
|
|
# Computes classifier scores for all subjects
|
|
for i in range(1, 41):
|
|
subj = "0" + str(i)
|
|
if len(str(i)) == 1:
|
|
subj = "0" + subj
|
|
|
|
# Load data
|
|
raw = load_preprocessed_data(subj, dataset)
|
|
epochs, events_dict = get_epochs(raw, picks=mne.pick_types(raw.info, eeg=True, eog=False))
|
|
data = epochs.get_data()
|
|
labels = events_to_labels(epochs.events[:, 2], events_dict, mask)
|
|
|
|
# Classify
|
|
clf = make_pipeline(Scaler(epochs.info), Vectorizer(), LogisticRegression(solver='lbfgs'))
|
|
time_decode = SlidingEstimator(clf)
|
|
scores = cross_val_multiscore(time_decode, data, labels, cv=10, n_jobs=4)
|
|
metric.append(np.mean(scores, axis=0))
|
|
|
|
if times is None:
|
|
times = epochs.times
|
|
np.save('cached_data/decoding_data/' + filename, metric)
|
|
metric = np.asarray(metric)
|
|
else:
|
|
# Dummy time which is created according to epoch.times
|
|
times = np.linspace(-0.09960938, 1, 1127)
|
|
metric = np.load('cached_data/decoding_data/' + filename + '.npy')
|
|
|
|
# Compute index of time point 0
|
|
index = math.floor((len(metric[0]) / time_scale) * 100)
|
|
baseline = np.array(metric[:index]).flatten()
|
|
|
|
# Plot the result
|
|
plt.plot(np.linspace(-200, 1000, 1127), np.mean(metric, axis=0))
|
|
plt.ylabel('Accuracy (%)')
|
|
plt.xlabel('Time (ms)')
|
|
plt.title('Mean Accuracy over Subjects for Faces vs. Cars')
|
|
plt.show()
|
|
|
|
# Compute the permutation tests
|
|
for t in range(len(metric[0][index:])):
|
|
score_t = np.asarray(metric[:, t + index])
|
|
p = permutation_test(baseline, score_t, 1000)
|
|
p_values.append(p)
|
|
if t % 50 == 0:
|
|
print(str(t) + " Out of " + str(len(metric[0][index:])))
|
|
|
|
# Plot the result
|
|
plt.plot(times[index:], p_values)
|
|
plt.ylabel('P-Value')
|
|
plt.xlabel('Time (ms)')
|
|
plt.title('P-Values for Faces vs. Cars')
|
|
plt.show()
|
|
|
|
|
|
def create_tfr(raw, condition, freqs, n_cycles, response='induced', baseline=None, plot=False):
|
|
"""
|
|
Compute the time frequency representation (TFR) of data for a given condition via Morlet wavelets
|
|
|
|
:param raw: the data
|
|
:param condition: the condition for which to compute the TFR. Given as a list of tuples of the form (stimulus, texture)
|
|
:param freqs: the frequencies for which to compute the TFR
|
|
:param n_cycles: the number of cycles used by the Morlet wavelets
|
|
:param response: type of expected TFR. Can be total, induced or evoked. Default is induced,
|
|
the others were not used for the report, only for exploration
|
|
:param baseline: baseline used to correct the power. A tuple of the form (start, end).
|
|
Default is None and no baseline correction will be applied
|
|
:param plot: True if results should be plotted, else false.
|
|
:return: The TFR or the given data for a given condition. Has type AverageTFR
|
|
"""
|
|
epochs, _ = get_epochs(raw, condition, tmin=-0.2, tmax=1)
|
|
print(' ' + str(condition))
|
|
|
|
if response == 'total':
|
|
print(' Power Total')
|
|
power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4)
|
|
elif response == 'induced':
|
|
print(' Power Induced')
|
|
power = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4)
|
|
else:
|
|
print(' Power Evoked')
|
|
power_total = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4)
|
|
power_induced = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4)
|
|
power = mne.combine_evoked([power_total, power_induced], weights=[1, -1])
|
|
if plot: power.plot(picks='P7')
|
|
# Apply a baseline correction to the power data
|
|
power.apply_baseline(mode='ratio', baseline=baseline)
|
|
if plot:
|
|
plot_oscillation_bands(power)
|
|
power.plot(picks='P7')
|
|
return power
|
|
|
|
|
|
def time_frequency(dataset, filename, compute_tfr=True):
|
|
"""
|
|
Runs time frequency analysis
|
|
|
|
:param dataset: The dataset for which the decoding is done
|
|
:param filename: Filename of either the file from which the TFRs will be loaded
|
|
or to which they will be saved
|
|
:param compute_tfr: If True the TFRs will be created, else the TFRs will be loaded from a precomputed file
|
|
"""
|
|
# Parameters
|
|
freqs = np.linspace(0.1, 50, num=50) # Use this for linear space scaling
|
|
# freqs = np.logspace(*np.log10([0.1, 50]), num=50)
|
|
n_cycles = freqs / 2
|
|
cond1 = []
|
|
cond2 = []
|
|
times = None
|
|
|
|
if compute_tfr:
|
|
for i in range(1, 41):
|
|
subj = "0" + str(i)
|
|
if len(str(i)) == 1:
|
|
subj = "0" + subj
|
|
print("########## SUBJECT " + subj + " ##########")
|
|
|
|
# Load data
|
|
raw = load_preprocessed_data(subj, dataset)
|
|
raw.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'})
|
|
raw.set_montage('standard_1020', match_case=False)
|
|
|
|
# Create the two conditions we want to compare
|
|
# IMPORTANT: If different conditions should be compared you have to change them here, by altering the second
|
|
# argument passed to create_tfr
|
|
power_cond1 = create_tfr(raw, [('face', 'intact')], freqs, n_cycles, 'induced', (-0.2, 0))
|
|
print(' CONDITION 1 LOADED')
|
|
cond1.append(power_cond1)
|
|
|
|
power_cond2 = create_tfr(raw, [('face', 'scrambled'), ('car', None)], freqs, n_cycles, 'induced',
|
|
(-0.2, 0))
|
|
print(' CONDITION 2 LOADED')
|
|
cond2.append(power_cond2)
|
|
print(' DONE')
|
|
|
|
# Save the data so we can access the results more easily
|
|
np.save('cached_data/tf_data/' + filename + '_cond1', cond1)
|
|
np.save('cached_data/tf_data/' + filename + '_cond2', cond2)
|
|
else:
|
|
# If the data should not be recomputed, load the given filename
|
|
cond1 = np.load('cached_data/tf_data/' + filename + '_cond1.npy', allow_pickle=True).tolist()
|
|
cond2 = np.load('cached_data/tf_data/' + filename + '_cond2.npy', allow_pickle=True).tolist()
|
|
if times is None:
|
|
times = cond1[0].times
|
|
|
|
# Some plots
|
|
mne.grand_average(cond1).plot(picks=['P7'], vmin=-3, vmax=3, title='Grand Average P7')
|
|
mne.grand_average(cond2).plot(picks=['P7'], vmin=-3, vmax=3, title='Grand Average P7')
|
|
plot_oscillation_bands(mne.grand_average(cond1))
|
|
plot_oscillation_bands(mne.grand_average(cond2))
|
|
|
|
# Compute the cluster permutation
|
|
F, clusters, cluster_p_values, h0 = mne.stats.permutation_cluster_test(
|
|
[mne.grand_average(cond1).data, mne.grand_average(cond2).data], n_jobs=4, verbose='INFO',
|
|
seed=123)
|
|
plot_tf_cluster(F, clusters, cluster_p_values, freqs, times)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
mne.set_log_level(verbose=VERBOSE_LEVEL)
|
|
ds = 'N170'
|
|
decoding(ds, 'faces_vs_cars', True)
|
|
time_frequency(ds, 'face_intact_vs_all_0.1_50hz_ncf2_linscale', True)
|