From c3c3f43fa7cd7cf844ae8d80d7c5a355dfc4dc3e Mon Sep 17 00:00:00 2001
From: Taras <tarassmakula@gmail.com>
Date: Thu, 17 Mar 2022 16:52:23 +0200
Subject: [PATCH] Change sdk to handle other auth types

---
 .../sdk/api/auth/data/LoginFlowTypes.kt       |   2 +-
 .../registration/RegistrationFlowResponse.kt  |   1 -
 .../auth/registration/RegistrationWizard.kt   |  15 ++-
 .../sdk/api/auth/registration/Stage.kt        |   3 -
 .../android/sdk/internal/auth/AuthAPI.kt      |  12 +-
 .../internal/auth/registration/AuthParams.kt  |  14 +-
 .../registration/DefaultRegistrationWizard.kt | 123 +++++++++++++-----
 .../auth/registration/RegisterOtherTask.kt    |  31 +++++
 .../registration/RegistrationOtherParams.kt   |  15 +++
 9 files changed, 157 insertions(+), 59 deletions(-)
 create mode 100644 matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegisterOtherTask.kt
 create mode 100644 matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegistrationOtherParams.kt

diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/data/LoginFlowTypes.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/data/LoginFlowTypes.kt
index ff92dab80..f2cebff19 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/data/LoginFlowTypes.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/data/LoginFlowTypes.kt
@@ -26,6 +26,6 @@ object LoginFlowTypes {
     const val RECAPTCHA = "m.login.recaptcha"
     const val DUMMY = "m.login.dummy"
     const val TERMS = "m.login.terms"
-    const val TOKEN = "m.login.registration_token"
+    const val TOKEN = "m.login.token"
     const val SSO = "m.login.sso"
 }
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationFlowResponse.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationFlowResponse.kt
index f7a46cbea..978b08b63 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationFlowResponse.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationFlowResponse.kt
@@ -94,7 +94,6 @@ fun RegistrationFlowResponse.toFlowResult(): FlowResult {
             LoginFlowTypes.TERMS          -> Stage.Terms(isMandatory, params?.get(type) as? TermPolicies ?: emptyMap<String, String>())
             LoginFlowTypes.EMAIL_IDENTITY -> Stage.Email(isMandatory)
             LoginFlowTypes.MSISDN         -> Stage.Msisdn(isMandatory)
-            LoginFlowTypes.TOKEN          -> Stage.Token(isMandatory)
             else                          -> Stage.Other(isMandatory, type, (params?.get(type) as? Map<*, *>))
         }
 
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationWizard.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationWizard.kt
index fb796b4ff..940a90bb0 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationWizard.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/RegistrationWizard.kt
@@ -16,6 +16,9 @@
 
 package org.matrix.android.sdk.api.auth.registration
 
+import org.matrix.android.sdk.api.util.JsonDict
+import org.matrix.android.sdk.internal.auth.registration.AuthParams
+
 /**
  * Set of methods to be able to create an account on a homeserver.
  *
@@ -52,9 +55,11 @@ interface RegistrationWizard {
      * @param password the desired password
      * @param initialDeviceDisplayName the device display name
      */
-    suspend fun createAccount(userName: String?,
-                              password: String?,
-                              initialDeviceDisplayName: String?): RegistrationResult
+    suspend fun createAccount(
+        userName: String?,
+        password: String?,
+        initialDeviceDisplayName: String?
+    ): RegistrationResult
 
     /**
      * Perform the "m.login.recaptcha" stage.
@@ -74,9 +79,9 @@ interface RegistrationWizard {
     suspend fun dummy(): RegistrationResult
 
     /**
-     * Perform the "m.login.registration_token" stage.
+     * Perform the other stage.
      */
-    suspend fun registrationToken(token: String): RegistrationResult
+    suspend fun registrationOther(authParams: JsonDict): RegistrationResult
 
     /**
      * Perform the "m.login.email.identity" or "m.login.msisdn" stage.
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/Stage.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/Stage.kt
index dc258154c..c21b667cf 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/Stage.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/api/auth/registration/Stage.kt
@@ -34,9 +34,6 @@ sealed class Stage(open val mandatory: Boolean) {
     // Undocumented yet: m.login.terms
     data class Terms(override val mandatory: Boolean, val policies: TermPolicies) : Stage(mandatory)
 
-    // m.login.registration_token
-    data class Token(override val mandatory: Boolean) : Stage(mandatory)
-
     // For unknown stages
     data class Other(override val mandatory: Boolean, val type: String, val params: Map<*, *>?) : Stage(mandatory)
 }
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/AuthAPI.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/AuthAPI.kt
index 554e21ce5..bd1799be2 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/AuthAPI.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/AuthAPI.kt
@@ -24,11 +24,11 @@ import org.matrix.android.sdk.internal.auth.data.PasswordLoginParams
 import org.matrix.android.sdk.internal.auth.data.TokenLoginParams
 import org.matrix.android.sdk.internal.auth.data.WebClientConfig
 import org.matrix.android.sdk.internal.auth.login.ResetPasswordMailConfirmed
+import org.matrix.android.sdk.internal.auth.registration.*
 import org.matrix.android.sdk.internal.auth.registration.AddThreePidRegistrationParams
 import org.matrix.android.sdk.internal.auth.registration.AddThreePidRegistrationResponse
+import org.matrix.android.sdk.internal.auth.registration.RegistrationOtherParams
 import org.matrix.android.sdk.internal.auth.registration.RegistrationParams
-import org.matrix.android.sdk.internal.auth.registration.SuccessResult
-import org.matrix.android.sdk.internal.auth.registration.ValidationCodeBody
 import org.matrix.android.sdk.internal.auth.version.Versions
 import org.matrix.android.sdk.internal.network.NetworkConstants
 import retrofit2.http.Body
@@ -68,6 +68,14 @@ internal interface AuthAPI {
     @POST(NetworkConstants.URI_API_PREFIX_PATH_R0 + "register")
     suspend fun register(@Body registrationParams: RegistrationParams): Credentials
 
+    /**
+     * Register to the homeserver, or get error 401 with a RegistrationFlowResponse object if registration is incomplete
+     * method to perform other custom stages
+     * Ref: https://matrix.org/docs/spec/client_server/latest#account-registration-and-management
+     */
+    @POST(NetworkConstants.URI_API_PREFIX_PATH_R0 + "register")
+    suspend fun registerOther(@Body registrationOtherParams: RegistrationOtherParams): Credentials
+
     /**
      * Checks to see if a username is available, and valid, for the server.
      */
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/AuthParams.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/AuthParams.kt
index f2c504bd6..a33746c27 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/AuthParams.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/AuthParams.kt
@@ -45,12 +45,7 @@ internal data class AuthParams(
      */
     @Json(name = "threepid_creds")
     val threePidCredentials: ThreePidCredentials? = null,
-
-    /**
-     * parameter for "m.login.registration_token" type
-     */
-    @Json(name = "token")
-    val token: String? = null
+    
 ) {
 
     companion object {
@@ -99,13 +94,6 @@ internal data class AuthParams(
             )
         }
 
-        fun createForRegistrationToken(session: String, token: String): AuthParams {
-            return AuthParams(
-                type = LoginFlowTypes.TOKEN,
-                session = session,
-                token = token
-            )
-        }
     }
 }
 
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/DefaultRegistrationWizard.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/DefaultRegistrationWizard.kt
index 89c640e77..b29ec98c4 100644
--- a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/DefaultRegistrationWizard.kt
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/DefaultRegistrationWizard.kt
@@ -17,14 +17,12 @@
 package org.matrix.android.sdk.internal.auth.registration
 
 import kotlinx.coroutines.delay
+import org.matrix.android.sdk.api.auth.data.Credentials
 import org.matrix.android.sdk.api.auth.data.LoginFlowTypes
-import org.matrix.android.sdk.api.auth.registration.RegisterThreePid
-import org.matrix.android.sdk.api.auth.registration.RegistrationAvailability
-import org.matrix.android.sdk.api.auth.registration.RegistrationResult
-import org.matrix.android.sdk.api.auth.registration.RegistrationWizard
-import org.matrix.android.sdk.api.auth.registration.toFlowResult
+import org.matrix.android.sdk.api.auth.registration.*
 import org.matrix.android.sdk.api.failure.Failure
 import org.matrix.android.sdk.api.failure.Failure.RegistrationFlowError
+import org.matrix.android.sdk.api.util.JsonDict
 import org.matrix.android.sdk.internal.auth.AuthAPI
 import org.matrix.android.sdk.internal.auth.PendingSessionStore
 import org.matrix.android.sdk.internal.auth.SessionCreator
@@ -39,22 +37,26 @@ internal class DefaultRegistrationWizard(
     private val pendingSessionStore: PendingSessionStore
 ) : RegistrationWizard {
 
-    private var pendingSessionData: PendingSessionData = pendingSessionStore.getPendingSessionData() ?: error("Pending session data should exist here")
+    private var pendingSessionData: PendingSessionData = pendingSessionStore.getPendingSessionData()
+        ?: error("Pending session data should exist here")
 
     private val registerTask: RegisterTask = DefaultRegisterTask(authAPI)
     private val registerAvailableTask: RegisterAvailableTask = DefaultRegisterAvailableTask(authAPI)
-    private val registerAddThreePidTask: RegisterAddThreePidTask = DefaultRegisterAddThreePidTask(authAPI)
+    private val registerAddThreePidTask: RegisterAddThreePidTask =
+        DefaultRegisterAddThreePidTask(authAPI)
+    private val registerOtherTask: RegisterOtherTask = DefaultRegisterOtherTask(authAPI)
     private val validateCodeTask: ValidateCodeTask = DefaultValidateCodeTask(authAPI)
 
     override val currentThreePid: String?
         get() {
             return when (val threePid = pendingSessionData.currentThreePidData?.threePid) {
-                is RegisterThreePid.Email  -> threePid.email
+                is RegisterThreePid.Email -> threePid.email
                 is RegisterThreePid.Msisdn -> {
                     // Take formatted msisdn if provided by the server
-                    pendingSessionData.currentThreePidData?.addThreePidRegistrationResponse?.formattedMsisdn?.takeIf { it.isNotBlank() } ?: threePid.msisdn
+                    pendingSessionData.currentThreePidData?.addThreePidRegistrationResponse?.formattedMsisdn?.takeIf { it.isNotBlank() }
+                        ?: threePid.msisdn
                 }
-                null                       -> null
+                null -> null
             }
         }
 
@@ -69,9 +71,11 @@ internal class DefaultRegistrationWizard(
         return performRegistrationRequest(params)
     }
 
-    override suspend fun createAccount(userName: String?,
-                                       password: String?,
-                                       initialDeviceDisplayName: String?): RegistrationResult {
+    override suspend fun createAccount(
+        userName: String?,
+        password: String?,
+        initialDeviceDisplayName: String?
+    ): RegistrationResult {
         val params = RegistrationParams(
             username = userName,
             password = password,
@@ -96,7 +100,12 @@ internal class DefaultRegistrationWizard(
         val safeSession = pendingSessionData.currentSession
             ?: throw IllegalStateException("developer error, call createAccount() method first")
 
-        val params = RegistrationParams(auth = AuthParams(type = LoginFlowTypes.TERMS, session = safeSession))
+        val params = RegistrationParams(
+            auth = AuthParams(
+                type = LoginFlowTypes.TERMS,
+                session = safeSession
+            )
+        )
         return performRegistrationRequest(params)
     }
 
@@ -115,26 +124,32 @@ internal class DefaultRegistrationWizard(
     }
 
     private suspend fun sendThreePid(threePid: RegisterThreePid): RegistrationResult {
-        val safeSession = pendingSessionData.currentSession ?: throw IllegalStateException("developer error, call createAccount() method first")
+        val safeSession = pendingSessionData.currentSession
+            ?: throw IllegalStateException("developer error, call createAccount() method first")
         val response = registerAddThreePidTask.execute(
             RegisterAddThreePidTask.Params(
                 threePid,
                 pendingSessionData.clientSecret,
-                pendingSessionData.sendAttempt))
+                pendingSessionData.sendAttempt
+            )
+        )
 
-        pendingSessionData = pendingSessionData.copy(sendAttempt = pendingSessionData.sendAttempt + 1)
-            .also { pendingSessionStore.savePendingSessionData(it) }
+        pendingSessionData =
+            pendingSessionData.copy(sendAttempt = pendingSessionData.sendAttempt + 1)
+                .also { pendingSessionStore.savePendingSessionData(it) }
 
         val params = RegistrationParams(
             auth = if (threePid is RegisterThreePid.Email) {
-                AuthParams.createForEmailIdentity(safeSession,
+                AuthParams.createForEmailIdentity(
+                    safeSession,
                     ThreePidCredentials(
                         clientSecret = pendingSessionData.clientSecret,
                         sid = response.sid
                     )
                 )
             } else {
-                AuthParams.createForMsisdnIdentity(safeSession,
+                AuthParams.createForMsisdnIdentity(
+                    safeSession,
                     ThreePidCredentials(
                         clientSecret = pendingSessionData.clientSecret,
                         sid = response.sid
@@ -143,7 +158,13 @@ internal class DefaultRegistrationWizard(
             }
         )
         // Store data
-        pendingSessionData = pendingSessionData.copy(currentThreePidData = ThreePidData.from(threePid, response, params))
+        pendingSessionData = pendingSessionData.copy(
+            currentThreePidData = ThreePidData.from(
+                threePid,
+                response,
+                params
+            )
+        )
             .also { pendingSessionStore.savePendingSessionData(it) }
 
         // and send the sid a first time
@@ -164,14 +185,18 @@ internal class DefaultRegistrationWizard(
     private suspend fun validateThreePid(code: String): RegistrationResult {
         val registrationParams = pendingSessionData.currentThreePidData?.registrationParams
             ?: throw IllegalStateException("developer error, no pending three pid")
-        val safeCurrentData = pendingSessionData.currentThreePidData ?: throw IllegalStateException("developer error, call createAccount() method first")
-        val url = safeCurrentData.addThreePidRegistrationResponse.submitUrl ?: throw IllegalStateException("Missing url to send the code")
+        val safeCurrentData = pendingSessionData.currentThreePidData ?: throw IllegalStateException(
+            "developer error, call createAccount() method first"
+        )
+        val url = safeCurrentData.addThreePidRegistrationResponse.submitUrl
+            ?: throw IllegalStateException("Missing url to send the code")
         val validationBody = ValidationCodeBody(
             clientSecret = pendingSessionData.clientSecret,
             sid = safeCurrentData.addThreePidRegistrationResponse.sid,
             code = code
         )
-        val validationResponse = validateCodeTask.execute(ValidateCodeTask.Params(url, validationBody))
+        val validationResponse =
+            validateCodeTask.execute(ValidateCodeTask.Params(url, validationBody))
         if (validationResponse.isSuccess()) {
             // The entered code is correct
             // Same than validate email
@@ -186,37 +211,67 @@ internal class DefaultRegistrationWizard(
         val safeSession = pendingSessionData.currentSession
             ?: throw IllegalStateException("developer error, call createAccount() method first")
 
-        val params = RegistrationParams(auth = AuthParams(type = LoginFlowTypes.DUMMY, session = safeSession))
+        val params = RegistrationParams(
+            auth = AuthParams(
+                type = LoginFlowTypes.DUMMY,
+                session = safeSession
+            )
+        )
         return performRegistrationRequest(params)
     }
 
-    override suspend fun registrationToken(token: String): RegistrationResult {
+    override suspend fun registrationOther(
+        authParams: JsonDict
+    ): RegistrationResult {
         val safeSession = pendingSessionData.currentSession
             ?: throw IllegalStateException("developer error, call createAccount() method first")
 
-        val params = RegistrationParams(auth = AuthParams.createForRegistrationToken(safeSession, token))
-        return performRegistrationRequest(params)
+        val mutableParams = authParams.toMutableMap()
+        mutableParams["session"] = safeSession
+
+        val params =
+            RegistrationOtherParams(
+                auth = mutableParams
+            )
+        return performRegistrationOtherRequest(params)
     }
 
-    private suspend fun performRegistrationRequest(registrationParams: RegistrationParams,
-                                                   delayMillis: Long = 0): RegistrationResult {
+    private suspend fun performRegistrationRequest(
+        registrationParams: RegistrationParams,
+        delayMillis: Long = 0
+    ): RegistrationResult {
         delay(delayMillis)
+        return register { registerTask.execute(RegisterTask.Params(registrationParams)) }
+    }
+
+    private suspend fun performRegistrationOtherRequest(
+        registrationOtherParams: RegistrationOtherParams
+    ): RegistrationResult {
+        return register { registerOtherTask.execute(RegisterOtherTask.Params(registrationOtherParams)) }
+    }
+
+    private suspend fun register(
+        execute: suspend () -> Credentials
+    ): RegistrationResult {
         val credentials = try {
-            registerTask.execute(RegisterTask.Params(registrationParams))
+            execute.invoke()
         } catch (exception: Throwable) {
             if (exception is RegistrationFlowError) {
-                pendingSessionData = pendingSessionData.copy(currentSession = exception.registrationFlowResponse.session)
-                    .also { pendingSessionStore.savePendingSessionData(it) }
+                pendingSessionData =
+                    pendingSessionData.copy(currentSession = exception.registrationFlowResponse.session)
+                        .also { pendingSessionStore.savePendingSessionData(it) }
                 return RegistrationResult.FlowResponse(exception.registrationFlowResponse.toFlowResult())
             } else {
                 throw exception
             }
         }
 
-        val session = sessionCreator.createSession(credentials, pendingSessionData.homeServerConnectionConfig)
+        val session =
+            sessionCreator.createSession(credentials, pendingSessionData.homeServerConnectionConfig)
         return RegistrationResult.Success(session)
     }
 
+
     override suspend fun registrationAvailable(userName: String): RegistrationAvailability {
         return registerAvailableTask.execute(RegisterAvailableTask.Params(userName))
     }
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegisterOtherTask.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegisterOtherTask.kt
new file mode 100644
index 000000000..bb2062115
--- /dev/null
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegisterOtherTask.kt
@@ -0,0 +1,31 @@
+package org.matrix.android.sdk.internal.auth.registration
+
+import org.matrix.android.sdk.api.auth.data.Credentials
+import org.matrix.android.sdk.api.failure.Failure
+import org.matrix.android.sdk.api.failure.toRegistrationFlowResponse
+import org.matrix.android.sdk.internal.auth.AuthAPI
+import org.matrix.android.sdk.internal.network.executeRequest
+import org.matrix.android.sdk.internal.task.Task
+
+internal interface RegisterOtherTask : Task<RegisterOtherTask.Params, Credentials> {
+    data class Params(
+        val registrationOtherParams: RegistrationOtherParams
+    )
+}
+
+internal class DefaultRegisterOtherTask(
+    private val authAPI: AuthAPI
+) : RegisterOtherTask {
+
+    override suspend fun execute(params: RegisterOtherTask.Params): Credentials {
+        try {
+            return executeRequest(null) {
+                authAPI.registerOther(params.registrationOtherParams)
+            }
+        } catch (throwable: Throwable) {
+            throw throwable.toRegistrationFlowResponse()
+                ?.let { Failure.RegistrationFlowError(it) }
+                ?: throwable
+        }
+    }
+}
\ No newline at end of file
diff --git a/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegistrationOtherParams.kt b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegistrationOtherParams.kt
new file mode 100644
index 000000000..2b7a4104e
--- /dev/null
+++ b/matrix-sdk-android/src/main/java/org/matrix/android/sdk/internal/auth/registration/RegistrationOtherParams.kt
@@ -0,0 +1,15 @@
+package org.matrix.android.sdk.internal.auth.registration
+
+import com.squareup.moshi.Json
+import com.squareup.moshi.JsonClass
+import org.matrix.android.sdk.api.util.JsonDict
+
+/**
+ * Class to pass parameters to the custom registration types for /register.
+ */
+@JsonClass(generateAdapter = true)
+internal data class RegistrationOtherParams(
+    // authentication parameters
+    @Json(name = "auth")
+    val auth: JsonDict? = null,
+)
\ No newline at end of file
-- 
GitLab