Skip to content
Snippets Groups Projects
Commit 29288690 authored by Hussain Kanafani's avatar Hussain Kanafani
Browse files

reading videos order fixed

parent 9d8eebb4
No related branches found
No related tags found
No related merge requests found
import numpy as np
import torchvision.models as models import torchvision.models as models
import torchvision as tv import torchvision as tv
import torch import torch
...@@ -6,16 +5,18 @@ import torch.nn as nn ...@@ -6,16 +5,18 @@ import torch.nn as nn
import os import os
from PIL import Image from PIL import Image
import argparse import argparse
import pickle
from utils import digits_in_string
class FeatureExtractor(nn.Module): class FeatureExtractor(nn.Module):
def __init__(self, arch=tv.models.resnet50): def __init__(self, arch):
super(FeatureExtractor, self).__init__() super(FeatureExtractor, self).__init__()
self.tranform = tv.transforms.Compose([ self.tranform = tv.transforms.Compose([
tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(), tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])]) std=[0.229, 0.224, 0.225])])
self.model = nn.Sequential(*(list(arch.children())[:-2] + [nn.MaxPool2d(4, 1)])) self.model = nn.Sequential(*(list(arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
def forward(self, frame): def forward(self, frame):
features = self.model(frame) features = self.model(frame)
...@@ -23,30 +24,48 @@ class FeatureExtractor(nn.Module): ...@@ -23,30 +24,48 @@ class FeatureExtractor(nn.Module):
return features return features
if __name__ == '__main__': def argParser():
parser = argparse.ArgumentParser(description='Features Extraction') parser = argparse.ArgumentParser(description='Features Extraction')
parser.add_argument('--frames', metavar='Frames-dir', parser.add_argument('--frames', metavar='Frames-dir',
default='./frames', default='./frames',
help='path to input frames') help='path to input frames')
parser.add_argument('--model', default='resnet50', parser.add_argument('--model', default='googlenet',
help='pretrained model architecture e.g. resnet50 or alexnet') help='pretrained model architecture e.g. resnet50 or alexnet')
return parser
if __name__ == '__main__':
parser = argParser()
args = parser.parse_args() args = parser.parse_args()
# frames_dir
frames_dir = args.frames frames_dir = args.frames
if args.model =='alexnet': # model architecture
if args.model == 'alexnet':
model_arch = models.alexnet(pretrained=True) model_arch = models.alexnet(pretrained=True)
else: elif args.model == 'resnet50':
model_arch = models.resnet50(pretrained=True) model_arch = models.resnet50(pretrained=True)
else:
model_arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
isCuda = torch.cuda.is_available() isCuda = torch.cuda.is_available()
if isCuda: if isCuda:
model = FeatureExtractor(model_arch).cuda() model = FeatureExtractor(model_arch).cuda()
else: else:
model = FeatureExtractor(model_arch) model = FeatureExtractor(model_arch)
feature = dict()
for i, video in enumerate(os.listdir(frames_dir)): features = dict()
# print model architecture
print(model_arch)
# sort videos according to their number video1,video2,..
all_videos = sorted(os.listdir(frames_dir), key=digits_in_string)
# print(all_videos)
# iterate over the videos
for i, video in enumerate(all_videos):
video = os.path.join(frames_dir, video) video = os.path.join(frames_dir, video)
print(video) print(video)
feature[i] = [] features[i] = []
# iterate over the frames of the videos
for frame in os.listdir(video): for frame in os.listdir(video):
frame = os.path.join(video, frame) frame = os.path.join(video, frame)
print(frame) print(frame)
...@@ -54,5 +73,10 @@ if __name__ == '__main__': ...@@ -54,5 +73,10 @@ if __name__ == '__main__':
img = model.tranform(img) img = model.tranform(img)
img = img.view((1,) + img.shape) img = img.view((1,) + img.shape)
feat = model(img) feat = model(img)
print(feat.shape) # print(feat.shape)
feature[i].append(feat.cpu().detach().numpy()[0]) features[i].append(feat.cpu().detach().numpy()[0])
print(len(features[i]))
# save extracted features in pickle file
pickle.dump(features, open('features.pkl', 'ab'))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment