Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3a6f29e
merge
lilithgrigoryan Jun 8, 2026
03937b0
add n-chunk reseting working
lilithgrigoryan Jun 8, 2026
e889eea
saving config
lilithgrigoryan Jun 8, 2026
e51ee3c
add n-chunk reseting working
lilithgrigoryan Jun 8, 2026
3765acc
add eou resetting
lilithgrigoryan Jun 8, 2026
0e11a4f
clean up debug prints
lilithgrigoryan Jun 8, 2026
913000a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Jun 16, 2026
6dc1423
typecast fix
lilithgrigoryan Jun 16, 2026
0a69dee
clean up
lilithgrigoryan Jun 16, 2026
e051b12
isort and black + clean up
lilithgrigoryan Jun 16, 2026
4dce1e6
clean up
lilithgrigoryan Jun 16, 2026
893f656
clean up
lilithgrigoryan Jun 16, 2026
664a246
isort and black
lilithgrigoryan Jun 16, 2026
b342a16
add per-stream biasing
lilithgrigoryan Jun 17, 2026
b9a31a4
clean up
lilithgrigoryan Jun 17, 2026
0ed3d92
clean up
lilithgrigoryan Jun 17, 2026
56816df
refactor, separate state
lilithgrigoryan Jun 18, 2026
f6da7a5
isort and black
lilithgrigoryan Jun 18, 2026
22f5f7d
clean up
lilithgrigoryan Jun 18, 2026
a299ee4
restore docstring
lilithgrigoryan Jun 18, 2026
957084f
move malsd stream step to model wrapper
lilithgrigoryan Jun 18, 2026
49eb9fe
clean up
lilithgrigoryan Jun 18, 2026
7be3088
refactor per-stream biasing, add utils
lilithgrigoryan Jun 18, 2026
629de90
add malsd-only warning
lilithgrigoryan Jun 18, 2026
c59bc00
isort and black
lilithgrigoryan Jun 18, 2026
36e1702
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan Jun 18, 2026
3656938
restore releasing biaing models
lilithgrigoryan Jun 18, 2026
7840a22
minor clean up
lilithgrigoryan Jun 18, 2026
fadef4e
clean up
lilithgrigoryan Jun 18, 2026
b2b7116
isort and black
lilithgrigoryan Jun 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions examples/asr/conf/asr_streaming_inference/cache_aware_rnnt.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ asr:
source_lang: "en" # The source language of the context-biasing phrases (for aggregate tokenizer),
# used with `key_phrases_file` and `key_phrases_list`
boosting_tree_alpha: 0.0 # Weight of the boosting tree
beam:
# Cache-aware streaming supports MALSD beam search only (set asr.decoding.strategy: malsd_batch).
beam_size: 4
allow_cuda_graphs: true
# n-gram LM (off by default)
ngram_lm_model: null
Comment thread
lilithgrigoryan marked this conversation as resolved.
ngram_lm_alpha: 0.0
# phrase boosting (off by default)
boosting_tree:
model_path: null
key_phrases_file: null
key_phrases_list: null
key_phrase_items_list: null
source_lang: "en"
boosting_tree_alpha: 0.0

# ==========================================
# Inverse Text Normalization Configuration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,13 @@
from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import (
CacheAwareASRInferenceWrapper,
)
from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState
from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext
from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states
from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel
from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder
from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer
from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists
from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis


Expand Down Expand Up @@ -76,32 +80,19 @@ def get_vocabulary(self) -> list[str]:
"""
return self.asr_model.joint.vocabulary

def execute_step(
def encoder_step(
self,
processed_signal: Tensor,
processed_signal_length: Tensor,
context: CacheAwareContext,
previous_hypotheses: list[Hypothesis] | None,
drop_extra_pre_encoded: int | None,
keep_all_outputs: bool,
drop_left_context: int | None = None,
valid_out_len: int | None = None,
prompt_vectors: Tensor | None = None,
) -> tuple[list[Hypothesis], CacheAwareContext]:
) -> tuple[Tensor, Tensor, CacheAwareContext]:
"""
Executes a single streaming step.
Args:
processed_signal: (Tensor) input signal tensor.
processed_signal_length: (Tensor) input signal length tensor.
context: (CacheAwareContext) context object.
previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding.
drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop.
keep_all_outputs: (bool) whether to keep all outputs or not.
drop_left_context: (int | None) number of left context frames to drop.
valid_out_len: (int | None) number of valid output frames.
prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts].
Returns:
(tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context.
Run the cache-aware encoder for one streaming chunk, returning the (trimmed)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For consistency, please bring back the argument descriptions in the docstring.

encoder output and updated streaming context. Decoder is NOT invoked.
"""
(
encoded,
Expand Down Expand Up @@ -134,11 +125,116 @@ def execute_step(
encoded = encoded[:, :, :valid_out_len]
encoded_len = torch.ones_like(encoded_len) * valid_out_len

return encoded, encoded_len, new_context

def execute_step(
self,
processed_signal: Tensor,
processed_signal_length: Tensor,
context: CacheAwareContext,
previous_hypotheses: list[Hypothesis] | None,
drop_extra_pre_encoded: int | None,
keep_all_outputs: bool,
drop_left_context: int | None = None,
valid_out_len: int | None = None,
prompt_vectors: Tensor | None = None,
) -> tuple[list[Hypothesis], CacheAwareContext]:
"""
Executes a single streaming step.
Args:
processed_signal: (Tensor) input signal tensor.
processed_signal_length: (Tensor) input signal length tensor.
context: (CacheAwareContext) context object.
previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding.
drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop.
keep_all_outputs: (bool) whether to keep all outputs or not.
drop_left_context: (int | None) number of left context frames to drop.
valid_out_len: (int | None) number of valid output frames.
prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts].
Returns:
(tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context.
"""
encoded, encoded_len, new_context = self.encoder_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
context=context,
drop_extra_pre_encoded=drop_extra_pre_encoded,
keep_all_outputs=keep_all_outputs,
drop_left_context=drop_left_context,
valid_out_len=valid_out_len,
)

best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor(
encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses
)
return best_hyp, new_context

def malsd_stream_step(
self,
malsd_computer: ModifiedALSDBatchedRNNTComputer,
states: list[CacheAwareRNNTBeamStreamingState],
processed_signal: Tensor,
processed_signal_length: Tensor,
context: CacheAwareContext,
drop_extra_pre_encoded: int | None,
keep_all_outputs: bool,
drop_left_context: int | None = None,
valid_out_len: int | None = None,
) -> tuple[list[Hypothesis], CacheAwareContext]:
"""Cache-aware MALSD encode/decode step for one chunk."""
if processed_signal.device != self.device:
processed_signal = processed_signal.to(self.device)

if processed_signal_length.device != self.device:
processed_signal_length = processed_signal_length.to(self.device)

carries = [state.hyp_decoding_state for state in states]
if all(c is None for c in carries):
batched_state = None
else:
batched_state = malsd_computer.merge_to_batched_state(carries)

multi_biasing_ids = multi_biasing_ids_tensor_from_states(
states,
self.device,
per_stream_biasing_enabled=malsd_computer.per_stream_biasing_enabled,
)

with (
torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp),
torch.inference_mode(),
torch.no_grad(),
):
processed_signal = processed_signal.to(self.cast_dtype)
encoded, encoded_len, new_context = self.encoder_step(
processed_signal=processed_signal,
processed_signal_length=processed_signal_length,
context=context,
drop_extra_pre_encoded=drop_extra_pre_encoded,
keep_all_outputs=keep_all_outputs,
drop_left_context=drop_left_context,
valid_out_len=valid_out_len,
)
encs_dim_last = encoded.transpose(1, 2).contiguous()

best_batched_hyps, batched_state = malsd_computer(
encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids
)

chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps)
beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist()
scores_cpu = best_batched_hyps.scores.detach().cpu()

carry_items = malsd_computer.split_batched_state(batched_state)
for state, ct, cts, rp, best_hyp_idx, carry in zip(
states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items
):
state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, best_hyp_idx)
state.hyp_decoding_state = carry

hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)]
return hyps, new_context

def stream_step(
self,
processed_signal: Tensor,
Expand Down
Loading
Loading