Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(78)

Side by Side Diff: third_party/prediction/suggest/policyimpl/typing/typing_weighting.h

Issue 1247903003: Add spellcheck and word suggestion to the prediction service (Closed) Base URL: https://github.com/domokit/mojo.git@master
Patch Set: Created 5 years, 4 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 /*
2 * Copyright (C) 2012 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef LATINIME_TYPING_WEIGHTING_H
18 #define LATINIME_TYPING_WEIGHTING_H
19
20 #include "third_party/prediction/defines.h"
21 #include "third_party/prediction/suggest/core/dicnode/dic_node_utils.h"
22 #include "third_party/prediction/suggest/core/dictionary/error_type_utils.h"
23 #include "third_party/prediction/suggest/core/layout/touch_position_correction_u tils.h"
24 #include "third_party/prediction/suggest/core/policy/weighting.h"
25 #include "third_party/prediction/suggest/core/session/dic_traverse_session.h"
26 #include "third_party/prediction/suggest/policyimpl/typing/scoring_params.h"
27 #include "third_party/prediction/utils/char_utils.h"
28
29 namespace latinime {
30
31 class DicNode;
32 struct DicNode_InputStateG;
33 class MultiBigramMap;
34
35 class TypingWeighting : public Weighting {
36 public:
37 static const TypingWeighting* getInstance() { return &sInstance; }
38
39 protected:
40 float getTerminalSpatialCost(const DicTraverseSession* const traverseSession,
41 const DicNode* const dicNode) const {
42 float cost = 0.0f;
43 if (dicNode->hasMultipleWords()) {
44 cost += ScoringParams::HAS_MULTI_WORD_TERMINAL_COST;
45 }
46 if (dicNode->getProximityCorrectionCount() > 0) {
47 cost += ScoringParams::HAS_PROXIMITY_TERMINAL_COST;
48 }
49 if (dicNode->getEditCorrectionCount() > 0) {
50 cost += ScoringParams::HAS_EDIT_CORRECTION_TERMINAL_COST;
51 }
52 return cost;
53 }
54
55 float getOmissionCost(const DicNode* const parentDicNode,
56 const DicNode* const dicNode) const {
57 const bool isZeroCostOmission = parentDicNode->isZeroCostOmission();
58 const bool isIntentionalOmission =
59 parentDicNode->canBeIntentionalOmission();
60 const bool sameCodePoint = dicNode->isSameNodeCodePoint(parentDicNode);
61 // If the traversal omitted the first letter then the dicNode should now be
62 // on the second.
63 const bool isFirstLetterOmission = dicNode->getNodeCodePointCount() == 2;
64 float cost = 0.0f;
65 if (isZeroCostOmission) {
66 cost = 0.0f;
67 } else if (isIntentionalOmission) {
68 cost = ScoringParams::INTENTIONAL_OMISSION_COST;
69 } else if (isFirstLetterOmission) {
70 cost = ScoringParams::OMISSION_COST_FIRST_CHAR;
71 } else {
72 cost = sameCodePoint ? ScoringParams::OMISSION_COST_SAME_CHAR
73 : ScoringParams::OMISSION_COST;
74 }
75 return cost;
76 }
77
78 float getMatchedCost(const DicTraverseSession* const traverseSession,
79 const DicNode* const dicNode,
80 DicNode_InputStateG* inputStateG) const {
81 const int pointIndex = dicNode->getInputIndex(0);
82 const float normalizedSquaredLength =
83 traverseSession->getProximityInfoState(0)->getPointToKeyLength(
84 pointIndex,
85 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
86 const float normalizedDistance =
87 TouchPositionCorrectionUtils::getSweetSpotFactor(
88 traverseSession->isTouchPositionCorrectionEnabled(),
89 normalizedSquaredLength);
90 const float weightedDistance =
91 ScoringParams::DISTANCE_WEIGHT_LENGTH * normalizedDistance;
92
93 const bool isFirstChar = pointIndex == 0;
94 const bool isProximity = isProximityDicNode(traverseSession, dicNode);
95 float cost = isProximity
96 ? (isFirstChar ? ScoringParams::FIRST_CHAR_PROXIMITY_COST
97 : ScoringParams::PROXIMITY_COST)
98 : 0.0f;
99 if (isProximity && dicNode->getProximityCorrectionCount() == 0) {
100 cost += ScoringParams::FIRST_PROXIMITY_COST;
101 }
102 if (dicNode->getNodeCodePointCount() == 2) {
103 // At the second character of the current word, we check if the first char
104 // is uppercase
105 // and the word is a second or later word of a multiple word suggestion.
106 // We demote it
107 // if so.
108 const bool isSecondOrLaterWordFirstCharUppercase =
109 dicNode->hasMultipleWords() && dicNode->isFirstCharUppercase();
110 if (isSecondOrLaterWordFirstCharUppercase) {
111 cost += ScoringParams::COST_SECOND_OR_LATER_WORD_FIRST_CHAR_UPPERCASE;
112 }
113 }
114 return weightedDistance + cost;
115 }
116
117 bool isProximityDicNode(const DicTraverseSession* const traverseSession,
118 const DicNode* const dicNode) const {
119 const int pointIndex = dicNode->getInputIndex(0);
120 const int primaryCodePoint =
121 CharUtils::toBaseLowerCase(traverseSession->getProximityInfoState(0)
122 ->getPrimaryCodePointAt(pointIndex));
123 const int dicNodeChar =
124 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint());
125 return primaryCodePoint != dicNodeChar;
126 }
127
128 float getTranspositionCost(const DicTraverseSession* const traverseSession,
129 const DicNode* const parentDicNode,
130 const DicNode* const dicNode) const {
131 const int16_t parentPointIndex = parentDicNode->getInputIndex(0);
132 const int prevCodePoint = parentDicNode->getNodeCodePoint();
133 const float distance1 =
134 traverseSession->getProximityInfoState(0)->getPointToKeyLength(
135 parentPointIndex + 1, CharUtils::toBaseLowerCase(prevCodePoint));
136 const int codePoint = dicNode->getNodeCodePoint();
137 const float distance2 =
138 traverseSession->getProximityInfoState(0)->getPointToKeyLength(
139 parentPointIndex, CharUtils::toBaseLowerCase(codePoint));
140 const float distance = distance1 + distance2;
141 const float weightedLengthDistance =
142 distance * ScoringParams::DISTANCE_WEIGHT_LENGTH;
143 return ScoringParams::TRANSPOSITION_COST + weightedLengthDistance;
144 }
145
146 float getInsertionCost(const DicTraverseSession* const traverseSession,
147 const DicNode* const parentDicNode,
148 const DicNode* const dicNode) const {
149 const int16_t insertedPointIndex = parentDicNode->getInputIndex(0);
150 const int prevCodePoint = traverseSession->getProximityInfoState(0)
151 ->getPrimaryCodePointAt(insertedPointIndex);
152 const int currentCodePoint = dicNode->getNodeCodePoint();
153 const bool sameCodePoint = prevCodePoint == currentCodePoint;
154 const bool existsAdjacentProximityChars =
155 traverseSession->getProximityInfoState(0)
156 ->existsAdjacentProximityChars(insertedPointIndex);
157 const float dist =
158 traverseSession->getProximityInfoState(0)->getPointToKeyLength(
159 insertedPointIndex + 1,
160 CharUtils::toBaseLowerCase(dicNode->getNodeCodePoint()));
161 const float weightedDistance = dist * ScoringParams::DISTANCE_WEIGHT_LENGTH;
162 const bool singleChar = dicNode->getNodeCodePointCount() == 1;
163 float cost = (singleChar ? ScoringParams::INSERTION_COST_FIRST_CHAR : 0.0f);
164 if (sameCodePoint) {
165 cost += ScoringParams::INSERTION_COST_SAME_CHAR;
166 } else if (existsAdjacentProximityChars) {
167 cost += ScoringParams::INSERTION_COST_PROXIMITY_CHAR;
168 } else {
169 cost += ScoringParams::INSERTION_COST;
170 }
171 return cost + weightedDistance;
172 }
173
174 float getNewWordSpatialCost(const DicTraverseSession* const traverseSession,
175 const DicNode* const dicNode,
176 DicNode_InputStateG* inputStateG) const {
177 return ScoringParams::COST_NEW_WORD *
178 traverseSession->getMultiWordCostMultiplier();
179 }
180
181 float getNewWordBigramLanguageCost(
182 const DicTraverseSession* const traverseSession,
183 const DicNode* const dicNode,
184 MultiBigramMap* const multiBigramMap) const {
185 return DicNodeUtils::getBigramNodeImprobability(
186 traverseSession->getDictionaryStructurePolicy(), dicNode,
187 multiBigramMap) *
188 ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
189 }
190
191 float getCompletionCost(const DicTraverseSession* const traverseSession,
192 const DicNode* const dicNode) const {
193 // The auto completion starts when the input index is same as the input size
194 const bool firstCompletion =
195 dicNode->getInputIndex(0) == traverseSession->getInputSize();
196 // TODO: Change the cost for the first completion for the gesture?
197 const float cost = firstCompletion ? ScoringParams::COST_FIRST_COMPLETION
198 : ScoringParams::COST_COMPLETION;
199 return cost;
200 }
201
202 float getTerminalLanguageCost(
203 const DicTraverseSession* const traverseSession,
204 const DicNode* const dicNode,
205 const float dicNodeLanguageImprobability) const {
206 return dicNodeLanguageImprobability *
207 ScoringParams::DISTANCE_WEIGHT_LANGUAGE;
208 }
209
210 float getTerminalInsertionCost(
211 const DicTraverseSession* const traverseSession,
212 const DicNode* const dicNode) const {
213 const int inputIndex = dicNode->getInputIndex(0);
214 const int inputSize = traverseSession->getInputSize();
215 ASSERT(inputIndex < inputSize);
216 // TODO: Implement more efficient logic
217 return ScoringParams::TERMINAL_INSERTION_COST * (inputSize - inputIndex);
218 }
219
220 AK_FORCE_INLINE bool needsToNormalizeCompoundDistance() const {
221 return false;
222 }
223
224 AK_FORCE_INLINE float getAdditionalProximityCost() const {
225 return ScoringParams::ADDITIONAL_PROXIMITY_COST;
226 }
227
228 AK_FORCE_INLINE float getSubstitutionCost() const {
229 return ScoringParams::SUBSTITUTION_COST;
230 }
231
232 AK_FORCE_INLINE float getSpaceSubstitutionCost(
233 const DicTraverseSession* const traverseSession,
234 const DicNode* const dicNode) const {
235 const float cost =
236 ScoringParams::SPACE_SUBSTITUTION_COST + ScoringParams::COST_NEW_WORD;
237 return cost * traverseSession->getMultiWordCostMultiplier();
238 }
239
240 ErrorTypeUtils::ErrorType getErrorType(
241 const CorrectionType correctionType,
242 const DicTraverseSession* const traverseSession,
243 const DicNode* const parentDicNode,
244 const DicNode* const dicNode) const;
245
246 private:
247 DISALLOW_COPY_AND_ASSIGN(TypingWeighting);
248 static const TypingWeighting sInstance;
249
250 TypingWeighting() {}
251 ~TypingWeighting() {}
252 };
253 } // namespace latinime
254 #endif // LATINIME_TYPING_WEIGHTING_H
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698