package org.futo.inputmethod.latin.xlm; import android.content.Context; import android.util.Log; import org.futo.inputmethod.latin.Dictionary; import org.futo.inputmethod.latin.NgramContext; import org.futo.inputmethod.latin.R; import org.futo.inputmethod.latin.SuggestedWords; import org.futo.inputmethod.latin.common.ComposedData; import org.futo.inputmethod.latin.common.InputPointers; import org.futo.inputmethod.latin.settings.SettingsValuesForSuggestion; import java.io.File; import java.io.FileOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.util.ArrayList; import java.util.Arrays; import java.util.Locale; import java.util.function.IntPredicate; public class LanguageModel extends Dictionary { static long mNativeState = 0; private String getPathToModelResource(Context context, int modelResource, int tokenizerResource, boolean forceDelete) { File outputDir = context.getCacheDir(); File outputFile = new File(outputDir, "ggml-model-" + String.valueOf(modelResource) + ".gguf"); File outputFileTokenizer = new File(outputDir, "tokenizer-" + String.valueOf(tokenizerResource) + ".tokenizer"); if(forceDelete && outputFile.exists()) { outputFile.delete(); outputFileTokenizer.delete(); } if((!outputFile.exists()) || forceDelete){ // FIXME: We save this to a random temporary file so that we can have a path instead of an InputStream InputStream is = context.getResources().openRawResource(modelResource); InputStream is_t = context.getResources().openRawResource(tokenizerResource); try { OutputStream os = new FileOutputStream(outputFile); int read = 0; byte[] bytes = new byte[1024]; while ((read = is.read(bytes)) != -1) { os.write(bytes, 0, read); } os.flush(); os.close(); is.close(); OutputStream os_t = new FileOutputStream(outputFileTokenizer); read = 0; while ((read = is_t.read(bytes)) != -1) { os_t.write(bytes, 0, read); } os_t.flush(); os_t.close(); is_t.close(); } catch(IOException e) { e.printStackTrace(); throw new RuntimeException("Failed to write model asset to file"); } } return outputFile.getAbsolutePath() + ":" + outputFileTokenizer.getAbsolutePath(); } Thread initThread = null; public LanguageModel(Context context, String dictType, Locale locale) { super(dictType, locale); initThread = new Thread() { @Override public void run() { if(mNativeState != 0) return; String modelPath = getPathToModelResource(context, R.raw.ml3, R.raw.ml3_tokenizer, true); mNativeState = openNative(modelPath); if(mNativeState == 0){ modelPath = getPathToModelResource(context, R.raw.ml3, R.raw.ml3_tokenizer, true); mNativeState = openNative(modelPath); } if(mNativeState == 0){ throw new RuntimeException("Failed to load R.raw.ml3, R.raw.ml3_tokenizer model"); } } }; initThread.start(); } @Override public ArrayList<SuggestedWords.SuggestedWordInfo> getSuggestions( ComposedData composedData, NgramContext ngramContext, long proximityInfoHandle, SettingsValuesForSuggestion settingsValuesForSuggestion, int sessionId, float weightForLocale, float[] inOutWeightOfLangModelVsSpatialModel ) { Log.d("LanguageModel", "getSuggestions called"); if (mNativeState == 0) { Log.d("LanguageModel", "Exiting becuase mNativeState == 0"); return null; } if (initThread != null && initThread.isAlive()){ Log.d("LanguageModel", "Exiting because initThread"); return null; } final InputPointers inputPointers = composedData.mInputPointers; final boolean isGesture = composedData.mIsBatchMode; final int inputSize; inputSize = inputPointers.getPointerSize(); String context = ngramContext.extractPrevWordsContext().replace(NgramContext.BEGINNING_OF_SENTENCE_TAG, " ").trim(); if(!ngramContext.fullContext.isEmpty()) { context = ngramContext.fullContext.trim(); } // Trim the context while(context.length() > 128) { if(context.contains("\n")) { context = context.substring(context.indexOf("\n") + 1).trim(); }else if(context.contains(".") || context.contains("?") || context.contains("!")) { int v = Arrays.stream( new int[]{ context.indexOf("."), context.indexOf("?"), context.indexOf("!") }).filter(i -> i != -1).min().orElse(-1); if(v == -1) break; // should be unreachable context = context.substring(v + 1).trim(); } else { break; } } String partialWord = composedData.mTypedWord; if(!partialWord.isEmpty() && context.endsWith(partialWord)) { context = context.substring(0, context.length() - partialWord.length()).trim(); } if(!partialWord.isEmpty()) { partialWord = partialWord.trim(); } // TODO: We may want to pass times too, and adjust autocorrect confidence // based on time (taking a long time to type a char = trust the typed character // more, speed typing = trust it less) int[] xCoordsI = composedData.mInputPointers.getXCoordinates(); int[] yCoordsI = composedData.mInputPointers.getYCoordinates(); float[] xCoords = new float[composedData.mInputPointers.getPointerSize()]; float[] yCoords = new float[composedData.mInputPointers.getPointerSize()]; for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) xCoords[i] = (float)xCoordsI[i]; for(int i=0; i<composedData.mInputPointers.getPointerSize(); i++) yCoords[i] = (float)yCoordsI[i]; int maxResults = 128; float[] outProbabilities = new float[maxResults]; String[] outStrings = new String[maxResults]; // TOOD: Pass multiple previous words information for n-gram. getSuggestionsNative(mNativeState, proximityInfoHandle, context, partialWord, xCoords, yCoords, outStrings, outProbabilities); final ArrayList<SuggestedWords.SuggestedWordInfo> suggestions = new ArrayList<>(); int kind = SuggestedWords.SuggestedWordInfo.KIND_PREDICTION; boolean mustNotAutocorrect = false; for(int i=0; i<maxResults; i++) { if (outStrings[i] == null) continue; if(!partialWord.isEmpty() && partialWord.trim().equalsIgnoreCase(outStrings[i].trim())) { // If this prediction matches the partial word ignoring case, and this is the top // prediction, then we can break. if(i == 0) { break; } else { // Otherwise, we cannot autocorrect to the top prediction unless the model is // super confident about this if(outProbabilities[i] * 8.0f >= outProbabilities[0]) { mustNotAutocorrect = true; } } } } if(!partialWord.isEmpty() && !mustNotAutocorrect) { kind = SuggestedWords.SuggestedWordInfo.KIND_WHITELIST | SuggestedWords.SuggestedWordInfo.KIND_FLAG_APPROPRIATE_FOR_AUTO_CORRECTION; } for(int i=0; i<maxResults; i++) { if(outStrings[i] == null) continue; String word = outStrings[i].trim(); suggestions.add(new SuggestedWords.SuggestedWordInfo( word, context, (int)(outProbabilities[i] * 100.0f), kind, this, 0, 0 )); } if(kind == SuggestedWords.SuggestedWordInfo.KIND_PREDICTION) { // TODO: Forcing the thing to appear for (int i = suggestions.size(); i < 3; i++) { String word = " "; for (int j = 0; j < i; j++) word += " "; suggestions.add(new SuggestedWords.SuggestedWordInfo(word, context, 1, kind, this, 0, 0)); } } Log.d("LanguageModel", "returning " + String.valueOf(suggestions.size()) + " suggestions"); return suggestions; } private synchronized void closeInternalLocked() { try { if (initThread != null) initThread.join(); } catch (InterruptedException e) { e.printStackTrace(); } /*if (mNativeState != 0) { closeNative(mNativeState); mNativeState = 0; }*/ } @Override protected void finalize() throws Throwable { try { closeInternalLocked(); } finally { super.finalize(); } } @Override public boolean isInDictionary(String word) { return false; } private static native long openNative(String sourceDir); private static native void closeNative(long state); private static native void getSuggestionsNative( // inputs long state, long proximityInfoHandle, String context, String partialWord, float[] inComposeX, float[] inComposeY, // outputs String[] outStrings, float[] outProbs ); }