Skip to content
Snippets Groups Projects
Commit 54c091d2 authored by Satoshi Kataoka's avatar Satoshi Kataoka Committed by Android (Google) Code Review
Browse files

Merge "Move policy and session to AOSP"

parents 43341ba0 3107b467
No related branches found
No related tags found
No related merge requests found
......@@ -29,7 +29,9 @@ LATIN_IME_SRC_FULLPATH_DIR := $(LOCAL_PATH)/$(LATIN_IME_SRC_DIR)
LOCAL_C_INCLUDES += \
$(LATIN_IME_SRC_FULLPATH_DIR) \
$(LATIN_IME_SRC_FULLPATH_DIR)/suggest \
$(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/dicnode
$(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/dicnode \
$(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/policy \
$(LATIN_IME_SRC_FULLPATH_DIR)/suggest/core/session
LOCAL_CFLAGS += -Werror -Wall -Wextra -Weffc++ -Wformat=2 -Wcast-qual -Wcast-align \
-Wwrite-strings -Wfloat-equal -Wpointer-arith -Winit-self -Wredundant-decls -Wno-system-headers
......@@ -63,7 +65,10 @@ LATIN_IME_CORE_SRC_FILES := \
unigram_dictionary.cpp \
words_priority_queue.cpp \
suggest/core/dicnode/dic_node.cpp \
suggest/core/dicnode/dic_nodes_cache.cpp \
suggest/core/dicnode/dic_node_utils.cpp \
suggest/core/policy/weighting.cpp \
suggest/core/session/dic_traverse_session.cpp \
suggest/gesture_suggest.cpp \
suggest/typing_suggest.cpp
......
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <list>
#include "defines.h"
#include "dic_node_priority_queue.h"
#include "dic_node_utils.h"
#include "dic_nodes_cache.h"
namespace latinime {
/**
* Truncates all of the dicNodes so that they start at the given commit point.
* Only called for multi-word typing input.
*/
DicNode *DicNodesCache::setCommitPoint(int commitPoint) {
std::list<DicNode> dicNodesList;
while (mCachedDicNodesForContinuousSuggestion->getSize() > 0) {
DicNode dicNode;
mCachedDicNodesForContinuousSuggestion->copyPop(&dicNode);
dicNodesList.push_front(dicNode);
}
// Get the starting words of the top scoring dicNode (last dicNode popped from priority queue)
// up to the commit point. These words have already been committed to the text view.
DicNode *topDicNode = &dicNodesList.front();
DicNode topDicNodeCopy;
DicNodeUtils::initByCopy(topDicNode, &topDicNodeCopy);
// Keep only those dicNodes that match the same starting words.
std::list<DicNode>::iterator iter;
for (iter = dicNodesList.begin(); iter != dicNodesList.end(); iter++) {
DicNode *dicNode = &*iter;
if (dicNode->truncateNode(&topDicNodeCopy, commitPoint)) {
mCachedDicNodesForContinuousSuggestion->copyPush(dicNode);
} else {
// Top dicNode should be reprocessed.
ASSERT(dicNode != topDicNode);
DicNode::managedDelete(dicNode);
}
}
mInputIndex -= commitPoint;
return topDicNode;
}
} // namespace latinime
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_DIC_NODES_CACHE_H
#define LATINIME_DIC_NODES_CACHE_H
#include <stdint.h>
#include "defines.h"
#include "dic_node_priority_queue.h"
#define INITIAL_QUEUE_ID_ACTIVE 0
#define INITIAL_QUEUE_ID_NEXT_ACTIVE 1
#define INITIAL_QUEUE_ID_TERMINAL 2
#define INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION 3
#define PRIORITY_QUEUES_SIZE 4
namespace latinime {
class DicNode;
/**
* Class for controlling dicNode search priority queue and lexicon trie traversal.
*/
class DicNodesCache {
public:
AK_FORCE_INLINE DicNodesCache()
: mActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_ACTIVE]),
mNextActiveDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_NEXT_ACTIVE]),
mTerminalDicNodes(&mDicNodePriorityQueues[INITIAL_QUEUE_ID_TERMINAL]),
mCachedDicNodesForContinuousSuggestion(
&mDicNodePriorityQueues[INITIAL_QUEUE_ID_CACHE_FOR_CONTINUOUS_SUGGESTION]),
mInputIndex(0), mLastCachedInputIndex(0) {
}
AK_FORCE_INLINE virtual ~DicNodesCache() {}
AK_FORCE_INLINE void reset(const int nextActiveSize, const int terminalSize) {
mInputIndex = 0;
mLastCachedInputIndex = 0;
mActiveDicNodes->reset();
mNextActiveDicNodes->clearAndResize(nextActiveSize);
mTerminalDicNodes->clearAndResize(terminalSize);
mCachedDicNodesForContinuousSuggestion->reset();
}
AK_FORCE_INLINE void continueSearch() {
resetTemporaryCaches();
restoreActiveDicNodesFromCache();
}
AK_FORCE_INLINE void advanceActiveDicNodes() {
if (DEBUG_DICT) {
AKLOGI("Advance active %d nodes.", mNextActiveDicNodes->getSize());
}
if (DEBUG_DICT_FULL) {
mNextActiveDicNodes->dump();
}
mNextActiveDicNodes =
moveNodesAndReturnReusableEmptyQueue(mNextActiveDicNodes, &mActiveDicNodes);
}
DicNode *setCommitPoint(int commitPoint);
int activeSize() const { return mActiveDicNodes->getSize(); }
int terminalSize() const { return mTerminalDicNodes->getSize(); }
bool isLookAheadCorrectionInputIndex(const int inputIndex) const {
return inputIndex == mInputIndex - 1;
}
void advanceInputIndex(const int inputSize) {
if (mInputIndex < inputSize) {
mInputIndex++;
}
}
AK_FORCE_INLINE void copyPushTerminal(DicNode *dicNode) {
mTerminalDicNodes->copyPush(dicNode);
}
AK_FORCE_INLINE void copyPushActive(DicNode *dicNode) {
mActiveDicNodes->copyPush(dicNode);
}
AK_FORCE_INLINE bool copyPushContinue(DicNode *dicNode) {
return mCachedDicNodesForContinuousSuggestion->copyPush(dicNode);
}
AK_FORCE_INLINE void copyPushNextActive(DicNode *dicNode) {
DicNode *pushedDicNode = mNextActiveDicNodes->copyPush(dicNode);
if (!pushedDicNode) {
if (dicNode->isCached()) {
dicNode->remove();
}
// We simply drop any dic node that was not cached, ignoring the slim chance
// that one of its children represents what the user really wanted.
}
}
void popTerminal(DicNode *dest) {
mTerminalDicNodes->copyPop(dest);
}
void popActive(DicNode *dest) {
mActiveDicNodes->copyPop(dest);
}
bool hasCachedDicNodesForContinuousSuggestion() const {
return mCachedDicNodesForContinuousSuggestion
&& mCachedDicNodesForContinuousSuggestion->getSize() > 0;
}
AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const {
// TODO: Move this variable to header
static const int CACHE_BACK_LENGTH = 3;
const int cacheInputIndex = inputSize - CACHE_BACK_LENGTH;
const bool shouldCache = (cacheInputIndex == mInputIndex)
&& (cacheInputIndex != mLastCachedInputIndex);
return shouldCache;
}
AK_FORCE_INLINE void updateLastCachedInputIndex() {
mLastCachedInputIndex = mInputIndex;
}
private:
DISALLOW_COPY_AND_ASSIGN(DicNodesCache);
AK_FORCE_INLINE void restoreActiveDicNodesFromCache() {
if (DEBUG_DICT) {
AKLOGI("Restore %d nodes. inputIndex = %d.",
mCachedDicNodesForContinuousSuggestion->getSize(), mLastCachedInputIndex);
}
if (DEBUG_DICT_FULL || DEBUG_CACHE) {
mCachedDicNodesForContinuousSuggestion->dump();
}
mInputIndex = mLastCachedInputIndex;
mCachedDicNodesForContinuousSuggestion =
moveNodesAndReturnReusableEmptyQueue(
mCachedDicNodesForContinuousSuggestion, &mActiveDicNodes);
}
AK_FORCE_INLINE static DicNodePriorityQueue *moveNodesAndReturnReusableEmptyQueue(
DicNodePriorityQueue *src, DicNodePriorityQueue **dest) {
const int srcMaxSize = src->getMaxSize();
const int destMaxSize = (*dest)->getMaxSize();
DicNodePriorityQueue *tmp = *dest;
*dest = src;
(*dest)->setMaxSize(destMaxSize);
tmp->clearAndResize(srcMaxSize);
return tmp;
}
AK_FORCE_INLINE void resetTemporaryCaches() {
mActiveDicNodes->clear();
mNextActiveDicNodes->clear();
mTerminalDicNodes->clear();
}
DicNodePriorityQueue mDicNodePriorityQueues[PRIORITY_QUEUES_SIZE];
// Active dicNodes currently being expanded.
DicNodePriorityQueue *mActiveDicNodes;
// Next dicNodes to be expanded.
DicNodePriorityQueue *mNextActiveDicNodes;
// Current top terminal dicNodes.
DicNodePriorityQueue *mTerminalDicNodes;
// Cached dicNodes used for continuous suggestion.
DicNodePriorityQueue *mCachedDicNodesForContinuousSuggestion;
int mInputIndex;
int mLastCachedInputIndex;
};
} // namespace latinime
#endif // LATINIME_DIC_NODES_CACHE_H
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_SCORING_H
#define LATINIME_SCORING_H
#include "defines.h"
namespace latinime {
class DicNode;
class DicTraverseSession;
// This class basically tweaks suggestions and distances apart from CompoundDistance
class Scoring {
public:
virtual int calculateFinalScore(const float compoundDistance, const int inputSize,
const bool forceCommit) const = 0;
virtual bool getMostProbableString(
const DicTraverseSession *const traverseSession, const int terminalSize,
const float languageWeight, int *const outputCodePoints, int *const type,
int *const freq) const = 0;
virtual void safetyNetForMostProbableString(const int terminalSize,
const int maxScore, int *const outputCodePoints, int *const frequencies) const = 0;
// TODO: Make more generic
virtual void searchWordWithDoubleLetter(DicNode *terminals,
const int terminalSize, int *doubleLetterTerminalIndex,
DoubleLetterLevel *doubleLetterLevel) const = 0;
virtual float getAdjustedLanguageWeight(DicTraverseSession *const traverseSession,
DicNode *const terminals, const int size) const = 0;
virtual float getDoubleLetterDemotionDistanceCost(const int terminalIndex,
const int doubleLetterTerminalIndex,
const DoubleLetterLevel doubleLetterLevel) const = 0;
virtual bool doesAutoCorrectValidWord() const = 0;
protected:
Scoring() {}
virtual ~Scoring() {}
private:
DISALLOW_COPY_AND_ASSIGN(Scoring);
};
} // namespace latinime
#endif // LATINIME_SCORING_H
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_SUGGEST_POLICY_H
#define LATINIME_SUGGEST_POLICY_H
#include "defines.h"
namespace latinime {
class Traversal;
class Scoring;
class Weighting;
class SuggestPolicy {
public:
SuggestPolicy() {}
virtual ~SuggestPolicy() {}
virtual const Traversal *getTraversal() const = 0;
virtual const Scoring *getScoring() const = 0;
virtual const Weighting *getWeighting() const = 0;
private:
DISALLOW_COPY_AND_ASSIGN(SuggestPolicy);
};
} // namespace latinime
#endif // LATINIME_SUGGEST_POLICY_H
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_TRAVERSAL_H
#define LATINIME_TRAVERSAL_H
#include "defines.h"
namespace latinime {
class Traversal {
public:
virtual int getMaxPointerCount() const = 0;
virtual bool allowsErrorCorrections(const DicNode *const dicNode) const = 0;
virtual bool isOmission(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, const DicNode *const childDicNode) const = 0;
virtual bool isSpaceSubstitutionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool isSpaceOmissionTerminal(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool shouldDepthLevelCache(const DicTraverseSession *const traverseSession) const = 0;
virtual bool shouldNodeLevelCache(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool canDoLookAheadCorrection(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual ProximityType getProximityType(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
const DicNode *const childDicNode) const = 0;
virtual bool sameAsTyped(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual bool needsToTraverseAllUserInput() const = 0;
virtual float getMaxSpatialDistance() const = 0;
virtual bool allowPartialCommit() const = 0;
virtual int getDefaultExpandDicNodeSize() const = 0;
virtual int getMaxCacheSize() const = 0;
virtual bool isPossibleOmissionChildNode(
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode) const = 0;
virtual bool isGoodToTraverseNextWord(const DicNode *const dicNode) const = 0;
protected:
Traversal() {}
virtual ~Traversal() {}
private:
DISALLOW_COPY_AND_ASSIGN(Traversal);
};
} // namespace latinime
#endif // LATINIME_TRAVERSAL_H
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "char_utils.h"
#include "defines.h"
#include "dic_node.h"
#include "dic_node_profiler.h"
#include "dic_node_utils.h"
#include "dic_traverse_session.h"
#include "hash_map_compat.h"
#include "weighting.h"
namespace latinime {
static inline void profile(const CorrectionType correctionType, DicNode *const node) {
#if DEBUG_DICT
switch (correctionType) {
case CT_OMISSION:
PROF_OMISSION(node->mProfiler);
return;
case CT_ADDITIONAL_PROXIMITY:
PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
return;
case CT_SUBSTITUTION:
PROF_SUBSTITUTION(node->mProfiler);
return;
case CT_NEW_WORD:
PROF_NEW_WORD(node->mProfiler);
return;
case CT_MATCH:
PROF_MATCH(node->mProfiler);
return;
case CT_COMPLETION:
PROF_COMPLETION(node->mProfiler);
return;
case CT_TERMINAL:
PROF_TERMINAL(node->mProfiler);
return;
case CT_SPACE_SUBSTITUTION:
PROF_SPACE_SUBSTITUTION(node->mProfiler);
return;
case CT_INSERTION:
PROF_INSERTION(node->mProfiler);
return;
case CT_TRANSPOSITION:
PROF_TRANSPOSITION(node->mProfiler);
return;
default:
// do nothing
return;
}
#else
// do nothing
#endif
}
/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap) {
const int inputSize = traverseSession->getInputSize();
DicNode_InputStateG inputStateG;
inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, &inputStateG);
const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
traverseSession, parentDicNode, dicNode, bigramCacheMap);
const bool edit = Weighting::isEditCorrection(correctionType);
const bool proximity = Weighting::isProximityCorrection(weighting, correctionType,
traverseSession, dicNode);
profile(correctionType, dicNode);
if (inputStateG.mNeedsToUpdateInputStateG) {
dicNode->updateInputIndexG(&inputStateG);
} else {
dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
(correctionType == CT_TRANSPOSITION));
}
dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
inputSize, edit, proximity);
}
/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) {
switch(correctionType) {
case CT_OMISSION:
return weighting->getOmissionCost(parentDicNode, dicNode);
case CT_ADDITIONAL_PROXIMITY:
// only used for typing
return weighting->getAdditionalProximityCost();
case CT_SUBSTITUTION:
// only used for typing
return weighting->getSubstitutionCost();
case CT_NEW_WORD:
return weighting->getNewWordCost(dicNode);
case CT_MATCH:
return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
case CT_COMPLETION:
return weighting->getCompletionCost(traverseSession, dicNode);
case CT_TERMINAL:
return weighting->getTerminalSpatialCost(traverseSession, dicNode);
case CT_SPACE_SUBSTITUTION:
return weighting->getSpaceSubstitutionCost();
case CT_INSERTION:
return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
case CT_TRANSPOSITION:
return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
default:
return 0.0f;
}
}
/* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap) {
switch(correctionType) {
case CT_OMISSION:
return 0.0f;
case CT_SUBSTITUTION:
return 0.0f;
case CT_NEW_WORD:
return weighting->getNewWordBigramCost(traverseSession, parentDicNode, bigramCacheMap);
case CT_MATCH:
return 0.0f;
case CT_COMPLETION:
return 0.0f;
case CT_TERMINAL: {
const float languageImprobability =
DicNodeUtils::getBigramNodeImprobability(
traverseSession->getOffsetDict(), dicNode, bigramCacheMap);
return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
}
case CT_SPACE_SUBSTITUTION:
return 0.0f;
case CT_INSERTION:
return 0.0f;
case CT_TRANSPOSITION:
return 0.0f;
default:
return 0.0f;
}
}
/* static */ bool Weighting::isEditCorrection(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
return true;
case CT_ADDITIONAL_PROXIMITY:
// Should return true?
return false;
case CT_SUBSTITUTION:
// Should return true?
return false;
case CT_NEW_WORD:
return false;
case CT_MATCH:
return false;
case CT_COMPLETION:
return false;
case CT_TERMINAL:
return false;
case CT_SPACE_SUBSTITUTION:
return false;
case CT_INSERTION:
return true;
case CT_TRANSPOSITION:
return true;
default:
return false;
}
}
/* static */ bool Weighting::isProximityCorrection(const Weighting *const weighting,
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession, const DicNode *const dicNode) {
switch(correctionType) {
case CT_OMISSION:
return false;
case CT_ADDITIONAL_PROXIMITY:
return false;
case CT_SUBSTITUTION:
return false;
case CT_NEW_WORD:
return false;
case CT_MATCH:
return weighting->isProximityDicNode(traverseSession, dicNode);
case CT_COMPLETION:
return false;
case CT_TERMINAL:
return false;
case CT_SPACE_SUBSTITUTION:
return false;
case CT_INSERTION:
return false;
case CT_TRANSPOSITION:
return false;
default:
return false;
}
}
/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
switch(correctionType) {
case CT_OMISSION:
return 0;
case CT_ADDITIONAL_PROXIMITY:
return 0;
case CT_SUBSTITUTION:
return 0;
case CT_NEW_WORD:
return 0;
case CT_MATCH:
return 1;
case CT_COMPLETION:
return 0;
case CT_TERMINAL:
return 0;
case CT_SPACE_SUBSTITUTION:
return 1;
case CT_INSERTION:
return 2;
case CT_TRANSPOSITION:
return 2;
default:
return 0;
}
}
} // namespace latinime
/*
* Copyright (C) 2013 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_WEIGHTING_H
#define LATINIME_WEIGHTING_H
#include "defines.h"
namespace latinime {
class DicNode;
class DicTraverseSession;
struct DicNode_InputStateG;
class Weighting {
public:
static void addCostAndForwardInputIndex(const Weighting *const weighting,
const CorrectionType correctionType,
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap);
protected:
virtual float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getOmissionCost(
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
virtual float getMatchedCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
DicNode_InputStateG *inputStateG) const = 0;
virtual bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getTranspositionCost(
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode) const = 0;
virtual float getInsertionCost(
const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const = 0;
virtual float getNewWordCost(const DicNode *const dicNode) const = 0;
virtual float getNewWordBigramCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap) const = 0;
virtual float getCompletionCost(
const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const = 0;
virtual float getTerminalLanguageCost(
const DicTraverseSession *const traverseSession, const DicNode *const dicNode,
float dicNodeLanguageImprobability) const = 0;
virtual bool needsToNormalizeCompoundDistance() const = 0;
virtual float getAdditionalProximityCost() const = 0;
virtual float getSubstitutionCost() const = 0;
virtual float getSpaceSubstitutionCost() const = 0;
Weighting() {}
virtual ~Weighting() {}
private:
DISALLOW_COPY_AND_ASSIGN(Weighting);
static float getSpatialCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
DicNode_InputStateG *const inputStateG);
static float getLanguageCost(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode,
hash_map_compat<int, int16_t> *const bigramCacheMap);
// TODO: Move to TypingWeighting and GestureWeighting?
static bool isEditCorrection(const CorrectionType correctionType);
// TODO: Move to TypingWeighting and GestureWeighting?
static bool isProximityCorrection(const Weighting *const weighting,
const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
const DicNode *const dicNode);
// TODO: Move to TypingWeighting and GestureWeighting?
static int getForwardInputCount(const CorrectionType correctionType);
};
} // namespace latinime
#endif // LATINIME_WEIGHTING_H
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "defines.h"
#include "dictionary.h"
#include "dic_node_utils.h"
#include "dic_traverse_session.h"
#include "dic_traverse_wrapper.h"
#include "jni.h"
namespace latinime {
const int DicTraverseSession::CACHE_START_INPUT_LENGTH_THRESHOLD = 20;
// A factory method for DicTraverseSession
static void *getSessionInstance(JNIEnv *env, jstring localeStr) {
return new DicTraverseSession(env, localeStr);
}
// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down.
static void initSessionInstance(void *traverseSession, const Dictionary *const dictionary,
const int *prevWord, const int prevWordLength) {
if (traverseSession) {
DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession);
tSession->init(dictionary, prevWord, prevWordLength);
}
}
// TODO: Pass "DicTraverseSession *traverseSession" when the source code structure settles down.
static void releaseSessionInstance(void *traverseSession) {
delete static_cast<DicTraverseSession *>(traverseSession);
}
// An ad-hoc internal class to register the factory method defined above
class TraverseSessionFactoryRegisterer {
public:
TraverseSessionFactoryRegisterer() {
DicTraverseWrapper::setTraverseSessionFactoryMethod(getSessionInstance);
DicTraverseWrapper::setTraverseSessionInitMethod(initSessionInstance);
DicTraverseWrapper::setTraverseSessionReleaseMethod(releaseSessionInstance);
}
private:
DISALLOW_COPY_AND_ASSIGN(TraverseSessionFactoryRegisterer);
};
// To invoke the TraverseSessionFactoryRegisterer constructor in the global constructor.
static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
int prevWordLength) {
mDictionary = dictionary;
if (!prevWord) {
mPrevWordPos = NOT_VALID_WORD;
return;
}
mPrevWordPos = DicNodeUtils::getWordPos(dictionary->getOffsetDict(), prevWord, prevWordLength);
}
void DicTraverseSession::setupForGetSuggestions(const ProximityInfo *pInfo,
const int *inputCodePoints, const int inputSize, const int *const inputXs,
const int *const inputYs, const int *const times, const int *const pointerIds,
const float maxSpatialDistance, const int maxPointerCount) {
mProximityInfo = pInfo;
mMaxPointerCount = maxPointerCount;
initializeProximityInfoStates(inputCodePoints, inputXs, inputYs, times, pointerIds, inputSize,
maxSpatialDistance, maxPointerCount);
}
const uint8_t *DicTraverseSession::getOffsetDict() const {
return mDictionary->getOffsetDict();
}
void DicTraverseSession::resetCache(const int nextActiveCacheSize, const int maxWords) {
mDicNodesCache.reset(nextActiveCacheSize, maxWords);
mBigramCacheMap.clear();
mPartiallyCommited = false;
}
void DicTraverseSession::initializeProximityInfoStates(const int *const inputCodePoints,
const int *const inputXs, const int *const inputYs, const int *const times,
const int *const pointerIds, const int inputSize, const float maxSpatialDistance,
const int maxPointerCount) {
ASSERT(1 <= maxPointerCount && maxPointerCount <= MAX_POINTER_COUNT_G);
mInputSize = 0;
for (int i = 0; i < maxPointerCount; ++i) {
mProximityInfoStates[i].initInputParams(i, maxSpatialDistance, getProximityInfo(),
inputCodePoints, inputSize, inputXs, inputYs, times, pointerIds,
maxPointerCount == MAX_POINTER_COUNT_G
/* TODO: this is a hack. fix proximity info state */);
mInputSize += mProximityInfoStates[i].size();
}
}
} // namespace latinime
/*
* Copyright (C) 2012 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LATINIME_DIC_TRAVERSE_SESSION_H
#define LATINIME_DIC_TRAVERSE_SESSION_H
#include <stdint.h>
#include <vector>
#include "defines.h"
#include "dic_nodes_cache.h"
#include "hash_map_compat.h"
#include "jni.h"
#include "proximity_info_state.h"
namespace latinime {
class Dictionary;
class ProximityInfo;
class DicTraverseSession {
public:
AK_FORCE_INLINE DicTraverseSession(JNIEnv *env, jstring localeStr)
: mPrevWordPos(NOT_VALID_WORD), mProximityInfo(0),
mDictionary(0), mDicNodesCache(), mBigramCacheMap(),
mInputSize(0), mPartiallyCommited(false), mMaxPointerCount(1) {
// NOTE: mProximityInfoStates is an array of instances.
// No need to initialize it explicitly here.
}
// Non virtual inline destructor -- never inherit this class
AK_FORCE_INLINE ~DicTraverseSession() {}
void init(const Dictionary *dictionary, const int *prevWord, int prevWordLength);
// TODO: Remove and merge into init
void setupForGetSuggestions(const ProximityInfo *pInfo, const int *inputCodePoints,
const int inputSize, const int *const inputXs, const int *const inputYs,
const int *const times, const int *const pointerIds, const float maxSpatialDistance,
const int maxPointerCount);
void resetCache(const int nextActiveCacheSize, const int maxWords);
const uint8_t *getOffsetDict() const;
bool canUseCache() const;
//--------------------
// getters and setters
//--------------------
const ProximityInfo *getProximityInfo() const { return mProximityInfo; }
int getPrevWordPos() const { return mPrevWordPos; }
// TODO: REMOVE
void setPrevWordPos(int pos) { mPrevWordPos = pos; }
// TODO: Use proper parameter when changed
int getDicRootPos() const { return 0; }
DicNodesCache *getDicTraverseCache() { return &mDicNodesCache; }
hash_map_compat<int, int16_t> *getBigramCacheMap() { return &mBigramCacheMap; }
const ProximityInfoState *getProximityInfoState(int id) const {
return &mProximityInfoStates[id];
}
int getInputSize() const { return mInputSize; }
void setPartiallyCommited() { mPartiallyCommited = true; }
bool isPartiallyCommited() const { return mPartiallyCommited; }
bool isOnlyOnePointerUsed(int *pointerId) const {
// Not in the dictionary word
int usedPointerCount = 0;
int usedPointerId = 0;
for (int i = 0; i < mMaxPointerCount; ++i) {
if (mProximityInfoStates[i].isUsed()) {
++usedPointerCount;
usedPointerId = i;
}
}
if (usedPointerCount != 1) {
return false;
}
*pointerId = usedPointerId;
return true;
}
void getSearchKeys(const DicNode *node, std::vector<int> *const outputSearchKeyVector) const {
for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) {
if (!mProximityInfoStates[i].isUsed()) {
continue;
}
const int pointerId = node->getInputIndex(i);
const std::vector<int> *const searchKeyVector =
mProximityInfoStates[i].getSearchKeyVector(pointerId);
outputSearchKeyVector->insert(outputSearchKeyVector->end(), searchKeyVector->begin(),
searchKeyVector->end());
}
}
ProximityType getProximityTypeG(const DicNode *const node, const int childCodePoint) const {
ProximityType proximityType = UNRELATED_CHAR;
for (int i = 0; i < MAX_POINTER_COUNT_G; ++i) {
if (!mProximityInfoStates[i].isUsed()) {
continue;
}
const int pointerId = node->getInputIndex(i);
proximityType = mProximityInfoStates[i].getProximityTypeG(pointerId, childCodePoint);
ASSERT(proximityType == UNRELATED_CHAR || proximityType == MATCH_CHAR);
// TODO: Make this more generic
// Currently we assume there are only two types here -- UNRELATED_CHAR
// and MATCH_CHAR
if (proximityType != UNRELATED_CHAR) {
return proximityType;
}
}
return proximityType;
}
AK_FORCE_INLINE bool isCacheBorderForTyping(const int inputSize) const {
return mDicNodesCache.isCacheBorderForTyping(inputSize);
}
/**
* Returns whether or not it is possible to continue suggestion from the previous search.
*/
// TODO: Remove. No need to check once the session is fully implemented.
bool isContinuousSuggestionPossible() const {
if (!mDicNodesCache.hasCachedDicNodesForContinuousSuggestion()) {
return false;
}
ASSERT(mMaxPointerCount < MAX_POINTER_COUNT_G);
for (int i = 0; i < mMaxPointerCount; ++i) {
const ProximityInfoState *const pInfoState = getProximityInfoState(i);
// If a proximity info state is not continuous suggestion possible,
// do not continue searching.
if (pInfoState->isUsed() && !pInfoState->isContinuousSuggestionPossible()) {
return false;
}
}
return true;
}
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(DicTraverseSession);
// threshold to start caching
static const int CACHE_START_INPUT_LENGTH_THRESHOLD;
void initializeProximityInfoStates(const int *const inputCodePoints, const int *const inputXs,
const int *const inputYs, const int *const times, const int *const pointerIds,
const int inputSize, const float maxSpatialDistance, const int maxPointerCount);
int mPrevWordPos;
const ProximityInfo *mProximityInfo;
const Dictionary *mDictionary;
DicNodesCache mDicNodesCache;
// Temporary cache for bigram frequencies
hash_map_compat<int, int16_t> mBigramCacheMap;
ProximityInfoState mProximityInfoStates[MAX_POINTER_COUNT_G];
int mInputSize;
bool mPartiallyCommited;
int mMaxPointerCount;
};
} // namespace latinime
#endif // LATINIME_DIC_TRAVERSE_SESSION_H
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment