Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
#pragma once
#include "suggest/core/dicnode/dic_node.h"
#include "suggest/core/session/dic_traverse_session.h"
#include "suggest/core/layout/proximity_info.h"
#include "suggest/core/policy/weighting.h"
#include "suggest/policyimpl/typing/scoring_params.h"
namespace util {
static AK_FORCE_INLINE int getDistanceBetweenPoints(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int px = proximityInfoState->getInputX(index);
int py = proximityInfoState->getInputY(index);
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
int kx = proximityInfo->getSweetSpotCenterXAt(keyIdx);
int ky = proximityInfo->getSweetSpotCenterYAt(keyIdx);
return sqrtf(latinime::GeometryUtils::getDistanceSq(px, py, kx, ky));
}
static AK_FORCE_INLINE float findMinimumPointDistance(int px, int py, int l0x, int l0y, int l1x, int l1y) {
int ax = l0x;
int ay = l0y;
int bx = l1x - l0x;
int by = l1y - l0y;
if(bx == 0 && by == 0) {
int dx = px - ax;
int dy = py - ay;
return (dx * dx + dy * dy);
}
int p_dot_b = px * bx + py * by;
int a_dot_b = ax * bx + ay * by;
int b_len_sq = bx * bx + by * by;
float t = (float)(p_dot_b - a_dot_b) / (float)b_len_sq;
if(t < 0.0f) t = 0.0f;
if(t > 1.0f) t = 1.0f;
float cx = (px - (ax + t * bx));
float cy = (py - (ay + t * by));
return sqrtf(cx * cx + cy * cy);
}
static AK_FORCE_INLINE float getDistanceLine(const latinime::DicTraverseSession *const traverseSession, int codePoint, int index0, int index1) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int l0x = proximityInfoState->getInputX(index0);
int l0y = proximityInfoState->getInputY(index0);
int l1x = proximityInfoState->getInputX(index1);
int l1y = proximityInfoState->getInputY(index1);
int keyIdx = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint));
int px = proximityInfo->getSweetSpotCenterXAt(keyIdx);
int py = proximityInfo->getSweetSpotCenterYAt(keyIdx);
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
}
static AK_FORCE_INLINE float getDistanceCodePointLine(const latinime::DicTraverseSession *const traverseSession, int codePoint0, int codePoint1, int index) {
auto proximityInfoState = traverseSession->getProximityInfoState(0);
auto proximityInfo = traverseSession->getProximityInfo();
int px = proximityInfoState->getInputX(index);
int py = proximityInfoState->getInputY(index);
int keyIdx0 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
int keyIdx1 = proximityInfo->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
int l0x = proximityInfo->getSweetSpotCenterXAt(keyIdx0);
int l0y = proximityInfo->getSweetSpotCenterYAt(keyIdx0);
int l1x = proximityInfo->getSweetSpotCenterXAt(keyIdx1);
int l1y = proximityInfo->getSweetSpotCenterYAt(keyIdx1);
return findMinimumPointDistance(px, py, l0x, l0y, l1x, l1y);
}
static AK_FORCE_INLINE float pow2(float f){
return f * f;
}
static AK_FORCE_INLINE float calcLineDeviationPunishment(
const latinime::DicTraverseSession *const traverseSession,
int codePoint0, int codePoint1,
int lowerLimit, int upperLimit,
float threshold
) {
float totalDistance = 0.0;
const int ki_0 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint0));
const int ki_1 = traverseSession->getProximityInfo()->getKeyIndexOf(latinime::CharUtils::toBaseLowerCase(codePoint1));
const float l0x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_0);
const float l0y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_0);
const float l1x = traverseSession->getProximityInfo()->getSweetSpotCenterXAt(ki_1);
const float l1y = traverseSession->getProximityInfo()->getSweetSpotCenterYAt(ki_1);
for(int j = lowerLimit; j < upperLimit; j++) {
const float distance = getDistanceCodePointLine(traverseSession, codePoint0, codePoint1, j);
totalDistance += distance;
if(distance > threshold) {
//AKLOGI("Attention please: at %d (%d->%d) [%c->%c], distance %.2f exceeds threshold %.2f", j, lowerLimit, upperLimit, (char)codePoint0, (char)codePoint1, distance, threshold);
return MAX_VALUE_FOR_WEIGHTING;
}
if(j > 1) {
const float px = traverseSession->getProximityInfoState(0)->getInputX(j);
const float py = traverseSession->getProximityInfoState(0)->getInputY(j);
const float pxp = traverseSession->getProximityInfoState(0)->getInputX(j - 1);
const float pyp = traverseSession->getProximityInfoState(0)->getInputY(j - 1);
float swipedx = px - pxp;
float swipedy = py - pyp;
const float swipelen = sqrtf(swipedx * swipedx + swipedy * swipedy);
swipedx /= swipelen;
swipedy /= swipelen;
float linedx = l1x - l0x;
float linedy = l1y - l0y;
const float linelen = sqrtf(linedx * linedx + linedy * linedy);
linedx /= linelen;
linedy /= linelen;
const float dotDirection = swipedx * linedx + swipedy * linedy;
if (dotDirection < 0.0) {
totalDistance += swipelen * -dotDirection;
}
}
}
return totalDistance;
}
static AK_FORCE_INLINE float getThresholdBase(const latinime::DicTraverseSession *const traverseSession) {
return traverseSession->getProximityInfo()->getMostCommonKeyWidth() / 48.0f;
}
}
namespace latinime {
class SwipeWeighting : public Weighting {
public:
static const SwipeWeighting *getInstance() { return &sInstance; }
AK_FORCE_INLINE float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode,
const DicNode *const dicNode) const override {
const int codePoint = dicNode->getNodeCodePoint();
const float distanceThreshold = util::getThresholdBase(traverseSession);
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint,
traverseSession->getInputSize() - 1);
if(distance > (distanceThreshold * 128.0f)) {
//AKLOGI("Terminal spatial for %c:%c fails due to exceeding distance", (parentDicNode != nullptr) ? (char)(parentDicNode->getNodeCodePoint()) : '?', (char)codePoint);
//dicNode->dump("TERMINAL");
return MAX_VALUE_FOR_WEIGHTING;
}
float totalDistance = distance * distance;
if(parentDicNode != nullptr) {
int codePoint0;
if(parentDicNode->isZeroCostOmission() || parentDicNode->canBeIntentionalOmission()) {
codePoint0 = parentDicNode->getPrevCodePointG(0);
} else {
codePoint0 = parentDicNode->getNodeCodePoint();
}
if(codePoint0 != NOT_A_CODE_POINT) {
const int codePoint1 = codePoint;
const int lowerLimit = dicNode->getInputIndex(0);
const int upperLimit = traverseSession->getInputSize();
const float threshold = (distanceThreshold * 86.0f);
const float extraDistance = 8.0f * util::calcLineDeviationPunishment(
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit, threshold);
totalDistance += extraDistance;
} else {
totalDistance += MAX_VALUE_FOR_WEIGHTING;
}
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
//AKLOGI("Terminal spatial for %c:%c - %d:%d : extra %.2f %.2f", (char)codePoint0, (char)codePoint1, lowerLimit, upperLimit, distance, extraDistance);
//dicNode->dump("TERMINAL");
return totalDistance;
} else {
AKLOGE("Nullptr parent unexpected! for terminal");
return MAX_VALUE_FOR_WEIGHTING;
}
}
AK_FORCE_INLINE float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
// If the traversal omitted the first letter then the dicNode should now be on the second.
const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
float cost = MAX_VALUE_FOR_WEIGHTING;
if(isZeroCostOmission || isIntentionalOmission || isFirstLetterOmission || sameCodePoint) {
cost = 0.0f;
}
return cost;
}
AK_FORCE_INLINE float getMatchedCost(const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const override {
const int codePoint = dicNode->getNodeCodePoint();
const float distanceThreshold = util::getThresholdBase(traverseSession);
if(dicNode->isFirstLetter()) { // Add the first point (from when swiping starts)
const float distance = util::getDistanceBetweenPoints(traverseSession, codePoint, 0);
if (distance < (40.0f * distanceThreshold)) {
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = 1;
inputStateG->mRawLength = distance;
inputStateG->mPrevCodePoint = NOT_A_CODE_POINT;
return distance;
} else {
return MAX_VALUE_FOR_WEIGHTING;
}
} else if(parentDicNode != nullptr && parentDicNode->getNodeCodePoint() == codePoint) {
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = dicNode->getInputIndex(0);
inputStateG->mRawLength = 0.0f;
inputStateG->mPrevCodePoint = parentDicNode->getPrevCodePointG(0);
return 0.0f;
} else if(dicNode->isZeroCostOmission() || dicNode->canBeIntentionalOmission()) {
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = dicNode->getInputIndex(0);
inputStateG->mRawLength = 0.0f;
if(parentDicNode != nullptr) {
inputStateG->mPrevCodePoint = parentDicNode->getNodeCodePoint();
} else {
inputStateG->mPrevCodePoint = NOT_A_CODE_POINT;
}
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
return 0.0f;
} else { // Add middle points
const int inputIndex = dicNode->getInputIndex(0);
const int swipeLength = traverseSession->getInputSize();
int minEdgeIndex = -1;
float minEdgeDistance = MAX_VALUE_FOR_WEIGHTING;
bool found = false;
bool headedTowardsCharacterYet = false;
const float keyThreshold = (80.0f * distanceThreshold);
//AKLOGI("commence search for %c", (char)codePoint);
for (int i = inputIndex; i < swipeLength; i++) {
if (i == 0) continue;
const float distance = util::getDistanceLine(traverseSession, codePoint, i - 1, i);
//AKLOGI("[%c:%d] distance %.2f, min %.2f. thresh %.2f", (char)codePoint, i, distance, minEdgeDistance, keyThreshold);
if (distance < minEdgeDistance) {
if(minEdgeIndex != -1) headedTowardsCharacterYet = true;
minEdgeDistance = distance;
minEdgeIndex = i;
}
if (((distance > minEdgeDistance) || (i >= (swipeLength - 1))) && (minEdgeDistance < keyThreshold) && headedTowardsCharacterYet) {
//AKLOGI("found!");
found = true;
break;
}
}
if(found && parentDicNode != nullptr && minEdgeDistance < MAX_VALUE_FOR_WEIGHTING) {
float totalDistance = 24.0f * pow(minEdgeDistance, 1.6f);
int codePoint0;
if(parentDicNode->isZeroCostOmission() || parentDicNode->canBeIntentionalOmission()) {
codePoint0 = parentDicNode->getPrevCodePointG(0);
} else {
codePoint0 = parentDicNode->getNodeCodePoint();
}
if(codePoint0 != NOT_A_CODE_POINT) {
const int codePoint1 = codePoint;
const int lowerLimit = inputIndex;
const int upperLimit = minEdgeIndex;
const float threshold = (distanceThreshold * 86.0f);
const float punishment = util::calcLineDeviationPunishment(
traverseSession, codePoint0, codePoint1, lowerLimit, upperLimit,
threshold);
if (punishment >= MAX_VALUE_FOR_WEIGHTING) {
//AKLOGI("Culled due to too large distance (%.2f, %.2f)", totalDistance, punishment);
//dicNode->dump("CULLED");
return MAX_VALUE_FOR_WEIGHTING;
}
totalDistance += punishment;
}
inputStateG->mNeedsToUpdateInputStateG = true;
inputStateG->mInputIndex = minEdgeIndex;
inputStateG->mRawLength = totalDistance;
inputStateG->mPrevCodePoint = codePoint0;
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
return totalDistance;
} else {
//AKLOGI("Culled due to not found or nullptr parent %p %d %.2f. inputIndex is %d and swipeLength is %d", parentDicNode, found, minEdgeDistance, inputIndex, swipeLength);
//dicNode->dump("CULLED");
}
if(parentDicNode == nullptr) {
AKLOGE("Nullptr parent unexpected! for match");
}
}
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE bool isProximityDicNode(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return false;
}
AK_FORCE_INLINE float getTranspositionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getTransitionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
int idx = dicNode->getInputIndex(0);
if(true || idx < 0 || idx >= traverseSession->getProximityInfoState(0)->size())
return MAX_VALUE_FOR_WEIGHTING;
return 1.0f * traverseSession->getProximityInfoState(0)->getProbability(idx, NOT_AN_INDEX);
}
AK_FORCE_INLINE float getInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const parentDicNode, const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getSpaceOmissionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, DicNode_InputStateG *const inputStateG) const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::SPACE_OMISSION_COST;
}
AK_FORCE_INLINE float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, MultiBigramMap *const multiBigramMap) const override {
return DicNodeUtils::getBigramNodeImprobability(
traverseSession->getDictionaryStructurePolicy(),
dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
}
AK_FORCE_INLINE float getCompletionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::COST_COMPLETION;
}
AK_FORCE_INLINE float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return ScoringParams::TERMINAL_INSERTION_COST;
}
AK_FORCE_INLINE float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode, float dicNodeLanguageImprobability) const override {
//return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
//return //dicNode->getSpatialDistanceForScoring() * dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
return dicNodeLanguageImprobability;
}
AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const override {
return false;
}
AK_FORCE_INLINE float getAdditionalProximityCost() const override {
return MAX_VALUE_FOR_WEIGHTING;// ScoringParams::ADDITIONAL_PROXIMITY_COST;
}
AK_FORCE_INLINE float getSubstitutionCost() const override {
return MAX_VALUE_FOR_WEIGHTING;
}
AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
const DicNode *const dicNode) const override {
return 1.5f;
}
AK_FORCE_INLINE ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
const DicTraverseSession *const traverseSession, const DicNode *const parentDicNode,
const DicNode *const dicNode) const override {
return ErrorTypeUtils::PROXIMITY_CORRECTION;
}
private:
DISALLOW_COPY_AND_ASSIGN(SwipeWeighting);
static const SwipeWeighting sInstance;
SwipeWeighting() {}
~SwipeWeighting() {}
};
};