OLD | NEW |
(Empty) | |
| 1 /* |
| 2 * Copyright (C) 2012 The Android Open Source Project |
| 3 * |
| 4 * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 * you may not use this file except in compliance with the License. |
| 6 * You may obtain a copy of the License at |
| 7 * |
| 8 * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 * |
| 10 * Unless required by applicable law or agreed to in writing, software |
| 11 * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 * See the License for the specific language governing permissions and |
| 14 * limitations under the License. |
| 15 */ |
| 16 |
| 17 #ifndef LATINIME_TYPING_WEIGHTING_H |
| 18 #define LATINIME_TYPING_WEIGHTING_H |
| 19 |
| 20 #include "third_party/android_prediction/defines.h" |
| 21 #include "third_party/android_prediction/suggest/core/dicnode/dic_node_utils.h" |
| 22 #include "third_party/android_prediction/suggest/core/dictionary/error_type_util
s.h" |
| 23 #include "third_party/android_prediction/suggest/core/layout/touch_position_corr
ection_utils.h" |
| 24 #include "third_party/android_prediction/suggest/core/policy/weighting.h" |
| 25 #include "third_party/android_prediction/suggest/core/session/dic_traverse_sessi
on.h" |
| 26 #include "third_party/android_prediction/suggest/policyimpl/typing/scoring_param
s.h" |
| 27 #include "third_party/android_prediction/utils/char_utils.h" |
| 28 |
| 29 namespace latinime { |
| 30 |
| 31 class DicNode; |
| 32 struct DicNode_InputStateG; |
| 33 class MultiBigramMap; |
| 34 |
| 35 class TypingWeighting : public Weighting { |
| 36 public: |
| 37 static const TypingWeighting *getInstance() { return &sInstance; } |
| 38 |
| 39 protected: |
| 40 float getTerminalSpatialCost(const DicTraverseSession *const traverseSession
, |
| 41 const DicNode *const dicNode) const { |
| 42 float cost = 0.0f; |
| 43 if (dicNode->hasMultipleWords()) { |
| 44 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST; |
| 45 } |
| 46 if (dicNode->getProximityCorrectionCount() > 0) { |
| 47 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST; |
| 48 } |
| 49 if (dicNode->getEditCorrectionCount() > 0) { |
| 50 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST; |
| 51 } |
| 52 return cost; |
| 53 } |
| 54 |
| 55 float getOmissionCost(const DicNode *const parentDicNode, const DicNode *con
st dicNode) const { |
| 56 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission(); |
| 57 const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmissi
on(); |
| 58 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode); |
| 59 // If the traversal omitted the first letter then the dicNode should now
be on the second. |
| 60 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2
; |
| 61 float cost = 0.0f; |
| 62 if (isZeroCostOmission) { |
| 63 cost = 0.0f; |
| 64 } else if (isIntentionalOmission) { |
| 65 cost = ScoringParams::INTENTIONAL_OMISSION_COST; |
| 66 } else if (isFirstLetterOmission) { |
| 67 cost = ScoringParams::OMISSION_COST_FIRST_CHAR; |
| 68 } else { |
| 69 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR |
| 70 : ScoringParams::OMISSION_COST; |
| 71 } |
| 72 return cost; |
| 73 } |
| 74 |
| 75 float getMatchedCost(const DicTraverseSession *const traverseSession, |
| 76 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) cons
t { |
| 77 const int pointIndex = dicNode->getInputIndex(0); |
| 78 const float normalizedSquaredLength = traverseSession->getProximityInfoS
tate(0) |
| 79 ->getPointToKeyLength(pointIndex, |
| 80 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()))
; |
| 81 const float normalizedDistance = TouchPositionCorrectionUtils::getSweetS
potFactor( |
| 82 traverseSession->isTouchPositionCorrectionEnabled(), normalizedS
quaredLength); |
| 83 const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * n
ormalizedDistance; |
| 84 |
| 85 const bool isFirstChar = pointIndex == 0; |
| 86 const bool isProximity = isProximityDicNode(traverseSession, dicNode); |
| 87 float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROX
IMITY_COST |
| 88 : ScoringParams::PROXIMITY_COST) : 0.0f; |
| 89 if (isProximity && dicNode->getProximityCorrectionCount() == 0) { |
| 90 cost += ScoringParams::FIRST_PROXIMITY_COST; |
| 91 } |
| 92 if (dicNode->getNodeCodePointCount() == 2) { |
| 93 // At the second character of the current word, we check if the firs
t char is uppercase |
| 94 // and the word is a second or later word of a multiple word suggest
ion. We demote it |
| 95 // if so. |
| 96 const bool isSecondOrLaterWordFirstCharUppercase = |
| 97 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase
(); |
| 98 if (isSecondOrLaterWordFirstCharUppercase) { |
| 99 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPE
RCASE; |
| 100 } |
| 101 } |
| 102 return weightedDistance + cost; |
| 103 } |
| 104 |
| 105 bool isProximityDicNode(const DicTraverseSession *const traverseSession, |
| 106 const DicNode *const dicNode) const { |
| 107 const int pointIndex = dicNode->getInputIndex(0); |
| 108 const int primaryCodePoint = CharUtils::toBaseLowerCase( |
| 109 traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt
(pointIndex)); |
| 110 const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodeP
oint()); |
| 111 return primaryCodePoint != dicNodeChar; |
| 112 } |
| 113 |
| 114 float getTranspositionCost(const DicTraverseSession *const traverseSession, |
| 115 const DicNode *const parentDicNode, const DicNode *const dicNode) co
nst { |
| 116 const int16_t parentPointIndex = parentDicNode->getInputIndex(0); |
| 117 const int prevCodePoint = parentDicNode->getNodeCodePoint(); |
| 118 const float distance1 = traverseSession->getProximityInfoState(0)->getPo
intToKeyLength( |
| 119 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint))
; |
| 120 const int codePoint = dicNode->getNodeCodePoint(); |
| 121 const float distance2 = traverseSession->getProximityInfoState(0)->getPo
intToKeyLength( |
| 122 parentPointIndex, CharUtils::toBaseLowerCase(codePoint)); |
| 123 const float distance = distance1 + distance2; |
| 124 const float weightedLengthDistance = |
| 125 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH; |
| 126 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance; |
| 127 } |
| 128 |
| 129 float getInsertionCost(const DicTraverseSession *const traverseSession, |
| 130 const DicNode *const parentDicNode, const DicNode *const dicNode) co
nst { |
| 131 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0); |
| 132 const int prevCodePoint = traverseSession->getProximityInfoState(0)->get
PrimaryCodePointAt( |
| 133 insertedPointIndex); |
| 134 const int currentCodePoint = dicNode->getNodeCodePoint(); |
| 135 const bool sameCodePoint = prevCodePoint == currentCodePoint; |
| 136 const bool existsAdjacentProximityChars = traverseSession->getProximityI
nfoState(0) |
| 137 ->existsAdjacentProximityChars(insertedPointIndex); |
| 138 const float dist = traverseSession->getProximityInfoState(0)->getPointTo
KeyLength( |
| 139 insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getN
odeCodePoint())); |
| 140 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LEN
GTH; |
| 141 const bool singleChar = dicNode->getNodeCodePointCount() == 1; |
| 142 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.
0f); |
| 143 if (sameCodePoint) { |
| 144 cost += ScoringParams::INSERTION_COST_SAME_CHAR; |
| 145 } else if (existsAdjacentProximityChars) { |
| 146 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR; |
| 147 } else { |
| 148 cost += ScoringParams::INSERTION_COST; |
| 149 } |
| 150 return cost + weightedDistance; |
| 151 } |
| 152 |
| 153 float getNewWordSpatialCost(const DicTraverseSession *const traverseSession, |
| 154 const DicNode *const dicNode, DicNode_InputStateG *inputStateG) cons
t { |
| 155 return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostM
ultiplier(); |
| 156 } |
| 157 |
| 158 float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseS
ession, |
| 159 const DicNode *const dicNode, |
| 160 MultiBigramMap *const multiBigramMap) const { |
| 161 return DicNodeUtils::getBigramNodeImprobability( |
| 162 traverseSession->getDictionaryStructurePolicy(), |
| 163 dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUA
GE; |
| 164 } |
| 165 |
| 166 float getCompletionCost(const DicTraverseSession *const traverseSession, |
| 167 const DicNode *const dicNode) const { |
| 168 // The auto completion starts when the input index is same as the input
size |
| 169 const bool firstCompletion = dicNode->getInputIndex(0) |
| 170 == traverseSession->getInputSize(); |
| 171 // TODO: Change the cost for the first completion for the gesture? |
| 172 const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETIO
N |
| 173 : ScoringParams::COST_COMPLETION; |
| 174 return cost; |
| 175 } |
| 176 |
| 177 float getTerminalLanguageCost(const DicTraverseSession *const traverseSessio
n, |
| 178 const DicNode *const dicNode, const float dicNodeLanguageImprobabili
ty) const { |
| 179 return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LAN
GUAGE; |
| 180 } |
| 181 |
| 182 float getTerminalInsertionCost(const DicTraverseSession *const traverseSessi
on, |
| 183 const DicNode *const dicNode) const { |
| 184 const int inputIndex = dicNode->getInputIndex(0); |
| 185 const int inputSize = traverseSession->getInputSize(); |
| 186 ASSERT(inputIndex < inputSize); |
| 187 // TODO: Implement more efficient logic |
| 188 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex
); |
| 189 } |
| 190 |
| 191 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const { |
| 192 return false; |
| 193 } |
| 194 |
| 195 AK_FORCE_INLINE float getAdditionalProximityCost() const { |
| 196 return ScoringParams::ADDITIONAL_PROXIMITY_COST; |
| 197 } |
| 198 |
| 199 AK_FORCE_INLINE float getSubstitutionCost() const { |
| 200 return ScoringParams::SUBSTITUTION_COST; |
| 201 } |
| 202 |
| 203 AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *con
st traverseSession, |
| 204 const DicNode *const dicNode) const { |
| 205 const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParam
s::COST_NEW_WORD; |
| 206 return cost * traverseSession->getMultiWordCostMultiplier(); |
| 207 } |
| 208 |
| 209 ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType, |
| 210 const DicTraverseSession *const traverseSession, |
| 211 const DicNode *const parentDicNode, const DicNode *const dicNode) co
nst; |
| 212 |
| 213 private: |
| 214 DISALLOW_COPY_AND_ASSIGN(TypingWeighting); |
| 215 static const TypingWeighting sInstance; |
| 216 |
| 217 TypingWeighting() {} |
| 218 ~TypingWeighting() {} |
| 219 }; |
| 220 } // namespace latinime |
| 221 #endif // LATINIME_TYPING_WEIGHTING_H |
OLD | NEW |