OLD | NEW |
(Empty) | |
| 1 /* |
| 2 * Copyright (C) 2013 The Android Open Source Project |
| 3 * |
| 4 * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 * you may not use this file except in compliance with the License. |
| 6 * You may obtain a copy of the License at |
| 7 * |
| 8 * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 * |
| 10 * Unless required by applicable law or agreed to in writing, software |
| 11 * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 * See the License for the specific language governing permissions and |
| 14 * limitations under the License. |
| 15 */ |
| 16 |
| 17 #include "third_party/prediction/suggest/core/policy/weighting.h" |
| 18 |
| 19 #include "third_party/prediction/defines.h" |
| 20 #include "third_party/prediction/suggest/core/dicnode/dic_node.h" |
| 21 #include "third_party/prediction/suggest/core/dicnode/dic_node_profiler.h" |
| 22 #include "third_party/prediction/suggest/core/dicnode/dic_node_utils.h" |
| 23 #include "third_party/prediction/suggest/core/dictionary/error_type_utils.h" |
| 24 #include "third_party/prediction/suggest/core/session/dic_traverse_session.h" |
| 25 |
| 26 namespace latinime { |
| 27 |
| 28 class MultiBigramMap; |
| 29 |
| 30 static inline void profile(const CorrectionType correctionType, |
| 31 DicNode* const node) { |
| 32 #if DEBUG_DICT |
| 33 switch (correctionType) { |
| 34 case CT_OMISSION: |
| 35 PROF_OMISSION(node->mProfiler); |
| 36 return; |
| 37 case CT_ADDITIONAL_PROXIMITY: |
| 38 PROF_ADDITIONAL_PROXIMITY(node->mProfiler); |
| 39 return; |
| 40 case CT_SUBSTITUTION: |
| 41 PROF_SUBSTITUTION(node->mProfiler); |
| 42 return; |
| 43 case CT_NEW_WORD_SPACE_OMISSION: |
| 44 PROF_NEW_WORD(node->mProfiler); |
| 45 return; |
| 46 case CT_MATCH: |
| 47 PROF_MATCH(node->mProfiler); |
| 48 return; |
| 49 case CT_COMPLETION: |
| 50 PROF_COMPLETION(node->mProfiler); |
| 51 return; |
| 52 case CT_TERMINAL: |
| 53 PROF_TERMINAL(node->mProfiler); |
| 54 return; |
| 55 case CT_TERMINAL_INSERTION: |
| 56 PROF_TERMINAL_INSERTION(node->mProfiler); |
| 57 return; |
| 58 case CT_NEW_WORD_SPACE_SUBSTITUTION: |
| 59 PROF_SPACE_SUBSTITUTION(node->mProfiler); |
| 60 return; |
| 61 case CT_INSERTION: |
| 62 PROF_INSERTION(node->mProfiler); |
| 63 return; |
| 64 case CT_TRANSPOSITION: |
| 65 PROF_TRANSPOSITION(node->mProfiler); |
| 66 return; |
| 67 default: |
| 68 // do nothing |
| 69 return; |
| 70 } |
| 71 #else |
| 72 // do nothing |
| 73 #endif |
| 74 } |
| 75 |
| 76 /* static */ void Weighting::addCostAndForwardInputIndex( |
| 77 const Weighting* const weighting, |
| 78 const CorrectionType correctionType, |
| 79 const DicTraverseSession* const traverseSession, |
| 80 const DicNode* const parentDicNode, |
| 81 DicNode* const dicNode, |
| 82 MultiBigramMap* const multiBigramMap) { |
| 83 const int inputSize = traverseSession->getInputSize(); |
| 84 DicNode_InputStateG inputStateG; |
| 85 inputStateG.mNeedsToUpdateInputStateG = |
| 86 false; // Don't use input info by default |
| 87 const float spatialCost = |
| 88 Weighting::getSpatialCost(weighting, correctionType, traverseSession, |
| 89 parentDicNode, dicNode, &inputStateG); |
| 90 const float languageCost = |
| 91 Weighting::getLanguageCost(weighting, correctionType, traverseSession, |
| 92 parentDicNode, dicNode, multiBigramMap); |
| 93 const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType( |
| 94 correctionType, traverseSession, parentDicNode, dicNode); |
| 95 profile(correctionType, dicNode); |
| 96 if (inputStateG.mNeedsToUpdateInputStateG) { |
| 97 dicNode->updateInputIndexG(&inputStateG); |
| 98 } else { |
| 99 dicNode->forwardInputIndex(0, getForwardInputCount(correctionType), |
| 100 (correctionType == CT_TRANSPOSITION)); |
| 101 } |
| 102 dicNode->addCost(spatialCost, languageCost, |
| 103 weighting->needsToNormalizeCompoundDistance(), inputSize, |
| 104 errorType); |
| 105 if (CT_NEW_WORD_SPACE_OMISSION == correctionType) { |
| 106 // When we are on a terminal, we save the current distance for evaluating |
| 107 // when to auto-commit partial suggestions. |
| 108 dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); |
| 109 } |
| 110 } |
| 111 |
| 112 /* static */ float Weighting::getSpatialCost( |
| 113 const Weighting* const weighting, |
| 114 const CorrectionType correctionType, |
| 115 const DicTraverseSession* const traverseSession, |
| 116 const DicNode* const parentDicNode, |
| 117 const DicNode* const dicNode, |
| 118 DicNode_InputStateG* const inputStateG) { |
| 119 switch (correctionType) { |
| 120 case CT_OMISSION: |
| 121 return weighting->getOmissionCost(parentDicNode, dicNode); |
| 122 case CT_ADDITIONAL_PROXIMITY: |
| 123 // only used for typing |
| 124 return weighting->getAdditionalProximityCost(); |
| 125 case CT_SUBSTITUTION: |
| 126 // only used for typing |
| 127 return weighting->getSubstitutionCost(); |
| 128 case CT_NEW_WORD_SPACE_OMISSION: |
| 129 return weighting->getNewWordSpatialCost(traverseSession, dicNode, |
| 130 inputStateG); |
| 131 case CT_MATCH: |
| 132 return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); |
| 133 case CT_COMPLETION: |
| 134 return weighting->getCompletionCost(traverseSession, dicNode); |
| 135 case CT_TERMINAL: |
| 136 return weighting->getTerminalSpatialCost(traverseSession, dicNode); |
| 137 case CT_TERMINAL_INSERTION: |
| 138 return weighting->getTerminalInsertionCost(traverseSession, dicNode); |
| 139 case CT_NEW_WORD_SPACE_SUBSTITUTION: |
| 140 return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); |
| 141 case CT_INSERTION: |
| 142 return weighting->getInsertionCost(traverseSession, parentDicNode, |
| 143 dicNode); |
| 144 case CT_TRANSPOSITION: |
| 145 return weighting->getTranspositionCost(traverseSession, parentDicNode, |
| 146 dicNode); |
| 147 default: |
| 148 return 0.0f; |
| 149 } |
| 150 } |
| 151 |
| 152 /* static */ float Weighting::getLanguageCost( |
| 153 const Weighting* const weighting, |
| 154 const CorrectionType correctionType, |
| 155 const DicTraverseSession* const traverseSession, |
| 156 const DicNode* const parentDicNode, |
| 157 const DicNode* const dicNode, |
| 158 MultiBigramMap* const multiBigramMap) { |
| 159 switch (correctionType) { |
| 160 case CT_OMISSION: |
| 161 return 0.0f; |
| 162 case CT_SUBSTITUTION: |
| 163 return 0.0f; |
| 164 case CT_NEW_WORD_SPACE_OMISSION: |
| 165 return weighting->getNewWordBigramLanguageCost( |
| 166 traverseSession, parentDicNode, multiBigramMap); |
| 167 case CT_MATCH: |
| 168 return 0.0f; |
| 169 case CT_COMPLETION: |
| 170 return 0.0f; |
| 171 case CT_TERMINAL: { |
| 172 const float languageImprobability = |
| 173 DicNodeUtils::getBigramNodeImprobability( |
| 174 traverseSession->getDictionaryStructurePolicy(), dicNode, |
| 175 multiBigramMap); |
| 176 return weighting->getTerminalLanguageCost(traverseSession, dicNode, |
| 177 languageImprobability); |
| 178 } |
| 179 case CT_TERMINAL_INSERTION: |
| 180 return 0.0f; |
| 181 case CT_NEW_WORD_SPACE_SUBSTITUTION: |
| 182 return weighting->getNewWordBigramLanguageCost( |
| 183 traverseSession, parentDicNode, multiBigramMap); |
| 184 case CT_INSERTION: |
| 185 return 0.0f; |
| 186 case CT_TRANSPOSITION: |
| 187 return 0.0f; |
| 188 default: |
| 189 return 0.0f; |
| 190 } |
| 191 } |
| 192 |
| 193 /* static */ int Weighting::getForwardInputCount( |
| 194 const CorrectionType correctionType) { |
| 195 switch (correctionType) { |
| 196 case CT_OMISSION: |
| 197 return 0; |
| 198 case CT_ADDITIONAL_PROXIMITY: |
| 199 return 0; /* 0 because CT_MATCH will be called */ |
| 200 case CT_SUBSTITUTION: |
| 201 return 0; /* 0 because CT_MATCH will be called */ |
| 202 case CT_NEW_WORD_SPACE_OMISSION: |
| 203 return 0; |
| 204 case CT_MATCH: |
| 205 return 1; |
| 206 case CT_COMPLETION: |
| 207 return 1; |
| 208 case CT_TERMINAL: |
| 209 return 0; |
| 210 case CT_TERMINAL_INSERTION: |
| 211 return 1; |
| 212 case CT_NEW_WORD_SPACE_SUBSTITUTION: |
| 213 return 1; |
| 214 case CT_INSERTION: |
| 215 return 2; /* look ahead + skip the current char */ |
| 216 case CT_TRANSPOSITION: |
| 217 return 2; /* look ahead + skip the current char */ |
| 218 default: |
| 219 return 0; |
| 220 } |
| 221 } |
| 222 } // namespace latinime |
OLD | NEW |