87 lines
3.9 KiB
Python
87 lines
3.9 KiB
Python
import unittest
|
|
|
|
import mne
|
|
import pandas as pd
|
|
|
|
from decoding_tf_analysis import events_to_labels, permutation_test
|
|
from erp_analysis import create_peak_difference_feature
|
|
from utils.file_utils import get_keys_for_events, get_epochs
|
|
from pandas.testing import assert_frame_equal
|
|
|
|
|
|
class TestFileUtils(unittest.TestCase):
|
|
|
|
def setUp(self):
|
|
# Load true values for keys
|
|
with open('test_files/face.txt') as f:
|
|
face = f.readlines()
|
|
self.face = [x.strip() for x in face]
|
|
with open('test_files/face_intact.txt') as f:
|
|
face_intact = f.readlines()
|
|
self.face_intact = [x.strip() for x in face_intact]
|
|
with open('test_files/face_scrambled.txt') as f:
|
|
face_scrambled = f.readlines()
|
|
self.face_scrambled = [x.strip() for x in face_scrambled]
|
|
|
|
# Load true epochs
|
|
self.raw = mne.io.read_raw_fif("..\\Dataset\\n170\\sub-001\\ses-n170\\eeg\\sub-001_task-n170_cleaned.fif")
|
|
wanted_keys = get_keys_for_events("face", "intact")
|
|
events, events_dict = mne.events_from_annotations(self.raw)
|
|
events_dict_key = dict((k, events_dict[k]) for k in wanted_keys if k in events_dict)
|
|
self.given = mne.Epochs(self.raw, events, events_dict_key, tmin=-0.2, tmax=0.5, reject_by_annotation=False,
|
|
picks='P7')
|
|
self.given.drop_bad()
|
|
self.given.load_data()
|
|
|
|
def test_keys_for_events(self):
|
|
# Test only for face key generation, as the rest is generated analogously
|
|
self.assertEqual("stimulus", get_keys_for_events(stimulus=None, condition=None))
|
|
self.assertEqual(self.face, get_keys_for_events(stimulus='face', condition=None))
|
|
self.assertEqual(self.face_intact, get_keys_for_events(stimulus='face', condition='intact'))
|
|
self.assertEqual(self.face_scrambled, get_keys_for_events(stimulus='face', condition='scrambled'))
|
|
|
|
def test_get_epochs(self):
|
|
# Get a epoch to compare against
|
|
epochs, key = get_epochs(self.raw, [("face", "intact")], picks='P7', tmin=-0.2, tmax=0.5)
|
|
epochs.load_data()
|
|
self.assertEqual(self.given, epochs)
|
|
|
|
|
|
class TestERPAnalysis(unittest.TestCase):
|
|
|
|
def test_difference_features(self):
|
|
# Check if the correct features are created
|
|
subject_ids = [1, 1, 1, 1]
|
|
stimuli = ['face', 'face', 'car', 'car']
|
|
conditions = ['intact', 'scrambled', 'intact', 'scrambled']
|
|
peaks = [1, 2, 3, 4]
|
|
df = pd.DataFrame(data={'subject_id': subject_ids, 'stimulus': stimuli, 'condition': conditions,
|
|
'peak': peaks})
|
|
diff_df_true = pd.DataFrame(
|
|
data={'subject_id': ['001'], 'mean_face': [1.5], 'mean_car': [3.5], 'peak_diff_overall': [-2.0],
|
|
'diff_intact': [2.0], 'diff_scrambled': [2.0], 'diff_face': [1.0], 'diff_fc_ci': [-1.0],
|
|
'diff_fi_rest': [-2.0]})
|
|
diff_df = create_peak_difference_feature(df, 1)
|
|
assert_frame_equal(diff_df_true, diff_df)
|
|
|
|
|
|
class TestDecodingTFAnalysis(unittest.TestCase):
|
|
def test_events_to_labels(self):
|
|
# Only check for stimuli 1-40, 41-80, 101-140, 141-180 as no other stimuli are possible
|
|
events_dict = {'stimulus:1': 1, 'stimulus:40': 2, 'stimulus:41': 3, 'stimulus:80': 4, 'stimulus:101': 5,
|
|
'stimulus:140': 6, 'stimulus:141': 7, 'stimulus:180': 8}
|
|
labels = events_to_labels([1, 2, 3, 4, 5, 6, 7, 8], events_dict, [0, 1, 2, 3])
|
|
self.assertCountEqual([0, 0, 1, 1, 2, 2, 3, 3], labels)
|
|
|
|
def test_permutation_test(self):
|
|
# Check permutation test
|
|
p = permutation_test([0, 0, 0, 0], [0, 0, 0, 0], 100)
|
|
self.assertEqual(1, p)
|
|
p = permutation_test([0, 0, 0, 0, 0], [5, 10, 15, 10, 5], 100)
|
|
self.assertGreater(0.05, p)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
mne.set_log_level(verbose='WARNING') # Avoid full console
|
|
unittest.main()
|