diff --git a/native/jni/src/proximity_info_state.cpp b/native/jni/src/proximity_info_state.cpp
index 0f7e4d65f535dc4ab7a6e375b00383401fd6ec19..bbc0deedec98c84d5e38b972d26568d01460606a 100644
--- a/native/jni/src/proximity_info_state.cpp
+++ b/native/jni/src/proximity_info_state.cpp
@@ -108,6 +108,7 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
         mSearchKeysVector.clear();
         mRelativeSpeeds.clear();
         mCharProbabilities.clear();
+        mDirections.clear();
     }
     if (DEBUG_GEO_FULL) {
         AKLOGI("Init ProximityInfoState: reused points =  %d, last input size = %d",
@@ -216,6 +217,13 @@ void ProximityInfoState::initInputParams(const int pointerId, const float maxPoi
                 mRelativeSpeeds[i] = speed / averageSpeed;
             }
         }
+
+        // Direction calculation.
+        mDirections.resize(mInputSize - 1);
+        for (int i = max(0, lastSavedInputSize - 1); i < mInputSize - 1; ++i) {
+            mDirections[i] = getDirection(i, i + 1);
+        }
+
     }
 
     if (DEBUG_GEO_FULL) {
diff --git a/native/jni/src/proximity_info_state.h b/native/jni/src/proximity_info_state.h
index 927244b0243c844c137fdc5c48cc736fc5080ab6..1a3f2869d575173366f51d6702f3f1e679bb414f 100644
--- a/native/jni/src/proximity_info_state.h
+++ b/native/jni/src/proximity_info_state.h
@@ -55,8 +55,8 @@ class ProximityInfoState {
               mHasTouchPositionCorrectionData(false), mMostCommonKeyWidthSquare(0), mLocaleStr(),
               mKeyCount(0), mCellHeight(0), mCellWidth(0), mGridHeight(0), mGridWidth(0),
               mIsContinuationPossible(false), mInputXs(), mInputYs(), mTimes(), mInputIndice(),
-              mDistanceCache(), mLengthCache(), mRelativeSpeeds(), mCharProbabilities(),
-              mNearKeysVector(), mSearchKeysVector(),
+              mDistanceCache(), mLengthCache(), mRelativeSpeeds(), mDirections(),
+              mCharProbabilities(), mNearKeysVector(), mSearchKeysVector(),
               mTouchPositionCorrectionEnabled(false), mInputSize(0) {
         memset(mInputCodes, 0, sizeof(mInputCodes));
         memset(mNormalizedSquaredDistances, 0, sizeof(mNormalizedSquaredDistances));
@@ -226,6 +226,9 @@ class ProximityInfoState {
         return mRelativeSpeeds[index];
     }
 
+    float getDirection(const int index) const {
+        return mDirections[index];
+    }
     // get xy direction
     float getDirection(const int x, const int y) const;
 
@@ -306,6 +309,7 @@ class ProximityInfoState {
     std::vector<float> mDistanceCache;
     std::vector<int>  mLengthCache;
     std::vector<float> mRelativeSpeeds;
+    std::vector<float> mDirections;
     // probabilities of skipping or mapping to a key for each point.
     std::vector<hash_map_compat<int, float> > mCharProbabilities;
     // The vector for the key code set which holds nearby keys for each sampled input point