Skip to content
Snippets Groups Projects
Commit 5064ac88 authored by Jean Chalard's avatar Jean Chalard Committed by Android (Google) Code Review
Browse files

Merge "Be careful about the dictionary size in detection methods"

parents eab27c1e 03f8c6ae
No related branches found
No related tags found
No related merge requests found
......@@ -109,7 +109,8 @@ static jlong latinime_BinaryDictionary_open(JNIEnv *env, jclass clazz, jstring s
}
Dictionary *dictionary = 0;
if (BinaryFormat::UNKNOWN_FORMAT
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf))) {
== BinaryFormat::detectFormat(static_cast<uint8_t *>(dictBuf),
static_cast<int>(dictSize))) {
AKLOGE("DICT: dictionary format is unknown, bad magic number");
#ifdef USE_MMAP_FOR_DICTIONARY
releaseDictBuf(static_cast<const char *>(dictBuf) - adjust, adjDictSize, fd);
......
......@@ -64,13 +64,14 @@ class BinaryFormat {
static const int UNKNOWN_FORMAT = -1;
static const int SHORTCUT_LIST_SIZE_SIZE = 2;
static int detectFormat(const uint8_t *const dict);
static int getHeaderSize(const uint8_t *const dict);
static int getFlags(const uint8_t *const dict);
static int detectFormat(const uint8_t *const dict, const int dictSize);
static int getHeaderSize(const uint8_t *const dict, const int dictSize);
static int getFlags(const uint8_t *const dict, const int dictSize);
static bool hasBlacklistedOrNotAWordFlag(const int flags);
static void readHeaderValue(const uint8_t *const dict, const char *const key, int *outValue,
const int outValueSize);
static int readHeaderValueInt(const uint8_t *const dict, const char *const key);
static void readHeaderValue(const uint8_t *const dict, const int dictSize,
const char *const key, int *outValue, const int outValueSize);
static int readHeaderValueInt(const uint8_t *const dict, const int dictSize,
const char *const key);
static int getGroupCountAndForwardPointer(const uint8_t *const dict, int *pos);
static uint8_t getFlagsAndForwardPointer(const uint8_t *const dict, int *pos);
static int getCodePointAndForwardPointer(const uint8_t *const dict, int *pos);
......@@ -96,7 +97,7 @@ class BinaryFormat {
const uint8_t *bigramFilter, const int unigramProbability);
static int getBigramProbabilityFromHashMap(const int position,
const hash_map_compat<int, int> *bigramMap, const int unigramProbability);
static float getMultiWordCostMultiplier(const uint8_t *const dict);
static float getMultiWordCostMultiplier(const uint8_t *const dict, const int dictSize);
static void fillBigramProbabilityToHashMap(const uint8_t *const root, int position,
hash_map_compat<int, int> *bigramMap);
static int getBigramProbability(const uint8_t *const root, int position,
......@@ -122,6 +123,8 @@ class BinaryFormat {
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_TWOBYTES = 0x20;
static const int FLAG_ATTRIBUTE_ADDRESS_TYPE_THREEBYTES = 0x30;
// Any file smaller than this is not a dictionary.
static const int DICTIONARY_MINIMUM_SIZE = 4;
// Originally, format version 1 had a 16-bit magic number, then the version number `01'
// then options that must be 0. Hence the first 32-bits of the format are always as follow
// and it's okay to consider them a magic number as a whole.
......@@ -131,6 +134,8 @@ class BinaryFormat {
// number, so we had to change it so that version 2 files would be rejected by older
// implementations. On this occasion, we made the magic number 32 bits long.
static const int FORMAT_VERSION_2_MAGIC_NUMBER = -1681835266; // 0x9BC13AFE
// Magic number (4 bytes), version (2 bytes), options (2 bytes), header size (4 bytes) = 12
static const int FORMAT_VERSION_2_MINIMUM_SIZE = 12;
static const int CHARACTER_ARRAY_TERMINATOR_SIZE = 1;
static const int MINIMAL_ONE_BYTE_CHARACTER_VALUE = 0x20;
......@@ -141,8 +146,11 @@ class BinaryFormat {
static int skipBigrams(const uint8_t *const dict, const uint8_t flags, const int pos);
};
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict, const int dictSize) {
// The magic number is stored big-endian.
// If the dictionary is less than 4 bytes, we can't even read the magic number, so we don't
// understand this format.
if (dictSize < DICTIONARY_MINIMUM_SIZE) return UNKNOWN_FORMAT;
const int magicNumber = (dict[0] << 24) + (dict[1] << 16) + (dict[2] << 8) + dict[3];
switch (magicNumber) {
case FORMAT_VERSION_1_MAGIC_NUMBER:
......@@ -152,6 +160,10 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
// Options (2 bytes) must be 0x00 0x00
return 1;
case FORMAT_VERSION_2_MAGIC_NUMBER:
// Version 2 dictionaries are at least 12 bytes long (see below details for the header).
// If this dictionary has the version 2 magic number but is less than 12 bytes long, then
// it's an unknown format and we need to avoid confidently reading the next bytes.
if (dictSize < FORMAT_VERSION_2_MINIMUM_SIZE) return UNKNOWN_FORMAT;
// Format 2 header is as follows:
// Magic number (4 bytes) 0x9B 0xC1 0x3A 0xFE
// Version number (2 bytes) 0x00 0x02
......@@ -163,8 +175,8 @@ AK_FORCE_INLINE int BinaryFormat::detectFormat(const uint8_t *const dict) {
}
}
inline int BinaryFormat::getFlags(const uint8_t *const dict) {
switch (detectFormat(dict)) {
inline int BinaryFormat::getFlags(const uint8_t *const dict, const int dictSize) {
switch (detectFormat(dict, dictSize)) {
case 1:
return NO_FLAGS; // TODO: NO_FLAGS is unused anywhere else?
default:
......@@ -176,8 +188,8 @@ inline bool BinaryFormat::hasBlacklistedOrNotAWordFlag(const int flags) {
return (flags & (FLAG_IS_BLACKLISTED | FLAG_IS_NOT_A_WORD)) != 0;
}
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
switch (detectFormat(dict)) {
inline int BinaryFormat::getHeaderSize(const uint8_t *const dict, const int dictSize) {
switch (detectFormat(dict, dictSize)) {
case 1:
return FORMAT_VERSION_1_HEADER_SIZE;
case 2:
......@@ -188,12 +200,12 @@ inline int BinaryFormat::getHeaderSize(const uint8_t *const dict) {
}
}
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char *const key,
int *outValue, const int outValueSize) {
inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const int dictSize,
const char *const key, int *outValue, const int outValueSize) {
int outValueIndex = 0;
// Only format 2 and above have header attributes as {key,value} string pairs. For prior
// formats, we just return an empty string, as if the key wasn't found.
if (2 <= detectFormat(dict)) {
if (2 <= detectFormat(dict, dictSize)) {
const int headerOptionsOffset = 4 /* magic number */
+ 2 /* dictionary version */ + 2 /* flags */;
const int headerSize =
......@@ -236,11 +248,12 @@ inline void BinaryFormat::readHeaderValue(const uint8_t *const dict, const char
if (outValueIndex >= 0) outValue[outValueIndex] = 0;
}
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const char *const key) {
inline int BinaryFormat::readHeaderValueInt(const uint8_t *const dict, const int dictSize,
const char *const key) {
const int bufferSize = LARGEST_INT_DIGIT_COUNT;
int intBuffer[bufferSize];
char charBuffer[bufferSize];
BinaryFormat::readHeaderValue(dict, key, intBuffer, bufferSize);
BinaryFormat::readHeaderValue(dict, dictSize, key, intBuffer, bufferSize);
for (int i = 0; i < bufferSize; ++i) {
charBuffer[i] = intBuffer[i];
}
......@@ -256,8 +269,10 @@ AK_FORCE_INLINE int BinaryFormat::getGroupCountAndForwardPointer(const uint8_t *
return ((msb & 0x7F) << 8) | dict[(*pos)++];
}
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict) {
const int headerValue = readHeaderValueInt(dict, "MULTIPLE_WORDS_DEMOTION_RATE");
inline float BinaryFormat::getMultiWordCostMultiplier(const uint8_t *const dict,
const int dictSize) {
const int headerValue = readHeaderValueInt(dict, dictSize,
"MULTIPLE_WORDS_DEMOTION_RATE");
if (headerValue == S_INT_MIN) {
return 1.0f;
}
......
......@@ -34,9 +34,11 @@ namespace latinime {
Dictionary::Dictionary(void *dict, int dictSize, int mmapFd, int dictBufAdjust)
: mDict(static_cast<unsigned char *>(dict)),
mOffsetDict((static_cast<unsigned char *>(dict)) + BinaryFormat::getHeaderSize(mDict)),
mOffsetDict((static_cast<unsigned char *>(dict))
+ BinaryFormat::getHeaderSize(mDict, dictSize)),
mDictSize(dictSize), mMmapFd(mmapFd), mDictBufAdjust(dictBufAdjust),
mUnigramDictionary(new UnigramDictionary(mOffsetDict, BinaryFormat::getFlags(mDict))),
mUnigramDictionary(new UnigramDictionary(mOffsetDict,
BinaryFormat::getFlags(mDict, dictSize))),
mBigramDictionary(new BigramDictionary(mOffsetDict)),
mGestureSuggest(new Suggest(GestureSuggestPolicyFactory::getGestureSuggestPolicy())),
mTypingSuggest(new Suggest(TypingSuggestPolicyFactory::getTypingSuggestPolicy())) {
......
......@@ -64,7 +64,8 @@ static TraverseSessionFactoryRegisterer traverseSessionFactoryRegisterer;
void DicTraverseSession::init(const Dictionary *const dictionary, const int *prevWord,
int prevWordLength) {
mDictionary = dictionary;
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict());
mMultiWordCostMultiplier = BinaryFormat::getMultiWordCostMultiplier(mDictionary->getDict(),
mDictionary->getDictSize());
if (!prevWord) {
mPrevWordPos = NOT_VALID_WORD;
return;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment