Newer
Older
#define LOG_TAG "LatinIME: jni: LanguageModel"
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
#include <cstring> // for memset()
#include <vector>
#include "jni.h"
#include "jni_common.h"
#include "ggml/LanguageModel.h"
#include "defines.h"
#include "suggest/core/layout/proximity_info.h"
#define TIME_START(name) const int64_t start_##name = ggml_time_us();
#define TIME_END(name) const int64_t end_##name = ggml_time_us(); \
const int64_t time_taken_##name = (end_##name - start_##name) / 1000L; \
AKLOGI("%s: Time taken by %s: %d ms\n", __func__, #name, (int)time_taken_##name);
#else
#define TIME_START(name)
#define TIME_END(name)
#endif
#define RETURNVAL_AUTOCORRECT "autocorrect"
#define RETURNVAL_UNCERTAIN "uncertain"
#define RETURNVAL_CLUELESS "clueless"
static std::string trim(const std::string &s) {
auto start = s.begin();
while (start != s.end() && std::isspace(*start)) {
start++;
}
auto end = s.end();
do {
end--;
} while (std::distance(start, end) > 0 && std::isspace(*end));
return {start, end + 1};
}
template<typename T>
bool sortProbabilityPairDescending(const std::pair<float, T>& a, const std::pair<float, T>& b) {
return a.first > b.first;
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec) {
std::sort(vec.begin(), vec.end(), sortProbabilityPairDescending<T>);
}
template<typename T>
static inline void sortProbabilityPairVectorDescending(std::vector<std::pair<float, T>> &vec, size_t partial) {
if(partial > vec.size()) partial = vec.size();
std::partial_sort(vec.begin(), vec.begin() + partial, vec.end(), sortProbabilityPairDescending<T>);
}
typedef struct potential_sequence_data {
token_sequence tokens;
} potential_sequence_data;
// P = P(tokens[0]) * P(tokens[1]) * [...]
typedef std::pair<float, potential_sequence_data> potential_sequence;
typedef struct banned_sequence {
token_sequence sequence;
int hash;
int compute_sequence_hash(const token_sequence &seq) {
int hash = 0;
for(llama_token t : seq) {
hash = (hash + t) % 999999999;
}
return hash;
}
int append_sequence_hash(int hash, llama_token t) {
return (hash + t) % 999999999;
}
static void softmax(float * input, size_t input_len) {
float m = -INFINITY;
for (size_t i = 0; i < input_len; i++) {
if (input[i] > m) {
m = input[i];
}
}
float sum = 0.0;
for (size_t i = 0; i < input_len; i++) {
sum += expf(input[i] - m);
}
float offset = m + logf(sum);
for (size_t i = 0; i < input_len; i++) {
input[i] = expf(input[i] - offset);
}
}
#define NUM_TOKEN_MIX 4
struct TokenMix {
struct {
float weight;
llama_token token;
} mixes[NUM_TOKEN_MIX];
};
struct DecodeResult {
int logits_head;
int size;
};
enum WordCapitalizeMode {
IgnoredCapitals, // partialWord = "t" or partialWord = "test"
FirstCapital, // partialWord = "T" or partialWord = "Test"
AllCapitals // partialWord = "TE" or partialWord = "TEST"
};
Aleksandras Kostarevas
committed
bool isLowercase(unsigned char c, bool strict) {
if(strict) {
return islower(c);
} else {
return isupper(c) == 0;
}
}
Aleksandras Kostarevas
committed
bool isFirstCharLowercase(const char* str, bool strict) {
if (str == nullptr || str[0] == '\0')
return false;
Aleksandras Kostarevas
committed
return isLowercase(static_cast<unsigned char>(str[0]), strict);
Aleksandras Kostarevas
committed
bool hasLowercase(const char* str, bool strict) {
if (str == nullptr)
return false;
for (; *str != '\0'; ++str) {
Aleksandras Kostarevas
committed
if (isLowercase(static_cast<unsigned char>(*str), strict))
return true;
}
return false;
}
bool isExactMatch(const std::string &a, const std::string &b){
auto preprocess = [](const std::string &str) -> std::string {
std::string result;
for(char c : str) {
if(c != '\'' && c != '-' && c != ' ') {
}
}
return result;
};
return preprocess(a) == preprocess(b);
}
bool isTokenMixRoughlyEqual(const TokenMix &a, const TokenMix &b) {
return (a.mixes[0].token == b.mixes[0].token) && std::abs(a.mixes[0].weight - b.mixes[0].weight) < EPS &&
(a.mixes[1].token == b.mixes[1].token) && std::abs(a.mixes[1].weight - b.mixes[1].weight) < EPS &&
(a.mixes[2].token == b.mixes[2].token) && std::abs(a.mixes[2].weight - b.mixes[2].weight) < EPS &&
(a.mixes[3].token == b.mixes[3].token) && std::abs(a.mixes[3].weight - b.mixes[3].weight) < EPS;
}
std::unique_ptr<LanguageModel> model;
int XBU = 0;
int XBC = 0;
int XEC = 0;
std::vector<int> banned_start_of_word_tokens;
std::vector<int> banned_tokens_for_first_capital;
std::vector<int> banned_tokens_for_all_capitals;
std::vector<int> banned_tokens_word_separators; // probabilities add to space token
std::vector<int> general_banned_tokens;
} specialTokens;
bool Initialize(const std::string &paths){
model = std::unique_ptr<LanguageModel>(LlamaAdapter::createLanguageModel(paths));
if(!model) {
AKLOGE("GGMLDict: Could not load model");
return false;
}
specialTokens.SPACE = model->tokenToId("▁"); // ▁
specialTokens.DASH = model->tokenToId("-");
specialTokens.STAR = model->tokenToId("*");
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
if(model->adapter->hasFeature(FEATURE_AUTOCORRECT)) {
specialTokens.XBU = model->tokenToId("<XBU>");
specialTokens.XBC = model->tokenToId("<XBC>");
specialTokens.XEC = model->tokenToId("<XEC>");
specialTokens.LETTERS_TO_IDS[0] = model->tokenToId("<CHAR_A>");
ASSERT(specialTokens.XBU != 0);
ASSERT(specialTokens.XBC != 0);
ASSERT(specialTokens.XEC != 0);
ASSERT(specialTokens.LETTERS_TO_IDS[0] != 0);
for(int i = 1; i < 26; i++) {
specialTokens.LETTERS_TO_IDS[i] = specialTokens.LETTERS_TO_IDS[0] + i;
}
if(model->adapter->hasFeature(FEATURE_SWIPE_TYPING)) {
specialTokens.XC0_SWIPE_MODE = model->tokenToId("<XC0>");
ASSERT(specialTokens.XC0_SWIPE_MODE != 0);
}
} else {
specialTokens.XBU = -1;
specialTokens.XBC = -1;
specialTokens.XEC = -1;
}
specialTokens.banned_tokens_word_separators = { };
specialTokens.general_banned_tokens = { model->tokenToId("-▁") };
//int permitted_period_token = model->tokenToId(".");
const char *blacklist_symbols = ".!@#$%^&*()_=?/,\\][{};:\"><+`~|\r\n\t\x0b\x0c";
for(int i = 0; i < model->getVocabSize(); i++) {
//if(i == permitted_period_token) continue;
const char *token = model->getToken(i);
bool has_symbol = false;
for(char c : std::string(token)){
if(strchr(blacklist_symbols, c) != nullptr) {
has_symbol = true;
break;
}
}
specialTokens.banned_tokens_word_separators.emplace_back(i);
size_t n_vocab = llama_n_vocab(model->model());
const char *text = model->adapter->getToken(i);
Aleksandras Kostarevas
committed
if(isFirstCharLowercase(text, true)) {
specialTokens.banned_tokens_for_first_capital.push_back(i);
specialTokens.banned_tokens_for_all_capitals.push_back(i);
Aleksandras Kostarevas
committed
}else if(hasLowercase(text, true)){
specialTokens.banned_tokens_for_all_capitals.push_back(i);
}
if(text[0] == '\'' || text[0] == '-') {
specialTokens.banned_start_of_word_tokens.push_back(i);
}
}
bool transform_logits(float *logits, size_t n_vocab, bool is_first_token, bool allow_correction_token, WordCapitalizeMode capitals, llama_token prev_token){
for(size_t i = 0; i < n_vocab; i++) {
if(isnan(logits[i])){
return false;
}
}
softmax(logits, n_vocab);
for(int x : specialTokens.banned_tokens_word_separators) {
if(allow_correction_token && x == specialTokens.XEC) continue;
logits[specialTokens.SPACE] += std::max(0.0f, logits[x]);
logits[x] = -999.0f;
}
logits[specialTokens.SPACE] = -999.0f;
for(int i : specialTokens.banned_start_of_word_tokens) {
logits[i] = -999.0f;
}
}
for(int i : specialTokens.general_banned_tokens) {
logits[i] = -999.0f;
}
if(prev_token == specialTokens.DASH) {
logits[specialTokens.DASH] = -999.0f;
}
if(capitals == WordCapitalizeMode::FirstCapital && is_first_token) {
for(int i : specialTokens.banned_tokens_for_first_capital) {
logits[i] = -999.0f;
}
}else if(capitals == WordCapitalizeMode::AllCapitals) {
// Note: In case the word is something like "AMD's" we may not wish to ban lowercase completely
for(int i : specialTokens.banned_tokens_for_all_capitals) {
logits[i] = -999.0f;
}
std::vector<TokenMix> past_mixes = { };
int GetCachedMixAmount(const std::vector<TokenMix> &mixes) {
TIME_START(GetcachedMixAmount)
for(i = 0; i < std::min(past_mixes.size(), mixes.size()); i++) {
if(!isTokenMixRoughlyEqual(past_mixes[i], mixes[i])) break;
}
DecodeResult DecodePromptAndMixes(const token_sequence &prompt, const std::vector<TokenMix> &mixes) {
TIME_START(PromptDecode)
llama_context *ctx = model->context();
llama_batch batch = model->adapter->batch;
LlamaAdapter *llamaAdapter = model->adapter.get();
size_t n_embd = llama_n_embd(llama_get_model(ctx));
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
auto prompt_ff = transformer_context_fastforward(model->transformerContext, prompt, !mixes.empty());
int n_batch = llamaAdapter->n_batch;
int head = -1;
if(!prompt_ff.first.empty()) {
for (size_t b = 0; b < (prompt_ff.first.size() + n_batch - 1) / n_batch; b++) {
batch.n_tokens = std::min((int)n_batch, (int)(prompt_ff.first.size() - b*n_batch));
for (int i = 0; i < batch.n_tokens; i++) {
batch.token[i] = prompt_ff.first[n_batch*b + i];
batch.pos[i] = (llama_pos)(prompt_ff.second + n_batch*b + i);
batch.seq_id[i][0] = 0;
batch.n_seq_id[i] = 1;
batch.logits[i] = false;
}
batch.logits[batch.n_tokens - 1] = (int8_t)(mixes.empty());
if(mixes.empty()) head = batch.n_tokens - 1;
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt_ff.second, -1);
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() failed");
return {};
}
//AKLOGI("No need to recompute prompt, proceeding to mixes");
}
transformer_context_apply(model->transformerContext, prompt_ff);
TIME_END(PromptDecode)
TIME_START(EmbedMixing)
bool useEncoder = !llamaAdapter->encoder_weight.empty();
//AKLOGI("DecodePromptAndMixes: useEncoder=%d", useEncoder);
for(auto &mix : mixes) {
int num_added = 0;
std::vector<float> mix_f(n_embd, 0.0f);
if(useEncoder && mix.x >= 0.0f && mix.y >= 0.0f) {
for(size_t i=0; i<n_embd; i++) {
mix_f[i] = llamaAdapter->encoder_bias[i]
+ llamaAdapter->encoder_weight[i*2]*mix.x
+ llamaAdapter->encoder_weight[i*2 + 1]*mix.y;
//AKLOGI("DEBUG: pos %.4f %.4f got this: [%.4f %.4f %.4f %.4f %.4f %.4f %.4f ...",
// mix.x, mix.y,
// mix_f[0], mix_f[1], mix_f[2], mix_f[3], mix_f[4], mix_f[5], mix_f[6]);
} else {
for (auto &t: mix.mixes) {
if (t.weight < EPS) continue;
if (t.token < 0 || t.token >= (int)n_vocab) continue;
float *src = llamaAdapter->embeddings.data() +
(t.token * n_embd);
float weight = t.weight;
for (size_t i = 0; i < n_embd; i++) {
mix_f[i] += src[i] * weight;
}
num_added++;
}
}
if(num_added == 0){
AKLOGE("Somehow a token mix had 0 weight for everything");
ASSERT(false);
}
embeds.insert(embeds.end(), mix_f.begin(), mix_f.end());
size++;
}
TIME_END(EmbedMixing)
TIME_START(CachedMixAmount)
int n_tokens = int32_t(mixes.size());
int n_past = GetCachedMixAmount(mixes);
past_mixes = mixes;
if(!prompt_ff.first.empty()) n_past = 0; // We have to recompute embeds completely if prompt changed
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)prompt.size() + n_past, -1);
TIME_END(CachedMixAmount)
if(!embeds.empty()) {
TIME_START(DecodeEmbeds)
// TODO: This is only processing one embd at a time, increasing n_tokens doesn't seem to work
for(int h = n_past; h < n_tokens; h++ ) {
llama_batch embd_batch = {
1,
nullptr,
embeds.data() + h*n_embd,
batch.pos,
batch.n_seq_id,
batch.seq_id,
batch.logits,
batch.all_pos_0,
batch.all_pos_1,
batch.all_seq_id
batch.pos[0] = (llama_pos)(prompt.size() + h);
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = false;
if (llama_decode(ctx, embd_batch) != 0) {
AKLOGE("llama_decode() with embeds failed");
return {};
}
}
TIME_END(DecodeEmbeds)
TIME_START(DecodeXBC)
// We always force an XBC token after
size += 1;
batch.n_tokens = 1;
batch.token[0] = specialTokens.XBC;
batch.seq_id[0][0] = 0;
batch.n_seq_id[0] = 1;
batch.logits[0] = true;
batch.pos[0] = (llama_pos)(prompt.size() + n_tokens);
head = 0;
if (llama_decode(ctx, batch) != 0) {
AKLOGE("llama_decode() for XBC failed");
return {};
}
TIME_END(DecodeXBC)
ASSERT(size == prompt.size() + n_tokens + 1);
ASSERT(size == prompt.size() + (embeds.size() / n_embd) + 1);
ASSERT(size == prompt.size());
//ASSERT(head == prompt_ff.first.size() - 1);
//AKLOGI("-- Decode");
//AKLOGI("First we processed the prompt (%d):", prompt_ff.first.size());
//for(auto t : prompt) {
// AKLOGI(" - [%s]", model->getToken(t));
//}
//AKLOGI("Then %d embeds (cached %d)", embeds.size(), n_past);
//AKLOGI("The final size is %d and head is %d", size, head);
llama_kv_cache_seq_rm(ctx, 0, (llama_pos)size, -1);
TIME_END(FinishRm)
return {
head,
bool MatchesBanned(const token_sequence &prior, int prior_hash, llama_token next, const std::vector<banned_sequence> &banned_sequences) const {
int new_hash = append_sequence_hash(prior_hash, next);
for(const auto &banned_sequence : banned_sequences) {
if(banned_sequence.sequence.back() == specialTokens.STAR && (prior.size() >= banned_sequence.sequence.size() - 1)) {
bool matches = true;
for(size_t i = 0; i < banned_sequence.sequence.size() - 1; i++) {
if(prior[i] != banned_sequence.sequence[i]) {
matches = false;
break;
}
}
if(matches){
auto priorTxt = model->decode(prior);
auto nextTxt = model->decode({next});
auto bannedTxt = model->decode(banned_sequence.sequence);
//AKLOGI("Tokens [%s] + [%s] matches banned wildcard [%s]", priorTxt.c_str(), nextTxt.c_str(), bannedTxt.c_str());
return true;
}
}else if((banned_sequence.sequence.size() == prior.size() + 1) && (banned_sequence.hash == new_hash)) {
if(banned_sequence.sequence.back() == next) {
bool matches = true;
for(size_t i = 0; i < prior.size(); i++) {
if(prior[i] != banned_sequence.sequence[i]) {
matches = false;
break;
}
}
if(matches) {
auto priorTxt = model->decode(prior);
auto nextTxt = model->decode({next});
auto bannedTxt = model->decode(banned_sequence.sequence);
//AKLOGI("Tokens [%s] + [%s] matches banned [%s]", priorTxt.c_str(), nextTxt.c_str(), bannedTxt.c_str());
return true;
}
}
}
}
return false;
}
std::vector<std::pair<float, token_sequence>> Sample(DecodeResult decodeResult, int n_results, WordCapitalizeMode capitals, const std::vector<banned_sequence> &banned_sequences) {
llama_context *ctx = model->context();
llama_batch batch = model->adapter->batch;
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
std::vector<potential_sequence> sequences;
bool allow_correction_token = decodeResult.logits_head == 0;
float *logits = llama_get_logits_ith(ctx, decodeResult.logits_head);
//AKLOGI("Value of [the ] before transform: %f", logits[561]);
bool is_bugged = logits[561] == 0.0f;
if(!transform_logits(logits, n_vocab, true, allow_correction_token, capitals, 0)) {
AKLOGE("logits have NaN!");
return { };
}
is_bugged = is_bugged && logits[561] < -990.0f && logits[561] > -1100.0f;
if(is_bugged) {
AKLOGE("Detected bug!!!! Trying to mitigate. Let's just reset cache and exit");
llama_kv_cache_seq_rm(ctx, -1, -1, -1);
model->transformerContext.active_context = { };
return { };
}
//AKLOGI("Value of [the ] after transform: %f", logits[561]);
std::vector<std::pair<float, int>> index_value;
index_value.clear();
for (size_t i = 0; i < n_vocab; i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, n_results * 2);
const token_sequence blank = {};
for(int i = 0; i < n_results * 2; i++) {
if(MatchesBanned(blank, 0, index_value[i].second, banned_sequences)) {
index_value[i].first = 0.0f;
}
}
sortProbabilityPairVectorDescending(index_value, n_results);
for (int i = 0; i < n_results; i++) {
sequences.emplace_back(
index_value[i].first,
{index_value[i].second},
i
}
);
}
is_bugged = true;
for(const auto &seq : sequences) {
if(seq.second.tokens.front() > 48 || seq.first != sequences[0].first) {
is_bugged = false;
break;
}
}
if(is_bugged) {
AKLOGE("Detected bug2!!!! Trying to mitigate. Let's just reset cache and exit");
llama_kv_cache_seq_rm(ctx, -1, -1, -1);
model->transformerContext.active_context = { };
return { };
}
for (auto &sequence: sequences) {
if (sequence.second.seq_id == 0) continue;
llama_kv_cache_seq_cp(ctx, 0, sequence.second.seq_id, 0, decodeResult.size);
std::vector<potential_sequence> next_sequences;
std::vector<std::pair<float, token_sequence>> outputs;
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
next_sequences.clear();
for (auto sequence: std::move(sequences)) {
int next_token = sequence.second.tokens[sequence.second.tokens.size() - 1];
// Check if this is the end of correction
if (next_token == specialTokens.XEC) {
token_sequence resulting_tokens = std::move(sequence.second.tokens);
resulting_tokens.resize(resulting_tokens.size() - 1);
outputs.emplace_back(sequence.first, resulting_tokens);
continue;
}
// Check if this is the end of a word
std::string token = model->getToken(next_token);
if (token.size() >= 3 && (token[token.size() - 1] == '\x81') &&
(token[token.size() - 2] == '\x96') && token[token.size() - 3] == '\xe2') {
outputs.emplace_back(sequence.first, std::move(sequence.second.tokens));
continue;
}
next_sequences.emplace_back(sequence);
}
sequences = next_sequences;
next_sequences.clear();
size_t remaining_count = n_results - outputs.size();
batch.n_tokens = 0;
//for(int i=0; i<batch.n_tokens; i++) batch.logits[i] = false;
for (auto &sequence: sequences) {
batch.token[batch.n_tokens] = sequence.second.tokens[sequence.second.tokens.size() - 1];
batch.pos[batch.n_tokens] = (llama_pos)(decodeResult.size + (sequence.second.tokens.size() - 1));
batch.seq_id[batch.n_tokens][0] = sequence.second.seq_id;
batch.n_seq_id[batch.n_tokens] = 1;
batch.logits[batch.n_tokens] = true;
batch.n_tokens += 1;
}
ASSERT(batch.n_tokens == (int)remaining_count); // usually 3
if (batch.n_tokens == 0) {
break;
}
llama_decode(ctx, batch);
for (int seq = 0; seq < (int)remaining_count; seq++) {
const potential_sequence &parent_seq = sequences[seq];
auto hash = compute_sequence_hash(parent_seq.second.tokens);
llama_token prev_token = 0;
if(!parent_seq.second.tokens.empty()) prev_token = parent_seq.second.tokens.back();
logits = llama_get_logits_ith(ctx, seq);
if(!transform_logits(logits, n_vocab, false, allow_correction_token, capitals, prev_token)) {
AKLOGE("Logits have NaN!");
return { };
}
index_value.clear();
for (size_t i = 0; i < n_vocab; i++) {
index_value.emplace_back(logits[i], i);
}
sortProbabilityPairVectorDescending(index_value, remaining_count * 2);
for(size_t i = 0; i < remaining_count * 2; i++) {
if(MatchesBanned(parent_seq.second.tokens, hash, index_value[i].second, banned_sequences)) {
index_value[i].first = 0.0f;
}
}
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
sortProbabilityPairVectorDescending(index_value, remaining_count);
for (size_t i = 0; i < remaining_count; i++) {
token_sequence new_sequence = parent_seq.second.tokens;
new_sequence.push_back(index_value[i].second);
if (index_value[i].first > 1.0f || index_value[i].first < 0.0f) {
AKLOGE("Expected index_value to be probability [%.2f]",
index_value[i].first);
}
if (sequences[i].first > 1.0f || sequences[i].first < 0.0f) {
AKLOGE("Expected sequences value to be probability [%.2f]",
sequences[i].first);
}
next_sequences.emplace_back(
index_value[i].first * sequences[i].first,
potential_sequence_data{
new_sequence,
parent_seq.second.seq_id
}
);
}
}
sortProbabilityPairVectorDescending(next_sequences, remaining_count);
next_sequences.resize(remaining_count);
sequences.clear();
// In some cases we may have picked a sequence from the same parent sequence
// We must re-assign the seq_id
int seq_id_use_count[n_results];
for (int i = 0; i < n_results; i++) seq_id_use_count[i] = 0;
for (auto &seq: next_sequences) seq_id_use_count[seq.second.seq_id] += 1;
for (auto &seq: next_sequences) {
if (seq_id_use_count[seq.second.seq_id] > 1) {
int old_seq_id = seq.second.seq_id;
int new_seq_id = -1;
for (int i = 0; i < n_results; i++) {
if (seq_id_use_count[i] == 0) {
new_seq_id = i;
break;
}
}
if (new_seq_id == -1) {
AKLOGE("Couldn't find an empty sequence id to use. This should never happen.");
return {};
}
seq_id_use_count[old_seq_id]--;
seq_id_use_count[new_seq_id]++;
llama_kv_cache_seq_cp(
ctx,
old_seq_id,
new_seq_id,
0, // could start from prompt.size()
(llama_pos)(decodeResult.size + (seq.second.tokens.size() - 1))
);
seq.second.seq_id = new_seq_id;
}
}
sequences = next_sequences;
}
for (int i = 1; i < n_results; i++) {
llama_kv_cache_seq_rm(ctx, i, 0, -1);
}
return outputs;
}
std::vector<std::pair<float, std::string>> PredictNextWord(const std::string &context, const std::vector<std::string> &banned_words) {
std::vector<banned_sequence> banned_sequences;
for(const std::string &bw : banned_words) {
auto tokenized = model->tokenize(trim(bw) + " ");
banned_sequences.push_back({ tokenized, compute_sequence_hash(tokenized) });
auto tokenized2 = model->tokenize(trim(bw));
banned_sequences.push_back({ tokenized2, compute_sequence_hash(tokenized2) });
}
token_sequence next_context = model->tokenize(trim(context) + " ");
next_context.insert(next_context.begin(), 1); // BOS
auto decoding_result = DecodePromptAndMixes(next_context, { });
auto results = Sample(decoding_result, 3, WordCapitalizeMode::IgnoredCapitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
str_results.emplace_back(result.first, model->decode(result.second));
}
return str_results;
std::vector<std::pair<float, std::string>> PredictCorrection(const std::string &context, const std::vector<TokenMix> &mixes, bool swipe_mode, WordCapitalizeMode capitals, const std::vector<std::string> &banned_words) {
if(specialTokens.XBU == -1) return { };
std::vector<banned_sequence> banned_sequences;
for(const std::string &bw : banned_words) {
auto tokenized = model->tokenize(trim(bw) + " ");
banned_sequences.push_back({ tokenized, compute_sequence_hash(tokenized) });
auto tokenized2 = model->tokenize(trim(bw));
banned_sequences.push_back({ tokenized2, compute_sequence_hash(tokenized2) });
}
token_sequence next_context;
next_context = model->tokenize(trim(context) + " ");
}
next_context.insert(next_context.begin(), 1); // BOS
next_context.push_back(specialTokens.XBU);
if(swipe_mode) {
next_context.push_back(specialTokens.XC0_SWIPE_MODE);
}
auto decoding_result = DecodePromptAndMixes(next_context, mixes);
auto results = Sample(decoding_result, 3, capitals, banned_sequences);
std::vector<std::pair<float, std::string>> str_results;
for(const auto& result : results) {
str_results.emplace_back(result.first, model->decode(result.second));
}
return str_results;
struct SuggestionItemToRescore {
int index;
int originalScore;
float transformedScore;
std::string word;
token_sequence tokens;
};
namespace latinime {
static jlong xlm_LanguageModel_open(JNIEnv *env, jclass clazz, jstring modelDir) {
AKLOGI("open LM");
const jsize sourceDirUtf8Length = env->GetStringUTFLength(modelDir);
if (sourceDirUtf8Length <= 0) {
AKLOGE("DICT: Can't get sourceDir string");
return 0;
}
char sourceDirChars[sourceDirUtf8Length + 1];
env->GetStringUTFRegion(modelDir, 0, env->GetStringLength(modelDir), sourceDirChars);
sourceDirChars[sourceDirUtf8Length] = '\0';
auto *state = new LanguageModelState();
if(!state->Initialize(sourceDirChars)) {
return 0;
}
return reinterpret_cast<jlong>(state);
}
static void xlm_LanguageModel_close(JNIEnv *env, jclass clazz, jlong statePtr) {
GGML_UNUSED(env);
GGML_UNUSED(clazz);
AKLOGI("LanguageModel_close called!");
auto *state = reinterpret_cast<LanguageModelState *>(statePtr);
if(state == nullptr) return;
delete state;
}
// (JLjava/lang/String;[Ljava/lang/String;[I[I)V
// TODO: This will also need caching to not make things extremely slow by recomputing every time
static void xlm_LanguageModel_rescoreSuggestions(JNIEnv *env, jclass clazz,
jlong dict,
jstring context,
jobjectArray inWords,
jintArray inScores,
jintArray outScores
) {
GGML_UNUSED(clazz);
auto *state = reinterpret_cast<LanguageModelState *>(dict);
std::string contextString = jstring2string(env, context);
jsize inputSize = env->GetArrayLength(inScores);
int scores[inputSize];
env->GetIntArrayRegion(inScores, 0, inputSize, scores);
float maxScore = -INFINITY;
float minScore = INFINITY;
for(int score : scores) {
auto scoref = (float)score;
if(scoref > maxScore) maxScore = scoref;
if(scoref < minScore) minScore = scoref;
}
minScore -= (maxScore - minScore) * 0.33f;
std::vector<SuggestionItemToRescore> words;
jsize numWords = env->GetArrayLength(inWords);
for(jsize i=0; i<numWords; i++) {
auto jstr = (jstring)env->GetObjectArrayElement(inWords, i);
SuggestionItemToRescore item = {
(int) i,
scores[i],
((float)scores[i] - minScore) / (maxScore - minScore),
jstring2string(env, jstr),
{}
};
item.tokens = state->model->tokenize(trim(item.word) + " ");
words.push_back(item);
env->DeleteLocalRef(jstr);
llama_context *ctx = state->model->context();
size_t n_vocab = llama_n_vocab(llama_get_model(ctx));
token_sequence next_context = state->model->tokenize(trim(contextString) + " ");
next_context.insert(next_context.begin(), 1); // BOS
auto decoding_result = state->DecodePromptAndMixes(next_context, { });
float *logits = llama_get_logits_ith(ctx, decoding_result.logits_head);
softmax(logits, n_vocab);
AKLOGI("Iter");
for(auto &entry : words) {
float pseudoScore = logits[entry.tokens[0]] / (float)entry.tokens.size();
AKLOGI("Word [%s], %d tokens, prob[0] = %.8f", entry.word.c_str(), entry.tokens.size(), pseudoScore);
entry.transformedScore *= pseudoScore * 1000.0f;
}
// TODO: Transform here
// Output scores
jint *outArray = env->GetIntArrayElements(outScores, nullptr);
for(const auto &entry : words) {
outArray[entry.index] = (jint)(entry.transformedScore * (maxScore - minScore) + minScore);
}
env->ReleaseIntArrayElements(outScores, outArray, 0);
}
static void xlm_LanguageModel_getSuggestions(JNIEnv *env, jclass clazz,
// inputs
jlong dict,
jstring context,
jstring partialWord,
jintArray inComposeX,
jintArray inComposeY,
jobjectArray bannedWordsArray,
// outputs
jobjectArray outPredictions,
jfloatArray outProbabilities
) {