Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import { ContextToken } from "./context-token.js";
import { CorrectionSearchable, PathResult } from "./correction-searchable.js";
import { ContextTokenization } from "./context-tokenization.js";
import { QuotientNodeFinalizer } from "./quotient-node-finalizer.js";
import { TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { TokenizationResult, TokenizationResultMapping } from "./tokenization-result-mapping.js";
import { EDIT_DISTANCE_COST_SCALE } from "./distance-modeler.js";
import { MAX_EDIT_THRESHOLD_FACTOR } from "./search-quotient-spur.js";

Expand Down Expand Up @@ -46,7 +46,7 @@ export type TokenResult = {
* all correctable tokens, generating corrections for the full represented
* range.
*/
export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray<TokenResult>, TokenizationResultMapping> {
export class TokenizationCorrector implements CorrectionSearchable<TokenizationResult, TokenizationResultMapping> {
public readonly tokenization: ContextTokenization;
private readonly tailCorrectionLength: number;

Expand All @@ -56,6 +56,8 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
private _predictable?: QuotientNodeFinalizer;
private _generatedTokenResults: Map<number, TokenResult>;
private _previousResults: TokenizationResultMapping[] = [];
private _correctableCodepoints: number = 0;
private _correctablesMatched = 0;

// fully private
public readonly modelsCorrectables: boolean;
Expand All @@ -65,6 +67,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
private lastTotalCost: number;
private handleHasBeenCalled: boolean = false;
private predictableMatchFound: boolean = false;
private matchableTokenCount = 0;

get currentCost(): number {
const correctable = this.selectionQueue.peek();
Expand Down Expand Up @@ -106,6 +109,10 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
return this._correctables.map((c) => this.tokenLookupMap.get(c.spaceId));
}

get correctableCodepoints(): number {
return this._correctableCodepoints;
}

/**
* Returns the token, if it exists, that is considered "predictable".
*
Expand Down Expand Up @@ -142,6 +149,10 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
return this._previousResults;
};

get matchedTokenCount() {
return this._correctablesMatched + (this.predictableMatchFound ? 1 : 0);
}

/**
* Constructs an instance of TokenizationCorrector for finding corrections for
* correctable tokens within the specified section of an existing
Expand All @@ -156,7 +167,7 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
constructor(
tokenization: ContextTokenization,
tailCorrectionLength: number,
filterClosure: (token: ContextToken) => boolean
filterClosure: (token: ContextToken, index?: number) => boolean
) {
this.tokenization = tokenization;
this.tailCorrectionLength = tailCorrectionLength;
Expand All @@ -175,16 +186,23 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.tokenLookupMap = new Map();
let modelsCorrectables = false;

// 0 index: the first index in range to be modeled, as split off from the main tokenization.
orderedTokens.forEach((token, index) => {
// New issue: this mangles the space IDs! We almost certainly need some
// sort of proper map to the source token.
const searchModule = new QuotientNodeFinalizer(token.searchModule, index == orderedTokens.length - 1);
this.tokenLookupMap.set(searchModule.spaceId, token);
const passesFilter = filterClosure(token);
// Index within the token subset being examined.
const passesFilter = filterClosure(token, index);
modelsCorrectables ||= passesFilter;
if(!passesFilter) {
this._uncorrectables.push(searchModule);
} else if(index == tailCorrectionLength - 1) {
return;
}

this.matchableTokenCount++;
this._correctableCodepoints += searchModule.codepointLength;
if(index == tailCorrectionLength - 1) {
// The sole assignment case for this field. It may only be assigned for
// the final token, and only if its text is of a form considered
// correctable by the filter.
Expand Down Expand Up @@ -270,13 +288,19 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
// It is possible that the editable tokenization range exists entirely of
// tokens considered to be uncorrectable.
this.handleHasBeenCalled = true;
const results = this.collateResults();
this._previousResults.push(results);
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};

// If no matchables exist, there's no prediction to do; don't make a return.
if(this.matchedTokenCount > 0) {
const results = this.collateResults();
this._previousResults.push(results);
return {
'type': 'complete',
cost: this.lastTotalCost,
mapping: results
};
} else {
return { type: 'none' };
}
}
}

Expand Down Expand Up @@ -314,6 +338,8 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
});
}

this._correctableCodepoints -= correctableToUpdate.codepointLength;

// We can make no further predictions if we've exhausted all search options.
// If we've reached this case, we're likely at the end of the search
// (unless correction for a correctable is still possible).
Expand All @@ -331,6 +357,8 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray

if(correctionIsThePredictable) {
this.predictableMatchFound = true;
} else {
this._correctablesMatched++;
}

// Either way, update the token -> correction-string map with the obtained result.
Expand Down Expand Up @@ -363,8 +391,8 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
this.selectionQueue.enqueue(this._predictable);
}

const correctionResults = this.collateResults();
if(correctionResults.matchedResult.findIndex((c) => c == undefined) != -1) {
// If any token lacks a matching lookup value, abort.
if([...this.tokenLookupMap.keys()].find((k) => !this._generatedTokenResults.has(k))) {

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this going to be an expensive operation?

@jahorton jahorton Jun 22, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// O(N) - iterates over all N tokens being corrected
[...this.tokenLookupMap.keys()].find( 
  // O(1) per call - is a hashmap lookup
  (k) => !this._generatedTokenResults.has(k) 
)

Therefore, the call should be O(N) worst-case.

return {
type: 'intermediate',
cost: tokenizationCost
Expand All @@ -376,12 +404,19 @@ export class TokenizationCorrector implements CorrectionSearchable<ReadonlyArray
// If there was no result obtained from the predictable and a result was previously found,
// that indicates no further predictions may be found.
if(tokenResult.type != 'none' || !correctionIsThePredictable || !this.predictableMatchFound) {
this._previousResults.push(correctionResults);
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};
if(this.matchedTokenCount > 0) {
const correctionResults = this.collateResults();
this._previousResults.push(correctionResults);
return {
type: 'complete',
cost: tokenizationCost,
mapping: correctionResults
};
} else {
return {
type: 'none'
}
}
} else {
return {
type: 'none'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,37 +1,51 @@
import { CorrectionResultMapping } from "./correction-result-mapping.js";
import { TokenizationCorrector, TokenResult } from './tokenization-corrector.js';

export class TokenizationResultMapping implements CorrectionResultMapping<ReadonlyArray<TokenResult>> {
export interface TokenizationResult {
tokenCorrections: ReadonlyArray<TokenResult>,
totalEditCount: number,
totalEditableCodepoints: number
}

export class TokenizationResultMapping implements CorrectionResultMapping<TokenizationResult> {
readonly matchingSpace: TokenizationCorrector;
readonly matchedResult: ReadonlyArray<TokenResult>;
readonly matchedResult: TokenizationResult;

constructor(tokenization: TokenResult[], corrector?: TokenizationCorrector) {
this.matchingSpace = corrector;
this.matchedResult = tokenization;

this.matchedResult = {
tokenCorrections: tokenization,
totalEditCount: tokenization.reduce((accum, curr) => accum + curr.knownCost, 0),
// If based on a legacy/custom model not using traversals, we don't
// support edit operations (for correction) beyond the direct results of
// the most recent input distribution.
totalEditableCodepoints: corrector?.correctableCodepoints ?? 0
}
}

get spaceId(): number {
return this.matchingSpace?.tokenization.spaceId;
}

// /**
// * Gets the number of Damerau-Levenshtein edits needed to reach the node's
// * matchString from the output induced by the input sequence used to reach it.
// *
// * (This is scaled by `SearchSpace.EDIT_DISTANCE_COST_SCALE` when included in
// * `totalCost`.)
// */
// get knownCost(): number {
// return this.node.editCount;
// }

// /**
// * Gets the "input sampling cost" of the edge, which should be considered as the
// * negative log-likelihood of the input path taken to reach the node.
// */
// get inputSamplingCost(): number {
// return this.node.inputSamplingCost;
// }
/**
* Gets the number of Damerau-Levenshtein edits needed to reach the node's
* matchString from the output induced by the input sequence used to reach it.
*
* (This is scaled by `SearchSpace.EDIT_DISTANCE_COST_SCALE` when included in
* `totalCost`.)
*/
get knownCost(): number {
return this.matchedResult.totalEditCount;
}

/**
* Gets the "input sampling cost" of the edge, which should be considered as the
* negative log-likelihood of the input path taken to reach the node.
*/
get inputSamplingCost(): number {
return this.matchedResult.tokenCorrections.reduce((accum, curr) => accum + curr.inputSamplingCost, 0);
}

/**
* Gets the "total cost" of the edge, which should be considered as the
Expand All @@ -40,6 +54,6 @@ export class TokenizationResultMapping implements CorrectionResultMapping<Readon
* to the resulting output.
*/
get totalCost(): number {
return this.matchedResult.reduce((total, curr) => total + curr.totalCost, 0);
return this.matchedResult.tokenCorrections.reduce((total, curr) => total + curr.totalCost, 0);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -544,8 +544,8 @@ export function buildCorrectionSequence(
const orderedTokens = tokenizationCorrection.matchingSpace?.orderedTokens;
const tokens: PredictionParameters['tokens'] = [];

for(let i = 0; i < tokenizationCorrection.matchedResult.length; i++) {
const correction = tokenizationCorrection.matchedResult[i];
for(let i = 0; i < tokenizationCorrection.matchedResult.tokenCorrections.length; i++) {
const correction = tokenizationCorrection.matchedResult.tokenCorrections[i];
/* If we're dealing with the FIRST keystroke of a new sequence, we'll **dramatically** boost
* the exponent to ensure only VERY nearby corrections have a chance of winning, and only if
* there are significantly more likely words. We only need this to allow very minor fat-finger
Expand Down Expand Up @@ -817,9 +817,18 @@ export function predictFromCorrectionSequence(

const predictionComponents = correctionTokens.map((correctionToken, i) => {
const correctionTransform = correctionToken.correction.sample;
const predictions = lexicalModel.predict(correctionTransform, currentContext);
let predictions = lexicalModel.predict(correctionTransform, currentContext);
const transitionId = correctionTransform.id;

// Ensure codepointLength == prediction codepoint length if i does not match the tail!
// Filter out cases that do not conform to this condition.
if(i != correctionTokens.length - 1) {
predictions = predictions.filter((p) => {
const codepointLength = KMWString.length(correctionToken.correction.sample.insert);
return KMWString.length(p.sample.transform.insert) == codepointLength;
});
}

// Failsafe: if there are no matching predictions, create a fake prediction
// matching the original text.
if(predictions.length != 0) {
Expand Down
Loading
Loading