import math import mne import numpy as np import matplotlib.pyplot as plt from mne.decoding import SlidingEstimator, cross_val_multiscore, Vectorizer, Scaler from mne.time_frequency import tfr_morlet from sklearn.linear_model import LogisticRegression from sklearn.pipeline import make_pipeline from utils.file_utils import load_preprocessed_data, get_epochs from utils.plot_utils import plot_tf_cluster, plot_oscillation_bands VERBOSE_LEVEL = 'CRITICAL' def events_to_labels(evts, events_dict, mask=None): # TODO Test schreiben """ Converts the event labels of epochs to class labels for classification :param evts: the event labels to be converted :param events_dict: a dictionary of event keys :param mask: an optional label mask with 4-entries, where: 1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled' If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1 :return: The list of class labels """ events = evts.copy() if mask is None: mask = [0, 1, 0, 1] for i in range(len(events)): key = list(events_dict.keys())[list(events_dict.values()).index(events[i])] k = int(key.split(':')[1]) if k < 41: events[i] = mask[0] # Face intact elif 40 < k < 81: events[i] = mask[1] # Car intact elif 100 < k < 141: events[i] = mask[2] # Face scrambled elif 140 < k < 181: events[i] = mask[3] # Car scrambled return events def permutation_test(baseline, score, n_iter): """ An implementation of a permutation test for classification scores. :param baseline: The classification scores of the baseline, i.e. selection by chance :param score: The classification scores which are tested for significance :param n_iter: number of permutations :return: p-value """ all_data = np.concatenate((baseline, score)) # Base statistic. The statistic used here is the difference of means given_diff = np.mean(score) - np.mean(baseline) all_diffs = [given_diff] # Permutation iterations for i in range(n_iter): # Create a permutation of indices and then use indices from index 0 to len(baseline) to get data for baseline. # Analogously for scores perm_indices = np.random.permutation(list(range(len(all_data)))) mean_diff = np.mean(all_data[perm_indices[len(baseline):]]) - np.mean(all_data[perm_indices[:len(baseline)]]) all_diffs.append(mean_diff) p_val = len(np.where(np.asarray(all_diffs) >= given_diff)[0]) / (n_iter + 1) return p_val def decoding(dataset, filename, compute_metric=True, mask=None): """ Runs decoding over time for all subjects :param dataset: The dataset for which the decoding is done :param filename: filename of either the file from which the classifier scores will be loaded or to which they will be saved :param compute_metric: If True the classifier will be run, else the result will be loaded from a precomputed file :param mask: an optional label mask with 4-entries, where: 1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled' If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1 """ if mask is None: mask = [0, 1, 0, 1] times = None time_scale = 1100 metric = [] p_values = [] if compute_metric: # Computes classifier scores for all subjects for i in range(1, 41): subj = "0" + str(i) if len(str(i)) == 1: subj = "0" + subj # Load data raw = load_preprocessed_data(subj, dataset) epochs, events_dict = get_epochs(raw, picks=mne.pick_types(raw.info, eeg=True, eog=False)) data = epochs.get_data() labels = events_to_labels(epochs.events[:, 2], events_dict, mask) # Classify clf = make_pipeline(Scaler(epochs.info), Vectorizer(), LogisticRegression(solver='lbfgs')) time_decode = SlidingEstimator(clf) scores = cross_val_multiscore(time_decode, data, labels, cv=10, n_jobs=4) metric.append(np.mean(scores, axis=0)) if times is None: times = epochs.times np.save('cached_data/decoding_data/' + filename, metric) metric = np.asarray(metric) else: # Dummy time which is created according to epoch.times times = np.linspace(-0.09960938, 1, 1127) metric = np.load('cached_data/decoding_data/' + filename + '.npy') # Compute index of time point 0 index = math.floor((len(metric[0]) / time_scale) * 100) baseline = np.array(metric[:index]).flatten() # Plot the result plt.plot(np.linspace(-200, 1000, 1127), np.mean(metric, axis=0)) plt.ylabel('Accuracy (%)') plt.xlabel('Time (ms)') plt.title('Mean Accuracy over Subjects for Faces vs. Cars') plt.show() # Compute the permutation tests for t in range(len(metric[0][index:])): score_t = np.asarray(metric[:, t + index]) p = permutation_test(baseline, score_t, 1000) p_values.append(p) if t % 50 == 0: print(str(t) + " Out of " + str(len(metric[0][index:]))) # Plot the result plt.plot(times[index:], p_values) plt.ylabel('P-Value') plt.xlabel('Time (ms)') plt.title('P-Values for Faces vs. Cars') plt.show() def create_tfr(raw, condition, freqs, n_cycles, response='induced', baseline=None, plot=False): """ Compute the time frequency representation (TFR) of data for a given condition via Morlet wavelets :param raw: the data :param condition: the condition for which to compute the TFR. Given as a list of tuples of the form (stimulus, texture) :param freqs: the frequencies for which to compute the TFR :param n_cycles: the number of cycles used by the Morlet wavelets :param response: type of expected TFR. Can be total, induced or evoked. Default is induced, the others were not used for the report, only for exploration :param baseline: baseline used to correct the power. A tuple of the form (start, end). Default is None and no baseline correction will be applied :param plot: True if results should be plotted, else false. :return: The TFR or the given data for a given condition. Has type AverageTFR """ epochs, _ = get_epochs(raw, condition, tmin=-0.2, tmax=1) print(' ' + str(condition)) if response == 'total': print(' Power Total') power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) elif response == 'induced': print(' Power Induced') power = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) else: print(' Power Evoked') power_total = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) power_induced = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) power = mne.combine_evoked([power_total, power_induced], weights=[1, -1]) if plot: power.plot(picks='P7') # Apply a baseline correction to the power data power.apply_baseline(mode='ratio', baseline=baseline) if plot: plot_oscillation_bands(power) power.plot(picks='P7') return power def time_frequency(dataset, filename, scaling='lin', compute_tfr=True): """ Runs time frequency analysis :param dataset: The dataset for which the decoding is done :param filename: Filename of either the file from which the TFRs will be loaded or to which they will be saved :param compute_tfr: If True the TFRs will be created, else the TFRs will be loaded from a precomputed file :param scaling: default 'lin' for linear scaling, else can be 'log' for logarithmic scaling """ # Parameters if scaling == 'lin': freqs = np.linspace(0.1, 50, num=50) # Use this for linear space scaling else: freqs = np.logspace(*np.log10([0.1, 50]), num=50) n_cycles = freqs / 2 cond1 = [] cond2 = [] times = None if compute_tfr: for i in range(1, 41): subj = "0" + str(i) if len(str(i)) == 1: subj = "0" + subj print("########## SUBJECT " + subj + " ##########") # Load data raw = load_preprocessed_data(subj, dataset) raw.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'}) raw.set_montage('standard_1020', match_case=False) # Create the two conditions we want to compare # IMPORTANT: If different conditions should be compared you have to change them here, by altering the second # argument passed to create_tfr power_cond1 = create_tfr(raw, [('face', 'intact')], freqs, n_cycles, 'induced', (-0.2, 0)) print(' CONDITION 1 LOADED') cond1.append(power_cond1) power_cond2 = create_tfr(raw, [('face', 'scrambled'), ('car', None)], freqs, n_cycles, 'induced', (-0.2, 0)) print(' CONDITION 2 LOADED') cond2.append(power_cond2) print(' DONE') # Save the data so we can access the results more easily np.save('cached_data/tf_data/' + filename + '_cond1', cond1) np.save('cached_data/tf_data/' + filename + '_cond2', cond2) else: # If the data should not be recomputed, load the given filename cond1 = np.load('cached_data/tf_data/' + filename + '_cond1.npy', allow_pickle=True).tolist() cond2 = np.load('cached_data/tf_data/' + filename + '_cond2.npy', allow_pickle=True).tolist() if times is None: times = cond1[0].times # Some plots mne.grand_average(cond1).plot(picks=['P7'], vmin=-3, vmax=3, title='Grand Average P7') mne.grand_average(cond2).plot(picks=['P7'], vmin=-3, vmax=3, title='Grand Average P7') plot_oscillation_bands(mne.grand_average(cond1)) plot_oscillation_bands(mne.grand_average(cond2)) # Compute the cluster permutation F, clusters, cluster_p_values, h0 = mne.stats.permutation_cluster_test( [mne.grand_average(cond1).data, mne.grand_average(cond2).data], n_jobs=4, verbose='INFO', seed=123) plot_tf_cluster(F, clusters, cluster_p_values, freqs, times, scaling) if __name__ == '__main__': mne.set_log_level(verbose=VERBOSE_LEVEL) ds = 'N170' decoding(ds, 'faces_vs_cars', True) time_frequency(ds, 'face_intact_vs_all_0.1_50hz_ncf2', 'log', True)