diff --git a/test/test_utils.py b/test/test_utils.py index 82d3268b12994fed46ee599a8c1b93a6555b97a1..b8b82139e8fd78cf8527bcd872d002fc075a9c35 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -1,10 +1,12 @@ from unittest import TestCase import unittest.mock as mock -from src.utils import digits_in_string, make_directory, read_json +from src.utils import digits_in_string, make_directory, read_json, drop_file_extension, make_sorted_sample, \ + sample_from_video_with_gt import random import os.path as osp import os import json +import numpy as np class TestUtils(TestCase): @@ -48,3 +50,27 @@ class TestUtils(TestCase): read_json('null') self.assertEqual( 'null does not exist.', str(context.exception)) + + def test_drop_file_extension(self): + file = 'test.mat' + self.assertEqual('test', drop_file_extension(file)) + self.assertEqual('', drop_file_extension('')) + with self.assertRaises(ValueError): + drop_file_extension(None) + + def test_make_sorted_sample(self): + arr = np.arange(4) + n_samples = 2 + sample = make_sorted_sample(arr, n_samples) + self.assertTrue(len(sample), n_samples) + self.assertTrue(sample[1] > sample[0]) + + def test_sample_from_video_with_gt(self): + fps = 30 + duration = 20 + n_frames=int(fps*duration) + user_scores = np.random.randint(low=0, high=2, size=n_frames) + user_scores=np.expand_dims(user_scores, axis=1) + video_frames = list(np.random.randn(n_frames)) + sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, user_scores, duration, fps, n_samples=2) + self.assertEquals(len(sampled_frames), len(sampled_gt)) \ No newline at end of file