Skip to content
Snippets Groups Projects
video2frames.py 3.18 KiB
Newer Older
import argparse
from src.utils import *
from src.SumMeVideo import SumMeVideo
from src.TVSumVideo import TVSumVideo
from src.FeatureExtractor import FeatureExtractor
import pickle
from PIL import Image
from tqdm import tqdm


def create_video_obj(dataset, video_name, video_path, gt_dir):
    if str(dataset).lower() == 'summe':
        return SumMeVideo(video_name, video_path, gt_dir)
    else:
        return TVSumVideo(video_name, video_path, gt_dir)
    parser = argparse.ArgumentParser(description='Extract Features')
    parser.add_argument('--dataset', default='TVSum', type=str, help='SumMe or TVSum')
    parser.add_argument('--videos_dir', metavar='DIR', default='../data/TVSum/video', help='path input videos')
    parser.add_argument('--gt', metavar='GT_Dir', default='../data/TVSum/data', help='path ground truth')
    parser.add_argument('--fps', default=2, type=int, help='Frames per second for the extraction')
    parser.add_argument('--model_arch', default='googlenet',
                        help='pretrained model architecture e.g. resnet50 or alexnet')
    return parser


if __name__ == '__main__':
    parser = arg_parser()
    args = parser.parse_args()
    # dataset, videos, GT, Sample rate, model architecture
    dataset, videos_dir, gt_dir, fps, model_arch = args.dataset, args.videos_dir, args.gt, args.fps, args.model_arch
    # define feature extractor model
    model = FeatureExtractor(model_arch)
    print(model)
    # stores all sampled data in array of dict
    features = dict()
    # iterate over videos, sample frames and gt, and extract features
    for idx, video in enumerate(os.listdir(videos_dir)):
        # sample video an ground truth
        video_name = drop_file_extension(video)
        video_path = os.path.join(videos_dir, video)
        # create video according to dataset
        video_obj = create_video_obj(dataset, video_name, video_path, gt_dir)
        # get video ground-truth data
        video_gt = video_obj.get_gt()
        # get video frames
        print('getting frames for {}...'.format(video_name))
        video_frames = video_obj.get_frames()
        # sample video frames
        print('sampling from video frames...')
        sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, video_obj.duration,
                                                               video_obj.fps)
        # delete video object
        del video_obj
        print('frames retrieved')
        # iterate over sampled frames to extract their features using pretrained model
        print('Extracting features for {} ...'.format(video_name))
        for f in tqdm(range(len(sampled_frames))):
            # convert to PIL
            PIL_image = Image.fromarray(sampled_frames[f])
            frame = model.tranform(PIL_image)
            # extend dim
            frame = frame.view((1,) + frame.shape)
            # get features
            feat = model(frame)
            features[idx].append(feat.cpu().detach().numpy()[0])

        print('Saving features ...')
        with open('features.pickle', 'ab') as handle:
            pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL)
        print('features saved')