Today we will start looking at the MNIST data set. This is a set of images of handwritten digits. The learning goal is to predict what digit the number represents (0-9). This is a canonical dataset for basic image processing and was probably the first dataset to which a large community of researchers used as a universal benchmark for computer vision. Today I am using just a subset (and slightly downsampled) of the full dataset. It is the example data given in the Elements of Statistical Learning. We will explore the larger dataset over the next two weeks as we start introducing neural networks.
I’ll read in the training and testing datasets; the split into train and test should be the same as that used by most other sources:
set.seed(1)
train <- read.csv("data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE)
test <- read.csv("data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE)
Looking at the data, we see that it has 257 columns: a first column giving the true digit class and the others giving the pixel intensity (in a scale from -1 to 1) of the 16x16 pixel image.
dim(train)
## [1] 7291 257
train[1:10,1:10]
## V1 V2 V3 V4 V5 V6 V7 V8 V9 V10
## 1 6 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.631 0.862
## 2 5 -1 -1 -1 -0.813 -0.671 -0.809 -0.887 -0.671 -0.853
## 3 4 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -1.000 -1.000
## 4 7 -1 -1 -1 -1.000 -1.000 -0.273 0.684 0.960 0.450
## 5 3 -1 -1 -1 -1.000 -1.000 -0.928 -0.204 0.751 0.466
## 6 6 -1 -1 -1 -1.000 -1.000 -0.397 0.983 -0.535 -1.000
## 7 3 -1 -1 -1 -0.830 0.442 1.000 1.000 0.479 -0.328
## 8 1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 0.510 -0.213
## 9 0 -1 -1 -1 -1.000 -1.000 -0.454 0.879 -0.745 -1.000
## 10 1 -1 -1 -1 -1.000 -1.000 -1.000 -1.000 -0.909 0.801
We can plot what the image actually looks like in R using the rasterImage function:
y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0)
rasterImage(y,-1,-1,1,1)
With a minimal amount of work, we can build a much better visualization of what these digits actually look like. Here is a grid of 35 observations. How hard to you think it will be to predict the correct classes?
iset <- sample(1:nrow(train),5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(5,7))
for (j in iset) {
y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0,xlab="",ylab="",axes=FALSE)
rasterImage(y,-1,-1,1,1)
box()
text(-0.8,-0.7, train[j,1], cex=3, col="red")
}
Now, let’s extract out the matrices and classes:
Xtrain <- as.matrix(train[,-1])
Xtest <- as.matrix(test[,-1])
ytrain <- train[,1]
ytest <- test[,1]
I now want to apply a suite of techniques that we have studied to trying to predict the correct class for each handwritten digit.
As a simple model, we can use k-nearest neighbors. I set k equal to three, which in a multi-class model says to use the closest point unless the next two closest points agree on the class label.
library(FNN)
predKnn <- knn(Xtrain,Xtest,ytrain,k=3)
For ridge regression, I’ll directly use the multinomial loss function and let the R function do cross validation for me. The glmnet package is my preferred package for doing ridge regressions, just remember to set alpha to 0.
library(glmnet)
outLm <- cv.glmnet(Xtrain, ytrain, alpha=0, nfolds=3,
family="multinomial")
predLm <- apply(predict(outLm, Xtest, s=outLm$lambda.min,
type="response"), 1, which.max) - 1L
The predicted classes are indexed starting at 1, so I subtract one off of the results to get the correct class labels.
We can also run a random forest model. It will run significantly faster if I restrict the maximum number of nodes somewhat.
library(randomForest)
outRf <- randomForest(Xtrain, factor(ytrain), maxnodes=10)
predRf <- predict(outRf, Xtest)
Gradient boosted trees also run directly on the multiclass labels. The model performs much better if I increase the interaction depth slightly. Increasing it past 2-3 is beneficial in large models, but rarely useful with smaller cases like this. I could also play with the learning rate, but won’t fiddle with that here for now.
library(gbm)
outGbm <- gbm.fit(Xtrain, factor(ytrain), distribution="multinomial",
n.trees=500, interaction.depth=2)
## Iter TrainDeviance ValidDeviance StepSize Improve
## 1 2.3026 nan 0.0010 0.0084
## 2 2.2982 nan 0.0010 0.0086
## 3 2.2936 nan 0.0010 0.0084
## 4 2.2891 nan 0.0010 0.0082
## 5 2.2849 nan 0.0010 0.0082
## 6 2.2805 nan 0.0010 0.0081
## 7 2.2763 nan 0.0010 0.0083
## 8 2.2720 nan 0.0010 0.0081
## 9 2.2677 nan 0.0010 0.0081
## 10 2.2634 nan 0.0010 0.0079
## 20 2.2228 nan 0.0010 0.0076
## 40 2.1471 nan 0.0010 0.0068
## 60 2.0781 nan 0.0010 0.0064
## 80 2.0147 nan 0.0010 0.0058
## 100 1.9555 nan 0.0010 0.0052
## 120 1.9015 nan 0.0010 0.0049
## 140 1.8512 nan 0.0010 0.0045
## 160 1.8040 nan 0.0010 0.0040
## 180 1.7596 nan 0.0010 0.0042
## 200 1.7176 nan 0.0010 0.0038
## 220 1.6777 nan 0.0010 0.0038
## 240 1.6399 nan 0.0010 0.0034
## 260 1.6037 nan 0.0010 0.0031
## 280 1.5690 nan 0.0010 0.0031
## 300 1.5360 nan 0.0010 0.0032
## 320 1.5041 nan 0.0010 0.0029
## 340 1.4735 nan 0.0010 0.0028
## 360 1.4445 nan 0.0010 0.0027
## 380 1.4164 nan 0.0010 0.0027
## 400 1.3893 nan 0.0010 0.0024
## 420 1.3634 nan 0.0010 0.0022
## 440 1.3384 nan 0.0010 0.0023
## 460 1.3146 nan 0.0010 0.0021
## 480 1.2917 nan 0.0010 0.0021
## 500 1.2693 nan 0.0010 0.0022
predGbm <- apply(predict(outGbm, Xtest, n.trees=outGbm$n.trees),1,which.max) - 1L
Finally, we will also fit a support vector machine. We can give the multiclass problem directly to the support vector machine, and one-vs-one prediction is done on all combinations of the classes. I found the radial kernel performed the best and the default cost also worked well:
library(e1071)
outSvm <- svm(Xtrain, factor(ytrain), kernel="radial", cost=1)
predSvm <- predict(outSvm, Xtest)
We see that the methods differ substantially in how predictive they are on the test dataset:
mean(predKnn != ytest)
## [1] 0.05530643
mean(predLm != ytest)
## [1] 0.0896861
mean(predRf != ytest)
## [1] 0.2670653
mean(predGbm != ytest)
## [1] 0.2321873
mean(predSvm != ytest)
## [1] 0.06178376
The tree-models perform far worse than the others. The ridge regression seems to work quite well given that it is constrained to only linear separating boundaries. The support vector machine does about twice as well but utilizing the kernel trick to easily fit higher dimensional models. The k-nearest neighbors performs just slightly better than the support vector machine.
You may have noticed that some of the techniques (when verbose is set to True) spit out a mis-classification rate by class. This is useful to assess the model when there is more than two categories. For example look at where the ridge regression and support vector machines make the majority of their errors:
tapply(predLm != ytest, ytest, mean)
## 0 1 2 3 4 5
## 0.03064067 0.04545455 0.16161616 0.14457831 0.10500000 0.13125000
## 6 7 8 9
## 0.07058824 0.08843537 0.15060241 0.05084746
tapply(predSvm != ytest, ytest, mean)
## 0 1 2 3 4 5
## 0.02228412 0.04166667 0.07575758 0.12048193 0.07000000 0.08125000
## 6 7 8 9
## 0.07058824 0.06122449 0.09036145 0.03954802
We see that 8 and 3 are particularly difficult, with 1 being quite easy to predict.
We might think that a lot 8’s and 3’s are being mis-classified as one another (they do look similar in some ways). Looking at the confusion matricies we see that this is not quite the case:
table(predLm,ytest)
## ytest
## predLm 0 1 2 3 4 5 6 7 8 9
## 0 348 0 5 4 2 7 2 0 6 0
## 1 0 252 0 0 2 0 0 0 0 2
## 2 2 0 166 3 5 0 3 1 4 1
## 3 2 2 4 142 0 7 0 1 5 0
## 4 3 4 10 1 179 2 3 5 2 3
## 5 0 0 2 11 0 139 3 0 4 1
## 6 2 4 3 0 4 0 158 0 1 0
## 7 0 0 1 2 1 0 0 134 1 1
## 8 1 1 7 2 1 1 1 1 141 1
## 9 1 1 0 1 6 4 0 5 2 168
table(predSvm,ytest)
## ytest
## predSvm 0 1 2 3 4 5 6 7 8 9
## 0 351 0 2 0 0 3 4 0 4 0
## 1 0 253 0 0 1 0 0 0 0 0
## 2 6 1 183 5 3 2 4 2 2 0
## 3 0 0 4 146 0 3 0 0 3 0
## 4 1 5 3 0 186 1 2 5 0 4
## 5 0 1 0 11 1 147 1 0 2 1
## 6 0 3 1 0 2 0 158 0 1 0
## 7 0 1 1 1 3 0 0 138 0 0
## 8 1 0 4 3 1 1 1 0 151 2
## 9 0 0 0 0 3 3 0 2 3 170
We also see that the points where these two models make mistakes do not have too great of an overlap.
table(predSvm != ytest, predLm != ytest)
##
## FALSE TRUE
## FALSE 1805 78
## TRUE 22 102
This gives evidence that stacking could be beneficial.
We can pick out a large sample of images that are 3’s and see if these are particularly difficult to detect.
iset <- sample(which(train[,1] == 3),5*7)
par(mar=c(0,0,0,0))
par(mfrow=c(5,7))
for (j in iset) {
y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0,xlab="",ylab="",axes=FALSE)
rasterImage(y,-1,-1,1,1)
box()
text(-0.8,-0.7, train[j,1], cex=3, col="red")
}
For the most part though, these do not seem difficult for a human to classify.
What if we look at the actual mis-classified points. Here are the ones from the support vector machine:
iset <- sample(which(predSvm != ytest),7*7)
par(mar=c(0,0,0,0))
par(mfrow=c(7,7))
for (j in iset) {
y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0,xlab="",ylab="",axes=FALSE)
rasterImage(y,-1,-1,1,1)
box()
text(-0.8,-0.7, test[j,1], cex=3, col="red")
text(0.8,-0.7, predSvm[j], cex=3, col="blue")
}
And the ridge regression:
iset <- sample(which(predLm != ytest),7*7)
par(mar=c(0,0,0,0))
par(mfrow=c(7,7))
for (j in iset) {
y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE)
y <- 1 - (y + 1)*0.5
plot(0,0,xlab="",ylab="",axes=FALSE)
rasterImage(y,-1,-1,1,1)
box()
text(-0.8,-0.7, test[j,1], cex=3, col="red")
text(0.8,-0.7, predLm[j], cex=3, col="blue")
}
Can you rationalize what is going on between what makes an error for the two models?