Skip to content
Snippets Groups Projects
Commit cd3d5a28 authored by Aleksandras Kostarevas's avatar Aleksandras Kostarevas
Browse files

Automatically schedule training

parent cb2edca6
No related branches found
No related tags found
1 merge request!7Merge model-metadata to master
...@@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState ...@@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
import org.futo.inputmethod.latin.xlm.TrainingWorker import org.futo.inputmethod.latin.xlm.TrainingWorker
import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
import org.futo.inputmethod.latin.uix.getSettingFlow
import java.util.concurrent.TimeUnit import java.util.concurrent.TimeUnit
import kotlin.math.roundToInt import kotlin.math.roundToInt
...@@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) { ...@@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
trainingDataAmount = data.size trainingDataAmount = data.size
} }
val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
ScrollableList { ScrollableList {
ScreenTitle("Training", showBack = true, navController) ScreenTitle("Training", showBack = true, navController)
Text("There are $trainingDataAmount pending training examples.") Text("The model has been trained ${numTrains.value} times in total.")
Button(onClick = { Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.build()
WorkManager.getInstance(context).enqueue(workRequest) Button(onClick = {
}, enabled = !TrainingWorkerStatus.isTraining.value) { scheduleTrainingWorkerImmediately(context)
}, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
if(TrainingWorkerStatus.isTraining.value) { if(TrainingWorkerStatus.isTraining.value) {
Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})") Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
} else { } else if(trainingDataAmount > 100) {
Text("Train model") Text("Train model")
} else {
Text("Train model (not enough data)")
} }
} }
......
...@@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController ...@@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
import androidx.savedstate.SavedStateRegistryOwner import androidx.savedstate.SavedStateRegistryOwner
import androidx.savedstate.findViewTreeSavedStateRegistryOwner import androidx.savedstate.findViewTreeSavedStateRegistryOwner
import androidx.savedstate.setViewTreeSavedStateRegistryOwner import androidx.savedstate.setViewTreeSavedStateRegistryOwner
import androidx.work.WorkManager
import kotlinx.coroutines.Job import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
...@@ -237,12 +238,16 @@ public class LanguageModelFacilitator( ...@@ -237,12 +238,16 @@ public class LanguageModelFacilitator(
} }
} }
withContext(Dispatchers.Default) { launch {
sharedFlow.conflate().collect { value -> withContext(Dispatchers.Default) {
println("LatinIME: Collecting") sharedFlow.conflate().collect { value ->
processUpdateSuggestionStrip(value) println("LatinIME: Collecting")
processUpdateSuggestionStrip(value)
}
} }
} }
scheduleTrainingWorkerBackground(context)
} }
public fun updateSuggestionStripAsync(inputStyle: Int) { public fun updateSuggestionStripAsync(inputStyle: Int) {
......
...@@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker ...@@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
import androidx.work.ForegroundInfo import androidx.work.ForegroundInfo
import androidx.work.WorkManager import androidx.work.WorkManager
import androidx.work.WorkerParameters import androidx.work.WorkerParameters
import androidx.work.Constraints
import androidx.work.PeriodicWorkRequest
import androidx.work.OneTimeWorkRequestBuilder
import androidx.datastore.preferences.core.intPreferencesKey
import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.flow.MutableSharedFlow import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.withContext import kotlinx.coroutines.withContext
import org.futo.inputmethod.latin.R import org.futo.inputmethod.latin.R
import org.futo.inputmethod.latin.uix.setSetting
import org.futo.inputmethod.latin.uix.getSetting
import java.io.File import java.io.File
import java.io.FileOutputStream import java.io.FileOutputStream
import java.io.IOException import java.io.IOException
import java.io.OutputStream import java.io.OutputStream
import java.util.concurrent.TimeUnit
val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
const val CHANNEL_ID = "TRAINING" const val CHANNEL_ID = "TRAINING"
const val NOTIFICATION_ID = 1 const val NOTIFICATION_ID = 1
...@@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
NotificationManager NotificationManager
override suspend fun doWork(): Result { override suspend fun doWork(): Result {
println("TrainingWorker is starting")
TrainingWorkerStatus.state.emit(TrainingState.Starting) TrainingWorkerStatus.state.emit(TrainingState.Starting)
TrainingWorkerStatus.isTraining.value = true TrainingWorkerStatus.isTraining.value = true
setForeground(createForegroundInfo("Training...")) setForeground(createForegroundInfo("Training..."))
TrainingWorkerStatus.state.emit(train()) TrainingWorkerStatus.state.emit(train())
TrainingWorkerStatus.isTraining.value = false TrainingWorkerStatus.isTraining.value = false
println("TrainingWorker has ended")
return Result.success() return Result.success()
} }
...@@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
val data = mutableListOf<HistoryLogForTraining>() val data = mutableListOf<HistoryLogForTraining>()
loadHistoryLogBackup(applicationContext, data) loadHistoryLogBackup(applicationContext, data)
if(data.size < 100) {
return ""
}
return data.map { entry -> return data.map { entry ->
if(entry.misspelledWord != null) { if(entry.misspelledWord != null) {
if(entry.importance == 3) { if(entry.importance == 3) {
...@@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
} }
private suspend fun train(): TrainingState { private suspend fun train(): TrainingState {
val data = getTrainingData()
if(data.isEmpty()) {
return TrainingState.ErrorInadequateData
}
val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin") val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
val builder = AdapterTrainerBuilder( val builder = AdapterTrainerBuilder(
...@@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
builder.setWeight(0.75f) builder.setWeight(0.75f)
val data = getTrainingData()
builder.addExamples(data.lines()) builder.addExamples(data.lines())
val trainer = try { val trainer = try {
...@@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
withContext(Dispatchers.Default) { withContext(Dispatchers.Default) {
println("Staring to train") println("Staring to train")
wakeLock.acquire(120*60*1000L /*1 hour*/) wakeLock.acquire(120*60*1000L /*1 hour*/)
trainer.train() try {
wakeLock.release() trainer.train()
} finally {
wakeLock.release()
}
println("Finished training") println("Finished training")
} }
// In case there's no one to receive ClearTrainingLog, save an empty log
saveHistoryLogBackup(applicationContext, listOf())
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel) TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog) TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
return TrainingState.Finished return TrainingState.Finished
} }
// Creates an instance of ForegroundInfo which can be used to update the // Creates an instance of ForegroundInfo which can be used to update the
...@@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine ...@@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
notificationManager.createNotificationChannel(channel) notificationManager.createNotificationChannel(channel)
} }
}
private val WORKER_TAG: String = "TRAINING_WORKER"
public fun scheduleTrainingWorkerBackground(context: Context) {
val workManager = WorkManager.getInstance(context)
workManager.cancelAllWorkByTag(WORKER_TAG)
val constraints = Constraints.Builder()
.setRequiresBatteryNotLow(true)
.setRequiresCharging(true)
.setRequiresDeviceIdle(true)
.build()
val request = PeriodicWorkRequest.Builder(
TrainingWorker::class.java,
20L, TimeUnit.HOURS,
// 12L, TimeUnit.HOURS
).addTag(WORKER_TAG).setConstraints(constraints).build()
workManager.enqueue(request)
}
public fun scheduleTrainingWorkerImmediately(context: Context) {
val workManager = WorkManager.getInstance(context)
val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
.setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
.addTag(WORKER_TAG)
.build()
workManager.enqueue(workRequest)
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment