diff --git a/java/src/com/android/inputmethod/latin/LatinIME.java b/java/src/com/android/inputmethod/latin/LatinIME.java
index e7d1c53bd98bbf79eb41c767778d102a9e1a5fe8..80dda9c191d405a7506787cc7e6dfe415340ab29 100644
--- a/java/src/com/android/inputmethod/latin/LatinIME.java
+++ b/java/src/com/android/inputmethod/latin/LatinIME.java
@@ -144,7 +144,7 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
 
     private LastComposedWord mLastComposedWord = LastComposedWord.NOT_A_COMPOSED_WORD;
     private WordComposer mWordComposer = new WordComposer();
-    private RichInputConnection mConnection = new RichInputConnection();
+    private RichInputConnection mConnection = new RichInputConnection(this);
 
     // Keep track of the last selection range to decide if we need to show word alternatives
     private static final int NOT_A_CURSOR_POSITION = -1;
@@ -537,7 +537,7 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
         if (mDisplayOrientation != conf.orientation) {
             mDisplayOrientation = conf.orientation;
             mHandler.startOrientationChanging();
-            mConnection.beginBatchEdit(getCurrentInputConnection());
+            mConnection.beginBatchEdit();
             commitTyped(LastComposedWord.NOT_A_SEPARATOR);
             mConnection.finishComposingText();
             mConnection.endBatchEdit();
@@ -1213,7 +1213,7 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
             mDeleteCount = 0;
         }
         mLastKeyTime = when;
-        mConnection.beginBatchEdit(getCurrentInputConnection());
+        mConnection.beginBatchEdit();
 
         if (ProductionFlag.IS_EXPERIMENTAL) {
             ResearchLogger.latinIME_onCodeInput(primaryCode, x, y);
@@ -1298,7 +1298,7 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
 
     @Override
     public void onTextInput(CharSequence text) {
-        mConnection.beginBatchEdit(getCurrentInputConnection());
+        mConnection.beginBatchEdit();
         commitTyped(LastComposedWord.NOT_A_SEPARATOR);
         text = specificTldProcessingOnTextInput(text);
         if (SPACE_STATE_PHANTOM == mSpaceState) {
@@ -1819,7 +1819,7 @@ public class LatinIME extends InputMethodService implements KeyboardActionListen
             mKeyboardSwitcher.updateShiftState();
             resetComposingState(true /* alsoResetLastComposedWord */);
             final CompletionInfo completionInfo = mApplicationSpecifiedCompletions[index];
-            mConnection.beginBatchEdit(getCurrentInputConnection());
+            mConnection.beginBatchEdit();
             mConnection.commitCompletion(completionInfo);
             mConnection.endBatchEdit();
             if (ProductionFlag.IS_EXPERIMENTAL) {
diff --git a/java/src/com/android/inputmethod/latin/RichInputConnection.java b/java/src/com/android/inputmethod/latin/RichInputConnection.java
index 40d327ebb62ec29e7fcbcfb7d7a0853c4c958aec..a37f480b711d7675ff2cd0e5595d40d563ec8fe7 100644
--- a/java/src/com/android/inputmethod/latin/RichInputConnection.java
+++ b/java/src/com/android/inputmethod/latin/RichInputConnection.java
@@ -16,6 +16,7 @@
 
 package com.android.inputmethod.latin;
 
+import android.inputmethodservice.InputMethodService;
 import android.text.TextUtils;
 import android.util.Log;
 import android.view.KeyEvent;
@@ -41,16 +42,18 @@ public class RichInputConnection {
     private static final Pattern spaceRegex = Pattern.compile("\\s+");
     private static final int INVALID_CURSOR_POSITION = -1;
 
+    private final InputMethodService mParent;
     InputConnection mIC;
     int mNestLevel;
-    public RichInputConnection() {
+    public RichInputConnection(final InputMethodService parent) {
+        mParent = parent;
         mIC = null;
         mNestLevel = 0;
     }
 
-    public void beginBatchEdit(final InputConnection newInputConnection) {
+    public void beginBatchEdit() {
         if (++mNestLevel == 1) {
-            mIC = newInputConnection;
+            mIC = mParent.getCurrentInputConnection();
             if (null != mIC) mIC.beginBatchEdit();
         } else {
             if (DBG) {
@@ -84,16 +87,19 @@ public class RichInputConnection {
     }
 
     public int getCursorCapsMode(final int inputType) {
+        mIC = mParent.getCurrentInputConnection();
         if (null == mIC) return Constants.TextUtils.CAP_MODE_OFF;
         return mIC.getCursorCapsMode(inputType);
     }
 
     public CharSequence getTextBeforeCursor(final int i, final int j) {
+        mIC = mParent.getCurrentInputConnection();
         if (null != mIC) return mIC.getTextBeforeCursor(i, j);
         return null;
     }
 
     public CharSequence getTextAfterCursor(final int i, final int j) {
+        mIC = mParent.getCurrentInputConnection();
         if (null != mIC) return mIC.getTextAfterCursor(i, j);
         return null;
     }
@@ -104,6 +110,7 @@ public class RichInputConnection {
     }
 
     public void performEditorAction(final int actionId) {
+        mIC = mParent.getCurrentInputConnection();
         if (null != mIC) mIC.performEditorAction(actionId);
     }
 
@@ -133,6 +140,7 @@ public class RichInputConnection {
     }
 
     public CharSequence getPreviousWord(final String sentenceSeperators) {
+        mIC = mParent.getCurrentInputConnection();
         //TODO: Should fix this. This could be slow!
         if (null == mIC) return null;
         CharSequence prev = mIC.getTextBeforeCursor(LOOKBACK_CHARACTER_NUM, 0);
@@ -194,6 +202,7 @@ public class RichInputConnection {
     }
 
     public CharSequence getThisWord(String sentenceSeperators) {
+        mIC = mParent.getCurrentInputConnection();
         if (null == mIC) return null;
         final CharSequence prev = mIC.getTextBeforeCursor(LOOKBACK_CHARACTER_NUM, 0);
         return getThisWord(prev, sentenceSeperators);
@@ -233,6 +242,7 @@ public class RichInputConnection {
     }
 
     private int getCursorPosition() {
+        mIC = mParent.getCurrentInputConnection();
         if (null == mIC) return INVALID_CURSOR_POSITION;
         final ExtractedText extracted = mIC.getExtractedText(new ExtractedTextRequest(), 0);
         if (extracted == null) {
@@ -250,6 +260,7 @@ public class RichInputConnection {
      * @return a range containing the text surrounding the cursor
      */
     public Range getWordRangeAtCursor(String sep, int additionalPrecedingWordsCount) {
+        mIC = mParent.getCurrentInputConnection();
         if (mIC == null || sep == null) {
             return null;
         }
diff --git a/tests/src/com/android/inputmethod/latin/RichInputConnectionTests.java b/tests/src/com/android/inputmethod/latin/RichInputConnectionTests.java
index 9ce581df8c458b65e4d55cf55af56f389fa8aadf..7bd7b0e5a9da00ede5f806618b67f817c560a79c 100644
--- a/tests/src/com/android/inputmethod/latin/RichInputConnectionTests.java
+++ b/tests/src/com/android/inputmethod/latin/RichInputConnectionTests.java
@@ -16,6 +16,7 @@
 
 package com.android.inputmethod.latin;
 
+import android.inputmethodservice.InputMethodService;
 import android.test.AndroidTestCase;
 import android.view.inputmethod.ExtractedText;
 import android.view.inputmethod.ExtractedTextRequest;
@@ -83,6 +84,17 @@ public class RichInputConnectionTests extends AndroidTestCase {
         }
     }
 
+    private class MockInputMethodService extends InputMethodService {
+        InputConnection mInputConnection;
+        public void setInputConnection(final InputConnection inputConnection) {
+            mInputConnection = inputConnection;
+        }
+        @Override
+        public InputConnection getCurrentInputConnection() {
+            return mInputConnection;
+        }
+    }
+
     /************************** Tests ************************/
 
     /**
@@ -122,14 +134,14 @@ public class RichInputConnectionTests extends AndroidTestCase {
      */
     public void testGetWordRangeAtCursor() {
         ExtractedText et = new ExtractedText();
-        final RichInputConnection ic = new RichInputConnection();
-        InputConnection mockConnection;
-        mockConnection = new MockConnection("word wo", "rd", et);
+        final MockInputMethodService mockInputMethodService = new MockInputMethodService();
+        final RichInputConnection ic = new RichInputConnection(mockInputMethodService);
+        mockInputMethodService.setInputConnection(new MockConnection("word wo", "rd", et));
         et.startOffset = 0;
         et.selectionStart = 7;
         Range r;
 
-        ic.beginBatchEdit(mockConnection);
+        ic.beginBatchEdit();
         // basic case
         r = ic.getWordRangeAtCursor(" ", 0);
         assertEquals("word", r.mWord);
@@ -140,37 +152,38 @@ public class RichInputConnectionTests extends AndroidTestCase {
         ic.endBatchEdit();
 
         // tab character instead of space
-        mockConnection = new MockConnection("one\tword\two", "rd", et);
-        ic.beginBatchEdit(mockConnection);
+        mockInputMethodService.setInputConnection(new MockConnection("one\tword\two", "rd", et));
+        ic.beginBatchEdit();
         r = ic.getWordRangeAtCursor("\t", 1);
         ic.endBatchEdit();
         assertEquals("word\tword", r.mWord);
 
         // only one word doesn't go too far
-        mockConnection = new MockConnection("one\tword\two", "rd", et);
-        ic.beginBatchEdit(mockConnection);
+        mockInputMethodService.setInputConnection(new MockConnection("one\tword\two", "rd", et));
+        ic.beginBatchEdit();
         r = ic.getWordRangeAtCursor("\t", 1);
         ic.endBatchEdit();
         assertEquals("word\tword", r.mWord);
 
         // tab or space
-        mockConnection = new MockConnection("one word\two", "rd", et);
-        ic.beginBatchEdit(mockConnection);
+        mockInputMethodService.setInputConnection(new MockConnection("one word\two", "rd", et));
+        ic.beginBatchEdit();
         r = ic.getWordRangeAtCursor(" \t", 1);
         ic.endBatchEdit();
         assertEquals("word\tword", r.mWord);
 
         // tab or space multiword
-        mockConnection = new MockConnection("one word\two", "rd", et);
-        ic.beginBatchEdit(mockConnection);
+        mockInputMethodService.setInputConnection(new MockConnection("one word\two", "rd", et));
+        ic.beginBatchEdit();
         r = ic.getWordRangeAtCursor(" \t", 2);
         ic.endBatchEdit();
         assertEquals("one word\tword", r.mWord);
 
         // splitting on supplementary character
         final String supplementaryChar = "\uD840\uDC8A";
-        mockConnection = new MockConnection("one word" + supplementaryChar + "wo", "rd", et);
-        ic.beginBatchEdit(mockConnection);
+        mockInputMethodService.setInputConnection(
+                new MockConnection("one word" + supplementaryChar + "wo", "rd", et));
+        ic.beginBatchEdit();
         r = ic.getWordRangeAtCursor(supplementaryChar, 0);
         ic.endBatchEdit();
         assertEquals("word", r.mWord);