Commit 4a2d7570 authored by E144069X's avatar E144069X

Added random representative vector dropout

parent 6d86edd5
......@@ -869,7 +869,7 @@ class LinearSecondModel(SecondModel):
class DeepSecondModel(SecondModel):
def __init__(self, nbFeat, nbClass, dropout,hidSizes=[512,1024]):
def __init__(self, nbFeat, nbClass, dropout,hidSizes=[512,1024],reprVecInDrop=0.5):
super(DeepSecondModel, self).__init__(nbFeat, nbClass)
self.dropout = nn.Dropout(p=dropout)
......@@ -879,8 +879,15 @@ class DeepSecondModel(SecondModel):
self.layers.append(nn.Linear(layerSizes[-2],layerSizes[-1]))
self.layers = nn.Sequential(*self.layers)
self.inDrop = reprVecInDrop
def forward(self,x):
if self.training:
x = x[:,:int(x.size(1)*(1-self.inDrop))].contiguous()
origSize = x.size()
x = x.view(x.size(0)*x.size(1),x.size(2))
x = self.layers(x)
x = x.view(origSize[0],origSize[1],x.size(-1))
......@@ -1057,7 +1064,7 @@ def netBuilder(args):
bil_cluster_ensemble_gate=args.bil_cluster_ensemble_gate,hidLay=args.hid_lay,gate_drop=args.bil_cluster_ensemble_gate_drop,\
gate_randdrop=args.bil_cluster_ensemble_gate_randdrop,**zoomArgs)
elif args.second_mod == "deepLinear":
secondModel = DeepSecondModel(nbFeat,args.class_nb,args.dropout)
secondModel = DeepSecondModel(nbFeat,args.class_nb,args.dropout,reprVecInDrop=args.repr_vec_in_drop)
else:
raise ValueError("Unknown second model type : ", args.second_mod)
......@@ -1082,6 +1089,9 @@ def addArgs(argreader):
argreader.parser.add_argument('--dropout', type=float, metavar='D',
help='The dropout amount on each layer of the RNN except the last one')
argreader.parser.add_argument('--repr_vec_in_drop', type=float, metavar='S',
help='The percentage of representative vectors dropped during training.')
argreader.parser.add_argument('--second_mod', type=str, metavar='MOD',
help='The temporal model. Can be "linear", "lstm" or "score_conv".')
......@@ -1222,4 +1232,6 @@ def addArgs(argreader):
help="To select random vectors as initial estimation instead of vectors with high norms.")
return argreader
......@@ -15,6 +15,7 @@ dataset_val = CUB_200_2011_train
dataset_test = CUB_200_2011_test
with_seg=True
repr_vec=False
repr_vec_in_drop=0
train_prop = 90
......
from torchvision.datasets.vision import VisionDataset
from PIL import Image
import os
import os.path
import scipy
import torch
import numpy as np
import random
from torchvision import transforms
def collateSeq(batch):
res = list(zip(*batch))
res[0] = torch.cat(res[0],dim=0)
res[1] = torch.cat(res[1],dim=0)
res[2] = torch.cat(res[2],dim=0)
return res
def has_file_allowed_extension(filename, extensions):
"""Checks if a file is an allowed extension.
Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)
Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
def is_image_file(filename):
"""Checks if a file is an allowed image extension.
Args:
filename (string): path to a file
Returns:
bool: True if the filename ends with a known image extension
"""
return has_file_allowed_extension(filename, IMG_EXTENSIONS)
def make_dataset(directory, class_to_idx, extensions=None, is_valid_file=None):
instances = []
directory = os.path.expanduser(directory)
both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")
if extensions is not None:
def is_valid_file(x):
return has_file_allowed_extension(x, extensions)
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if is_valid_file(path):
item = path, class_index
instances.append(item)
return instances
class DatasetFolder(VisionDataset):
def __init__(self, root, extensions=None,
target_transform=None, is_valid_file=None):
super(DatasetFolder, self).__init__(root, transform=None,
target_transform=target_transform)
classes, class_to_idx = self._find_classes(self.root)
samples = make_dataset(self.root, class_to_idx, extensions, is_valid_file)
if len(samples) == 0:
msg = "Found 0 files in subfolders of: {}\n".format(self.root)
if extensions is not None:
msg += "Supported extensions are: {}".format(",".join(extensions))
raise RuntimeError(msg)
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
self.reprVec = np.load("../results/{}_reprVec.npy".format(root.split("/")[-1]))
def _find_classes(self, dir):
"""
Finds the class folders in a dataset.
Args:
dir (string): Root directory path.
Returns:
tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary.
Ensures:
No class is a subdirectory of another.
"""
classes = [d.name for d in os.scandir(dir) if d.is_dir()]
classes.sort()
class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where target is class_index of the target class.
"""
_, target = self.samples[index]
sample = self.reprVec[index%len(self.reprVec)]
sample = np.random.permutation(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample,target
def __len__(self):
return len(self.samples)
IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')
class ReprVec(DatasetFolder):
"""A generic data loader where the images are arranged in this way: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/asd932_.png
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(self, root,target_transform=None, is_valid_file=None):
super(ReprVec, self).__init__(root, IMG_EXTENSIONS if is_valid_file is None else None,
target_transform=target_transform,
is_valid_file=is_valid_file)
self.imgs = self.samples
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment