Residual networks in torch (MNIST 100 layers)

Residual networks architecture is a type of architecture introduced by MSR(Microsoft Research) which helped them win the Imagenet competition 2015 ahead of Google which won last year. Using residual networks, they were able to train very deep neural networks as deep as 150 layers much more then previously trained and got better solutions thanks to its architecture. Here we will have a look at the relatively simple change to the traditional deep neural networks, and which can be easily added to your existing networks with out changing the architecture radically.

The main problem faced by very deep neural networks is that gradients have a hard time being propagated through all the layers which is a well known problem. Many solutions have been extended previously like highway networks, gradient clipping in RNN and even batch normalization which ameliorates  gradient vanishing or gradient explosion among others here and here.

Suppose you have a layer L, which takes as input x and gives output  y = L(x). If the layer has to learn a function y = 0, then L need not learn much and have all weights as zero. If the layer has to learn a function y = x, it has a much harder time to learn the identity function and have to learn its weights according. Here we will start the layer as the Identity transformation. We will see how this helps in learning very deep networks.

One thing deep learning says is the more layers we add the better the accuracy of prediction. But its not so simple. We see performance peaking at some depth, but the more layers we add after that the performance actually decreases (both the training error as well as validation error). But the (smaller) best performing net can be transformed into a deeper net by adding identity layers on the smaller net. Hence if the layers are identity transformation by default, then the deeper nets would not have a problem to reach to the solution of smaller nets.

One way to do this is if a layers transformation is y = F(x), define it as y = F(x) + x. Here we assume that the dimensions of y and x are same. Remember that this is a layer transformation which can be nonlinear. See the following image taken from the paper.

Screen Shot 2016-01-05 at 5.57.43 pm

If the dimensions of input and output are different in a layer, y = F(x) i.e dim(x) not equal to dim(y), add a simple linear transform, y = F(x) + W * x.

This is similar to highway networks in a sense, the first layer gradients don’t explode or vanish.

Implementation

Here as we are using MNIST, so lets focus on convolutional layers only. This is can be extended to others easily. We use 3×3 filters with stride 1 and pad 1 mostly as the convolutional layer. This ensures that the dimensions of the input is same as the output of that particular  convolutional layer.

nn.SpatialConvolution(64 -> 64, 3x3, 1,1, 1,1)

Instead of max pooling layers we will use 3×3 convolutions with stride 2 and no padding. This leads to filters of size half in width and height. They also increase double the number of filters. This keeps the computation at each layer the same.

nn.SpatialConvolution(64 -> 128, 3x3, 2,2, 1,1)

So a plain (non-residual) network may look like:

 net = nn.Sequential()
 net:add(nn.Reshape(1,28,28))
 net:add(nn.SpatialConvolution(1,64,3,3,1,1,1,1))
 net:add(nn.ReLU(true))
 net:add(nn.SpatialConvolution(64,64,3,3,1,1,1,1))
 net:add(nn.ReLU(true))
 net:add(nn.SpatialConvolution(64,128,3,3,2,2))
 net:add(nn.ReLU(true))
 net:add(nn.SpatialConvolution(128,128, 3,3,1,1,))
 .......

Now we will see how to create a residual network from this plain network.

First, suppose the layer L has equal input(x) and output(y) dimensions. The the transformation is simple. Instead of y = L(x), we will have y = L(x) + x. This can be easily done by CAddTable module. Also a layer is taken as two consecutive convolutional layers as this creates better nonlinearities.

Here the unit represents the layer L

local cat = nn.ConcatTable()
cat:add(unit)
cat:add(nn.Identity())
local net = net or nn.Sequential()
net:add(cat)
net:add(nn.CAddTable())
net:add(nn.ReLU(true))

Normally, to reduce the dimensions of a image we use MaxPooling. The paper uses convolutional layers with stride 2. They also increase the number of filters by doubling it. This keeps the computation at each layer the same. As the dimensions are changing at this layer, we will use y = L(y) + W * y. Here W is implemented by a convolutional layer with filters 1 and stride 2.

A unit here is convolutional layer with stride 2 and double filters

 local cat = nn.ConcatTable()
 cat:add(unit)
 cat:add(nn.SpatialConvolution(fin,2*fin,1,1,2,2))
 local net = net or nn.Sequential()
 net:add(cat)
 net:add(nn.CAddTable())
 net:add(nn.ReLU(true))

So a residual network looks like:

  (1): nn.Reshape(1x28x28)
  (2): nn.SpatialConvolution(1 -> 64, 3x3, 1,1, 1,1)
  (3): nn.SpatialBatchNormalization
  (4): nn.ReLU
  (5): nn.ConcatTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      |      (1): nn.SpatialConvolution(64 -> 64, 3x3, 1,1, 1,1)
      |      (2): nn.SpatialBatchNormalization
      |      (3): nn.ReLU
      |      (4): nn.SpatialConvolution(64 -> 64, 3x3, 1,1, 1,1)
      |      (5): nn.SpatialBatchNormalization
      |      (6): nn.ReLU
      |    }
      |`-> (2): nn.Identity
       ... -> output
  }
  (6): nn.CAddTable
  (7): nn.ReLU
  (8): nn.ConcatTable {
    input
      |`-> (1): nn.Sequential {
      |      [input -> (1) -> (2) -> (3) -> (4) -> (5) -> (6) -> output]
      |      (1): nn.SpatialConvolution(64 -> 128, 3x3, 2,2, 1,1)
      |      (2): nn.SpatialBatchNormalization
      |      (3): nn.ReLU
      |      (4): nn.SpatialConvolution(128 -> 128, 3x3, 1,1, 1,1)
      |      (5): nn.SpatialBatchNormalization
      |      (6): nn.ReLU
      |    }
      |`-> (2): nn.SpatialConvolution(64 -> 128, 1x1, 2,2)
       ... -> output
  }
  (9): nn.CAddTable
  (10): nn.ReLU

Results

Using this we train  plain and residual networks on MNIST. As can be seen, the same depth of network, residual networks have better performance. We were also able to train 100 layer deep networks with no problem.

MNIST plain 34 layer net – 99.16%

MNIST residual 34 layer net – 99.23%

MNIST plain 100 layer net – 99.49%

Code is on github here.

They have also removed the final weight layer instead using global average pooling. I have not implemented that way instead using linear layer after max pooling. Using data augmentation and dropout may lead to better results. But here we only demonstrate how to train very deep networks.

One thought on “Residual networks in torch (MNIST 100 layers)

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s