import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
On first take, I need to match dimension of the weights to that of input. This is done in two ways:
This is a typical conv net for MNIST and CIFAR10 task:
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
The fully connected layer is a good place to start. Let's brute force some integer that match the dimensions, specifically, set a reasonable weight dimension range, say bigger than 100, and compute possible pooling kernel size:
$\frac{x+800}{m} \frac{500}{n} + \frac{500}{n} \frac{10}{2} = x$
# fc1 dimension: a, b
# fc2 dimension: b, c
a = 160
b = 50
c = 10
dx = 10
for _dx in range(10, 40, 1):
for n in range(10, b, 1):
for m in range(10, a, 1):
_mm = (_dx + a) / m
_nn = b / n
if int(_mm) == _mm and int(_nn) == _nn:
dw = (_mm + c/2) * _nn
if _dx-dw > 0:
print("dw = %d, dx = %d, dx-dw = %d, m = %d, n = %d" % (dw, _dx, _dx-dw, m, n))
Let's use dw = 22, dx = 26, dx-dw = 4, m = 31, n = 25
The only thing we need to update is extra input in initialization, and concat in forward
m = 31
n = 25
a = 160
b = 50
c = 10
dx = 26
dw = 22
class WeightNetMnist(nn.Module):
def __init__(self):
super(WeightNetMnist, self).__init__()
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10 + dx, 60)
self.fc2 = nn.Linear(60, 10)
def forward(self, x):
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
m1 = torch.unsqueeze(self.fc1.weight.data, 0)
m1 = F.max_pool2d(m1, (n, m))
m1 = torch.squeeze(m1, 0)
m2 = torch.unsqueeze(self.fc2.weight.data, 0)
m2 = F.max_pool2d(m2, (2, n))
m2 = torch.squeeze(m2, 0)
me = torch.cat((m1, m2.t()), 1)
me = me.view(-1, int(dw))
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)
fill = torch.ones([1, dx-dw])
me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
Let's setup the experiment infra:
from types import SimpleNamespace
args = {
"batch_size": 64,
"test_batch_size": 1000,
"epochs": 6,
"lr": 0.01,
"momentum": 0.5,
"no_cuda": False,
"seed": 1,
"dataset": "mnist",
"log_interval": 40,
"save_model": False
}
args = SimpleNamespace(**args)
args.dataset = 'mnist'
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
all_loss = []
all_test_loss = []
all_acc = []
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
all_loss.append(loss.item())
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
all_test_loss.append(test_loss)
all_acc.append(100. * correct / len(test_loader.dataset))
# model = NetCifar10().to(device)
# model = Net().to(device)
def run(modelClass):
if args.dataset == "mnist":
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
elif args.dataset == "cifar":
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = modelClass().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(),"mnist_cnn.pt")
run(Net)
run(WeightNetMnist)
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
all_acc_mnist_origin = all_acc[0:6]
all_acc_mnist_weightnet = all_acc[6:13]
t = np.arange(0, 6)
fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()
This is just slightly better, need to run against a larger dataset:
m = 45
n = 50
class NetCifar10(nn.Module):
def __init__(self):
super(NetCifar10, self).__init__()
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50, 700)
self.fc2 = nn.Linear(700, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
class WeightNetCifar10(nn.Module):
def __init__(self):
super(WeightNetCifar10, self).__init__()
dx = 460
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50 + dx, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
dx = 460
m1 = torch.unsqueeze(self.fc1.weight.data, 0)
m1 = F.max_pool2d(m1, (n, m))
m1 = torch.squeeze(m1, 0)
m2 = torch.unsqueeze(self.fc2.weight.data, 0)
m2 = F.max_pool2d(m2, (2, n))
m2 = torch.squeeze(m2, 0)
me = torch.cat((m1, m2.t()), 1)
size = (((5*5*50 + dx) / m) + 5) * (500/n)
me = me.view(-1, int(size))
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)
fill = torch.ones([1, 30])
me = torch.cat((me, fill), 1).expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
args.dataset = 'cifar'
args.epochs = 20
run(NetCifar10)
run(WeightNetCifar10)
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]
t = np.arange(0, 20)
fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()
Dimension matching is pretty ugly, as all we want is feed signal back into the input, a handy trick is random matrix projection. This trick was used to measure intrinsic dimension of reinforcement learning tasks.
class WeightNetMnistP(nn.Module):
def __init__(self):
super(WeightNetMnistP, self).__init__()
dx = 26
self.conv1 = nn.Conv2d(1, 5, 5, 1)
self.conv2 = nn.Conv2d(5, 10, 5, 1)
self.fc1 = nn.Linear(4*4*10 + dx, 50)
self.fc2 = nn.Linear(50, 10)
in_dim = (4*4*10 + dx + 10) * 50
out_dim = dx
P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
self.P = torch.nn.init.xavier_uniform_(P)
def forward(self, x):
dx = 26
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
size = (4*4*10 + dx + 10) * 50
me = me.view(-1, int(size))
me = torch.matmul(me, self.P)
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*10)
me = me.expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
args.dataset = 'mnist'
args.epochs = 6
run(WeightNetMnistP)
t = np.arange(0, 6)
all_acc_mnist_weightnetP = all_acc[52:52+6]
fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_mnist_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_mnist_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_mnist_weightnetP, label='w/ Weight feedback RMP') # random matrix projection
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()
class WeightNetCifar10P(nn.Module):
def __init__(self):
super(WeightNetCifar10P, self).__init__()
dx = 460
self.conv1 = nn.Conv2d(3, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(5*5*50 + dx, 500)
self.fc2 = nn.Linear(500, 10)
in_dim = (5*5*50 + dx + 10) * 500
out_dim = dx
P = torch.zeros((in_dim, out_dim)).type(torch.FloatTensor)
self.P = torch.nn.init.xavier_uniform_(P)
def forward(self, x):
dx = 460
me = torch.cat((self.fc1.weight.data, self.fc2.weight.data.t()), 1)
size = (5*5*50 + dx + 10) * 500
me = me.view(-1, int(size))
me = torch.matmul(me, self.P)
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 5*5*50)
me = me.expand((x.shape[0], -1))
x = torch.cat((x, me), 1)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
args.dataset = 'cifar'
args.epochs = 20
run(WeightNetCifar10P)
all_acc_cifar_origin = all_acc[12:12+20]
all_acc_cifar_weightnet = all_acc[32:32+20]
all_acc_cifar_weightnetP = all_acc[58:58+20]
t = np.arange(0, 20)
fig, ax = plt.subplots()
line1 = ax.plot(t, all_acc_cifar_origin, label='w/o weight feedback')
line2 = ax.plot(t, all_acc_cifar_weightnet, label='w/ Weight feedback')
line3 = ax.plot(t, all_acc_cifar_weightnetP, label='w/ Weight feedback RMP')
ax.set(xlabel='epochs', ylabel='accuracy %')
ax.grid()
ax.legend()
# fig.savefig("test.png")
plt.show()