Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
LatinIME
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Deploy
Releases
Package Registry
Container Registry
Model registry
Operate
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
keyboard
LatinIME
Commits
9f6941ef
Commit
9f6941ef
authored
1 year ago
by
Aleksandras Kostarevas
Browse files
Options
Downloads
Patches
Plain Diff
Move certain tensors to static companion
parent
3acb8b5e
No related branches found
Branches containing commit
No related tags found
Tags containing commit
1 merge request
!1
Integrate voice input
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt
+42
-25
42 additions, 25 deletions
...n/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt
with
42 additions
and
25 deletions
voiceinput-shared/src/main/java/org/futo/voiceinput/shared/whisper/WhisperModel.kt
+
42
−
25
View file @
9f6941ef
package
org.futo.voiceinput.shared.whisper
package
org.futo.voiceinput.shared.whisper
import
android.content.Context
import
android.content.Context
import
kotlinx.coroutines.DelicateCoroutinesApi
import
kotlinx.coroutines.Dispatchers
import
kotlinx.coroutines.Dispatchers
import
kotlinx.coroutines.launch
import
kotlinx.coroutines.launch
import
kotlinx.coroutines.newSingleThreadContext
import
kotlinx.coroutines.newSingleThreadContext
import
kotlinx.coroutines.runBlocking
import
kotlinx.coroutines.withContext
import
kotlinx.coroutines.withContext
import
kotlinx.coroutines.yield
import
kotlinx.coroutines.yield
import
org.futo.voiceinput.shared.types.DecodedMetadata
import
org.futo.voiceinput.shared.types.DecodedMetadata
...
@@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
...
@@ -16,6 +16,11 @@ import org.tensorflow.lite.DataType
import
org.tensorflow.lite.support.model.Model
import
org.tensorflow.lite.support.model.Model
import
org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import
org.tensorflow.lite.support.tensorbuffer.TensorBuffer
/**
* This is necessary to synchronize so two threads don't try to use the same tensor at once,
* free a model while it's in use, etc.
*/
@OptIn
(
DelicateCoroutinesApi
::
class
)
private
val
inferenceContext
=
newSingleThreadContext
(
"InferenceContext"
)
private
val
inferenceContext
=
newSingleThreadContext
(
"InferenceContext"
)
class
WhisperModel
(
class
WhisperModel
(
...
@@ -23,30 +28,23 @@ class WhisperModel(
...
@@ -23,30 +28,23 @@ class WhisperModel(
val
loader
:
ModelLoader
,
val
loader
:
ModelLoader
,
)
{
)
{
private
var
closed
=
false
private
var
closed
=
false
private
class
InferenceSession
(
private
class
InferenceSession
(
val
model
:
WhisperModel
,
val
bannedTokens
:
IntArray
val
model
:
WhisperModel
,
val
bannedTokens
:
IntArray
)
:
ModelInferenceSession
{
)
:
ModelInferenceSession
{
private
val
seqLenArray
=
FloatArray
(
1
)
private
val
inputIdsArray
=
FloatArray
(
1
)
private
var
seqLen
=
0
private
var
seqLen
=
0
private
var
xAtn
:
TensorBuffer
?
=
null
private
var
xAtn
:
TensorBuffer
?
=
null
private
val
decodedTokens
=
mutableListOf
(
model
.
tokenizer
.
decodeStartToken
)
private
val
decodedTokens
=
mutableListOf
(
model
.
tokenizer
.
decodeStartToken
)
private
suspend
fun
decodeStep
(
forceOption
:
Int
?
=
null
):
Int
{
private
fun
decodeStep
(
forceOption
:
Int
?
=
null
):
Int
{
if
(
xAtn
==
null
)
{
if
(
xAtn
==
null
)
{
throw
IllegalStateException
(
"melToFeatures must be called before starting decoding"
)
throw
IllegalStateException
(
"melToFeatures must be called before starting decoding"
)
}
}
seqLenArray
[
0
]
=
seqLen
.
toFloat
()
model
.
loadSeqLenInputId
(
seqLen
,
decodedTokens
.
last
())
inputIdsArray
[
0
]
=
decodedTokens
.
last
().
toFloat
()
model
.
seqLenTensor
.
loadArray
(
seqLenArray
)
val
decoderOutputs
=
model
.
runDecoder
(
xAtn
!!
,
model
.
cacheTensor
)
model
.
inputIdTensor
.
loadArray
(
inputIdsArray
)
val
decoderOutputs
=
model
.
runDecoder
(
xAtn
!!
,
model
.
seqLenTensor
,
model
.
cacheTensor
,
model
.
inputIdTensor
)
model
.
cacheTensor
.
loadBuffer
(
decoderOutputs
.
nextCache
.
buffer
.
duplicate
())
model
.
cacheTensor
.
loadBuffer
(
decoderOutputs
.
nextCache
.
buffer
.
duplicate
())
val
selectedToken
=
if
(
forceOption
!=
null
)
{
val
selectedToken
=
if
(
forceOption
!=
null
)
{
...
@@ -159,6 +157,7 @@ class WhisperModel(
...
@@ -159,6 +157,7 @@ class WhisperModel(
init
{
init
{
val
cpuOption
=
Model
.
Options
.
Builder
().
setDevice
(
Model
.
Device
.
CPU
).
build
()
val
cpuOption
=
Model
.
Options
.
Builder
().
setDevice
(
Model
.
Device
.
CPU
).
build
()
// NNAPI is disabled due to reported issues
val
(
encoder
,
decoder
)
=
loader
.
loadEncoderDecoder
(
context
,
cpuOption
)
val
(
encoder
,
decoder
)
=
loader
.
loadEncoderDecoder
(
context
,
cpuOption
)
...
@@ -192,29 +191,48 @@ class WhisperModel(
...
@@ -192,29 +191,48 @@ class WhisperModel(
this
.
bannedTokens
=
bannedTokens
this
.
bannedTokens
=
bannedTokens
}
}
// Must be called within inferenceContext
private
fun
runEncoderAndGetXatn
(
mel
:
FloatArray
):
TensorBuffer
{
private
fun
runEncoderAndGetXatn
(
mel
:
FloatArray
):
TensorBuffer
{
if
(
closed
)
if
(
closed
)
throw
IllegalStateException
(
"Cannot run session after model has been closed"
)
throw
IllegalStateException
(
"Cannot run session after model has been closed"
)
audioFeatures
.
loadArray
(
mel
)
audioFeatures
.
loadArray
(
mel
)
return
encoderModel
.
process
(
audioFeatures
).
crossAttention
return
encoderModel
.
process
(
audioFeatures
).
crossAttention
}
}
// Must be called within inferenceContext
private
fun
runDecoder
(
private
fun
runDecoder
(
xAtn
:
TensorBuffer
,
seqLen
:
TensorBuffer
,
cache
:
TensorBuffer
,
inputId
:
TensorBuffer
xAtn
:
TensorBuffer
,
cache
:
TensorBuffer
):
DecoderModel
.
Outputs
{
):
DecoderModel
.
Outputs
{
if
(
closed
)
if
(
closed
)
throw
IllegalStateException
(
"Cannot run session after model has been closed"
)
throw
IllegalStateException
(
"Cannot run session after model has been closed"
)
return
decoderModel
.
process
(
return
decoderModel
.
process
(
crossAttention
=
xAtn
,
seqLen
=
seqLen
,
cache
=
cache
,
inputIds
=
inputId
crossAttention
=
xAtn
,
seqLen
=
seqLen
Tensor
,
cache
=
cache
,
inputIds
=
inputId
Tensor
)
)
}
}
private
val
audioFeatures
=
// TODO: Ideally this should be shared between model instances as well.
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
,
80
,
3000
),
DataType
.
FLOAT32
)
private
val
seqLenTensor
=
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
),
DataType
.
FLOAT32
)
private
val
cacheTensor
=
private
val
cacheTensor
=
TensorBuffer
.
createFixedSize
(
decoderModel
.
getCacheTensorShape
(),
DataType
.
FLOAT32
)
TensorBuffer
.
createFixedSize
(
decoderModel
.
getCacheTensorShape
(),
DataType
.
FLOAT32
)
private
val
inputIdTensor
=
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
,
1
),
DataType
.
FLOAT32
)
companion
object
{
private
val
audioFeatures
=
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
,
80
,
3000
),
DataType
.
FLOAT32
)
private
val
seqLenTensor
=
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
),
DataType
.
FLOAT32
)
private
val
inputIdTensor
=
TensorBuffer
.
createFixedSize
(
intArrayOf
(
1
,
1
),
DataType
.
FLOAT32
)
private
val
seqLenArray
=
FloatArray
(
1
)
private
val
inputIdsArray
=
FloatArray
(
1
)
}
// Must be called within inferenceContext
private
fun
loadSeqLenInputId
(
seqLen
:
Int
,
inputId
:
Int
)
{
// TFLite has sketchy support for ints, so the model takes floats as input and casts them
// back to int internally
seqLenArray
[
0
]
=
seqLen
.
toFloat
()
inputIdsArray
[
0
]
=
inputId
.
toFloat
()
seqLenTensor
.
loadArray
(
seqLenArray
)
inputIdTensor
.
loadArray
(
inputIdsArray
)
}
init
{
init
{
val
shape
=
cacheTensor
.
shape
val
shape
=
cacheTensor
.
shape
...
@@ -223,8 +241,7 @@ class WhisperModel(
...
@@ -223,8 +241,7 @@ class WhisperModel(
}
}
fun
startInferenceSession
(
settings
:
DecodingConfiguration
):
ModelInferenceSession
{
fun
startInferenceSession
(
settings
:
DecodingConfiguration
):
ModelInferenceSession
{
if
(
closed
)
if
(
closed
)
throw
IllegalStateException
(
"Cannot start session after model has been closed"
)
throw
IllegalStateException
(
"Cannot start session after model has been closed"
)
updateBannedTokens
(
settings
)
updateBannedTokens
(
settings
)
return
InferenceSession
(
return
InferenceSession
(
...
@@ -233,7 +250,7 @@ class WhisperModel(
...
@@ -233,7 +250,7 @@ class WhisperModel(
}
}
suspend
fun
close
()
{
suspend
fun
close
()
{
if
(
closed
)
return
if
(
closed
)
return
closed
=
true
closed
=
true
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment