diff --git a/java/src/com/android/inputmethod/keyboard/ProximityInfo.java b/java/src/com/android/inputmethod/keyboard/ProximityInfo.java
index 57d3fede4c3fceca5dd7eee745b2bd3ee4c2710c..0e86b051de898deb72a49c51c9b6e6fdb2ac4d6e 100644
--- a/java/src/com/android/inputmethod/keyboard/ProximityInfo.java
+++ b/java/src/com/android/inputmethod/keyboard/ProximityInfo.java
@@ -240,28 +240,121 @@ public class ProximityInfo {
 
     private void computeNearestNeighbors() {
         final int defaultWidth = mMostCommonKeyWidth;
-        final Key[] keys = mKeys;
-        final int thresholdBase = (int) (defaultWidth * SEARCH_DISTANCE);
-        final int threshold = thresholdBase * thresholdBase;
+        final int keyCount = mKeys.length;
+        final int gridSize = mGridNeighbors.length;
+        final int threshold = (int) (defaultWidth * SEARCH_DISTANCE);
+        final int thresholdSquared = threshold * threshold;
         // Round-up so we don't have any pixels outside the grid
-        final Key[] neighborKeys = new Key[keys.length];
-        final int gridWidth = mGridWidth * mCellWidth;
-        final int gridHeight = mGridHeight * mCellHeight;
-        for (int x = 0; x < gridWidth; x += mCellWidth) {
-            for (int y = 0; y < gridHeight; y += mCellHeight) {
-                final int centerX = x + mCellWidth / 2;
-                final int centerY = y + mCellHeight / 2;
-                int count = 0;
-                for (final Key key : keys) {
-                    if (key.isSpacer()) continue;
-                    if (key.squaredDistanceToEdge(centerX, centerY) < threshold) {
-                        neighborKeys[count++] = key;
+        final int fullGridWidth = mGridWidth * mCellWidth;
+        final int fullGridHeight = mGridHeight * mCellHeight;
+
+        // For large layouts, 'neighborsFlatBuffer' is about 80k of memory: gridSize is usually 512,
+        // keycount is about 40 and a pointer to a Key is 4 bytes. This contains, for each cell,
+        // enough space for as many keys as there are on the keyboard. Hence, every
+        // keycount'th element is the start of a new cell, and each of these virtual subarrays
+        // start empty with keycount spaces available. This fills up gradually in the loop below.
+        // Since in the practice each cell does not have a lot of neighbors, most of this space is
+        // actually just empty padding in this fixed-size buffer.
+        final Key[] neighborsFlatBuffer = new Key[gridSize * keyCount];
+        final int[] neighborCountPerCell = new int[gridSize];
+        final int halfCellWidth = mCellWidth / 2;
+        final int halfCellHeight = mCellHeight / 2;
+        for (final Key key : mKeys) {
+            if (key.isSpacer()) continue;
+
+/* HOW WE PRE-SELECT THE CELLS (iterate over only the relevant cells, instead of all of them)
+
+  We want to compute the distance for keys that are in the cells that are close enough to the
+  key border, as this method is performance-critical. These keys are represented with 'star'
+  background on the diagram below. Let's consider the Y case first.
+
+  We want to select the cells which center falls between the top of the key minus the threshold,
+  and the bottom of the key plus the threshold.
+  topPixelWithinThreshold is key.mY - threshold, and bottomPixelWithinThreshold is
+  key.mY + key.mHeight + threshold.
+
+  Then we need to compute the center of the top row that we need to evaluate, as we'll iterate
+  from there.
+
+(0,0)----> x
+| .-------------------------------------------.
+| |   |   |   |   |   |   |   |   |   |   |   |
+| |---+---+---+---+---+---+---+---+---+---+---|   .- top of top cell (aligned on the grid)
+| |   |   |   |   |   |   |   |   |   |   |   |   |
+| |-----------+---+---+---+---+---+---+---+---|---'                          v
+| |   |   |   |***|***|*_________________________ topPixelWithinThreshold    | yDeltaToGrid
+| |---+---+---+-----^-+-|-+---+---+---+---+---|                              ^
+| |   |   |   |***|*|*|*|*|***|***|   |   |   |           ______________________________________
+v |---+---+--threshold--|-+---+---+---+---+---|          |
+  |   |   |   |***|*|*|*|*|***|***|   |   |   |          | Starting from key.mY, we substract
+y |---+---+---+---+-v-+-|-+---+---+---+---+---|          | thresholdBase and get the top pixel
+  |   |   |   |***|**########------------------- key.mY  | within the threshold. We align that on
+  |---+---+---+---+--#+---+-#-+---+---+---+---|          | the grid by computing the delta to the
+  |   |   |   |***|**#|***|*#*|***|   |   |   |          | grid, and get the top of the top cell.
+  |---+---+---+---+--#+---+-#-+---+---+---+---|          |
+  |   |   |   |***|**########*|***|   |   |   |          | Adding half the cell height to the top
+  |---+---+---+---+---+-|-+---+---+---+---+---|          | of the top cell, we get the middle of
+  |   |   |   |***|***|*|*|***|***|   |   |   |          | the top cell (yMiddleOfTopCell).
+  |---+---+---+---+---+-|-+---+---+---+---+---|          |
+  |   |   |   |***|***|*|*|***|***|   |   |   |          |
+  |---+---+---+---+---+-|________________________ yEnd   | Since we only want to add the key to
+  |   |   |   |   |   |   | (bottomPixelWithinThreshold) | the proximity if it's close enough to
+  |---+---+---+---+---+---+---+---+---+---+---|          | the center of the cell, we only need
+  |   |   |   |   |   |   |   |   |   |   |   |          | to compute for these cells where
+  '---'---'---'---'---'---'---'---'---'---'---'          | topPixelWithinThreshold is above the
+                                        (positive x,y)   | center of the cell. This is the case
+                                                         | when yDeltaToGrid is less than half
+  [Zoomed in diagram]                                    | the height of the cell.
+  +-------+-------+-------+-------+-------+              |
+  |       |       |       |       |       |              | On the zoomed in diagram, on the right
+  |       |       |       |       |       |              | the topPixelWithinThreshold (represented
+  |       |       |       |       |       |      top of  | with an = sign) is below and we can skip
+  +-------+-------+-------+--v----+-------+ .. top cell  | this cell, while on the left it's above
+  |       | = topPixelWT  |  |  yDeltaToGrid             | and we need to compute for this cell.
+  |..yStart.|.....|.......|..|....|.......|... y middle  | Thus, if yDeltaToGrid is more than half
+  |   (left)|     |       |  ^ =  |       | of top cell  | the height of the cell, we start the
+  +-------+-|-----+-------+----|--+-------+              | iteration one cell below the top cell,
+  |       | |     |       |    |  |       |              | else we start it on the top cell. This
+  |.......|.|.....|.......|....|..|.....yStart (right)   | is stored in yStart.
+
+  Since we only want to go up to bottomPixelWithinThreshold, and we only iterate on the center
+  of the keys, we can stop as soon as the y value exceeds bottomPixelThreshold, so we don't
+  have to align this on the center of the key. Hence, we don't need a separate value for
+  bottomPixelWithinThreshold and call this yEnd right away.
+*/
+            final int topPixelWithinThreshold = key.mY - threshold;
+            final int yDeltaToGrid = topPixelWithinThreshold % mCellHeight;
+            final int yMiddleOfTopCell = topPixelWithinThreshold - yDeltaToGrid + halfCellHeight;
+            final int yStart = Math.max(halfCellHeight,
+                    yMiddleOfTopCell + (yDeltaToGrid <= halfCellHeight ? 0 : mCellHeight));
+            final int yEnd = Math.min(fullGridHeight, key.mY + key.mHeight + threshold);
+
+            final int leftPixelWithinThreshold = key.mX - threshold;
+            final int xDeltaToGrid = leftPixelWithinThreshold % mCellWidth;
+            final int xMiddleOfLeftCell = leftPixelWithinThreshold - xDeltaToGrid + halfCellWidth;
+            final int xStart = Math.max(halfCellWidth,
+                    xMiddleOfLeftCell + (xDeltaToGrid <= halfCellWidth ? 0 : mCellWidth));
+            final int xEnd = Math.min(fullGridWidth, key.mX + key.mWidth + threshold);
+
+            int baseIndexOfCurrentRow = (yStart / mCellHeight) * mGridWidth + (xStart / mCellWidth);
+            for (int centerY = yStart; centerY <= yEnd; centerY += mCellHeight) {
+                int index = baseIndexOfCurrentRow;
+                for (int centerX = xStart; centerX <= xEnd; centerX += mCellWidth) {
+                    if (key.squaredDistanceToEdge(centerX, centerY) < thresholdSquared) {
+                        neighborsFlatBuffer[index * keyCount + neighborCountPerCell[index]] = key;
+                        ++neighborCountPerCell[index];
                     }
+                    ++index;
                 }
-                mGridNeighbors[(y / mCellHeight) * mGridWidth + (x / mCellWidth)] =
-                        Arrays.copyOfRange(neighborKeys, 0, count);
+                baseIndexOfCurrentRow += mGridWidth;
             }
         }
+
+        for (int i = 0; i < gridSize; ++i) {
+            final int base = i * keyCount;
+            mGridNeighbors[i] =
+                    Arrays.copyOfRange(neighborsFlatBuffer, base, base + neighborCountPerCell[i]);
+        }
     }
 
     public void fillArrayWithNearestKeyCodes(final int x, final int y, final int primaryKeyCode,