MNIST digit prediction

The data

Today we will start looking at the MNIST data set. This is a set of images of handwritten digits. The learning goal is to predict what digit the number represents (0-9). This is a canonical dataset for basic image processing and was probably the first dataset to which a large community of researchers used as a universal benchmark for computer vision. Today I am using just a subset (and slightly downsampled) of the full dataset. It is the example data given in the Elements of Statistical Learning. We will explore the larger dataset over the next two weeks as we start introducing neural networks.

I’ll read in the training and testing datasets; the split into train and test should be the same as that used by most other sources:

set.seed(1)
train <- read.csv("data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE)
test <- read.csv("data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE)

Looking at the data, we see that it has 257 columns: a first column giving the true digit class and the others giving the pixel intensity (in a scale from -1 to 1) of the 16x16 pixel image.

dim(train)
## [1] 7291  257
train[1:10,1:10]
##    V1 V2 V3 V4     V5     V6     V7     V8     V9    V10
## 1   6 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.631  0.862
## 2   5 -1 -1 -1 -0.813 -0.671 -0.809 -0.887 -0.671 -0.853
## 3   4 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -1.000 -1.000
## 4   7 -1 -1 -1 -1.000 -1.000 -0.273  0.684  0.960  0.450
## 5   3 -1 -1 -1 -1.000 -1.000 -0.928 -0.204  0.751  0.466
## 6   6 -1 -1 -1 -1.000 -1.000 -0.397  0.983 -0.535 -1.000
## 7   3 -1 -1 -1 -0.830  0.442  1.000  1.000  0.479 -0.328
## 8   1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000  0.510 -0.213
## 9   0 -1 -1 -1 -1.000 -1.000 -0.454  0.879 -0.745 -1.000
## 10  1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.909  0.801

We can plot what the image actually looks like in R using the rasterImage function:

y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5

plot(0,0)
rasterImage(y,-1,-1,1,1)

With a minimal amount of work, we can build a much better visualization of what these digits actually look like. Here is a grid of 35 observations. How hard to you think it will be to predict the correct classes?

iset <- sample(1:nrow(train),5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(5,7))
for (j in iset) {
  y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, train[j,1], cex=3, col="red")
}

Now, let’s extract out the matrices and classes:

Xtrain <- as.matrix(train[,-1])
Xtest <- as.matrix(test[,-1])
ytrain <- train[,1]
ytest <- test[,1]

I now want to apply a suite of techniques that we have studied to trying to predict the correct class for each handwritten digit.

Model fitting

K-nearest neighbors

As a simple model, we can use k-nearest neighbors. I set k equal to three, which in a multi-class model says to use the closest point unless the next two closest points agree on the class label.

library(FNN)
predKnn <- knn(Xtrain,Xtest,ytrain,k=3)

Ridge regression

For ridge regression, I’ll directly use the multinomial loss function and let the R function do cross validation for me. The glmnet package is my preferred package for doing ridge regressions, just remember to set alpha to 0.

library(glmnet)
outLm <- cv.glmnet(Xtrain, ytrain, alpha=0, nfolds=3,
                    family="multinomial")
predLm <- apply(predict(outLm, Xtest, s=outLm$lambda.min,
                  type="response"), 1, which.max) - 1L

The predicted classes are indexed starting at 1, so I subtract one off of the results to get the correct class labels.

Random forest

We can also run a random forest model. It will run significantly faster if I restrict the maximum number of nodes somewhat.

library(randomForest)
outRf <- randomForest(Xtrain,  factor(ytrain), maxnodes=10)
predRf <- predict(outRf, Xtest)

Gradient boosted trees

Gradient boosted trees also run directly on the multiclass labels. The model performs much better if I increase the interaction depth slightly. Increasing it past 2-3 is beneficial in large models, but rarely useful with smaller cases like this. I could also play with the learning rate, but won’t fiddle with that here for now.

library(gbm)
outGbm <- gbm.fit(Xtrain,  factor(ytrain), distribution="multinomial",
                  n.trees=500, interaction.depth=2)
## Iter   TrainDeviance   ValidDeviance   StepSize   Improve
##      1        2.3026             nan     0.0010    0.0084
##      2        2.2982             nan     0.0010    0.0086
##      3        2.2936             nan     0.0010    0.0084
##      4        2.2891             nan     0.0010    0.0082
##      5        2.2849             nan     0.0010    0.0082
##      6        2.2805             nan     0.0010    0.0081
##      7        2.2763             nan     0.0010    0.0083
##      8        2.2720             nan     0.0010    0.0081
##      9        2.2677             nan     0.0010    0.0081
##     10        2.2634             nan     0.0010    0.0079
##     20        2.2228             nan     0.0010    0.0076
##     40        2.1471             nan     0.0010    0.0068
##     60        2.0781             nan     0.0010    0.0064
##     80        2.0147             nan     0.0010    0.0058
##    100        1.9555             nan     0.0010    0.0052
##    120        1.9015             nan     0.0010    0.0049
##    140        1.8512             nan     0.0010    0.0045
##    160        1.8040             nan     0.0010    0.0040
##    180        1.7596             nan     0.0010    0.0042
##    200        1.7176             nan     0.0010    0.0038
##    220        1.6777             nan     0.0010    0.0038
##    240        1.6399             nan     0.0010    0.0034
##    260        1.6037             nan     0.0010    0.0031
##    280        1.5690             nan     0.0010    0.0031
##    300        1.5360             nan     0.0010    0.0032
##    320        1.5041             nan     0.0010    0.0029
##    340        1.4735             nan     0.0010    0.0028
##    360        1.4445             nan     0.0010    0.0027
##    380        1.4164             nan     0.0010    0.0027
##    400        1.3893             nan     0.0010    0.0024
##    420        1.3634             nan     0.0010    0.0022
##    440        1.3384             nan     0.0010    0.0023
##    460        1.3146             nan     0.0010    0.0021
##    480        1.2917             nan     0.0010    0.0021
##    500        1.2693             nan     0.0010    0.0022
predGbm <- apply(predict(outGbm, Xtest, n.trees=outGbm$n.trees),1,which.max) - 1L

Support vector machines

Finally, we will also fit a support vector machine. We can give the multiclass problem directly to the support vector machine, and one-vs-one prediction is done on all combinations of the classes. I found the radial kernel performed the best and the default cost also worked well:

library(e1071)
outSvm <- svm(Xtrain,  factor(ytrain), kernel="radial", cost=1)
predSvm <- predict(outSvm, Xtest)

Prediction Performance and Comparison

Misclassification rates

We see that the methods differ substantially in how predictive they are on the test dataset:

mean(predKnn != ytest)
## [1] 0.05530643
mean(predLm != ytest)
## [1] 0.0896861
mean(predRf != ytest)
## [1] 0.2670653
mean(predGbm != ytest)
## [1] 0.2321873
mean(predSvm != ytest)
## [1] 0.06178376

The tree-models perform far worse than the others. The ridge regression seems to work quite well given that it is constrained to only linear separating boundaries. The support vector machine does about twice as well but utilizing the kernel trick to easily fit higher dimensional models. The k-nearest neighbors performs just slightly better than the support vector machine.

Mis-classification rates by class

You may have noticed that some of the techniques (when verbose is set to True) spit out a mis-classification rate by class. This is useful to assess the model when there is more than two categories. For example look at where the ridge regression and support vector machines make the majority of their errors:

tapply(predLm != ytest, ytest, mean)
##          0          1          2          3          4          5 
## 0.03064067 0.04545455 0.16161616 0.14457831 0.10500000 0.13125000 
##          6          7          8          9 
## 0.07058824 0.08843537 0.15060241 0.05084746
tapply(predSvm != ytest, ytest, mean)
##          0          1          2          3          4          5 
## 0.02228412 0.04166667 0.07575758 0.12048193 0.07000000 0.08125000 
##          6          7          8          9 
## 0.07058824 0.06122449 0.09036145 0.03954802

We see that 8 and 3 are particularly difficult, with 1 being quite easy to predict.

Confusion matrix

We might think that a lot 8’s and 3’s are being mis-classified as one another (they do look similar in some ways). Looking at the confusion matricies we see that this is not quite the case:

table(predLm,ytest)
##       ytest
## predLm   0   1   2   3   4   5   6   7   8   9
##      0 348   0   5   4   2   7   2   0   6   0
##      1   0 252   0   0   2   0   0   0   0   2
##      2   2   0 166   3   5   0   3   1   4   1
##      3   2   2   4 142   0   7   0   1   5   0
##      4   3   4  10   1 179   2   3   5   2   3
##      5   0   0   2  11   0 139   3   0   4   1
##      6   2   4   3   0   4   0 158   0   1   0
##      7   0   0   1   2   1   0   0 134   1   1
##      8   1   1   7   2   1   1   1   1 141   1
##      9   1   1   0   1   6   4   0   5   2 168
table(predSvm,ytest)
##        ytest
## predSvm   0   1   2   3   4   5   6   7   8   9
##       0 351   0   2   0   0   3   4   0   4   0
##       1   0 253   0   0   1   0   0   0   0   0
##       2   6   1 183   5   3   2   4   2   2   0
##       3   0   0   4 146   0   3   0   0   3   0
##       4   1   5   3   0 186   1   2   5   0   4
##       5   0   1   0  11   1 147   1   0   2   1
##       6   0   3   1   0   2   0 158   0   1   0
##       7   0   1   1   1   3   0   0 138   0   0
##       8   1   0   4   3   1   1   1   0 151   2
##       9   0   0   0   0   3   3   0   2   3 170

We also see that the points where these two models make mistakes do not have too great of an overlap.

table(predSvm != ytest, predLm != ytest)
##        
##         FALSE TRUE
##   FALSE  1805   78
##   TRUE     22  102

This gives evidence that stacking could be beneficial.

Mis-classified digits

We can pick out a large sample of images that are 3’s and see if these are particularly difficult to detect.

iset <- sample(which(train[,1] == 3),5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(5,7))
for (j in iset) {
  y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, train[j,1], cex=3, col="red")
}

For the most part though, these do not seem difficult for a human to classify.

What if we look at the actual mis-classified points. Here are the ones from the support vector machine:

iset <- sample(which(predSvm != ytest),7*7)
par(mar=c(0,0,0,0))
par(mfrow=c(7,7))
for (j in iset) {
  y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, test[j,1], cex=3, col="red")
  text(0.8,-0.7, predSvm[j], cex=3, col="blue")
}

And the ridge regression:

iset <- sample(which(predLm != ytest),7*7)
par(mar=c(0,0,0,0))
par(mfrow=c(7,7))
for (j in iset) {
  y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE)
  y <- 1 - (y + 1)*0.5

  plot(0,0,xlab="",ylab="",axes=FALSE)
  rasterImage(y,-1,-1,1,1)
  box()
  text(-0.8,-0.7, test[j,1], cex=3, col="red")
  text(0.8,-0.7, predLm[j], cex=3, col="blue")
}

Can you rationalize what is going on between what makes an error for the two models?