Skip to content
Snippets Groups Projects
Commit 52dcc433 authored by Hussain Kanafani's avatar Hussain Kanafani
Browse files

redundant files deleted

parent 5bfec854
No related branches found
No related tags found
No related merge requests found
import torchvision.models as models
import torchvision as tv
import torch
import torch.nn as nn
import os
from PIL import Image
import argparse
import pickle
from src.utils import digits_in_string
class FeatureExtractor(nn.Module):
def __init__(self, arch):
super(FeatureExtractor, self).__init__()
# set model architecture according to architecture input name
self.set_model_arch(arch)
# resize frame and normalize
self.tranform = tv.transforms.Compose([
tv.transforms.Resize([224, 224]), tv.transforms.ToTensor(),
tv.transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# get the pool layer of the model
self.model = nn.Sequential(*(list(self.arch.children())[:-2] + [nn.AvgPool2d(1, 1)]))
def forward(self, frame):
features = self.model(frame)
features = features.reshape((features.shape[0], -1))
return features
def set_model_arch(self, arch):
# model architecture
if arch == 'alexnet':
self.arch = models.alexnet(pretrained=True)
elif arch == 'resnet50':
self.arch = models.resnet50(pretrained=True)
else:
self.arch = torch.hub.load('pytorch/vision:v0.6.0', 'googlenet', pretrained=True)
def argParser():
parser = argparse.ArgumentParser(description='Features Extraction')
parser.add_argument('--frames', metavar='Frames-dir',
default='./frames',
help='path to input frames')
parser.add_argument('--model', default='googlenet',
help='pretrained model architecture e.g. resnet50 or alexnet')
return parser
if __name__ == '__main__':
parser = argParser()
args = parser.parse_args()
# frames_dir
frames_dir = args.frames
is_cuda = torch.cuda.is_available()
if is_cuda:
model = FeatureExtractor(args.model).cuda()
else:
model = FeatureExtractor(args.model)
# print model architecture
print(model)
# sort videos according to their number video1,video2,..
all_videos = sorted(os.listdir(frames_dir), key=digits_in_string)
# print(all_videos)
# iterate over the videos
features = dict()
for i, video in enumerate(all_videos):
video = os.path.join(frames_dir, video)
print(video)
features[i] = []
# iterate over the frames of the videos
for frame in os.listdir(video):
frame = os.path.join(video, frame)
print(frame)
img = Image.open(frame)
img = model.tranform(img)
img = img.view((1,) + img.shape)
feat = model(img)
# print(feat.shape)
features[i].append(feat.cpu().detach().numpy()[0])
print(len(features[i]))
# save extracted features in pickle file
with open('features.pickle', 'ab') as handle:
pickle.dump(features, handle, protocol=pickle.HIGHEST_PROTOCOL)
import pickle
import torch
from scipy.io import loadmat
def main():
data = pickle.load(open('sampled_data.pickle', 'rb'))
print(data)
#torch.FloatTensor(features[0]).shape
x = loadmat('./data/SumMe/Vide/Base jumping.mat')
x.keys()
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment