From 0f7e10ba2248c1dc21f931ad8c01219745d1668b Mon Sep 17 00:00:00 2001
From: Hussain Kanafani <hussainkanafani@gmail.com>
Date: Wed, 8 Jul 2020 00:50:20 +0200
Subject: [PATCH] feature extraction using pretrained models implemented

---
 src/FeatureExtractor.py   | 32 ++++++++++++++++
 src/SumMeVideo.py         |  1 +
 src/TVSumVideo.py         | 31 +++++++++++++++
 src/feature_extraction.py | 35 ++++++++++-------
 src/video2frames.py       | 79 ++++++++++++++++++++++++++++-----------
 5 files changed, 142 insertions(+), 36 deletions(-)
 create mode 100644 src/FeatureExtractor.py
 create mode 100644 src/TVSumVideo.py

diff --git a/src/FeatureExtractor.py b/src/FeatureExtractor.py
new file mode 100644
index 0000000..3cb1bc4
--- /dev/null
+++ b/src/FeatureExtractor.py
@@ -0,0 +1,32 @@
+import torchvision.models as models
+import torchvision as tv
+import torch
+import torch.nn as nn
+
+
+class FeatureExtractor(nn.Module):
+    def __init__(self, arch):
+        super(FeatureExtractor, self).__init__()
+        # set model architecture according to architecture input name
+        self.set_model_arch(arch)
+        # resize frame and normalize
+        self.tranform = tv.transforms.Compose([
+            tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
+            tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
+                                    std=[0.229, 0.224, 0.225])])
+        # get the pool layer of the model
+        self.model = nn.Sequential(*(list(self.arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
+
+    def forward(self, frame):
+        features = self.model(frame)
+        features = features.reshape((features.shape[0], -1))
+        return features
+
+    def set_model_arch(self, arch):
+        # model architecture
+        if arch == 'alexnet':
+            self.arch = models.alexnet(pretrained=True)
+        elif arch == 'resnet50':
+            self.arch = models.resnet50(pretrained=True)
+        else:
+            self.arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
diff --git a/src/SumMeVideo.py b/src/SumMeVideo.py
index 5fada48..d86f5c1 100644
--- a/src/SumMeVideo.py
+++ b/src/SumMeVideo.py
@@ -2,6 +2,7 @@ from scipy.io import loadmat
 import os
 from moviepy.editor import VideoFileClip
 
+
 class SumMeVideo(VideoFileClip):
     def __init__(self, video_name, video_path, gt_base_dir):
         self.video_clip = VideoFileClip(video_path)
diff --git a/src/TVSumVideo.py b/src/TVSumVideo.py
new file mode 100644
index 0000000..29b849d
--- /dev/null
+++ b/src/TVSumVideo.py
@@ -0,0 +1,31 @@
+import os
+import numpy as np
+import pandas as pd
+from moviepy.editor import VideoFileClip
+
+
+class TVSumVideo(VideoFileClip):
+    GT_FILE = 'ydata-tvsum50-anno.tsv'
+
+    def __init__(self, video_name, video_path, gt_base_dir):
+        self.video_clip = VideoFileClip(video_path)
+        self.fps = int(self.video_clip.fps)
+        self.duration = int(self.video_clip.duration)
+        self.gt_path = os.path.join(gt_base_dir, self.GT_FILE)
+        self.video_name = video_name
+
+    def get_gt(self):
+        gt_df = pd.read_csv(self.gt_path, sep="\t", header=None, index_col=0)
+        sub_gt = gt_df.loc[self.video_name]
+        users_gt = []
+        for i in range(len(sub_gt)):
+            users_gt.append(sub_gt.iloc[i, -1].split(","))
+        users_gt = np.array(users_gt)
+        return self.__avg_array(users_gt)
+
+    def __avg_array(self, users_gt):
+        users_gt = users_gt.astype(int)
+        return users_gt.mean(axis=0)
+
+    def get_frames(self):
+        return list(self.video_clip.iter_frames(with_times=False))
diff --git a/src/feature_extraction.py b/src/feature_extraction.py
index 3e128fe..836ec49 100644
--- a/src/feature_extraction.py
+++ b/src/feature_extraction.py
@@ -12,17 +12,30 @@ from src.utils import digits_in_string
 class FeatureExtractor(nn.Module):
     def __init__(self, arch):
         super(FeatureExtractor, self).__init__()
+        # set model architecture according to architecture input name
+        self.set_model_arch(arch)
+        # resize frame and normalize
         self.tranform = tv.transforms.Compose([
             tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
             tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])
-        self.model = nn.Sequential(*(list(arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
+        # get the pool layer of the model
+        self.model = nn.Sequential(*(list(self.arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
 
     def forward(self, frame):
         features = self.model(frame)
         features = features.reshape((features.shape[0], -1))
         return features
 
+    def set_model_arch(self, arch):
+        # model architecture
+        if arch == 'alexnet':
+            self.arch = models.alexnet(pretrained=True)
+        elif arch == 'resnet50':
+            self.arch = models.resnet50(pretrained=True)
+        else:
+            self.arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
+
 
 def argParser():
     parser = argparse.ArgumentParser(description='Features Extraction')
@@ -39,28 +52,22 @@ if __name__ == '__main__':
     args = parser.parse_args()
     # frames_dir
     frames_dir = args.frames
-    # model architecture
-    if args.model == 'alexnet':
-        model_arch = models.alexnet(pretrained=True)
-    elif args.model == 'resnet50':
-        model_arch = models.resnet50(pretrained=True)
-    else:
-        model_arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
 
-    isCuda = torch.cuda.is_available()
-    if isCuda:
-        model = FeatureExtractor(model_arch).cuda()
+    is_cuda = torch.cuda.is_available()
+    if is_cuda:
+        model = FeatureExtractor(args.model).cuda()
     else:
-        model = FeatureExtractor(model_arch)
+        model = FeatureExtractor(args.model)
+
 
-    features = dict()
     # print model architecture
-    print(model_arch)
+    print(model)
 
     # sort videos according to their number video1,video2,..
     all_videos = sorted(os.listdir(frames_dir), key=digits_in_string)
     # print(all_videos)
     # iterate over the videos
+    features = dict()
     for i, video in enumerate(all_videos):
         video = os.path.join(frames_dir, video)
         print(video)
diff --git a/src/video2frames.py b/src/video2frames.py
index d5d2a01..b456fb9 100644
--- a/src/video2frames.py
+++ b/src/video2frames.py
@@ -1,39 +1,74 @@
 import argparse
 from src.utils import *
 from src.SumMeVideo import SumMeVideo
+from src.TVSumVideo import TVSumVideo
+from src.FeatureExtractor import FeatureExtractor
+import pickle
+from PIL import Image
+from tqdm import tqdm
+
+
+def create_video_obj(dataset, video_name, video_path, gt_dir):
+    if str(dataset).lower() == 'summe':
+        return SumMeVideo(video_name, video_path, gt_dir)
+    else:
+        return TVSumVideo(video_name, video_path, gt_dir)
 
 
 def arg_parser():
-    parser = argparse.ArgumentParser(description='Extract frames')
-    parser.add_argument('--videos_dir', metavar='DIR',
-                        default='../data/SumMe/videos',
-                        help='path input videos')
-    parser.add_argument('--gt', metavar='GT_Dir',
-                        default='../data/SumMe/GT',
-                        help='path ground truth')
-    parser.add_argument('--fps', default=2, type=int,
-                        help='Frames per second for the extraction')
-    parser.add_argument('--out_dir',
-                        default='../out',
-                        help='path to extracted frames')
+    parser = argparse.ArgumentParser(description='Extract Features')
+    parser.add_argument('--dataset', default='TVSum', type=str, help='SumMe or TVSum')
+    parser.add_argument('--videos_dir', metavar='DIR', default='../data/TVSum/video', help='path input videos')
+    parser.add_argument('--gt', metavar='GT_Dir', default='../data/TVSum/data', help='path ground truth')
+    parser.add_argument('--fps', default=2, type=int, help='Frames per second for the extraction')
+    parser.add_argument('--model_arch', default='googlenet',
+                        help='pretrained model architecture e.g. resnet50 or alexnet')
     return parser
 
 
 if __name__ == '__main__':
     parser = arg_parser()
     args = parser.parse_args()
-    # videos,GT, Sample rate, output dir
-    videos_dir, gt_dir, fps, out_dir = args.videos_dir, args.gt, args.fps, args.out_dir
-
+    # dataset, videos, GT, Sample rate, model architecture
+    dataset, videos_dir, gt_dir, fps, model_arch = args.dataset, args.videos_dir, args.gt, args.fps, args.model_arch
+    # define feature extractor model
+    model = FeatureExtractor(model_arch)
+    print(model)
     # stores all sampled data in array of dict
-    sampled_data = []
+    features = dict()
+    # iterate over videos, sample frames and gt, and extract features
     for idx, video in enumerate(os.listdir(videos_dir)):
+        features[idx] = []
         # sample video an ground truth
         video_name = drop_file_extension(video)
         video_path = os.path.join(videos_dir, video)
-        sumMe_video = SumMeVideo(video_name, video_path, gt_dir)
-        video_gt = sumMe_video.get_gt()
-        video_frames = sumMe_video.get_frames()
-        sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, sumMe_video.duration,
-                                                               sumMe_video.fps)
-        del sumMe_video
+        # create video according to dataset
+        video_obj = create_video_obj(dataset, video_name, video_path, gt_dir)
+        # get video ground-truth data
+        video_gt = video_obj.get_gt()
+        # get video frames
+        print('getting frames for {}...'.format(video_name))
+        video_frames = video_obj.get_frames()
+        # sample video frames
+        print('sampling from video frames...')
+        sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, video_obj.duration,
+                                                               video_obj.fps)
+        # delete video object
+        del video_obj
+        print('frames retrieved')
+        # iterate over sampled frames to extract their features using pretrained model
+        print('Extracting features for {} ...'.format(video_name))
+        for f in tqdm(range(len(sampled_frames))):
+            # convert to PIL
+            PIL_image = Image.fromarray(sampled_frames[f])
+            frame = model.tranform(PIL_image)
+            # extend dim
+            frame = frame.view((1,) + frame.shape)
+            # get features
+            feat = model(frame)
+            features[idx].append(feat.cpu().detach().numpy()[0])
+
+        print('Saving features ...')
+        with open('features.pickle', 'ab') as handle:
+            pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL)
+        print('features saved')
-- 
GitLab