| Index: third_party/prediction/suggest/policyimpl/typing/typing_weighting.h
|
| diff --git a/third_party/prediction/suggest/policyimpl/typing/typing_weighting.h b/third_party/prediction/suggest/policyimpl/typing/typing_weighting.h
|
| new file mode 100644
|
| index 0000000000000000000000000000000000000000..05ebc8db12a7225dd160a74545699b2fbed15655
|
| --- /dev/null
|
| +++ b/third_party/prediction/suggest/policyimpl/typing/typing_weighting.h
|
| @@ -0,0 +1,254 @@
|
| +/*
|
| + * Copyright (C) 2012 The Android Open Source Project
|
| + *
|
| + * Licensed under the Apache License, Version 2.0 (the "License");
|
| + * you may not use this file except in compliance with the License.
|
| + * You may obtain a copy of the License at
|
| + *
|
| + * http://www.apache.org/licenses/LICENSE-2.0
|
| + *
|
| + * Unless required by applicable law or agreed to in writing, software
|
| + * distributed under the License is distributed on an "AS IS" BASIS,
|
| + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| + * See the License for the specific language governing permissions and
|
| + * limitations under the License.
|
| + */
|
| +
|
| +#ifndef LATINIME_TYPING_WEIGHTING_H
|
| +#define LATINIME_TYPING_WEIGHTING_H
|
| +
|
| +#include "third_party/prediction/defines.h"
|
| +#include "third_party/prediction/suggest/core/dicnode/dic_node_utils.h"
|
| +#include "third_party/prediction/suggest/core/dictionary/error_type_utils.h"
|
| +#include "third_party/prediction/suggest/core/layout/touch_position_correction_utils.h"
|
| +#include "third_party/prediction/suggest/core/policy/weighting.h"
|
| +#include "third_party/prediction/suggest/core/session/dic_traverse_session.h"
|
| +#include "third_party/prediction/suggest/policyimpl/typing/scoring_params.h"
|
| +#include "third_party/prediction/utils/char_utils.h"
|
| +
|
| +namespace latinime {
|
| +
|
| +class DicNode;
|
| +struct DicNode_InputStateG;
|
| +class MultiBigramMap;
|
| +
|
| +class TypingWeighting : public Weighting {
|
| + public:
|
| + static const TypingWeighting* getInstance() { return &sInstance; }
|
| +
|
| + protected:
|
| + float getTerminalSpatialCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode) const {
|
| + float cost = 0.0f;
|
| + if (dicNode->hasMultipleWords()) {
|
| + cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
|
| + }
|
| + if (dicNode->getProximityCorrectionCount() > 0) {
|
| + cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
|
| + }
|
| + if (dicNode->getEditCorrectionCount() > 0) {
|
| + cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
|
| + }
|
| + return cost;
|
| + }
|
| +
|
| + float getOmissionCost(const DicNode* const parentDicNode,
|
| + const DicNode* const dicNode) const {
|
| + const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
|
| + const bool isIntentionalOmission =
|
| + parentDicNode->canBeIntentionalOmission();
|
| + const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
|
| + // If the traversal omitted the first letter then the dicNode should now be
|
| + // on the second.
|
| + const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
|
| + float cost = 0.0f;
|
| + if (isZeroCostOmission) {
|
| + cost = 0.0f;
|
| + } else if (isIntentionalOmission) {
|
| + cost = ScoringParams::INTENTIONAL_OMISSION_COST;
|
| + } else if (isFirstLetterOmission) {
|
| + cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
|
| + } else {
|
| + cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
|
| + : ScoringParams::OMISSION_COST;
|
| + }
|
| + return cost;
|
| + }
|
| +
|
| + float getMatchedCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode,
|
| + DicNode_InputStateG* inputStateG) const {
|
| + const int pointIndex = dicNode->getInputIndex(0);
|
| + const float normalizedSquaredLength =
|
| + traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + pointIndex,
|
| + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
|
| + const float normalizedDistance =
|
| + TouchPositionCorrectionUtils::getSweetSpotFactor(
|
| + traverseSession->isTouchPositionCorrectionEnabled(),
|
| + normalizedSquaredLength);
|
| + const float weightedDistance =
|
| + ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
|
| +
|
| + const bool isFirstChar = pointIndex == 0;
|
| + const bool isProximity = isProximityDicNode(traverseSession, dicNode);
|
| + float cost = isProximity
|
| + ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
|
| + : ScoringParams::PROXIMITY_COST)
|
| + : 0.0f;
|
| + if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
|
| + cost += ScoringParams::FIRST_PROXIMITY_COST;
|
| + }
|
| + if (dicNode->getNodeCodePointCount() == 2) {
|
| + // At the second character of the current word, we check if the first char
|
| + // is uppercase
|
| + // and the word is a second or later word of a multiple word suggestion.
|
| + // We demote it
|
| + // if so.
|
| + const bool isSecondOrLaterWordFirstCharUppercase =
|
| + dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
|
| + if (isSecondOrLaterWordFirstCharUppercase) {
|
| + cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
|
| + }
|
| + }
|
| + return weightedDistance + cost;
|
| + }
|
| +
|
| + bool isProximityDicNode(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode) const {
|
| + const int pointIndex = dicNode->getInputIndex(0);
|
| + const int primaryCodePoint =
|
| + CharUtils::toBaseLowerCase(traverseSession->getProximityInfoState(0)
|
| + ->getPrimaryCodePointAt(pointIndex));
|
| + const int dicNodeChar =
|
| + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
|
| + return primaryCodePoint != dicNodeChar;
|
| + }
|
| +
|
| + float getTranspositionCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const parentDicNode,
|
| + const DicNode* const dicNode) const {
|
| + const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
|
| + const int prevCodePoint = parentDicNode->getNodeCodePoint();
|
| + const float distance1 =
|
| + traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
|
| + const int codePoint = dicNode->getNodeCodePoint();
|
| + const float distance2 =
|
| + traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
|
| + const float distance = distance1 + distance2;
|
| + const float weightedLengthDistance =
|
| + distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
|
| + return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
|
| + }
|
| +
|
| + float getInsertionCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const parentDicNode,
|
| + const DicNode* const dicNode) const {
|
| + const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
|
| + const int prevCodePoint = traverseSession->getProximityInfoState(0)
|
| + ->getPrimaryCodePointAt(insertedPointIndex);
|
| + const int currentCodePoint = dicNode->getNodeCodePoint();
|
| + const bool sameCodePoint = prevCodePoint == currentCodePoint;
|
| + const bool existsAdjacentProximityChars =
|
| + traverseSession->getProximityInfoState(0)
|
| + ->existsAdjacentProximityChars(insertedPointIndex);
|
| + const float dist =
|
| + traverseSession->getProximityInfoState(0)->getPointToKeyLength(
|
| + insertedPointIndex + 1,
|
| + CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
|
| + const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
|
| + const bool singleChar = dicNode->getNodeCodePointCount() == 1;
|
| + float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
|
| + if (sameCodePoint) {
|
| + cost += ScoringParams::INSERTION_COST_SAME_CHAR;
|
| + } else if (existsAdjacentProximityChars) {
|
| + cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
|
| + } else {
|
| + cost += ScoringParams::INSERTION_COST;
|
| + }
|
| + return cost + weightedDistance;
|
| + }
|
| +
|
| + float getNewWordSpatialCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode,
|
| + DicNode_InputStateG* inputStateG) const {
|
| + return ScoringParams::COST_NEW_WORD *
|
| + traverseSession->getMultiWordCostMultiplier();
|
| + }
|
| +
|
| + float getNewWordBigramLanguageCost(
|
| + const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode,
|
| + MultiBigramMap* const multiBigramMap) const {
|
| + return DicNodeUtils::getBigramNodeImprobability(
|
| + traverseSession->getDictionaryStructurePolicy(), dicNode,
|
| + multiBigramMap) *
|
| + ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
| + }
|
| +
|
| + float getCompletionCost(const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode) const {
|
| + // The auto completion starts when the input index is same as the input size
|
| + const bool firstCompletion =
|
| + dicNode->getInputIndex(0) == traverseSession->getInputSize();
|
| + // TODO: Change the cost for the first completion for the gesture?
|
| + const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
|
| + : ScoringParams::COST_COMPLETION;
|
| + return cost;
|
| + }
|
| +
|
| + float getTerminalLanguageCost(
|
| + const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode,
|
| + const float dicNodeLanguageImprobability) const {
|
| + return dicNodeLanguageImprobability *
|
| + ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
|
| + }
|
| +
|
| + float getTerminalInsertionCost(
|
| + const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode) const {
|
| + const int inputIndex = dicNode->getInputIndex(0);
|
| + const int inputSize = traverseSession->getInputSize();
|
| + ASSERT(inputIndex < inputSize);
|
| + // TODO: Implement more efficient logic
|
| + return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
|
| + }
|
| +
|
| + AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
|
| + return false;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getAdditionalProximityCost() const {
|
| + return ScoringParams::ADDITIONAL_PROXIMITY_COST;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getSubstitutionCost() const {
|
| + return ScoringParams::SUBSTITUTION_COST;
|
| + }
|
| +
|
| + AK_FORCE_INLINE float getSpaceSubstitutionCost(
|
| + const DicTraverseSession* const traverseSession,
|
| + const DicNode* const dicNode) const {
|
| + const float cost =
|
| + ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
|
| + return cost * traverseSession->getMultiWordCostMultiplier();
|
| + }
|
| +
|
| + ErrorTypeUtils::ErrorType getErrorType(
|
| + const CorrectionType correctionType,
|
| + const DicTraverseSession* const traverseSession,
|
| + const DicNode* const parentDicNode,
|
| + const DicNode* const dicNode) const;
|
| +
|
| + private:
|
| + DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
|
| + static const TypingWeighting sInstance;
|
| +
|
| + TypingWeighting() {}
|
| + ~TypingWeighting() {}
|
| +};
|
| +} // namespace latinime
|
| +#endif // LATINIME_TYPING_WEIGHTING_H
|
|
|