diff --git a/java/src/com/android/inputmethod/keyboard/Key.java b/java/src/com/android/inputmethod/keyboard/Key.java
index 0d78c39f21273bd527ae6c7bcd6cdf8405d44669..4d7fe3d8e3406474aedb428e1e3600eef015cd43 100644
--- a/java/src/com/android/inputmethod/keyboard/Key.java
+++ b/java/src/com/android/inputmethod/keyboard/Key.java
@@ -33,12 +33,10 @@ import android.util.Xml;
 
 import com.android.inputmethod.keyboard.internal.KeySpecParser;
 import com.android.inputmethod.keyboard.internal.KeySpecParser.MoreKeySpec;
-import com.android.inputmethod.keyboard.internal.KeyStyles;
 import com.android.inputmethod.keyboard.internal.KeyStyles.KeyStyle;
 import com.android.inputmethod.keyboard.internal.KeyboardIconsSet;
 import com.android.inputmethod.latin.R;
 import com.android.inputmethod.latin.StringUtils;
-import com.android.inputmethod.latin.XmlParseUtils;
 
 import org.xmlpull.v1.XmlPullParser;
 import org.xmlpull.v1.XmlPullParserException;
@@ -201,7 +199,6 @@ public class Key {
      */
     public Key(Resources res, Keyboard.Params params, Keyboard.Builder.Row row,
             XmlPullParser parser) throws XmlPullParserException {
-        final KeyStyles keyStyles = params.mKeyStyles;
         final float horizontalGap = isSpacer() ? 0 : params.mHorizontalGap;
         final int keyHeight = row.mRowHeight;
         mVerticalGap = params.mVerticalGap;
@@ -210,17 +207,7 @@ public class Key {
         final TypedArray keyAttr = res.obtainAttributes(Xml.asAttributeSet(parser),
                 R.styleable.Keyboard_Key);
 
-        final KeyStyle style;
-        if (keyAttr.hasValue(R.styleable.Keyboard_Key_keyStyle)) {
-            String styleName = keyAttr.getString(R.styleable.Keyboard_Key_keyStyle);
-            style = keyStyles.getKeyStyle(styleName);
-            if (style == null) {
-                throw new XmlParseUtils.ParseException("Unknown key style: " + styleName, parser);
-            }
-        } else {
-            style = keyStyles.getEmptyKeyStyle();
-        }
-
+        final KeyStyle style = params.mKeyStyles.getKeyStyle(keyAttr, parser);
         final float keyXPos = row.getKeyX(keyAttr);
         final float keyWidth = row.getKeyWidth(keyAttr, keyXPos);
         final int keyYPos = row.getKeyY();
diff --git a/java/src/com/android/inputmethod/keyboard/internal/KeyStyles.java b/java/src/com/android/inputmethod/keyboard/internal/KeyStyles.java
index b32172ebec1995094d5e8b95e266c5c2b4b27bdd..80f4f259b1a7b325b4ad89b93d1d885ad4787b43 100644
--- a/java/src/com/android/inputmethod/keyboard/internal/KeyStyles.java
+++ b/java/src/com/android/inputmethod/keyboard/internal/KeyStyles.java
@@ -32,24 +32,19 @@ public class KeyStyles {
     private static final String TAG = KeyStyles.class.getSimpleName();
     private static final boolean DEBUG = false;
 
-    private final HashMap<String, DeclaredKeyStyle> mStyles =
-            new HashMap<String, DeclaredKeyStyle>();
+    final HashMap<String, KeyStyle> mStyles = new HashMap<String, KeyStyle>();
 
-    private final KeyboardTextsSet mTextsSet;
+    final KeyboardTextsSet mTextsSet;
     private final KeyStyle mEmptyKeyStyle;
+    private static final String EMPTY_STYLE_NAME = "<empty>";
 
     public KeyStyles(KeyboardTextsSet textsSet) {
         mTextsSet = textsSet;
-        mEmptyKeyStyle = new EmptyKeyStyle(textsSet);
+        mEmptyKeyStyle = new EmptyKeyStyle();
+        mStyles.put(EMPTY_STYLE_NAME, mEmptyKeyStyle);
     }
 
-    public static abstract class KeyStyle {
-        protected final KeyboardTextsSet mTextsSet;
-
-        public KeyStyle(KeyboardTextsSet textsSet) {
-            mTextsSet = textsSet;
-        }
-
+    public abstract class KeyStyle {
         public abstract String[] getStringArray(TypedArray a, int index);
         public abstract String getString(TypedArray a, int index);
         public abstract int getInt(TypedArray a, int index, int defaultValue);
@@ -70,11 +65,7 @@ public class KeyStyles {
         }
     }
 
-    private static class EmptyKeyStyle extends KeyStyle {
-        public EmptyKeyStyle(KeyboardTextsSet textsSet) {
-            super(textsSet);
-        }
-
+    class EmptyKeyStyle extends KeyStyle {
         @Override
         public String[] getStringArray(TypedArray a, int index) {
             return parseStringArray(a, index);
@@ -96,11 +87,12 @@ public class KeyStyles {
         }
     }
 
-    private static class DeclaredKeyStyle extends KeyStyle {
+    private class DeclaredKeyStyle extends KeyStyle {
+        private final String mParentStyleName;
         private final HashMap<Integer, Object> mStyleAttributes = new HashMap<Integer, Object>();
 
-        public DeclaredKeyStyle(KeyboardTextsSet textsSet) {
-            super(textsSet);
+        public DeclaredKeyStyle(String parentStyleName) {
+            mParentStyleName = parentStyleName;
         }
 
         @Override
@@ -108,7 +100,11 @@ public class KeyStyles {
             if (a.hasValue(index)) {
                 return parseStringArray(a, index);
             }
-            return (String[])mStyleAttributes.get(index);
+            if (mStyleAttributes.containsKey(index)) {
+                return (String[])mStyleAttributes.get(index);
+            }
+            final KeyStyle parentStyle = mStyles.get(mParentStyleName);
+            return parentStyle.getStringArray(a, index);
         }
 
         @Override
@@ -116,7 +112,11 @@ public class KeyStyles {
             if (a.hasValue(index)) {
                 return parseString(a, index);
             }
-            return (String)mStyleAttributes.get(index);
+            if (mStyleAttributes.containsKey(index)) {
+                return (String)mStyleAttributes.get(index);
+            }
+            final KeyStyle parentStyle = mStyles.get(mParentStyleName);
+            return parentStyle.getString(a, index);
         }
 
         @Override
@@ -124,15 +124,21 @@ public class KeyStyles {
             if (a.hasValue(index)) {
                 return a.getInt(index, defaultValue);
             }
-            final Integer styleValue = (Integer)mStyleAttributes.get(index);
-            return styleValue != null ? styleValue : defaultValue;
+            if (mStyleAttributes.containsKey(index)) {
+                return (Integer)mStyleAttributes.get(index);
+            }
+            final KeyStyle parentStyle = mStyles.get(mParentStyleName);
+            return parentStyle.getInt(a, index, defaultValue);
         }
 
         @Override
         public int getFlag(TypedArray a, int index) {
-            final int value = a.getInt(index, 0);
-            final Integer styleValue = (Integer)mStyleAttributes.get(index);
-            return (styleValue != null ? styleValue : 0) | value;
+            int value = a.getInt(index, 0);
+            if (mStyleAttributes.containsKey(index)) {
+                value |= (Integer)mStyleAttributes.get(index);
+            }
+            final KeyStyle parentStyle = mStyles.get(mParentStyleName);
+            return value | parentStyle.getFlag(a, index);
         }
 
         void readKeyAttributes(TypedArray keyAttr) {
@@ -177,10 +183,6 @@ public class KeyStyles {
                 mStyleAttributes.put(index, parseStringArray(a, index));
             }
         }
-
-        void addParentStyleAttributes(DeclaredKeyStyle parentStyle) {
-            mStyleAttributes.putAll(parentStyle.mStyleAttributes);
-        }
     }
 
     public void parseKeyStyleAttributes(TypedArray keyStyleAttr, TypedArray keyAttrs,
@@ -195,26 +197,28 @@ public class KeyStyles {
             }
         }
 
-        final DeclaredKeyStyle style = new DeclaredKeyStyle(mTextsSet);
+        String parentStyleName = EMPTY_STYLE_NAME;
         if (keyStyleAttr.hasValue(R.styleable.Keyboard_KeyStyle_parentStyle)) {
-            final String parentStyle = keyStyleAttr.getString(
-                    R.styleable.Keyboard_KeyStyle_parentStyle);
-            final DeclaredKeyStyle parent = mStyles.get(parentStyle);
-            if (parent == null) {
+            parentStyleName = keyStyleAttr.getString(R.styleable.Keyboard_KeyStyle_parentStyle);
+            if (!mStyles.containsKey(parentStyleName)) {
                 throw new XmlParseUtils.ParseException(
-                        "Unknown parentStyle " + parentStyle, parser);
+                        "Unknown parentStyle " + parentStyleName, parser);
             }
-            style.addParentStyleAttributes(parent);
         }
+        final DeclaredKeyStyle style = new DeclaredKeyStyle(parentStyleName);
         style.readKeyAttributes(keyAttrs);
         mStyles.put(styleName, style);
     }
 
-    public KeyStyle getKeyStyle(String styleName) {
+    public KeyStyle getKeyStyle(TypedArray keyAttr, XmlPullParser parser)
+            throws XmlParseUtils.ParseException {
+        if (!keyAttr.hasValue(R.styleable.Keyboard_Key_keyStyle)) {
+            return mEmptyKeyStyle;
+        }
+        final String styleName = keyAttr.getString(R.styleable.Keyboard_Key_keyStyle);
+        if (!mStyles.containsKey(styleName)) {
+            throw new XmlParseUtils.ParseException("Unknown key style: " + styleName, parser);
+        }
         return mStyles.get(styleName);
     }
-
-    public KeyStyle getEmptyKeyStyle() {
-        return mEmptyKeyStyle;
-    }
 }