Skip to content
Snippets Groups Projects
Commit f31db527 authored by abb128's avatar abb128
Browse files

Add whisper.cpp

parent 7ce4f317
No related branches found
No related tags found
2 merge requests!7Merge model-metadata to master,!4Merge lm-2-finetuning-whisperggml into lm-2-finetuning
Showing
with 7580 additions and 28 deletions
......@@ -7,26 +7,22 @@ import android.os.Build
import android.os.PowerManager
import androidx.annotation.RequiresApi
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.setValue
import androidx.core.app.NotificationCompat
import androidx.datastore.preferences.core.intPreferencesKey
import androidx.work.Constraints
import androidx.work.CoroutineWorker
import androidx.work.ForegroundInfo
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.PeriodicWorkRequest
import androidx.work.WorkManager
import androidx.work.WorkerParameters
import androidx.work.Constraints
import androidx.work.PeriodicWorkRequest
import androidx.work.OneTimeWorkRequestBuilder
import androidx.datastore.preferences.core.intPreferencesKey
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.setSetting
import org.futo.inputmethod.latin.uix.getSetting
import org.futo.inputmethod.latin.uix.setSetting
import java.io.File
import java.io.FileOutputStream
import java.io.IOException
import java.io.OutputStream
import java.util.concurrent.TimeUnit
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
......
......@@ -19,6 +19,7 @@ LATIN_IME_JNI_SRC_FILES := \
org_futo_inputmethod_latin_DicTraverseSession.cpp \
org_futo_inputmethod_latin_xlm_LanguageModel.cpp \
org_futo_inputmethod_latin_xlm_AdapterTrainer.cpp \
org_futo_voiceinput_WhisperGGML.cpp \
jni_common.cpp
LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/sentencepiece/builtin_pb
......@@ -29,12 +30,14 @@ LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/third_party/darts_clone
LOCAL_C_INCLUDES += $(LOCAL_PATH)/src/third_party/absl
LATIN_IME_CORE_SRC_FILES := \
jni_utils.cpp \
ggml/context.cpp \
ggml/ggml.c \
ggml/ggml-alloc.c \
ggml/ggml-quants.c \
ggml/ggml-backend.c \
ggml/llama.cpp \
ggml/whisper.cpp \
ggml/finetune.cpp \
ggml/train.cpp \
ggml/common.cpp \
......
......@@ -25,12 +25,13 @@
#include "org_futo_inputmethod_latin_xlm_LanguageModel.h"
#include "defines.h"
#include "org_futo_inputmethod_latin_xlm_AdapterTrainer.h"
#include "org_futo_voiceinput_WhisperGGML.h"
/*
* Returns the JNI version on success, -1 on failure.
*/
jint JNI_OnLoad(JavaVM *vm, void *reserved) {
JNIEnv *env = 0;
JNIEnv *env = nullptr;
if (vm->GetEnv(reinterpret_cast<void **>(&env), JNI_VERSION_1_6) != JNI_OK) {
AKLOGE("ERROR: GetEnv failed");
......@@ -65,6 +66,10 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
AKLOGE("ERROR: AdapterTrainer native registration failed");
return -1;
}
if (!voiceinput::register_WhisperGGML(env)) {
AKLOGE("ERROR: WhisperGGML native registration failed");
return -1;
}
/* success -- return valid version number */
return JNI_VERSION_1_6;
}
......
......@@ -8,20 +8,7 @@
#include "jni_common.h"
#include "ggml/finetune.h"
#include "sentencepiece/sentencepiece_processor.h"
std::string jstring2string(JNIEnv *env, jstring jStr) {
const jsize stringUtf8Length = env->GetStringUTFLength(jStr);
if (stringUtf8Length <= 0) {
AKLOGE("Can't get jStr");
return "";
}
char stringChars[stringUtf8Length + 1];
env->GetStringUTFRegion(jStr, 0, env->GetStringLength(jStr), stringChars);
stringChars[stringUtf8Length] = '\0';
return {stringChars};
}
#include "jni_utils.h"
namespace latinime {
struct AdapterTrainerState {
......
//
// Created by hp on 11/22/23.
//
#include <string>
#include <bits/sysconf.h>
#include "org_futo_voiceinput_WhisperGGML.h"
#include "jni_common.h"
#include "defines.h"
#include "ggml/whisper.h"
#include "jni_utils.h"
struct WhisperModelState {
int n_threads = 4;
struct whisper_context *context = nullptr;
};
static jlong WhisperGGML_open(JNIEnv *env, jclass clazz, jstring model_dir) {
std::string model_dir_str = jstring2string(env, model_dir);
auto *state = new WhisperModelState();
state->context = whisper_init_from_file(model_dir_str.c_str());
if(!state->context){
AKLOGE("Failed to initialize whisper_context from path %s", model_dir_str.c_str());
delete state;
return 0L;
}
return reinterpret_cast<jlong>(state);
}
static jlong WhisperGGML_openFromBuffer(JNIEnv *env, jclass clazz, jobject buffer) {
void* buffer_address = env->GetDirectBufferAddress(buffer);
jlong buffer_capacity = env->GetDirectBufferCapacity(buffer);
auto *state = new WhisperModelState();
state->context = whisper_init_from_buffer(buffer_address, buffer_capacity);
if(!state->context){
AKLOGE("Failed to initialize whisper_context from direct buffer");
delete state;
return 0L;
}
return reinterpret_cast<jlong>(state);
}
static void WhisperGGML_infer(JNIEnv *env, jobject instance, jlong handle, jfloatArray samples_array, jstring prompt) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
size_t num_samples = env->GetArrayLength(samples_array);
jfloat *samples = env->GetFloatArrayElements(samples_array, nullptr);
AKLOGI("Received %d samples", (int)num_samples);
long num_procs = sysconf(_SC_NPROCESSORS_ONLN);
if(num_procs < 2 || num_procs > 16) num_procs = 6; // Make sure the number is sane
AKLOGI("num procs = %d", (int)num_procs);
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
wparams.print_progress = false;
wparams.print_realtime = false;
wparams.print_special = false;
wparams.print_timestamps = false;
wparams.max_tokens = 256;
wparams.n_threads = (int)num_procs;
//wparams.audio_ctx = (int)ceil((double)num_samples / (double)(160.0 * 2.0));
wparams.temperature_inc = 0.0f;
//std::string prompt_str = jstring2string(env, prompt);
//wparams.initial_prompt = prompt_str.c_str();
//AKLOGI("Initial prompt is [%s]", prompt_str.c_str());
wparams.new_segment_callback = [](struct whisper_context * ctx, struct whisper_state * state, int n_new, void * user_data) {
const int n_segments = whisper_full_n_segments(ctx);
const int s0 = n_segments - n_new;
if (s0 == 0) {
AKLOGI("s0 == 0, \\n");
}
for (int i = s0; i < n_segments; i++) {
auto seg = whisper_full_get_segment_text(ctx, i);
AKLOGI("WhisperGGML new segment %s", seg);
}
};
AKLOGI("Calling whisper_full");
int res = whisper_full(state->context, wparams, samples, (int)num_samples);
if(res != 0) {
AKLOGE("WhisperGGML whisper_full failed with non-zero code %d", res);
}
AKLOGI("whisper_full finished :3");
whisper_print_timings(state->context);
/*
ASSERT(mel_count % 80 == 0);
whisper_set_mel(state->context, mel, (int)(mel_count / 80), 80);
whisper_encode(state->context, 0, 4);
whisper_token tokens[512] = { 0 };
whisper_decode(state->context, tokens, 512, 0, 4);
*/
}
static void WhisperGGML_close(JNIEnv *env, jclass clazz, jlong handle) {
auto *state = reinterpret_cast<WhisperModelState *>(handle);
if(!state) return;
delete state;
}
namespace voiceinput {
static const JNINativeMethod sMethods[] = {
{
const_cast<char *>("openNative"),
const_cast<char *>("(Ljava/lang/String;)J"),
reinterpret_cast<void *>(WhisperGGML_open)
},
{
const_cast<char *>("openFromBufferNative"),
const_cast<char *>("(Ljava/nio/Buffer;)J"),
reinterpret_cast<void *>(WhisperGGML_openFromBuffer)
},
{
const_cast<char *>("inferNative"),
const_cast<char *>("(J[FLjava/lang/String;)V"),
reinterpret_cast<void *>(WhisperGGML_infer)
},
{
const_cast<char *>("closeNative"),
const_cast<char *>("(J)V"),
reinterpret_cast<void *>(WhisperGGML_close)
}
};
int register_WhisperGGML(JNIEnv *env) {
const char *const kClassPathName = "org/futo/voiceinput/shared/ggml/WhisperGGML";
return latinime::registerNativeMethods(env, kClassPathName, sMethods, NELEMS(sMethods));
}
}
\ No newline at end of file
//
// Created by hp on 11/22/23.
//
#ifndef LATINIME_ORG_FUTO_VOICEINPUT_WHISPERGGML_H
#define LATINIME_ORG_FUTO_VOICEINPUT_WHISPERGGML_H
#include "jni.h"
namespace voiceinput {
int register_WhisperGGML(JNIEnv *env);
} // namespace latinime
#endif //LATINIME_ORG_FUTO_VOICEINPUT_WHISPERGGML_H
......@@ -339,4 +339,5 @@ typedef enum {
// Create new word with space substitution
CT_NEW_WORD_SPACE_SUBSTITUTION,
} CorrectionType;
#endif // LATINIME_DEFINES_H
This diff is collapsed.
This diff is collapsed.
//
// Created by hp on 11/22/23.
//
#include "jni_utils.h"
#include <string>
#include "defines.h"
std::string jstring2string(JNIEnv *env, jstring jStr) {
const jsize stringUtf8Length = env->GetStringUTFLength(jStr);
if (stringUtf8Length <= 0) {
AKLOGE("Can't get jStr");
return "";
}
char stringChars[stringUtf8Length + 1];
env->GetStringUTFRegion(jStr, 0, env->GetStringLength(jStr), stringChars);
stringChars[stringUtf8Length] = '\0';
return {stringChars};
}
\ No newline at end of file
//
// Created by hp on 11/22/23.
//
#ifndef LATINIME_JNI_UTILS_H
#define LATINIME_JNI_UTILS_H
#include <string>
#include "../jni_common.h"
std::string jstring2string(JNIEnv *env, jstring jStr);
#endif //LATINIME_JNI_UTILS_H
......@@ -4,6 +4,7 @@ import android.Manifest
import android.content.Context
import android.content.Intent
import android.content.pm.PackageManager
import android.content.res.AssetManager
import android.hardware.SensorPrivacyManager
import android.media.AudioFormat
import android.media.AudioRecord
......@@ -24,6 +25,7 @@ import kotlinx.coroutines.Job
import kotlinx.coroutines.launch
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.ggml.WhisperGGML
import org.futo.voiceinput.shared.types.AudioRecognizerListener
import org.futo.voiceinput.shared.types.InferenceState
import org.futo.voiceinput.shared.types.Language
......@@ -33,14 +35,17 @@ import org.futo.voiceinput.shared.types.ModelLoader
import org.futo.voiceinput.shared.whisper.DecodingConfiguration
import org.futo.voiceinput.shared.whisper.ModelManager
import org.futo.voiceinput.shared.whisper.MultiModelRunConfiguration
import org.futo.voiceinput.shared.whisper.MultiModelRunner
import org.futo.voiceinput.shared.whisper.isBlankResult
import org.tensorflow.lite.support.common.FileUtil
import java.io.FileInputStream
import java.nio.FloatBuffer
import java.nio.ShortBuffer
import java.nio.channels.FileChannel
import kotlin.math.min
import kotlin.math.pow
import kotlin.math.sqrt
data class AudioRecognizerSettings(
val modelRunConfiguration: MultiModelRunConfiguration,
val decodingConfiguration: DecodingConfiguration
......@@ -58,13 +63,16 @@ class AudioRecognizer(
private var isRecording = false
private var recorder: AudioRecord? = null
private val modelRunner = MultiModelRunner(modelManager)
//private val modelRunner = MultiModelRunner(modelManager)
private val floatSamples: FloatBuffer = FloatBuffer.allocate(16000 * 30)
private var recorderJob: Job? = null
private var modelJob: Job? = null
private var loadModelJob: Job? = null
private val buffer = FileUtil.loadMappedFile(context, "ggml-model.tflite")
private val ggmlModel = WhisperGGML(buffer)
@Throws(ModelDoesNotExistException::class)
private fun verifyModelsExist() {
val modelsThatDoNotExist = mutableListOf<ModelLoader>()
......@@ -163,7 +171,7 @@ class AudioRecognizer(
}
private suspend fun preloadModels() {
modelRunner.preload(settings.modelRunConfiguration)
//modelRunner.preload(settings.modelRunConfiguration)
}
private suspend fun recordingJob(recorder: AudioRecord, vad: VadModel) {
......@@ -352,6 +360,12 @@ class AudioRecognizer(
}
private suspend fun runModel() {
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
println("RUNNING GGML MODEL")
ggmlModel.infer(floatArray)
println("FINISHED RUNNING GGML MODEL")
/*
loadModelJob?.let {
if (it.isActive) {
println("Model was not finished loading...")
......@@ -359,7 +373,7 @@ class AudioRecognizer(
}
}
val floatArray = floatSamples.array().sliceArray(0 until floatSamples.position())
yield()
val outputText = modelRunner.run(
......@@ -381,6 +395,8 @@ class AudioRecognizer(
listener.finished(text)
}
}
*/
}
private fun onFinishRecording() {
......
package org.futo.voiceinput.shared.ggml
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.withContext
import java.nio.Buffer
@OptIn(DelicateCoroutinesApi::class)
val inferenceContext = newSingleThreadContext("whisper-ggml-inference")
class WhisperGGML(
buffer: Buffer
) {
private var handle: Long = 0L
init {
handle = openFromBufferNative(buffer)
if(handle == 0L) {
throw IllegalArgumentException("The Whisper model could not be loaded from the given buffer")
}
}
suspend fun infer(samples: FloatArray) = withContext(inferenceContext) {
inferNative(handle, samples, "")
}
external fun openNative(path: String): Long
external fun openFromBufferNative(buffer: Buffer): Long
external fun inferNative(handle: Long, samples: FloatArray, prompt: String)
external fun closeNative(handle: Long)
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment