diff --git a/java/src/org/futo/inputmethod/latin/xlm/BatchInputConverter.kt b/java/src/org/futo/inputmethod/latin/xlm/BatchInputConverter.kt index 19416fe42ed5fc6bcd4c3e714539a098ddd51f57..90e04af655f7a6cf4ab9681caf8fd7e53c247d66 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/BatchInputConverter.kt +++ b/java/src/org/futo/inputmethod/latin/xlm/BatchInputConverter.kt @@ -46,7 +46,7 @@ object BatchInputConverter { val dot = dot(directionFromLastCoord, directionFromNextCoord) // TODO: Figure out a good threshold - if(dot < 0.95) { + if(dot < 0.86) { val key = keyDetector.detectHitKey(coords[i].first, coords[i].second)?.label ?: continue if(s.isNotEmpty() && s.last() == key.first()) continue diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java index 38125dfcc7fffd6d8836080a413f2a95d9f8dd03..5b9aef40bd4487402f1eda219b5e45f1c5c86f62 100644 --- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java +++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModel.java @@ -47,7 +47,7 @@ public class LanguageModel { } if(mNativeState == 0){ - throw new RuntimeException("Failed to load R.raw.ml4_1_f16, R.raw.ml3_tokenizer model"); + throw new RuntimeException("Failed to load models " + modelPath); } } }; @@ -102,7 +102,9 @@ public class LanguageModel { int[] xCoords; int[] yCoords; + int inputMode = 0; if(isGesture) { + inputMode = 1; List<Integer> xCoordsList = new ArrayList<>(); List<Integer> yCoordsList = new ArrayList<>(); // Partial word is gonna be derived from batch data @@ -170,7 +172,7 @@ public class LanguageModel { String[] outStrings = new String[maxResults]; // TOOD: Pass multiple previous words information for n-gram. - getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities); + getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, inputMode, xCoords, yCoords, outStrings, outProbabilities); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); @@ -262,6 +264,7 @@ public class LanguageModel { long proximityInfoHandle, String context, String partialWord, + int inputMode, int[] inComposeX, int[] inComposeY, diff --git a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp index d3f7f6de27a177e9bc2404783c91e45fc87e4ebb..3a7ba3b13706320880ca2c4653758d3c5fa34078 100644 --- a/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp +++ b/native/jni/org_futo_inputmethod_latin_xlm_LanguageModel.cpp @@ -75,6 +75,8 @@ static void softmax(float * input, size_t input_len) { #define NUM_TOKEN_MIX 4 struct TokenMix { + float x; + float y; struct { float weight; llama_token token; @@ -99,6 +101,8 @@ struct LanguageModelState { int XBC; int XEC; + int XC0_SWIPE_MODE; + int LETTERS_TO_IDS[26]; } specialTokens; @@ -132,6 +136,7 @@ struct LanguageModelState { specialTokens.XBU = model->tokenToId("<XBU>"); specialTokens.XBC = model->tokenToId("<XBC>"); specialTokens.XEC = model->tokenToId("<XEC>"); + specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>"); specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>"); ASSERT(specialTokens.XBU != 0); @@ -173,22 +178,8 @@ struct LanguageModelState { TIME_START(GetcachedMixAmount) int i = 0; for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) { - bool flagged = false; - for(int m = 0; m < NUM_TOKEN_MIX; m++) { - if(std::abs(past_mixes[i].mixes[m].weight - mixes[i].mixes[m].weight) >= EPS){ - flagged = true; - break; - } - } - if(flagged) break; - - for(int m = 0; m < NUM_TOKEN_MIX; m++) { - if(past_mixes[i].mixes[m].weight >= EPS && past_mixes[i].mixes[m].token != mixes[i].mixes[m].token){ - flagged = true; - break; - } - } - if(flagged) break; + if(std::abs(past_mixes[i].x - mixes[i].x) >= EPS) break; + if(std::abs(past_mixes[i].y - mixes[i].y) >= EPS) break; } TIME_END(GetcachedMixAmount) @@ -200,6 +191,7 @@ struct LanguageModelState { TIME_START(PromptDecode) llama_context *ctx = ((LlamaAdapter *) model->adapter)->context; llama_batch batch = ((LlamaAdapter *) model->adapter)->batch; + LlamaAdapter *llamaAdapter = ((LlamaAdapter *)model->adapter); size_t n_embd = llama_n_embd(llama_get_model(ctx)); size_t n_vocab = llama_n_vocab(llama_get_model(ctx)); @@ -240,22 +232,41 @@ struct LanguageModelState { std::vector<float> embeds; + bool useEncoder = !llamaAdapter->encoder_weight.empty(); + AKLOGI("DecodePromptAndMixes: useEncoder=%d", useEncoder); + for(auto &mix : mixes) { int num_added = 0; std::vector<float> mix_f(n_embd, 0.0f); - for(auto &t : mix.mixes) { - if(t.weight < EPS) break; - float *src = ((LlamaAdapter *)model->adapter)->embeddings.data() + (t.token * n_embd); - float weight = t.weight; + if(useEncoder) { + num_added = 1; - for(size_t i = 0; i < n_embd; i++){ - mix_f[i] += src[i] * weight; + for(size_t i=0; i<n_embd; i++) { + mix_f[i] = llamaAdapter->encoder_bias[i] + + llamaAdapter->encoder_weight[i*2]*mix.x + + llamaAdapter->encoder_weight[i*2 + 1]*mix.y; } - num_added++; + //AKLOGI("DEBUG: pos %.4f %.4f got this: [%.4f %.4f %.4f %.4f %.4f %.4f %.4f ...", + // mix.x, mix.y, + // mix_f[0], mix_f[1], mix_f[2], mix_f[3], mix_f[4], mix_f[5], mix_f[6]); + } else { + for (auto &t: mix.mixes) { + if (t.weight < EPS) break; + + float *src = ((LlamaAdapter *) model->adapter)->embeddings.data() + + (t.token * n_embd); + float weight = t.weight; + + for (size_t i = 0; i < n_embd; i++) { + mix_f[i] += src[i] * weight; + } + + num_added++; + } } if(num_added == 0){ @@ -290,6 +301,10 @@ struct LanguageModelState { batch.n_seq_id, batch.seq_id, batch.logits, + + batch.all_pos_0, + batch.all_pos_1, + batch.all_seq_id }; batch.pos[0] = prompt.size() + h; @@ -386,7 +401,7 @@ struct LanguageModelState { llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size); } - std::vector<potential_sequence> next_sequences; + std::vector<potential_sequence> next_sequences; std::vector<std::pair<float, token_sequence>> outputs; @@ -543,7 +558,7 @@ struct LanguageModelState { return str_results; } - std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes) { + std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, std::string &word, const std::vector<TokenMix> &mixes, bool swipe_mode) { token_sequence next_context; if(context.length() != 0) { next_context = model->tokenize(trim(context) + " "); @@ -552,6 +567,10 @@ struct LanguageModelState { next_context.insert(next_context.begin(), 1); // BOS next_context.push_back(specialTokens.XBU); + if(swipe_mode) { + next_context.push_back(specialTokens.XC0_SWIPE_MODE); + } + auto decoding_result = DecodePromptAndMixes(next_context, mixes); auto results = Sample(decoding_result, 3); @@ -598,6 +617,7 @@ namespace latinime { jlong proximityInfo, jstring context, jstring partialWord, + jint inputMode, jintArray inComposeX, jintArray inComposeY, @@ -608,10 +628,8 @@ namespace latinime { LanguageModelState *state = reinterpret_cast<LanguageModelState *>(dict); ProximityInfo *pInfo = reinterpret_cast<ProximityInfo *>(proximityInfo); - size_t inputSize = env->GetArrayLength(inComposeX); - const char* cstr = env->GetStringUTFChars(context, nullptr); std::string contextString(cstr); env->ReleaseStringUTFChars(context, cstr); @@ -679,13 +697,17 @@ namespace latinime { index_value[j].first /= total_sum; } - AKLOGI("%d | Char %c, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i], + TokenMix results; + results.x = ((float)xCoordinates[i]) / ((float)pInfo->getKeyboardWidth()); + results.y = ((float)yCoordinates[i]) / ((float)pInfo->getKeyboardHeight()); + + AKLOGI("%d | Char %c, pos %.6f %.6f, nearest is %c at %.2f, then %c at %.2f, finally %c at %.2f", i, partialWordString[i], + results.x, results.y, (char)(pInfo->getKeyCodePoint(index_value[0].second)), (float)(index_value[0].first), (char)(pInfo->getKeyCodePoint(index_value[1].second)), (float)(index_value[1].first), (char)(pInfo->getKeyCodePoint(index_value[2].second)), (float)(index_value[2].first) ); - TokenMix results; for(int j=0; j<NUM_TOKEN_MIX; j++) { char c = (char) (pInfo->getKeyCodePoint(index_value[j].second)); @@ -719,7 +741,8 @@ namespace latinime { //} } else { isAutoCorrect = true; - results = state->PredictCorrection(contextString, partialWordString, mixes); + bool swipeMode = inputMode == 1; + results = state->PredictCorrection(contextString, partialWordString, mixes, swipeMode); //for(const auto &result : results) { // AKLOGI("LanguageModel correction %.2f [%s] -> [%s]", result.first, partialWordString.c_str(), result.second.c_str()); @@ -755,7 +778,7 @@ namespace latinime { }, { const_cast<char *>("getSuggestionsNative"), - const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;[I[I[Ljava/lang/String;[F)V"), + const_cast<char *>("(JJLjava/lang/String;Ljava/lang/String;I[I[I[Ljava/lang/String;[F)V"), reinterpret_cast<void *>(xlm_LanguageModel_getSuggestions) } }; diff --git a/native/jni/src/ggml/LanguageModel.cpp b/native/jni/src/ggml/LanguageModel.cpp index ea2da3948e06c9dee172083f7e7def2a1168abd6..b5167c8269cce44fd1fdf423ac1beb9fb2b60213 100644 --- a/native/jni/src/ggml/LanguageModel.cpp +++ b/native/jni/src/ggml/LanguageModel.cpp @@ -85,7 +85,44 @@ LanguageModel *LlamaAdapter::createLanguageModel(const std::string &paths) { auto tensor = llama_get_model_tensor(adapter->model, "token_embd.weight"); assert(tensor); - ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, adapter->embeddings.data(), adapter->embeddings.size()); + + if(tensor->type != GGML_TYPE_F32) { + ggml_internal_get_type_traits(tensor->type).to_float(tensor->data, + adapter->embeddings.data(), + adapter->embeddings.size()); + } else { + ASSERT((tensor->ne[0] * tensor->ne[1]) == adapter->embeddings.size()); + memcpy(adapter->embeddings.data(), tensor->data, adapter->embeddings.size() * sizeof(float)); + } + + auto encoder_weight_tensor = llama_get_model_tensor(adapter->model, "encoder.weight"); + auto encoder_bias_tensor = llama_get_model_tensor(adapter->model, "encoder.bias"); + if(encoder_weight_tensor && encoder_bias_tensor) { + adapter->encoder_weight.resize(llama_n_embd(adapter->model) * 2); + adapter->encoder_bias.resize(llama_n_embd(adapter->model)); + + if(encoder_weight_tensor->type != GGML_TYPE_F32) { + ggml_internal_get_type_traits(encoder_weight_tensor->type).to_float( + encoder_weight_tensor->data, + adapter->encoder_weight.data(), + adapter->encoder_weight.size() + ); + } else { + ASSERT((encoder_weight_tensor->ne[0] * encoder_weight_tensor->ne[1]) == adapter->encoder_weight.size()); + memcpy(adapter->encoder_weight.data(), encoder_weight_tensor->data, adapter->encoder_weight.size() * sizeof(float)); + } + + if(encoder_bias_tensor->type != GGML_TYPE_F32) { + ggml_internal_get_type_traits(encoder_bias_tensor->type).to_float( + encoder_bias_tensor->data, + adapter->encoder_bias.data(), + adapter->encoder_bias.size() + ); + } else { + ASSERT(encoder_bias_tensor->ne[0] == adapter->encoder_bias.size()); + memcpy(adapter->encoder_bias.data(), encoder_bias_tensor->data, adapter->encoder_bias.size() * sizeof(float)); + } + } return new LanguageModel(adapter); } diff --git a/native/jni/src/ggml/LanguageModel.h b/native/jni/src/ggml/LanguageModel.h index 6ca4c62a4722f178a424e5dda7fc7c37abc73414..81ac72cbdfd62edffcceff6f01f757f48e99e445 100644 --- a/native/jni/src/ggml/LanguageModel.h +++ b/native/jni/src/ggml/LanguageModel.h @@ -138,6 +138,10 @@ public: llama_batch batch; std::vector<float> embeddings; + + std::vector<float> encoder_weight = {}; + std::vector<float> encoder_bias = {}; + private: LlamaAdapter(); diff --git a/native/jni/src/ggml/llama.cpp b/native/jni/src/ggml/llama.cpp index 9b9ed24465c2d30e09be0cb9f433eb8b41de465a..1cdd8a296b97b4f518606701b015b6658fb88dcf 100644 --- a/native/jni/src/ggml/llama.cpp +++ b/native/jni/src/ggml/llama.cpp @@ -1368,6 +1368,9 @@ struct llama_model { llama_hparams hparams = {}; llama_vocab vocab; + struct ggml_tensor * pos_encoder; + struct ggml_tensor * pos_encoder_b; + struct ggml_tensor * tok_embd; struct ggml_tensor * pos_embd; struct ggml_tensor * tok_norm; @@ -2715,6 +2718,14 @@ static void llm_load_tensors( case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: { + if (strcmp(ml.get_tensor_name(0), "encoder.bias") == 0) { + model.pos_encoder_b = ml.create_tensor(ctx, "encoder.bias", {n_embd}, GGML_BACKEND_CPU); + model.pos_encoder = ml.create_tensor(ctx, "encoder.weight", {2, n_embd}, GGML_BACKEND_CPU); + } else { + model.pos_encoder_b = nullptr; + model.pos_encoder = nullptr; + } + model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); // output