Index: third_party/android_prediction/suggest/core/suggest.cpp |
diff --git a/third_party/android_prediction/suggest/core/suggest.cpp b/third_party/android_prediction/suggest/core/suggest.cpp |
new file mode 100644 |
index 0000000000000000000000000000000000000000..0aa1c33eed576a3fecbafc1d031c89c0c9ddd4e5 |
--- /dev/null |
+++ b/third_party/android_prediction/suggest/core/suggest.cpp |
@@ -0,0 +1,435 @@ |
+/* |
+ * Copyright (C) 2012 The Android Open Source Project |
+ * |
+ * Licensed under the Apache License, Version 2.0 (the "License"); |
+ * you may not use this file except in compliance with the License. |
+ * You may obtain a copy of the License at |
+ * |
+ * http://www.apache.org/licenses/LICENSE-2.0 |
+ * |
+ * Unless required by applicable law or agreed to in writing, software |
+ * distributed under the License is distributed on an "AS IS" BASIS, |
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
+ * See the License for the specific language governing permissions and |
+ * limitations under the License. |
+ */ |
+ |
+#include "third_party/android_prediction/suggest/core/suggest.h" |
+ |
+#include "third_party/android_prediction/suggest/core/dicnode/dic_node.h" |
+#include "third_party/android_prediction/suggest/core/dicnode/dic_node_priority_queue.h" |
+#include "third_party/android_prediction/suggest/core/dicnode/dic_node_vector.h" |
+#include "third_party/android_prediction/suggest/core/dictionary/dictionary.h" |
+#include "third_party/android_prediction/suggest/core/dictionary/digraph_utils.h" |
+#include "third_party/android_prediction/suggest/core/layout/proximity_info.h" |
+#include "third_party/android_prediction/suggest/core/policy/dictionary_structure_with_buffer_policy.h" |
+#include "third_party/android_prediction/suggest/core/policy/traversal.h" |
+#include "third_party/android_prediction/suggest/core/policy/weighting.h" |
+#include "third_party/android_prediction/suggest/core/result/suggestions_output_utils.h" |
+#include "third_party/android_prediction/suggest/core/session/dic_traverse_session.h" |
+ |
+namespace latinime { |
+ |
+// Initialization of class constants. |
+const int Suggest::MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE = 2; |
+ |
+/** |
+ * Returns a set of suggestions for the given input touch points. The commitPoint argument indicates |
+ * whether to prematurely commit the suggested words up to the given point for sentence-level |
+ * suggestion. |
+ * |
+ * Note: Currently does not support concurrent calls across threads. Continuous suggestion is |
+ * automatically activated for sequential calls that share the same starting input. |
+ * TODO: Stop detecting continuous suggestion. Start using traverseSession instead. |
+ */ |
+void Suggest::getSuggestions(ProximityInfo *pInfo, void *traverseSession, |
+ int *inputXs, int *inputYs, int *times, int *pointerIds, int *inputCodePoints, |
+ int inputSize, const float languageWeight, |
+ SuggestionResults *const outSuggestionResults) const { |
+ PROF_OPEN; |
+ PROF_START(0); |
+ const float maxSpatialDistance = TRAVERSAL->getMaxSpatialDistance(); |
+ DicTraverseSession *tSession = static_cast<DicTraverseSession *>(traverseSession); |
+ tSession->setupForGetSuggestions(pInfo, inputCodePoints, inputSize, inputXs, inputYs, times, |
+ pointerIds, maxSpatialDistance, TRAVERSAL->getMaxPointerCount()); |
+ // TODO: Add the way to evaluate cache |
+ |
+ initializeSearch(tSession); |
+ PROF_END(0); |
+ PROF_START(1); |
+ |
+ // keep expanding search dicNodes until all have terminated. |
+ while (tSession->getDicTraverseCache()->activeSize() > 0) { |
+ expandCurrentDicNodes(tSession); |
+ tSession->getDicTraverseCache()->advanceActiveDicNodes(); |
+ tSession->getDicTraverseCache()->advanceInputIndex(inputSize); |
+ } |
+ PROF_END(1); |
+ PROF_START(2); |
+ SuggestionsOutputUtils::outputSuggestions( |
+ SCORING, tSession, languageWeight, outSuggestionResults); |
+ PROF_END(2); |
+ PROF_CLOSE; |
+} |
+ |
+/** |
+ * Initializes the search at the root of the lexicon trie. Note that when possible the search will |
+ * continue suggestion from where it left off during the last call. |
+ */ |
+void Suggest::initializeSearch(DicTraverseSession *traverseSession) const { |
+ if (!traverseSession->getProximityInfoState(0)->isUsed()) { |
+ return; |
+ } |
+ |
+ if (traverseSession->getInputSize() > MIN_CONTINUOUS_SUGGESTION_INPUT_SIZE |
+ && traverseSession->isContinuousSuggestionPossible()) { |
+ // Continue suggestion |
+ traverseSession->getDicTraverseCache()->continueSearch(); |
+ } else { |
+ // Restart recognition at the root. |
+ traverseSession->resetCache(TRAVERSAL->getMaxCacheSize(traverseSession->getInputSize()), |
+ TRAVERSAL->getTerminalCacheSize()); |
+ // Create a new dic node here |
+ DicNode rootNode; |
+ DicNodeUtils::initAsRoot(traverseSession->getDictionaryStructurePolicy(), |
+ traverseSession->getPrevWordsPtNodePos(), &rootNode); |
+ traverseSession->getDicTraverseCache()->copyPushActive(&rootNode); |
+ } |
+} |
+ |
+/** |
+ * Expands the dicNodes in the current search priority queue by advancing to the possible child |
+ * nodes based on the next touch point(s) (or no touch points for lookahead) |
+ */ |
+void Suggest::expandCurrentDicNodes(DicTraverseSession *traverseSession) const { |
+ const int inputSize = traverseSession->getInputSize(); |
+ DicNodeVector childDicNodes(TRAVERSAL->getDefaultExpandDicNodeSize()); |
+ DicNode correctionDicNode; |
+ |
+ // TODO: Find more efficient caching |
+ const bool shouldDepthLevelCache = TRAVERSAL->shouldDepthLevelCache(traverseSession); |
+ if (shouldDepthLevelCache) { |
+ traverseSession->getDicTraverseCache()->updateLastCachedInputIndex(); |
+ } |
+ if (DEBUG_CACHE) { |
+ AKLOGI("expandCurrentDicNodes depth level cache = %d, inputSize = %d", |
+ shouldDepthLevelCache, inputSize); |
+ } |
+ while (traverseSession->getDicTraverseCache()->activeSize() > 0) { |
+ DicNode dicNode; |
+ traverseSession->getDicTraverseCache()->popActive(&dicNode); |
+ if (dicNode.isTotalInputSizeExceedingLimit()) { |
+ return; |
+ } |
+ childDicNodes.clear(); |
+ const int point0Index = dicNode.getInputIndex(0); |
+ const bool canDoLookAheadCorrection = |
+ TRAVERSAL->canDoLookAheadCorrection(traverseSession, &dicNode); |
+ const bool isLookAheadCorrection = canDoLookAheadCorrection |
+ && traverseSession->getDicTraverseCache()-> |
+ isLookAheadCorrectionInputIndex(static_cast<int>(point0Index)); |
+ const bool isCompletion = dicNode.isCompletion(inputSize); |
+ |
+ const bool shouldNodeLevelCache = |
+ TRAVERSAL->shouldNodeLevelCache(traverseSession, &dicNode); |
+ if (shouldDepthLevelCache || shouldNodeLevelCache) { |
+ if (DEBUG_CACHE) { |
+ dicNode.dump("PUSH_CACHE"); |
+ } |
+ traverseSession->getDicTraverseCache()->copyPushContinue(&dicNode); |
+ dicNode.setCached(); |
+ } |
+ |
+ if (dicNode.isInDigraph()) { |
+ // Finish digraph handling if the node is in the middle of a digraph expansion. |
+ processDicNodeAsDigraph(traverseSession, &dicNode); |
+ } else if (isLookAheadCorrection) { |
+ // The algorithm maintains a small set of "deferred" nodes that have not consumed the |
+ // latest touch point yet. These are needed to apply look-ahead correction operations |
+ // that require special handling of the latest touch point. For example, with insertions |
+ // (e.g., "thiis" -> "this") the latest touch point should not be consumed at all. |
+ processDicNodeAsTransposition(traverseSession, &dicNode); |
+ processDicNodeAsInsertion(traverseSession, &dicNode); |
+ } else { // !isLookAheadCorrection |
+ // Only consider typing error corrections if the normalized compound distance is |
+ // below a spatial distance threshold. |
+ // NOTE: the threshold may need to be updated if scoring model changes. |
+ // TODO: Remove. Do not prune node here. |
+ const bool allowsErrorCorrections = TRAVERSAL->allowsErrorCorrections(&dicNode); |
+ // Process for handling space substitution (e.g., hevis => he is) |
+ if (allowsErrorCorrections |
+ && TRAVERSAL->isSpaceSubstitutionTerminal(traverseSession, &dicNode)) { |
+ createNextWordDicNode(traverseSession, &dicNode, true /* spaceSubstitution */); |
+ } |
+ |
+ DicNodeUtils::getAllChildDicNodes( |
+ &dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes); |
+ |
+ const int childDicNodesSize = childDicNodes.getSizeAndLock(); |
+ for (int i = 0; i < childDicNodesSize; ++i) { |
+ DicNode *const childDicNode = childDicNodes[i]; |
+ if (isCompletion) { |
+ // Handle forward lookahead when the lexicon letter exceeds the input size. |
+ processDicNodeAsMatch(traverseSession, childDicNode); |
+ continue; |
+ } |
+ if (DigraphUtils::hasDigraphForCodePoint( |
+ traverseSession->getDictionaryStructurePolicy() |
+ ->getHeaderStructurePolicy(), |
+ childDicNode->getNodeCodePoint())) { |
+ correctionDicNode.initByCopy(childDicNode); |
+ correctionDicNode.advanceDigraphIndex(); |
+ processDicNodeAsDigraph(traverseSession, &correctionDicNode); |
+ } |
+ if (TRAVERSAL->isOmission(traverseSession, &dicNode, childDicNode, |
+ allowsErrorCorrections)) { |
+ // TODO: (Gesture) Change weight between omission and substitution errors |
+ // TODO: (Gesture) Terminal node should not be handled as omission |
+ correctionDicNode.initByCopy(childDicNode); |
+ processDicNodeAsOmission(traverseSession, &correctionDicNode); |
+ } |
+ const ProximityType proximityType = TRAVERSAL->getProximityType( |
+ traverseSession, &dicNode, childDicNode); |
+ switch (proximityType) { |
+ // TODO: Consider the difference of proximityType here |
+ case MATCH_CHAR: |
+ case PROXIMITY_CHAR: |
+ processDicNodeAsMatch(traverseSession, childDicNode); |
+ break; |
+ case ADDITIONAL_PROXIMITY_CHAR: |
+ if (allowsErrorCorrections) { |
+ processDicNodeAsAdditionalProximityChar(traverseSession, &dicNode, |
+ childDicNode); |
+ } |
+ break; |
+ case SUBSTITUTION_CHAR: |
+ if (allowsErrorCorrections) { |
+ processDicNodeAsSubstitution(traverseSession, &dicNode, childDicNode); |
+ } |
+ break; |
+ case UNRELATED_CHAR: |
+ // Just drop this dicNode and do nothing. |
+ break; |
+ default: |
+ // Just drop this dicNode and do nothing. |
+ break; |
+ } |
+ } |
+ |
+ // Push the dicNode for look-ahead correction |
+ if (allowsErrorCorrections && canDoLookAheadCorrection) { |
+ traverseSession->getDicTraverseCache()->copyPushNextActive(&dicNode); |
+ } |
+ } |
+ } |
+} |
+ |
+void Suggest::processTerminalDicNode( |
+ DicTraverseSession *traverseSession, DicNode *dicNode) const { |
+ if (dicNode->getCompoundDistance() >= static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { |
+ return; |
+ } |
+ if (!dicNode->isTerminalDicNode()) { |
+ return; |
+ } |
+ if (dicNode->shouldBeFilteredBySafetyNetForBigram()) { |
+ return; |
+ } |
+ if (!dicNode->hasMatchedOrProximityCodePoints()) { |
+ return; |
+ } |
+ // Create a non-cached node here. |
+ DicNode terminalDicNode(*dicNode); |
+ if (TRAVERSAL->needsToTraverseAllUserInput() |
+ && dicNode->getInputIndex(0) < traverseSession->getInputSize()) { |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL_INSERTION, traverseSession, 0, |
+ &terminalDicNode, traverseSession->getMultiBigramMap()); |
+ } |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TERMINAL, traverseSession, 0, |
+ &terminalDicNode, traverseSession->getMultiBigramMap()); |
+ traverseSession->getDicTraverseCache()->copyPushTerminal(&terminalDicNode); |
+} |
+ |
+/** |
+ * Adds the expanded dicNode to the next search priority queue. Also creates an additional next word |
+ * (by the space omission error correction) search path if input dicNode is on a terminal. |
+ */ |
+void Suggest::processExpandedDicNode( |
+ DicTraverseSession *traverseSession, DicNode *dicNode) const { |
+ processTerminalDicNode(traverseSession, dicNode); |
+ if (dicNode->getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { |
+ if (TRAVERSAL->isSpaceOmissionTerminal(traverseSession, dicNode)) { |
+ createNextWordDicNode(traverseSession, dicNode, false /* spaceSubstitution */); |
+ } |
+ const int allowsLookAhead = !(dicNode->hasMultipleWords() |
+ && dicNode->isCompletion(traverseSession->getInputSize())); |
+ if (dicNode->hasChildren() && allowsLookAhead) { |
+ traverseSession->getDicTraverseCache()->copyPushNextActive(dicNode); |
+ } |
+ } |
+} |
+ |
+void Suggest::processDicNodeAsMatch(DicTraverseSession *traverseSession, |
+ DicNode *childDicNode) const { |
+ weightChildNode(traverseSession, childDicNode); |
+ processExpandedDicNode(traverseSession, childDicNode); |
+} |
+ |
+void Suggest::processDicNodeAsAdditionalProximityChar(DicTraverseSession *traverseSession, |
+ DicNode *dicNode, DicNode *childDicNode) const { |
+ // Note: Most types of corrections don't need to look up the bigram information since they do |
+ // not treat the node as a terminal. There is no need to pass the bigram map in these cases. |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_ADDITIONAL_PROXIMITY, |
+ traverseSession, dicNode, childDicNode, 0 /* multiBigramMap */); |
+ weightChildNode(traverseSession, childDicNode); |
+ processExpandedDicNode(traverseSession, childDicNode); |
+} |
+ |
+void Suggest::processDicNodeAsSubstitution(DicTraverseSession *traverseSession, |
+ DicNode *dicNode, DicNode *childDicNode) const { |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_SUBSTITUTION, traverseSession, |
+ dicNode, childDicNode, 0 /* multiBigramMap */); |
+ weightChildNode(traverseSession, childDicNode); |
+ processExpandedDicNode(traverseSession, childDicNode); |
+} |
+ |
+// Process the DicNode codepoint as a digraph. This means that composite glyphs like the German |
+// u-umlaut is expanded to the transliteration "ue". Note that this happens in parallel with |
+// the normal non-digraph traversal, so both "uber" and "ueber" can be corrected to "[u-umlaut]ber". |
+void Suggest::processDicNodeAsDigraph(DicTraverseSession *traverseSession, |
+ DicNode *childDicNode) const { |
+ weightChildNode(traverseSession, childDicNode); |
+ childDicNode->advanceDigraphIndex(); |
+ processExpandedDicNode(traverseSession, childDicNode); |
+} |
+ |
+/** |
+ * Handle the dicNode as an omission error (e.g., ths => this). Skip the current letter and consider |
+ * matches for all possible next letters. Note that just skipping the current letter without any |
+ * other conditions tends to flood the search DicNodes cache with omission DicNodes. Instead, check |
+ * the possible *next* letters after the omission to better limit search to plausible omissions. |
+ * Note that apostrophes are handled as omissions. |
+ */ |
+void Suggest::processDicNodeAsOmission( |
+ DicTraverseSession *traverseSession, DicNode *dicNode) const { |
+ DicNodeVector childDicNodes; |
+ DicNodeUtils::getAllChildDicNodes( |
+ dicNode, traverseSession->getDictionaryStructurePolicy(), &childDicNodes); |
+ |
+ const int size = childDicNodes.getSizeAndLock(); |
+ for (int i = 0; i < size; i++) { |
+ DicNode *const childDicNode = childDicNodes[i]; |
+ // Treat this word as omission |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_OMISSION, traverseSession, |
+ dicNode, childDicNode, 0 /* multiBigramMap */); |
+ weightChildNode(traverseSession, childDicNode); |
+ if (!TRAVERSAL->isPossibleOmissionChildNode(traverseSession, dicNode, childDicNode)) { |
+ continue; |
+ } |
+ processExpandedDicNode(traverseSession, childDicNode); |
+ } |
+} |
+ |
+/** |
+ * Handle the dicNode as an insertion error (e.g., thiis => this). Skip the current touch point and |
+ * consider matches for the next touch point. |
+ */ |
+void Suggest::processDicNodeAsInsertion(DicTraverseSession *traverseSession, |
+ DicNode *dicNode) const { |
+ const int16_t pointIndex = dicNode->getInputIndex(0); |
+ DicNodeVector childDicNodes; |
+ DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(), |
+ &childDicNodes); |
+ const int size = childDicNodes.getSizeAndLock(); |
+ for (int i = 0; i < size; i++) { |
+ if (traverseSession->getProximityInfoState(0)->getPrimaryCodePointAt(pointIndex + 1) |
+ != childDicNodes[i]->getNodeCodePoint()) { |
+ continue; |
+ } |
+ DicNode *const childDicNode = childDicNodes[i]; |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_INSERTION, traverseSession, |
+ dicNode, childDicNode, 0 /* multiBigramMap */); |
+ processExpandedDicNode(traverseSession, childDicNode); |
+ } |
+} |
+ |
+/** |
+ * Handle the dicNode as a transposition error (e.g., thsi => this). Swap the next two touch points. |
+ */ |
+void Suggest::processDicNodeAsTransposition(DicTraverseSession *traverseSession, |
+ DicNode *dicNode) const { |
+ const int16_t pointIndex = dicNode->getInputIndex(0); |
+ DicNodeVector childDicNodes1; |
+ DicNodeVector childDicNodes2; |
+ DicNodeUtils::getAllChildDicNodes(dicNode, traverseSession->getDictionaryStructurePolicy(), |
+ &childDicNodes1); |
+ const int childSize1 = childDicNodes1.getSizeAndLock(); |
+ for (int i = 0; i < childSize1; i++) { |
+ const ProximityType matchedId1 = traverseSession->getProximityInfoState(0) |
+ ->getProximityType(pointIndex + 1, childDicNodes1[i]->getNodeCodePoint(), |
+ true /* checkProximityChars */); |
+ if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId1)) { |
+ continue; |
+ } |
+ if (childDicNodes1[i]->hasChildren()) { |
+ childDicNodes2.clear(); |
+ DicNodeUtils::getAllChildDicNodes(childDicNodes1[i], |
+ traverseSession->getDictionaryStructurePolicy(), &childDicNodes2); |
+ const int childSize2 = childDicNodes2.getSizeAndLock(); |
+ for (int j = 0; j < childSize2; j++) { |
+ DicNode *const childDicNode2 = childDicNodes2[j]; |
+ const ProximityType matchedId2 = traverseSession->getProximityInfoState(0) |
+ ->getProximityType(pointIndex, childDicNode2->getNodeCodePoint(), |
+ true /* checkProximityChars */); |
+ if (!ProximityInfoUtils::isMatchOrProximityChar(matchedId2)) { |
+ continue; |
+ } |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_TRANSPOSITION, |
+ traverseSession, childDicNodes1[i], childDicNode2, 0 /* multiBigramMap */); |
+ processExpandedDicNode(traverseSession, childDicNode2); |
+ } |
+ } |
+ } |
+} |
+ |
+/** |
+ * Weight child dicNode by aligning it to the key |
+ */ |
+void Suggest::weightChildNode(DicTraverseSession *traverseSession, DicNode *dicNode) const { |
+ const int inputSize = traverseSession->getInputSize(); |
+ if (dicNode->isCompletion(inputSize)) { |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_COMPLETION, traverseSession, |
+ 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); |
+ } else { // completion |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, CT_MATCH, traverseSession, |
+ 0 /* parentDicNode */, dicNode, 0 /* multiBigramMap */); |
+ } |
+} |
+ |
+/** |
+ * Creates a new dicNode that represents a space insertion at the end of the input dicNode. Also |
+ * incorporates the unigram / bigram score for the ending word into the new dicNode. |
+ */ |
+void Suggest::createNextWordDicNode(DicTraverseSession *traverseSession, DicNode *dicNode, |
+ const bool spaceSubstitution) const { |
+ if (!TRAVERSAL->isGoodToTraverseNextWord(dicNode)) { |
+ return; |
+ } |
+ |
+ // Create a non-cached node here. |
+ DicNode newDicNode; |
+ DicNodeUtils::initAsRootWithPreviousWord( |
+ traverseSession->getDictionaryStructurePolicy(), dicNode, &newDicNode); |
+ const CorrectionType correctionType = spaceSubstitution ? |
+ CT_NEW_WORD_SPACE_SUBSTITUTION : CT_NEW_WORD_SPACE_OMISSION; |
+ Weighting::addCostAndForwardInputIndex(WEIGHTING, correctionType, traverseSession, dicNode, |
+ &newDicNode, traverseSession->getMultiBigramMap()); |
+ if (newDicNode.getCompoundDistance() < static_cast<float>(MAX_VALUE_FOR_WEIGHTING)) { |
+ // newDicNode is worth continuing to traverse. |
+ // CAVEAT: This pruning is important for speed. Remove this when we can afford not to prune |
+ // here because here is not the right place to do pruning. Pruning should take place only |
+ // in DicNodePriorityQueue. |
+ traverseSession->getDicTraverseCache()->copyPushNextActive(&newDicNode); |
+ } |
+} |
+} // namespace latinime |