The world’s leading publication for data science, AI, and ML professionals.

Chess Recognition Problem: A Deep Dive Solution

An end-to-end solution to solve 2D chess recognition problem.

Photo by Randy Fath on Unsplash
Photo by Randy Fath on Unsplash

1. Introduction

The problem of identifying the configuration of chess pieces from an image of a physical chessboard is often referred to as chess recognition. The ability of a computer to recognize the chess pieces on a chessboard is the first step to developing an intelligent system that can play chess, can solve chess problems/puzzles, and can perform chess analysis.

The goal of my project is to recognize the chess pieces and their respective position on the chessboard which can be described in a structured format such as Forsyth–Edwards Notation (FEN) which is compatible with various chess engines. I also added an extra layer of interpretation which takes the generated FEN as input and outputs if there are any potential attacks (checks) and also spots illegal chess positions.


2. Dataset Overview

The dataset has 100000 images of randomly generated chess positions of 5–15 pieces (2 kings and 3–13 pawns/pieces). All images are 400 by 400 pixels.

Pieces were generated with the following probability distribution:

  1. 30% for Pawn.
  2. 20% for Bishop.
  3. 20% for Knight.
  4. 20% for Rook.
  5. 10% for Queen.
  6. 2 Kings are guaranteed to be on the board.

Labels are in a filename in FEN format, but with dashes instead of slashes.

The dataset is in the public domain. Please check the citations for the source of the dataset [1].

2.1. Forsyth–Edwards Notation (FEN)

Forsyth–Edwards Notation (FEN) is a standard notation for describing a particular board position of a Chess game. The purpose of FEN is to provide all the necessary information to restart a game from a particular position.

A FEN record defines a particular game position, all in one text line and using only the ASCII character set [2].

FEN represents 6 fields:

  • Piece placement data
  • Active color
  • Castling
  • En passant
  • Halfmove clock
  • Fullmove clock

Note: Since the dataset has static images, I can generate a FEN which has only piece placement data.


3. Exploratory Data Analysis

This is a critical stage to understand and investigate the patterns within the dataset.

In the above code snippet, I have imported the chess_positions module. I developed, and perfected this module for two main reasons –

  1. EDA
  2. Interpretation of the FEN.

3.1. Schema of the dataset

The training dataset has 80000 images and the test dataset has 20000 images.

3.2. Checking for duplicates

The below code snippet shows that all the labels (also images) are unique.

3.3. Chess piece distribution

In chess, there are 6 different pieces (by color – there are 12 different pieces).

  1. K – White King, k – Black King.
  2. Q – White Queen, q – Black Queen.
  3. B – White Bishop, b – Black Bishop.
  4. N – White Knight, n – Black Knight.
  5. R – White Rook, r – Black Rook.
  6. P – White Pawn, p – Black Pawn.

3.3.1. Chess piece distribution of the training dataset

Image by Author - Total pieces (by type) in the training dataset
Image by Author – Total pieces (by type) in the training dataset

3.3.2. Chess piece distribution of the test dataset

Image by Author - Total pieces (by type) in the test dataset
Image by Author – Total pieces (by type) in the test dataset

Conclusion of the above chess piece distribution plots:

  1. Neither the train set nor the test set has a total of 8 pawns (white and black) in the boards.
  2. There is one and only one king (white and black) on each board, which is legit.
  3. Queens are the least used pieces in both train and test sets.

3.4. Density plots

The density plot is a visual representation of the PDF for a set of points. The PDF mainly shows the distribution of the data.

3.4.1. PDF of the training dataset

Image by Author - PDF plot of chess pieces in the training dataset
Image by Author – PDF plot of chess pieces in the training dataset

3.4.2. PDF of the test dataset

Image by Author - PDF plot of chess pieces in the test dataset
Image by Author – PDF plot of chess pieces in the test dataset

Conclusion of the above density plots:

  1. There is only 1 Black king and 1 White King on all the boards.
  2. The distribution of all pieces is almost similar. There are more boards with 0 pawns or 0 rooks or 0 knights or 0 bishops or 0 queens.

3.5. Total pieces vs No. of boards

As already mentioned, the training dataset has 80000 chessboards and the test data set has 20000 chessboards.

Image by Author - Histogram of total pieces vs no. of boards
Image by Author – Histogram of total pieces vs no. of boards

Conclusion of the above board-wise pieces distribution plot:

  1. The maximum number of chess pieces is 15.
  2. The minimum number of chess pieces is 5.
  3. The majority of boards are filled with 15 chess pieces.

3.6. Finding checks and illegal positioning of chess pieces

I developed and perfected the chess_positions module to detect the checks and illegal positioning of chess pieces in the images. Below are the code snippets of 3 classes in the module which help detect checks and illegal chess images.

3.6.1. Defining the chessboard

class Board(object):
    """
    This class is defines the chessboard.
    """

    def __init__(self, fen_label):
        self.fen_label = re.sub(pattern=r'd',
                                repl=lambda x: self.get_ones(char=x.group()),
                                string=fen_label)
        self.fen_matrix = self.get_fen_matrix()

    def get_ones(self, char):
        """
        This method returns repetitive 1s based on input digit character.
        """
        if char.isdigit():
            return '1' * int(char)

    def get_fen_matrix(self):
        """
        This method constructs a FEN matrix.
        """
        fen_matrix = np.array([list(row) for row in self.fen_label.split('/')])
        return fen_matrix

    def get_piece_positions(self, notation):
        """
        This method returns the 2D index of the piece from FEN matrix.
        """
        (i, j) = np.where(self.fen_matrix == notation)
        try:
            if i is not None and j is not None:
                return i, j
        except:
            return None

3.6.2. Checks in the chessboard

class Check(Board):
    """
    This class finds if there are any checks in the chessboard.
    """

    def __init__(self, fen_label):
        super().__init__(fen_label=fen_label)

    def get_sub_matrix(self, ai, aj, di, dj):
        """
        This method chops the chessboard to a sub-matrix.
        """
        corners = np.array([(ai, aj), (di, aj), (ai, dj), (di, dj)])
        min_i, max_i = min(corners[:, 0]), max(corners[:, 0])
        min_j, max_j = min(corners[:, 1]), max(corners[:, 1])
        sub_matrix = self.fen_matrix[min_i:max_i+1, min_j:max_j+1]
        return sub_matrix, sub_matrix.shape

    def get_straight_checks(self, ai, aj, di, dj, a, d):
        """
        This method returns the checks along the straight path.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            if di == i:
                attack_path = self.fen_matrix[di]
            elif dj == j:
                attack_path = self.fen_matrix[:, dj]
            else:
                continue
            a_ind = np.where(attack_path == a)[0]
            d_ind = np.where(attack_path == d)[0][0]
            for a_i_ in a_ind:
                attack_path_ = attack_path[min(a_i_, d_ind): max(a_i_, d_ind)+1]
                checks.append(np.where(attack_path_ != '1')[0])
        checks = list(filter(lambda x: len(x) == 2, checks))
        return checks

    def get_diagonal_checks(self, ai, aj, di, dj, a):
        """
        This method returns the checks along the diagonal path.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            sub_mat, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
            if sub_shape[0] == sub_shape[1]:
                if a not in sub_mat.diagonal():
                    sub_mat = np.flipud(m=sub_mat)
                checks.append(np.where(sub_mat.diagonal() != '1')[0])
            else:
                continue
        checks = list(filter(lambda x: len(x) == 2, checks))
        return checks

    def get_knight_checks(self, ai, aj, di, dj):
        """
        This method returns the checks along the L-shaped paths for knights.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            attack_positions = [(i-2, j-1), (i-2, j+1),
                                (i-1, j-2), (i-1, j+2),
                                (i+1, j-2), (i+1, j+2),
                                (i+2, j-1), (i+2, j+1)]
            if (di, dj) in attack_positions:
                checks.append((i, j))
        return checks

    def get_pawn_checks(self, ai, aj, di, dj):
        """
        This method returns the checks for pawns.
        """
        checks = list()
        for (i, j) in zip(ai, aj):
            _, sub_shape = self.get_sub_matrix(ai=i, aj=j, di=di, dj=dj)
            if sub_shape[0] == 2 and sub_shape[1] == 2:
                checks.append((i, j))
            else:
                continue
        return checks

    def king_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the other king.
        This is unlikely, but I am just adding a validation rule.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        ai, aj = ai[0], aj[0]
        attack_positions = [(di, dj-1), (di, dj+1),
                            (di-1, dj), (di+1, dj),
                            (di-1, dj+1), (di-1, dj-1),
                            (di+1, dj-1), (di+1, dj+1)]
        if (ai, aj) in attack_positions:
            flag = True
        return flag

    def rook_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the rook.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_straight_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
        if checks:
            flag = True
        return flag

    def bishop_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the bishop.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_diagonal_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker)
        if checks:
            flag = True
        return flag

    def knight_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the knight.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_knight_checks(ai=ai, aj=aj, di=di, dj=dj)
        if checks:
            flag = True
        return flag

    def queen_checks_king(self, attacker, defendant):
        """
        This method checks if the king is being attacked by the queen.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        straight_checks = self.get_straight_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker, d=defendant)
        diagonal_checks = self.get_diagonal_checks(
            ai=ai, aj=aj, di=di, dj=dj, a=attacker)
        if straight_checks or diagonal_checks:
            flag = True
        return flag

    def pawn_checks_king(self, attacker, defendant):
        """
        This methos checks if the king is being attacked by the pawn.

        Note: It is hard to determine from an image, which side of 
              the chessboard is black or is white.
              Hence, this method assumes the pawn is attacking the king 
              if both the pieces are diagnolly aligned by 1 step.
        """
        flag = False
        di, dj = self.get_piece_positions(notation=defendant)
        if len(di) == 1 and len(dj) == 1:
            di, dj = di[0], dj[0]
        else:
            return flag
        ai, aj = self.get_piece_positions(notation=attacker)
        checks = self.get_pawn_checks(ai=ai, aj=aj, di=di, dj=dj)
        if checks:
            flag = True
        return flag

3.6.2.1. Check distribution (only on legal images)

Below is the check distribution of the training dataset.

Image by Author - Check distribution of the training dataset
Image by Author – Check distribution of the training dataset

Below is the check distribution of the test dataset.

Image by Author - Check distribution of the test dataset
Image by Author – Check distribution of the test dataset

Conclusion of the above check distribution plots:

  1. Rooks attack the opponent king more often compared to other chess pieces.
  2. Pawns attack the opponent king less often compared to other chess pieces.

3.6.3. Illegal positionings in the chessboard

class IllegalPosition(Check):
    """
    This class finds if the pieces are illegally positioned in the chessboard.
    """

    def __init__(self, fen_label):
        super().__init__(fen_label=fen_label)

    def are_kings_less(self):
        """
        Rule on kings.
        """
        k_c = self.fen_label.count('k')
        K_c = self.fen_label.count('K')
        return (k_c < 1 and K_c < 1) or (k_c < 1) or (K_c < 1)

    def are_kings_more(self):
        """
        Rule on kings.
        """
        k_c = self.fen_label.count('k')
        K_c = self.fen_label.count('K')
        return (k_c > 1 and K_c > 1) or (k_c > 1) or (K_c > 1)

    def are_queens_more(self):
        """
        Rule on queens.
        """
        q_c = self.fen_label.count('q')
        Q_c = self.fen_label.count('Q')
        return (q_c > 9 and Q_c > 9) or (q_c > 9) or (Q_c > 9)

    def are_bishops_more(self):
        """
        Rule on bishops.
        """
        b_c = self.fen_label.count('b')
        B_c = self.fen_label.count('B')
        return (b_c > 10 and B_c > 10) or (b_c > 10) or (B_c > 10)

    def are_knights_more(self):
        """
        Rule on knights.
        """
        n_c = self.fen_label.count('n')
        N_c = self.fen_label.count('N')
        return (n_c > 10 and N_c > 10) or (n_c > 10) or (N_c > 10)

    def are_rooks_more(self):
        """
        Rule on rooks.
        """
        r_c = self.fen_label.count('r')
        R_c = self.fen_label.count('R')
        return (r_c > 10 and R_c > 10) or (r_c > 10) or (R_c > 10)

    def are_pawns_more(self):
        """
        Rule on pawns.
        """
        p_c = self.fen_label.count('p')
        P_c = self.fen_label.count('P')
        return (p_c > 8 and P_c > 8) or (p_c > 8) or (P_c > 8)

    def rule_1(self):
        """
        This method checks the count of the kings and the pieces in the board.
        1. The count of white king and black king should always be 1.
        2. The count of white queen and/or black queen should not cross 9.
        3. The count of white bishop and/or black bishop should not cross 10.
        4. The count of white knight and/or black knight should not cross 10.
        5. The count of white rook and/or black rook should not cross 10.
        6. The count of while pawn and/or black pawn should not cross 8.
        7. The chessboard should never be empty.
        """
        flag = False
        if self.are_kings_less():
            flag = True
        elif self.are_kings_more():
            flag = True
        elif self.are_queens_more():
            flag = True
        elif self.are_bishops_more():
            flag = True
        elif self.are_knights_more():
            flag = True
        elif self.are_rooks_more():
            flag = True
        elif self.are_pawns_more():
            flag = True
        return flag

    def rule_2(self):
        """
        This method checks if the pawns are in the first and last row of the board.
        1. No pawn should be on the first row and/or on the last row.
           The pawn that reaches the last row always gets promoted.
           Hence no pawns on the last row.
        """
        flag = False
        fen_label_list = self.fen_label.split('/')
        f_row, l_row = fen_label_list[0], fen_label_list[-1]
        p_f_row = 'p' in f_row
        p_l_row = 'p' in l_row
        P_f_row = 'P' in f_row
        P_l_row = 'P' in l_row
        if (p_f_row and p_l_row) or p_f_row or p_l_row:
            flag = True
        elif (P_f_row and P_l_row) or P_f_row or P_l_row:
            flag = True
        return flag

    def rule_3(self):
        """
        This method checks if the king is attacking the other king.
        1. The king never checks the other king.
        2. The king can attack other enemy pieces except the enemy king.
        """
        return self.king_checks_king(attacker='k', defendant='K')

    def rule_4(self):
        """
        This method checks if the kings are under check simultaneously.
        1. The two kings are never under check at the same time.
        """
        r_checks_K = self.rook_checks_king(attacker='r', defendant='K')
        n_checks_K = self.knight_checks_king(attacker='n', defendant='K')
        b_checks_K = self.bishop_checks_king(attacker='b', defendant='K')
        q_checks_K = self.queen_checks_king(attacker='q', defendant='K')
        p_checks_K = self.pawn_checks_king(attacker='p', defendant='K')
        R_checks_k = self.rook_checks_king(attacker='R', defendant='k')
        N_checks_k = self.knight_checks_king(attacker='N', defendant='k')
        B_checks_k = self.bishop_checks_king(attacker='B', defendant='k')
        Q_checks_k = self.queen_checks_king(attacker='Q', defendant='k')
        P_checks_k = self.pawn_checks_king(attacker='P', defendant='k')
        is_K_checked = r_checks_K or n_checks_K or b_checks_K or q_checks_K or p_checks_K
        is_k_checked = R_checks_k or N_checks_k or B_checks_k or Q_checks_k or P_checks_k
        return is_K_checked and is_k_checked

    def is_illegal(self):
        """
        This method is a consolidation of all the above basic rules of chess.
        """
        return self.rule_1() or self.rule_2() or self.rule_3() or self.rule_4()

3.6.3.1. Legal and Illegal positionings in the chessboard

After filtering the training and test dataset using the above class, I obtained the below results.

Training dataset (80000 chess images).

  1. The legal training chess image count is 67813 (84.8%).
  2. The illegal training chess image count is 12187 (15.2%).

Test dataset (20000 chess images).

  1. The legal testing chess image count is 17019 (85.1%).
  2. The illegal testing chess image count is 2981 (14.9%).

3.6.3.2. Sample plot of legal chess images

Image by Author - Sample of legal chess images
Image by Author – Sample of legal chess images

3.6.3.3. Sample plot of illegal chess images

Image by Author - Sample of illegal chess images
Image by Author – Sample of illegal chess images

3.7. Ratio, height, and width of all images

Image by Author - Ratio, height, and width
Image by Author – Ratio, height, and width

Conclusion of the above image sizes plot:

  1. Ratio = Height / Width → 1. All images have the same ratio i.e., 1.
  2. All images have the same width i.e., 400 pixels.
  3. All images have the same height i.e., 400 pixels.

4. Data Pipeline

I created a data pipeline class to feed the data to the model (learner).

Before that, I would like to show the preprocessing step. In preprocessing, I resized the chess image by 50% and divided it into 64 blocks (squares). The major advantage of resizing the dataset is less space complexity. The RAM will not be overburdened.

Chess image before preprocessing:

Image by Author - Before resizing and preprocessing
Image by Author – Before resizing and preprocessing

Chess image after preprocessing:

Image by Author - After resizing and preprocessing
Image by Author – After resizing and preprocessing

I used Pavel Koryakin’s (author of the dataset) one-hot encoding logic to encode the labels.

Below is the data pipeline class.

class DataPipeline(object):
    """
    This class is a data pipeline for Deep Learning model.
    """

    def __init__(self, tr_images, tr_labels, cv_images, cv_labels, te_images, te_labels):
        self.rows, self.cols = (8, 8)
        self.square = None
        self.h, self.w, self.c = None, None, None
        self.N = 13
        self.tr_images = np.array(tr_images)
        self.tr_labels = np.array(tr_labels)
        self.cv_images = np.array(cv_images)
        self.cv_labels = np.array(cv_labels)
        self.te_images = np.array(te_images)
        self.te_labels = np.array(te_labels)
        self.piece_symbols = 'prbnkqPRBNKQ'

    def preprocess_input_image(self, imagefile, resize_scale=(200, 200)):
        """
        This function preprocesses in the input image.
        """
        img = cv.imread(filename=imagefile)
        img = cv.resize(src=img, dsize=resize_scale)

        self.h, self.w, self.c = img.shape
        self.square = self.h // self.rows

        img_blocks = view_as_blocks(
            arr_in=img, block_shape=(self.square, self.square, self.c))
        img_blocks = img_blocks.reshape(
            self.rows * self.cols, self.square, self.square, self.c)

        return img_blocks

    def tr_data_generator(self):
        """
        This method preprocess the input images.
        """
        for i, l in zip(self.tr_images, self.tr_labels):
            yield (self.preprocess_input_image(imagefile=i),
                   self.onehot_from_fen(fen=l))

    def cv_data_generator(self):
        """
        This method preprocess the input images.
        """
        for i, l in zip(self.cv_images, self.cv_labels):
            yield (self.preprocess_input_image(imagefile=i),
                   self.onehot_from_fen(fen=l))

    def te_data_generator(self):
        """
        This method preprocess the input targets.
        """
        for i in self.te_images:
            yield self.preprocess_input_image(imagefile=i)

    def onehot_from_fen(self, fen):
        """
        This method converts FEN to onehot.
        The original author of this method is 'Pavel Koryakin'.
        Pavel Koryakin is also the maintainer of Chess Positions dataset.
        """
        eye = np.eye(N=self.N)
        output = np.empty(shape=(0, self.N))
        fen = re.sub(pattern='[/]', repl='', string=fen)

        for char in fen:
            if char in '12345678':
                output = np.append(
                    arr=output,
                    values=np.tile(A=eye[self.N-1], reps=(int(char), 1)), axis=0
                )
            else:
                idx = self.piece_symbols.index(char)
                output = np.append(
                    arr=output,
                    values=eye[idx].reshape((1, self.N)), axis=0
                )

        return output

    def fen_from_onehot(self, onehot):
        """
        This method converts onehot to FEN.
        The original author of this method is 'Pavel Koryakin'.
        Pavel Koryakin is also the maintainer of Chess Positions dataset.
        """
        output = str()
        for j in range(self.rows):
            for i in range(self.cols):
                if onehot[j][i] == 12: # TensorFlow coded 12 for empty squares.
                    output += ' '
                else:
                    output += self.piece_symbols[int(onehot[j][i])]
            if j != self.rows - 1:
                output += '/'

        for i in range(self.rows, 0, -1):
            output = output.replace(' ' * i, str(i))

        return output

    def construct_dataset(self):
        """
        This method constructs the dataset.
        """
        tr_dataset = tf.data.Dataset.from_generator(
            generator=self.tr_data_generator, output_types=(tf.int64, tf.int64))
        tr_dataset = tr_dataset.repeat()

        cv_dataset = tf.data.Dataset.from_generator(
            generator=self.cv_data_generator, output_types=(tf.int64, tf.int64))
        cv_dataset = cv_dataset.repeat()

        te_dataset = tf.data.Dataset.from_generator(
            generator=self.te_data_generator, output_types=tf.int64)
        te_dataset = te_dataset.repeat()

        it_tr = tr_dataset.__iter__()
        it_cv = cv_dataset.__iter__()
        it_te = te_dataset.__iter__()

        return it_tr, it_cv, it_te

With the above class, I created the training, the validation, and the test dataset generators that can be used for modeling.


5. Modeling

I created a modeling class, that will first tune the model and then fits the model with the dataset. Tuning was done using KerasTuner. It took 20 hours to tune the model with the best hyperparameters.

5.1. Base model

Below is the modeling class.

class ChessModel(object):
    """
    This class is for deep learning model for chess recognition problem.
    """

    def __init__(self,
                 tr_dataset,
                 cv_dataset,
                 tr_size,
                 cv_size,
                 filepath_tuner,
                 filepath_fitter,
                 filepath_tracker):
        self.tr_dataset = tr_dataset
        self.cv_dataset = cv_dataset
        self.input_shape = (25, 25, 3)
        self.batch_size = 64
        self.output_units = 13 # 12 for chess pieces and 1 for empty square.
        self.tr_size = tr_size
        self.cv_size = cv_size
        self.filepath_tuner = filepath_tuner
        self.filepath_fitter = filepath_fitter
        self.filepath_tracker = filepath_tracker

    def build_model(self, hp):
        """
        This method builds the optimized model.
        """
        hp_activations = hp.Choice(
            name='activation', values=['relu', 'tanh', 'sigmoid'])
        hp_filters_1 = hp.Int(
            name='filter_1', min_value=32, max_value=64, step=10)
        hp_filters_2 = hp.Int(
            name='filter_2', min_value=32, max_value=64, step=10)
        hp_kernel_1 = hp.Int(
            name='Kernel_1', min_value=2, max_value=5, step=None)
        hp_kernel_2 = hp.Int(
            name='Kernel_2', min_value=2, max_value=5, step=None)
        hp_units = hp.Int(
            name='dense', min_value=32, max_value=64, step=10)
        hp_learning_rate = hp.Choice(
            name='learning_rate', values=[1e-2, 1e-3, 1e-4])

        input_layer = Input(
            shape=self.input_shape, batch_size=self.batch_size, name='Input')
        conv_2d_layer_1 = Conv2D(
            filters=hp_filters_1, kernel_size=hp_kernel_1,
            activation=hp_activations, name='Conv2D_1')(input_layer)
        conv_2d_layer_2 = Conv2D(
            filters=hp_filters_2, kernel_size=hp_kernel_2,
            activation=hp_activations, name='Conv2D_2')(conv_2d_layer_1)
        flatten_layer = Flatten(name='Flatten')(conv_2d_layer_2)
        dense_layer = Dense(
            units=hp_units, activation=hp_activations, name='Dense')(flatten_layer)
        output_layer = Dense(
            units=self.output_units, activation='softmax', name='Output')(dense_layer)

        model = Model(inputs=input_layer, outputs=output_layer, name='Chess_Model')

        optimizer = tf.keras.optimizers.Adam(learning_rate=hp_learning_rate)
        model.compile(
            optimizer=optimizer, loss='categorical_crossentropy',
            metrics=['accuracy'])

        return model

    def model_tuner(self):
        """
        This method tunes the chess model.
        """
        if not os.path.isfile(path=self.filepath_tuner):
            print("Tuning the model.")
            stop_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=3)
            tuner = kt.Hyperband(
                hypermodel=self.build_model, objective='val_accuracy', max_epochs=10)
            tuner.search(
                x=tr_dataset, epochs=50, steps_per_epoch=self.tr_size,
                validation_data=cv_dataset, validation_steps=self.cv_size,
                callbacks=[stop_early])
            print("Tuning completed.")

            best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
            model = tuner.hypermodel.build(best_hps)
            tf.keras.models.save_model(model=model, filepath=self.filepath_tuner)
            print("Saved the best model to the file.")
        else:
            print("Model is already tuned, and is also saved.")
            model = tf.keras.models.load_model(filepath=self.filepath_tuner)
            print("Loaded the tuned model and ready for fitting.")

        return model

    def model_fitter(self):
        """
        This method fits the tuned model.
        """
        model = self.model_tuner()
        print()
        model.summary()
        print()

        model_save_callback = ModelCheckpoint(
            filepath=self.filepath_fitter, monitor='val_accuracy',
            verbose=1, save_best_only=True, mode='auto')
        callbacks = [model_save_callback]

        if not os.path.isfile(path=self.filepath_fitter):
            print("Fitting the model.")

            epochs = 10
            tracker = model.fit(
                x=tr_dataset, validation_data=cv_dataset, epochs=epochs,
                steps_per_epoch=len(tr_images), validation_steps=len(cv_images),
                callbacks=callbacks)
            print("nSaved the fitted model.")

            tracker_df = pd.DataFrame(data=tracker.history)
            tracker_df.to_csv(path_or_buf=self.filepath_tracker, index=False)
            print("Saved the history to the file.")
        else:
            print("Model is already fitted, and is also saved.")
            model = tf.keras.models.load_model(filepath=self.filepath_fitter)
            print("Loaded the fitted model and ready for prediction.")

            tracker_df = pd.read_csv(filepath_or_buffer=self.filepath_tracker)

        print()
        plot_model_performance(tracker_df=tracker_df)

        return model

5.2. Model architecture

Model: "Chess_Model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 Input (InputLayer)          [(64, 25, 25, 3)]         0         

 Conv2D_1 (Conv2D)           (64, 21, 21, 32)          2432      

 Conv2D_2 (Conv2D)           (64, 19, 19, 62)          17918     

 Flatten (Flatten)           (64, 22382)               0         

 Dense (Dense)               (64, 42)                  940086    

 Output (Dense)              (64, 13)                  559       

=================================================================
Total params: 960,995
Trainable params: 960,995
Non-trainable params: 0
_________________________________________________________________

5.3. Model performance – loss and accuracy

After fitting the model for 10 epochs, I obtained the performance of the model – loss, and accuracy.

Image by Author - Accuracy and loss
Image by Author – Accuracy and loss

Below is the confusion matrix of 25 test images.

Image by Author - Confusion matrix of 25 test images
Image by Author – Confusion matrix of 25 test images

6. Productionization of Data Product

Productionization is the process to expose the local model to the outside world from the Jupyter Notebook environment. Here, I exported the model as a file that was trained in the modeling stage. This model file has the learned parameters which can be readily used for test data.

6.1. Data product pipeline

from chess_positions import IllegalPosition
from chess_positions import Check
from glob import glob
from skimage.util.shape import view_as_blocks

import cv2 as cv
import plotly.express as px
import random
import tensorflow as tf
import warnings
warnings.filterwarnings(action='ignore')

class Pipeline(object):
    """
    This class is a pipeline mechanism to feed the
    query chess image into the model for FEN prediction.
    """

    def __init__(self, chess_image):
        self.piece_symbols = "prbnkqPRBNKQ"
        self.rows, self.cols = (8, 8)
        self.square = None
        self.h, self.w, self.c = None, None, None
        self.chess_image = chess_image
        self.chess_model = tf.keras.models.load_model(
            filepath='chess_model.h5')
        self.chess_image_display = self.display_image()

    def display_image(self):
        """
        This method reads the image and
        gives plotly fig for the final display.
        """
        image = cv.imread(filename=self.chess_image)
        image = cv.cvtColor(src=image, code=cv.COLOR_BGR2RGB)

        image_fig = px.imshow(img=image)
        image_fig.update_layout(
            coloraxis_showscale=False, autosize=True,
            margin=dict(l=0, r=0, b=0, t=0))
        image_fig.update_xaxes(showticklabels=False)
        image_fig.update_yaxes(showticklabels=False)
        return image_fig

    def preprocess(self, resize_scale=(200, 200)):
        """
        This method preprocesses the chess image.
        """
        img = cv.imread(filename=self.chess_image)
        img = cv.resize(src=img, dsize=resize_scale)

        self.h, self.w, self.c = img.shape
        self.square = self.h // self.rows

        img_blocks = view_as_blocks(
            arr_in=img, block_shape=(self.square, self.square, self.c))
        img_blocks = img_blocks.reshape(
            self.rows * self.cols, self.square, self.square, self.c)

        return img_blocks

    def fen_from_onehot(self, onehot):
        """
        This method converts onehot to FEN.
        The original author of this method is 'Pavel Koryakin'.
        Pavel Koryakin is also the maintainer of Chess Positions dataset.
        """
        output = str()
        for j in range(self.rows):
            for i in range(self.cols):
                if onehot[j][i] == 12: # TensorFlow coded 12 for empty squares.
                    output += ' '
                else:
                    output += self.piece_symbols[int(onehot[j][i])]
            if j != self.rows - 1:
                output += '/'

        for i in range(self.rows, 0, -1):
            output = output.replace(' ' * i, str(i))

        return output

    def predict(self):
        """
        This method predicts the FEN of the query chess image.
        """
        chess_image_blocks = self.preprocess()

        onehot = self.chess_model.predict(x=chess_image_blocks)
        onehot = onehot.argmax(axis=1).reshape(-1, 8, 8)[0]

        fen_label = self.fen_from_onehot(onehot=onehot)

        interpretation = self.illegal_interpreter(fen_label=fen_label)
        if len(interpretation) > 0:
            interpretation = f"This is an illegal chess position. Reason is {interpretation}"
        else:
            interpretation = self.check_interpreter(fen_label=fen_label)

        fen_label = f"The Forsyth-Edwards Notation (FEN) of an uploaded chess image is {fen_label}."
        interpretation = f"Further interpretation: {interpretation}"

        return fen_label, interpretation

    def illegal_interpreter(self, fen_label):
        """
        This method interprets the predicted FEN.
        """
        reason = str()

        chess_illegal = IllegalPosition(fen_label=fen_label)

        if chess_illegal.are_kings_less():
            reason += "either white king, black king, or both are missing."
        elif chess_illegal.are_kings_more():
            reason += "either white king, black king, or both are more than 1."
        elif chess_illegal.are_queens_more():
            reason += "either white queen, black queen, or both are more than 9."
        elif chess_illegal.are_bishops_more():
            reason += "either white bishop, black bishop, or both are more than 10."
        elif chess_illegal.are_knights_more():
            reason += "either white knight, black knight, or both are more than 10."
        elif chess_illegal.are_rooks_more():
            reason += "either white rook, black rook, or both are more than 10."
        elif chess_illegal.are_pawns_more():
            reason += "either white pawn, black pawn, or both are more than 8."
        elif chess_illegal.rule_2():
            reason += "either white pawn, black pawn, or both are in first row and/or last row."
        elif chess_illegal.rule_3():
            reason += "the king checks the other the king."
        elif chess_illegal.rule_4():
            reason += "white king and black king are under attack simultaneously."
        else:
            reason += ""

        return reason

    def check_interpreter(self, fen_label):
        """
        This method interprets the predicted FEN.
        """
        reason = str()

        chess_check = Check(fen_label=fen_label)

        r_checks_K = chess_check.rook_checks_king(
            attacker='r', defendant='K')
        n_checks_K = chess_check.knight_checks_king(
            attacker='n', defendant='K')
        b_checks_K = chess_check.bishop_checks_king(
            attacker='b', defendant='K')
        q_checks_K = chess_check.queen_checks_king(
            attacker='q', defendant='K')
        p_checks_K = chess_check.pawn_checks_king(
            attacker='p', defendant='K')
        R_checks_k = chess_check.rook_checks_king(
            attacker='R', defendant='k')
        N_checks_k = chess_check.knight_checks_king(
            attacker='N', defendant='k')
        B_checks_k = chess_check.bishop_checks_king(
            attacker='B', defendant='k')
        Q_checks_k = chess_check.queen_checks_king(
            attacker='Q', defendant='k')
        P_checks_k = chess_check.pawn_checks_king(
            attacker='P', defendant='k')

        is_K_checked = r_checks_K or n_checks_K or b_checks_K or q_checks_K or p_checks_K
        is_k_checked = R_checks_k or N_checks_k or B_checks_k or Q_checks_k or P_checks_k

        if is_K_checked:
            reason += "The white king is under attack."
        elif is_k_checked:
            reason += "The black king is under attack."
        else:
            reason += "Both kings are safe."

        return reason

6.2. Demo of the data product

Data product link: https://huggingface.co/spaces/mohd-saifuddin/Chess-Recognition-2D

Please note, you will need test images to use this data product. So, I recommend you download the test images from the dataset source.


7. Learning Outcomes

My learning outcomes working on this project.

  1. I learned to perform detailed EDA on chess images and FEN labels.
  2. I learned data preprocessing and working with the TensorFlow Data module.
  3. I learned to perform hyperparameter tuning using KerasTuner (still there are so many concepts to learn).
  4. Finally, I learned to develop a data product and published it on the Streamlit platform.

8. References

[1] Pavel Koryakin, Chess Positions. In Kaggle. here.

[2] Forsyth–Edwards Notation. In Wikipedia. here.


9. End

Thank you for reading. If you have any suggestions, please let me know.

  1. Deep learning code: here.
  2. Streamlit application code: here.

You can connect with me on LinkedIn: here.


Related Articles