Newer
Older
import argparse
from src.utils import *
from src.SumMeVideo import SumMeVideo
from src.TVSumVideo import TVSumVideo
from src.FeatureExtractor import FeatureExtractor
import pickle
from PIL import Image
from tqdm import tqdm
def create_video_obj(dataset, video_name, video_path, gt_dir):
if str(dataset).lower() == 'summe':
return SumMeVideo(video_name, video_path, gt_dir)
else:
return TVSumVideo(video_name, video_path, gt_dir)
parser = argparse.ArgumentParser(description='Extract Features')
parser.add_argument('--dataset', default='TVSum', type=str, help='SumMe or TVSum')
parser.add_argument('--videos_dir', metavar='DIR', default='../data/TVSum/video', help='path input videos')
parser.add_argument('--gt', metavar='GT_Dir', default='../data/TVSum/data', help='path ground truth')
parser.add_argument('--fps', default=2, type=int, help='Frames per second for the extraction')
parser.add_argument('--model_arch', default='googlenet',
help='pretrained model architecture e.g. resnet50 or alexnet')
return parser
if __name__ == '__main__':
parser = arg_parser()
args = parser.parse_args()
# dataset, videos, GT, Sample rate, model architecture
dataset, videos_dir, gt_dir, fps, model_arch = args.dataset, args.videos_dir, args.gt, args.fps, args.model_arch
# define feature extractor model
model = FeatureExtractor(model_arch)
print(model)
# stores all sampled data in array of dict
features = dict()
# iterate over videos, sample frames and gt, and extract features
for idx, video in enumerate(os.listdir(videos_dir)):
features[idx] = []
# sample video an ground truth
video_name = drop_file_extension(video)
video_path = os.path.join(videos_dir, video)
# create video according to dataset
video_obj = create_video_obj(dataset, video_name, video_path, gt_dir)
# get video ground-truth data
video_gt = video_obj.get_gt()
# get video frames
print('getting frames for {}...'.format(video_name))
video_frames = video_obj.get_frames()
# sample video frames
print('sampling from video frames...')
sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, video_obj.duration,
video_obj.fps)
# delete video object
del video_obj
print('frames retrieved')
# iterate over sampled frames to extract their features using pretrained model
print('Extracting features for {} ...'.format(video_name))
for f in tqdm(range(len(sampled_frames))):
# convert to PIL
PIL_image = Image.fromarray(sampled_frames[f])
frame = model.tranform(PIL_image)
# extend dim
frame = frame.view((1,) + frame.shape)
# get features
feat = model(frame)
features[idx].append(feat.cpu().detach().numpy()[0])
print('Saving features ...')
with open('features.pickle', 'ab') as handle:
pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('features saved')