| OLD | NEW |
| (Empty) | |
| 1 /* |
| 2 * Copyright (C) 2013, The Android Open Source Project |
| 3 * |
| 4 * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 * you may not use this file except in compliance with the License. |
| 6 * You may obtain a copy of the License at |
| 7 * |
| 8 * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 * |
| 10 * Unless required by applicable law or agreed to in writing, software |
| 11 * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 * See the License for the specific language governing permissions and |
| 14 * limitations under the License. |
| 15 */ |
| 16 |
| 17 #include "third_party/prediction/suggest/policyimpl/dictionary/utils/forgetting_
curve_utils.h" |
| 18 |
| 19 #include <algorithm> |
| 20 #include <cmath> |
| 21 #include <stdlib.h> |
| 22 |
| 23 #include "third_party/prediction/suggest/policyimpl/dictionary/header/header_pol
icy.h" |
| 24 #include "third_party/prediction/suggest/policyimpl/dictionary/utils/probability
_utils.h" |
| 25 #include "third_party/prediction/utils/time_keeper.h" |
| 26 |
| 27 namespace latinime { |
| 28 |
| 29 const int ForgettingCurveUtils::MULTIPLIER_TWO_IN_PROBABILITY_SCALE = 8; |
| 30 const int ForgettingCurveUtils::DECAY_INTERVAL_SECONDS = 2 * 60 * 60; |
| 31 |
| 32 const int ForgettingCurveUtils::MAX_LEVEL = 3; |
| 33 const int ForgettingCurveUtils::MIN_VISIBLE_LEVEL = 1; |
| 34 const int ForgettingCurveUtils::MAX_ELAPSED_TIME_STEP_COUNT = 15; |
| 35 const int |
| 36 ForgettingCurveUtils::DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD = |
| 37 14; |
| 38 |
| 39 const float ForgettingCurveUtils::UNIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; |
| 40 const float ForgettingCurveUtils::BIGRAM_COUNT_HARD_LIMIT_WEIGHT = 1.2; |
| 41 |
| 42 const ForgettingCurveUtils::ProbabilityTable |
| 43 ForgettingCurveUtils::sProbabilityTable; |
| 44 |
| 45 // TODO: Revise the logic to decide the initial probability depending on the |
| 46 // given probability. |
| 47 /* static */ const HistoricalInfo |
| 48 ForgettingCurveUtils::createUpdatedHistoricalInfo( |
| 49 const HistoricalInfo* const originalHistoricalInfo, |
| 50 const int newProbability, |
| 51 const HistoricalInfo* const newHistoricalInfo, |
| 52 const HeaderPolicy* const headerPolicy) { |
| 53 const int timestamp = newHistoricalInfo->getTimeStamp(); |
| 54 if (newProbability != NOT_A_PROBABILITY && |
| 55 originalHistoricalInfo->getLevel() == 0) { |
| 56 // Add entry as a valid word. |
| 57 const int level = |
| 58 clampToVisibleEntryLevelRange(newHistoricalInfo->getLevel()); |
| 59 const int count = |
| 60 clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); |
| 61 return HistoricalInfo(timestamp, level, count); |
| 62 } else if (!originalHistoricalInfo->isValid() || |
| 63 originalHistoricalInfo->getLevel() < |
| 64 newHistoricalInfo->getLevel() || |
| 65 (originalHistoricalInfo->getLevel() == |
| 66 newHistoricalInfo->getLevel() && |
| 67 originalHistoricalInfo->getCount() < |
| 68 newHistoricalInfo->getCount())) { |
| 69 // Initial information. |
| 70 const int level = clampToValidLevelRange(newHistoricalInfo->getLevel()); |
| 71 const int count = |
| 72 clampToValidCountRange(newHistoricalInfo->getCount(), headerPolicy); |
| 73 return HistoricalInfo(timestamp, level, count); |
| 74 } else { |
| 75 const int updatedCount = originalHistoricalInfo->getCount() + 1; |
| 76 if (updatedCount >= |
| 77 headerPolicy->getForgettingCurveOccurrencesToLevelUp()) { |
| 78 // The count exceeds the max value the level can be incremented. |
| 79 if (originalHistoricalInfo->getLevel() >= MAX_LEVEL) { |
| 80 // The level is already max. |
| 81 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), |
| 82 originalHistoricalInfo->getCount()); |
| 83 } else { |
| 84 // Level up. |
| 85 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel() + 1, |
| 86 0 /* count */); |
| 87 } |
| 88 } else { |
| 89 return HistoricalInfo(timestamp, originalHistoricalInfo->getLevel(), |
| 90 updatedCount); |
| 91 } |
| 92 } |
| 93 } |
| 94 |
| 95 /* static */ int ForgettingCurveUtils::decodeProbability( |
| 96 const HistoricalInfo* const historicalInfo, |
| 97 const HeaderPolicy* const headerPolicy) { |
| 98 const int elapsedTimeStepCount = getElapsedTimeStepCount( |
| 99 historicalInfo->getTimeStamp(), |
| 100 headerPolicy->getForgettingCurveDurationToLevelDown()); |
| 101 return sProbabilityTable.getProbability( |
| 102 headerPolicy->getForgettingCurveProbabilityValuesTableId(), |
| 103 clampToValidLevelRange(historicalInfo->getLevel()), |
| 104 clampToValidTimeStepCountRange(elapsedTimeStepCount)); |
| 105 } |
| 106 |
| 107 /* static */ int ForgettingCurveUtils::getProbability( |
| 108 const int unigramProbability, |
| 109 const int bigramProbability) { |
| 110 if (unigramProbability == NOT_A_PROBABILITY) { |
| 111 return NOT_A_PROBABILITY; |
| 112 } else if (bigramProbability == NOT_A_PROBABILITY) { |
| 113 return std::min(backoff(unigramProbability), MAX_PROBABILITY); |
| 114 } else { |
| 115 // TODO: Investigate better way to handle bigram probability. |
| 116 return std::min( |
| 117 std::max(unigramProbability, |
| 118 bigramProbability + MULTIPLIER_TWO_IN_PROBABILITY_SCALE), |
| 119 MAX_PROBABILITY); |
| 120 } |
| 121 } |
| 122 |
| 123 /* static */ bool ForgettingCurveUtils::needsToKeep( |
| 124 const HistoricalInfo* const historicalInfo, |
| 125 const HeaderPolicy* const headerPolicy) { |
| 126 return historicalInfo->getLevel() > 0 || |
| 127 getElapsedTimeStepCount( |
| 128 historicalInfo->getTimeStamp(), |
| 129 headerPolicy->getForgettingCurveDurationToLevelDown()) < |
| 130 DISCARD_LEVEL_ZERO_ENTRY_TIME_STEP_COUNT_THRESHOLD; |
| 131 } |
| 132 |
| 133 /* static */ const HistoricalInfo |
| 134 ForgettingCurveUtils::createHistoricalInfoToSave( |
| 135 const HistoricalInfo* const originalHistoricalInfo, |
| 136 const HeaderPolicy* const headerPolicy) { |
| 137 if (originalHistoricalInfo->getTimeStamp() == NOT_A_TIMESTAMP) { |
| 138 return HistoricalInfo(); |
| 139 } |
| 140 const int durationToLevelDownInSeconds = |
| 141 headerPolicy->getForgettingCurveDurationToLevelDown(); |
| 142 const int elapsedTimeStep = getElapsedTimeStepCount( |
| 143 originalHistoricalInfo->getTimeStamp(), durationToLevelDownInSeconds); |
| 144 if (elapsedTimeStep <= MAX_ELAPSED_TIME_STEP_COUNT) { |
| 145 // No need to update historical info. |
| 146 return *originalHistoricalInfo; |
| 147 } |
| 148 // Level down. |
| 149 const int maxLevelDownAmonut = |
| 150 elapsedTimeStep / (MAX_ELAPSED_TIME_STEP_COUNT + 1); |
| 151 const int levelDownAmount = |
| 152 (maxLevelDownAmonut >= originalHistoricalInfo->getLevel()) |
| 153 ? originalHistoricalInfo->getLevel() |
| 154 : maxLevelDownAmonut; |
| 155 const int adjustedTimestampInSeconds = |
| 156 originalHistoricalInfo->getTimeStamp() + |
| 157 levelDownAmount * durationToLevelDownInSeconds; |
| 158 return HistoricalInfo(adjustedTimestampInSeconds, |
| 159 originalHistoricalInfo->getLevel() - levelDownAmount, |
| 160 0 /* count */); |
| 161 } |
| 162 |
| 163 /* static */ bool ForgettingCurveUtils::needsToDecay( |
| 164 const bool mindsBlockByDecay, |
| 165 const int unigramCount, |
| 166 const int bigramCount, |
| 167 const HeaderPolicy* const headerPolicy) { |
| 168 if (unigramCount >= |
| 169 getUnigramCountHardLimit(headerPolicy->getMaxUnigramCount())) { |
| 170 // Unigram count exceeds the limit. |
| 171 return true; |
| 172 } else if (bigramCount >= |
| 173 getBigramCountHardLimit(headerPolicy->getMaxBigramCount())) { |
| 174 // Bigram count exceeds the limit. |
| 175 return true; |
| 176 } |
| 177 if (mindsBlockByDecay) { |
| 178 return false; |
| 179 } |
| 180 if (headerPolicy->getLastDecayedTime() + DECAY_INTERVAL_SECONDS < |
| 181 TimeKeeper::peekCurrentTime()) { |
| 182 // Time to decay. |
| 183 return true; |
| 184 } |
| 185 return false; |
| 186 } |
| 187 |
| 188 // See comments in ProbabilityUtils::backoff(). |
| 189 /* static */ int ForgettingCurveUtils::backoff(const int unigramProbability) { |
| 190 // See TODO comments in ForgettingCurveUtils::getProbability(). |
| 191 return unigramProbability; |
| 192 } |
| 193 |
| 194 /* static */ int ForgettingCurveUtils::getElapsedTimeStepCount( |
| 195 const int timestamp, |
| 196 const int durationToLevelDownInSeconds) { |
| 197 const int elapsedTimeInSeconds = TimeKeeper::peekCurrentTime() - timestamp; |
| 198 const int timeStepDurationInSeconds = |
| 199 durationToLevelDownInSeconds / (MAX_ELAPSED_TIME_STEP_COUNT + 1); |
| 200 return elapsedTimeInSeconds / timeStepDurationInSeconds; |
| 201 } |
| 202 |
| 203 /* static */ int ForgettingCurveUtils::clampToVisibleEntryLevelRange( |
| 204 const int level) { |
| 205 return std::min(std::max(level, MIN_VISIBLE_LEVEL), MAX_LEVEL); |
| 206 } |
| 207 |
| 208 /* static */ int ForgettingCurveUtils::clampToValidCountRange( |
| 209 const int count, |
| 210 const HeaderPolicy* const headerPolicy) { |
| 211 return std::min(std::max(count, 0), |
| 212 headerPolicy->getForgettingCurveOccurrencesToLevelUp() - 1); |
| 213 } |
| 214 |
| 215 /* static */ int ForgettingCurveUtils::clampToValidLevelRange(const int level) { |
| 216 return std::min(std::max(level, 0), MAX_LEVEL); |
| 217 } |
| 218 |
| 219 /* static */ int ForgettingCurveUtils::clampToValidTimeStepCountRange( |
| 220 const int timeStepCount) { |
| 221 return std::min(std::max(timeStepCount, 0), MAX_ELAPSED_TIME_STEP_COUNT); |
| 222 } |
| 223 |
| 224 const int ForgettingCurveUtils::ProbabilityTable::PROBABILITY_TABLE_COUNT = 4; |
| 225 const int ForgettingCurveUtils::ProbabilityTable::WEAK_PROBABILITY_TABLE_ID = 0; |
| 226 const int ForgettingCurveUtils::ProbabilityTable::MODEST_PROBABILITY_TABLE_ID = |
| 227 1; |
| 228 const int ForgettingCurveUtils::ProbabilityTable::STRONG_PROBABILITY_TABLE_ID = |
| 229 2; |
| 230 const int |
| 231 ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_PROBABILITY_TABLE_ID = 3; |
| 232 const int ForgettingCurveUtils::ProbabilityTable::WEAK_MAX_PROBABILITY = 127; |
| 233 const int ForgettingCurveUtils::ProbabilityTable::MODEST_BASE_PROBABILITY = 32; |
| 234 const int ForgettingCurveUtils::ProbabilityTable::STRONG_BASE_PROBABILITY = 35; |
| 235 const int ForgettingCurveUtils::ProbabilityTable::AGGRESSIVE_BASE_PROBABILITY = |
| 236 40; |
| 237 |
| 238 ForgettingCurveUtils::ProbabilityTable::ProbabilityTable() : mTables() { |
| 239 mTables.resize(PROBABILITY_TABLE_COUNT); |
| 240 for (int tableId = 0; tableId < PROBABILITY_TABLE_COUNT; ++tableId) { |
| 241 mTables[tableId].resize(MAX_LEVEL + 1); |
| 242 for (int level = 0; level <= MAX_LEVEL; ++level) { |
| 243 mTables[tableId][level].resize(MAX_ELAPSED_TIME_STEP_COUNT + 1); |
| 244 const float initialProbability = |
| 245 getBaseProbabilityForLevel(tableId, level); |
| 246 const float endProbability = |
| 247 getBaseProbabilityForLevel(tableId, level - 1); |
| 248 for (int timeStepCount = 0; timeStepCount <= MAX_ELAPSED_TIME_STEP_COUNT; |
| 249 ++timeStepCount) { |
| 250 if (level == 0) { |
| 251 mTables[tableId][level][timeStepCount] = NOT_A_PROBABILITY; |
| 252 continue; |
| 253 } |
| 254 const float probability = |
| 255 initialProbability * |
| 256 powf(initialProbability / endProbability, |
| 257 -1.0f * static_cast<float>(timeStepCount) / |
| 258 static_cast<float>(MAX_ELAPSED_TIME_STEP_COUNT + 1)); |
| 259 mTables[tableId][level][timeStepCount] = std::min( |
| 260 std::max(static_cast<int>(probability), 1), MAX_PROBABILITY); |
| 261 } |
| 262 } |
| 263 } |
| 264 } |
| 265 |
| 266 /* static */ int |
| 267 ForgettingCurveUtils::ProbabilityTable::getBaseProbabilityForLevel( |
| 268 const int tableId, |
| 269 const int level) { |
| 270 if (tableId == WEAK_PROBABILITY_TABLE_ID) { |
| 271 // Max probability is 127. |
| 272 return static_cast<float>(WEAK_MAX_PROBABILITY / |
| 273 (1 << (MAX_LEVEL - level))); |
| 274 } else if (tableId == MODEST_PROBABILITY_TABLE_ID) { |
| 275 // Max probability is 128. |
| 276 return static_cast<float>(MODEST_BASE_PROBABILITY * (level + 1)); |
| 277 } else if (tableId == STRONG_PROBABILITY_TABLE_ID) { |
| 278 // Max probability is 140. |
| 279 return static_cast<float>(STRONG_BASE_PROBABILITY * (level + 1)); |
| 280 } else if (tableId == AGGRESSIVE_PROBABILITY_TABLE_ID) { |
| 281 // Max probability is 160. |
| 282 return static_cast<float>(AGGRESSIVE_BASE_PROBABILITY * (level + 1)); |
| 283 } else { |
| 284 return NOT_A_PROBABILITY; |
| 285 } |
| 286 } |
| 287 |
| 288 } // namespace latinime |
| OLD | NEW |