Files
semesterproject_lecture_eeg/utils/plot_utils.py

101 lines
4.4 KiB
Python

import mne
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from matplotlib.colors import LogNorm
from utils.file_utils import load_preprocessed_data, get_keys_for_events
def plot_grand_average(dataset):
"""
Plot the grand average ERPs
:param dataset: the datset for which the grand average is computed
"""
evtss = [('face', 'intact'), ('face', 'scrambled'), ('car', 'intact'), ('car', 'scrambled')]
chs = ['P7', 'PO7', 'P8', 'PO8']
for ch in chs:
fi = []
fs = []
ci = []
cs = []
for i in range(1, 41):
subj = "0" + str(i)
if len(str(i)) == 1:
subj = "0" + subj
# Load preprocessed .fif data files
raw = load_preprocessed_data(subj, dataset)
# Epoch the data
for ev in evtss:
wanted_keys = get_keys_for_events(ev[0], ev[1])
events, events_dict = mne.events_from_annotations(raw)
events_dict_key = dict((k, events_dict[k]) for k in wanted_keys if k in events_dict)
epochs = mne.Epochs(raw, events, events_dict_key, tmin=-0.1, tmax=1, reject_by_annotation=True,
picks=[ch])
# Get the N170 peak
# First construct a data frame
if ev[0] == 'face' and ev[1] == 'intact':
fi.append(epochs.average(picks=[ch]))
elif ev[0] == 'face' and ev[1] == 'scrambled':
fs.append(epochs.average(picks=[ch]))
elif ev[0] == 'car' and ev[1] == 'intact':
ci.append(epochs.average(picks=[ch]))
elif ev[0] == 'car' and ev[1] == 'scrambled':
cs.append(epochs.average(picks=[ch]))
ga_fi = mne.grand_average(fi)
ga_ci = mne.grand_average(ci)
ga_fs = mne.grand_average(fs)
ga_cs = mne.grand_average(cs)
ga_fi.comment = 'Face Intact'
ga_ci.comment = 'Car Intact'
ga_fs.comment = 'Face Scrambled'
ga_cs.comment = 'Car Scrambled'
mne.viz.plot_compare_evokeds([ga_fi, ga_ci, ga_fs, ga_cs], picks=ch, colors=['blue', 'red', 'blue', 'red'],
linestyles=['solid', 'solid', 'dotted', 'dotted'])
def plot_tf_cluster(F, clusters, cluster_p_values, freqs, times):
"""
Plot the F-Statistic values of permutation clusters with p-values <= 0.05 in color and > 0.05 in grey.
Currently only works well for the linear scaling. For the logarithmic scaling a different x-axis has to be chosen
:param F: F-Statistics of the permutation clusters
:param clusters: all permutation clusters
:param cluster_p_values: p-values of the clusters
:param freqs: frequency domain
:param times: time domain
"""
good_c = np.nan * np.ones_like(F)
for clu, p_val in zip(clusters, cluster_p_values):
if p_val <= 0.05:
good_c[clu] = F[clu]
bbox = [times[0], times[-1], freqs[0], freqs[-1]]
plt.imshow(F, aspect='auto', origin='lower', cmap=cm.gray, extent=bbox, interpolation='None')
a = plt.imshow(good_c, cmap=cm.RdBu_r, aspect='auto', origin='lower', extent=bbox, interpolation='None')
plt.colorbar(a)
plt.xlabel('Time (s)')
plt.ylabel('Frequency (Hz)')
plt.show()
def plot_oscillation_bands(condition):
"""
Plot the oscillation bands for a given condition in the time from 130ms to 200ms
:param condition: the condition to plot the oscillation bands for
"""
fig, axis = plt.subplots(1, 5, figsize=(25, 5))
condition.plot_topomap(baseline=(-0.2, 0), fmin=0, fmax=4, title='Delta', axes=axis[0], show=False, vmin=0,
vmax=1.5, tmin=0.13, tmax=0.2)
condition.plot_topomap(baseline=(-0.2, 0), fmin=4, fmax=8, title='Theta', axes=axis[1], show=False, vmin=0,
vmax=0.7, tmin=0.13, tmax=0.2)
condition.plot_topomap(baseline=(-0.2, 0), fmin=8, fmax=12, title='Alpha', axes=axis[2], show=False, vmin=-0.25,
vmax=0.2, tmin=0.13, tmax=0.2)
condition.plot_topomap(baseline=(-0.2, 0), fmin=13, fmax=30, title='Beta', axes=axis[3], show=False, vmin=-0.21,
vmax=0.2, tmin=0.13, tmax=0.2)
condition.plot_topomap(baseline=(-0.2, 0), fmin=30, fmax=45, title='Gamma', axes=axis[4], vmin=-0.05, vmax=0.2,
tmin=0.13,
tmax=0.2)