From 9a23f0fba25137760a60e9bfaf6bf20a5889648c Mon Sep 17 00:00:00 2001
From: Keisuke Kuroyanagi <ksk@google.com>
Date: Tue, 12 Aug 2014 20:32:42 +0900
Subject: [PATCH] Add bigrams to language model content.

Bug: 14425059

Change-Id: Id81e3775ea0104750a23e3dca62c00681ed8dc2e
---
 .../v402/ver4_patricia_trie_node_writer.cpp   |  4 +-
 .../content/language_model_dict_content.cpp   | 15 ++++++-
 .../v4/content/language_model_dict_content.h  |  2 +-
 .../structure/v4/content/probability_entry.h  | 24 ++++++++++-
 .../structure/v4/ver4_dict_constants.cpp      |  2 +-
 .../structure/v4/ver4_dict_constants.h        |  2 +-
 .../v4/ver4_patricia_trie_node_writer.cpp     | 41 ++++++++++++++-----
 .../v4/ver4_patricia_trie_node_writer.h       |  6 +--
 native/jni/src/utils/int_array_view.h         |  5 +++
 .../jni/tests/utils/int_array_view_test.cpp   | 17 +++++++-
 10 files changed, 96 insertions(+), 22 deletions(-)

diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
index 278f2b199d..f7179f68d0 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/backward/v402/ver4_patricia_trie_node_writer.cpp
@@ -234,8 +234,8 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
 bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
         const BigramProperty *const bigramProperty, bool *const outAddedNewEntry) {
     if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewEntry)) {
-        AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
-                sourcePtNodeParams->getTerminalId(), targetPtNodeParam->getTerminalId());
+        AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
+                prevWordIds[0], wordId);
         return false;
     }
     const int ptNodePos =
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
index 5dc91ba100..f3bc4a0cbf 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.cpp
@@ -46,7 +46,7 @@ ProbabilityEntry LanguageModelDictContent::getNgramProbabilityEntry(
 
 bool LanguageModelDictContent::setNgramProbabilityEntry(const WordIdArrayView prevWordIds,
         const int terminalId, const ProbabilityEntry *const probabilityEntry) {
-    const int bitmapEntryIndex = getBitmapEntryIndex(prevWordIds);
+    const int bitmapEntryIndex = createAndGetBitmapEntryIndex(prevWordIds);
     if (bitmapEntryIndex == TrieMap::INVALID_INDEX) {
         return false;
     }
@@ -80,6 +80,19 @@ bool LanguageModelDictContent::runGCInner(
     return true;
 }
 
+int LanguageModelDictContent::createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds) {
+    if (prevWordIds.empty()) {
+        return mTrieMap.getRootBitmapEntryIndex();
+    }
+    const int lastBitmapEntryIndex =
+            getBitmapEntryIndex(prevWordIds.limit(prevWordIds.size() - 1));
+    if (lastBitmapEntryIndex == TrieMap::INVALID_INDEX) {
+        return TrieMap::INVALID_INDEX;
+    }
+    return mTrieMap.getNextLevelBitmapEntryIndex(prevWordIds[prevWordIds.size() - 1],
+            lastBitmapEntryIndex);
+}
+
 int LanguageModelDictContent::getBitmapEntryIndex(const WordIdArrayView prevWordIds) const {
     int bitmapEntryIndex = mTrieMap.getRootBitmapEntryIndex();
     for (const int wordId : prevWordIds) {
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
index 18f2e01702..104ee2520a 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/language_model_dict_content.h
@@ -76,7 +76,7 @@ class LanguageModelDictContent {
     bool runGCInner(const TerminalPositionLookupTable::TerminalIdMap *const terminalIdMap,
             const TrieMap::TrieMapRange trieMapRange, const int nextLevelBitmapEntryIndex,
             int *const outNgramCount);
-
+    int createAndGetBitmapEntryIndex(const WordIdArrayView prevWordIds);
     int getBitmapEntryIndex(const WordIdArrayView prevWordIds) const;
 };
 } // namespace latinime
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
index feff6b57f5..ed77bd20ef 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/content/probability_entry.h
@@ -21,6 +21,8 @@
 #include <cstdint>
 
 #include "defines.h"
+#include "suggest/core/dictionary/property/bigram_property.h"
+#include "suggest/core/dictionary/property/unigram_property.h"
 #include "suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h"
 #include "suggest/policyimpl/dictionary/utils/historical_info.h"
 
@@ -45,6 +47,20 @@ class ProbabilityEntry {
             const HistoricalInfo *const historicalInfo)
             : mFlags(flags), mProbability(probability), mHistoricalInfo(*historicalInfo) {}
 
+    // Create from unigram property.
+    // TODO: Set flags.
+    ProbabilityEntry(const UnigramProperty *const unigramProperty)
+            : mFlags(0), mProbability(unigramProperty->getProbability()),
+              mHistoricalInfo(unigramProperty->getTimestamp(), unigramProperty->getLevel(),
+                      unigramProperty->getCount()) {}
+
+    // Create from bigram property.
+    // TODO: Set flags.
+    ProbabilityEntry(const BigramProperty *const bigramProperty)
+            : mFlags(0), mProbability(bigramProperty->getProbability()),
+              mHistoricalInfo(bigramProperty->getTimestamp(), bigramProperty->getLevel(),
+                      bigramProperty->getCount()) {}
+
     const ProbabilityEntry createEntryWithUpdatedProbability(const int probability) const {
         return ProbabilityEntry(mFlags, probability, &mHistoricalInfo);
     }
@@ -54,6 +70,10 @@ class ProbabilityEntry {
         return ProbabilityEntry(mFlags, mProbability, historicalInfo);
     }
 
+    bool isValid() const {
+        return (mProbability != NOT_A_PROBABILITY) || hasHistoricalInfo();
+    }
+
     bool hasHistoricalInfo() const {
         return mHistoricalInfo.isValid();
     }
@@ -89,7 +109,7 @@ class ProbabilityEntry {
     static ProbabilityEntry decode(const uint64_t encodedEntry, const bool hasHistoricalInfo) {
         if (hasHistoricalInfo) {
             const int flags = readFromEncodedEntry(encodedEntry,
-                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                     Ver4DictConstants::TIME_STAMP_FIELD_SIZE
                             + Ver4DictConstants::WORD_LEVEL_FIELD_SIZE
                             + Ver4DictConstants::WORD_COUNT_FIELD_SIZE);
@@ -106,7 +126,7 @@ class ProbabilityEntry {
             return ProbabilityEntry(flags, NOT_A_PROBABILITY, &historicalInfo);
         } else {
             const int flags = readFromEncodedEntry(encodedEntry,
-                    Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE,
+                    Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE,
                     Ver4DictConstants::PROBABILITY_SIZE);
             const int probability = readFromEncodedEntry(encodedEntry,
                     Ver4DictConstants::PROBABILITY_SIZE, 0 /* pos */);
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
index 93d4e562da..e622442ba2 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.cpp
@@ -46,7 +46,7 @@ const int Ver4DictConstants::SHORTCUT_BUFFERS_INDEX =
 
 const int Ver4DictConstants::NOT_A_TERMINAL_ID = -1;
 const int Ver4DictConstants::PROBABILITY_SIZE = 1;
-const int Ver4DictConstants::FLAGS_IN_PROBABILITY_FILE_SIZE = 1;
+const int Ver4DictConstants::FLAGS_IN_LANGUAGE_MODEL_SIZE = 1;
 const int Ver4DictConstants::TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE = 3;
 const int Ver4DictConstants::NOT_A_TERMINAL_ADDRESS = 0;
 const int Ver4DictConstants::TERMINAL_ID_FIELD_SIZE = 4;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
index 6950ca70fa..8d29f60d4c 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_dict_constants.h
@@ -41,7 +41,7 @@ class Ver4DictConstants {
 
     static const int NOT_A_TERMINAL_ID;
     static const int PROBABILITY_SIZE;
-    static const int FLAGS_IN_PROBABILITY_FILE_SIZE;
+    static const int FLAGS_IN_LANGUAGE_MODEL_SIZE;
     static const int TERMINAL_ADDRESS_TABLE_ADDRESS_SIZE;
     static const int NOT_A_TERMINAL_ADDRESS;
     static const int TERMINAL_ID_FIELD_SIZE;
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
index 857222f5d3..2c848cb297 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.cpp
@@ -145,10 +145,11 @@ bool Ver4PatriciaTrieNodeWriter::updatePtNodeUnigramProperty(
     const ProbabilityEntry originalProbabilityEntry =
             mBuffers->getLanguageModelDictContent()->getProbabilityEntry(
                     toBeUpdatedPtNodeParams->getTerminalId());
-    const ProbabilityEntry probabilityEntry = createUpdatedEntryFrom(&originalProbabilityEntry,
-            unigramProperty);
+    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
+    const ProbabilityEntry updatedProbabilityEntry =
+            createUpdatedEntryFrom(&originalProbabilityEntry, &probabilityEntryOfUnigramProperty);
     return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
-            toBeUpdatedPtNodeParams->getTerminalId(), &probabilityEntry);
+            toBeUpdatedPtNodeParams->getTerminalId(), &updatedProbabilityEntry);
 }
 
 bool Ver4PatriciaTrieNodeWriter::updatePtNodeProbabilityAndGetNeedsToKeepPtNodeAfterGC(
@@ -216,16 +217,36 @@ bool Ver4PatriciaTrieNodeWriter::writeNewTerminalPtNodeAndAdvancePosition(
     }
     // Write probability.
     ProbabilityEntry newProbabilityEntry;
+    const ProbabilityEntry probabilityEntryOfUnigramProperty = ProbabilityEntry(unigramProperty);
     const ProbabilityEntry probabilityEntryToWrite = createUpdatedEntryFrom(
-            &newProbabilityEntry, unigramProperty);
+            &newProbabilityEntry, &probabilityEntryOfUnigramProperty);
     return mBuffers->getMutableLanguageModelDictContent()->setProbabilityEntry(
             terminalId, &probabilityEntryToWrite);
 }
 
 bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds, const int wordId,
         const BigramProperty *const bigramProperty, bool *const outAddedNewBigram) {
+    // TODO: Support n-gram.
+    LanguageModelDictContent *const languageModelDictContent =
+            mBuffers->getMutableLanguageModelDictContent();
+    const ProbabilityEntry probabilityEntry =
+            languageModelDictContent->getNgramProbabilityEntry(
+                    prevWordIds.limit(1 /* maxSize */), wordId);
+    const ProbabilityEntry probabilityEntryOfBigramProperty(bigramProperty);
+    const ProbabilityEntry updatedProbabilityEntry = createUpdatedEntryFrom(
+            &probabilityEntry, &probabilityEntryOfBigramProperty);
+    if (!languageModelDictContent->setNgramProbabilityEntry(
+            prevWordIds.limit(1 /* maxSize */), wordId, &updatedProbabilityEntry)) {
+        AKLOGE("Cannot add new ngram entry. prevWordId: %d, wordId: %d",
+                prevWordIds[0], wordId);
+        return false;
+    }
+    if (!probabilityEntry.isValid() && outAddedNewBigram) {
+        *outAddedNewBigram = true;
+    }
+    // TODO: Remove.
     if (!mBigramPolicy->addNewEntry(prevWordIds[0], wordId, bigramProperty, outAddedNewBigram)) {
-        AKLOGE("Cannot add new bigram entry. terminalId: %d, targetTerminalId: %d",
+        AKLOGE("Cannot add new bigram entry. prevWordId: %d, wordId: %d",
                 prevWordIds[0], wordId);
         return false;
     }
@@ -234,6 +255,7 @@ bool Ver4PatriciaTrieNodeWriter::addNgramEntry(const WordIdArrayView prevWordIds
 
 bool Ver4PatriciaTrieNodeWriter::removeNgramEntry(const WordIdArrayView prevWordIds,
         const int wordId) {
+    // TODO: Remove.
     return mBigramPolicy->removeEntry(prevWordIds[0], wordId);
 }
 
@@ -352,20 +374,19 @@ bool Ver4PatriciaTrieNodeWriter::writePtNodeAndGetTerminalIdAndAdvancePosition(
 
 const ProbabilityEntry Ver4PatriciaTrieNodeWriter::createUpdatedEntryFrom(
         const ProbabilityEntry *const originalProbabilityEntry,
-        const UnigramProperty *const unigramProperty) const {
+        const ProbabilityEntry *const probabilityEntry) const {
     // TODO: Consolidate historical info and probability.
     if (mHeaderPolicy->hasHistoricalInfoOfWords()) {
-        const HistoricalInfo historicalInfoForUpdate(unigramProperty->getTimestamp(),
-                unigramProperty->getLevel(), unigramProperty->getCount());
         const HistoricalInfo updatedHistoricalInfo =
                 ForgettingCurveUtils::createUpdatedHistoricalInfo(
                         originalProbabilityEntry->getHistoricalInfo(),
-                        unigramProperty->getProbability(), &historicalInfoForUpdate, mHeaderPolicy);
+                        probabilityEntry->getProbability(), probabilityEntry->getHistoricalInfo(),
+                        mHeaderPolicy);
         return originalProbabilityEntry->createEntryWithUpdatedHistoricalInfo(
                 &updatedHistoricalInfo);
     } else {
         return originalProbabilityEntry->createEntryWithUpdatedProbability(
-                unigramProperty->getProbability());
+                probabilityEntry->getProbability());
     }
 }
 
diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
index 6703dba045..5d73b6ea39 100644
--- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
+++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_node_writer.h
@@ -98,12 +98,12 @@ class Ver4PatriciaTrieNodeWriter : public PtNodeWriter {
             const PtNodeParams *const ptNodeParams, int *const outTerminalId,
             int *const ptNodeWritingPos);
 
-    // Create updated probability entry using given unigram property. In addition to the
+    // Create updated probability entry using given probability property. In addition to the
     // probability, this method updates historical information if needed.
-    // TODO: Update flags belonging to the unigram property.
+    // TODO: Update flags.
     const ProbabilityEntry createUpdatedEntryFrom(
             const ProbabilityEntry *const originalProbabilityEntry,
-            const UnigramProperty *const unigramProperty) const;
+            const ProbabilityEntry *const probabilityEntry) const;
 
     bool updatePtNodeFlags(const int ptNodePos, const bool isBlacklisted, const bool isNotAWord,
             const bool isTerminal, const bool hasMultipleChars);
diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h
index c1ddc9812b..53f2d29712 100644
--- a/native/jni/src/utils/int_array_view.h
+++ b/native/jni/src/utils/int_array_view.h
@@ -91,6 +91,11 @@ class IntArrayView {
         return mPtr + mSize;
     }
 
+    // Returns the view whose size is smaller than or equal to the given count.
+    const IntArrayView limit(const size_t maxSize) const {
+        return IntArrayView(mPtr, std::min(maxSize, mSize));
+    }
+
  private:
     DISALLOW_ASSIGNMENT_OPERATOR(IntArrayView);
 
diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp
index bd843ab025..ecc451af0a 100644
--- a/native/jni/tests/utils/int_array_view_test.cpp
+++ b/native/jni/tests/utils/int_array_view_test.cpp
@@ -53,9 +53,24 @@ TEST(IntArrayViewTest, TestConstructFromArray) {
 TEST(IntArrayViewTest, TestConstructFromObject) {
     const int object = 10;
     const auto intArrayView = IntArrayView::fromObject(&object);
-    EXPECT_EQ(1, intArrayView.size());
+    EXPECT_EQ(1u, intArrayView.size());
     EXPECT_EQ(object, intArrayView[0]);
 }
 
+TEST(IntArrayViewTest, TestLimit) {
+    const std::vector<int> intVector = {3, 2, 1, 0, -1, -2};
+    IntArrayView intArrayView(intVector);
+
+    EXPECT_TRUE(intArrayView.limit(0).empty());
+    EXPECT_EQ(intArrayView.size(), intArrayView.limit(intArrayView.size()).size());
+    EXPECT_EQ(intArrayView.size(), intArrayView.limit(1000).size());
+
+    IntArrayView subView = intArrayView.limit(4);
+    EXPECT_EQ(4u, subView.size());
+    for (size_t i = 0; i < subView.size(); ++i) {
+        EXPECT_EQ(intVector[i], subView[i]);
+    }
+}
+
 }  // namespace
 }  // namespace latinime
-- 
GitLab