commit fc4eec6ac71b6f89349325a36f8ecf9da1004026 Author: Julius Voggesberger Date: Sat Mar 27 14:58:49 2021 +0100 Initial commit diff --git a/decoding_tf_analysis.py b/decoding_tf_analysis.py new file mode 100644 index 0000000..e895648 --- /dev/null +++ b/decoding_tf_analysis.py @@ -0,0 +1,239 @@ +import math +import mne +import numpy as np +import matplotlib.pyplot as plt + +from mne.decoding import SlidingEstimator, cross_val_multiscore, Vectorizer, Scaler +from mne.time_frequency import tfr_morlet + +from sklearn.linear_model import LogisticRegression +from sklearn.pipeline import make_pipeline + +from utils.file_utils import load_preprocessed_data, get_epochs +from utils.plot_utils import plot_tf_cluster, plot_oscillation_bands + +VERBOSE_LEVEL = 'CRITICAL' + + +def events_to_labels(evts, events_dict, mask=None): # TODO Test schreiben + """ + Converts the event labels of epochs to class labels for classification + :param evts: the event labels to be converted + :param events_dict: a dictionary of event keys + :param mask: an optional label mask with 4-entries, where: + 1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled' + If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1 + :return: The list of class labels + """ + events = evts.copy() + if mask is None: + mask = [0, 1, 0, 1] + for i in range(len(events)): + key = list(events_dict.keys())[list(events_dict.values()).index(events[i])] + k = int(key.split(':')[1]) + if k < 41: + events[i] = mask[0] # Face intact + elif 40 < k < 81: + events[i] = mask[1] # Car intact + elif 100 < k < 141: + events[i] = mask[2] # Face scrambled + elif 140 < k < 181: + events[i] = mask[3] # Car scrambled + + return events + + +def permutation_test(baseline, score, n_iter): + """ + An implementation of a permutation test for classification scores. + :param baseline: The classification scores of the baseline, i.e. selection by chance + :param score: The classification scores which are tested for significance + :param n_iter: number of permutations + :return: p-value + """ + all_data = np.concatenate((baseline, score)) + + # Base statistic. The statistic used here is the difference of means + given_diff = np.mean(score) - np.mean(baseline) + all_diffs = [given_diff] + + # Permutation iterations + for i in range(n_iter): + # Create a permutation of indices and then use indices from index 0 to len(baseline) to get data for baseline. + # Analogously for scores + perm_indices = np.random.permutation(list(range(len(all_data)))) + mean_diff = np.mean(all_data[perm_indices[len(baseline):]]) - np.mean(all_data[perm_indices[:len(baseline)]]) + all_diffs.append(mean_diff) + + p_val = len(np.where(np.asarray(all_diffs) >= given_diff)[0]) / (n_iter + 1) + + return p_val + + +def decoding(dataset, filename, compute_metric=True, mask=None): + """ + Runs decoding over time for all subjects + + :param dataset: The dataset for which the decoding is done + :param filename: filename of either the file from which the classifier scores will be loaded + or to which they will be saved + :param compute_metric: If True the classifier will be run, else the result will be loaded from a precomputed file + :param mask: an optional label mask with 4-entries, where: + 1. entry: 'face intact', 2. entry: 'car intact', 3. entry: 'face scrambled', 4. entry: 'face scrambled' + If None the entries are [0,1,0,1] i.e. all faces are in class 0 and all cars are in class 1 + """ + if mask is None: + mask = [0, 1, 0, 1] + times = None + time_scale = 1100 + metric = [] + p_values = [] + + if compute_metric: + # Computes classifier scores for all subjects + for i in range(1, 41): + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + + # Load data + raw = load_preprocessed_data(subj, dataset) + epochs, events_dict = get_epochs(raw, picks=mne.pick_types(raw.info, eeg=True, eog=False)) + data = epochs.get_data() + labels = events_to_labels(epochs.events[:, 2], events_dict, mask) + + # Classify + clf = make_pipeline(Scaler(epochs.info), Vectorizer(), LogisticRegression(solver='lbfgs')) + time_decode = SlidingEstimator(clf) + scores = cross_val_multiscore(time_decode, data, labels, cv=10, n_jobs=4) + metric.append(np.mean(scores, axis=0)) + + if times is None: + times = epochs.times + np.save('cached_data/decoding_data/' + filename, metric) + else: + # Dummy time which is created according to epoch.times + times = np.linspace(-0.09960938, 1, 1127) + metric = np.load('cached_data/decoding_data/' + filename + '.npy') + + # Compute index of time point 0 + index = math.floor((len(metric[0]) / time_scale) * 100) + baseline = np.array(metric[:index]).flatten() + plt.plot(np.linspace(-200, 1000, 1127), np.mean(metric, axis=0)) + plt.ylabel('Accuracy (%)') + plt.xlabel('Time (ms)') + plt.title('Mean Accuracy over Subjects for Faces vs. Cars') + plt.show() + + # Compute the permutation tests + for t in range(len(metric[0][index:])): + score_t = np.asarray(metric[:, t + index]) + p = permutation_test(baseline, score_t, 100) + p_values.append(p) + if t % 50 == 0: + print(str(t) + " Out of " + str(len(metric[0][index:]))) + + plt.plot(times[index:], p_values) + plt.ylabel('P-Value') + plt.xlabel('Time (ms)') + plt.title('P-Values for Faces vs. Cars') + plt.show() + + +def create_tfr(raw, condition, freqs, n_cycles, response='induced', baseline=None): + """ + Compute the time frequency representation (TFR) of data for a given condition via morlet wavelets + :param raw: the data + :param condition: the condition for which to compute the TFR. Given as a list of tuples of the form (stimulus, condition) # TODO ambiguous use of condition + :param freqs: the frequencies for which to compute the TFR + :param n_cycles: the number of cycles used by the morlet wavelets + :param response: type of expected TFR. Can be total, induced or evoked. Default is induced + :param baseline: baseline used to correct the power. A tuple of the form (start, end). + Default is None and no baseline correction will be applid + :return: The TFR or the given data for a given condition. Has type AverageTFR + """ + epochs, _ = get_epochs(raw, condition, tmin=-0.2, tmax=1) + print(' ' + str(condition)) + + if response == 'total': + print(' Power Total') + power = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) + elif response == 'induced': + print(' Power Induced') + power = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) + else: + print(' Power Evoked') + power_total = tfr_morlet(epochs, freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) + power_induced = tfr_morlet(epochs.subtract_evoked(), freqs=freqs, n_cycles=n_cycles, return_itc=False, n_jobs=4) + power = mne.combine_evoked([power_total, power_induced], weights=[1, -1]) + # power.plot(picks='P7') + power.apply_baseline(mode='ratio', baseline=baseline) + # plot_oscillation_bands(power) + # power.plot(picks='P7') + return power + + +def time_frequency(dataset, filename, compute_tfr=True): + """ + Runs time frequency analysis + + :param dataset: The dataset for which the decoding is done + :param filename: Filename of either the file from which the TFRs will be loaded + or to which they will be saved + :param compute_tfr: If True the TFRs will be created, else the TFRs will be loaded from a precomputed file + """ + # Parameters + # Frequency space (from, to, steps) -> Control frequency resolution : Between num=50-80 good for 1-50Hz + freqs = np.logspace(*np.log10([0.5, 50]), num=50) # + # Number of cycles -> Controls time resolution ? At ~freqs/2 good for high frequency resolution + n_cycles = freqs / 2 # 1 for high time resolution & freq smoothing, freqs/2 for high freq resolution & time smooth + # Baseline -> Should not go post-stimulus, i.e. > 0 -> Best ist pre-stimulus (e.g. -400 to -200ms) + baseline = [-0.5, 0] + cond1 = [] + cond2 = [] + times = None + + if compute_tfr: + for i in range(1, 41): + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + print("########## SUBJECT " + subj + " ##########") + + # Load data + raw = load_preprocessed_data(subj, dataset) + raw.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'}) + raw.set_montage('standard_1020', match_case=False) + + # Create the two conditions we want to compare + power_cond1 = create_tfr(raw, [('face', 'intact')], freqs, n_cycles, 'induced', (-0.2, 0)) + print(' CONDITION 1 LOADED') + cond1.append(power_cond1) + + power_cond2 = create_tfr(raw, [('face', 'scrambled'), ('car', None)], freqs, n_cycles, 'induced', + (-0.2, 0)) + print(' CONDITION 2 LOADED') + cond2.append(power_cond2) + print(' DONE') + + np.save('cached_data/tf_data/' + filename + '_cond1', cond1) + np.save('cached_data/tf_data/' + filename + '_cond2', cond2) + else: + cond1 = np.load('cached_data/tf_data/' + filename + '_cond1.npy', allow_pickle=True).tolist() + cond2 = np.load('cached_data/tf_data/' + filename + '_cond2.npy', allow_pickle=True).tolist() + if times is None: + times = cond1[0].times + mne.grand_average(cond2).plot(picks=['P7'], vmin=-3, vmax=3, title='Grand Average P7') + plot_oscillation_bands(mne.grand_average(cond1)) + plot_oscillation_bands(mne.grand_average(cond2)) + F, clusters, cluster_p_values, h0 = mne.stats.permutation_cluster_test( + [mne.grand_average(cond1).data, mne.grand_average(cond2).data], n_jobs=4, verbose='INFO', + seed=123) + plot_tf_cluster(F, clusters, cluster_p_values, freqs, times) + + +if __name__ == '__main__': + mne.set_log_level(verbose=VERBOSE_LEVEL) + ds = 'N170' + # decoding(ds, 'faces_vs_cars_100iters', False) + time_frequency(ds, 'face_intact_vs_all_0.1_50hz_ncf2', False) diff --git a/erp_analysis.py b/erp_analysis.py new file mode 100644 index 0000000..9d671b7 --- /dev/null +++ b/erp_analysis.py @@ -0,0 +1,134 @@ +import mne +import pandas as pd +from scipy.stats import ttest_1samp, f_oneway + +from utils.file_utils import load_preprocessed_data, get_epochs + +VERBOSE_LEVEL = 'CRITICAL' + + +def extract_erp_peak(raw, subject, stimulus, condition, channel): + """ + Extracts the erp peak for a given subject, stimulus and condition as a single value. + + :param raw: The raw object, from which the epochs are generated + :param subject: The subject for which the peak is extracted + :param stimulus: The stimulus we look at: Either 'car' or 'face' + :param condition: The condition of the stimulus: Either 'intact' or 'scrambled' + :param channel: The currently selected channel, for which the erp_peak should be extracted + :return: A dictionary conforming to the data frame format: + {'subject_id': subject, 'stimulus': stimulus, 'condition': condition, 'peak': peak} + """ + # Epoch the data + epochs, _ = get_epochs(raw, [(stimulus, condition)], picks=channel) + # Check only for negative peaks, as only the channels P7,P07,P8,P08 are used + ch, latency, peak = epochs.average().get_peak(tmin=0.13, tmax=0.2, mode='neg', return_amplitude=True) + return {'subject_id': subject, 'stimulus': stimulus, 'condition': condition, 'peak': peak} + + +def precompute_erp_df(dataset): + """ + This method generates a .csv file where the erp peaks for each stimulus-condition pair for each subject are saved + :param dataset: The dataset for which the erp peaks are computed + """ + chs = ['P7', 'PO7', 'P8', 'PO8'] + events = [('face', 'intact'), ('face', 'scrambled'), ('car', 'intact'), ('car', 'scrambled')] + + for ch in chs: + df = pd.DataFrame(data={'subject_id': [], 'stimulus': [], 'condition': [], 'peak': []}) + for i in range(1, 41): + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + # Load preprocessed .fif data files + raw = load_preprocessed_data(subj, dataset) + # Extract ERP peaks + for ev in events: + row = extract_erp_peak(raw, subj, ev[0], ev[1], ch) + df = df.append(row, ignore_index=True) + df.to_csv('cached_data/erp_peaks/erp_peaks_' + ch + '.csv') + + +def create_peak_difference_feature(df, max_subj=40): + """ + Compute the difference of two N170 peaks for different conditions for all subjects. + I.e. the difference of face(intact)-car(intact),face(scrambled)-car(scrambled),etc. + :param max_subj: the maximum subject till which the features are computed. + :param df: A pandas dataframe containing the peak information for all conditions and subjects + :return: A pandas dataframe containing the peak-difference for multiple condition differences + """ + peak_diff_df = pd.DataFrame( + data={'subject_id': [], 'mean_face': [], 'mean_car': [], 'peak_diff_overall': [], 'diff_intact': [], + 'diff_scrambled': [], 'diff_face': [], 'diff_fc_ci': [], 'diff_fi_rest': []}) + + for i in range(1, max_subj + 1): + + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + sub_df = df.loc[df['subject_id'] == i] + # difference of face and car (intact) + diff_intact = sub_df.loc[df['condition'] == 'intact']['peak'].diff().to_numpy()[1] + # difference of face and car (scrambled) + diff_scrambled = sub_df.loc[df['condition'] == 'scrambled']['peak'].diff().to_numpy()[1] + # Difference of Face intact and Face scrambled + diff_face = sub_df.loc[df['stimulus'] == 'face']['peak'].diff().to_numpy()[1] + # Difference of Face scrambled and Car intact + diff_fs_ci = sub_df.loc[(df['stimulus'] == 'face') & (df['condition'] == 'scrambled')]['peak'].values[0] - \ + sub_df.loc[(df['stimulus'] == 'car') & (df['condition'] == 'intact')]['peak'].values[0] + # Mean of face (intact) and face (scrambled) + mean_face = sub_df.loc[df['stimulus'] == 'face']['peak'].mean() + # Mean of car (intact) and car (scrambled) + mean_car = sub_df.loc[df['stimulus'] == 'car']['peak'].mean() + mean_rest = sub_df.loc[(df['stimulus'] == 'car') | ((df['stimulus'] == 'face') & (df['condition'] == 'scrambled'))]['peak'].mean() + diff_fi_rest = sub_df.loc[df['stimulus'] == 'face']['peak'].values[0] - mean_rest + # Difference of face (overall) and car (overall) + diff = mean_face - mean_car + peak_diff_df = peak_diff_df.append( + {'subject_id': subj, 'mean_face': mean_face, 'mean_car': mean_car, 'peak_diff_overall': diff, + 'diff_intact': diff_intact, 'diff_scrambled': diff_scrambled, 'diff_face': diff_face, + 'diff_fc_ci': diff_fs_ci, 'diff_fi_rest': diff_fi_rest}, ignore_index=True) + return peak_diff_df + + +def analyze_erp(channels): + """ + Execute several statistical tests for different hypothesis, to analyze ERPs + :param channels: The channels for which the tests are executed + """ + for c in channels: + print("CHANNEL: " + c) + erp_df = pd.read_csv('cached_data/erp_peaks/erp_peaks_' + c + '.csv', index_col=0) + feature_df = create_peak_difference_feature(erp_df) + # 1. H_a : There is a difference between the N170 peak of recognizing faces and cars + # Run one-sample ttest against 0 mean + stat, p_val = ttest_1samp(feature_df['peak_diff_overall'].to_numpy(), 0) + print("Peak Difference Faces-Car (All)") + print("P-Value=" + str(p_val)) + # 2. H_a : There is a difference between the peak difference of intact faces&cars, + # to the peak difference of scrambled faces&cars + # Run ANOVA for two samples. 1. Diff of intact faces&cars, 2. Diff of scrambled faces&cars + stat, p_val = f_oneway(feature_df['diff_intact'].to_numpy(), feature_df['diff_scrambled'].to_numpy()) + print("Difference of peak-differences face-car (intact) vs. face-car (scrambled)") + print("P-Value=" + str(p_val)) + # # 3. H_a : There is a difference in the peak-difference of face-car (intact) + stat, p_val = ttest_1samp(feature_df['diff_intact'].to_numpy(), 0) + print("Peak Difference Faces-Car (Intact)") + print("P-Value=" + str(p_val)) + # # 4. H_a : There is a difference in the peak-difference of face-car (scrambled) + stat, p_val = ttest_1samp(feature_df['diff_scrambled'].to_numpy(), 0) + print("Peak Difference Faces-Car (Scrambled)") + print("P-Value=" + str(p_val)) + # # 5. H_a : There is a Difference between Face (scrambled) and Face (intact) in the peak difference + stat, p_val = ttest_1samp(feature_df['diff_face'].to_numpy(), 0) + print("Peak Difference Face intact and scrambled") + print("P-Value=" + str(p_val)) + stat, p_val = ttest_1samp(feature_df['diff_fi_rest'].to_numpy(), 0) + print("Peak Difference Face intact and Rest") + print("P-Value=" + str(p_val)) + + +if __name__ == '__main__': + mne.set_log_level(verbose=VERBOSE_LEVEL) + # precompute_erp_df('N170') + analyze_erp(['P7', 'PO7', 'P8', 'PO8']) diff --git a/plotting.py b/plotting.py new file mode 100644 index 0000000..12c71c6 --- /dev/null +++ b/plotting.py @@ -0,0 +1,110 @@ +import mne +from mne.preprocessing import create_eog_epochs + +from mne_bids import BIDSPath, read_raw_bids +from utils.ccs_eeg_utils import read_annotations_core +from utils.file_utils import get_epochs + + +def load_unprocessed_subject(subject, dataset): + """ + Load the eeg data of a subject + :param subject: The subject of which the data will be loaded + :param dataset: The dataset which will be loaded + :return: the subject data + """ + bids_path = BIDSPath(subject=subject, task=dataset, session=dataset, datatype='eeg', suffix='eeg', + root='Dataset\\' + dataset) + raw = read_raw_bids(bids_path) + # Add annotations + read_annotations_core(bids_path, raw) + return raw + + +def filter_data(raw): + """ + Filter the data of a single subject with a bandpass filter. + The lower bound ist 0.5Hz to compensate the slow drifts. + The upper bound is 50Hz to compensate the high frequencies, including the power line spike at 60Hz + :param raw: The data to be filtered + :return: The filtered data + """ + raw.filter(0.5, 48, fir_design='firwin') + return raw + + +def plot_filter_data(): + ds = 'N170' + for subj in ['014']: + data = load_unprocessed_subject(subj, ds) + data.load_data() + # data.plot(n_channels=len(data.ch_names), block=True, scalings=40e-6) + filter_data(data) + fig = mne.viz.plot_raw_psd(data, fmax=80, average=True, show=False) + fig.savefig("plots/frequency_filtered_subj_" + subj + "_48Hz.png") + # data.plot(n_channels=len(data.ch_names), block=True, scalings=40e-6) + + +def plot_filter_data_epoched(subj): + ds = 'N170' + data = load_unprocessed_subject(subj, ds) + data.load_data() + filter_data(data) + get_epochs(data)[0].average().plot() + + +def plot_cleaning(): + ds = 'N170' + for subj in ['014']: + data = load_unprocessed_subject(subj, ds) + data.load_data() + filter_data(data) + folder = "Dataset\\" + ds + "\\sub-" + subj + "\\ses-" + ds + "\\eeg\\" + filepath = folder + "sub-" + subj + "_task-" + ds + print(filepath) + ann = mne.read_annotations(filepath + "_" + "badannotations.csv") + data.annotations.append(ann.onset, ann.duration, ann.description) + data.plot(n_channels=len(data.ch_names), block=True, scalings=40e-6) + + +def plot_ica(): + ds = 'N170' + for subj in ['014']: + data = load_unprocessed_subject(subj, ds) + folder = "Dataset\\" + ds + "\\sub-" + subj + "\\ses-" + ds + "\\eeg\\" + filepath = folder + "sub-" + subj + "_task-" + ds + ann = mne.read_annotations(filepath + "_" + "badannotations.csv") + data.annotations.append(ann.onset, ann.duration, ann.description) + data.load_data() + data.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'}) + data.set_montage('standard_1020', match_case=False) + ica_raw = data.copy() + ica_raw.filter(l_freq=1, h_freq=None) + + # Then run ICA + ica = mne.preprocessing.ICA(method="fastica", + random_state=123) # Use a random state for reproducable results #TODO Old Random state 123 or new one? + ica.fit(ica_raw, verbose=True) + ica_raw.load_data() + # ica.plot_components(inst=ica_raw, ch_type='eeg', contours=0, topomap_args={'extrapolate': 'head'}, + # psd_args={'fmin': 0, 'fmax': 80}) + ica.plot_sources(ica_raw) + ica.plot_properties(inst=ica_raw, dB=False, topomap_args={'extrapolate': 'head', 'contours': 0}, + psd_args={'fmin': 3, 'fmax': 50}, picks=['eeg']) + + +def plot_joint_eog_plots(): + ds = 'N170' + for subj in ['014']: + data = load_unprocessed_subject(subj, ds) + data.load_data() + data.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'}) + data.set_montage('standard_1020', match_case=False) + eog_evoked = create_eog_epochs(data).average() + eog_evoked.apply_baseline(baseline=(None, -0.2)) + eog_evoked.plot_joint() + + +# plot_ica() +# plot_joint_eog_plots() +plot_filter_data_epoched('003') diff --git a/preprocessing_and_cleaning.py b/preprocessing_and_cleaning.py new file mode 100644 index 0000000..803553a --- /dev/null +++ b/preprocessing_and_cleaning.py @@ -0,0 +1,182 @@ +import os +import mne + +from mne_bids import (BIDSPath, read_raw_bids) +from utils.ccs_eeg_semesterproject import load_precomputed_badData, load_precomputed_ica +from utils.ccs_eeg_utils import read_annotations_core + + +def load_subject(subject, dataset): + """ + Load the eeg data of a subject + :param subject: The subject of which the data will be loaded + :param dataset: The dataset which will be loaded + :return: the subject data + """ + bids_path = BIDSPath(subject=subject, task=dataset, session=dataset, datatype='eeg', suffix='eeg', + root='Dataset\\' + dataset) + raw = read_raw_bids(bids_path) + # Add annotations + read_annotations_core(bids_path, raw) + return raw + + +def load_given_preprocessing_data(subject, dataset): + """ + Loads given pre-processing information for a given subject. + This is used for all subjects which were not manually preprocessed + :param subject: The subject to load the data for + :param dataset: The dataset currently viewed + :return: The bad annotations, bad channels, ica object, bad ICs + """ + anno, bc = load_precomputed_badData("Dataset\\" + dataset + "\\", subject, + dataset) # Loads annotations and bad channels + ica, bad_comp = load_precomputed_ica("Dataset\\" + dataset + "\\", subject, + dataset) # Loads ica and bad components + return anno, bc, ica, bad_comp + + +def save_subject(raw, subject, dataset): + """ + Save a raw object to a .fif file + :param raw: the raw object to be saved + :param subject: the subject, which the raw object belongs to + :param dataset: the dataset currently viewed + """ + folder = "Dataset\\" + dataset + "\\sub-" + subject + "\\ses-" + dataset + "\\eeg\\" + filepath = folder + "sub-" + subject + "_task-" + dataset + raw.save(filepath + "_cleaned.fif", overwrite=True) + + +def filter_data(raw): + """ + Filter the data of a single subject with a bandpass filter. + The lower bound ist 0.5Hz to compensate the slow drifts. + The upper bound is 50Hz to compensate the high frequencies, including the power line spike at 60Hz + :param raw: The data to be filtered + :return: The filtered data + """ + raw.filter(0.5, 48, fir_design='firwin') + return raw + + +def clean_data(raw, subject, dataset, cleaned=False): + """ + Clean the data of a single subject, meaning finding the bad segments and channels of a subject. + If these were already found, they are loaded onto the data + :param raw: the subject data + :param subject: the subject which data will be viewed + :param cleaned: If True the data was already viewed and the 'BAD_' annotations as well as the bad channels will be loaded + :return: the bad channels + """ + channels = None + folder = "Dataset\\" + dataset + "\\sub-" + subject + "\\ses-" + dataset + "\\eeg\\" + filepath = folder + "sub-" + subject + "_task-" + dataset + + # If nothing was marked yet, plot the data to mark bad segments + if not cleaned: + raw.plot(n_channels=len(raw.ch_names), block=True, scalings=40e-6) + # Get indices of bad annotations + bad_idx = [idx for (idx, annot) in enumerate(raw.annotations) if annot['description'] == "BAD_"] + # If bad intervals were found save + if bad_idx: + raw.annotations[bad_idx].save(filepath + "_badannotations.csv") + + if os.path.isfile(filepath + "_badannotations.csv"): + annotations = mne.read_annotations(filepath + "_badannotations.csv") + raw.annotations.append(annotations.onset, annotations.duration, annotations.description) + + # Set the bad channels for each subject + if subject == '001': + channels = ['F8'] # Maybe also FP2? + elif subject == '003': + channels = [] + elif subject == '014': + channels = [] + + return channels + + +def run_ica(raw, dataset, subject, search='manual'): + """ + Runs Independent Component Analysis. Depending on the 'search' mode, it is either used to find bad ICs or to exclude + bad ICs + :param raw: the data to be preprocessed + :param dataset: the dataset currently viewed + :param subject: the subject currently viewed + :param search: default value 'manual': The user views different plots for all ICs found + 'eog' : Uses the eog channels to find bad ICs + 'done' : Applies the bad ICs that were found + """ + # First filter the data to remove slow drifts - this is done with 1Hz, as proposed by the mne Tutorial at: + # https://mne.tools/dev/auto_tutorials/preprocessing/plot_40_artifact_correction_ica.html#filtering-to-remove-slow-drifts + ica_raw = raw.copy() + ica_raw.filter(l_freq=1, h_freq=None) + + # Then run ICA + ica = mne.preprocessing.ICA(method="fastica", random_state=123) # Use a random state for reproducable results + ica.fit(ica_raw, verbose=True) + + if search == 'manual': + ica_raw.load_data() + # ica.plot_components(inst=ica_raw, ch_type='eeg', contours=0, topomap_args={'extrapolate': 'head'}, + # psd_args={'fmin': 0, 'fmax': 80}) + ica.plot_properties(inst=ica_raw, dB=False, topomap_args={'extrapolate': 'head', 'contours': 0}, + psd_args={'fmin': 0, 'fmax': 50}, picks=['eeg']) + ica.plot_sources(ica_raw) + elif search == 'eog': + eog_indices, _ = ica.find_bads_eog(raw) + ica.exclude = eog_indices + print('BAD COMPONENTS VIA EOG: ' + str(eog_indices)) + ica.plot_overlay(ica_raw, exclude=eog_indices) + elif search == 'done': + exclude = None + if subj == '001': + exclude = [0, 1, 2, 4, 8, 14, 16, 25] # Through eog: 0,1 + elif subj == '003': + exclude = [0, 2] # Through eog: 0, 2 + elif subj == '014': + exclude = [0, 1, 9] # Through eog: 0,1 + # ica.plot_overlay(ica_raw, exclude=exclude) # Plot differences through exclude + # ica.exclude = exclude + # Apply ica to the raw object + raw.load_data() + # ica.plot_overlay(ica_raw, exclude=exclude) + raw = ica.apply(raw, exclude=exclude) + # Lastly save the ica to a file + folder = "Dataset\\" + dataset + "\\sub-" + subject + "\\ses-" + dataset + "\\eeg\\" + filepath = folder + "sub-" + subject + "_task-" + dataset + ica.save(filepath + "-ica.fif") + return raw + + +if __name__ == '__main__': + ds = 'N170' + for i in range(1, 41): + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + data = load_subject(subj, ds) + # Load data into memory + data.load_data() + # Filter data with a bandpass filter + filter_data(data) + if subj in ["001", "003", "014"]: + # Manual preprocessing + # Clean the data + b_ch = clean_data(data, subj, ds, True) + # Run ICA + data.set_channel_types({'HEOG_left': 'eog', 'HEOG_right': 'eog', 'VEOG_lower': 'eog'}) + data.set_montage('standard_1020', match_case=False) + data = run_ica(data, ds, subj, 'done') + else: + # Provided cleaning and preprocessing information + ann, b_ch, ica_pre, bad_component = load_given_preprocessing_data(subj, ds) + data.annotations.append(ann.onset, ann.duration, ann.description) + data = ica_pre.apply(data, exclude=bad_component) + # Interpolate bad channels + data.interpolate_bads(b_ch) + # Re-Reference the data + data_re = data.copy().set_eeg_reference('average') + # Save preprocessed and cleaned data set + save_subject(data_re, subj, ds) diff --git a/test/test.py b/test/test.py new file mode 100644 index 0000000..3a9e35e --- /dev/null +++ b/test/test.py @@ -0,0 +1,86 @@ +import unittest + +import mne +import pandas as pd + +from decoding_tf_analysis import events_to_labels, permutation_test +from erp_analysis import create_peak_difference_feature +from utils.file_utils import get_keys_for_events, get_epochs +from pandas.testing import assert_frame_equal + + +class TestFileUtils(unittest.TestCase): + + def setUp(self): + # Load true values for keys + with open('test_files/face.txt') as f: + face = f.readlines() + self.face = [x.strip() for x in face] + with open('test_files/face_intact.txt') as f: + face_intact = f.readlines() + self.face_intact = [x.strip() for x in face_intact] + with open('test_files/face_scrambled.txt') as f: + face_scrambled = f.readlines() + self.face_scrambled = [x.strip() for x in face_scrambled] + + # Load true epochs + self.raw = mne.io.read_raw_fif("..\\Dataset\\n170\\sub-001\\ses-n170\\eeg\\sub-001_task-n170_cleaned.fif") + wanted_keys = get_keys_for_events("face", "intact") + events, events_dict = mne.events_from_annotations(self.raw) + events_dict_key = dict((k, events_dict[k]) for k in wanted_keys if k in events_dict) + self.given = mne.Epochs(self.raw, events, events_dict_key, tmin=-0.2, tmax=0.5, reject_by_annotation=False, + picks='P7') + self.given.drop_bad() + self.given.load_data() + + def test_keys_for_events(self): + # Test only for face key generation, as the rest is generated analogously + self.assertEqual("stimulus", get_keys_for_events(stimulus=None, condition=None)) + self.assertEqual(self.face, get_keys_for_events(stimulus='face', condition=None)) + self.assertEqual(self.face_intact, get_keys_for_events(stimulus='face', condition='intact')) + self.assertEqual(self.face_scrambled, get_keys_for_events(stimulus='face', condition='scrambled')) + + def test_get_epochs(self): + # Get a epoch to compare against + epochs, key = get_epochs(self.raw, [("face", "intact")], picks='P7', tmin=-0.2, tmax=0.5) + epochs.load_data() + self.assertEqual(self.given, epochs) + + +class TestERPAnalysis(unittest.TestCase): + + def test_difference_features(self): + # Check if the correct features are created + subject_ids = [1, 1, 1, 1] + stimuli = ['face', 'face', 'car', 'car'] + conditions = ['intact', 'scrambled', 'intact', 'scrambled'] + peaks = [1, 2, 3, 4] + df = pd.DataFrame(data={'subject_id': subject_ids, 'stimulus': stimuli, 'condition': conditions, + 'peak': peaks}) + diff_df_true = pd.DataFrame( + data={'subject_id': ['001'], 'mean_face': [1.5], 'mean_car': [3.5], 'peak_diff_overall': [-2.0], + 'diff_intact': [2.0], 'diff_scrambled': [2.0], 'diff_face': [1.0], 'diff_fc_ci': [-1.0], + 'diff_fi_rest': [-2.0]}) + diff_df = create_peak_difference_feature(df, 1) + assert_frame_equal(diff_df_true, diff_df) + + +class TestDecodingTFAnalysis(unittest.TestCase): + def test_events_to_labels(self): + # Only check for stimuli 1-40, 41-80, 101-140, 141-180 as no other stimuli are possible + events_dict = {'stimulus:1': 1, 'stimulus:40': 2, 'stimulus:41': 3, 'stimulus:80': 4, 'stimulus:101': 5, + 'stimulus:140': 6, 'stimulus:141': 7, 'stimulus:180': 8} + labels = events_to_labels([1, 2, 3, 4, 5, 6, 7, 8], events_dict, [0, 1, 2, 3]) + self.assertCountEqual([0, 0, 1, 1, 2, 2, 3, 3], labels) + + def test_permutation_test(self): + # Check permutation test + p = permutation_test([0, 0, 0, 0], [0, 0, 0, 0], 100) + self.assertEqual(1, p) + p = permutation_test([0, 0, 0, 0, 0], [5, 10, 15, 10, 5], 100) + self.assertGreater(0.05, p) + + +if __name__ == '__main__': + mne.set_log_level(verbose='WARNING') # Avoid full console + unittest.main() diff --git a/test/test_files/face.csv b/test/test_files/face.csv new file mode 100644 index 0000000..6c77daf --- /dev/null +++ b/test/test_files/face.csv @@ -0,0 +1,160 @@ +stimulus:1 +stimulus:2 +stimulus:3 +stimulus:4 +stimulus:5 +stimulus:6 +stimulus:7 +stimulus:8 +stimulus:9 +stimulus:10 +stimulus:11 +stimulus:12 +stimulus:13 +stimulus:14 +stimulus:15 +stimulus:16 +stimulus:17 +stimulus:18 +stimulus:19 +stimulus:20 +stimulus:21 +stimulus:22 +stimulus:23 +stimulus:24 +stimulus:25 +stimulus:26 +stimulus:27 +stimulus:28 +stimulus:29 +stimulus:30 +stimulus:31 +stimulus:32 +stimulus:33 +stimulus:34 +stimulus:35 +stimulus:36 +stimulus:37 +stimulus:38 +stimulus:39 +stimulus:40 +stimulus:41 +stimulus:42 +stimulus:43 +stimulus:44 +stimulus:45 +stimulus:46 +stimulus:47 +stimulus:48 +stimulus:49 +stimulus:50 +stimulus:51 +stimulus:52 +stimulus:53 +stimulus:54 +stimulus:55 +stimulus:56 +stimulus:57 +stimulus:58 +stimulus:59 +stimulus:60 +stimulus:61 +stimulus:62 +stimulus:63 +stimulus:64 +stimulus:65 +stimulus:66 +stimulus:67 +stimulus:68 +stimulus:69 +stimulus:70 +stimulus:71 +stimulus:72 +stimulus:73 +stimulus:74 +stimulus:75 +stimulus:76 +stimulus:77 +stimulus:78 +stimulus:79 +stimulus:80 +stimulus:101 +stimulus:102 +stimulus:103 +stimulus:104 +stimulus:105 +stimulus:106 +stimulus:107 +stimulus:108 +stimulus:109 +stimulus:110 +stimulus:111 +stimulus:112 +stimulus:113 +stimulus:114 +stimulus:115 +stimulus:116 +stimulus:117 +stimulus:118 +stimulus:119 +stimulus:120 +stimulus:121 +stimulus:122 +stimulus:123 +stimulus:124 +stimulus:125 +stimulus:126 +stimulus:127 +stimulus:128 +stimulus:129 +stimulus:130 +stimulus:131 +stimulus:132 +stimulus:133 +stimulus:134 +stimulus:135 +stimulus:136 +stimulus:137 +stimulus:138 +stimulus:139 +stimulus:140 +stimulus:141 +stimulus:142 +stimulus:143 +stimulus:144 +stimulus:145 +stimulus:146 +stimulus:147 +stimulus:148 +stimulus:149 +stimulus:150 +stimulus:151 +stimulus:152 +stimulus:153 +stimulus:154 +stimulus:155 +stimulus:156 +stimulus:157 +stimulus:158 +stimulus:159 +stimulus:160 +stimulus:161 +stimulus:162 +stimulus:163 +stimulus:164 +stimulus:165 +stimulus:166 +stimulus:167 +stimulus:168 +stimulus:169 +stimulus:170 +stimulus:171 +stimulus:172 +stimulus:173 +stimulus:174 +stimulus:175 +stimulus:176 +stimulus:177 +stimulus:178 +stimulus:179 +stimulus:180 diff --git a/test/test_files/face.txt b/test/test_files/face.txt new file mode 100644 index 0000000..65cfe23 --- /dev/null +++ b/test/test_files/face.txt @@ -0,0 +1,80 @@ +stimulus:1 +stimulus:2 +stimulus:3 +stimulus:4 +stimulus:5 +stimulus:6 +stimulus:7 +stimulus:8 +stimulus:9 +stimulus:10 +stimulus:11 +stimulus:12 +stimulus:13 +stimulus:14 +stimulus:15 +stimulus:16 +stimulus:17 +stimulus:18 +stimulus:19 +stimulus:20 +stimulus:21 +stimulus:22 +stimulus:23 +stimulus:24 +stimulus:25 +stimulus:26 +stimulus:27 +stimulus:28 +stimulus:29 +stimulus:30 +stimulus:31 +stimulus:32 +stimulus:33 +stimulus:34 +stimulus:35 +stimulus:36 +stimulus:37 +stimulus:38 +stimulus:39 +stimulus:40 +stimulus:101 +stimulus:102 +stimulus:103 +stimulus:104 +stimulus:105 +stimulus:106 +stimulus:107 +stimulus:108 +stimulus:109 +stimulus:110 +stimulus:111 +stimulus:112 +stimulus:113 +stimulus:114 +stimulus:115 +stimulus:116 +stimulus:117 +stimulus:118 +stimulus:119 +stimulus:120 +stimulus:121 +stimulus:122 +stimulus:123 +stimulus:124 +stimulus:125 +stimulus:126 +stimulus:127 +stimulus:128 +stimulus:129 +stimulus:130 +stimulus:131 +stimulus:132 +stimulus:133 +stimulus:134 +stimulus:135 +stimulus:136 +stimulus:137 +stimulus:138 +stimulus:139 +stimulus:140 diff --git a/test/test_files/face_intact.txt b/test/test_files/face_intact.txt new file mode 100644 index 0000000..6486e2d --- /dev/null +++ b/test/test_files/face_intact.txt @@ -0,0 +1,40 @@ +stimulus:1 +stimulus:2 +stimulus:3 +stimulus:4 +stimulus:5 +stimulus:6 +stimulus:7 +stimulus:8 +stimulus:9 +stimulus:10 +stimulus:11 +stimulus:12 +stimulus:13 +stimulus:14 +stimulus:15 +stimulus:16 +stimulus:17 +stimulus:18 +stimulus:19 +stimulus:20 +stimulus:21 +stimulus:22 +stimulus:23 +stimulus:24 +stimulus:25 +stimulus:26 +stimulus:27 +stimulus:28 +stimulus:29 +stimulus:30 +stimulus:31 +stimulus:32 +stimulus:33 +stimulus:34 +stimulus:35 +stimulus:36 +stimulus:37 +stimulus:38 +stimulus:39 +stimulus:40 \ No newline at end of file diff --git a/test/test_files/face_scrambled.txt b/test/test_files/face_scrambled.txt new file mode 100644 index 0000000..87b1d18 --- /dev/null +++ b/test/test_files/face_scrambled.txt @@ -0,0 +1,40 @@ +stimulus:101 +stimulus:102 +stimulus:103 +stimulus:104 +stimulus:105 +stimulus:106 +stimulus:107 +stimulus:108 +stimulus:109 +stimulus:110 +stimulus:111 +stimulus:112 +stimulus:113 +stimulus:114 +stimulus:115 +stimulus:116 +stimulus:117 +stimulus:118 +stimulus:119 +stimulus:120 +stimulus:121 +stimulus:122 +stimulus:123 +stimulus:124 +stimulus:125 +stimulus:126 +stimulus:127 +stimulus:128 +stimulus:129 +stimulus:130 +stimulus:131 +stimulus:132 +stimulus:133 +stimulus:134 +stimulus:135 +stimulus:136 +stimulus:137 +stimulus:138 +stimulus:139 +stimulus:140 \ No newline at end of file diff --git a/test/visual_sanity_checks.py b/test/visual_sanity_checks.py new file mode 100644 index 0000000..be13f27 --- /dev/null +++ b/test/visual_sanity_checks.py @@ -0,0 +1,19 @@ +from utils.file_utils import load_preprocessed_data, get_epochs + + +def check_peaks(): + """ + Sanity check for the "get_peaks" method + """ + import matplotlib.pyplot as plt + raw = load_preprocessed_data('002', 'N170') + epochs, _ = get_epochs(raw, [('face', 'intact')], picks='P7') + ch, latency, peak = epochs.average().get_peak(tmin=0.13, tmax=0.2, mode='neg', return_amplitude=True) + import numpy as np + plt.plot(epochs.times, np.squeeze(epochs.average().data.T)) + plt.vlines([0.13, 0.2], -0.00001, 0.00001, colors='r', linestyles='dotted') + plt.vlines(latency, -0.00001, 0.00001, colors='gray', linestyles='dotted') + plt.show() + + +check_peaks() diff --git a/utils/__pycache__/ccs_eeg_semesterproject.cpython-37.pyc b/utils/__pycache__/ccs_eeg_semesterproject.cpython-37.pyc new file mode 100644 index 0000000..479b31c Binary files /dev/null and b/utils/__pycache__/ccs_eeg_semesterproject.cpython-37.pyc differ diff --git a/utils/__pycache__/ccs_eeg_semesterproject.cpython-38.pyc b/utils/__pycache__/ccs_eeg_semesterproject.cpython-38.pyc new file mode 100644 index 0000000..4f28132 Binary files /dev/null and b/utils/__pycache__/ccs_eeg_semesterproject.cpython-38.pyc differ diff --git a/utils/__pycache__/ccs_eeg_utils.cpython-37.pyc b/utils/__pycache__/ccs_eeg_utils.cpython-37.pyc new file mode 100644 index 0000000..59b6bfd Binary files /dev/null and b/utils/__pycache__/ccs_eeg_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/ccs_eeg_utils.cpython-38.pyc b/utils/__pycache__/ccs_eeg_utils.cpython-38.pyc new file mode 100644 index 0000000..4653f0c Binary files /dev/null and b/utils/__pycache__/ccs_eeg_utils.cpython-38.pyc differ diff --git a/utils/__pycache__/file_utils.cpython-37.pyc b/utils/__pycache__/file_utils.cpython-37.pyc new file mode 100644 index 0000000..eb8a040 Binary files /dev/null and b/utils/__pycache__/file_utils.cpython-37.pyc differ diff --git a/utils/__pycache__/plot_utils.cpython-37.pyc b/utils/__pycache__/plot_utils.cpython-37.pyc new file mode 100644 index 0000000..445f3f1 Binary files /dev/null and b/utils/__pycache__/plot_utils.cpython-37.pyc differ diff --git a/utils/ccs_eeg_semesterproject.py b/utils/ccs_eeg_semesterproject.py new file mode 100644 index 0000000..025e9b4 --- /dev/null +++ b/utils/ccs_eeg_semesterproject.py @@ -0,0 +1,70 @@ +import os +import mne +import numpy as np +import pandas as pd +from mne_bids import (BIDSPath, read_raw_bids) + + +def _get_filepath(bids_root, subject_id, task): + bids_path = BIDSPath(subject=subject_id, task=task, session=task, + datatype='eeg', suffix='eeg', + root=bids_root) + # this is not a bids-conform file format, but a derivate/extension. Therefore we have to hack a bit + # Depending on path structure, this might push a warning. + fn = os.path.splitext(bids_path.fpath.__str__())[0] + assert (fn[-3:] == "eeg") + fn = fn[0:-3] + return fn + + +def load_precomputed_ica(bids_root, subject_id, task): + # returns ICA and badComponents (starting at component = 0). + # Note the existance of add_ica_info in case you want to plot something. + fn = _get_filepath(bids_root, subject_id, task) + 'ica' + + # import the eeglab ICA. I used eeglab because the "amica" ICA is a bit more powerful than runica + ica = mne.preprocessing.read_ica_eeglab(fn + '.set') + # ica = custom_read_eeglab_ica(fn+'.set') + # Potentially for plotting one might want to copy over the raw.info, but in this function we dont have access / dont want to load it + # ica.info = raw.info + ica._update_ica_names() + badComps = np.loadtxt(fn + '.tsv', delimiter="\t") + badComps -= 1 # start counting at 0 + + # if only a single component is in the file, we get an error here because it is an ndarray with n-dim = 0. + if len(badComps.shape) == 0: + badComps = [float(badComps)] + return ica, badComps + + +def add_ica_info(raw, ica): + # This function exists due to a MNE bug: https://github.com/mne-tools/mne-python/issues/8581 + # In case you want to plot your ICA components, this function will generate a ica.info + ch_raw = raw.info['ch_names'] + ch_ica = ica.ch_names + + ix = [k for k, c in enumerate(ch_raw) if c in ch_ica and not c in raw.info['bads']] + info = raw.info.copy() + mne.io.pick.pick_info(info, ix, copy=False) + ica.info = info + + return ica + + +def load_precomputed_badData(bids_root, subject_id, task): + # return precomputed annotations and bad channels (first channel = 0) + + fn = _get_filepath(bids_root, subject_id, task) + print(fn) + + tmp = pd.read_csv(fn + 'badSegments.csv') + # print(tmp) + annotations = mne.Annotations(tmp.onset, tmp.duration, tmp.description) + # Unfortunately MNE assumes that csv files are in milliseconds and only txt files in seconds.. wth? + # annotations = mne.read_annotations(fn+'badSegments.csv') + badChannels = np.loadtxt(fn + 'badChannels.tsv', delimiter='\t') + badChannels = badChannels.astype(int) + badChannels -= 1 # start counting at 0 + + # badChannels = [int(b) for b in badChannels] + return annotations, badChannels \ No newline at end of file diff --git a/utils/ccs_eeg_utils.py b/utils/ccs_eeg_utils.py new file mode 100644 index 0000000..721a8a8 --- /dev/null +++ b/utils/ccs_eeg_utils.py @@ -0,0 +1,275 @@ +from osfclient import cli +import os +from mne_bids.read import _from_tsv, _drop +from mne_bids import (BIDSPath, read_raw_bids) +import mne +import numpy as np + +import scipy.ndimage +import scipy.signal +from numpy import sin as sin + + +def read_annotations_core(bids_path, raw): + tsv = os.path.join(bids_path.directory, bids_path.update(suffix="events", extension=".tsv").basename) + _handle_events_reading_core(tsv, raw) + + +def _handle_events_reading_core(events_fname, raw): + """Read associated events.tsv and populate raw. + Handle onset, duration, and description of each event. + """ + events_dict = _from_tsv(events_fname) + + if ('value' in events_dict) and ('trial_type' in events_dict): + events_dict = _drop(events_dict, 'n/a', 'trial_type') + events_dict = _drop(events_dict, 'n/a', 'value') + + descriptions = np.asarray([a + ':' + b for a, b in zip(events_dict["trial_type"], events_dict["value"])], + dtype=str) + + # Get the descriptions of the events + elif 'trial_type' in events_dict: + + # Drop events unrelated to a trial type + events_dict = _drop(events_dict, 'n/a', 'trial_type') + descriptions = np.asarray(events_dict['trial_type'], dtype=str) + + # If we don't have a proper description of the events, perhaps we have + # at least an event value? + elif 'value' in events_dict: + # Drop events unrelated to value + events_dict = _drop(events_dict, 'n/a', 'value') + descriptions = np.asarray(events_dict['value'], dtype=str) + # Worst case, we go with 'n/a' for all events + else: + descriptions = 'n/a' + # Deal with "n/a" strings before converting to float + ons = [np.nan if on == 'n/a' else on for on in events_dict['onset']] + dus = [0 if du == 'n/a' else du for du in events_dict['duration']] + onsets = np.asarray(ons, dtype=float) + durations = np.asarray(dus, dtype=float) + # Keep only events where onset is known + good_events_idx = ~np.isnan(onsets) + onsets = onsets[good_events_idx] + durations = durations[good_events_idx] + descriptions = descriptions[good_events_idx] + del good_events_idx + # Add Events to raw as annotations + annot_from_events = mne.Annotations(onset=onsets, + duration=durations, + description=descriptions, + orig_time=None) + raw.set_annotations(annot_from_events) + return raw + + +# taken from the osfclient tutorial https://github.com/ubcbraincircuits/osfclienttutorial +class args: + def __init__(self, project, username=None, update=True, force=False, destination=None, source=None, recursive=False, + target=None, output=None, remote=None, local=None): + self.project = project + self.username = username + self.update = update # applies to upload, clone, and fetch + self.force = force # applies to fetch and upload + # upload arguments: + self.destination = destination + self.source = source + self.recursive = recursive + # remove argument: + self.target = target + # clone argument: + self.output = output + # fetch arguments: + self.remote = remote + self.local = local + + +def download_erpcore(task="MMN", subject=1, localpath="local/bids/"): + project = "9f5w7" # after recent change they put everything as "sessions" in one big BIDS file + + arguments = args(project) # project ID + for extension in ["channels.tsv", "events.tsv", "eeg.fdt", "eeg.json", "eeg.set"]: + targetpath = '/sub-{:03d}/ses-{}/eeg/sub-{:03d}_ses-{}_task-{}_{}'.format(subject, task, subject, task, task, + extension) + print("Downloading {}".format(targetpath)) + arguments.remote = "\\ERP_CORE_BIDS_Raw_Files/" + targetpath + arguments.local = localpath + targetpath + cli.fetch(arguments) + + +def simulate_ICA(dims=4): + A = [[-0.3, 0.2], [.2, 0.1]] + sample_rate = 100.0 + nsamples = 1000 + t = np.arange(nsamples) / sample_rate + + s = [] + + # boxcars + s.append(np.mod(np.array(range(0, nsamples)), 250) > 125) + # a triangle staircase + trend + s.append((np.mod(np.array(range(0, nsamples)), 100) + np.array(range(0, nsamples)) * 0.05) / 100) + if dims == 4: + A = np.array([[.7, 0.3, 0.2, -0.5], [0.2, -0.5, -0.2, 0.3], [-.3, 0.1, 0, 0.2], [-0.5, -0.3, -0.2, 0.8]]) + + # some sinosoids + s.append(np.cos(2 * np.pi * 0.5 * t) + 0.2 * np.sin(2 * np.pi * 2.5 * t + 0.1) + 0.2 * np.sin( + 2 * np.pi * 15.3 * t) + 0.1 * np.sin(2 * np.pi * 16.7 * t + 0.1) + 0.1 * np.sin(2 * np.pi * 23.45 * t + .8)) + # uniform noise + s.append(0.2 * np.random.rand(nsamples)) + x = np.matmul(A, np.array(s)) + return x + + +def spline_matrix(x, knots): + # bah, spline-matrices are a pain to implement. + # But package "patsy" with function "bs" crashed my notebook... + # Anyway, knots define where the spline should be anchored. The default should work + # X defines where the spline set should be evaluated. + # e.g. call using: spline_matrix(np.linspace(0,0.95,num=100)) + import scipy.interpolate as si + + x_tup = si.splrep(knots, knots, k=3) + nknots = len(x_tup[0]) + x_i = np.empty((len(x), nknots - 4)) + for i in range(nknots - 4): + vec = np.zeros(nknots) + vec[i] = 1.0 + x_list = list(x_tup) + x_list[1] = vec.tolist() + x_i[:, i] = si.splev(x, x_list) + return x_i + + +def simulate_TF(signal=1, noise=True): + # signal can be 1 (image), 2(chirp) or 3 (steps) + import imageio + + if signal == 2: + im = imageio.imread('ex9_tf.png') + + im = im[0:60, :, 3] - im[0:60, :, 1] + # im = scipy.ndimage.zoom(im,[1,1]) + im = np.flip(im, axis=0) + + # plt.imshow(im,origin='lower') + + # sig = (scipy.fft.irfft(im.T,axis=1)) + + nov = 10; + im.shape[0] * 0.5 + nperseg = 50; + im.shape[0] - 1 + t, sig = scipy.signal.istft(im, fs=500, noverlap=nov, nperseg=nperseg) + sig = sig / 300 # normalize + elif signal == 3: + sig = scipy.signal.chirp(t=np.arange(0, 10, 1 / 500), f0=1, f1=100, t1=2, method='linear', phi=90) + elif signal == 1: + + x = np.arange(0, 2, 1 / 500) + sig_steps = np.concatenate([1.0 * sin(2 * np.pi * x * 50), 1.2 * sin(2 * np.pi * x * 55 + np.pi / 2), + 0.8 * sin(2 * np.pi * x * 125 + np.pi), + 1.0 * sin(2 * np.pi * x * 120 + 3 * np.pi / 2)]) + + sig = sig_steps + if noise: + sig = sig + 0.1 * np.std(sig) * np.random.randn(sig.shape[0]) + + return sig + + +def get_TF_dataset(subject_id='002', bids_root="../local/bids"): + bids_path = BIDSPath(subject=subject_id, task="P3", session="P3", + datatype='eeg', suffix='eeg', + root=bids_root) + + raw = read_raw_bids(bids_path) + read_annotations_core(bids_path, raw) + # raw.pick_channels(["Cz"])#["Pz","Fz","Cz"]) + raw.load_data() + raw.set_montage('standard_1020', match_case=False) + + evts, evts_dict = mne.events_from_annotations(raw) + wanted_keys = [e for e in evts_dict.keys() if "response" in e] + evts_dict_stim = dict((k, evts_dict[k]) for k in wanted_keys if k in evts_dict) + epochs = mne.Epochs(raw, evts, evts_dict_stim, tmin=-1, tmax=2) + return epochs + + +def get_classification_dataset(subject=1, typeInt=4): + # TypeInt: + # Task 1 (open and close left or right fist) + # Task 2 (imagine opening and closing left or right fist) + # Task 3 (open and close both fists or both feet) + # Task 4 (imagine opening and closing both fists or both feet) + assert (typeInt >= 1) + assert (typeInt <= 4) + from mne.io import concatenate_raws, read_raw_edf + from mne.datasets import eegbci + tmin, tmax = -1., 4. + runs = [3, 7, 11] + runs = [r + typeInt - 1 for r in runs] + print("loading subject {} with runs {}".format(subject, runs)) + if typeInt <= 1: + event_id = dict(left=2, right=3) + else: + event_id = dict(hands=2, feet=3) + + raw_fnames = eegbci.load_data(subject, runs) + raws = [read_raw_edf(f, preload=True) for f in raw_fnames] + raw = concatenate_raws(raws) + + raw.filter(7., 30., fir_design='firwin', skip_by_annotation='edge') + + eegbci.standardize(raw) # set channel names + montage = mne.channels.make_standard_montage('standard_1005') + raw.set_montage(montage) + raw.rename_channels(lambda x: x.strip('.')) + events, _ = mne.events_from_annotations(raw, event_id=dict(T1=2, T2=3)) + + picks = mne.pick_types(raw.info, meg=False, eeg=True, stim=False, eog=False, + exclude='bads') + + # Read epochs (train will be done only between 1 and 2s) + # Testing will be done with a running classifier + epochs = mne.Epochs(raw, events, event_id, tmin, tmax, proj=True, picks=picks, + baseline=None, preload=True) + return (epochs) + + +def ex8_simulateData(width=40, n_subjects=15, signal_mean=100, noise_between=30, noise_within=10, smooth_sd=4, + rng_seed=43): + # adapted and extended from https://mne.tools/dev/auto_tutorials/discussions/plot_background_statistics.html#sphx-glr-auto-tutorials-discussions-plot-background-statistics-py + rng = np.random.RandomState(rng_seed) + # For each "subject", make a smoothed noisy signal with a centered peak + + X = noise_within * rng.randn(n_subjects, width, width) + # Add three signals + X[:, width // 6 * 2, width // 6 * 2] -= signal_mean / 3 * 3 + rng.randn(n_subjects) * noise_between + X[:, width // 6 * 4, width // 6 * 4] += signal_mean / 3 * 2 + rng.randn(n_subjects) * noise_between + X[:, width // 6 * 5, width // 6 * 5] += signal_mean / 3 * 2 + rng.randn(n_subjects) * noise_between + # Spatially smooth with a 2D Gaussian kernel + size = width // 2 - 1 + gaussian = np.exp(-(np.arange(-size, size + 1) ** 2 / float(smooth_sd ** 2))) + for si in range(X.shape[0]): + for ri in range(X.shape[1]): + X[si, ri, :] = np.convolve(X[si, ri, :], gaussian, 'same') + for ci in range(X.shape[2]): + X[si, :, ci] = np.convolve(X[si, :, ci], gaussian, 'same') + # X += 10 * rng.randn(n_subjects, width, width) + return X + + +def stc_plot2img(h, title="SourceEstimate", closeAfterwards=False, crop=True): + h.add_text(0.1, 0.9, title, 'title', font_size=16) + screenshot = h.screenshot() + if closeAfterwards: + h.close() + + if crop: + nonwhite_pix = (screenshot != 255).any(-1) + nonwhite_row = nonwhite_pix.any(1) + nonwhite_col = nonwhite_pix.any(0) + screenshot = screenshot[nonwhite_row][:, nonwhite_col] + return screenshot diff --git a/utils/file_utils.py b/utils/file_utils.py new file mode 100644 index 0000000..b8faeca --- /dev/null +++ b/utils/file_utils.py @@ -0,0 +1,89 @@ +import os +import mne + + +def load_bad_annotations(filepath, fileending="badSegments.csv"): + """ + Loads the annotations for bad segments + :param filepath: The path to the file we want to load + :param fileending: Depending if the subject, for which we load the annotations, was preprocessed manually, + the ending of the filename will be different. + The default are file endings of the given preprocessed annotations: "badSegments.csv" + For the manual preprocessed annotations, the file endings are: "badannotations.csv" + :return: The mne annotations + """ + if os.path.isfile(filepath + "_" + fileending): + return mne.read_annotations(filepath + "_" + fileending) + + +def load_preprocessed_data(subject, dataset): + """ + Load the raw object as well as the annotations of the preprocessed file + :param subject: The subject, for which we want to load the raw object + :param dataset: The currently viewed dataset + :param selected_subjects: The manually preprocessed subjects + :return: The raw object + """ + folder = "Dataset\\" + dataset + "\\sub-" + subject + "\\ses-" + dataset + "\\eeg\\" + filepath = folder + "sub-" + subject + "_task-" + dataset + raw = mne.io.read_raw_fif(filepath + "_cleaned.fif") + return raw + + +def get_keys_for_events(stimulus=None, condition=None): + """ + For a given stimulus and condition get all the event keys. + :param stimulus: Either 'face' or 'car' or 'None' for no stimulus + :param condition: Either 'intact' or 'scrambled' or 'None' for no condition + :return: A list of keys or 'stimulus' if neither stimulus or condition was given + """ + if stimulus == 'face': + if condition == 'intact': + return ["stimulus:{}".format(k) for k in range(1, 41)] + elif condition == 'scrambled': + return ["stimulus:{}".format(k) for k in range(101, 141)] + else: # All faces + return ["stimulus:{}".format(k) for k in list(range(1, 41)) + list(range(101, 141))] + elif stimulus == 'car': + if condition == 'intact': + return ["stimulus:{}".format(k) for k in range(41, 81)] + elif condition == 'scrambled': + return ["stimulus:{}".format(k) for k in range(141, 181)] + else: # All cars + return ["stimulus:{}".format(k) for k in list(range(41, 81)) + list(range(141, 181))] + else: # If no stimulus is given + if condition == 'intact': + return ["stimulus:{}".format(k) for k in range(1, 41) and range(41, 81)] + elif condition == 'scrambled': + return ["stimulus:{}".format(k) for k in list(range(101, 141)) + list(range(141, 181))] + else: # Every stimulus + return "stimulus" + + +def get_epochs(raw, conditions=None, picks=None, tmin=-0.1, tmax=1): + """ + Returns the epochs for a given dataset + :param raw: the dataset + :param conditions: A List of tuples, of the Form [(stimulus, condition), (stimulus,condition)] + i.e. [('face',None), ('car', 'scrambled')] returns the epochs where the stimulus is face and the stim+condition is car+scrambled + default is None, i.e. everything + :param picks: a list. Additional criteria for picking the epochs, e.g. channels, etc. + :param tmin: onset before the event + :param tmax: end after the event + :return: + """ + + events, events_dict = mne.events_from_annotations(raw) + events_dict_key = {} + if conditions is None: + conditions = [(None, None)] + + for condition in conditions: + wanted_keys = get_keys_for_events(condition[0], condition[1]) + if wanted_keys == "stimulus": + wanted_keys = [e for e in events_dict.keys() if "stimulus" in e] + events_dict_key.update(dict((k, events_dict[k]) for k in wanted_keys if k in events_dict)) + epochs = mne.Epochs(raw, events, events_dict_key, tmin=tmin, tmax=tmax, reject_by_annotation=False, picks=picks) + epochs.drop_bad() + + return epochs, events_dict_key diff --git a/utils/plot_utils.py b/utils/plot_utils.py new file mode 100644 index 0000000..349172e --- /dev/null +++ b/utils/plot_utils.py @@ -0,0 +1,89 @@ +import mne + +import matplotlib.pyplot as plt +import numpy as np + +from matplotlib import cm +from matplotlib.colors import LogNorm + +from utils.file_utils import load_preprocessed_data, get_keys_for_events + + +def plot_grand_average(dataset): + """ + Plot the grand average ERPs + :param dataset: the datset for which the grand average is computed + """ + evtss = [('face', 'intact'), ('face', 'scrambled'), ('car', 'intact'), ('car', 'scrambled')] + chs = ['P7', 'PO7', 'P8', 'PO8'] + for ch in chs: + fi = [] + fs = [] + ci = [] + cs = [] + for i in range(1, 41): + subj = "0" + str(i) + if len(str(i)) == 1: + subj = "0" + subj + # Load preprocessed .fif data files + raw = load_preprocessed_data(subj, dataset) + # Epoch the data + for ev in evtss: + wanted_keys = get_keys_for_events(ev[0], ev[1]) + events, events_dict = mne.events_from_annotations(raw) + events_dict_key = dict((k, events_dict[k]) for k in wanted_keys if k in events_dict) + epochs = mne.Epochs(raw, events, events_dict_key, tmin=-0.1, tmax=1, reject_by_annotation=True, + picks=[ch]) + # Get the N170 peak + # First construct a data frame + if ev[0] == 'face' and ev[1] == 'intact': + fi.append(epochs.average(picks=[ch])) + elif ev[0] == 'face' and ev[1] == 'scrambled': + fs.append(epochs.average(picks=[ch])) + elif ev[0] == 'car' and ev[1] == 'intact': + ci.append(epochs.average(picks=[ch])) + elif ev[0] == 'car' and ev[1] == 'scrambled': + cs.append(epochs.average(picks=[ch])) + ga_fi = mne.grand_average(fi) + ga_ci = mne.grand_average(ci) + ga_fs = mne.grand_average(fs) + ga_cs = mne.grand_average(cs) + ga_fi.comment = 'Face Intact' + ga_ci.comment = 'Car Intact' + ga_fs.comment = 'Face Scrambled' + ga_cs.comment = 'Car Scrambled' + mne.viz.plot_compare_evokeds([ga_fi, ga_ci, ga_fs, ga_cs], picks=ch, colors=['blue', 'red', 'blue', 'red'], + linestyles=['solid', 'solid', 'dotted', 'dotted']) + + +def plot_tf_cluster(F, clusters, cluster_p_values, freqs, times): + """ + Plot teh F-Statistic values of permutation clusters with p-values <= 0.05 in color and > 0.05 in grey. + + :param F: F-Statistics of the permutation clusters + :param clusters: all permutation clusters + :param cluster_p_values: p-values of the clusters + :param freqs: frequency domain + :param times: time domain + """ + good_c = np.nan * np.ones_like(F) + for clu, p_val in zip(clusters, cluster_p_values): + if p_val <= 0.05: + good_c[clu] = F[clu] + + bbox = [times[0], times[-1], freqs[0], freqs[-1]] + plt.imshow(F, aspect='auto', origin='lower', cmap=cm.gray, extent=bbox, interpolation='None') + a = plt.imshow(good_c, cmap=cm.RdBu_r, aspect='auto', origin='lower', extent=bbox, interpolation='None') + plt.colorbar(a) + plt.xlabel('Time (s)') + plt.ylabel('Frequency (Hz)') + plt.show() + + +def plot_oscillation_bands(condition): + fig, axis = plt.subplots(1, 5, figsize=(25, 5)) + condition.plot_topomap(baseline=(-0.2, 0), fmin=0, fmax=4, title='Delta', axes=axis[0], show=False, vmin=0, vmax=1.5, tmin=0, tmax=1) + condition.plot_topomap(baseline=(-0.2, 0), fmin=4, fmax=8, title='Theta', axes=axis[1], show=False, vmin=0, vmax=0.7, tmin=0, tmax=1) + condition.plot_topomap(baseline=(-0.2, 0), fmin=8, fmax=12, title='Alpha', axes=axis[2], show=False, vmin=-0.15, vmax=0.2, tmin=0, tmax=1) + condition.plot_topomap(baseline=(-0.2, 0), fmin=13, fmax=30, title='Beta', axes=axis[3], show=False, vmin=-0.18, vmax=0.2, tmin=0, tmax=1) + condition.plot_topomap(baseline=(-0.2, 0), fmin=30, fmax=45, title='Gamma', axes=axis[4], vmin=0, vmax=0.2, tmin=0, tmax=1)