From 8876b75ca1c218949539dcc2fb6c88a19da9e3f8 Mon Sep 17 00:00:00 2001
From: satok <satok@google.com>
Date: Thu, 4 Aug 2011 18:31:57 +0900
Subject: [PATCH] Move scoring part to the correction state

Change-Id: I2dc4a0869636fce5526f48b3a6267b6bdf61dbfb
---
 native/src/correction_state.cpp   | 131 ++++++++++++++++--
 native/src/correction_state.h     |  63 +++++++--
 native/src/unigram_dictionary.cpp | 221 ++++++++++--------------------
 native/src/unigram_dictionary.h   |  12 +-
 4 files changed, 245 insertions(+), 182 deletions(-)

diff --git a/native/src/correction_state.cpp b/native/src/correction_state.cpp
index b2c77b00d2..9000e9e9cc 100644
--- a/native/src/correction_state.cpp
+++ b/native/src/correction_state.cpp
@@ -25,13 +25,31 @@
 
 namespace latinime {
 
+//////////////////////
+// inline functions //
+//////////////////////
+static const char QUOTE = '\'';
+
+inline bool CorrectionState::needsToSkipCurrentNode(const unsigned short c) {
+    const unsigned short userTypedChar = mProximityInfo->getPrimaryCharAt(mInputIndex);
+    // Skip the ' or other letter and continue deeper
+    return (c == QUOTE && userTypedChar != QUOTE) || mSkipPos == mOutputIndex;
+}
+
+/////////////////////
+// CorrectionState //
+/////////////////////
+
 CorrectionState::CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier)
         : TYPED_LETTER_MULTIPLIER(typedLetterMultiplier), FULL_WORD_MULTIPLIER(fullWordMultiplier) {
 }
 
-void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength) {
+void CorrectionState::initCorrectionState(const ProximityInfo *pi, const int inputLength,
+        const int maxDepth) {
     mProximityInfo = pi;
     mInputLength = inputLength;
+    mMaxDepth = maxDepth;
+    mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
 }
 
 void CorrectionState::setCorrectionParams(const int skipPos, const int excessivePos,
@@ -58,27 +76,37 @@ int CorrectionState::getFreqForSplitTwoWords(const int firstFreq, const int seco
     return CorrectionState::RankingAlgorithm::calcFreqForSplitTwoWords(firstFreq, secondFreq, this);
 }
 
-int CorrectionState::getFinalFreq(const unsigned short *word, const int freq) {
-    if (mProximityInfo->sameAsTyped(word, mOutputIndex + 1) || mOutputIndex < MIN_SUGGEST_DEPTH) {
+int CorrectionState::getFinalFreq(const int freq, unsigned short **word, int *wordLength) {
+    const int outputIndex = mOutputIndex - 1;
+    const int inputIndex = (mCurrentStateType == TRAVERSE_ALL_ON_TERMINAL
+            || mCurrentStateType == TRAVERSE_ALL_NOT_ON_TERMINAL) ? mInputIndex : mInputIndex - 1;
+    *wordLength = outputIndex + 1;
+    if (mProximityInfo->sameAsTyped(mWord, outputIndex + 1) || outputIndex < MIN_SUGGEST_DEPTH) {
         return -1;
     }
-    const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == mInputIndex + 2)
-            : (mInputLength == mInputIndex + 1);
+    *word = mWord;
+    const bool sameLength = (mExcessivePos == mInputLength - 1) ? (mInputLength == inputIndex + 2)
+            : (mInputLength == inputIndex + 1);
     return CorrectionState::RankingAlgorithm::calculateFinalFreq(
-            mInputIndex, mOutputIndex, mMatchedCharCount, freq, sameLength, this);
+            inputIndex, outputIndex, mMatchedCharCount, freq, sameLength, this);
 }
 
-void CorrectionState::initProcessState(
-        const int matchCount, const int inputIndex, const int outputIndex) {
+void CorrectionState::initProcessState(const int matchCount, const int inputIndex,
+        const int outputIndex, const bool traverseAllNodes, const int diffs) {
     mMatchedCharCount = matchCount;
     mInputIndex = inputIndex;
     mOutputIndex = outputIndex;
+    mTraverseAllNodes = traverseAllNodes;
+    mDiffs = diffs;
 }
 
-void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex) {
+void CorrectionState::getProcessState(int *matchedCount, int *inputIndex, int *outputIndex,
+        bool *traverseAllNodes, int *diffs) {
     *matchedCount = mMatchedCharCount;
     *inputIndex = mInputIndex;
     *outputIndex = mOutputIndex;
+    *traverseAllNodes = mTraverseAllNodes;
+    *diffs = mDiffs;
 }
 
 void CorrectionState::charMatched() {
@@ -95,6 +123,11 @@ int CorrectionState::getInputIndex() {
     return mInputIndex;
 }
 
+// TODO: remove
+bool CorrectionState::needsToTraverseAll() {
+    return mTraverseAllNodes;
+}
+
 void CorrectionState::incrementInputIndex() {
     ++mInputIndex;
 }
@@ -103,6 +136,86 @@ void CorrectionState::incrementOutputIndex() {
     ++mOutputIndex;
 }
 
+void CorrectionState::startTraverseAll() {
+    mTraverseAllNodes = true;
+}
+
+bool CorrectionState::needsToPrune() const {
+    return (mOutputIndex - 1 >= (mTransposedPos >= 0 ? mInputLength - 1 : mMaxDepth)
+            || mDiffs > mMaxEditDistance);
+}
+
+CorrectionState::CorrectionStateType CorrectionState::processCharAndCalcState(
+        const int32_t c, const bool isTerminal) {
+    mCurrentStateType = NOT_ON_TERMINAL;
+    // This has to be done for each virtual char (this forwards the "inputIndex" which
+    // is the index in the user-inputted chars, as read by proximity chars.
+    if (mExcessivePos == mOutputIndex && mInputIndex < mInputLength - 1) {
+        incrementInputIndex();
+    }
+
+    if (mTraverseAllNodes || needsToSkipCurrentNode(c)) {
+        mWord[mOutputIndex] = c;
+        if (needsToTraverseAll() && isTerminal) {
+            mCurrentStateType = TRAVERSE_ALL_ON_TERMINAL;
+        } else {
+            mCurrentStateType = TRAVERSE_ALL_NOT_ON_TERMINAL;
+        }
+    } else {
+        int inputIndexForProximity = mInputIndex;
+
+        if (mTransposedPos >= 0) {
+            if (mInputIndex == mTransposedPos) {
+                ++inputIndexForProximity;
+            }
+            if (mInputIndex == (mTransposedPos + 1)) {
+                --inputIndexForProximity;
+            }
+        }
+
+        int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
+                inputIndexForProximity, c, this);
+        if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
+            mCurrentStateType = UNRELATED;
+            return mCurrentStateType;
+        }
+        mWord[mOutputIndex] = c;
+        // If inputIndex is greater than mInputLength, that means there is no
+        // proximity chars. So, we don't need to check proximity.
+        if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
+            charMatched();
+        }
+
+        if (ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) {
+            incrementDiffs();
+        }
+
+        const bool isSameAsUserTypedLength = mInputLength
+                == getInputIndex() + 1
+                        || (mExcessivePos == mInputLength - 1
+                                    && getInputIndex() == mInputLength - 2);
+        if (isSameAsUserTypedLength && isTerminal) {
+            mCurrentStateType = ON_TERMINAL;
+        }
+        // Start traversing all nodes after the index exceeds the user typed length
+        if (isSameAsUserTypedLength) {
+            startTraverseAll();
+        }
+
+        // Finally, we are ready to go to the next character, the next "virtual node".
+        // We should advance the input index.
+        // We do this in this branch of the 'if traverseAllNodes' because we are still matching
+        // characters to input; the other branch is not matching them but searching for
+        // completions, this is why it does not have to do it.
+        incrementInputIndex();
+    }
+
+    // Also, the next char is one "virtual node" depth more than this char.
+    incrementOutputIndex();
+
+    return mCurrentStateType;
+}
+
 CorrectionState::~CorrectionState() {
 }
 
diff --git a/native/src/correction_state.h b/native/src/correction_state.h
index cc3c3e669d..a548bcb68f 100644
--- a/native/src/correction_state.h
+++ b/native/src/correction_state.h
@@ -29,49 +29,76 @@ class CorrectionState {
 
 public:
     typedef enum {
-        ALLOW_ALL,
+        TRAVERSE_ALL_ON_TERMINAL,
+        TRAVERSE_ALL_NOT_ON_TERMINAL,
         UNRELATED,
-        RELATED
+        ON_TERMINAL,
+        NOT_ON_TERMINAL
     } CorrectionStateType;
 
     CorrectionState(const int typedLetterMultiplier, const int fullWordMultiplier);
-    void initCorrectionState(const ProximityInfo *pi, const int inputLength);
+    void initCorrectionState(
+            const ProximityInfo *pi, const int inputLength, const int maxWordLength);
     void setCorrectionParams(const int skipPos, const int excessivePos, const int transposedPos,
             const int spaceProximityPos, const int missingSpacePos);
     void checkState();
-    void initProcessState(const int matchCount, const int inputIndex, const int outputIndex);
-    void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex);
-    void charMatched();
-    void incrementInputIndex();
-    void incrementOutputIndex();
+    void initProcessState(const int matchCount, const int inputIndex, const int outputIndex,
+            const bool traverseAllNodes, const int diffs);
+    void getProcessState(int *matchedCount, int *inputIndex, int *outputIndex,
+            bool *traverseAllNodes, int *diffs);
     int getOutputIndex();
     int getInputIndex();
+    bool needsToTraverseAll();
 
     virtual ~CorrectionState();
+    int getSpaceProximityPos() const {
+        return mSpaceProximityPos;
+    }
+    int getMissingSpacePos() const {
+        return mMissingSpacePos;
+    }
+
     int getSkipPos() const {
         return mSkipPos;
     }
+
     int getExcessivePos() const {
         return mExcessivePos;
     }
+
     int getTransposedPos() const {
         return mTransposedPos;
     }
-    int getSpaceProximityPos() const {
-        return mSpaceProximityPos;
-    }
-    int getMissingSpacePos() const {
-        return mMissingSpacePos;
-    }
+
+    bool needsToPrune() const;
+
     int getFreqForSplitTwoWords(const int firstFreq, const int secondFreq);
-    int getFinalFreq(const unsigned short *word, const int freq);
+    int getFinalFreq(const int freq, unsigned short **word, int* wordLength);
+
+    CorrectionStateType processCharAndCalcState(const int32_t c, const bool isTerminal);
 
+    int getDiffs() const {
+        return mDiffs;
+    }
 private:
+    void charMatched();
+    void incrementInputIndex();
+    void incrementOutputIndex();
+    void startTraverseAll();
+
+    // TODO: remove
+
+    void incrementDiffs() {
+        ++mDiffs;
+    }
 
     const int TYPED_LETTER_MULTIPLIER;
     const int FULL_WORD_MULTIPLIER;
 
     const ProximityInfo *mProximityInfo;
+
+    int mMaxEditDistance;
+    int mMaxDepth;
     int mInputLength;
     int mSkipPos;
     int mExcessivePos;
@@ -82,6 +109,12 @@ private:
     int mMatchedCharCount;
     int mInputIndex;
     int mOutputIndex;
+    int mDiffs;
+    bool mTraverseAllNodes;
+    CorrectionStateType mCurrentStateType;
+    unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
+
+    inline bool needsToSkipCurrentNode(const unsigned short c);
 
     class RankingAlgorithm {
     public:
diff --git a/native/src/unigram_dictionary.cpp b/native/src/unigram_dictionary.cpp
index b95da99a3f..93d2b84181 100644
--- a/native/src/unigram_dictionary.cpp
+++ b/native/src/unigram_dictionary.cpp
@@ -181,14 +181,14 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
     PROF_START(0);
     initSuggestions(
             proximityInfo, xcoordinates, ycoordinates, codes, codesSize, outWords, frequencies);
-    mCorrectionState->initCorrectionState(mProximityInfo, mInputLength);
     if (DEBUG_DICT) assert(codesSize == mInputLength);
 
-    const int MAX_DEPTH = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH);
+    const int maxDepth = min(mInputLength * MAX_DEPTH_MULTIPLIER, MAX_WORD_LENGTH);
+    mCorrectionState->initCorrectionState(mProximityInfo, mInputLength, maxDepth);
     PROF_END(0);
 
     PROF_START(1);
-    getSuggestionCandidates(-1, -1, -1, MAX_DEPTH);
+    getSuggestionCandidates(-1, -1, -1);
     PROF_END(1);
 
     PROF_START(2);
@@ -198,7 +198,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
             if (DEBUG_DICT) {
                 LOGI("--- Suggest missing characters %d", i);
             }
-            getSuggestionCandidates(i, -1, -1, MAX_DEPTH);
+            getSuggestionCandidates(i, -1, -1);
         }
     }
     PROF_END(2);
@@ -211,7 +211,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
             if (DEBUG_DICT) {
                 LOGI("--- Suggest excessive characters %d", i);
             }
-            getSuggestionCandidates(-1, i, -1, MAX_DEPTH);
+            getSuggestionCandidates(-1, i, -1);
         }
     }
     PROF_END(3);
@@ -224,7 +224,7 @@ void UnigramDictionary::getWordSuggestions(ProximityInfo *proximityInfo,
             if (DEBUG_DICT) {
                 LOGI("--- Suggest transposed characters %d", i);
             }
-            getSuggestionCandidates(-1, -1, i, mInputLength - 1);
+            getSuggestionCandidates(-1, -1, i);
         }
     }
     PROF_END(4);
@@ -272,7 +272,6 @@ void UnigramDictionary::initSuggestions(ProximityInfo *proximityInfo, const int
     mFrequencies = frequencies;
     mOutputChars = outWords;
     mInputLength = codesSize;
-    mMaxEditDistance = mInputLength < 5 ? 2 : mInputLength / 2;
     proximityInfo->setInputParams(codes, codesSize);
     mProximityInfo = proximityInfo;
 }
@@ -342,9 +341,8 @@ static const char QUOTE = '\'';
 static const char SPACE = ' ';
 
 void UnigramDictionary::getSuggestionCandidates(const int skipPos,
-        const int excessivePos, const int transposedPos, const int maxDepth) {
+        const int excessivePos, const int transposedPos) {
     if (DEBUG_DICT) {
-        LOGI("getSuggestionCandidates %d", maxDepth);
         assert(transposedPos + 1 < mInputLength);
         assert(excessivePos < mInputLength);
         assert(missingPos < mInputLength);
@@ -368,32 +366,26 @@ void UnigramDictionary::getSuggestionCandidates(const int skipPos,
     while (depth >= 0) {
         if (mStackChildCount[depth] > 0) {
             --mStackChildCount[depth];
-            bool traverseAllNodes = mStackTraverseAll[depth];
-            int diffs = mStackDiffs[depth];
             int siblingPos = mStackSiblingPos[depth];
             int firstChildPos;
             mCorrectionState->initProcessState(
-                    mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth]);
+                    mStackMatchedCount[depth], mStackInputIndex[depth], mStackOutputIndex[depth],
+                    mStackTraverseAll[depth], mStackDiffs[depth]);
 
-            // depth will never be greater than maxDepth because in that case,
             // needsToTraverseChildrenNodes should be false
             const bool needsToTraverseChildrenNodes = processCurrentNode(siblingPos,
-                    maxDepth, traverseAllNodes, diffs,
-                    mCorrectionState, &childCount,
-                    &firstChildPos, &traverseAllNodes, &diffs,
-                    &siblingPos);
+                    mCorrectionState, &childCount, &firstChildPos, &siblingPos);
             // Update next sibling pos
             mStackSiblingPos[depth] = siblingPos;
             if (needsToTraverseChildrenNodes) {
                 // Goes to child node
                 ++depth;
                 mStackChildCount[depth] = childCount;
-                mStackTraverseAll[depth] = traverseAllNodes;
-                mStackDiffs[depth] = diffs;
                 mStackSiblingPos[depth] = firstChildPos;
 
                 mCorrectionState->getProcessState(&mStackMatchedCount[depth],
-                        &mStackInputIndex[depth], &mStackOutputIndex[depth]);
+                        &mStackInputIndex[depth], &mStackOutputIndex[depth],
+                        &mStackTraverseAll[depth], &mStackDiffs[depth]);
             }
         } else {
             // Goes to parent sibling node
@@ -437,12 +429,12 @@ inline bool UnigramDictionary::needsToSkipCurrentNode(const unsigned short c,
     return (c == QUOTE && userTypedChar != QUOTE) || skipPos == depth;
 }
 
-
-inline void UnigramDictionary::onTerminal(
-        unsigned short int* word, const int freq, CorrectionState *correctionState) {
-    const int finalFreq = correctionState->getFinalFreq(word, freq);
+inline void UnigramDictionary::onTerminal(const int freq, CorrectionState *correctionState) {
+    int wordLength;
+    unsigned short* wordPointer;
+    const int finalFreq = correctionState->getFinalFreq(freq, &wordPointer, &wordLength);
     if (finalFreq >= 0) {
-        addWord(word, correctionState->getOutputIndex() + 1, finalFreq);
+        addWord(wordPointer, wordLength, finalFreq);
     }
 }
 
@@ -657,20 +649,13 @@ int UnigramDictionary::getBigramPosition(int pos, unsigned short *word, int offs
 // there aren't any more nodes at this level, it merely returns the address of the first byte after
 // the current node in nextSiblingPosition. Thus, the caller must keep count of the nodes at any
 // given level, as output into newCount when traversing this level's parent.
-inline bool UnigramDictionary::processCurrentNode(const int initialPos, const int maxDepth,
-        const bool initialTraverseAllNodes, const int initialDiffs,
-        CorrectionState *correctionState, int *newCount, int *newChildrenPosition,
-        bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition) {
-    const int skipPos = correctionState->getSkipPos();
-    const int excessivePos = correctionState->getExcessivePos();
-    const int transposedPos = correctionState->getTransposedPos();
+inline bool UnigramDictionary::processCurrentNode(const int initialPos,
+        CorrectionState *correctionState, int *newCount,
+        int *newChildrenPosition, int *nextSiblingPosition) {
     if (DEBUG_DICT) {
         correctionState->checkState();
     }
     int pos = initialPos;
-    int traverseAllNodes = initialTraverseAllNodes;
-    int diffs = initialDiffs;
-    const int initialInputIndex = correctionState->getInputIndex();
 
     // Flags contain the following information:
     // - Address type (MASK_GROUP_ADDRESS_TYPE) on two bits:
@@ -682,6 +667,9 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
     // - FLAG_HAS_BIGRAMS: whether this node has bigrams or not
     const uint8_t flags = BinaryFormat::getFlagsAndForwardPointer(DICT_ROOT, &pos);
     const bool hasMultipleChars = (0 != (FLAG_HAS_MULTIPLE_CHARS & flags));
+    const bool isTerminalNode = (0 != (FLAG_IS_TERMINAL & flags));
+
+    bool needsToInvokeOnTerminal = false;
 
     // This gets only ONE character from the stream. Next there will be:
     // if FLAG_HAS_MULTIPLE CHARS: the other characters of the same node
@@ -707,111 +695,21 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
         const bool isLastChar = (NOT_A_CHARACTER == nextc);
         // If there are more chars in this nodes, then this virtual node is not a terminal.
         // If we are on the last char, this virtual node is a terminal if this node is.
-        const bool isTerminal = isLastChar && (0 != (FLAG_IS_TERMINAL & flags));
-        // If there are more chars in this node, then this virtual node has children.
-        // If we are on the last char, this virtual node has children if this node has.
-        const bool hasChildren = (!isLastChar) || BinaryFormat::hasChildrenInFlags(flags);
-
-        // This has to be done for each virtual char (this forwards the "inputIndex" which
-        // is the index in the user-inputted chars, as read by proximity chars.
-        if (excessivePos == correctionState->getOutputIndex()
-                && correctionState->getInputIndex() < mInputLength - 1) {
-            correctionState->incrementInputIndex();
-        }
-        if (traverseAllNodes || needsToSkipCurrentNode(
-                c, correctionState->getInputIndex(), skipPos, correctionState->getOutputIndex())) {
-            mWord[correctionState->getOutputIndex()] = c;
-            if (traverseAllNodes && isTerminal) {
-                // The frequency should be here, because we come here only if this is actually
-                // a terminal node, and we are on its last char.
-                const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, freq, mCorrectionState);
-            }
-            if (!hasChildren) {
-                // If we don't have children here, that means we finished processing all
-                // characters of this node (we are on the last virtual node), AND we are in
-                // traverseAllNodes mode, which means we are searching for *completions*. We
-                // should skip the frequency if we have a terminal, and report the position
-                // of the next sibling. We don't have to return other values because we are
-                // returning false, as in "don't traverse children".
-                if (isTerminal) pos = BinaryFormat::skipFrequency(flags, pos);
-                *nextSiblingPosition =
-                        BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
-                return false;
-            }
-        } else {
-            int inputIndexForProximity = correctionState->getInputIndex();
-
-            if (transposedPos >= 0) {
-                if (correctionState->getInputIndex() == transposedPos) {
-                    ++inputIndexForProximity;
-                }
-                if (correctionState->getInputIndex() == (transposedPos + 1)) {
-                    --inputIndexForProximity;
-                }
-            }
-
-            int matchedProximityCharId = mProximityInfo->getMatchedProximityId(
-                    inputIndexForProximity, c, mCorrectionState);
-            if (ProximityInfo::UNRELATED_CHAR == matchedProximityCharId) {
-                // We found that this is an unrelated character, so we should give up traversing
-                // this node and its children entirely.
-                // However we may not be on the last virtual node yet so we skip the remaining
-                // characters in this node, the frequency if it's there, read the next sibling
-                // position to output it, then return false.
-                // We don't have to output other values because we return false, as in
-                // "don't traverse children".
-                if (!isLastChar) {
-                    pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos);
-                }
-                pos = BinaryFormat::skipFrequency(flags, pos);
-                *nextSiblingPosition =
-                        BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
-                return false;
-            }
-            mWord[correctionState->getOutputIndex()] = c;
-            // If inputIndex is greater than mInputLength, that means there is no
-            // proximity chars. So, we don't need to check proximity.
-            if (ProximityInfo::SAME_OR_ACCENTED_OR_CAPITALIZED_CHAR == matchedProximityCharId) {
-                correctionState->charMatched();
-            }
-            const bool isSameAsUserTypedLength = mInputLength
-                    == correctionState->getInputIndex() + 1
-                            || (excessivePos == mInputLength - 1
-                                        && correctionState->getInputIndex() == mInputLength - 2);
-            if (isSameAsUserTypedLength && isTerminal) {
-                const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
-                onTerminal(mWord, freq, mCorrectionState);
-            }
-            // Start traversing all nodes after the index exceeds the user typed length
-            traverseAllNodes = isSameAsUserTypedLength;
-            diffs = diffs
-                    + ((ProximityInfo::NEAR_PROXIMITY_CHAR == matchedProximityCharId) ? 1 : 0);
-            // Finally, we are ready to go to the next character, the next "virtual node".
-            // We should advance the input index.
-            // We do this in this branch of the 'if traverseAllNodes' because we are still matching
-            // characters to input; the other branch is not matching them but searching for
-            // completions, this is why it does not have to do it.
-            correctionState->incrementInputIndex();
-
-            // This character matched the typed character (enough to traverse the node at least)
-            // so we just evaluated it. Now we should evaluate this virtual node's children - that
-            // is, if it has any. If it has no children, we're done here - so we skip the end of
-            // the node, output the siblings position, and return false "don't traverse children".
-            // Note that !hasChildren implies isLastChar, so we know we don't have to skip any
-            // remaining char in this group for there can't be any.
-            if (!hasChildren) {
-                pos = BinaryFormat::skipFrequency(flags, pos);
-                *nextSiblingPosition =
-                        BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
-                return false;
-            }
-        }
-        // Optimization: Prune out words that are too long compared to how much was typed.
-        if (isTerminal
-                && (correctionState->getOutputIndex() >= maxDepth || diffs > mMaxEditDistance)) {
-            // We are giving up parsing this node and its children. Skip the rest of the node,
-            // output the sibling position, and return that we don't want to traverse children.
+        const bool isTerminal = isLastChar && isTerminalNode;
+
+        CorrectionState::CorrectionStateType stateType = correctionState->processCharAndCalcState(
+                c, isTerminal);
+        if (stateType == CorrectionState::TRAVERSE_ALL_ON_TERMINAL
+                || stateType == CorrectionState::ON_TERMINAL) {
+            needsToInvokeOnTerminal = true;
+        } else if (stateType == CorrectionState::UNRELATED) {
+            // We found that this is an unrelated character, so we should give up traversing
+            // this node and its children entirely.
+            // However we may not be on the last virtual node yet so we skip the remaining
+            // characters in this node, the frequency if it's there, read the next sibling
+            // position to output it, then return false.
+            // We don't have to output other values because we return false, as in
+            // "don't traverse children".
             if (!isLastChar) {
                 pos = BinaryFormat::skipOtherCharacters(DICT_ROOT, pos);
             }
@@ -820,8 +718,6 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
                     BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
             return false;
         }
-        // Also, the next char is one "virtual node" depth more than this char.
-        correctionState->incrementOutputIndex();
 
         // Prepare for the next character. Promote the prefetched char to current char - the loop
         // will take care of prefetching the next. If we finally found our last char, nextc will
@@ -829,16 +725,39 @@ inline bool UnigramDictionary::processCurrentNode(const int initialPos, const in
         c = nextc;
     } while (NOT_A_CHARACTER != c);
 
-    // If inputIndex is greater than mInputLength, that means there are no proximity chars.
-    // Here, that's all we are interested in so we don't need to check for isSameAsUserTypedLength.
-    if (mInputLength <= initialInputIndex) {
-        traverseAllNodes = true;
-    }
+    if (isTerminalNode) {
+        if (needsToInvokeOnTerminal) {
+            // The frequency should be here, because we come here only if this is actually
+            // a terminal node, and we are on its last char.
+            const int freq = BinaryFormat::readFrequencyWithoutMovingPointer(DICT_ROOT, pos);
+            onTerminal(freq, mCorrectionState);
+        }
+
+        // If there are more chars in this node, then this virtual node has children.
+        // If we are on the last char, this virtual node has children if this node has.
+        const bool hasChildren = BinaryFormat::hasChildrenInFlags(flags);
+
+        // This character matched the typed character (enough to traverse the node at least)
+        // so we just evaluated it. Now we should evaluate this virtual node's children - that
+        // is, if it has any. If it has no children, we're done here - so we skip the end of
+        // the node, output the siblings position, and return false "don't traverse children".
+        // Note that !hasChildren implies isLastChar, so we know we don't have to skip any
+        // remaining char in this group for there can't be any.
+        if (!hasChildren) {
+            pos = BinaryFormat::skipFrequency(flags, pos);
+            *nextSiblingPosition =
+                    BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
+            return false;
+        }
 
-    // All the output values that are purely computation by this function are held in local
-    // variables. Output them to the caller.
-    *newTraverseAllNodes = traverseAllNodes;
-    *newDiffs = diffs;
+        // Optimization: Prune out words that are too long compared to how much was typed.
+        if (correctionState->needsToPrune()) {
+            pos = BinaryFormat::skipFrequency(flags, pos);
+            *nextSiblingPosition =
+                    BinaryFormat::skipChildrenPosAndAttributes(DICT_ROOT, flags, pos);
+            return false;
+        }
+    }
 
     // Now we finished processing this node, and we want to traverse children. If there are no
     // children, we can't come here.
diff --git a/native/src/unigram_dictionary.h b/native/src/unigram_dictionary.h
index cb86da41ce..a45df24fb0 100644
--- a/native/src/unigram_dictionary.h
+++ b/native/src/unigram_dictionary.h
@@ -87,21 +87,20 @@ private:
             const int *ycoordinates, const int *codes, const int codesSize,
             unsigned short *outWords, int *frequencies);
     void getSuggestionCandidates(const int skipPos, const int excessivePos,
-            const int transposedPos, const int maxDepth);
+            const int transposedPos);
     bool addWord(unsigned short *word, int length, int frequency);
     void getSplitTwoWordsSuggestion(const int inputLength, CorrectionState *correctionState);
     void getMissingSpaceWords(
             const int inputLength, const int missingSpacePos, CorrectionState *correctionState);
     void getMistypedSpaceWords(
             const int inputLength, const int spaceProximityPos, CorrectionState *correctionState);
-    void onTerminal(unsigned short int* word, const int freq, CorrectionState *correctionState);
+    void onTerminal(const int freq, CorrectionState *correctionState);
     bool needsToSkipCurrentNode(const unsigned short c,
             const int inputIndex, const int skipPos, const int depth);
     // Process a node by considering proximity, missing and excessive character
-    bool processCurrentNode(const int initialPos, const int maxDepth,
-            const bool initialTraverseAllNodes, const int initialDiffs,
-            CorrectionState *correctionState, int *newCount, int *newChildPosition,
-            bool *newTraverseAllNodes, int *newDiffs, int *nextSiblingPosition);
+    bool processCurrentNode(const int initialPos,
+            CorrectionState *correctionState, int *newCount,
+            int *newChildPosition, int *nextSiblingPosition);
     int getMostFrequentWordLike(const int startInputIndex, const int inputLength,
             unsigned short *word);
     int getMostFrequentWordLikeInner(const uint16_t* const inWord, const int length,
@@ -134,7 +133,6 @@ private:
     int mInputLength;
     // MAX_WORD_LENGTH_INTERNAL must be bigger than MAX_WORD_LENGTH
     unsigned short mWord[MAX_WORD_LENGTH_INTERNAL];
-    int mMaxEditDistance;
 
     int mStackMatchedCount[MAX_WORD_LENGTH_INTERNAL];
     int mStackChildCount[MAX_WORD_LENGTH_INTERNAL];
-- 
GitLab