diff --git a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts index 98b3086709f..b2bd4297fa6 100644 --- a/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts +++ b/web/src/engine/predictive-text/worker-thread/src/main/predict-helpers.ts @@ -13,6 +13,7 @@ import { ContextTransition } from './correction/context-transition.js'; import { ExecutionTimer } from './correction/execution-timer.js'; import { ModelCompositor } from './model-compositor.js'; import { EDIT_DISTANCE_COST_SCALE, getBestTokenMatches } from './correction/distance-modeler.js'; +import { TokenResult } from './correction/tokenization-corrector.js'; import CasingForm = LexicalModelTypes.CasingForm; import Context = LexicalModelTypes.Context; @@ -25,7 +26,6 @@ import Reversion = LexicalModelTypes.Reversion; import Suggestion = LexicalModelTypes.Suggestion; import SuggestionTag = LexicalModelTypes.SuggestionTag; import Transform = LexicalModelTypes.Transform; -import { TokenResult } from './correction/tokenization-corrector.js'; /* * The functions in this file exist to provide unit-testable stateless components for the @@ -470,6 +470,34 @@ export interface PredictionParameters { applyInPost: (entry: CorrectionPredictionTuple) => void } +export function buildCorrectionSequence( + transitionEffects: ReturnType, + context: Context, + match: Readonly, + costFactor: number +) { + const { deleteLeft } = transitionEffects; + + const rootContext = models.applyTransform({insert: '', deleteLeft}, context); + + // Replace the existing context with the correction. + const correctionTransform: Transform = { + insert: match.matchString, // insert correction string + deleteLeft: 0, + } + + const rootCost = match.totalCost; + const predictionRoot = { + sample: correctionTransform, + p: Math.exp(-rootCost * costFactor) + }; + + return { + rootContext, + tokenizedCorrection: [predictionRoot] + }; +} + /** * This function takes in metadata about generated corrections (for models that * implement Traversals) and uses that to produce the corresponding parameters @@ -491,31 +519,20 @@ export function determineTokenizedCorrectionSequence( costFactor: number ): PredictionParameters { const applicationTarget = transition.base.displayTokenization; - const { deleteLeft } = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); - - const rootContext = models.applyTransform({insert: '', deleteLeft}, transition.base.context); + const transitionParams = determineSuggestionRange(applicationTarget.tokens, tokenization.tokens, (a, b) => a.spaceId == b.spaceId); - // Replace the existing context with the correction. - const correctionTransform: Transform = { - insert: match.matchString, // insert correction string - deleteLeft: 0, - } + const suggestionParams = buildCorrectionSequence(transitionParams, transition.base.context, match, costFactor); // The correction should always be based on the most recent external // transform/transcription ID. if(transition.transitionId !== undefined) { - correctionTransform.id = transition.transitionId; + suggestionParams.tokenizedCorrection.forEach((t) => t.sample.id = transition.transitionId); } - const rootCost = match.totalCost; - const predictionRoot = { - sample: correctionTransform, - p: Math.exp(-rootCost * costFactor) - }; + const { deleteLeft } = transitionParams; return { - rootContext, - tokenizedCorrection: [predictionRoot], + ...suggestionParams, applyInPost: (entry: CorrectionPredictionTuple) => { entry.preservationTransform = tokenization.taillessTrueKeystroke; // // Will need an extra lookup layer if the suggestion is generated from within a cluster. diff --git a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts index 756a09c6559..ddf838d4586 100644 --- a/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts +++ b/web/src/test/auto/headless/engine/predictive-text/worker-thread/prediction-helpers/determine-tokenized-correction-sequence.tests.ts @@ -286,7 +286,7 @@ describe('determineTokenizedCorrectionSequence', () => { ]); }); - it(`properly analyzes conplex transition - multi-token replacement`, () => { + it(`properly analyzes complex transition - multi-token replacement`, () => { const context: Context = { left: 'the quick brown f', right: '',