| Index: third_party/android_prediction/suggest/policyimpl/typing/typing_weighting.h
|
| diff --git a/third_party/android_prediction/suggest/policyimpl/typing/typing_weighting.h b/third_party/android_prediction/suggest/policyimpl/typing/typing_weighting.h
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..f432444f55c2522128f5852d589ee7fce8a93d1c
|
| --- /dev/null
|
| +++ b/third_party/android_prediction/suggest/policyimpl/typing/typing_weighting.h
|
| @@ -0,0 +1,221 @@
|
| +/*
|
| + * Copyright (C) 2012 The Android Open Source Project
|
| + *
|
| + * Licensed under the Apache License, Version 2.0 (the "License");
|
| + * you may not use this file except in compliance with the License.
|
| + * You may obtain a copy of the License at
|
| + *
|
| + * http://www.apache.org/licenses/LICENSE-2.0
|
| + *
|
| + * Unless required by applicable law or agreed to in writing, software
|
| + * distributed under the License is distributed on an "AS IS" BASIS,
|
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| + * See the License for the specific language governing permissions and
|
| + * limitations under the License.
|
| + */
|
| +
|
| +#ifndef LATINIME_TYPING_WEIGHTING_H
|
| +#define LATINIME_TYPING_WEIGHTING_H
|
| +
|
| +#include "third_party/android_prediction/defines.h"
|
| +#include "third_party/android_prediction/suggest/core/dicnode/dic_node_utils.h"
|
| +#include "third_party/android_prediction/suggest/core/dictionary/error_type_utils.h"
|
| +#include "third_party/android_prediction/suggest/core/layout/touch_position_correction_utils.h"
|
| +#include "third_party/android_prediction/suggest/core/policy/weighting.h"
|
| +#include "third_party/android_prediction/suggest/core/session/dic_traverse_session.h"
|
| +#include "third_party/android_prediction/suggest/policyimpl/typing/scoring_params.h"
|
| +#include "third_party/android_prediction/utils/char_utils.h"
|
| +
|
| +namespace latinime {
|
| +
|
| +class DicNode;
|
| +struct DicNode_InputStateG;
|
| +class MultiBigramMap;
|
| +
|
| +class TypingWeighting : public Weighting {
|
| + public:
|
| + static const TypingWeighting *getInstance() { return &sInstance; }
|
| +
|
| + protected:
|
| + float getTerminalSpatialCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode) const {
|
| + float cost = 0.0f;
|
| + if (dicNode->hasMultipleWords()) {
|
| + cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
|
| + }
|
| + if (dicNode->getProximityCorrectionCount() > 0) {
|
| + cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
|
| + }
|
| + if (dicNode->getEditCorrectionCount() > 0) {
|
| + cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
|
| + }
|
| + return cost;
|
| + }
|
| +
|
| + float getOmissionCost(const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
| + const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
|
| + const bool isIntentionalOmission = parentDicNode->canBeIntentionalOmission();
|
| + const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
|
| + // If the traversal omitted the first letter then the dicNode should now be on the second.
|
| + const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
|
| + float cost = 0.0f;
|
| + if (isZeroCostOmission) {
|
| + cost = 0.0f;
|
| + } else if (isIntentionalOmission) {
|
| + cost = ScoringParams::INTENTIONAL_OMISSION_COST;
|
| + } else if (isFirstLetterOmission) {
|
| + cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
|
| + } else {
|
| + cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
|
| + : ScoringParams::OMISSION_COST;
|
| + }
|
| + return cost;
|
| + }
|
| +
|
| + float getMatchedCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
|
| + const int pointIndex = dicNode->getInputIndex(0);
|
| + const float normalizedSquaredLength = traverseSession->getProximityInfoState(0)
|
| + ->getPointToKeyLength(pointIndex,
|
| + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
|
| + const float normalizedDistance = TouchPositionCorrectionUtils::getSweetSpotFactor(
|
| + traverseSession->isTouchPositionCorrectionEnabled(), normalizedSquaredLength);
|
| + const float weightedDistance = ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
|
| +
|
| + const bool isFirstChar = pointIndex == 0;
|
| + const bool isProximity = isProximityDicNode(traverseSession, dicNode);
|
| + float cost = isProximity ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
|
| + : ScoringParams::PROXIMITY_COST) : 0.0f;
|
| + if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
|
| + cost += ScoringParams::FIRST_PROXIMITY_COST;
|
| + }
|
| + if (dicNode->getNodeCodePointCount() == 2) {
|
| + // At the second character of the current word, we check if the first char is uppercase
|
| + // and the word is a second or later word of a multiple word suggestion. We demote it
|
| + // if so.
|
| + const bool isSecondOrLaterWordFirstCharUppercase =
|
| + dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
|
| + if (isSecondOrLaterWordFirstCharUppercase) {
|
| + cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
|
| + }
|
| + }
|
| + return weightedDistance + cost;
|
| + }
|
| +
|
| + bool isProximityDicNode(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode) const {
|
| + const int pointIndex = dicNode->getInputIndex(0);
|
| + const int primaryCodePoint = CharUtils::toBaseLowerCase(
|
| + traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex));
|
| + const int dicNodeChar = CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
|
| + return primaryCodePoint != dicNodeChar;
|
| + }
|
| +
|
| + float getTranspositionCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
| + const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
| + const int prevCodePoint = parentDicNode->getNodeCodePoint();
|
| + const float distance1 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
|
| + const int codePoint = dicNode->getNodeCodePoint();
|
| + const float distance2 = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
|
| + const float distance = distance1 + distance2;
|
| + const float weightedLengthDistance =
|
| + distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
|
| + return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
|
| + }
|
| +
|
| + float getInsertionCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, const DicNode *const dicNode) const {
|
| + const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
|
| + const int prevCodePoint = traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(
|
| + insertedPointIndex);
|
| + const int currentCodePoint = dicNode->getNodeCodePoint();
|
| + const bool sameCodePoint = prevCodePoint == currentCodePoint;
|
| + const bool existsAdjacentProximityChars = traverseSession->getProximityInfoState(0)
|
| + ->existsAdjacentProximityChars(insertedPointIndex);
|
| + const float dist = traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + insertedPointIndex + 1, CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
|
| + const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
|
| + const bool singleChar = dicNode->getNodeCodePointCount() == 1;
|
| + float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
|
| + if (sameCodePoint) {
|
| + cost += ScoringParams::INSERTION_COST_SAME_CHAR;
|
| + } else if (existsAdjacentProximityChars) {
|
| + cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
|
| + } else {
|
| + cost += ScoringParams::INSERTION_COST;
|
| + }
|
| + return cost + weightedDistance;
|
| + }
|
| +
|
| + float getNewWordSpatialCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode, DicNode_InputStateG *inputStateG) const {
|
| + return ScoringParams::COST_NEW_WORD * traverseSession->getMultiWordCostMultiplier();
|
| + }
|
| +
|
| + float getNewWordBigramLanguageCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode,
|
| + MultiBigramMap *const multiBigramMap) const {
|
| + return DicNodeUtils::getBigramNodeImprobability(
|
| + traverseSession->getDictionaryStructurePolicy(),
|
| + dicNode, multiBigramMap) * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
| + }
|
| +
|
| + float getCompletionCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode) const {
|
| + // The auto completion starts when the input index is same as the input size
|
| + const bool firstCompletion = dicNode->getInputIndex(0)
|
| + == traverseSession->getInputSize();
|
| + // TODO: Change the cost for the first completion for the gesture?
|
| + const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
|
| + : ScoringParams::COST_COMPLETION;
|
| + return cost;
|
| + }
|
| +
|
| + float getTerminalLanguageCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode, const float dicNodeLanguageImprobability) const {
|
| + return dicNodeLanguageImprobability * ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
| + }
|
| +
|
| + float getTerminalInsertionCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode) const {
|
| + const int inputIndex = dicNode->getInputIndex(0);
|
| + const int inputSize = traverseSession->getInputSize();
|
| + ASSERT(inputIndex < inputSize);
|
| + // TODO: Implement more efficient logic
|
| + return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
|
| + }
|
| +
|
| + AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
|
| + return false;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getAdditionalProximityCost() const {
|
| + return ScoringParams::ADDITIONAL_PROXIMITY_COST;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getSubstitutionCost() const {
|
| + return ScoringParams::SUBSTITUTION_COST;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getSpaceSubstitutionCost(const DicTraverseSession *const traverseSession,
|
| + const DicNode *const dicNode) const {
|
| + const float cost = ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
|
| + return cost * traverseSession->getMultiWordCostMultiplier();
|
| + }
|
| +
|
| + ErrorTypeUtils::ErrorType getErrorType(const CorrectionType correctionType,
|
| + const DicTraverseSession *const traverseSession,
|
| + const DicNode *const parentDicNode, const DicNode *const dicNode) const;
|
| +
|
| + private:
|
| + DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
|
| + static const TypingWeighting sInstance;
|
| +
|
| + TypingWeighting() {}
|
| + ~TypingWeighting() {}
|
| +};
|
| +} // namespace latinime
|
| +#endif // LATINIME_TYPING_WEIGHTING_H
|
|
|