Initial commit

This commit is contained in:
2021-03-27 14:58:49 +01:00
commit fc4eec6ac7
20 changed files with 1613 additions and 0 deletions

86
test/test.py Normal file
View File

@@ -0,0 +1,86 @@
import unittest
import mne
import pandas as pd
from decoding_tf_analysis import events_to_labels, permutation_test
from erp_analysis import create_peak_difference_feature
from utils.file_utils import get_keys_for_events, get_epochs
from pandas.testing import assert_frame_equal
class TestFileUtils(unittest.TestCase):
def setUp(self):
# Load true values for keys
with open('test_files/face.txt') as f:
face = f.readlines()
self.face = [x.strip() for x in face]
with open('test_files/face_intact.txt') as f:
face_intact = f.readlines()
self.face_intact = [x.strip() for x in face_intact]
with open('test_files/face_scrambled.txt') as f:
face_scrambled = f.readlines()
self.face_scrambled = [x.strip() for x in face_scrambled]
# Load true epochs
self.raw = mne.io.read_raw_fif("..\\Dataset\\n170\\sub-001\\ses-n170\\eeg\\sub-001_task-n170_cleaned.fif")
wanted_keys = get_keys_for_events("face", "intact")
events, events_dict = mne.events_from_annotations(self.raw)
events_dict_key = dict((k, events_dict[k]) for k in wanted_keys if k in events_dict)
self.given = mne.Epochs(self.raw, events, events_dict_key, tmin=-0.2, tmax=0.5, reject_by_annotation=False,
picks='P7')
self.given.drop_bad()
self.given.load_data()
def test_keys_for_events(self):
# Test only for face key generation, as the rest is generated analogously
self.assertEqual("stimulus", get_keys_for_events(stimulus=None, condition=None))
self.assertEqual(self.face, get_keys_for_events(stimulus='face', condition=None))
self.assertEqual(self.face_intact, get_keys_for_events(stimulus='face', condition='intact'))
self.assertEqual(self.face_scrambled, get_keys_for_events(stimulus='face', condition='scrambled'))
def test_get_epochs(self):
# Get a epoch to compare against
epochs, key = get_epochs(self.raw, [("face", "intact")], picks='P7', tmin=-0.2, tmax=0.5)
epochs.load_data()
self.assertEqual(self.given, epochs)
class TestERPAnalysis(unittest.TestCase):
def test_difference_features(self):
# Check if the correct features are created
subject_ids = [1, 1, 1, 1]
stimuli = ['face', 'face', 'car', 'car']
conditions = ['intact', 'scrambled', 'intact', 'scrambled']
peaks = [1, 2, 3, 4]
df = pd.DataFrame(data={'subject_id': subject_ids, 'stimulus': stimuli, 'condition': conditions,
'peak': peaks})
diff_df_true = pd.DataFrame(
data={'subject_id': ['001'], 'mean_face': [1.5], 'mean_car': [3.5], 'peak_diff_overall': [-2.0],
'diff_intact': [2.0], 'diff_scrambled': [2.0], 'diff_face': [1.0], 'diff_fc_ci': [-1.0],
'diff_fi_rest': [-2.0]})
diff_df = create_peak_difference_feature(df, 1)
assert_frame_equal(diff_df_true, diff_df)
class TestDecodingTFAnalysis(unittest.TestCase):
def test_events_to_labels(self):
# Only check for stimuli 1-40, 41-80, 101-140, 141-180 as no other stimuli are possible
events_dict = {'stimulus:1': 1, 'stimulus:40': 2, 'stimulus:41': 3, 'stimulus:80': 4, 'stimulus:101': 5,
'stimulus:140': 6, 'stimulus:141': 7, 'stimulus:180': 8}
labels = events_to_labels([1, 2, 3, 4, 5, 6, 7, 8], events_dict, [0, 1, 2, 3])
self.assertCountEqual([0, 0, 1, 1, 2, 2, 3, 3], labels)
def test_permutation_test(self):
# Check permutation test
p = permutation_test([0, 0, 0, 0], [0, 0, 0, 0], 100)
self.assertEqual(1, p)
p = permutation_test([0, 0, 0, 0, 0], [5, 10, 15, 10, 5], 100)
self.assertGreater(0.05, p)
if __name__ == '__main__':
mne.set_log_level(verbose='WARNING') # Avoid full console
unittest.main()