Optical Character Recognition with KNN Classifier

Coding and explaining the working principles of Optical Character Recognition through the popular algorithm K Nearest Neighbors

Riccardo Andreoni
Towards Data Science

--

Handwritten digits from the MNIST dataset

Optical Character Recognition (OCR) is present in our daily life more regularly than we imagine. When we use Google Translate to translate text from pictures, we’re using OCR. When we send a letter and it arrives at its destination, OCR is working for us. When a visually impaired person scans a document and a machine reads it for them, again, OCR takes the credit for that. These are just a few cases of OCR presence in our routine, but, as of today, endlessly more of them exist and their number will further increase in the near future.

This article aims to illustrate what Optical Character Recognition is and to present a step-by-step application to get familiar with it. I will use the MNIST dataset to train a machine learning model, which will be capable of recognizing handwritten digits from pictures. After fine-tuning the algorithm, I will assess its accuracy.

What is Optical Character Recognition?

Optical Character Recognition (OCR), also known as Text Recognition, is a technology capable of converting printed or handwritten characters into machine-readable text. OCR makes use of both hardware and software tools. While hardware often consists of optical scanners or cameras, software, which is the focus of this article, consists of Machine Learning algorithms such as KNN Classifier, SVM Classifier, or Neural Networks.

Current applications of OCR are endless. Just to cite some of them:

  • Generation of electronic versions (soft versions) of old books or documents (see Google Books)
  • Scan road signs for autonomous driving
  • Tools creation enabling blind and visually impaired people to read (see AFB)
  • Real-time conversion of handwritten text
  • Automatic information extraction from passports or insurance documents

Model Selection

Optical Character Recognition is a classification problem, which means it is a Machine Learning problem where the output is categorical (i.e. it belongs to a finite set of values). The output classes are, of course, the different characters.
In this application, we take care of handwritten digits recognition. The number of output classes is 10 and it consists of the integers from 0 to 9:

There are many ready-to-use classification algorithms, so we need to choose one that is well suited for this purpose. The main constraint in the algorithm selection is that it must be suitable for multiclass classification. In fact, being the classes more than 2, this can’t be considered a binary classification problem.

Multiclass Classification

Some algorithms are natively capable of handling multiclass classification problems (such as Stochastic Gradient Descend, Random Forest, or Naive Bayes classifiers). Other algorithms (such as Support Vector Machine or Logistic Regression classifiers) are natively binary classifiers. Even if techniques to perform multiclass classification with binary classifiers (one-vs-all or one-vs-one strategies) exist, for the scope of this article, I will stick with a native multiclass classifier.

K-Nearest Neighbors Classifier

As there is no accepted theory on which the best learning algorithm for each problem type is, my choice falls on the k-Nearest Neighbors Classifier (KNN) mainly for its simplicity: it has few hyperparameters to tune (the number of neighbors to consider k, the distance function, which usually corresponds to the Euclidean or the Manhattan distance) and it is, furthermore, non-parametric (meaning that it does not need assumptions about the data distribution). As a drawback, KNN scales poorly with large datasets because of its intrinsic nature: it has to scan the whole dataset whenever a new example needs to be classified. For this reason, it makes sense to explore additional learning algorithms, such as Decision Trees or Random Forest.

The idea behind KNN is that an example i has a high probability of belonging to a particular class, let’s call it m, if m is also the most popular class of the k nearest examples in the feature space. More formally, the KNN learning algorithm follows these steps:

  1. Given one unknown example, measure its distance from all the labeled examples in the dataset
  2. Take the k labeled examples that are nearest to the unlabeled one
  3. Based on the classes of the k nearest neighbors, predict the class of the unlabeled example
  4. Repeat steps 1–3 for all the unlabeled examples.

After this operation, the feature space can be visualized as divided into as many regions as there are classes. Depending on where a new example lands, it will be predicted as belonging to a particular class.

Division of the feature space into classes (Image source: Wikipedia.org)

In order to apply KNN to a classification problem we need to specify:

  • A distance metrics p. The most commonly used is the Euclidean distance, which is a particular case of Minkowski distance, where it is set p=2. Given 2 points X and Y of the same dimension, the Minkowski distance D(X,Y) is computed as follows:
  • The number of neighbors k to take into consideration. This is maybe the most critical hyperparameter of KNN because:
    - a value too small for k may lead to overfitting
    - a value too big for k may lead to underfitting
    It is highly recommended to optimize k with a cross-validation set.
  • Optional weighting function. It is often ignored, but, in some applications, better results are achieved by setting a weighting function that rewards the classes of closest nearest neighbors.
  • Aggregation method. The simple majority vote is typically selected.

Here I offer a concise snapshot of the KNN algorithm, for more detailed information about its functioning check this article.

MNIST Dataset

To train our model I will use the MNIST dataset, which is a large database of handwritten digits. The dataset contains 70,000 small images (28 x 28 pixels), each one of them being labeled.

First, we have to import the dataset.

To get a sense of the dataset, I define a function to print one of the 70,000 images:

The output will be:

The dataset is composed of 70,000 examples, each of them having 784 features (28 x 28 pixels). Each feature represents the intensity of a given pixel on a greyscale colormap, with the feature values ranging from 0 to 255.
A 0 value corresponds to a white pixel, whereas a 255 value coincides with a black pixel.

Model Training

We split the dataset into a training set and a test set, as a common practice. I chose to assign 25% of the examples to the test set. The test set will be set aside and never touched until the end. Its only purpose is to provide a generalized evaluation of the model. Hyperparameters tuning will be performed with a cross-validation set.

Let’s first try to fit the classifier with the default parameters.

This should output an accuracy score of about 96.8%, which is not that bad considering the low effort put into building the model.

Hyperparameters Tuning

To achieve a better result I will try to tune the model hyperparameters with the Grid Search method. Grid Search is an exhaustive searching method. This is why I lower the number of folds in Cross-Validation, that is, for computational reasons. I am exploring the hyperparameter space by considering the following combinations:

In other words, Grid Search will explore the combination of 3 different values of nearest neighbors k and 2 ways of aggregating the neighbors’ classes (with uniform weights or distance-weighted). With 6 possible hyperparameters combinations, I set 3 as the number of cross-validation folds totaling 18 training steps.

Another possible hyperparameter that may be worth inspecting is the distance metrics.

After the grid search and cross-validation process, we can access the best estimator by calling the best_estimator_ method on the fitted grid search object.

In my case, the best estimator is the one with k=4 and using distance-weighted aggregation.

I will now evaluate the tuned model on the test set. The steps here are the following:

  1. Use the trained model to predict the label of each example in the test set
  2. Compare the predicted labels with the actual labels
  3. Compute the accuracy score

Scikit-learn allows to perform all of that with a single line of code:

The accuracy of the model increased to 97.3% by simply applying grid search to find better hyperparameters.

Next Steps

To further improve the OCR model we can try different approaches, like running other learning algorithms, expanding the hyperparameter search space, and increasing the size of the training set by synthetically modifying the original data (add noise, shift or rotate the images). Next, an article where a new model is trained with synthetically generated samples will follow.

Thanks for reading, I hope you enjoyed it!

--

--

Full time Data Scientist. My passions are statistics, data, and AI. I try my best to share them on my posts! https://www.linkedin.com/in/riccardo-andreoni/