Skip to content
Snippets Groups Projects
Commit ef831148 authored by abb128's avatar abb128
Browse files

Add LanguageModel wrapper, word strategies

parent a104e952
No related tags found
No related merge requests found
Showing
with 466 additions and 125 deletions
......@@ -25,6 +25,7 @@ LATIN_IME_CORE_SRC_FILES := \
ggml/common.cpp \
ggml/context.cpp \
ggml/gpt_neox.cpp \
ggml/LanguageModel.cpp \
$(addprefix dictionary/header/, \
header_policy.cpp \
header_read_write_utils.cpp) \
......
This diff is collapsed.
......@@ -114,6 +114,8 @@ class DictionaryStructureWithBufferPolicy {
virtual bool isCorrupted() const = 0;
virtual int getWordStrategy(const char *text) const = 0;
protected:
DictionaryStructureWithBufferPolicy() {}
......
......@@ -657,6 +657,18 @@ int Ver4PatriciaTriePolicy::getTerminalPtNodePosFromWordId(const int wordId) con
return wordId == NOT_A_WORD_ID ? NOT_A_DICT_POS : wordId;
}
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in getWordId().");
}
return strategy;
}
} // namespace v402
} // namespace backward
} // namespace latinime
......@@ -139,6 +139,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted;
}
int getWordStrategy(const char *text) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);
......
......@@ -600,4 +600,15 @@ int Ver4PatriciaTriePolicy::getNextWordAndNextToken(const int token, int *const
return nextToken;
}
int Ver4PatriciaTriePolicy::getWordStrategy(const char *text) const {
DynamicPtReadingHelper readingHelper(&mNodeReader, &mPtNodeArrayReader);
readingHelper.initWithPtNodeArrayPos(getRootPosition());
const int strategy = readingHelper.searchWordAndReturnStrategy(text);
if (readingHelper.isError()) {
mIsCorrupted = true;
AKLOGE("Dictionary reading error in createAndGetAllChildDicNodes().");
}
return strategy;
}
} // namespace latinime
......@@ -118,6 +118,7 @@ class Ver4PatriciaTriePolicy : public DictionaryStructureWithBufferPolicy {
return mIsCorrupted;
}
int getWordStrategy(const char *text) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Ver4PatriciaTriePolicy);
......
//
// Created by alex on 7/24/23.
//
#include "LanguageModel.h"
LanguageModelAdapter::~LanguageModelAdapter() {};
LanguageModel::LanguageModel(LanguageModelAdapter *adapter): adapter(adapter) {
}
int GPTNeoXAdapter::getVocabSize() const {
return model.hparams.n_vocab;
}
const char *GPTNeoXAdapter::getToken(int id) const {
return vocab.id_to_token.at(id).c_str();
}
bool GPTNeoXAdapter::eval(int nPast, token_sequence input, std::vector<float> &outLogits) {
// TODO
ASSERT(nPast + input.size() < model.hparams.n_ctx);
return gpt_neox_eval(model, numThreads, nPast, input, outLogits, memPerToken);
}
std::vector<int> GPTNeoXAdapter::tokenize(const char *text) {
return gpt_tokenize(vocab, text);
}
std::string GPTNeoXAdapter::decode(const token_sequence &tokens) const {
// For now we just merge the tokens together, this may need to be different for other languages and unicode
size_t length = 0;
for(int token : tokens) length += strlen(getToken(token));
std::string result(length);
for(int token : tokens) result.append(getToken(token));
return result;
}
LanguageModel *GPTNeoXAdapter::createLanguageModel(const char *path) {
auto adapter = new GPTNeoXAdapter();
bool result = gpt_neox_model_load(path, adapter->model, adapter->vocab);
if(!result) {
delete adapter;
return nullptr;
}
return new LanguageModel(adapter);
}
GPTNeoXAdapter::GPTNeoXAdapter() = default;
//
// Created by alex on 7/24/23.
//
#ifndef LATINIME_LANGUAGEMODEL_H
#define LATINIME_LANGUAGEMODEL_H
#include <vector>
#include <unordered_set>
#include "context.h"
#include "defines.h"
#include "gpt_neox.h"
class LanguageModelAdapter {
public:
int numThreads = 4;
virtual int getVocabSize() const = 0;
virtual const char *getToken(int id) const = 0;
virtual bool eval(int nPast, token_sequence input, std::vector<float> &outLogits) = 0;
virtual std::vector<int> tokenize(const char *text) = 0;
virtual std::string decode(const token_sequence &tokens) const = 0;
virtual ~LanguageModelAdapter() = 0;
};
class LanguageModel {
public:
LanguageModel(LanguageModelAdapter *adapter);
// Tokenizes the given text to tokens
AK_FORCE_INLINE std::vector<int> tokenize(const char *text) const {
return adapter->tokenize(text);
}
AK_FORCE_INLINE std::vector<int> tokenize(const std::string &text) const {
return tokenize(text.c_str());
}
AK_FORCE_INLINE std::string decode(const token_sequence &tokens) const {
return adapter->decode(tokens);
}
// Fast forward the context
AK_FORCE_INLINE void updateContext(const std::vector<int> &newContext) {
auto result = transformer_context_fastforward(transformerContext, newContext);
pendingEvaluationSequence = result.first;
pendingNPast = result.second;
pendingContext = newContext;
}
AK_FORCE_INLINE void updateContext(const char *text) {
return updateContext(tokenize(text));
}
AK_FORCE_INLINE void pushToContext(int token) {
pendingContext.push_back(token);
updateContext(pendingContext);
}
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> infer() {
if(pendingEvaluationSequence.empty()){
AKLOGI("LanguageModel: evaluation skipped due to empty pending evaluation sequence\n");
return outLogits;
}
if(!adapter->eval(pendingNPast, pendingEvaluationSequence, outLogits)) {
ASSERT(false);
}
transformer_context_apply(transformerContext, {pendingEvaluationSequence, pendingNPast});
pendingEvaluationSequence.clear();
return outLogits;
}
// Infers the given tokens on top of the active context without updating cache.
// TODO: This method returns a copy of 128kB of data
AK_FORCE_INLINE std::vector<float> temporarilyInfer(const std::vector<int> &tokens) {
ASSERT(pendingEvaluationSequence.empty());
ASSERT(!tokens.empty());
if(!adapter->eval(transformerContext.active_context.size(), tokens, tmpOutLogits)) {
ASSERT(false);
}
return tmpOutLogits;
}
AK_FORCE_INLINE int getVocabSize() const {
return adapter->getVocabSize();
}
AK_FORCE_INLINE const char *getToken(int token) const {
return adapter->getToken(token);
}
AK_FORCE_INLINE bool isPendingEvaluation() const {
return pendingEvaluationSequence.size() > 0;
}
private:
token_sequence pendingContext;
token_sequence pendingEvaluationSequence;
int pendingNPast = 0;
LanguageModelAdapter *adapter;
transformer_context transformerContext;
std::vector<float> outLogits;
std::vector<float> tmpOutLogits;
std::unordered_set<int> punctIds;
};
class GPTNeoXAdapter : public LanguageModelAdapter {
public:
int getVocabSize() const;
const char *getToken(int id) const;
bool eval(int nPast, token_sequence input, std::vector<float> &outLogits);
virtual std::vector<int> tokenize(const char *text);
virtual std::string decode(const token_sequence &tokens) const;
static LanguageModel *createLanguageModel(const char *path);
private:
GPTNeoXAdapter();
gpt_vocab vocab;
gpt_neox_model model;
size_t memPerToken = 0;
};
#endif //LATINIME_LANGUAGEMODEL_H
......@@ -196,6 +196,11 @@ int Dictionary::getNextWordAndNextToken(const int token, int *const outCodePoint
token, outCodePoints, outCodePointCount);
}
int Dictionary::getWordStrategy(const char *text) const {
TimeKeeper::setCurrentTime();
return mDictionaryStructureWithBufferPolicy->getWordStrategy(text);
}
void Dictionary::logDictionaryInfo(JNIEnv *const env) const {
int dictionaryIdCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
int versionStringCodePointBuffer[HEADER_ATTRIBUTE_BUFFER_SIZE];
......
......@@ -118,6 +118,7 @@ class Dictionary {
void logDictionaryInfo(JNIEnv *const env) const;
int getWordStrategy(const char *word) const;
private:
DISALLOW_IMPLICIT_CONSTRUCTORS(Dictionary);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment