Skip to content
Snippets Groups Projects
Commit 9f6941ef authored by Aleksandras Kostarevas's avatar Aleksandras Kostarevas
Browse files

Move certain tensors to static companion

parent 3acb8b5e
No related branches found
No related tags found
1 merge request!1Integrate voice input
package org.futo.voiceinput.shared.whisper
import android.content.Context
import kotlinx.coroutines.DelicateCoroutinesApi
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.launch
import kotlinx.coroutines.newSingleThreadContext
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withContext
import kotlinx.coroutines.yield
import org.futo.voiceinput.shared.types.DecodedMetadata
......@@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.model.Model
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn(DelicateCoroutinesApi::class)
private val inferenceContext = newSingleThreadContext("InferenceContext")
class WhisperModel(
......@@ -23,30 +28,23 @@ class WhisperModel(
val loader: ModelLoader,
) {
private var closed = false
private class InferenceSession(
val model: WhisperModel, val bannedTokens: IntArray
) : ModelInferenceSession {
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
private var seqLen = 0
private var xAtn: TensorBuffer? = null
private val decodedTokens = mutableListOf(model.tokenizer.decodeStartToken)
private suspend fun decodeStep(forceOption: Int? = null): Int {
private fun decodeStep(forceOption: Int? = null): Int {
if (xAtn == null) {
throw IllegalStateException("melToFeatures must be called before starting decoding")
}
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = decodedTokens.last().toFloat()
model.loadSeqLenInputId(seqLen, decodedTokens.last())
model.seqLenTensor.loadArray(seqLenArray)
model.inputIdTensor.loadArray(inputIdsArray)
val decoderOutputs =
model.runDecoder(xAtn!!, model.seqLenTensor, model.cacheTensor, model.inputIdTensor)
val decoderOutputs = model.runDecoder(xAtn!!, model.cacheTensor)
model.cacheTensor.loadBuffer(decoderOutputs.nextCache.buffer.duplicate())
val selectedToken = if (forceOption != null) {
......@@ -159,6 +157,7 @@ class WhisperModel(
init {
val cpuOption = Model.Options.Builder().setDevice(Model.Device.CPU).build()
// NNAPI is disabled due to reported issues
val (encoder, decoder) = loader.loadEncoderDecoder(context, cpuOption)
......@@ -192,29 +191,48 @@ class WhisperModel(
this.bannedTokens = bannedTokens
}
// Must be called within inferenceContext
private fun runEncoderAndGetXatn(mel: FloatArray): TensorBuffer {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
audioFeatures.loadArray(mel)
return encoderModel.process(audioFeatures).crossAttention
}
// Must be called within inferenceContext
private fun runDecoder(
xAtn: TensorBuffer, seqLen: TensorBuffer, cache: TensorBuffer, inputId: TensorBuffer
xAtn: TensorBuffer, cache: TensorBuffer
): DecoderModel.Outputs {
if(closed)
throw IllegalStateException("Cannot run session after model has been closed")
if (closed) throw IllegalStateException("Cannot run session after model has been closed")
return decoderModel.process(
crossAttention = xAtn, seqLen = seqLen, cache = cache, inputIds = inputId
crossAttention = xAtn, seqLen = seqLenTensor, cache = cache, inputIds = inputIdTensor
)
}
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
// TODO: Ideally this should be shared between model instances as well.
private val cacheTensor =
TensorBuffer.createFixedSize(decoderModel.getCacheTensorShape(), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
companion object {
private val audioFeatures =
TensorBuffer.createFixedSize(intArrayOf(1, 80, 3000), DataType.FLOAT32)
private val seqLenTensor = TensorBuffer.createFixedSize(intArrayOf(1), DataType.FLOAT32)
private val inputIdTensor = TensorBuffer.createFixedSize(intArrayOf(1, 1), DataType.FLOAT32)
private val seqLenArray = FloatArray(1)
private val inputIdsArray = FloatArray(1)
}
// Must be called within inferenceContext
private fun loadSeqLenInputId(seqLen: Int, inputId: Int) {
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray[0] = seqLen.toFloat()
inputIdsArray[0] = inputId.toFloat()
seqLenTensor.loadArray(seqLenArray)
inputIdTensor.loadArray(inputIdsArray)
}
init {
val shape = cacheTensor.shape
......@@ -223,8 +241,7 @@ class WhisperModel(
}
fun startInferenceSession(settings: DecodingConfiguration): ModelInferenceSession {
if(closed)
throw IllegalStateException("Cannot start session after model has been closed")
if (closed) throw IllegalStateException("Cannot start session after model has been closed")
updateBannedTokens(settings)
return InferenceSession(
......@@ -233,7 +250,7 @@ class WhisperModel(
}
suspend fun close() {
if(closed) return
if (closed) return
closed = true
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment