Get the weights from here
To run the test:
`python circle_vs_ellipse.py circle_vs_ellipse --n_train 100 --n_test 100 --restore_path cornet_z_epoch25.pth.tar - -j 4 --batch_size 100
| import sys | |
| import os | |
| import argparse | |
| import time | |
| import glob | |
| import pickle | |
| import subprocess | |
| import shlex | |
| import io | |
| from collections import OrderedDict | |
| import numpy as np | |
| import pandas | |
| import tqdm | |
| import fire | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import sklearn.linear_model | |
| import skimage.draw | |
| from PIL import Image | |
| Image.warnings.simplefilter('ignore') | |
| np.random.seed(0) | |
| torch.manual_seed(0) | |
| torch.backends.cudnn.benchmark = True | |
| normalize = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| parser = argparse.ArgumentParser(description='ImageNet Training') | |
| parser.add_argument('--data_path', default='./', | |
| help='path to ImageNet folder that contains train and val folders') | |
| parser.add_argument('-o', '--output_path', default=None, | |
| help='path for storing ') | |
| parser.add_argument('--model', choices=['Z', 'R', 'S'], default='Z', | |
| help='which model to train') | |
| parser.add_argument('--times', default=5, type=int, | |
| help='number of time steps to run the model (only R and S models)') | |
| parser.add_argument('--ngpus', default=1, type=int, | |
| help='number of GPUs to use') | |
| parser.add_argument('-j', '--workers', default=4, type=int, | |
| help='number of data loading workers') | |
| parser.add_argument('--epochs', default=20, type=int, | |
| help='number of total epochs to run') | |
| parser.add_argument('--batch_size', default=256, type=int, | |
| help='mini-batch size') | |
| parser.add_argument('--lr', '--learning_rate', default=.1, type=float, | |
| help='initial learning rate') | |
| parser.add_argument('--step_size', default=10, type=int, | |
| help='after how many epochs learning rate should be decreased 10x') | |
| parser.add_argument('--momentum', default=.9, type=float, help='momentum') | |
| parser.add_argument('--weight_decay', default=1e-4, type=float, | |
| help='weight decay ') | |
| FLAGS, FIRE_FLAGS = parser.parse_known_args() | |
| class Flatten(nn.Module): | |
| """ | |
| Helper module for flattening input tensor to 1-D for the use in Linear modules | |
| """ | |
| def forward(self, x): | |
| return x.view(x.size(0), -1) | |
| class Identity(nn.Module): | |
| """ | |
| Helper module that stores the current tensor. Useful for accessing by name | |
| """ | |
| def forward(self, x): | |
| return x | |
| class CORblock_Z(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): | |
| super().__init__() | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=stride, padding=kernel_size // 2) | |
| self.nonlin = nn.ReLU(inplace=True) | |
| self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | |
| self.output = Identity() # for an easy access to this block's output | |
| def forward(self, inp): | |
| x = self.conv(inp) | |
| x = self.nonlin(x) | |
| x = self.pool(x) | |
| x = self.output(x) # for an easy access to this block's output | |
| return x | |
| def CORnet_Z(): | |
| model = nn.Sequential(OrderedDict([ | |
| ('V1', CORblock_Z(3, 64, kernel_size=7, stride=2)), | |
| ('V2', CORblock_Z(64, 128)), | |
| ('V4', CORblock_Z(128, 256)), | |
| ('IT', CORblock_Z(256, 512)), | |
| ('decoder', nn.Sequential(OrderedDict([ | |
| ('avgpool', nn.AdaptiveAvgPool2d(1)), | |
| ('flatten', Flatten()), | |
| ('linear', nn.Linear(512, 1000)), | |
| ('output', Identity()) | |
| ]))) | |
| ])) | |
| # weight initialization | |
| for m in model.modules(): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| nn.init.xavier_uniform_(m.weight) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| return model | |
| class GenEllipse(torch.utils.data.Dataset): | |
| def __init__(self, imsize=224, transform=None, min_aspect_ratio=.8): | |
| self.imsize = imsize | |
| self.transform = transform | |
| self.min_aspect_ratio = min_aspect_ratio | |
| def __len__(self): | |
| return sys.maxsize | |
| def __getitem__(self, index): | |
| c_radius = np.random.uniform(1, self.imsize / 2) | |
| r_radius = np.random.uniform(1, self.min_aspect_ratio * c_radius) | |
| rr, cc = skimage.draw.ellipse( | |
| r=np.random.uniform(c_radius, self.imsize - c_radius), | |
| c=np.random.uniform(c_radius, self.imsize - c_radius), | |
| r_radius=r_radius, | |
| c_radius=c_radius, | |
| rotation=np.random.uniform(-np.pi, np.pi) | |
| ) | |
| im = np.zeros((self.imsize, self.imsize, 3)).astype('float32') | |
| im[rr, cc] = 1 | |
| if self.transform is not None: | |
| im = self.transform(im) | |
| return im, r_radius / c_radius | |
| class GenCircle(torch.utils.data.Dataset): | |
| def __init__(self, imsize=224, transform=None): | |
| self.imsize = imsize | |
| self.transform = transform | |
| def __len__(self): | |
| return sys.maxsize | |
| def __getitem__(self, index): | |
| c_radius = np.random.uniform(1, self.imsize / 2) | |
| rr, cc = skimage.draw.ellipse( | |
| r=np.random.uniform(c_radius, self.imsize - c_radius), | |
| c=np.random.uniform(c_radius, self.imsize - c_radius), | |
| r_radius=c_radius, | |
| c_radius=c_radius, | |
| rotation=np.random.uniform(-np.pi, np.pi) | |
| ) | |
| im = np.zeros((self.imsize, self.imsize, 3)).astype('float32') | |
| im[rr, cc] = 1 | |
| if self.transform is not None: | |
| im = self.transform(im) | |
| return im, 1 | |
| def circle_vs_ellipse(n_train=100, n_test=10, restore_path=None, imsize=224, | |
| use_gpu=False): | |
| model = CORnet_Z() | |
| model = torch.nn.DataParallel(model) | |
| if use_gpu: | |
| model = model.cuda() | |
| if restore_path is not None: | |
| ckpt_data = torch.load(restore_path, map_location='cpu') | |
| model.load_state_dict(ckpt_data['state_dict']) | |
| model.eval() | |
| def _get_features(n, kind): | |
| def _store_feats(layer, inp, output): | |
| """An ugly but effective way of accessing intermediate model features | |
| """ | |
| _model_feats.append(np.reshape(output, (len(output), -1)).numpy()) | |
| handle = model_layer.register_forward_hook(_store_feats) | |
| dataset = kind(imsize, | |
| torchvision.transforms.Compose([ | |
| torchvision.transforms.ToTensor(), | |
| normalize, | |
| ])) | |
| data_loader = torch.utils.data.DataLoader(dataset, | |
| batch_size=FLAGS.batch_size, | |
| shuffle=False, | |
| num_workers=FLAGS.workers, | |
| pin_memory=True) | |
| with torch.no_grad(): | |
| model_feats = [] | |
| aspect_ratios = [] | |
| for i, ims in enumerate(data_loader): | |
| if i * FLAGS.batch_size >= n: | |
| break | |
| aspect_ratios.append(ims[1]) | |
| _model_feats = [] | |
| model(ims[0]) | |
| model_feats.append(_model_feats[0]) | |
| model_feats = np.concatenate(model_feats)[:n] | |
| aspect_ratios = np.concatenate(aspect_ratios)[:n] | |
| handle.remove() | |
| return model_feats, aspect_ratios | |
| model_layer = model._modules['module'].decoder.flatten | |
| train_circles, _ = _get_features(n_train, kind=GenCircle) | |
| train_ellipses, _ = _get_features(n_train, kind=GenEllipse) | |
| test_circles, test_car = _get_features(n_test, kind=GenCircle) | |
| test_ellipses, test_ear = _get_features(n_test, kind=GenEllipse) | |
| clf = sklearn.svm.LinearSVC() | |
| train_feats = np.concatenate([train_circles, train_ellipses], axis=0) | |
| train_labels = np.concatenate([np.zeros(len(train_circles)), | |
| np.ones(len(train_ellipses))]) | |
| test_feats = np.concatenate([test_circles, test_ellipses], axis=0) | |
| test_labels = np.concatenate([np.zeros(len(test_circles)), | |
| np.ones(len(test_ellipses))]) | |
| test_ars = np.concatenate([test_car, test_ear]) | |
| clf.fit(train_feats, train_labels) | |
| preds = clf.predict(test_feats) | |
| df = pandas.DataFrame(np.stack([preds, test_labels, test_ars]).T, | |
| columns=['prediction', 'actual', 'aspect_ratio']) | |
| df['acc'] = df.prediction == df.actual | |
| print('accuracy:', df.acc.mean()) | |
| agg = df.groupby(pandas.cut(df.aspect_ratio, np.arange(0, 1.2, .1), | |
| include_lowest=True, right=False)).acc.mean() | |
| print(agg) | |
| if __name__ == '__main__': | |
| fire.Fire(command=FIRE_FLAGS) |