From cd3d5a284f403e183ba0a8585cb3337972bd22ff Mon Sep 17 00:00:00 2001
From: Aleksandras Kostarevas <alex@futo.org>
Date: Tue, 21 Nov 2023 20:26:23 +0200
Subject: [PATCH] Automatically schedule training

---
 .../latin/uix/settings/pages/TrainDev.kt      | 21 +++---
 .../latin/xlm/LanguageModelFacilitator.kt     | 13 ++--
 .../inputmethod/latin/xlm/TrainingWorker.kt   | 64 ++++++++++++++++++-
 3 files changed, 83 insertions(+), 15 deletions(-)

diff --git a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt
index b6c97abec7..625a8dab58 100644
--- a/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt
+++ b/java/src/org/futo/inputmethod/latin/uix/settings/pages/TrainDev.kt
@@ -23,6 +23,9 @@ import org.futo.inputmethod.latin.xlm.TrainingState
 import org.futo.inputmethod.latin.xlm.TrainingWorker
 import org.futo.inputmethod.latin.xlm.TrainingWorkerStatus
 import org.futo.inputmethod.latin.xlm.loadHistoryLogBackup
+import org.futo.inputmethod.latin.xlm.scheduleTrainingWorkerImmediately
+import org.futo.inputmethod.latin.xlm.NUM_TRAINING_RUNS_KEY
+import org.futo.inputmethod.latin.uix.getSettingFlow
 import java.util.concurrent.TimeUnit
 import kotlin.math.roundToInt
 
@@ -45,22 +48,24 @@ fun TrainDevScreen(navController: NavHostController = rememberNavController()) {
         trainingDataAmount = data.size
     }
 
+    val numTrains = context.getSettingFlow(NUM_TRAINING_RUNS_KEY, 0).collectAsState(initial = 0)
+
     ScrollableList {
         ScreenTitle("Training", showBack = true, navController)
 
-        Text("There are $trainingDataAmount pending training examples.")
+        Text("The model has been trained ${numTrains.value} times in total.")
 
-        Button(onClick = {
-            val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
-                .setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
-                .build()
+        Text("There are $trainingDataAmount pending training examples (minimum for training is 100)")
 
-            WorkManager.getInstance(context).enqueue(workRequest)
-        }, enabled = !TrainingWorkerStatus.isTraining.value) {
+        Button(onClick = {
+            scheduleTrainingWorkerImmediately(context)
+        }, enabled = (!TrainingWorkerStatus.isTraining.value) && (trainingDataAmount >= 100)) {
             if(TrainingWorkerStatus.isTraining.value) {
                 Text("Currently training (${(progress.value * 100.0f).roundToInt()}%, loss ${loss.value})")
-            } else {
+            } else if(trainingDataAmount > 100) {
                 Text("Train model")
+            } else {
+                Text("Train model (not enough data)")
             }
         }
 
diff --git a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt
index 6cd8879caa..581b3fb14f 100644
--- a/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt
+++ b/java/src/org/futo/inputmethod/latin/xlm/LanguageModelFacilitator.kt
@@ -58,6 +58,7 @@ import androidx.savedstate.SavedStateRegistryController
 import androidx.savedstate.SavedStateRegistryOwner
 import androidx.savedstate.findViewTreeSavedStateRegistryOwner
 import androidx.savedstate.setViewTreeSavedStateRegistryOwner
+import androidx.work.WorkManager
 import kotlinx.coroutines.Job
 import kotlinx.coroutines.flow.Flow
 import kotlinx.coroutines.flow.MutableSharedFlow
@@ -237,12 +238,16 @@ public class LanguageModelFacilitator(
             }
         }
 
-        withContext(Dispatchers.Default) {
-            sharedFlow.conflate().collect { value ->
-                println("LatinIME: Collecting")
-                processUpdateSuggestionStrip(value)
+        launch {
+            withContext(Dispatchers.Default) {
+                sharedFlow.conflate().collect { value ->
+                    println("LatinIME: Collecting")
+                    processUpdateSuggestionStrip(value)
+                }
             }
         }
+
+        scheduleTrainingWorkerBackground(context)
     }
 
     public fun updateSuggestionStripAsync(inputStyle: Int) {
diff --git a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt
index 00e962499a..2528b695ef 100644
--- a/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt
+++ b/java/src/org/futo/inputmethod/latin/xlm/TrainingWorker.kt
@@ -13,14 +13,23 @@ import androidx.work.CoroutineWorker
 import androidx.work.ForegroundInfo
 import androidx.work.WorkManager
 import androidx.work.WorkerParameters
+import androidx.work.Constraints
+import androidx.work.PeriodicWorkRequest
+import androidx.work.OneTimeWorkRequestBuilder
+import androidx.datastore.preferences.core.intPreferencesKey
 import kotlinx.coroutines.Dispatchers
 import kotlinx.coroutines.flow.MutableSharedFlow
 import kotlinx.coroutines.withContext
 import org.futo.inputmethod.latin.R
+import org.futo.inputmethod.latin.uix.setSetting
+import org.futo.inputmethod.latin.uix.getSetting
 import java.io.File
 import java.io.FileOutputStream
 import java.io.IOException
 import java.io.OutputStream
+import java.util.concurrent.TimeUnit
+
+val NUM_TRAINING_RUNS_KEY = intPreferencesKey("training_runs_count")
 
 const val CHANNEL_ID = "TRAINING"
 const val NOTIFICATION_ID = 1
@@ -52,12 +61,14 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
                 NotificationManager
 
     override suspend fun doWork(): Result {
+        println("TrainingWorker is starting")
         TrainingWorkerStatus.state.emit(TrainingState.Starting)
         TrainingWorkerStatus.isTraining.value = true
         setForeground(createForegroundInfo("Training..."))
 
         TrainingWorkerStatus.state.emit(train())
         TrainingWorkerStatus.isTraining.value = false
+        println("TrainingWorker has ended")
         return Result.success()
     }
 
@@ -65,6 +76,10 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
         val data = mutableListOf<HistoryLogForTraining>()
         loadHistoryLogBackup(applicationContext, data)
 
+        if(data.size < 100) {
+            return ""
+        }
+
         return data.map { entry ->
             if(entry.misspelledWord != null) {
                 if(entry.importance == 3) {
@@ -118,6 +133,11 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
     }
 
     private suspend fun train(): TrainingState {
+        val data = getTrainingData()
+        if(data.isEmpty()) {
+            return TrainingState.ErrorInadequateData
+        }
+
         val cacheLoraPath = File(applicationContext.cacheDir, "adapter.bin")
 
         val builder = AdapterTrainerBuilder(
@@ -132,7 +152,6 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
 
         builder.setWeight(0.75f)
 
-        val data = getTrainingData()
         builder.addExamples(data.lines())
 
         val trainer = try {
@@ -146,14 +165,22 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
         withContext(Dispatchers.Default) {
             println("Staring to train")
             wakeLock.acquire(120*60*1000L /*1 hour*/)
-            trainer.train()
-            wakeLock.release()
+            try {
+                trainer.train()
+            } finally {
+                wakeLock.release()
+            }
             println("Finished training")
         }
 
+        // In case there's no one to receive ClearTrainingLog, save an empty log
+        saveHistoryLogBackup(applicationContext, listOf())
+
         TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ResetModel)
         TrainingWorkerStatus.lmRequest.emit(LanguageModelFacilitatorRequest.ClearTrainingLog)
 
+        applicationContext.setSetting(NUM_TRAINING_RUNS_KEY, applicationContext.getSetting(NUM_TRAINING_RUNS_KEY, 0) + 1)
+
         return TrainingState.Finished
     }
     // Creates an instance of ForegroundInfo which can be used to update the
@@ -194,4 +221,35 @@ class TrainingWorker(context: Context, parameters: WorkerParameters) : Coroutine
 
         notificationManager.createNotificationChannel(channel)
     }
+}
+
+private val WORKER_TAG: String = "TRAINING_WORKER"
+public fun scheduleTrainingWorkerBackground(context: Context) {
+    val workManager = WorkManager.getInstance(context)
+    workManager.cancelAllWorkByTag(WORKER_TAG)
+
+    val constraints = Constraints.Builder()
+        .setRequiresBatteryNotLow(true)
+        .setRequiresCharging(true)
+        .setRequiresDeviceIdle(true)
+        .build()
+    
+    val request = PeriodicWorkRequest.Builder(
+        TrainingWorker::class.java,
+        20L, TimeUnit.HOURS,
+        // 12L, TimeUnit.HOURS
+    ).addTag(WORKER_TAG).setConstraints(constraints).build()
+
+    workManager.enqueue(request)
+}
+
+public fun scheduleTrainingWorkerImmediately(context: Context) {
+    val workManager = WorkManager.getInstance(context)
+
+    val workRequest = OneTimeWorkRequestBuilder<TrainingWorker>()
+        .setInitialDelay(0, TimeUnit.SECONDS) // Run immediately
+        .addTag(WORKER_TAG)
+        .build()
+
+    workManager.enqueue(workRequest)
 }
\ No newline at end of file
-- 
GitLab