diff --git a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml index a2bf72683648..ee3512a3b38d 100644 --- a/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml +++ b/examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml @@ -30,6 +30,21 @@ asr: source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer), # used with `key_phrases_file` and `key_phrases_list` boosting_tree_alpha: 0.0 # Weight of the boosting tree + beam: + # Cache-aware streaming supports MALSD beam search only (set asr.decoding.strategy: malsd_batch). + beam_size: 4 + allow_cuda_graphs: true + # n-gram LM (off by default) + ngram_lm_model: null + ngram_lm_alpha: 0.0 + # phrase boosting (off by default) + boosting_tree: + model_path: null + key_phrases_file: null + key_phrases_list: null + key_phrase_items_list: null + source_lang: "en" + boosting_tree_alpha: 0.0 # ========================================== # Inverse Text Normalization Configuration diff --git a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py index 62fb81a5e5ba..8beed8fa5621 100644 --- a/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py +++ b/nemo/collections/asr/inference/model_wrappers/cache_aware_rnnt_inference_wrapper.py @@ -18,9 +18,13 @@ from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( CacheAwareASRInferenceWrapper, ) +from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext +from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer +from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -76,32 +80,19 @@ def get_vocabulary(self) -> list[str]: """ return self.asr_model.joint.vocabulary - def execute_step( + def encoder_step( self, processed_signal: Tensor, processed_signal_length: Tensor, context: CacheAwareContext, - previous_hypotheses: list[Hypothesis] | None, drop_extra_pre_encoded: int | None, keep_all_outputs: bool, drop_left_context: int | None = None, valid_out_len: int | None = None, - prompt_vectors: Tensor | None = None, - ) -> tuple[list[Hypothesis], CacheAwareContext]: + ) -> tuple[Tensor, Tensor, CacheAwareContext]: """ - Executes a single streaming step. - Args: - processed_signal: (Tensor) input signal tensor. - processed_signal_length: (Tensor) input signal length tensor. - context: (CacheAwareContext) context object. - previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. - drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. - keep_all_outputs: (bool) whether to keep all outputs or not. - drop_left_context: (int | None) number of left context frames to drop. - valid_out_len: (int | None) number of valid output frames. - prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. - Returns: - (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + Run the cache-aware encoder for one streaming chunk, returning the (trimmed) + encoder output and updated streaming context. Decoder is NOT invoked. """ ( encoded, @@ -134,11 +125,116 @@ def execute_step( encoded = encoded[:, :, :valid_out_len] encoded_len = torch.ones_like(encoded_len) * valid_out_len + return encoded, encoded_len, new_context + + def execute_step( + self, + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + previous_hypotheses: list[Hypothesis] | None, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + prompt_vectors: Tensor | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """ + Executes a single streaming step. + Args: + processed_signal: (Tensor) input signal tensor. + processed_signal_length: (Tensor) input signal length tensor. + context: (CacheAwareContext) context object. + previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. + drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. + keep_all_outputs: (bool) whether to keep all outputs or not. + drop_left_context: (int | None) number of left context frames to drop. + valid_out_len: (int | None) number of valid output frames. + prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. + Returns: + (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. + """ + encoded, encoded_len, new_context = self.encoder_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=drop_left_context, + valid_out_len=valid_out_len, + ) + best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses ) return best_hyp, new_context + def malsd_stream_step( + self, + malsd_computer: ModifiedALSDBatchedRNNTComputer, + states: list[CacheAwareRNNTBeamStreamingState], + processed_signal: Tensor, + processed_signal_length: Tensor, + context: CacheAwareContext, + drop_extra_pre_encoded: int | None, + keep_all_outputs: bool, + drop_left_context: int | None = None, + valid_out_len: int | None = None, + ) -> tuple[list[Hypothesis], CacheAwareContext]: + """Cache-aware MALSD encode/decode step for one chunk.""" + if processed_signal.device != self.device: + processed_signal = processed_signal.to(self.device) + + if processed_signal_length.device != self.device: + processed_signal_length = processed_signal_length.to(self.device) + + carries = [state.hyp_decoding_state for state in states] + if all(c is None for c in carries): + batched_state = None + else: + batched_state = malsd_computer.merge_to_batched_state(carries) + + multi_biasing_ids = multi_biasing_ids_tensor_from_states( + states, + self.device, + per_stream_biasing_enabled=malsd_computer.per_stream_biasing_enabled, + ) + + with ( + torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), + torch.inference_mode(), + torch.no_grad(), + ): + processed_signal = processed_signal.to(self.cast_dtype) + encoded, encoded_len, new_context = self.encoder_step( + processed_signal=processed_signal, + processed_signal_length=processed_signal_length, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=drop_left_context, + valid_out_len=valid_out_len, + ) + encs_dim_last = encoded.transpose(1, 2).contiguous() + + best_batched_hyps, batched_state = malsd_computer( + encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids + ) + + chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) + beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() + scores_cpu = best_batched_hyps.scores.detach().cpu() + + carry_items = malsd_computer.split_batched_state(batched_state) + for state, ct, cts, rp, best_hyp_idx, carry in zip( + states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items + ): + state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, best_hyp_idx) + state.hyp_decoding_state = carry + + hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] + return hyps, new_context + def stream_step( self, processed_signal: Tensor, diff --git a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py index 39ba0446b2dd..005ec62456ec 100644 --- a/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py +++ b/nemo/collections/asr/inference/pipelines/cache_aware_rnnt_pipeline.py @@ -31,14 +31,23 @@ from nemo.collections.asr.inference.streaming.framing.multi_stream import ContinuousBatchedRequestStreamer from nemo.collections.asr.inference.streaming.framing.request import FeatureBuffer, Frame, Request from nemo.collections.asr.inference.streaming.framing.request_options import ASRRequestOptions -from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTStreamingState +from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import ( + CacheAwareRNNTBeamStreamingState, + CacheAwareRNNTStreamingState, +) from nemo.collections.asr.inference.utils.endpointing_utils import millisecond_to_frames from nemo.collections.asr.inference.utils.enums import RequestType +from nemo.collections.asr.inference.utils.per_stream_biasing import ( + build_multi_biasing_ids_np, + release_all_biasing_models, + release_auto_managed_stream_biasing, +) from nemo.collections.asr.inference.utils.pipeline_utils import ( check_existance_of_required_attributes, drop_trailing_features, get_confidence_utils, ) +from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis from nemo.utils import logging @@ -76,8 +85,34 @@ def __init__( self.init_endpointer() self.init_text_processor(cfg, itn_model) self.init_nmt_model(nmt_model) + self.init_decoding_computer() + strategy = str(getattr(cfg.asr.decoding, "strategy", "greedy_batch")) + if strategy not in {"greedy_batch", "malsd_batch"}: + raise ValueError( + "Cache-aware RNNT streaming supports `greedy_batch` and `malsd_batch` only; " + f"configured decoding strategy is `{strategy}`." + ) + if self.beam_decoder_computer is not None and self.prompt_enabled: + raise ValueError("Cache-aware RNNT MALSD beam search does not yet support prompt vectors.") super().__init__() + def init_decoding_computer(self) -> None: + """Initialize ``decoding_computer``.""" + self.decoding_computer = None + asr_model = getattr(self.asr_model, "asr_model", None) + if asr_model is None: + return + decoding = getattr(getattr(asr_model, "decoding", None), "decoding", None) + if decoding is not None: + self.decoding_computer = getattr(decoding, "decoding_computer", None) + + @property + def beam_decoder_computer(self) -> ModifiedALSDBatchedRNNTComputer | None: + """Return ``decoding_computer`` when beam-search decoding is active.""" + if isinstance(self.decoding_computer, ModifiedALSDBatchedRNNTComputer): + return self.decoding_computer + return None + def init_parameters(self, cfg: DictConfig) -> None: """ Initialize the parameters. @@ -181,7 +216,11 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta Returns: (CacheAwareRNNTStreamingState) New empty state. """ - state = CacheAwareRNNTStreamingState() + state = ( + CacheAwareRNNTBeamStreamingState() + if self.beam_decoder_computer is not None + else CacheAwareRNNTStreamingState() + ) state.set_global_offset(0) new_options = options.fill_defaults( default_enable_itn=self.text_processor.itn_enabled, @@ -213,6 +252,12 @@ def create_state(self, options: ASRRequestOptions) -> CacheAwareRNNTStreamingSta return state + def close_session(self) -> None: + """Close the session and release per-stream biasing models held in the decoder.""" + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: + release_all_biasing_models(self.decoding_computer.biasing_multi_model, self._state_pool.values()) + super().close_session() + def get_sep(self) -> str: """Return the separator for the text processor.""" return self.sep @@ -238,6 +283,82 @@ def preprocess(self, buffers: list[Tensor], right_paddings: list[int] | None = N feature_buffers = torch.cat(feature_buffers).to(self.device) return feature_buffers, feature_buffer_lens + def _streaming_step( + self, + states: list[CacheAwareRNNTStreamingState], + feature_buffers: Tensor, + feature_buffer_lens: Tensor, + context, + previous_hypotheses: list[Hypothesis | None], + drop_extra_pre_encoded: int, + keep_all_outputs: bool, + prompt_vectors: Tensor | None, + ) -> tuple[list[Hypothesis], object]: + """ + Run one cache-aware encode/decode step for the current chunk. + Returns per-stream hypotheses and the updated encoder cache context. + """ + if self.beam_decoder_computer is None: + return self.asr_model.stream_step( + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + previous_hypotheses=previous_hypotheses, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + prompt_vectors=prompt_vectors, + ) + return self.asr_model.malsd_stream_step( + malsd_computer=self.beam_decoder_computer, + states=states, + processed_signal=feature_buffers, + processed_signal_length=feature_buffer_lens, + context=context, + drop_extra_pre_encoded=drop_extra_pre_encoded, + keep_all_outputs=keep_all_outputs, + drop_left_context=self.drop_left_context, + valid_out_len=self.valid_out_len, + ) + + def _prepare_per_stream_biasing( + self, + states: list[CacheAwareRNNTStreamingState], + previous_hypotheses: list[Hypothesis | None], + ) -> list[Hypothesis | None]: + if self.decoding_computer is None or not self.decoding_computer.per_stream_biasing_enabled: + if any(state.has_biasing_request() for state in states): + logging.warning( + "Biasing request is not empty, but decoder does not support per-stream biasing. Skipping" + ) + return previous_hypotheses + + multi_biasing_ids_np = build_multi_biasing_ids_np( + states, + self.decoding_computer.biasing_multi_model, + self.asr_model.tokenizer, + ) + + if self.beam_decoder_computer is not None: + return previous_hypotheses + + for i, (state, previous_hyp) in enumerate(zip(states, previous_hypotheses)): + if multi_biasing_ids_np[i] < 0: + continue + biasing_cfg = state.options.biasing_cfg + if previous_hyp is None: + previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(biasing_cfg) + else: + previous_hyp.biasing_cfg = biasing_cfg + return previous_hypotheses + + def _apply_beam_update_(self, state: CacheAwareRNNTBeamStreamingState, eou_detected: bool) -> None: + """After endpointing: refresh beam publish tokens and fold cumulative prefix on EOU.""" + if eou_detected and state.hyp_decoding_state is not None: + self.beam_decoder_computer.select_beam_in_state_item_(state.hyp_decoding_state, state.get_best_hyp_idx()) + state.update_(eou_detected) + def run_greedy_decoder(self, state: CacheAwareRNNTStreamingState, request: Request, hyp: Hypothesis) -> bool: """ Run the greedy RNNT decoder on the hypothesis and update the state @@ -310,34 +431,10 @@ def cache_aware_transcribe_step( previous_hypotheses = [state.get_previous_hypothesis() for state in states] - try: - decoding_computer = self.asr_model.asr_model.decoding.decoding.decoding_computer - biasing_enabled = decoding_computer.per_stream_biasing_enabled - except AttributeError: - decoding_computer = None - biasing_enabled = False - - if not biasing_enabled and any(state.has_biasing_request() for state in states): - logging.warning("Biasing request is not empty, but decoder does not support per-stream biasing. Skipping") - - # Handle per-stream biasing: add biasing models to multi_model if needed - if biasing_enabled: - for i, (request, state, previous_hyp) in enumerate(zip(requests, states, previous_hypotheses)): - if state.has_biasing_request(): - if state.options.biasing_cfg.multi_model_id is None: - if state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.add_to_multi_model( - tokenizer=self.asr_model.tokenizer, - biasing_multi_model=decoding_computer.biasing_multi_model, - ) - else: - logging.warning( - "Biasing request is not empty, not auto managed and not compiled. Skipping" - ) - if previous_hyp is None: - previous_hypotheses[i] = Hypothesis.empty_with_biasing_cfg(state.options.biasing_cfg) - else: - previous_hyp.biasing_cfg = state.options.biasing_cfg + previous_hypotheses = self._prepare_per_stream_biasing( + states=states, + previous_hypotheses=previous_hypotheses, + ) context, mapping = self.context_manager.get_context(stream_ids) @@ -346,15 +443,14 @@ def cache_aware_transcribe_step( prompt_vectors = self._build_prompt_vectors(states) drop_extra_pre_encoded = 0 if not self.use_cache else self.asr_model.drop_extra_pre_encoded - best_hyp, new_context = self.asr_model.stream_step( - processed_signal=feature_buffers, - processed_signal_length=feature_buffer_lens, + best_hyp, new_context = self._streaming_step( + states=states, + feature_buffers=feature_buffers, + feature_buffer_lens=feature_buffer_lens, context=context, previous_hypotheses=previous_hypotheses, drop_extra_pre_encoded=drop_extra_pre_encoded, keep_all_outputs=keep_all_outputs, - drop_left_context=self.drop_left_context, - valid_out_len=self.valid_out_len, prompt_vectors=prompt_vectors, ) @@ -372,20 +468,24 @@ def cache_aware_transcribe_step( # run greedy decoder for each request-state-hypothesis tuple for request, state, hyp in zip(requests, states, best_hyp): eou_detected = self.run_greedy_decoder(state, request, hyp) + if self.beam_decoder_computer is not None: + self._apply_beam_update_(state, eou_detected) if eou_detected: self.bpe_decoder.decode_bpe_tokens(state) state.cleanup_after_eou() ready_state_ids.add(request.stream_id) # Cleanup per-stream biasing models when stream ends - if biasing_enabled: + if self.decoding_computer is not None and self.decoding_computer.per_stream_biasing_enabled: for request, state in zip(requests, states): # only the first request contains biasing options; biasing options for the stream are stored in state if request.is_last and state.has_biasing_request(): - if state.options.biasing_cfg.auto_manage_multi_model: - state.options.biasing_cfg.remove_from_multi_model( - biasing_multi_model=decoding_computer.biasing_multi_model - ) + release_auto_managed_stream_biasing(state, self.decoding_computer.biasing_multi_model) + + if self.beam_decoder_computer is not None: + for state, eos in zip(states, eos_flags): + if eos: + state.reset_beam_decoding_state_() def transcribe_step_for_feature_buffers(self, fbuffers: list[FeatureBuffer]) -> None: """ diff --git a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py index c9374c37ba26..715a3717a6e7 100644 --- a/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py +++ b/nemo/collections/asr/inference/streaming/state/cache_aware_rnnt_state.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING from nemo.collections.asr.inference.streaming.state.cache_aware_state import CacheAwareStreamingState from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis +if TYPE_CHECKING: + from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import MALSDStateItem + class CacheAwareRNNTStreamingState(CacheAwareStreamingState): """ @@ -64,3 +70,107 @@ def reset_previous_hypothesis(self) -> None: Reset the previous hypothesis to None """ self.previous_hypothesis = None + + +class CacheAwareRNNTBeamStreamingState(CacheAwareRNNTStreamingState): + """Beam search streaming state; decoder carry + cumulative/partial tokens. + + ``hyp_decoding_state``: K-beam carry across chunks (collapsed to top1 on EOU in the pipeline). + ``cumulative_*``: tokens/timestamps sealed at each EOU (prior utterances in a stream). + ``partial_*[k]``: per-beam in-flight suffix since last EOU (chunk-local exports merged via lineage). + ``best_hyp_idx``: index into ``partial_*`` for the chunk argmax beam used to publish. + """ + + def _additional_params_reset(self) -> None: + super()._additional_params_reset() + self.hyp_decoding_state: MALSDStateItem | None = None + self.cumulative_tokens: list[int] = [] + self.cumulative_timestamps: list[int] = [] + self.partial_tokens: list[list[int]] | None = None + self.partial_timestamps: list[list[int]] | None = None + self._cumulative_tokens_len: int = 0 + self.best_hyp_idx: int | None = None + + def reset_beam_decoding_state_(self) -> None: + """Clear beam search carry and cumulative/partial tokens when a stream ends.""" + self.hyp_decoding_state = None + self.cumulative_tokens = [] + self.cumulative_timestamps = [] + self.partial_tokens = None + self.partial_timestamps = None + self._cumulative_tokens_len = 0 + self.best_hyp_idx = None + + def append_chunk_beam_( + self, + chunk_tokens: list[list[int]], + chunk_timestamps: list[list[int]], + root_ptrs: list[int], + beam_size: int, + best_hyp_idx: int, + ) -> None: + """Append chunk-local beam exports into state.""" + prev_t = self.partial_tokens or [[] for _ in range(beam_size)] + prev_ts = self.partial_timestamps or [[] for _ in range(beam_size)] + next_tokens: list[list[int]] = [] + next_timestamps: list[list[int]] = [] + for k in range(beam_size): + lineage = int(root_ptrs[k]) + next_tokens.append(prev_t[lineage] + list(chunk_tokens[k])) + next_timestamps.append(prev_ts[lineage] + list(chunk_timestamps[k])) + self.partial_tokens = next_tokens + self.partial_timestamps = next_timestamps + self.best_hyp_idx = best_hyp_idx + + def get_best_hyp_idx(self) -> int: + """Index into ``partial_*`` for publish (chunk argmax, or score argmax from carry).""" + if self.best_hyp_idx is not None: + return int(self.best_hyp_idx) + if self.hyp_decoding_state is None: + raise RuntimeError("Cannot resolve top-1 beam index without decoding carry.") + return int(self.hyp_decoding_state.score.argmax().item()) + + def _get_tokens(self) -> tuple[list[int], list[int]]: + """``cumulative_*`` plus the current top-1 ``partial_*`` suffix.""" + if self.partial_tokens is None or self.hyp_decoding_state is None: + return [], [] + best_hyp_idx = self.get_best_hyp_idx() + return ( + self.cumulative_tokens + list(self.partial_tokens[best_hyp_idx]), + self.cumulative_timestamps + list(self.partial_timestamps[best_hyp_idx]), + ) + + def get_hypothesis(self, score: float) -> Hypothesis: + """Build the publishable cumulative hypothesis for the current top-1 beam.""" + cum_tokens, cum_ts = self._get_tokens() + return Hypothesis( + score=score, + y_sequence=cum_tokens, + timestamp=cum_ts, + length=len(cum_tokens), + ) + + def update_(self, eou_detected: bool) -> None: + """Refresh publish tokens; on EOU fold utterance into ``cumulative_*`` and clear ``partial_*``.""" + cum_tokens, cum_ts = self._get_tokens() + if cum_tokens: + start = max(0, min(int(self._cumulative_tokens_len), len(cum_tokens))) + tokens = list(cum_tokens[start:]) + timesteps = list(cum_ts[start:]) + self.tokens = tokens + self.timesteps = timesteps + self.confidences = [0.0] * len(tokens) + if tokens: + self.last_token = tokens[-1] + self.last_token_idx = timesteps[-1] if timesteps else None + + if not eou_detected: + return + + if cum_tokens: + self._cumulative_tokens_len = len(cum_tokens) + self.cumulative_tokens = list(cum_tokens) + self.cumulative_timestamps = list(cum_ts) + self.partial_tokens = None + self.partial_timestamps = None + self.best_hyp_idx = None diff --git a/nemo/collections/asr/inference/utils/per_stream_biasing.py b/nemo/collections/asr/inference/utils/per_stream_biasing.py new file mode 100644 index 000000000000..3ea49c813d74 --- /dev/null +++ b/nemo/collections/asr/inference/utils/per_stream_biasing.py @@ -0,0 +1,105 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +import numpy as np +import torch +from torch import Tensor + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec +from nemo.utils import logging + +if TYPE_CHECKING: + from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModelBase + + +def build_multi_biasing_ids_np( + states: Sequence[Any], + biasing_multi_model: GPUBiasingMultiModelBase, + tokenizer: TokenizerSpec, +) -> np.ndarray: + """Build per-stream biasing model ids; ``-1`` means no biasing for that stream.""" + ids_np = np.full([len(states)], fill_value=-1, dtype=np.int64) + for i, state in enumerate(states): + if not state.has_biasing_request(): + continue + + biasing_cfg = state.options.biasing_cfg + model_id = biasing_cfg.multi_model_id + if model_id is not None and not biasing_multi_model.model2active[model_id].item(): + biasing_cfg.multi_model_id = None + model_id = None + + if model_id is None: + if biasing_cfg.auto_manage_multi_model: + with torch.inference_mode(): + biasing_cfg.add_to_multi_model(tokenizer=tokenizer, biasing_multi_model=biasing_multi_model) + model_id = biasing_cfg.multi_model_id + else: + logging.warning("Biasing request is not empty, not auto managed and not compiled. Skipping") + continue + + ids_np[i] = model_id + return ids_np + + +def multi_biasing_ids_tensor_from_states( + states: Sequence[Any], + device: torch.device, + *, + per_stream_biasing_enabled: bool, +) -> Tensor | None: + """Build decode-time biasing ids from ``state.options.biasing_cfg`` (after registration).""" + if not per_stream_biasing_enabled: + return None + + ids_np = np.full([len(states)], fill_value=-1, dtype=np.int64) + for i, state in enumerate(states): + if not state.has_biasing_request(): + continue + model_id = state.options.biasing_cfg.multi_model_id + if model_id is None: + logging.warning(f"Boosting tree requested in index {i}, not compiled, skipping") + continue + ids_np[i] = model_id + + if (ids_np < 0).all(): + return None + return torch.from_numpy(ids_np).to(device=device) + + +def release_all_biasing_models(biasing_multi_model: GPUBiasingMultiModelBase, states: Sequence[Any]) -> None: + """Remove every active biasing model and clear per-stream ``multi_model_id`` bookkeeping.""" + active_model_ids = [ + model_id + for model_id in range(biasing_multi_model.num_models) + if biasing_multi_model.model2active[model_id].item() + ] + with torch.inference_mode(): + for model_id in sorted(active_model_ids, reverse=True): + biasing_multi_model.remove_model(model_id) + for state in states: + if state.has_biasing_request(): + state.options.biasing_cfg.multi_model_id = None + + +def release_auto_managed_stream_biasing(state: Any, biasing_multi_model: GPUBiasingMultiModelBase) -> None: + """Drop an auto-managed biasing model when a single stream ends.""" + if not state.has_biasing_request(): + return + if state.options.biasing_cfg.auto_manage_multi_model: + state.options.biasing_cfg.remove_from_multi_model(biasing_multi_model) diff --git a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py index 650571a78b91..f3785b6bc971 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_malsd_batched_computer.py @@ -211,6 +211,22 @@ class SeparateGraphsMALSD: loop_update_decoder: torch.cuda.CUDAGraph = field(default_factory=torch.cuda.CUDAGraph) +@dataclass +class MALSDStateItem: + """Per-stream MALSD carry for cache-aware streaming (beam-shaped tensors).""" + + predictor_state: Any # opaque per-stream predictor state of size beam_size + predictor_output: torch.Tensor # [beam_size, 1, D] + label: torch.Tensor # [beam_size] + decoded_length: torch.Tensor # scalar + score: torch.Tensor # [beam_size] + transcript_hash: torch.Tensor # [beam_size] + current_lengths_nb: torch.Tensor # [beam_size] + last_timestamp_lasts: Optional[torch.Tensor] = None # [beam_size] or None + transcript_prefix_hash: Optional[torch.Tensor] = None # [beam_size] or None + fusion_state_list: list[torch.Tensor] = field(default_factory=list) # each [beam_size, ...] + + class ModifiedALSDBatchedRNNTComputer(WithOptionalCudaGraphs, ConfidenceMethodMixin): """ Batched Alignment-Length Synchronous Decoding implementation. Callable. @@ -1476,6 +1492,258 @@ def _create_decoding_state( **self.state.batched_hyps.export_cross_chunk_state(batch_size=current_batch_size), ) + def _get_state_item_after_sos(self, device: torch.device | str) -> MALSDStateItem: + """After-SOS per-stream state; used to fill ``None`` entries in merge.""" + batched = self._get_batched_state_after_sos(device=device, batch_size=1) + return self.split_batched_state(batched)[0] + + def _get_batched_state_after_sos(self, device: torch.device | str, batch_size: int) -> BatchedBeamState: + """Fresh batched MALSD state after ```` (slot 0 active, others inactive).""" + beam_size = self.beam_size + total = batch_size * beam_size + + sos_labels = torch.full([total], fill_value=self._SOS, dtype=torch.long, device=device) + decoder_output, predictor_state, *_ = self.decoder.predict( + sos_labels.unsqueeze(1), None, add_sos=False, batch_size=total + ) + decoder_output = self.joint.project_prednet(decoder_output) # [B*K, 1, D] + + scores = torch.full( + [batch_size, beam_size], fill_value=INACTIVE_SCORE, dtype=decoder_output.dtype, device=device + ) + scores[:, 0] = 0.0 + + fusion_states_list: list[torch.Tensor] = [] + if self.has_fusion_models: + for fm in self._all_fusion_models(): + fs = fm.get_init_states(batch_size=total, bos=True).to(device) + fusion_states_list.append(fs.reshape(batch_size, beam_size, *fs.shape[1:])) + + def zeros_bk() -> torch.Tensor: + return torch.zeros([batch_size, beam_size], dtype=torch.long, device=device) + + return BatchedBeamState( + predictor_states=predictor_state, + predictor_outputs=decoder_output, + labels=sos_labels.view(batch_size, beam_size), + decoded_lengths=torch.zeros([batch_size], dtype=torch.long, device=device), + fusion_states_list=fusion_states_list, + time_jumps=None, + scores=scores, + transcript_hash=zeros_bk(), + current_lengths_nb=zeros_bk(), + last_timestamp_lasts=zeros_bk(), + transcript_prefix_hash=None, + ) + + def split_batched_state(self, state: BatchedBeamState) -> list[MALSDStateItem]: + """Split a batched MALSD state into per-stream items.""" + if state is None: + return [] + batch_size = state.labels.shape[0] + beam_size = self.beam_size + + per_row_states = self.decoder.batch_split_states(state.predictor_states) + if len(per_row_states) != batch_size * beam_size: + raise AssertionError( + f"Expected predictor states with batch dim {batch_size * beam_size}, " + f"got {len(per_row_states)} per-row items" + ) + + items: list[MALSDStateItem] = [] + for i in range(batch_size): + stream_predictor_state = self.decoder.batch_unsplit_states( + per_row_states[i * beam_size : (i + 1) * beam_size] + ) + fusion_state_list = [fs[i].clone() for fs in state.fusion_states_list] if state.fusion_states_list else [] + items.append( + MALSDStateItem( + predictor_state=stream_predictor_state, + predictor_output=state.predictor_outputs[i * beam_size : (i + 1) * beam_size].clone(), + label=state.labels[i].clone(), + decoded_length=state.decoded_lengths[i].clone(), + score=state.scores[i].clone() if state.scores is not None else None, + transcript_hash=(state.transcript_hash[i].clone() if state.transcript_hash is not None else None), + current_lengths_nb=( + state.current_lengths_nb[i].clone() if state.current_lengths_nb is not None else None + ), + last_timestamp_lasts=( + state.last_timestamp_lasts[i].clone() if state.last_timestamp_lasts is not None else None + ), + transcript_prefix_hash=( + state.transcript_prefix_hash[i].clone() if state.transcript_prefix_hash is not None else None + ), + fusion_state_list=fusion_state_list, + ) + ) + return items + + def merge_to_batched_state(self, state_items: list[Optional[MALSDStateItem]]) -> BatchedBeamState: + """Merge per-stream items into one batched state; ``None`` entries get after-SOS fillers.""" + if any(item is None for item in state_items): + not_none_item = next(item for item in state_items if item is not None) + device = not_none_item.predictor_output.device + start_item = self._get_state_item_after_sos(device=device) + state_items = [item if item is not None else start_item for item in state_items] + + per_row_states: list[Any] = [] + for item in state_items: + per_row_states.extend(self.decoder.batch_split_states(item.predictor_state)) + batched_predictor_state = self.decoder.batch_unsplit_states(per_row_states) + + predictor_outputs = torch.cat([item.predictor_output for item in state_items], dim=0) + labels = torch.stack([item.label for item in state_items], dim=0) + decoded_lengths = torch.stack([item.decoded_length for item in state_items], dim=0) + scores = torch.stack([item.score for item in state_items], dim=0) + transcript_hash = torch.stack([item.transcript_hash for item in state_items], dim=0) + current_lengths_nb = torch.stack([item.current_lengths_nb for item in state_items], dim=0) + last_timestamp_lasts = ( + torch.stack([item.last_timestamp_lasts for item in state_items], dim=0) + if state_items[0].last_timestamp_lasts is not None + else None + ) + transcript_prefix_hash = ( + torch.stack([item.transcript_prefix_hash for item in state_items], dim=0) + if state_items[0].transcript_prefix_hash is not None + else None + ) + + num_fusion = max(len(item.fusion_state_list) for item in state_items) + if num_fusion > 0: + sos_fusion_template: list[torch.Tensor] | None = None + for item in state_items: + if len(item.fusion_state_list) < num_fusion: + if sos_fusion_template is None: + sos_fusion_template = self._get_state_item_after_sos( + device=item.predictor_output.device + ).fusion_state_list + for fi in range(len(item.fusion_state_list), num_fusion): + item.fusion_state_list.append(sos_fusion_template[fi].clone()) + + fusion_states_list = [ + torch.stack([item.fusion_state_list[fi] for item in state_items], dim=0) for fi in range(num_fusion) + ] + else: + fusion_states_list = [] + + return BatchedBeamState( + predictor_states=batched_predictor_state, + predictor_outputs=predictor_outputs, + labels=labels, + decoded_lengths=decoded_lengths, + fusion_states_list=fusion_states_list, + time_jumps=None, + scores=scores, + transcript_hash=transcript_hash, + current_lengths_nb=current_lengths_nb, + last_timestamp_lasts=last_timestamp_lasts, + transcript_prefix_hash=transcript_prefix_hash, + ) + + def collapse_batched_state_to_beams_( + self, + state: BatchedBeamState, + batched_hyps: BatchedBeamHyps, + beam_indices: torch.Tensor, + ) -> None: + """Collapse each batch row to one beam, replicated across all slots.""" + batch_size = state.labels.shape[0] + beam_size = self.beam_size + if beam_indices.shape != (batch_size,): + raise ValueError( + f"beam_indices must have shape [batch_size={batch_size}], got {tuple(beam_indices.shape)}" + ) + + device = state.labels.device + beam_indices = beam_indices.to(dtype=torch.long, device=device) + + row_offsets = torch.arange(batch_size, device=device, dtype=torch.long) * beam_size + chosen_flat_idx = row_offsets + beam_indices # [B] + flat_perm = chosen_flat_idx.unsqueeze(-1).expand(batch_size, beam_size).reshape(-1) # [B*K] + + per_row = self.decoder.batch_split_states(state.predictor_states) + if len(per_row) != batch_size * beam_size: + raise AssertionError( + f"Expected predictor states with batch dim {batch_size * beam_size}, " + f"got {len(per_row)} per-row items" + ) + replicated_per_row = [per_row[int(idx)] for idx in flat_perm.tolist()] + state.predictor_states = self.decoder.batch_unsplit_states(replicated_per_row) + + state.predictor_outputs = state.predictor_outputs.index_select(0, flat_perm).contiguous() + + beam_perm = beam_indices.unsqueeze(-1).expand(batch_size, beam_size) + state.labels = torch.gather(state.labels, dim=1, index=beam_perm).contiguous() + if state.scores is not None: + state.scores = torch.gather(state.scores, dim=1, index=beam_perm).contiguous() + state.scores[:, 1:].fill_(INACTIVE_SCORE) + if state.transcript_hash is not None: + state.transcript_hash = torch.gather(state.transcript_hash, dim=1, index=beam_perm).contiguous() + if state.current_lengths_nb is not None: + state.current_lengths_nb = torch.gather(state.current_lengths_nb, dim=1, index=beam_perm).contiguous() + if state.last_timestamp_lasts is not None: + state.last_timestamp_lasts = torch.gather(state.last_timestamp_lasts, dim=1, index=beam_perm).contiguous() + if state.transcript_prefix_hash is not None: + state.transcript_prefix_hash = torch.gather( + state.transcript_prefix_hash, dim=1, index=beam_perm + ).contiguous() + + if state.fusion_states_list: + for fs in state.fusion_states_list: + if fs.ndim != 2: + raise NotImplementedError( + f"collapse_batched_state_to_beams_ only supports rank-2 [B, K] " + f"fusion states; got shape {tuple(fs.shape)}" + ) + state.fusion_states_list = [ + torch.gather(fs, dim=1, index=beam_perm).contiguous() for fs in state.fusion_states_list + ] + + batched_hyps.keep_beam_(beam_indices) + + def select_beam_in_state_item_(self, item: MALSDStateItem, beam_index: int) -> None: + """In-place per-stream beam selection (used at EOU in streaming). + + Selects ``beam_index`` and replicates that beam's decoder carry across all + ``beam_size`` slots. Beam width is unchanged; every slot holds the same + predictor, fusion, and score state so the next decode step starts from one + committed hypothesis. + """ + beam_size = self.beam_size + if not 0 <= beam_index < beam_size: + raise ValueError(f"beam_index must be in [0, {beam_size}), got {beam_index}") + + with torch.inference_mode(): + per_row = self.decoder.batch_split_states(item.predictor_state) + if len(per_row) != beam_size: + raise AssertionError( + f"Expected per-stream predictor state with batch dim {beam_size}, got {len(per_row)}" + ) + item.predictor_state = self.decoder.batch_unsplit_states([per_row[beam_index]] * beam_size) + + item.predictor_output = ( + item.predictor_output[beam_index : beam_index + 1] + .expand(beam_size, *item.predictor_output.shape[1:]) + .contiguous() + ) + + idx = torch.full([beam_size], fill_value=beam_index, dtype=torch.long, device=item.label.device) + item.label = item.label.index_select(0, idx).contiguous() + if item.score is not None: + item.score = item.score.index_select(0, idx).contiguous() + item.score[1:].fill_(INACTIVE_SCORE) + if item.transcript_hash is not None: + item.transcript_hash = item.transcript_hash.index_select(0, idx).contiguous() + if item.current_lengths_nb is not None: + item.current_lengths_nb = item.current_lengths_nb.index_select(0, idx).contiguous() + if item.last_timestamp_lasts is not None: + item.last_timestamp_lasts = item.last_timestamp_lasts.index_select(0, idx).contiguous() + if item.transcript_prefix_hash is not None: + item.transcript_prefix_hash = item.transcript_prefix_hash.index_select(0, idx).contiguous() + + for fi, fs in enumerate(item.fusion_state_list): + item.fusion_state_list[fi] = fs.index_select(0, idx).contiguous() + def __call__( self, x: torch.Tensor, diff --git a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py index af7ac09619ef..71eaa94b4822 100644 --- a/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py +++ b/nemo/collections/asr/parts/utils/batched_beam_decoding_utils.py @@ -323,6 +323,19 @@ def clone(self, batch_size: Optional[int] = None) -> "BatchedBeamHyps": new_hyps.token_durations.copy_(self.token_durations[:out_batch]) return new_hyps + def keep_beam_(self, beam_indices: torch.Tensor) -> None: + """Collapse each row to one beam, replicated across all slots (in-place).""" + if self.beam_size <= 1: + return + permutation = ( + beam_indices.to(dtype=torch.long, device=self.device) + .unsqueeze(-1) + .expand(self.batch_size, self.beam_size) + .contiguous() + ) + self._flatten_with_permutation_(permutation) + self.scores[:, 1:].fill_(INACTIVE_SCORE) + def get_last_labels(self, pad_id: int = -1) -> torch.Tensor: """ Get last labels for each hypothesis in the beam. @@ -608,57 +621,20 @@ def recombine_prefixes(self, label_logps: torch.Tensor, active_mask: torch.Tenso to_update_mask = torch.logical_and(active_mask, self.scores != INACTIVE_SCORE) self.scores = torch.where(to_update_mask, torch.logaddexp(self.scores, prefix_label_logps), self.scores) - def _export_hypothesis_timestamps( - self, - beam_timestamps: torch.Tensor, - beam_durations: Optional[torch.Tensor], - mask: torch.Tensor, - ) -> tuple: - """Convert internal beam timestamps into Hypothesis timestamp fields.""" - end_times = beam_timestamps[mask] - if self.model_type == ASRModelTypeEnum.TDT: - durations = beam_durations[mask] - start_times = end_times - durations - return ( - start_times.cpu().detach().numpy(), - durations.cpu().detach().numpy(), - ) - return end_times.cpu().detach().numpy(), None - def to_hyps_list(self, score_norm: bool = True) -> list[Hypothesis]: """ - Converts the batched beam search results into a list of signle best hypotheses for each batch. + Converts the batched beam search results into a list of single best hypotheses for each batch. Args: score_norm (bool): If True, normalize the scores before sorting. Defaults to True. Returns: list[Hypothesis]: A list where each element corresponds to a batch and contains best hypothesis. """ - self.flatten_sort_(score_norm) - - scores = self.scores[self.batch_indices, 0].tolist() - - max_idx = self.current_lengths_wb.max() - 1 - timestamps = self.timestamps[..., 0, : max_idx + 1] - transcripts = self.transcript_wb[..., 0, : max_idx + 1] - durations = self.token_durations[..., 0, : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None - hypotheses = [] - for batch_idx in range(self.batch_size): - mask = self._create_transcripts_mask(transcripts[batch_idx]) - timestamp, token_duration = self._export_hypothesis_timestamps( - timestamps[batch_idx], durations[batch_idx] if durations is not None else None, mask - ) - hypotheses.append( - Hypothesis( - score=scores[batch_idx], - y_sequence=transcripts[batch_idx][mask].cpu().detach().numpy(), - timestamp=timestamp, - token_duration=token_duration, - alignments=None, - dec_state=None, - ) - ) - return hypotheses + scores, transcripts, timestamps, durations, _ = self._export(sort=True, score_norm=score_norm) + return [ + self._hypothesis_from_flat(b, 0, scores, transcripts, timestamps, durations) + for b in range(self.batch_size) + ] def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: """ @@ -669,41 +645,78 @@ def to_nbest_hyps_list(self, score_norm: bool = True) -> list[NBestHypotheses]: list[NBestHypotheses]: A list where each element corresponds to a batch and contains N-best hypotheses. """ - - self.flatten_sort_(score_norm) - - scores = self.scores.tolist() - - max_idx = self.current_lengths_wb.max() - 1 - transcripts = self.transcript_wb[..., : max_idx + 1] - timestamps = self.timestamps[..., : max_idx + 1] - durations = self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None + scores, transcripts, timestamps, durations, _ = self._export(sort=True, score_norm=score_norm) hypotheses = [] for batch_idx in range(self.batch_size): nbest = [] for beam_idx in range(self.beam_size): if scores[batch_idx][beam_idx] <= INACTIVE_SCORE: continue - mask = self._create_transcripts_mask(transcripts[batch_idx][beam_idx]) - timestamp, token_duration = self._export_hypothesis_timestamps( - timestamps[batch_idx][beam_idx], - durations[batch_idx][beam_idx] if durations is not None else None, - mask, - ) nbest.append( - Hypothesis( - score=scores[batch_idx][beam_idx], - y_sequence=transcripts[batch_idx][beam_idx][mask].cpu().detach().numpy(), - timestamp=timestamp, - token_duration=token_duration, - alignments=None, - dec_state=None, - ) + self._hypothesis_from_flat(batch_idx, beam_idx, scores, transcripts, timestamps, durations) ) hypotheses.append(NBestHypotheses(nbest)) return hypotheses - def flatten_sort_(self, score_norm: bool = True): + def _export( + self, sort: bool = True, score_norm: bool = True + ) -> tuple[list[list[float]], torch.Tensor, torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """ + Flatten the prefix tree and return per-(batch, beam) views. + + Args: + sort: if True, flatten by descending (normalized) score; otherwise + flatten while preserving slot order. + score_norm: passed to :meth:`flatten_sort_` when ``sort=True``. + + Returns: + (scores, transcripts, timestamps, durations, root_ptrs). The first four + are inputs for :meth:`_hypothesis_from_flat`; ``root_ptrs`` is the + chunk-start -> chunk-end slot descent map (``[batch, beam]`` long + tensor) for the current beam ordering. + """ + if sort: + root_ptrs = self.flatten_sort_(score_norm) + else: + root_ptrs = self.flatten_() + scores = self.scores.tolist() + max_idx = self.current_lengths_wb.max() - 1 + transcripts = self.transcript_wb[..., : max_idx + 1] + timestamps = self.timestamps[..., : max_idx + 1] + durations = self.token_durations[..., : max_idx + 1] if self.model_type == ASRModelTypeEnum.TDT else None + return scores, transcripts, timestamps, durations, root_ptrs + + def _hypothesis_from_flat( + self, + batch_idx: int, + beam_idx: int, + scores: list[list[float]], + transcripts: torch.Tensor, + timestamps: torch.Tensor, + durations: Optional[torch.Tensor], + ) -> Hypothesis: + """Build one ``Hypothesis`` from already-flattened per-(batch, beam) views.""" + transcript = transcripts[batch_idx][beam_idx] + mask = self._create_transcripts_mask(transcript) + end_times = timestamps[batch_idx][beam_idx][mask] + if durations is not None: + # TDT: report per-token start times and durations. + token_duration = durations[batch_idx][beam_idx][mask] + timestamp = (end_times - token_duration).cpu().detach().numpy() + token_duration = token_duration.cpu().detach().numpy() + else: + timestamp = end_times.cpu().detach().numpy() + token_duration = None + return Hypothesis( + score=scores[batch_idx][beam_idx], + y_sequence=transcript[mask].cpu().detach().numpy(), + timestamp=timestamp, + token_duration=token_duration, + alignments=None, + dec_state=None, + ) + + def flatten_sort_(self, score_norm: bool = True) -> torch.Tensor: """ Sorts and flattens the tree structure of hypotheses in a batched beam search decoding process. Args: @@ -715,6 +728,11 @@ def flatten_sort_(self, score_norm: bool = True): 3. Iteratively reconstructs the tokens and timestamps for each hypothesis in reverse order. 4. Updates the internal state of the object, including transcripts, timestamps, scores, lengths, labels, and other metadata, based on the sorted order. + + Returns: + ``root_ptrs`` of shape ``[batch_size, beam_size]``: the chunk-start beam index + (before the first ``add_results_*`` write) from which each sorted output beam + descends. Same semantics as :meth:`flatten_`, but for the sorted ordering. """ # add one for consistency with non-batched decodings, that use SOS. @@ -722,7 +740,7 @@ def flatten_sort_(self, score_norm: bool = True): self.scores / (self.current_lengths_nb.to(self.scores.dtype) + 1) if score_norm else self.scores ) _, indices = torch.sort(normalized_scores, dim=-1, descending=True) - self._flatten_with_permutation_(indices) + return self._flatten_with_permutation_(indices) def flatten_(self) -> torch.Tensor: """ @@ -974,3 +992,27 @@ def merge_( self.last_timestamp_lasts.copy_(other.last_timestamp_lasts) return self + + +def export_batched_beam_hyps_to_cpu_lists( + bbh: BatchedBeamHyps, +) -> tuple[list[list[list[int]]], list[list[list[int]]], list[list[int]]]: + """Export chunk-local per-beam tokens/timestamps and beam descent map to CPU lists.""" + _, transcripts, timestamps, _, root_ptrs = bbh._export(sort=False) + root_ptrs_list = root_ptrs.detach().cpu().tolist() + transcripts_cpu = transcripts.detach().cpu() + timestamps_cpu = timestamps.detach().cpu() + + tokens: list[list[list[int]]] = [] + timestamps_out: list[list[list[int]]] = [] + for b in range(bbh.batch_size): + bt: list[list[int]] = [] + bts: list[list[int]] = [] + for k in range(bbh.beam_size): + t = transcripts_cpu[b, k] + mask = bbh._create_transcripts_mask(t) + bt.append(t[mask].tolist()) + bts.append(timestamps_cpu[b, k][mask].tolist()) + tokens.append(bt) + timestamps_out.append(bts) + return tokens, timestamps_out, root_ptrs_list diff --git a/nemo/collections/asr/parts/utils/streaming_utils.py b/nemo/collections/asr/parts/utils/streaming_utils.py index 34c18463549d..5a61e0870fc9 100644 --- a/nemo/collections/asr/parts/utils/streaming_utils.py +++ b/nemo/collections/asr/parts/utils/streaming_utils.py @@ -2464,7 +2464,9 @@ def append_no_checks_(self, data: torch.Tensor, lengths: torch.Tensor | None = N indices = torch.arange(other_len, device=self.device) shifted_indices = self.lengths[:, None] + indices[None, :] # add trailing len(dim_shape) axes to shifted_indices - shifted_indices = shifted_indices[..., *[None for _ in range(len(self.dim_shape))]] + # NB: ``a[..., *unpack]`` subscript-unpacking is Python 3.11+; loop ``unsqueeze`` for 3.10. + for _ in range(len(self.dim_shape)): + shifted_indices = shifted_indices.unsqueeze(-1) self.data.scatter_(dim=1, index=shifted_indices.expand([-1, -1] + self.dim_shape), src=data) if lengths is None: self.lengths += other_len