# MNIST digit prediction ## The data Today we will start looking at the MNIST data set. This is a set of images of handwritten digits. The learning goal is to predict what digit the number represents (0-9). This is a canonical dataset for basic image processing and was probably the first dataset to which a large community of researchers used as a universal benchmark for computer vision. Today I am using just a subset (and slightly downsampled) of the full dataset. It is the example data given in the Elements of Statistical Learning. We will explore the larger dataset over the next two weeks as we start introducing neural networks. I'll read in the training and testing datasets; the split into train and test should be the same as that used by most other sources: ```{r} set.seed(1) train <- read.csv("../../data/mnist_train.psv", sep="|", as.is=TRUE, header=FALSE) test <- read.csv("../../data/mnist_test.psv", sep="|", as.is=TRUE, header=FALSE) ``` Looking at the data, we see that it has 257 columns: a first column giving the true digit class and the others giving the pixel intensity (in a scale from -1 to 1) of the 16x16 pixel image. ```{r} dim(train) train[1:10,1:10] ``` We can plot what the image actually looks like in R using the *rasterImage* function: ```{r} y <- matrix(as.matrix(train[3400,-1]),16,16,byrow=TRUE) y <- 1 - (y + 1)*0.5 plot(0,0) rasterImage(y,-1,-1,1,1) ``` With a minimal amount of work, we can build a much better visualization of what these digits actually look like. Here is a grid of 35 observations. How hard to you think it will be to predict the correct classes? ```{r, fig.width=7, fig.height=6} iset <- sample(1:nrow(train),5*7) par(mar=c(0,0,0,0)) par(mfrow=c(5,7)) for (j in iset) { y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE) y <- 1 - (y + 1)*0.5 plot(0,0,xlab="",ylab="",axes=FALSE) rasterImage(y,-1,-1,1,1) box() text(-0.8,-0.7, train[j,1], cex=3, col="red") } ``` Now, let's extract out the matrices and classes: ```{r} Xtrain <- as.matrix(train[,-1]) Xtest <- as.matrix(test[,-1]) ytrain <- train[,1] ytest <- test[,1] ``` I now want to apply a suite of techniques that we have studied to trying to predict the correct class for each handwritten digit. ## Model fitting ### K-nearest neighbors As a simple model, we can use k-nearest neighbors. I set k equal to three, which in a multi-class model says to use the closest point unless the next two closest points agree on the class label. ```{r} library(FNN) predKnn <- knn(Xtrain,Xtest,ytrain,k=3) ``` ### Ridge regression For ridge regression, I'll directly use the multinomial loss function and let the R function do cross validation for me. The **glmnet** package is my preferred package for doing ridge regressions, just remember to set *alpha* to 0. ```{r} library(glmnet) outLm <- cv.glmnet(Xtrain, ytrain, alpha=0, nfolds=3, family="multinomial") predLm <- apply(predict(outLm, Xtest, s=outLm$lambda.min, type="response"), 1, which.max) - 1L ``` The predicted classes are indexed starting at 1, so I subtract one off of the results to get the correct class labels. ### Random forest We can also run a random forest model. It will run significantly faster if I restrict the maximum number of nodes somewhat. ```{r} library(randomForest) outRf <- randomForest(Xtrain, factor(ytrain), maxnodes=10) predRf <- predict(outRf, Xtest) ``` ### Gradient boosted trees Gradient boosted trees also run directly on the multiclass labels. The model performs much better if I increase the interaction depth slightly. Increasing it past 2-3 is beneficial in large models, but rarely useful with smaller cases like this. I could also play with the learning rate, but won't fiddle with that here for now. ```{r} library(gbm) outGbm <- gbm.fit(Xtrain, factor(ytrain), distribution="multinomial", n.trees=500, interaction.depth=2) predGbm <- apply(predict(outGbm, Xtest, n.trees=outGbm$n.trees),1,which.max) - 1L ``` ### Support vector machines Finally, we will also fit a support vector machine. We can give the multiclass problem directly to the support vector machine, and one-vs-one prediction is done on all combinations of the classes. I found the radial kernel performed the best and the default cost also worked well: ```{r} library(e1071) outSvm <- svm(Xtrain, factor(ytrain), kernel="radial", cost=1) predSvm <- predict(outSvm, Xtest) ``` ## Prediction Performance and Comparison ### Misclassification rates We see that the methods differ substantially in how predictive they are on the test dataset: ```{r} mean(predKnn != ytest) mean(predLm != ytest) mean(predRf != ytest) mean(predGbm != ytest) mean(predSvm != ytest) ``` The tree-models perform far worse than the others. The ridge regression seems to work quite well given that it is constrained to only linear separating boundaries. The support vector machine does about twice as well but utilizing the kernel trick to easily fit higher dimensional models. The k-nearest neighbors performs just slightly better than the support vector machine. ### Mis-classification rates by class You may have noticed that some of the techniques (when verbose is set to *True*) spit out a mis-classification rate by class. This is useful to assess the model when there is more than two categories. For example look at where the ridge regression and support vector machines make the majority of their errors: ```{r} tapply(predLm != ytest, ytest, mean) tapply(predSvm != ytest, ytest, mean) ``` We see that 8 and 3 are particularly difficult, with 1 being quite easy to predict. ### Confusion matrix We might think that a lot 8's and 3's are being mis-classified as one another (they do look similar in some ways). Looking at the confusion matricies we see that this is not quite the case: ```{r} table(predLm,ytest) table(predSvm,ytest) ``` We also see that the points where these two models make mistakes do not have too great of an overlap. ```{r} table(predSvm != ytest, predLm != ytest) ``` This gives evidence that stacking could be beneficial. ### Mis-classified digits We can pick out a large sample of images that are 3's and see if these are particularly difficult to detect. ```{r, fig.width=7, fig.height=6} iset <- sample(which(train[,1] == 3),5*7) par(mar=c(0,0,0,0)) par(mfrow=c(5,7)) for (j in iset) { y <- matrix(as.matrix(train[j,-1]),16,16,byrow=TRUE) y <- 1 - (y + 1)*0.5 plot(0,0,xlab="",ylab="",axes=FALSE) rasterImage(y,-1,-1,1,1) box() text(-0.8,-0.7, train[j,1], cex=3, col="red") } ``` For the most part though, these do not seem difficult for a human to classify. What if we look at the actual mis-classified points. Here are the ones from the support vector machine: ```{r, fig.width=7, fig.height=7} iset <- sample(which(predSvm != ytest),7*7) par(mar=c(0,0,0,0)) par(mfrow=c(7,7)) for (j in iset) { y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE) y <- 1 - (y + 1)*0.5 plot(0,0,xlab="",ylab="",axes=FALSE) rasterImage(y,-1,-1,1,1) box() text(-0.8,-0.7, test[j,1], cex=3, col="red") text(0.8,-0.7, predSvm[j], cex=3, col="blue") } ``` And the ridge regression: ```{r, fig.width=7, fig.height=7} iset <- sample(which(predLm != ytest),7*7) par(mar=c(0,0,0,0)) par(mfrow=c(7,7)) for (j in iset) { y <- matrix(as.matrix(test[j,-1]),16,16,byrow=TRUE) y <- 1 - (y + 1)*0.5 plot(0,0,xlab="",ylab="",axes=FALSE) rasterImage(y,-1,-1,1,1) box() text(-0.8,-0.7, test[j,1], cex=3, col="red") text(0.8,-0.7, predLm[j], cex=3, col="blue") } ``` Can you rationalize what is going on between what makes an error for the two models?