diff --git a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp index 1336a6229b4d0b576f47a0893752fb998bdd67d2..d537711b0ab8610547e3443a3a1ee5f6942869dc 100644 --- a/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp +++ b/native/jni/src/suggest/policyimpl/dictionary/structure/v4/ver4_patricia_trie_policy.cpp @@ -310,30 +310,32 @@ bool Ver4PatriciaTriePolicy::addNgramEntry(const PrevWordsInfo *const prevWordsI if (prevWordIds.empty()) { return false; } - // TODO: Support N-gram. - if (prevWordIds[0] == NOT_A_WORD_ID) { - if (prevWordsInfo->isNthPrevWordBeginningOfSentence(1 /* n */)) { - const std::vector<UnigramProperty::ShortcutProperty> shortcuts; - const UnigramProperty beginningOfSentenceUnigramProperty( - true /* representsBeginningOfSentence */, true /* isNotAWord */, - false /* isBlacklisted */, MAX_PROBABILITY /* probability */, - NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); - if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), - &beginningOfSentenceUnigramProperty)) { - AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); - return false; - } - // Refresh word ids. - prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); - } else { + for (size_t i = 0; i < prevWordIds.size(); ++i) { + if (prevWordIds[i] != NOT_A_WORD_ID) { + continue; + } + if (!prevWordsInfo->isNthPrevWordBeginningOfSentence(i + 1 /* n */)) { return false; } + const std::vector<UnigramProperty::ShortcutProperty> shortcuts; + const UnigramProperty beginningOfSentenceUnigramProperty( + true /* representsBeginningOfSentence */, true /* isNotAWord */, + false /* isBlacklisted */, MAX_PROBABILITY /* probability */, + NOT_A_TIMESTAMP /* timestamp */, 0 /* level */, 0 /* count */, &shortcuts); + if (!addUnigramEntry(prevWordsInfo->getNthPrevWordCodePoints(1 /* n */), + &beginningOfSentenceUnigramProperty)) { + AKLOGE("Cannot add unigram entry for the beginning-of-sentence."); + return false; + } + // Refresh word ids. + prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSearch */); } const int wordId = getWordId(CodePointArrayView(*bigramProperty->getTargetCodePoints()), false /* forceLowerCaseSearch */); if (wordId == NOT_A_WORD_ID) { return false; } + // TODO: Support N-gram. bool addedNewEntry = false; WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordsPtNodePos; for (size_t i = 0; i < prevWordsPtNodePos.size(); ++i) { @@ -375,8 +377,7 @@ bool Ver4PatriciaTriePolicy::removeNgramEntry(const PrevWordsInfo *const prevWor WordIdArray<MAX_PREV_WORD_COUNT_FOR_N_GRAM> prevWordIdArray; const WordIdArrayView prevWordIds = prevWordsInfo->getPrevWordIds(this, &prevWordIdArray, false /* tryLowerCaseSerch */); - // TODO: Support N-gram. - if (prevWordIds.empty() || prevWordIds[0] == NOT_A_WORD_ID) { + if (prevWordIds.empty() || prevWordIds.contains(NOT_A_WORD_ID)) { return false; } const int wordId = getWordId(wordCodePoints, false /* forceLowerCaseSearch */); diff --git a/native/jni/src/utils/int_array_view.h b/native/jni/src/utils/int_array_view.h index caa13d97625f8e6e2d61a6b40acf5b054e05c51f..cc5f328ba938db8806571ad97626e52a7b887c68 100644 --- a/native/jni/src/utils/int_array_view.h +++ b/native/jni/src/utils/int_array_view.h @@ -17,6 +17,7 @@ #ifndef LATINIME_INT_ARRAY_VIEW_H #define LATINIME_INT_ARRAY_VIEW_H +#include <algorithm> #include <array> #include <cstdint> #include <cstring> @@ -92,12 +93,16 @@ class IntArrayView { return mPtr + mSize; } + AK_FORCE_INLINE bool contains(const int value) const { + return std::find(begin(), end(), value) != end(); + } + // Returns the view whose size is smaller than or equal to the given count. - const IntArrayView limit(const size_t maxSize) const { + AK_FORCE_INLINE const IntArrayView limit(const size_t maxSize) const { return IntArrayView(mPtr, std::min(maxSize, mSize)); } - const IntArrayView skip(const size_t n) const { + AK_FORCE_INLINE const IntArrayView skip(const size_t n) const { if (mSize <= n) { return IntArrayView(); } diff --git a/native/jni/tests/utils/int_array_view_test.cpp b/native/jni/tests/utils/int_array_view_test.cpp index 3bc294cdde66bde468d0b85ba4dc038699bc136e..934e27e1c56c3ad0210458518a0cb5bf8c16ee22 100644 --- a/native/jni/tests/utils/int_array_view_test.cpp +++ b/native/jni/tests/utils/int_array_view_test.cpp @@ -58,6 +58,19 @@ TEST(IntArrayViewTest, TestConstructFromObject) { EXPECT_EQ(object, intArrayView[0]); } +TEST(IntArrayViewTest, TestContains) { + EXPECT_FALSE(IntArrayView().contains(0)); + EXPECT_FALSE(IntArrayView().contains(1)); + + const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; + IntArrayView intArrayView(intVector); + EXPECT_TRUE(intArrayView.contains(0)); + EXPECT_TRUE(intArrayView.contains(3)); + EXPECT_TRUE(intArrayView.contains(-2)); + EXPECT_FALSE(intArrayView.contains(-3)); + EXPECT_FALSE(intArrayView.limit(0).contains(3)); +} + TEST(IntArrayViewTest, TestLimit) { const std::vector<int> intVector = {3, 2, 1, 0, -1, -2}; IntArrayView intArrayView(intVector);