-
Notifications
You must be signed in to change notification settings - Fork 3.4k
add streaming beam search for cache aware models to NeMo inference #15768
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
lilithgrigoryan
wants to merge
30
commits into
main
Choose a base branch
from
lgrigoryan/streaming-beam-search-niva-cache-aware
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
3a6f29e
merge
lilithgrigoryan 03937b0
add n-chunk reseting working
lilithgrigoryan e889eea
saving config
lilithgrigoryan e51ee3c
add n-chunk reseting working
lilithgrigoryan 3765acc
add eou resetting
lilithgrigoryan 0e11a4f
clean up debug prints
lilithgrigoryan 913000a
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan 6dc1423
typecast fix
lilithgrigoryan 0a69dee
clean up
lilithgrigoryan e051b12
isort and black + clean up
lilithgrigoryan 4dce1e6
clean up
lilithgrigoryan 893f656
clean up
lilithgrigoryan 664a246
isort and black
lilithgrigoryan b342a16
add per-stream biasing
lilithgrigoryan b9a31a4
clean up
lilithgrigoryan 0ed3d92
clean up
lilithgrigoryan 56816df
refactor, separate state
lilithgrigoryan f6da7a5
isort and black
lilithgrigoryan 22f5f7d
clean up
lilithgrigoryan a299ee4
restore docstring
lilithgrigoryan 957084f
move malsd stream step to model wrapper
lilithgrigoryan 49eb9fe
clean up
lilithgrigoryan 7be3088
refactor per-stream biasing, add utils
lilithgrigoryan 629de90
add malsd-only warning
lilithgrigoryan c59bc00
isort and black
lilithgrigoryan 36e1702
Merge branch 'main' of https://github.com/NVIDIA/NeMo into lgrigoryan…
lilithgrigoryan 3656938
restore releasing biaing models
lilithgrigoryan 7840a22
minor clean up
lilithgrigoryan fadef4e
clean up
lilithgrigoryan b2b7116
isort and black
lilithgrigoryan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,9 +18,13 @@ | |
| from nemo.collections.asr.inference.model_wrappers.cache_aware_asr_inference_wrapper import ( | ||
| CacheAwareASRInferenceWrapper, | ||
| ) | ||
| from nemo.collections.asr.inference.streaming.state.cache_aware_rnnt_state import CacheAwareRNNTBeamStreamingState | ||
| from nemo.collections.asr.inference.utils.context_manager import CacheAwareContext | ||
| from nemo.collections.asr.inference.utils.per_stream_biasing import multi_biasing_ids_tensor_from_states | ||
| from nemo.collections.asr.models import EncDecHybridRNNTCTCModel, EncDecRNNTModel | ||
| from nemo.collections.asr.parts.mixins.streaming import StreamingEncoder | ||
| from nemo.collections.asr.parts.submodules.rnnt_malsd_batched_computer import ModifiedALSDBatchedRNNTComputer | ||
| from nemo.collections.asr.parts.utils.batched_beam_decoding_utils import export_batched_beam_hyps_to_cpu_lists | ||
| from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis | ||
|
|
||
|
|
||
|
|
@@ -76,32 +80,19 @@ def get_vocabulary(self) -> list[str]: | |
| """ | ||
| return self.asr_model.joint.vocabulary | ||
|
|
||
| def execute_step( | ||
| def encoder_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| previous_hypotheses: list[Hypothesis] | None, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| prompt_vectors: Tensor | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| ) -> tuple[Tensor, Tensor, CacheAwareContext]: | ||
| """ | ||
| Executes a single streaming step. | ||
| Args: | ||
| processed_signal: (Tensor) input signal tensor. | ||
| processed_signal_length: (Tensor) input signal length tensor. | ||
| context: (CacheAwareContext) context object. | ||
| previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. | ||
| drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. | ||
| keep_all_outputs: (bool) whether to keep all outputs or not. | ||
| drop_left_context: (int | None) number of left context frames to drop. | ||
| valid_out_len: (int | None) number of valid output frames. | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. | ||
| Returns: | ||
| (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. | ||
| Run the cache-aware encoder for one streaming chunk, returning the (trimmed) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For consistency, please bring back the argument descriptions in the docstring. |
||
| encoder output and updated streaming context. Decoder is NOT invoked. | ||
| """ | ||
| ( | ||
| encoded, | ||
|
|
@@ -134,11 +125,116 @@ def execute_step( | |
| encoded = encoded[:, :, :valid_out_len] | ||
| encoded_len = torch.ones_like(encoded_len) * valid_out_len | ||
|
|
||
| return encoded, encoded_len, new_context | ||
|
|
||
| def execute_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| previous_hypotheses: list[Hypothesis] | None, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| prompt_vectors: Tensor | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| """ | ||
| Executes a single streaming step. | ||
| Args: | ||
| processed_signal: (Tensor) input signal tensor. | ||
| processed_signal_length: (Tensor) input signal length tensor. | ||
| context: (CacheAwareContext) context object. | ||
| previous_hypotheses: (list[Hypothesis] | None) list of previous hypotheses for RNNT decoding. | ||
| drop_extra_pre_encoded: (int | None) number of extra pre-encoded frames to drop. | ||
| keep_all_outputs: (bool) whether to keep all outputs or not. | ||
| drop_left_context: (int | None) number of left context frames to drop. | ||
| valid_out_len: (int | None) number of valid output frames. | ||
| prompt_vectors: (Tensor | None) Optional prompt vectors of shape [B, num_prompts]. | ||
| Returns: | ||
| (tuple[list[Hypothesis], CacheAwareContext]) best hypothesis and new context. | ||
| """ | ||
| encoded, encoded_len, new_context = self.encoder_step( | ||
| processed_signal=processed_signal, | ||
| processed_signal_length=processed_signal_length, | ||
| context=context, | ||
| drop_extra_pre_encoded=drop_extra_pre_encoded, | ||
| keep_all_outputs=keep_all_outputs, | ||
| drop_left_context=drop_left_context, | ||
| valid_out_len=valid_out_len, | ||
| ) | ||
|
|
||
| best_hyp = self.asr_model.decoding.rnnt_decoder_predictions_tensor( | ||
| encoded, encoded_len, return_hypotheses=True, partial_hypotheses=previous_hypotheses | ||
| ) | ||
| return best_hyp, new_context | ||
|
|
||
| def malsd_stream_step( | ||
| self, | ||
| malsd_computer: ModifiedALSDBatchedRNNTComputer, | ||
| states: list[CacheAwareRNNTBeamStreamingState], | ||
| processed_signal: Tensor, | ||
| processed_signal_length: Tensor, | ||
| context: CacheAwareContext, | ||
| drop_extra_pre_encoded: int | None, | ||
| keep_all_outputs: bool, | ||
| drop_left_context: int | None = None, | ||
| valid_out_len: int | None = None, | ||
| ) -> tuple[list[Hypothesis], CacheAwareContext]: | ||
| """Cache-aware MALSD encode/decode step for one chunk.""" | ||
| if processed_signal.device != self.device: | ||
| processed_signal = processed_signal.to(self.device) | ||
|
|
||
| if processed_signal_length.device != self.device: | ||
| processed_signal_length = processed_signal_length.to(self.device) | ||
|
|
||
| carries = [state.hyp_decoding_state for state in states] | ||
| if all(c is None for c in carries): | ||
| batched_state = None | ||
| else: | ||
| batched_state = malsd_computer.merge_to_batched_state(carries) | ||
|
|
||
| multi_biasing_ids = multi_biasing_ids_tensor_from_states( | ||
| states, | ||
| self.device, | ||
| per_stream_biasing_enabled=malsd_computer.per_stream_biasing_enabled, | ||
| ) | ||
|
|
||
| with ( | ||
| torch.amp.autocast(device_type=self.device_str, dtype=self.compute_dtype, enabled=self.use_amp), | ||
| torch.inference_mode(), | ||
| torch.no_grad(), | ||
| ): | ||
| processed_signal = processed_signal.to(self.cast_dtype) | ||
| encoded, encoded_len, new_context = self.encoder_step( | ||
| processed_signal=processed_signal, | ||
| processed_signal_length=processed_signal_length, | ||
| context=context, | ||
| drop_extra_pre_encoded=drop_extra_pre_encoded, | ||
| keep_all_outputs=keep_all_outputs, | ||
| drop_left_context=drop_left_context, | ||
| valid_out_len=valid_out_len, | ||
| ) | ||
| encs_dim_last = encoded.transpose(1, 2).contiguous() | ||
|
|
||
| best_batched_hyps, batched_state = malsd_computer( | ||
| encs_dim_last, encoded_len, batched_state, multi_biasing_ids=multi_biasing_ids | ||
| ) | ||
|
|
||
| chunk_tokens, chunk_timestamps, root_ptrs = export_batched_beam_hyps_to_cpu_lists(best_batched_hyps) | ||
| beam_indices = best_batched_hyps.scores.argmax(dim=-1).detach().cpu().tolist() | ||
| scores_cpu = best_batched_hyps.scores.detach().cpu() | ||
|
|
||
| carry_items = malsd_computer.split_batched_state(batched_state) | ||
| for state, ct, cts, rp, best_hyp_idx, carry in zip( | ||
| states, chunk_tokens, chunk_timestamps, root_ptrs, beam_indices, carry_items | ||
| ): | ||
| state.append_chunk_beam_(ct, cts, rp, best_batched_hyps.beam_size, best_hyp_idx) | ||
| state.hyp_decoding_state = carry | ||
|
|
||
| hyps = [state.get_hypothesis(float(scores_cpu[b, beam_indices[b]].item())) for b, state in enumerate(states)] | ||
| return hyps, new_context | ||
|
|
||
| def stream_step( | ||
| self, | ||
| processed_signal: Tensor, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.