| Index: third_party/android_prediction/suggest/core/policy/weighting.cpp
|
| diff --git a/third_party/android_prediction/suggest/core/policy/weighting.cpp b/third_party/android_prediction/suggest/core/policy/weighting.cpp
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..1db2625eddfff180a320b023a1116e74e86247e9
|
| --- /dev/null
|
| +++ b/third_party/android_prediction/suggest/core/policy/weighting.cpp
|
| @@ -0,0 +1,202 @@
|
| +/*
|
| + * Copyright (C) 2013 The Android Open Source Project
|
| + *
|
| + * Licensed under the Apache License, Version 2.0 (the "License");
|
| + * you may not use this file except in compliance with the License.
|
| + * You may obtain a copy of the License at
|
| + *
|
| + * http://www.apache.org/licenses/LICENSE-2.0
|
| + *
|
| + * Unless required by applicable law or agreed to in writing, software
|
| + * distributed under the License is distributed on an "AS IS" BASIS,
|
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| + * See the License for the specific language governing permissions and
|
| + * limitations under the License.
|
| + */
|
| +
|
| +#include "third_party/android_prediction/suggest/core/policy/weighting.h"
|
| +
|
| +#include "third_party/android_prediction/defines.h"
|
| +#include "third_party/android_prediction/suggest/core/dicnode/dic_node.h"
|
| +#include "third_party/android_prediction/suggest/core/dicnode/dic_node_profiler.h"
|
| +#include "third_party/android_prediction/suggest/core/dicnode/dic_node_utils.h"
|
| +#include "third_party/android_prediction/suggest/core/dictionary/error_type_utils.h"
|
| +#include "third_party/android_prediction/suggest/core/session/dic_traverse_session.h"
|
| +
|
| +namespace latinime {
|
| +
|
| +class MultiBigramMap;
|
| +
|
| +static inline void profile(const CorrectionType correctionType, DicNode *const node) {
|
| +#if DEBUG_DICT
|
| + switch (correctionType) {
|
| + case CT_OMISSION:
|
| + PROF_OMISSION(node->mProfiler);
|
| + return;
|
| + case CT_ADDITIONAL_PROXIMITY:
|
| + PROF_ADDITIONAL_PROXIMITY(node->mProfiler);
|
| + return;
|
| + case CT_SUBSTITUTION:
|
| + PROF_SUBSTITUTION(node->mProfiler);
|
| + return;
|
| + case CT_NEW_WORD_SPACE_OMISSION:
|
| + PROF_NEW_WORD(node->mProfiler);
|
| + return;
|
| + case CT_MATCH:
|
| + PROF_MATCH(node->mProfiler);
|
| + return;
|
| + case CT_COMPLETION:
|
| + PROF_COMPLETION(node->mProfiler);
|
| + return;
|
| + case CT_TERMINAL:
|
| + PROF_TERMINAL(node->mProfiler);
|
| + return;
|
| + case CT_TERMINAL_INSERTION:
|
| + PROF_TERMINAL_INSERTION(node->mProfiler);
|
| + return;
|
| + case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
| + PROF_SPACE_SUBSTITUTION(node->mProfiler);
|
| + return;
|
| + case CT_INSERTION:
|
| + PROF_INSERTION(node->mProfiler);
|
| + return;
|
| + case CT_TRANSPOSITION:
|
| + PROF_TRANSPOSITION(node->mProfiler);
|
| + return;
|
| + default:
|
| + // do nothing
|
| + return;
|
| + }
|
| +#else
|
| + // do nothing
|
| +#endif
|
| +}
|
| +
|
| +/* static */ void Weighting::addCostAndForwardInputIndex(const Weighting *const weighting,
|
| + const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, DicNode *const dicNode,
|
| + MultiBigramMap *const multiBigramMap) {
|
| + const int inputSize = traverseSession->getInputSize();
|
| + DicNode_InputStateG inputStateG;
|
| + inputStateG.mNeedsToUpdateInputStateG = false; // Don't use input info by default
|
| + const float spatialCost = Weighting::getSpatialCost(weighting, correctionType,
|
| + traverseSession, parentDicNode, dicNode, &inputStateG);
|
| + const float languageCost = Weighting::getLanguageCost(weighting, correctionType,
|
| + traverseSession, parentDicNode, dicNode, multiBigramMap);
|
| + const ErrorTypeUtils::ErrorType errorType = weighting->getErrorType(correctionType,
|
| + traverseSession, parentDicNode, dicNode);
|
| + profile(correctionType, dicNode);
|
| + if (inputStateG.mNeedsToUpdateInputStateG) {
|
| + dicNode->updateInputIndexG(&inputStateG);
|
| + } else {
|
| + dicNode->forwardInputIndex(0, getForwardInputCount(correctionType),
|
| + (correctionType == CT_TRANSPOSITION));
|
| + }
|
| + dicNode->addCost(spatialCost, languageCost, weighting->needsToNormalizeCompoundDistance(),
|
| + inputSize, errorType);
|
| + if (CT_NEW_WORD_SPACE_OMISSION == correctionType) {
|
| + // When we are on a terminal, we save the current distance for evaluating
|
| + // when to auto-commit partial suggestions.
|
| + dicNode->saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet();
|
| + }
|
| +}
|
| +
|
| +/* static */ float Weighting::getSpatialCost(const Weighting *const weighting,
|
| + const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, const DicNode *const dicNode,
|
| + DicNode_InputStateG *const inputStateG) {
|
| + switch(correctionType) {
|
| + case CT_OMISSION:
|
| + return weighting->getOmissionCost(parentDicNode, dicNode);
|
| + case CT_ADDITIONAL_PROXIMITY:
|
| + // only used for typing
|
| + return weighting->getAdditionalProximityCost();
|
| + case CT_SUBSTITUTION:
|
| + // only used for typing
|
| + return weighting->getSubstitutionCost();
|
| + case CT_NEW_WORD_SPACE_OMISSION:
|
| + return weighting->getNewWordSpatialCost(traverseSession, dicNode, inputStateG);
|
| + case CT_MATCH:
|
| + return weighting->getMatchedCost(traverseSession, dicNode, inputStateG);
|
| + case CT_COMPLETION:
|
| + return weighting->getCompletionCost(traverseSession, dicNode);
|
| + case CT_TERMINAL:
|
| + return weighting->getTerminalSpatialCost(traverseSession, dicNode);
|
| + case CT_TERMINAL_INSERTION:
|
| + return weighting->getTerminalInsertionCost(traverseSession, dicNode);
|
| + case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
| + return weighting->getSpaceSubstitutionCost(traverseSession, dicNode);
|
| + case CT_INSERTION:
|
| + return weighting->getInsertionCost(traverseSession, parentDicNode, dicNode);
|
| + case CT_TRANSPOSITION:
|
| + return weighting->getTranspositionCost(traverseSession, parentDicNode, dicNode);
|
| + default:
|
| + return 0.0f;
|
| + }
|
| +}
|
| +
|
| +/* static */ float Weighting::getLanguageCost(const Weighting *const weighting,
|
| + const CorrectionType correctionType, const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, const DicNode *const dicNode,
|
| + MultiBigramMap *const multiBigramMap) {
|
| + switch(correctionType) {
|
| + case CT_OMISSION:
|
| + return 0.0f;
|
| + case CT_SUBSTITUTION:
|
| + return 0.0f;
|
| + case CT_NEW_WORD_SPACE_OMISSION:
|
| + return weighting->getNewWordBigramLanguageCost(
|
| + traverseSession, parentDicNode, multiBigramMap);
|
| + case CT_MATCH:
|
| + return 0.0f;
|
| + case CT_COMPLETION:
|
| + return 0.0f;
|
| + case CT_TERMINAL: {
|
| + const float languageImprobability =
|
| + DicNodeUtils::getBigramNodeImprobability(
|
| + traverseSession->getDictionaryStructurePolicy(), dicNode, multiBigramMap);
|
| + return weighting->getTerminalLanguageCost(traverseSession, dicNode, languageImprobability);
|
| + }
|
| + case CT_TERMINAL_INSERTION:
|
| + return 0.0f;
|
| + case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
| + return weighting->getNewWordBigramLanguageCost(
|
| + traverseSession, parentDicNode, multiBigramMap);
|
| + case CT_INSERTION:
|
| + return 0.0f;
|
| + case CT_TRANSPOSITION:
|
| + return 0.0f;
|
| + default:
|
| + return 0.0f;
|
| + }
|
| +}
|
| +
|
| +/* static */ int Weighting::getForwardInputCount(const CorrectionType correctionType) {
|
| + switch(correctionType) {
|
| + case CT_OMISSION:
|
| + return 0;
|
| + case CT_ADDITIONAL_PROXIMITY:
|
| + return 0; /* 0 because CT_MATCH will be called */
|
| + case CT_SUBSTITUTION:
|
| + return 0; /* 0 because CT_MATCH will be called */
|
| + case CT_NEW_WORD_SPACE_OMISSION:
|
| + return 0;
|
| + case CT_MATCH:
|
| + return 1;
|
| + case CT_COMPLETION:
|
| + return 1;
|
| + case CT_TERMINAL:
|
| + return 0;
|
| + case CT_TERMINAL_INSERTION:
|
| + return 1;
|
| + case CT_NEW_WORD_SPACE_SUBSTITUTION:
|
| + return 1;
|
| + case CT_INSERTION:
|
| + return 2; /* look ahead + skip the current char */
|
| + case CT_TRANSPOSITION:
|
| + return 2; /* look ahead + skip the current char */
|
| + default:
|
| + return 0;
|
| + }
|
| +}
|
| +} // namespace latinime
|
|
|