Index: third_party/prediction/suggest/core/dicnode/internal/dic_node_state_scoring.h |
diff --git a/third_party/prediction/suggest/core/dicnode/internal/dic_node_state_scoring.h b/third_party/prediction/suggest/core/dicnode/internal/dic_node_state_scoring.h |
new file mode 100644 |
index 0000000000000000000000000000000000000000..756b4fcc6da0f393dccb3f07f8b57822510c2a8a |
--- /dev/null |
+++ b/third_party/prediction/suggest/core/dicnode/internal/dic_node_state_scoring.h |
@@ -0,0 +1,217 @@ |
+/* |
+ * Copyright (C) 2012 The Android Open Source Project |
+ * |
+ * Licensed under the Apache License, Version 2.0 (the "License"); |
+ * you may not use this file except in compliance with the License. |
+ * You may obtain a copy of the License at |
+ * |
+ * http://www.apache.org/licenses/LICENSE-2.0 |
+ * |
+ * Unless required by applicable law or agreed to in writing, software |
+ * distributed under the License is distributed on an "AS IS" BASIS, |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+ * See the License for the specific language governing permissions and |
+ * limitations under the License. |
+ */ |
+ |
+#ifndef LATINIME_DIC_NODE_STATE_SCORING_H |
+#define LATINIME_DIC_NODE_STATE_SCORING_H |
+ |
+#include <algorithm> |
+#include <cstdint> |
+ |
+#include "third_party/prediction/defines.h" |
+#include "third_party/prediction/suggest/core/dictionary/digraph_utils.h" |
+#include "third_party/prediction/suggest/core/dictionary/error_type_utils.h" |
+ |
+namespace latinime { |
+ |
+class DicNodeStateScoring { |
+ public: |
+ AK_FORCE_INLINE DicNodeStateScoring() |
+ : mDoubleLetterLevel(NOT_A_DOUBLE_LETTER), |
+ mDigraphIndex(DigraphUtils::NOT_A_DIGRAPH_INDEX), |
+ mEditCorrectionCount(0), |
+ mProximityCorrectionCount(0), |
+ mCompletionCount(0), |
+ mNormalizedCompoundDistance(0.0f), |
+ mSpatialDistance(0.0f), |
+ mLanguageDistance(0.0f), |
+ mRawLength(0.0f), |
+ mContainedErrorTypes(ErrorTypeUtils::NOT_AN_ERROR), |
+ mNormalizedCompoundDistanceAfterFirstWord(MAX_VALUE_FOR_WEIGHTING) {} |
+ |
+ ~DicNodeStateScoring() {} |
+ |
+ void init() { |
+ mEditCorrectionCount = 0; |
+ mProximityCorrectionCount = 0; |
+ mCompletionCount = 0; |
+ mNormalizedCompoundDistance = 0.0f; |
+ mSpatialDistance = 0.0f; |
+ mLanguageDistance = 0.0f; |
+ mRawLength = 0.0f; |
+ mDoubleLetterLevel = NOT_A_DOUBLE_LETTER; |
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; |
+ mNormalizedCompoundDistanceAfterFirstWord = MAX_VALUE_FOR_WEIGHTING; |
+ mContainedErrorTypes = ErrorTypeUtils::NOT_AN_ERROR; |
+ } |
+ |
+ AK_FORCE_INLINE void initByCopy(const DicNodeStateScoring* const scoring) { |
+ mEditCorrectionCount = scoring->mEditCorrectionCount; |
+ mProximityCorrectionCount = scoring->mProximityCorrectionCount; |
+ mCompletionCount = scoring->mCompletionCount; |
+ mNormalizedCompoundDistance = scoring->mNormalizedCompoundDistance; |
+ mSpatialDistance = scoring->mSpatialDistance; |
+ mLanguageDistance = scoring->mLanguageDistance; |
+ mRawLength = scoring->mRawLength; |
+ mDoubleLetterLevel = scoring->mDoubleLetterLevel; |
+ mDigraphIndex = scoring->mDigraphIndex; |
+ mContainedErrorTypes = scoring->mContainedErrorTypes; |
+ mNormalizedCompoundDistanceAfterFirstWord = |
+ scoring->mNormalizedCompoundDistanceAfterFirstWord; |
+ } |
+ |
+ void addCost(const float spatialCost, |
+ const float languageCost, |
+ const bool doNormalization, |
+ const int inputSize, |
+ const int totalInputIndex, |
+ const ErrorTypeUtils::ErrorType errorType) { |
+ addDistance(spatialCost, languageCost, doNormalization, inputSize, |
+ totalInputIndex); |
+ mContainedErrorTypes = mContainedErrorTypes | errorType; |
+ if (ErrorTypeUtils::isEditCorrectionError(errorType)) { |
+ ++mEditCorrectionCount; |
+ } |
+ if (ErrorTypeUtils::isProximityCorrectionError(errorType)) { |
+ ++mProximityCorrectionCount; |
+ } |
+ if (ErrorTypeUtils::isCompletion(errorType)) { |
+ ++mCompletionCount; |
+ } |
+ } |
+ |
+ // Saves the current normalized distance for space-aware gestures. |
+ // See getNormalizedCompoundDistanceAfterFirstWord for details. |
+ void saveNormalizedCompoundDistanceAfterFirstWordIfNoneYet() { |
+ // We get called here after each word. We only want to store the distance |
+ // after |
+ // the first word, so if we already have a distance we skip saving -- hence |
+ // "IfNoneYet" |
+ // in the method name. |
+ if (mNormalizedCompoundDistanceAfterFirstWord >= MAX_VALUE_FOR_WEIGHTING) { |
+ mNormalizedCompoundDistanceAfterFirstWord = |
+ getNormalizedCompoundDistance(); |
+ } |
+ } |
+ |
+ void addRawLength(const float rawLength) { mRawLength += rawLength; } |
+ |
+ float getCompoundDistance() const { return getCompoundDistance(1.0f); } |
+ |
+ float getCompoundDistance(const float languageWeight) const { |
+ return mSpatialDistance + mLanguageDistance * languageWeight; |
+ } |
+ |
+ float getNormalizedCompoundDistance() const { |
+ return mNormalizedCompoundDistance; |
+ } |
+ |
+ // For space-aware gestures, we store the normalized distance at the char |
+ // index |
+ // that ends the first word of the suggestion. We call this the distance after |
+ // first word. |
+ float getNormalizedCompoundDistanceAfterFirstWord() const { |
+ return mNormalizedCompoundDistanceAfterFirstWord; |
+ } |
+ |
+ float getSpatialDistance() const { return mSpatialDistance; } |
+ |
+ float getLanguageDistance() const { return mLanguageDistance; } |
+ |
+ int16_t getEditCorrectionCount() const { return mEditCorrectionCount; } |
+ |
+ int16_t getProximityCorrectionCount() const { |
+ return mProximityCorrectionCount; |
+ } |
+ |
+ int16_t getCompletionCount() const { return mCompletionCount; } |
+ |
+ float getRawLength() const { return mRawLength; } |
+ |
+ DoubleLetterLevel getDoubleLetterLevel() const { return mDoubleLetterLevel; } |
+ |
+ void setDoubleLetterLevel(DoubleLetterLevel doubleLetterLevel) { |
+ switch (doubleLetterLevel) { |
+ case NOT_A_DOUBLE_LETTER: |
+ break; |
+ case A_DOUBLE_LETTER: |
+ if (mDoubleLetterLevel != A_STRONG_DOUBLE_LETTER) { |
+ mDoubleLetterLevel = doubleLetterLevel; |
+ } |
+ break; |
+ case A_STRONG_DOUBLE_LETTER: |
+ mDoubleLetterLevel = doubleLetterLevel; |
+ break; |
+ } |
+ } |
+ |
+ DigraphUtils::DigraphCodePointIndex getDigraphIndex() const { |
+ return mDigraphIndex; |
+ } |
+ |
+ void advanceDigraphIndex() { |
+ switch (mDigraphIndex) { |
+ case DigraphUtils::NOT_A_DIGRAPH_INDEX: |
+ mDigraphIndex = DigraphUtils::FIRST_DIGRAPH_CODEPOINT; |
+ break; |
+ case DigraphUtils::FIRST_DIGRAPH_CODEPOINT: |
+ mDigraphIndex = DigraphUtils::SECOND_DIGRAPH_CODEPOINT; |
+ break; |
+ case DigraphUtils::SECOND_DIGRAPH_CODEPOINT: |
+ mDigraphIndex = DigraphUtils::NOT_A_DIGRAPH_INDEX; |
+ break; |
+ } |
+ } |
+ |
+ ErrorTypeUtils::ErrorType getContainedErrorTypes() const { |
+ return mContainedErrorTypes; |
+ } |
+ |
+ private: |
+ DISALLOW_COPY_AND_ASSIGN(DicNodeStateScoring); |
+ |
+ DoubleLetterLevel mDoubleLetterLevel; |
+ DigraphUtils::DigraphCodePointIndex mDigraphIndex; |
+ |
+ int16_t mEditCorrectionCount; |
+ int16_t mProximityCorrectionCount; |
+ int16_t mCompletionCount; |
+ |
+ float mNormalizedCompoundDistance; |
+ float mSpatialDistance; |
+ float mLanguageDistance; |
+ float mRawLength; |
+ // All accumulated error types so far |
+ ErrorTypeUtils::ErrorType mContainedErrorTypes; |
+ float mNormalizedCompoundDistanceAfterFirstWord; |
+ |
+ AK_FORCE_INLINE void addDistance(float spatialDistance, |
+ float languageDistance, |
+ bool doNormalization, |
+ int inputSize, |
+ int totalInputIndex) { |
+ mSpatialDistance += spatialDistance; |
+ mLanguageDistance += languageDistance; |
+ if (!doNormalization) { |
+ mNormalizedCompoundDistance = mSpatialDistance + mLanguageDistance; |
+ } else { |
+ mNormalizedCompoundDistance = |
+ (mSpatialDistance + mLanguageDistance) / |
+ static_cast<float>(std::max(1, totalInputIndex)); |
+ } |
+ } |
+}; |
+} // namespace latinime |
+#endif // LATINIME_DIC_NODE_STATE_SCORING_H |