The world’s leading publication for data science, AI, and ML professionals.

Spelling Correction: How to make an accurate and fast corrector

Dirty data leads to bad model quality. In real-world NLP problems we often meet texts with a lot of typos. As the result, we are unable to…

Source: https://unsplash.com/photos/wRgNwR9CZDA
Source: https://unsplash.com/photos/wRgNwR9CZDA

Dirty data leads to bad model quality. In real-world NLP problems we often meet texts with a lot of typos. As the result, we are unable to reach the best score. As painful as it may be, data should be cleaned before fitting.

We need an automatic spelling corrector which can fix words with typos and, at the same time not break correct spellings.

But how can we achieve this?

Let start with a Norvig’s spelling corrector and iteratively increase its capabilities.

Norvig’s approach

Peter Norvig (director of research at Google) described the following approach to spelling correction.

Let’s take a word and brute force all possible edits, such as delete, insert, transpose, replace and split. Eg. for word abc possible candidates will be: ab ac bc bac cba acb a_bc ab_c aabc abbc acbc adbc aebc etc.

Every word is added to a candidate list. We repeat this procedure for every word for a second time to get candidates with bigger edit distance (for cases with two errors).

Each candidate is estimated with unigram language model. For each vocabulary word frequencies are pre-calculated, based on some big text collections. The candidate word with highest frequency is taken as an answer.

Adding some context

First improvement – adding n-gram language model (3-grams). Let’s pre-calculate not only single words, but word and a small context (3 nearest words). Let’s estimate probability of some fragment as a product of all n-grams of n-size:

To make everything simple let’s calculate probability of n-gram of size n as a product of probabilities of all lower order grams (actually there are some smoothing technics, like Kneser–Ney – they improve model’s accuracy, but let’s talk about it later, see "Improve Accuracy" paragraph below):

To get a probability of n-gram from appearance frequencies we need to normalize frequencies (eg. divide number of 3-grams by number of 2-grams, etc.):

Now we can use our extended language model to estimate candidates with context.

Sentence probability can be calculated like this:

def predict(self, sentence):
    result = 0
    for i in range(0, len(sentence) - 2):
        p2 = self.getGram3Prob(sentence[i], sentence[i + 1], sentence[i + 2])
        p3 = self.getGram2Prob(sentence[i], sentence[i + 1])
        p4 = self.getGram1Prob(sentence[i])
        result += math.log(p2) + math.log(p3) + math.log(p4)
    return result

And n-gram probabilities like this:

def getGram1Prob(self, wordID):
    wordCounts = self.gram1.get(wordID, 0) + SimpleLangModel.K
    vocabSize = len(self.gram1)
    return float(wordCounts) / (self.totalWords + vocabSize)

def getGram2Prob(self, wordID1, wordID2):
    countsWord1 = self.gram1.get(wordID1, 0) + self.totalWords
    countsBigram = self.gram2.get((wordID1, wordID2), 0) + SimpleLangModel.K
    return float(countsBigram) / countsWord1

def getGram3Prob(self, wordID1, wordID2, wordID3):
    countsGram2 = self.gram2.get((wordID1, wordID2), 0) + self.totalWords
    countsGram3 = self.gram3.get((wordID1, wordID2, wordID3), 0) + SimpleLangModel.K
    return float(countsGram3) / countsGram2

Now we got a much better accuracy. However, model become really huge, and everything works so slow. For 600 Mb train text we got:

Improving speed – SymSpell approach

To improve speed – let’s use an idea from SymSpell. Idea is quite elegant. Instead of generating all possible edits each time we met incorrect word – we can pre-calculate all delete typos (and other typos derived from deletes). You can read more details in the original article.

Obviously, we wouldn’t be able to achieve as high speed as the original (because we use a language model and look at the context, and not only single word), but we can improve performance significantly. The cost is additional memory consumption:

Improve memory consumption

To get the best possible accuracy we need a big dataset (at least few gigabytes). Training n-gram model on 600 mb file leads to significant memory consumption (25 Gb). Half of that size is used by language model, and another half by the symspell index.

One reason of such a high memory usage is that we don’t store a plain text, instead we store frequencies. For example, for following text of 5 words: "a b c a b" we store following frequencies:

a => 2 b => 2 c => 1 a b => 2 b c => 1 c a => 1 a b c => 1 b c a => 1 c a b => 1

Another reason – high memory overhead of the hash table data structure (hash table is used inside Python dict or c++ _unorderedmap).

To compress our n-gram model let’s use an approach described in Efficient Minimal Perfect Hash Language Models paper. Let’s use a perfect hash (Compress, Hash and Displace) to store n-gram counts. Perfect hash is a hash which guarantees no collisions. Without collisions it’s possible to store only values (count frequencies) and not the original n-grams. To ensure that unknown words hash wouldn’t match the existing one we will use a bloom filter with known words. Also we can use a nonlinear quantization to pack a 32 bit long count frequencies into a 16 bit values. This does not affect the final metrics but reduces memory usage.

Quantization:

static const uint32_t MAX_REAL_NUM = 268435456;
static const uint32_t MAX_AVAILABLE_NUM = 65536;

uint16_t PackInt32(uint32_t num) {
    double r = double(num) / double(MAX_REAL_NUM);
    assert(r >= 0.0 &amp;&amp; r <= 1.0);
    r = pow(r, 0.2);
    r *= MAX_AVAILABLE_NUM;
    return uint16_t(r);
}

uint32_t UnpackInt32(uint16_t num) {
    double r = double(num) / double(MAX_AVAILABLE_NUM);
    r = pow(r, 5.0);
    r *= MAX_REAL_NUM;
    return uint32_t(ceil(r));
}

Count frequencies extraction:

template<typename T>
TCount GetGramHashCount(T key,
                        const TPerfectHash&amp; ph,
                        const std::vector<uint16_t>&amp; buckets,
                        TBloomFilter&amp; filter)
{
    constexpr int TMP_BUF_SIZE = 128;
    static char tmpBuff[TMP_BUF_SIZE];
    static MemStream tmpBuffStream(tmpBuff, TMP_BUF_SIZE - 1);
    static std::ostream out(&amp;tmpBuffStream);

    tmpBuffStream.Reset();

    NHandyPack::Dump(out, key);
    if (!filter.Contains(tmpBuff, tmpBuffStream.Size())) {
        return TCount();
    }

    uint32_t bucket = ph.Hash(tmpBuff, tmpBuffStream.Size());

    assert(bucket < ph.BucketsNumber());

    return UnpackInt32(buckets[bucket]);
}

First we check if key exists in bloom filter. And then we get counts based on perfect hash bucket number.

To compress a symspell index let’s use a Bloom filter. A Bloom filter is a space-efficient probabilistic data structure that is used to test whether an element is a member of a set. Let’s put all delete hashes into a bloom filter and use this index to skip non-existing candidates.

Here is a modified second step of the symspell algorithm. Here we take candidate words, that were previously generated by removing a single or multiple letters from original word, and check each word if it contains in index or not. Deletes1 and Deletes2 are the bloom filters.

TWords CheckCandidate(const std::wstring&amp; s)
{
    TWords results;
    if (Deletes1->Contains(w)) {
        Inserts(w, results);
    }
    if (Deletes2->Contains(w)) {
        Inserts2(w, results);
    }
}
void TSpellCorrector::Inserts(const std::wstring&amp; w, TWords&amp; result) const 
{
    for (size_t i = 0; i < w.size() + 1; ++i) {
        for (auto&amp;&amp; ch: LangModel.GetAlphabet()) {
            std::wstring s = w.substr(0, i) + ch + w.substr(i);
            TWord c = LangModel.GetWord(s);
            if (!c.Empty()) {
                result.push_back(c);
            }
        }
    }
}

void TSpellCorrector::Inserts2(const std::wstring&amp; w, TWords&amp; result) const 
{
    for (size_t i = 0; i < w.size() + 1; ++i) {
        for (auto&amp;&amp; ch: LangModel.GetAlphabet()) {
            std::wstring s = w.substr(0, i) + ch + w.substr(i);
            if (Deletes1->Contains(WideToUTF8(s))) {
                Inserts(s, result);
            }
        }
    }
}

After this optimization model size was reduced significantly, down to 800 Mb:

Improve accuracy

To improve accuracy let’s add a several machine learning classifiers. First one will be used to make the decision whether the word has error or not. Second one, regressor, will be used for candidates ranking. This classifier partly plays a role of language model smoothing (it gets all grams as a separate input and a classifier makes a decision how much impact each gram has).

For candidates ranking we will train a catboost (gradient boosted decision trees) ranking model with following features:

  • word frequency
  • n-grams frequencies, separate for each gram (2, 3)
  • frequencies of nearby words with distance 3, 4
  • n-gram model prediction
  • edit distance between candidate and a source word
  • number of candidates with better edit distance
  • words length
  • word existence in a clean static dictionary
from catboost import CatBoost
params = {
        'loss_function': 'PairLogit',
        'iterations': 400,
        'learning_rate': 0.1,
        'depth': 8,
        'verbose': False,
        'random_seed': 42,
        'early_stopping_rounds': 50,
        'border_count': 64,
        'leaf_estimation_backtracking': 'AnyImprovement',
        'leaf_estimation_iterations': 2,
        'leaf_estimation_method': 'Newton',
        'task_type': 'CPU'
    }
model = CatBoost(params, )
    model.fit(trainX, trainY, pairs=trainPairs, group_id=groupIDs, eval_set=evalPool, verbose=1)

For error prediction we will train a binary classifier. Let’s use the same features calculated for each word of the original word, and let the classifier decide if the word has error or not. That will gives an ability to detect an error depends on context, even for dictionary words.

from catboost import CatBoostClassifier
model = CatBoostClassifier(
    iterations=400,
    learning_rate=0.3,
    depth=8
)

model.fit(trainX, trainY, sample_weight=trainWeight, verbose=False)

This improves accuracy even more, however, it’s not free, we get reduced performance. Still, for most applications this performance is more than enough, and accuracy usually more important.

Evaluate

To evaluate a model we need some datasets. We can generate an artificial errors based on a clean text. Also we can use public datasets – one of them is a SpellRuEval dataset. Let’s check accuracy both on a couple of artificial datasets and a real one. We will use some alternative spell checkers from giant IT companies to compare with. Here jamspell is our spell checker and we got the following metrics:

RU, artificial, literature

RU, real, internet posts

EN, artificial, news

  • errRate – number of errors left in text after performing automatic correction
  • fixRate – number of fixed errors
  • broken – number of correct words that were broken
  • pr, re, f1 – precision, recall, f1 score

Further steps

Next steps to improve accuracy – gathering a large parallel corpus of texts (separate for mobile and desktop platforms) with errors and corrected texts and training a dedicated error models.

Another possible way to improve accuracy is to add dynamic learning option. We can learn on flight while making corrections, or we can make a two-passes correction. At the first pass model will learn some statistics and at the second pass make an actual correction.

Also a neural-network language models (bidirectional LSTM or BERT) may give some additional accuracy boost. They didn’t work well with a straight forward approach (LSTM errors classifier, seq-2-seq LSTM model, using BERT outputs as a candidates ranking weights) but may be their predictions will be useful as features in a rank model / error detector.

Here is some details about approaches that we were unable to make working (doesn’t mean that they can’t work – it just our experience).

Using BERT to select best candidate

We trained BERT and tried to use it for predicting the best candidate.

from transformers import RobertaConfig, RobertaTokenizerFast, RobertaForMaskedLM, LineByLineTextDataset, DataCollatorForLanguageModeling
config = RobertaConfig(
    vocab_size=52_000,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=6,
    type_vocab_size=1,
)
dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path=TRAIN_TEXT_FILE,
    block_size=128,
)
from transformers import
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir="~/transformers",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_gpu_train_batch_size=32,
    save_steps=10_000,
    save_total_limit=2,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    prediction_loss_only=False,
)
trainer.train()

It worked superior to n-gram language model for masked word prediction task (30% accuracy for BERT model, and 20% accuracy for n-gram model), but it performed worse while selecting best word from a list of candidate words. Our hypothesis that it’s due to the fact that BERT knows nothing about edit distance or matching with original word. We believe that adding BERT prediction as a feature for catboost ranking model can give accuracy boost.

Using LSTM as an error classifier

We tried to use LSTM as an error detector (to predict if word has error or not), but our best possible result was ~same as a regular n-gram language model errors predictor + manual heuristics. And training time was much bigger, so we decided not to use it for now. Here is a model that gave the best score. Input is word level GloVe embeddings, trained on the same file.

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 9, 200)]          0         
_________________________________________________________________
bidirectional (Bidirectional (None, 9, 1800)           7927200   
_________________________________________________________________
dropout (Dropout)            (None, 9, 1800)           0         
_________________________________________________________________
attention (Attention)        (None, 1800)              1809      
_________________________________________________________________
dense (Dense)                (None, 724)               1303924   
_________________________________________________________________
dropout_1 (Dropout)          (None, 724)               0         
_________________________________________________________________
batch_normalization (BatchNo (None, 724)               2896      
_________________________________________________________________
dense_1 (Dense)              (None, 2)                 1450      
=================================================================
Total params: 9,237,279
Trainable params: 9,235,831
Non-trainable params: 1,448
_________________________________________________________________
None

Conclusion

We started from very simple model and iteratively increase it’s capabilities, and finally we get a strong, production level spell checker. Still, it’s not the end, there is a lot of steps on the long road to our goal – making the best possible spell checker in the world.

Links


Related Articles