Chromium Code Reviews
chromiumcodereview-hr@appspot.gserviceaccount.com (chromiumcodereview-hr) | Please choose your nickname with Settings | Help | Chromium Project | Gerrit Changes | Sign out
(92)

Side by Side Diff: third_party/prediction/suggest/policyimpl/dictionary/utils/forgetting_curve_utils.cpp

Issue 1247903003: Add spellcheck and word suggestion to the prediction service (Closed) Base URL: https://github.com/domokit/mojo.git@master
Patch Set: Created 5 years, 4 months ago
Use n/p to move between diff chunks; N/P to move between comments. Draft comments are only viewable by you.
Jump to:
View unified diff | Download patch
OLDNEW
(Empty)
1 /*
2 * Copyright (C) 2013, The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "third_party/prediction/suggest/policyimpl/dictionary/utils/forgetting_ curve_utils.h"
18
19 #include <algorithm>
20 #include <cmath>
21 #include <stdlib.h>
22
23 #include "third_party/prediction/suggest/policyimpl/dictionary/header/header_pol icy.h"
24 #include "third_party/prediction/suggest/policyimpl/dictionary/utils/probability _utils.h"
25 #include "third_party/prediction/utils/time_keeper.h"
26
27 namespace latinime {
28
29 const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8;
30 const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60;
31
32 const int ForgettingCurveUtils::MAX_LEVEL = 3;
33 const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1;
34 const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15;
35 const int
36 ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD =
37 14;
38
39 const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
40 const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2;
41
42 const ForgettingCurveUtils::ProbabilityTable
43 ForgettingCurveUtils::sProbabilityTable;
44
45 // TODO: Revise the logic to decide the initial probability depending on the
46 // given probability.
47 /* static */ const HistoricalInfo
48 ForgettingCurveUtils::createUpdatedHistoricalInfo(
49 const HistoricalInfo* const originalHistoricalInfo,
50 const int newProbability,
51 const HistoricalInfo* const newHistoricalInfo,
52 const HeaderPolicy* const headerPolicy) {
53 const int timestamp = newHistoricalInfo->getTimeStamp();
54 if (newProbability != NOT_A_PROBABILITY &&
55 originalHistoricalInfo->getLevel() == 0) {
56 // Add entry as a valid word.
57 const int level =
58 clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel());
59 const int count =
60 clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
61 return HistoricalInfo(timestamp, level, count);
62 } else if (!originalHistoricalInfo->isValid() ||
63 originalHistoricalInfo->getLevel() <
64 newHistoricalInfo->getLevel() ||
65 (originalHistoricalInfo->getLevel() ==
66 newHistoricalInfo->getLevel() &&
67 originalHistoricalInfo->getCount() <
68 newHistoricalInfo->getCount())) {
69 // Initial information.
70 const int level = clampToValidLevelRange(newHistoricalInfo->getLevel());
71 const int count =
72 clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy);
73 return HistoricalInfo(timestamp, level, count);
74 } else {
75 const int updatedCount = originalHistoricalInfo->getCount() + 1;
76 if (updatedCount >=
77 headerPolicy->getForgettingCurveOccurrencesToLevelUp()) {
78 // The count exceeds the max value the level can be incremented.
79 if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) {
80 // The level is already max.
81 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(),
82 originalHistoricalInfo->getCount());
83 } else {
84 // Level up.
85 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1,
86 0 /* count */);
87 }
88 } else {
89 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(),
90 updatedCount);
91 }
92 }
93 }
94
95 /* static */ int ForgettingCurveUtils::decodeProbability(
96 const HistoricalInfo* const historicalInfo,
97 const HeaderPolicy* const headerPolicy) {
98 const int elapsedTimeStepCount = getElapsedTimeStepCount(
99 historicalInfo->getTimeStamp(),
100 headerPolicy->getForgettingCurveDurationToLevelDown());
101 return sProbabilityTable.getProbability(
102 headerPolicy->getForgettingCurveProbabilityValuesTableId(),
103 clampToValidLevelRange(historicalInfo->getLevel()),
104 clampToValidTimeStepCountRange(elapsedTimeStepCount));
105 }
106
107 /* static */ int ForgettingCurveUtils::getProbability(
108 const int unigramProbability,
109 const int bigramProbability) {
110 if (unigramProbability == NOT_A_PROBABILITY) {
111 return NOT_A_PROBABILITY;
112 } else if (bigramProbability == NOT_A_PROBABILITY) {
113 return std::min(backoff(unigramProbability), MAX_PROBABILITY);
114 } else {
115 // TODO: Investigate better way to handle bigram probability.
116 return std::min(
117 std::max(unigramProbability,
118 bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE),
119 MAX_PROBABILITY);
120 }
121 }
122
123 /* static */ bool ForgettingCurveUtils::needsToKeep(
124 const HistoricalInfo* const historicalInfo,
125 const HeaderPolicy* const headerPolicy) {
126 return historicalInfo->getLevel() > 0 ||
127 getElapsedTimeStepCount(
128 historicalInfo->getTimeStamp(),
129 headerPolicy->getForgettingCurveDurationToLevelDown()) <
130 DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD;
131 }
132
133 /* static */ const HistoricalInfo
134 ForgettingCurveUtils::createHistoricalInfoToSave(
135 const HistoricalInfo* const originalHistoricalInfo,
136 const HeaderPolicy* const headerPolicy) {
137 if (originalHistoricalInfo->getTimeStamp() == NOT_A_TIMESTAMP) {
138 return HistoricalInfo();
139 }
140 const int durationToLevelDownInSeconds =
141 headerPolicy->getForgettingCurveDurationToLevelDown();
142 const int elapsedTimeStep = getElapsedTimeStepCount(
143 originalHistoricalInfo->getTimeStamp(), durationToLevelDownInSeconds);
144 if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) {
145 // No need to update historical info.
146 return *originalHistoricalInfo;
147 }
148 // Level down.
149 const int maxLevelDownAmonut =
150 elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
151 const int levelDownAmount =
152 (maxLevelDownAmonut >= originalHistoricalInfo->getLevel())
153 ? originalHistoricalInfo->getLevel()
154 : maxLevelDownAmonut;
155 const int adjustedTimestampInSeconds =
156 originalHistoricalInfo->getTimeStamp() +
157 levelDownAmount * durationToLevelDownInSeconds;
158 return HistoricalInfo(adjustedTimestampInSeconds,
159 originalHistoricalInfo->getLevel() - levelDownAmount,
160 0 /* count */);
161 }
162
163 /* static */ bool ForgettingCurveUtils::needsToDecay(
164 const bool mindsBlockByDecay,
165 const int unigramCount,
166 const int bigramCount,
167 const HeaderPolicy* const headerPolicy) {
168 if (unigramCount >=
169 getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) {
170 // Unigram count exceeds the limit.
171 return true;
172 } else if (bigramCount >=
173 getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) {
174 // Bigram count exceeds the limit.
175 return true;
176 }
177 if (mindsBlockByDecay) {
178 return false;
179 }
180 if (headerPolicy->getLastDecayedTime() + DECAY_INTERVAL_SECONDS <
181 TimeKeeper::peekCurrentTime()) {
182 // Time to decay.
183 return true;
184 }
185 return false;
186 }
187
188 // See comments in ProbabilityUtils::backoff().
189 /* static */ int ForgettingCurveUtils::backoff(const int unigramProbability) {
190 // See TODO comments in ForgettingCurveUtils::getProbability().
191 return unigramProbability;
192 }
193
194 /* static */ int ForgettingCurveUtils::getElapsedTimeStepCount(
195 const int timestamp,
196 const int durationToLevelDownInSeconds) {
197 const int elapsedTimeInSeconds = TimeKeeper::peekCurrentTime() - timestamp;
198 const int timeStepDurationInSeconds =
199 durationToLevelDownInSeconds / (MAX_ELAPSED_TIME_STEP_COUNT + 1);
200 return elapsedTimeInSeconds / timeStepDurationInSeconds;
201 }
202
203 /* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange(
204 const int level) {
205 return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL);
206 }
207
208 /* static */ int ForgettingCurveUtils::clampToValidCountRange(
209 const int count,
210 const HeaderPolicy* const headerPolicy) {
211 return std::min(std::max(count, 0),
212 headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1);
213 }
214
215 /* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) {
216 return std::min(std::max(level, 0), MAX_LEVEL);
217 }
218
219 /* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange(
220 const int timeStepCount) {
221 return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT);
222 }
223
224 const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4;
225 const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0;
226 const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID =
227 1;
228 const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID =
229 2;
230 const int
231 ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3;
232 const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127;
233 const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32;
234 const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35;
235 const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY =
236 40;
237
238 ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() {
239 mTables.resize(PROBABILITY_TABLE_COUNT);
240 for (int tableId = 0; tableId < PROBABILITY_TABLE_COUNT; ++tableId) {
241 mTables[tableId].resize(MAX_LEVEL + 1);
242 for (int level = 0; level <= MAX_LEVEL; ++level) {
243 mTables[tableId][level].resize(MAX_ELAPSED_TIME_STEP_COUNT + 1);
244 const float initialProbability =
245 getBaseProbabilityForLevel(tableId, level);
246 const float endProbability =
247 getBaseProbabilityForLevel(tableId, level - 1);
248 for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT;
249 ++timeStepCount) {
250 if (level == 0) {
251 mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY;
252 continue;
253 }
254 const float probability =
255 initialProbability *
256 powf(initialProbability / endProbability,
257 -1.0f * static_cast<float>(timeStepCount) /
258 static_cast<float>(MAX_ELAPSED_TIME_STEP_COUNT + 1));
259 mTables[tableId][level][timeStepCount] = std::min(
260 std::max(static_cast<int>(probability), 1), MAX_PROBABILITY);
261 }
262 }
263 }
264 }
265
266 /* static */ int
267 ForgettingCurveUtils::ProbabilityTable::getBaseProbabilityForLevel(
268 const int tableId,
269 const int level) {
270 if (tableId == WEAK_PROBABILITY_TABLE_ID) {
271 // Max probability is 127.
272 return static_cast<float>(WEAK_MAX_PROBABILITY /
273 (1 << (MAX_LEVEL - level)));
274 } else if (tableId == MODEST_PROBABILITY_TABLE_ID) {
275 // Max probability is 128.
276 return static_cast<float>(MODEST_BASE_PROBABILITY * (level + 1));
277 } else if (tableId == STRONG_PROBABILITY_TABLE_ID) {
278 // Max probability is 140.
279 return static_cast<float>(STRONG_BASE_PROBABILITY * (level + 1));
280 } else if (tableId == AGGRESSIVE_PROBABILITY_TABLE_ID) {
281 // Max probability is 160.
282 return static_cast<float>(AGGRESSIVE_BASE_PROBABILITY * (level + 1));
283 } else {
284 return NOT_A_PROBABILITY;
285 }
286 }
287
288 } // namespace latinime
OLDNEW

Powered by Google App Engine
This is Rietveld 408576698