Index: third_party/prediction/suggest/core/policy/weighting.cpp |
diff --git a/third_party/prediction/suggest/core/policy/weighting.cpp b/third_party/prediction/suggest/core/policy/weighting.cpp |
new file mode 100644 |
index 0000000000000000000000000000000000000000..2b7eddd81060de3b511005ecb6c319957d893ca3 |
--- /dev/null |
+++ b/third_party/prediction/suggest/core/policy/weighting.cpp |
@@ -0,0 +1,222 @@ |
+/* |
+ * Copyright (C) 2013 The Android Open Source Project |
+ * |
+ * Licensed under the Apache License, Version 2.0 (the "License"); |
+ * you may not use this file except in compliance with the License. |
+ * You may obtain a copy of the License at |
+ * |
+ * http://www.apache.org/licenses/LICENSE-2.0 |
+ * |
+ * Unless required by applicable law or agreed to in writing, software |
+ * distributed under the License is distributed on an "AS IS" BASIS, |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+ * See the License for the specific language governing permissions and |
+ * limitations under the License. |
+ */ |
+ |
+#include "third_party/prediction/suggest/core/policy/weighting.h" |
+ |
+#include "third_party/prediction/defines.h" |
+#include "third_party/prediction/suggest/core/dicnode/dic_node.h" |
+#include "third_party/prediction/suggest/core/dicnode/dic_node_profiler.h" |
+#include "third_party/prediction/suggest/core/dicnode/dic_node_utils.h" |
+#include "third_party/prediction/suggest/core/dictionary/error_type_utils.h" |
+#include "third_party/prediction/suggest/core/session/dic_traverse_session.h" |
+ |
+namespace latinime { |
+ |
+class MultiBigramMap; |
+ |
+static inline void profile(const CorrectionType correctionType, |
+ DicNode* const node) { |
+#if DEBUG_DICT |
+ switch (correctionType) { |
+ case CT_OMISSION: |
+ PROF_OMISSION(node->mProfiler); |
+ return; |
+ case CT_ADDITIONAL_PROXIMITY: |
+ PROF_ADDITIONAL_PROXIMITY(node->mProfiler); |
+ return; |
+ case CT_SUBSTITUTION: |
+ PROF_SUBSTITUTION(node->mProfiler); |
+ return; |
+ case CT_NEW_WORD_SPACE_OMISSION: |
+ PROF_NEW_WORD(node->mProfiler); |
+ return; |
+ case CT_MATCH: |
+ PROF_MATCH(node->mProfiler); |
+ return; |
+ case CT_COMPLETION: |
+ PROF_COMPLETION(node->mProfiler); |
+ return; |
+ case CT_TERMINAL: |
+ PROF_TERMINAL(node->mProfiler); |
+ return; |
+ case CT_TERMINAL_INSERTION: |
+ PROF_TERMINAL_INSERTION(node->mProfiler); |
+ return; |
+ case CT_NEW_WORD_SPACE_SUBSTITUTION: |
+ PROF_SPACE_SUBSTITUTION(node->mProfiler); |
+ return; |
+ case CT_INSERTION: |
+ PROF_INSERTION(node->mProfiler); |
+ return; |
+ case CT_TRANSPOSITION: |
+ PROF_TRANSPOSITION(node->mProfiler); |
+ return; |
+ default: |
+ // do nothing |
+ return; |
+ } |
+#else |
+// do nothing |
+#endif |
+} |
+ |
+/* static */ void Weighting::addCostAndForwardInputIndex( |
+ const Weighting* const weighting, |
+ const CorrectionType correctionType, |
+ const DicTraverseSession* const traverseSession, |
+ const DicNode* const parentDicNode, |
+ DicNode* const dicNode, |
+ MultiBigramMap* const multiBigramMap) { |
+ const int inputSize = traverseSession->getInputSize(); |
+ DicNode_InputStateG inputStateG; |
+ inputStateG.mNeedsToUpdateInputStateG = |
+ false; // Don't use input info by default |
+ const float spatialCost = |
+ Weighting::getSpatialCost(weighting, correctionType, traverseSession, |
+ parentDicNode, dicNode, &inputStateG); |
+ const float languageCost = |
+ Weighting::getLanguageCost(weighting, correctionType, traverseSession, |
+ parentDicNode, dicNode, multiBigramMap); |
+ const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType( |
+ correctionType, traverseSession, parentDicNode, dicNode); |
+ profile(correctionType, dicNode); |
+ if (inputStateG.mNeedsToUpdateInputStateG) { |
+ dicNode->updateInputIndexG(&inputStateG); |
+ } else { |
+ dicNode->forwardInputIndex(0, getForwardInputCount(correctionType), |
+ (correctionType == CT_TRANSPOSITION)); |
+ } |
+ dicNode->addCost(spatialCost, languageCost, |
+ weighting->needsToNormalizeCompoundDistance(), inputSize, |
+ errorType); |
+ if (CT_NEW_WORD_SPACE_OMISSION == correctionType) { |
+ // When we are on a terminal, we save the current distance for evaluating |
+ // when to auto-commit partial suggestions. |
+ dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet(); |
+ } |
+} |
+ |
+/* static */ float Weighting::getSpatialCost( |
+ const Weighting* const weighting, |
+ const CorrectionType correctionType, |
+ const DicTraverseSession* const traverseSession, |
+ const DicNode* const parentDicNode, |
+ const DicNode* const dicNode, |
+ DicNode_InputStateG* const inputStateG) { |
+ switch (correctionType) { |
+ case CT_OMISSION: |
+ return weighting->getOmissionCost(parentDicNode, dicNode); |
+ case CT_ADDITIONAL_PROXIMITY: |
+ // only used for typing |
+ return weighting->getAdditionalProximityCost(); |
+ case CT_SUBSTITUTION: |
+ // only used for typing |
+ return weighting->getSubstitutionCost(); |
+ case CT_NEW_WORD_SPACE_OMISSION: |
+ return weighting->getNewWordSpatialCost(traverseSession, dicNode, |
+ inputStateG); |
+ case CT_MATCH: |
+ return weighting->getMatchedCost(traverseSession, dicNode, inputStateG); |
+ case CT_COMPLETION: |
+ return weighting->getCompletionCost(traverseSession, dicNode); |
+ case CT_TERMINAL: |
+ return weighting->getTerminalSpatialCost(traverseSession, dicNode); |
+ case CT_TERMINAL_INSERTION: |
+ return weighting->getTerminalInsertionCost(traverseSession, dicNode); |
+ case CT_NEW_WORD_SPACE_SUBSTITUTION: |
+ return weighting->getSpaceSubstitutionCost(traverseSession, dicNode); |
+ case CT_INSERTION: |
+ return weighting->getInsertionCost(traverseSession, parentDicNode, |
+ dicNode); |
+ case CT_TRANSPOSITION: |
+ return weighting->getTranspositionCost(traverseSession, parentDicNode, |
+ dicNode); |
+ default: |
+ return 0.0f; |
+ } |
+} |
+ |
+/* static */ float Weighting::getLanguageCost( |
+ const Weighting* const weighting, |
+ const CorrectionType correctionType, |
+ const DicTraverseSession* const traverseSession, |
+ const DicNode* const parentDicNode, |
+ const DicNode* const dicNode, |
+ MultiBigramMap* const multiBigramMap) { |
+ switch (correctionType) { |
+ case CT_OMISSION: |
+ return 0.0f; |
+ case CT_SUBSTITUTION: |
+ return 0.0f; |
+ case CT_NEW_WORD_SPACE_OMISSION: |
+ return weighting->getNewWordBigramLanguageCost( |
+ traverseSession, parentDicNode, multiBigramMap); |
+ case CT_MATCH: |
+ return 0.0f; |
+ case CT_COMPLETION: |
+ return 0.0f; |
+ case CT_TERMINAL: { |
+ const float languageImprobability = |
+ DicNodeUtils::getBigramNodeImprobability( |
+ traverseSession->getDictionaryStructurePolicy(), dicNode, |
+ multiBigramMap); |
+ return weighting->getTerminalLanguageCost(traverseSession, dicNode, |
+ languageImprobability); |
+ } |
+ case CT_TERMINAL_INSERTION: |
+ return 0.0f; |
+ case CT_NEW_WORD_SPACE_SUBSTITUTION: |
+ return weighting->getNewWordBigramLanguageCost( |
+ traverseSession, parentDicNode, multiBigramMap); |
+ case CT_INSERTION: |
+ return 0.0f; |
+ case CT_TRANSPOSITION: |
+ return 0.0f; |
+ default: |
+ return 0.0f; |
+ } |
+} |
+ |
+/* static */ int Weighting::getForwardInputCount( |
+ const CorrectionType correctionType) { |
+ switch (correctionType) { |
+ case CT_OMISSION: |
+ return 0; |
+ case CT_ADDITIONAL_PROXIMITY: |
+ return 0; /* 0 because CT_MATCH will be called */ |
+ case CT_SUBSTITUTION: |
+ return 0; /* 0 because CT_MATCH will be called */ |
+ case CT_NEW_WORD_SPACE_OMISSION: |
+ return 0; |
+ case CT_MATCH: |
+ return 1; |
+ case CT_COMPLETION: |
+ return 1; |
+ case CT_TERMINAL: |
+ return 0; |
+ case CT_TERMINAL_INSERTION: |
+ return 1; |
+ case CT_NEW_WORD_SPACE_SUBSTITUTION: |
+ return 1; |
+ case CT_INSERTION: |
+ return 2; /* look ahead + skip the current char */ |
+ case CT_TRANSPOSITION: |
+ return 2; /* look ahead + skip the current char */ |
+ default: |
+ return 0; |
+ } |
+} |
+} // namespace latinime |