Skip to content
Snippets Groups Projects
Commit 0f7e10ba authored by Hussain Kanafani's avatar Hussain Kanafani
Browse files

feature extraction using pretrained models implemented

parent d4dcc9b6
No related branches found
No related tags found
No related merge requests found
import torchvision.models as models
import torchvision as tv
import torch
import torch.nn as nn
class FeatureExtractor(nn.Module):
def __init__(self, arch):
super(FeatureExtractor, self).__init__()
# set model architecture according to architecture input name
self.set_model_arch(arch)
# resize frame and normalize
self.tranform = tv.transforms.Compose([
tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# get the pool layer of the model
self.model = nn.Sequential(*(list(self.arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
def forward(self, frame):
features = self.model(frame)
features = features.reshape((features.shape[0], -1))
return features
def set_model_arch(self, arch):
# model architecture
if arch == 'alexnet':
self.arch = models.alexnet(pretrained=True)
elif arch == 'resnet50':
self.arch = models.resnet50(pretrained=True)
else:
self.arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
......@@ -2,6 +2,7 @@ from scipy.io import loadmat
import os
from moviepy.editor import VideoFileClip
class SumMeVideo(VideoFileClip):
def __init__(self, video_name, video_path, gt_base_dir):
self.video_clip = VideoFileClip(video_path)
......
import os
import numpy as np
import pandas as pd
from moviepy.editor import VideoFileClip
class TVSumVideo(VideoFileClip):
GT_FILE = 'ydata-tvsum50-anno.tsv'
def __init__(self, video_name, video_path, gt_base_dir):
self.video_clip = VideoFileClip(video_path)
self.fps = int(self.video_clip.fps)
self.duration = int(self.video_clip.duration)
self.gt_path = os.path.join(gt_base_dir, self.GT_FILE)
self.video_name = video_name
def get_gt(self):
gt_df = pd.read_csv(self.gt_path, sep="\t", header=None, index_col=0)
sub_gt = gt_df.loc[self.video_name]
users_gt = []
for i in range(len(sub_gt)):
users_gt.append(sub_gt.iloc[i, -1].split(","))
users_gt = np.array(users_gt)
return self.__avg_array(users_gt)
def __avg_array(self, users_gt):
users_gt = users_gt.astype(int)
return users_gt.mean(axis=0)
def get_frames(self):
return list(self.video_clip.iter_frames(with_times=False))
......@@ -12,17 +12,30 @@ from src.utils import digits_in_string
class FeatureExtractor(nn.Module):
def __init__(self, arch):
super(FeatureExtractor, self).__init__()
# set model architecture according to architecture input name
self.set_model_arch(arch)
# resize frame and normalize
self.tranform = tv.transforms.Compose([
tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
self.model = nn.Sequential(*(list(arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
# get the pool layer of the model
self.model = nn.Sequential(*(list(self.arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
def forward(self, frame):
features = self.model(frame)
features = features.reshape((features.shape[0], -1))
return features
def set_model_arch(self, arch):
# model architecture
if arch == 'alexnet':
self.arch = models.alexnet(pretrained=True)
elif arch == 'resnet50':
self.arch = models.resnet50(pretrained=True)
else:
self.arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
def argParser():
parser = argparse.ArgumentParser(description='Features Extraction')
......@@ -39,28 +52,22 @@ if __name__ == '__main__':
args = parser.parse_args()
# frames_dir
frames_dir = args.frames
# model architecture
if args.model == 'alexnet':
model_arch = models.alexnet(pretrained=True)
elif args.model == 'resnet50':
model_arch = models.resnet50(pretrained=True)
else:
model_arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
isCuda = torch.cuda.is_available()
if isCuda:
model = FeatureExtractor(model_arch).cuda()
is_cuda = torch.cuda.is_available()
if is_cuda:
model = FeatureExtractor(args.model).cuda()
else:
model = FeatureExtractor(model_arch)
model = FeatureExtractor(args.model)
features = dict()
# print model architecture
print(model_arch)
print(model)
# sort videos according to their number video1,video2,..
all_videos = sorted(os.listdir(frames_dir), key=digits_in_string)
# print(all_videos)
# iterate over the videos
features = dict()
for i, video in enumerate(all_videos):
video = os.path.join(frames_dir, video)
print(video)
......
import argparse
from src.utils import *
from src.SumMeVideo import SumMeVideo
from src.TVSumVideo import TVSumVideo
from src.FeatureExtractor import FeatureExtractor
import pickle
from PIL import Image
from tqdm import tqdm
def create_video_obj(dataset, video_name, video_path, gt_dir):
if str(dataset).lower() == 'summe':
return SumMeVideo(video_name, video_path, gt_dir)
else:
return TVSumVideo(video_name, video_path, gt_dir)
def arg_parser():
parser = argparse.ArgumentParser(description='Extract frames')
parser.add_argument('--videos_dir', metavar='DIR',
default='../data/SumMe/videos',
help='path input videos')
parser.add_argument('--gt', metavar='GT_Dir',
default='../data/SumMe/GT',
help='path ground truth')
parser.add_argument('--fps', default=2, type=int,
help='Frames per second for the extraction')
parser.add_argument('--out_dir',
default='../out',
help='path to extracted frames')
parser = argparse.ArgumentParser(description='Extract Features')
parser.add_argument('--dataset', default='TVSum', type=str, help='SumMe or TVSum')
parser.add_argument('--videos_dir', metavar='DIR', default='../data/TVSum/video', help='path input videos')
parser.add_argument('--gt', metavar='GT_Dir', default='../data/TVSum/data', help='path ground truth')
parser.add_argument('--fps', default=2, type=int, help='Frames per second for the extraction')
parser.add_argument('--model_arch', default='googlenet',
help='pretrained model architecture e.g. resnet50 or alexnet')
return parser
if __name__ == '__main__':
parser = arg_parser()
args = parser.parse_args()
# videos,GT, Sample rate, output dir
videos_dir, gt_dir, fps, out_dir = args.videos_dir, args.gt, args.fps, args.out_dir
# dataset, videos, GT, Sample rate, model architecture
dataset, videos_dir, gt_dir, fps, model_arch = args.dataset, args.videos_dir, args.gt, args.fps, args.model_arch
# define feature extractor model
model = FeatureExtractor(model_arch)
print(model)
# stores all sampled data in array of dict
sampled_data = []
features = dict()
# iterate over videos, sample frames and gt, and extract features
for idx, video in enumerate(os.listdir(videos_dir)):
features[idx] = []
# sample video an ground truth
video_name = drop_file_extension(video)
video_path = os.path.join(videos_dir, video)
sumMe_video = SumMeVideo(video_name, video_path, gt_dir)
video_gt = sumMe_video.get_gt()
video_frames = sumMe_video.get_frames()
sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, sumMe_video.duration,
sumMe_video.fps)
del sumMe_video
# create video according to dataset
video_obj = create_video_obj(dataset, video_name, video_path, gt_dir)
# get video ground-truth data
video_gt = video_obj.get_gt()
# get video frames
print('getting frames for {}...'.format(video_name))
video_frames = video_obj.get_frames()
# sample video frames
print('sampling from video frames...')
sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, video_obj.duration,
video_obj.fps)
# delete video object
del video_obj
print('frames retrieved')
# iterate over sampled frames to extract their features using pretrained model
print('Extracting features for {} ...'.format(video_name))
for f in tqdm(range(len(sampled_frames))):
# convert to PIL
PIL_image = Image.fromarray(sampled_frames[f])
frame = model.tranform(PIL_image)
# extend dim
frame = frame.view((1,) + frame.shape)
# get features
feat = model(frame)
features[idx].append(feat.cpu().detach().numpy()[0])
print('Saving features ...')
with open('features.pickle', 'ab') as handle:
pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('features saved')
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment