Skip to content
Snippets Groups Projects
Commit 569719e0 authored by Hussain Kanafani's avatar Hussain Kanafani
Browse files

VSUMM added

parent 5e1782bc
No related branches found
No related tags found
No related merge requests found
import os
from moviepy.editor import VideoFileClip
import numpy as np
from src.utils import digits_in_string, has_string
class VSUMMVideo(VideoFileClip):
def __init__(self, video_name, video_path, gt_base_dir):
self.video_clip = VideoFileClip(video_path)
self.fps = int(self.video_clip.fps)
self.duration = int(self.video_clip.duration)
self.gt_path = os.path.join(gt_base_dir, video_name)
self.video_name = video_name
self.nframes = self.video_clip.reader.infos['video_nframes']
def get_gt(self):
vid_scores = []
for i, user in enumerate(os.listdir(self.gt_path)):
user_dir = os.path.join(self.gt_path, user)
frame_idx = []
for idx, summary in enumerate(os.listdir(user_dir)):
if has_string(summary, 'frame'):
frame_idx.append(digits_in_string(summary)) # Frame123--> 123
vid_scores.append(frame_idx)
binary_scores = self.bin_classify_user_score(vid_scores)
return binary_scores
def bin_classify_user_score(self, user_scores_list):
result = []
for frames_idx in list(user_scores_list):
vid_frames = np.zeros(self.nframes)
np.put(vid_frames, frames_idx, [1]) # replace all selected frames with 1
result.append(vid_frames)
return np.asarray(result).T # (n_frames,n_annotator)
def get_frames(self):
return list(self.video_clip.iter_frames(with_times=False))
......@@ -2,24 +2,30 @@ import argparse
from src.utils import *
from src.SumMeVideo import SumMeVideo
from src.TVSumVideo import TVSumVideo
from src.VSUMMVideo import VSUMMVideo
from src.FeatureExtractor import FeatureExtractor
import pickle
from PIL import Image
from tqdm import tqdm
def create_video_obj(dataset, video_name, video_path, gt_dir):
if str(dataset).lower() == 'summe':
dataset = str(dataset).lower()
if dataset == 'summe':
return SumMeVideo(video_name, video_path, gt_dir)
else:
elif dataset == 'tvsumm':
return TVSumVideo(video_name, video_path, gt_dir)
else:
return VSUMMVideo(video_name, video_path, gt_dir)
def arg_parser():
# ../data/SumMe/videos ../data/SumMe/videos/GT
# ../ data / TVSum / video / ../data/TVSum/data
# ../data/VSUMM/new_database ../data/VSUMM/newUserSummary
parser = argparse.ArgumentParser(description='Extract Features')
parser.add_argument('--dataset', default='TVSum', type=str, help='SumMe or TVSum')
parser.add_argument('--videos_dir', metavar='DIR', default='../data/TVSum/video/', help='path input videos')
parser.add_argument('--gt', metavar='GT_Dir', default='../data/TVSum/data', help='path ground truth')
parser.add_argument('--dataset', default='VSUMM', type=str, help='SumMe, TVSum or VSUMM')
parser.add_argument('--videos_dir', metavar='DIR', default='../data/VSUMM/new_database', help='path input videos')
parser.add_argument('--gt', metavar='GT_Dir', default='../data/VSUMM/newUserSummary', help='path ground truth')
parser.add_argument('--fps', default=2, type=int, help='Frames per second for the extraction')
parser.add_argument('--model_arch', default='googlenet',
help='pretrained model architecture e.g. resnet50 or alexnet')
......@@ -37,8 +43,12 @@ if __name__ == '__main__':
# stores all sampled data in array of dict
features = dict()
ground_truth = dict()
# sort videos according to their number video1,video2,..
all_videos = sorted(os.listdir(videos_dir), key=digits_in_string)
# iterate over videos, sample frames and gt, and extract features
for idx, video in enumerate(os.listdir(videos_dir)):
for idx, video in enumerate(all_videos):
features[idx] = []
ground_truth[idx] = []
# sample video an ground truth
......@@ -51,6 +61,7 @@ if __name__ == '__main__':
# get video frames
print('getting frames for {}...'.format(video_name))
video_frames = video_obj.get_frames()
# sample video frames
print('sampling from video frames...')
sampled_frames, sampled_gt = sample_from_video_with_gt(video_frames, video_gt, video_obj.duration,
......@@ -70,5 +81,5 @@ if __name__ == '__main__':
feat = model(frame)
features[idx].append(feat.cpu().detach().numpy()[0])
ground_truth[idx].append(sampled_gt)
save_pickle_file('features',features)
save_pickle_file('GT',ground_truth)
\ No newline at end of file
save_pickle_file('features', features)
save_pickle_file('GT', ground_truth)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment