Skip to content

Data key caching breaks under concurrency: N parallel encrypts → N data keys #1665

@ybs-me

Description

@ybs-me

Data key caching breaks under concurrency: N parallel encrypts → N data keys

Problem:

The Node caching CMM (NodeCachingMaterialsManager) has no in-flight request coalescing. When several encryptions for the same cache entry run concurrently against a cold cache, they all miss together (the cache is only populated after each call returns), so each one independently asks the backing keyring for a brand-new data key.

With data key caching the whole point is to reuse one DEK across many messages. Under concurrency that reuse silently doesn't happen: a parallel burst of N encryptions produces N DEKs (N KMS GenerateDataKey/Encrypt calls) instead of one, defeating the cache and inflating KMS traffic.

You can see this without any KMS Decrypt call: the wrapped data key is cached with the encryption material, so a reused DEK produces a byte-identical encrypted data key in the message header, while a fresh DEK differs. Fingerprinting the aws-kms encrypted data key in the header counts distinct DEKs.

Reproduction:

Two small files, maxMessagesEncrypted: 5, one KMS key. Run with:

KMS_KEY_ARN=arn:aws:kms:...:key/... npx tsx caching-cmm-dek-reuse.ts
caching-cmm-dek-reuse.ts
import {
  buildClient,
  CommitmentPolicy,
  KmsKeyringNode,
  NodeCachingMaterialsManager,
  getLocalCryptographicMaterialsCache,
} from '@aws-crypto/client-node'
import { dekFingerprint } from './dek-fingerprint'

// Follows the AWS Encryption SDK data key caching guide:
// https://docs.aws.amazon.com/encryption-sdk/latest/developer-guide/implement-caching.html
const MAX_REUSE = 5
const CACHE_CAPACITY = 100
const MAX_AGE_MS = 1000 * 60

const kmsKeyId = process.env.KMS_KEY_ARN
if (!kmsKeyId) {
  throw new Error('Set KMS_KEY_ARN to a KMS key ARN to run this reproduction.')
}

const { encrypt } = buildClient(CommitmentPolicy.REQUIRE_ENCRYPT_REQUIRE_DECRYPT)

// A fresh caching CMM (hence a cold cache) per scenario, so each starts clean.
// One KMS key is enough: we only care how many *data* keys get generated.
const freshCachingCmm = (): NodeCachingMaterialsManager =>
  new NodeCachingMaterialsManager({
    backingMaterials: new KmsKeyringNode({ generatorKeyId: kmsKeyId }),
    cache: getLocalCryptographicMaterialsCache(CACHE_CAPACITY),
    maxAge: MAX_AGE_MS,
    maxMessagesEncrypted: MAX_REUSE,
  })

// The cache entry is keyed by suite + encryption context (not plaintext), so a
// single shared encryption context is what lets one DEK be reused across calls.
const encryptionContext = { purpose: 'caching-cmm-dek-reuse-repro' }

const encryptOnce = async (
  cmm: NodeCachingMaterialsManager,
  plaintext: string
): Promise<string> => {
  const { result } = await encrypt(cmm, plaintext, { encryptionContext })
  return dekFingerprint(new Uint8Array(result))
}

const report = (label: string, fingerprints: string[], expectation: string) => {
  const usesByDek = new Map<string, number>()
  fingerprints.forEach((fp) => usesByDek.set(fp, (usesByDek.get(fp) || 0) + 1))

  // Group DEKs by how many times each was used, e.g. "1 DEK used 4x, 6 DEKs used 1x".
  const deksByUseCount = new Map<number, number>()
  usesByDek.forEach((u) => deksByUseCount.set(u, (deksByUseCount.get(u) || 0) + 1))
  const reuse = [...deksByUseCount.entries()]
    .sort((a, b) => b[0] - a[0])
    .map(([u, deks]) => `${deks} ${deks === 1 ? 'DEK' : 'DEKs'} used ${u}x`)
    .join(', ')

  console.log(`\n${label}`)
  console.log(`  encryptions:   ${fingerprints.length}`)
  console.log(`  distinct DEKs: ${usesByDek.size}`)
  console.log(`  reuse:         ${reuse}`)
  console.log(`  expected:      ${expectation}`)
}

const tenInParallel = async () => {
  const cmm = freshCachingCmm()
  const fingerprints = await Promise.all(
    Array.from({ length: 10 }, (_, i) => encryptOnce(cmm, `parallel-${i}`))
  )
  report(
    '1) 10 encryptions in parallel (cold cache)',
    fingerprints,
    '10 distinct DEKs -- no request coalescing: every concurrent miss makes its own DEK'
  )
}

const tenOneByOne = async () => {
  const cmm = freshCachingCmm()
  const fingerprints: string[] = []
  for (let i = 0; i < 10; i++) {
    fingerprints.push(await encryptOnce(cmm, `sequential-${i}`))
  }
  report(
    '2) 10 encryptions one by one (cold cache)',
    fingerprints,
    `2 distinct DEKs reused 5 times each -- exactly maxReuse=${MAX_REUSE}`
  )
}

const warmThenTenInParallel = async () => {
  const cmm = freshCachingCmm()
  await encryptOnce(cmm, 'warm') // 1 DEK cached, used once
  const fingerprints = await Promise.all(
    Array.from({ length: 10 }, (_, i) => encryptOnce(cmm, `warm-parallel-${i}`))
  )
  report(
    '3) 1 warm-up, then 10 in parallel',
    fingerprints,
    `7 distinct DEKs -- the warm DEK (already used once) is reused by 4 of the 10 to reach maxReuse=${MAX_REUSE}; the remaining 6 each mint their own DEK because there is no coalescing`
  )
}

const main = async () => {
  await tenInParallel()
  await tenOneByOne()
  await warmThenTenInParallel()
}

main().catch((error) => {
  console.error(error)
  process.exit(1)
})
dek-fingerprint.ts
import { deserializeFactory } from '@aws-crypto/serialize'
import { NodeAlgorithmSuite } from '@aws-crypto/material-management-node'

const toUtf8 = (bytes: Uint8Array): string =>
  Buffer.from(bytes.buffer, bytes.byteOffset, bytes.byteLength).toString('utf8')

const { deserializeMessageHeader } = deserializeFactory(toUtf8, NodeAlgorithmSuite)

// The wrapped data key (the KMS-encrypted DEK) is cached together with the
// encryption material, so a *reused* DEK yields byte-identical encrypted data
// keys, while a *freshly generated* DEK is wrapped by a new KMS call and differs.
// Fingerprinting the 'aws-kms' encrypted data key therefore tells us which DEK
// encrypted a message -- by reading the message header only, with no KMS Decrypt.
export const dekFingerprint = (message: Uint8Array): string => {
  const headerInfo = deserializeMessageHeader(Buffer.from(message))
  if (!headerInfo) {
    throw new Error('Incomplete message: shorter than a full SDK header.')
  }

  const kmsDataKey = headerInfo.messageHeader.encryptedDataKeys.find(
    (edk: { providerId: string; encryptedDataKey: Uint8Array }) =>
      edk.providerId === 'aws-kms'
  )
  if (!kmsDataKey) {
    throw new Error('No aws-kms encrypted data key found in the message header.')
  }

  return Buffer.from(kmsDataKey.encryptedDataKey).toString('base64')
}

Three scenarios, all sharing one encryption context (so one DEK could serve them all). Actual output:

  1. 10 in parallel, cold cache10 distinct DEKs (each used once). Every concurrent miss makes its own DEK — this is the bug.
  2. 10 one by one, cold cache2 distinct DEKs (used 5x each). Each DEK is reused exactly maxMessagesEncrypted=5 times. Caching works fine when serialized.
  3. 1 warm-up, then 10 in parallel7 distinct DEKs (1 used 4x, 6 used once). The warm DEK (already used once) is reused by 4 of the 10 to reach the limit of 5; the remaining 6 then find the entry gone and each mint their own DEK, because there is no coalescing.

Scenario 2 vs scenario 1 is the proof: identical work, identical config, only the concurrency differs — serialized gets the intended 2 DEKs, parallel makes 10. Scenario 3 shows it also happens after a warm cache: once the reuse budget is spent, the concurrent tail stampedes (6 fresh DEKs instead of 1).

Solution:

Add single-flight (in-flight request coalescing) to the caching CMM, keyed by the cache key, the same way it can be done for the hierarchical keyring's branch-key lookup: the first miss for a key starts the backing fetch and stores the promise; concurrent callers for that key await the same promise instead of each calling the backing keyring. The entry is evicted on settle, so the materials cache keeps ownership of caching/TTL and a failed fetch isn't shared.

There is a catch that makes this harder than the hierarchical keyring case. The HKE branch-key cache is bounded by TTL only, so plain single-flight (followers take the material straight off the shared promise) is correct there. The caching CMM's encrypt path also has per-entry usage counters (messagesEncrypted / bytesEncrypted) that bound DEK reuse — and getEncryptionMaterial increments them on every read. If coalesced followers receive the data key off the shared promise without going back through the cache, their usage is never counted: the shared DEK then gets reused beyond maxMessagesEncrypted — the unsafe over-reuse direction.

So the encrypt fix must share the fetch but still charge each coalesced caller's use against the shared entry. Concretely: only the leader fetches and populates the cache; every caller (leader and followers) then reads through the cache so the counter advances and the limit is enforced. That is strictly more involved than the HKE single-flight, and it interacts with eviction (under a small cache, an evicted entry forces a re-fetch), so the "perfect" number of backing calls isn't a clean constant.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions