| OLD | NEW |
| (Empty) | |
| 1 /* |
| 2 * Copyright (C) 2012 The Android Open Source Project |
| 3 * |
| 4 * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 * you may not use this file except in compliance with the License. |
| 6 * You may obtain a copy of the License at |
| 7 * |
| 8 * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 * |
| 10 * Unless required by applicable law or agreed to in writing, software |
| 11 * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 * See the License for the specific language governing permissions and |
| 14 * limitations under the License. |
| 15 */ |
| 16 |
| 17 #ifndef LATINIME_TYPING_WEIGHTING_H |
| 18 #define LATINIME_TYPING_WEIGHTING_H |
| 19 |
| 20 #include "third_party/prediction/defines.h" |
| 21 #include "third_party/prediction/suggest/core/dicnode/dic_node_utils.h" |
| 22 #include "third_party/prediction/suggest/core/dictionary/error_type_utils.h" |
| 23 #include "third_party/prediction/suggest/core/layout/touch_position_correction_u
tils.h" |
| 24 #include "third_party/prediction/suggest/core/policy/weighting.h" |
| 25 #include "third_party/prediction/suggest/core/session/dic_traverse_session.h" |
| 26 #include "third_party/prediction/suggest/policyimpl/typing/scoring_params.h" |
| 27 #include "third_party/prediction/utils/char_utils.h" |
| 28 |
| 29 namespace latinime { |
| 30 |
| 31 class DicNode; |
| 32 struct DicNode_InputStateG; |
| 33 class MultiBigramMap; |
| 34 |
| 35 class TypingWeighting : public Weighting { |
| 36 public: |
| 37 static const TypingWeighting* getInstance() { return &sInstance; } |
| 38 |
| 39 protected: |
| 40 float getTerminalSpatialCost(const DicTraverseSession* const traverseSession, |
| 41 const DicNode* const dicNode) const { |
| 42 float cost = 0.0f; |
| 43 if (dicNode->hasMultipleWords()) { |
| 44 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; |
| 45 } |
| 46 if (dicNode->getProximityCorrectionCount() > 0) { |
| 47 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; |
| 48 } |
| 49 if (dicNode->getEditCorrectionCount() > 0) { |
| 50 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; |
| 51 } |
| 52 return cost; |
| 53 } |
| 54 |
| 55 float getOmissionCost(const DicNode* const parentDicNode, |
| 56 const DicNode* const dicNode) const { |
| 57 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); |
| 58 const bool isIntentionalOmission = |
| 59 parentDicNode->canBeIntentionalOmission(); |
| 60 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); |
| 61 // If the traversal omitted the first letter then the dicNode should now be |
| 62 // on the second. |
| 63 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2; |
| 64 float cost = 0.0f; |
| 65 if (isZeroCostOmission) { |
| 66 cost = 0.0f; |
| 67 } else if (isIntentionalOmission) { |
| 68 cost = ScoringParams::INTENTIONAL_OMISSION_COST; |
| 69 } else if (isFirstLetterOmission) { |
| 70 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; |
| 71 } else { |
| 72 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR |
| 73 : ScoringParams::OMISSION_COST; |
| 74 } |
| 75 return cost; |
| 76 } |
| 77 |
| 78 float getMatchedCost(const DicTraverseSession* const traverseSession, |
| 79 const DicNode* const dicNode, |
| 80 DicNode_InputStateG* inputStateG) const { |
| 81 const int pointIndex = dicNode->getInputIndex(0); |
| 82 const float normalizedSquaredLength = |
| 83 traverseSession->getProximityInfoState(0)->getPointToKeyLength( |
| 84 pointIndex, |
| 85 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); |
| 86 const float normalizedDistance = |
| 87 TouchPositionCorrectionUtils::getSweetSpotFactor( |
| 88 traverseSession->isTouchPositionCorrectionEnabled(), |
| 89 normalizedSquaredLength); |
| 90 const float weightedDistance = |
| 91 ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance; |
| 92 |
| 93 const bool isFirstChar = pointIndex == 0; |
| 94 const bool isProximity = isProximityDicNode(traverseSession, dicNode); |
| 95 float cost = isProximity |
| 96 ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST |
| 97 : ScoringParams::PROXIMITY_COST) |
| 98 : 0.0f; |
| 99 if (isProximity && dicNode->getProximityCorrectionCount() == 0) { |
| 100 cost += ScoringParams::FIRST_PROXIMITY_COST; |
| 101 } |
| 102 if (dicNode->getNodeCodePointCount() == 2) { |
| 103 // At the second character of the current word, we check if the first char |
| 104 // is uppercase |
| 105 // and the word is a second or later word of a multiple word suggestion. |
| 106 // We demote it |
| 107 // if so. |
| 108 const bool isSecondOrLaterWordFirstCharUppercase = |
| 109 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase(); |
| 110 if (isSecondOrLaterWordFirstCharUppercase) { |
| 111 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE; |
| 112 } |
| 113 } |
| 114 return weightedDistance + cost; |
| 115 } |
| 116 |
| 117 bool isProximityDicNode(const DicTraverseSession* const traverseSession, |
| 118 const DicNode* const dicNode) const { |
| 119 const int pointIndex = dicNode->getInputIndex(0); |
| 120 const int primaryCodePoint = |
| 121 CharUtils::toBaseLowerCase(traverseSession->getProximityInfoState(0) |
| 122 ->getPrimaryCodePointAt(pointIndex)); |
| 123 const int dicNodeChar = |
| 124 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()); |
| 125 return primaryCodePoint != dicNodeChar; |
| 126 } |
| 127 |
| 128 float getTranspositionCost(const DicTraverseSession* const traverseSession, |
| 129 const DicNode* const parentDicNode, |
| 130 const DicNode* const dicNode) const { |
| 131 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); |
| 132 const int prevCodePoint = parentDicNode->getNodeCodePoint(); |
| 133 const float distance1 = |
| 134 traverseSession->getProximityInfoState(0)->getPointToKeyLength( |
| 135 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint)); |
| 136 const int codePoint = dicNode->getNodeCodePoint(); |
| 137 const float distance2 = |
| 138 traverseSession->getProximityInfoState(0)->getPointToKeyLength( |
| 139 parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); |
| 140 const float distance = distance1 + distance2; |
| 141 const float weightedLengthDistance = |
| 142 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; |
| 143 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; |
| 144 } |
| 145 |
| 146 float getInsertionCost(const DicTraverseSession* const traverseSession, |
| 147 const DicNode* const parentDicNode, |
| 148 const DicNode* const dicNode) const { |
| 149 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); |
| 150 const int prevCodePoint = traverseSession->getProximityInfoState(0) |
| 151 ->getPrimaryCodePointAt(insertedPointIndex); |
| 152 const int currentCodePoint = dicNode->getNodeCodePoint(); |
| 153 const bool sameCodePoint = prevCodePoint == currentCodePoint; |
| 154 const bool existsAdjacentProximityChars = |
| 155 traverseSession->getProximityInfoState(0) |
| 156 ->existsAdjacentProximityChars(insertedPointIndex); |
| 157 const float dist = |
| 158 traverseSession->getProximityInfoState(0)->getPointToKeyLength( |
| 159 insertedPointIndex + 1, |
| 160 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint())); |
| 161 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH; |
| 162 const bool singleChar = dicNode->getNodeCodePointCount() == 1; |
| 163 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f); |
| 164 if (sameCodePoint) { |
| 165 cost += ScoringParams::INSERTION_COST_SAME_CHAR; |
| 166 } else if (existsAdjacentProximityChars) { |
| 167 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; |
| 168 } else { |
| 169 cost += ScoringParams::INSERTION_COST; |
| 170 } |
| 171 return cost + weightedDistance; |
| 172 } |
| 173 |
| 174 float getNewWordSpatialCost(const DicTraverseSession* const traverseSession, |
| 175 const DicNode* const dicNode, |
| 176 DicNode_InputStateG* inputStateG) const { |
| 177 return ScoringParams::COST_NEW_WORD * |
| 178 traverseSession->getMultiWordCostMultiplier(); |
| 179 } |
| 180 |
| 181 float getNewWordBigramLanguageCost( |
| 182 const DicTraverseSession* const traverseSession, |
| 183 const DicNode* const dicNode, |
| 184 MultiBigramMap* const multiBigramMap) const { |
| 185 return DicNodeUtils::getBigramNodeImprobability( |
| 186 traverseSession->getDictionaryStructurePolicy(), dicNode, |
| 187 multiBigramMap) * |
| 188 ScoringParams::DISTANCE_WEIGHT_LANGUAGE; |
| 189 } |
| 190 |
| 191 float getCompletionCost(const DicTraverseSession* const traverseSession, |
| 192 const DicNode* const dicNode) const { |
| 193 // The auto completion starts when the input index is same as the input size |
| 194 const bool firstCompletion = |
| 195 dicNode->getInputIndex(0) == traverseSession->getInputSize(); |
| 196 // TODO: Change the cost for the first completion for the gesture? |
| 197 const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION |
| 198 : ScoringParams::COST_COMPLETION; |
| 199 return cost; |
| 200 } |
| 201 |
| 202 float getTerminalLanguageCost( |
| 203 const DicTraverseSession* const traverseSession, |
| 204 const DicNode* const dicNode, |
| 205 const float dicNodeLanguageImprobability) const { |
| 206 return dicNodeLanguageImprobability * |
| 207 ScoringParams::DISTANCE_WEIGHT_LANGUAGE; |
| 208 } |
| 209 |
| 210 float getTerminalInsertionCost( |
| 211 const DicTraverseSession* const traverseSession, |
| 212 const DicNode* const dicNode) const { |
| 213 const int inputIndex = dicNode->getInputIndex(0); |
| 214 const int inputSize = traverseSession->getInputSize(); |
| 215 ASSERT(inputIndex < inputSize); |
| 216 // TODO: Implement more efficient logic |
| 217 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex); |
| 218 } |
| 219 |
| 220 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { |
| 221 return false; |
| 222 } |
| 223 |
| 224 AK_FORCE_INLINE float getAdditionalProximityCost() const { |
| 225 return ScoringParams::ADDITIONAL_PROXIMITY_COST; |
| 226 } |
| 227 |
| 228 AK_FORCE_INLINE float getSubstitutionCost() const { |
| 229 return ScoringParams::SUBSTITUTION_COST; |
| 230 } |
| 231 |
| 232 AK_FORCE_INLINE float getSpaceSubstitutionCost( |
| 233 const DicTraverseSession* const traverseSession, |
| 234 const DicNode* const dicNode) const { |
| 235 const float cost = |
| 236 ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD; |
| 237 return cost * traverseSession->getMultiWordCostMultiplier(); |
| 238 } |
| 239 |
| 240 ErrorTypeUtils::ErrorType getErrorType( |
| 241 const CorrectionType correctionType, |
| 242 const DicTraverseSession* const traverseSession, |
| 243 const DicNode* const parentDicNode, |
| 244 const DicNode* const dicNode) const; |
| 245 |
| 246 private: |
| 247 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); |
| 248 static const TypingWeighting sInstance; |
| 249 |
| 250 TypingWeighting() {} |
| 251 ~TypingWeighting() {} |
| 252 }; |
| 253 } // namespace latinime |
| 254 #endif // LATINIME_TYPING_WEIGHTING_H |
| OLD | NEW |