diff --git a/nemo_curator/stages/audio/asr/__init__.py b/nemo_curator/stages/audio/asr/__init__.py new file mode 100644 index 0000000000..ea29ce2176 --- /dev/null +++ b/nemo_curator/stages/audio/asr/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.asr.metadata import ASRMetadata +from nemo_curator.stages.audio.asr.normalization import TranscriptNormalizationStage, TranscriptStatsStage + +__all__ = ["ASRMetadata", "TranscriptNormalizationStage", "TranscriptStatsStage"] diff --git a/nemo_curator/stages/audio/asr/datasets/__init__.py b/nemo_curator/stages/audio/asr/datasets/__init__.py new file mode 100644 index 0000000000..0b2034acd4 --- /dev/null +++ b/nemo_curator/stages/audio/asr/datasets/__init__.py @@ -0,0 +1,21 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.asr.datasets.base import BaseASRDatasetHandlerStage +from nemo_curator.stages.audio.asr.datasets.huggingface import HuggingFaceASRDatasetHandler + +__all__ = [ + "BaseASRDatasetHandlerStage", + "HuggingFaceASRDatasetHandler", +] diff --git a/nemo_curator/stages/audio/asr/datasets/base.py b/nemo_curator/stages/audio/asr/datasets/base.py new file mode 100644 index 0000000000..58ff2a0c8f --- /dev/null +++ b/nemo_curator/stages/audio/asr/datasets/base.py @@ -0,0 +1,230 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Base interface for ASR dataset handler stages. + +A *dataset handler* is a fan-out source stage that takes a raw, already-downloaded +dataset directory, extracts/decodes the audio into the ASR-training format +(WAV, 16 kHz, mono, PCM16), and emits one :class:`AudioTask` per utterance. + +Concrete handlers (e.g. ``HuggingFaceASRDatasetHandler``) implement :meth:`process`, +reusing the shared helpers provided here (audio conversion, task construction, +and optional per-language/per-split manifest writing). Heavy extraction is +parallelized *inside* a single Xenna worker via ``extraction_workers`` (joblib), +so handlers run with ``xenna_workers=1`` by default. +""" + +from __future__ import annotations + +import json +import os +from abc import ABC +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask, _EmptyTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + from nemo_curator.stages.audio.asr.metadata import ASRMetadata + + +@dataclass +class BaseASRDatasetHandlerStage(ProcessingStage[_EmptyTask, AudioTask], ABC): + """Base interface/protocol for ASR dataset handlers. + + Subclasses MUST implement :meth:`process`, which should: + 1. discover the raw units (e.g. HF arrow dirs / tar archives) under + ``raw_data_dir`` for each requested language and native split; + 2. extract/decode audio in parallel (use ``extraction_workers``) and + convert each clip to WAV/16 kHz/mono/PCM16 via :meth:`convert_audio`; + 3. assign dataset-specific ``split_type`` values in the concrete handler; + 4. optionally write per-split JSONL manifests via + :meth:`write_manifest_entry` when ``write_manifest`` is enabled; + 5. return one ``AudioTask`` per utterance via :meth:`build_audio_task`. + + Args: + raw_data_dir: Directory containing the already-downloaded raw dataset. + output_dir: Destination root for converted audio. + langs: Languages to process. + xenna_workers: Number of Xenna workers for this stage (kept at 1; the + stage parallelizes extraction internally). + extraction_workers: Internal joblib worker count for parallel extraction. + (Named separately from the framework ``num_workers()`` method.) + target_sample_rate: Output sample rate (Hz). + target_channels: Output channel count (1 = mono). + skip_untar: If True, reuse already-extracted WAV files when present + instead of re-decoding/writing them. + write_manifest: If True, write each emitted metadata record to + ``{output_dir}/{lang}/{split_type}.jsonl`` from this source stage. + Downstream writer stages can be used instead by leaving this False. + manifest_splits: Optional split names used by handlers with custom split + logic. Manifest files are opened lazily when a row is written, so + missing splits do not create empty JSONL files or audio directories. + """ + + raw_data_dir: str = "" + output_dir: str = "" + langs: list[str] = field(default_factory=list) + name: str = "asr_dataset_handler" + source_name: str = "asr_dataset" + xenna_workers: int = 1 + extraction_workers: int = 10 + target_sample_rate: int = 16000 + target_channels: int = 1 + skip_untar: bool = False + write_manifest: bool = False + manifest_splits: list[str] | None = None + audio_filepath_key: str = "audio_filepath" + text_key: str = "text" + batch_size: int = 1 + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def __post_init__(self) -> None: + super().__init__() + for attr in ("raw_data_dir", "output_dir"): + if not getattr(self, attr): + msg = f"{attr} is required for {type(self).__name__}" + raise ValueError(msg) + if not self.langs: + msg = f"langs is required for {type(self).__name__}" + raise ValueError(msg) + # Give the single Xenna worker enough CPUs for internal parallel extraction. + self.resources = Resources(cpus=float(max(self.extraction_workers, 1))) + + # ------------------------------------------------------------------ + # Framework wiring + # ------------------------------------------------------------------ + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [self.audio_filepath_key, self.text_key, "duration", "lang", "split_type"] + + def num_workers(self) -> int | None: + return self.xenna_workers + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: True} + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers": self.xenna_workers} + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + """Lazy-import heavy deps on the worker (not the driver).""" + import librosa + import numpy as np + import soundfile + + self._np = np + self._sf = soundfile + self._librosa = librosa + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + self._manifest_handles = {} + if not self.write_manifest: + return + for lang in self.langs: + for split_type in self._output_splits(): + os.makedirs(self.audio_output_dir(lang, split_type), exist_ok=True) + self._manifest_handles[(lang, split_type)] = self._open_manifest(lang, split_type) + + def teardown(self) -> None: + for handle in getattr(self, "_manifest_handles", {}).values(): + handle.close() + self._manifest_handles = {} + + # ------------------------------------------------------------------ + # Shared helpers for subclasses + # ------------------------------------------------------------------ + def convert_audio(self, array: Any, sample_rate: int, orig_channels: int, dst_path: str) -> dict[str, Any]: # noqa: ANN401 + """Convert one clip to WAV/target-SR/mono/PCM16 and write it to ``dst_path``. + + ``array`` must already be decoded by the concrete dataset handler. Returns + a dict with ``duration``, ``orig_sample_rate`` and ``orig_num_channels``. + When ``skip_untar`` is set and ``dst_path`` already exists, the file is + probed instead of rewritten. + """ + os.makedirs(os.path.dirname(dst_path), exist_ok=True) + + if self.skip_untar and os.path.exists(dst_path): + info = self._sf.info(dst_path) + return { + "duration": float(info.frames / info.samplerate) if info.samplerate else 0.0, + "orig_sample_rate": int(info.samplerate), + "orig_num_channels": int(info.channels), + } + + arr = self._np.asarray(array, dtype=self._np.float32) + orig_sample_rate = int(sample_rate) + if orig_sample_rate != self.target_sample_rate: + arr = self._librosa.resample(arr, orig_sr=orig_sample_rate, target_sr=self.target_sample_rate) + + self._sf.write(dst_path, arr, self.target_sample_rate, subtype="PCM_16") + duration = float(len(arr) / self.target_sample_rate) if self.target_sample_rate else 0.0 + return { + "duration": duration, + "orig_sample_rate": orig_sample_rate, + "orig_num_channels": orig_channels, + } + + def build_audio_task(self, meta: ASRMetadata) -> AudioTask: + """Wrap an :class:`ASRMetadata` into an ``AudioTask``.""" + return AudioTask( + data=meta.to_dict(), + dataset_name=f"{self.source_name}_{meta.lang}_{meta.split_type}", + filepath_key=self.audio_filepath_key, + ) + + def audio_output_dir(self, lang: str, split_type: str) -> str: + """Standard per-language/per-split audio output directory.""" + return os.path.join(self.output_dir, lang, split_type, "audio") + + def _output_splits(self) -> list[str]: + """Return split names expected by a dataset handler.""" + return list(dict.fromkeys(self.manifest_splits or [])) + + def manifest_path(self, lang: str, split_type: str) -> str: + """Return the JSONL manifest path for one language/split pair.""" + return os.path.join(self.output_dir, lang, f"{split_type}.jsonl") + + def _open_manifest(self, lang: str, split_type: str) -> Any: # noqa: ANN401 + """Open a manifest handle for one language/split pair.""" + manifest_path = self.manifest_path(lang, split_type) + os.makedirs(os.path.dirname(manifest_path), exist_ok=True) + logger.info(f"[{self.name}] writing manifest -> {manifest_path}") + return open(manifest_path, "w", encoding="utf-8") + + def write_manifest_entry(self, meta: ASRMetadata) -> None: + """Write one ``ASRMetadata`` row to its split manifest when enabled.""" + if not self.write_manifest: + return + key = (meta.lang, meta.split_type) + if not hasattr(self, "_manifest_handles"): + self._manifest_handles = {} + if key not in self._manifest_handles: + os.makedirs(self.audio_output_dir(meta.lang, meta.split_type), exist_ok=True) + self._manifest_handles[key] = self._open_manifest(*key) + handle = self._manifest_handles[key] + handle.write(json.dumps(meta.to_dict(), ensure_ascii=False) + "\n") + handle.flush() diff --git a/nemo_curator/stages/audio/asr/datasets/huggingface.py b/nemo_curator/stages/audio/asr/datasets/huggingface.py new file mode 100644 index 0000000000..2f1e9c0926 --- /dev/null +++ b/nemo_curator/stages/audio/asr/datasets/huggingface.py @@ -0,0 +1,282 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic Hugging Face ASR dataset handler for saved Arrow datasets.""" + +from __future__ import annotations + +import hashlib +import os +import time +from dataclasses import dataclass, field, fields +from pathlib import Path +from typing import TYPE_CHECKING, Any, Literal + +from datasets import Audio, load_from_disk +from joblib import Parallel, delayed +from loguru import logger + +from nemo_curator.stages.audio.asr.datasets.base import BaseASRDatasetHandlerStage +from nemo_curator.stages.audio.asr.metadata import ASRMetadata + +if TYPE_CHECKING: + from nemo_curator.tasks import AudioTask, _EmptyTask + +_SOURCE_FIELD_MAPPING_BY_SOURCE = { + "indicvoices": { + "speaker_id": "speaker_id", + "gender": "gender", + "age_group": "age", + "scenario": "scenario", + "task_name": "task_name", + "state": "state", + "district": "district", + "normalized": "normalized", + }, + "kathbath": { + "fname": "fname", + "speaker_id": "speaker_id", + "gender": "gender", + }, + "shrutilipi": {}, +} +_SUPPORTED_SOURCE_NAMES = { + source_name.lower(): source_name for source_name in ("IndicVoices", "Kathbath", "Shrutilipi") +} +_ASR_METADATA_FIELD_NAMES = {metadata_field.name for metadata_field in fields(ASRMetadata)} - {"extra"} + + +@dataclass +class _RowResult: + meta: ASRMetadata | None + skip_reason: str | None = None + + +@dataclass +class HuggingFaceASRDatasetHandler(BaseASRDatasetHandlerStage): + """Extract saved Hugging Face ASR datasets into canonical ASR audio tasks. + + The handler expects datasets that were written with ``Dataset.save_to_disk`` + and contain an audio column compatible with ``datasets.Audio``. + """ + + name: str = "huggingface_asr_dataset_handler" + native_splits: list[str] = field(default_factory=lambda: ["train", "valid"]) + split_dir_pattern: str = "{lang}/{split}" + valid_split_strategy: Literal["keep", "map", "dev_test"] = "keep" + dev_fraction: float = 0.6 + hash_buckets: int = 100 + split_mapping: dict[str, str] | None = None + duration_key: str | None = "duration" + filename_key: str | None = "fname" + extra_keys: list[str] = field(default_factory=list) + + def __post_init__(self) -> None: + super().__post_init__() + if self.source_name.lower() not in _SUPPORTED_SOURCE_NAMES: + supported_sources = ", ".join(_SUPPORTED_SOURCE_NAMES.values()) + msg = ( + f"Unsupported source_name '{self.source_name}' for {type(self).__name__}. " + f"Supported source names: {supported_sources}" + ) + raise ValueError(msg) + + def _output_splits(self) -> list[str]: + if self.manifest_splits: + return list(dict.fromkeys(self.manifest_splits)) + splits = [] + for native_split in self.native_splits: + if self._is_validation_split(native_split) and self.valid_split_strategy == "dev_test": + splits.extend(["dev", "test"]) + else: + splits.append(self.assign_split(native_split)) + return list(dict.fromkeys(splits)) + + def assign_split(self, native_split: str, utt_id: str | None = None) -> str: + """Map native dataset split names to emitted split names.""" + if self.split_mapping and native_split in self.split_mapping: + return self.split_mapping[native_split] + if self._is_validation_split(native_split) and self.valid_split_strategy == "dev_test": + if utt_id is None: + msg = "utt_id is required when valid_split_strategy='dev_test'" + raise ValueError(msg) + bucket = int(hashlib.md5(utt_id.encode("utf-8")).hexdigest(), 16) % self.hash_buckets # noqa: S324 + return "dev" if bucket < self.dev_fraction * self.hash_buckets else "test" + return native_split + + @staticmethod + def _is_validation_split(native_split: str) -> bool: + return native_split.lower() in {"valid", "val", "validation"} + + def _source_field_mapping(self) -> dict[str, str]: + if self.extra_keys: + return {key: key for key in self.extra_keys} + return _SOURCE_FIELD_MAPPING_BY_SOURCE[self.source_name.lower()] + + def _metadata_fields_from_row(self, row: dict) -> tuple[dict[str, object], dict[str, object]]: + metadata_fields = {} + extra = {} + for source_key, output_key in self._source_field_mapping().items(): + if source_key not in row: + continue + if output_key in _ASR_METADATA_FIELD_NAMES: + metadata_fields[output_key] = row[source_key] + else: + extra[output_key] = row[source_key] + return metadata_fields, extra + + def coerce_audio(self, audio_obj: Any) -> tuple[Any, int, int]: # noqa: ANN401 + """Coerce decoded Hugging Face audio into mono ``(array, sample_rate, channels)``.""" + np = self._np + if hasattr(audio_obj, "get_all_samples"): + samples = audio_obj.get_all_samples() + arr = samples.data.detach().cpu().numpy() + sample_rate = int(samples.sample_rate) + elif isinstance(audio_obj, dict) and "array" in audio_obj: + arr = np.asarray(audio_obj["array"]) + sample_rate = int(audio_obj["sampling_rate"]) + else: + msg = f"Unsupported Hugging Face audio object type: {type(audio_obj)!r}" + raise TypeError(msg) + + arr = np.asarray(arr, dtype=np.float32) + if arr.ndim == 1: + return arr, sample_rate, 1 + if arr.shape[0] <= arr.shape[1]: + return arr.mean(axis=0), sample_rate, int(arr.shape[0]) + return arr.mean(axis=1), sample_rate, int(arr.shape[1]) + + def _audio_filename(self, row: dict, utt_id: str) -> str: + if self.filename_key and row.get(self.filename_key): + return f"{Path(str(row[self.filename_key])).stem}.wav" + return f"{utt_id}.wav" + + def _process_row(self, row: dict, index: int, lang: str, native_split: str) -> _RowResult: + if self.text_key not in row or row.get(self.text_key) is None: + return _RowResult(meta=None, skip_reason="missing_text") + if self.audio_filepath_key not in row or row.get(self.audio_filepath_key) is None: + logger.debug(f"[{self.name}] skipping {lang}/{native_split} row {index}: no audio") + return _RowResult(meta=None, skip_reason="missing_audio") + + utt_id = f"{lang}_{native_split}_{index}" + split_type = self.assign_split(native_split, utt_id) + dst_path = os.path.join(self.audio_output_dir(lang, split_type), self._audio_filename(row, utt_id)) + + try: + array, sample_rate, orig_channels = self.coerce_audio(row[self.audio_filepath_key]) + audio_info = self.convert_audio(array, sample_rate, orig_channels, dst_path) + except Exception as e: # noqa: BLE001 + logger.warning(f"[{self.name}] failed to convert {utt_id}: {e}") + return _RowResult(meta=None, skip_reason="audio_load") + + metadata_fields, extra = self._metadata_fields_from_row(row) + return _RowResult( + meta=ASRMetadata( + audio_filepath=dst_path, + text=str(row[self.text_key]), + duration=audio_info["duration"], + lang=lang, + split_type=split_type, + source=self.source_name, + sample_rate=self.target_sample_rate, + num_channels=self.target_channels, + orig_sample_rate=audio_info["orig_sample_rate"], + orig_num_channels=audio_info["orig_num_channels"], + **metadata_fields, + extra=extra, + ) + ) + + def _extract_split(self, lang: str, native_split: str) -> tuple[list[ASRMetadata], dict[str, int]]: + stats = { + "input_rows": 0, + "emitted_tasks": 0, + "skipped_missing_text": 0, + "skipped_missing_audio": 0, + "skipped_audio_load": 0, + } + data_path = os.path.join(self.raw_data_dir, self.split_dir_pattern.format(lang=lang, split=native_split)) + if not os.path.isdir(data_path): + logger.warning(f"[{self.name}] missing dataset dir, skipping: {data_path}") + return [], stats + + logger.info(f"[{self.name}] loading {data_path}") + dataset = load_from_disk(data_path) + dataset = dataset.cast_column(self.audio_filepath_key, Audio(decode=True)) + + def load_and_process(index: int) -> _RowResult: + try: + row = dataset[index] + except Exception as e: # noqa: BLE001 + logger.warning(f"[{self.name}] failed to load row {lang}/{native_split}/{index}: {e}") + return _RowResult(meta=None, skip_reason="audio_load") + return self._process_row(row, index, lang, native_split) + + start = time.perf_counter() + results = Parallel(n_jobs=self.extraction_workers, backend="threading")( + delayed(load_and_process)(i) for i in range(len(dataset)) + ) + metas = [result.meta for result in results if result.meta is not None] + for result in results: + if result.skip_reason: + stats[f"skipped_{result.skip_reason}"] += 1 + stats["input_rows"] = len(results) + stats["emitted_tasks"] = len(metas) + logger.info( + f"[{self.name}] {lang}/{native_split}: extracted {len(metas)}/{len(results)} " + f"(missing_text={stats['skipped_missing_text']}, missing_audio={stats['skipped_missing_audio']}, " + f"audio_load_failed={stats['skipped_audio_load']}) " + f"in {time.perf_counter() - start:.1f}s" + ) + return metas, stats + + def process(self, _: _EmptyTask) -> list[AudioTask]: + start = time.perf_counter() + all_tasks: list[AudioTask] = [] + total_stats = { + "input_rows": 0, + "emitted_tasks": 0, + "skipped_missing_text": 0, + "skipped_missing_audio": 0, + "skipped_audio_load": 0, + } + duration_by_split = dict.fromkeys(["train", "dev", "test", *self._output_splits()], 0.0) + for lang in self.langs: + for native_split in self.native_splits: + metas, stats = self._extract_split(lang, native_split) + for key, value in stats.items(): + total_stats[key] += value + for meta in metas: + duration_by_split[meta.split_type] = duration_by_split.get(meta.split_type, 0.0) + meta.duration + self.write_manifest_entry(meta) + # multiple languages can be processed in one go if we are not storing tasks in memory. + if not self.write_manifest: + all_tasks.extend(self.build_audio_task(meta) for meta in metas) + total_stats["emitted_tasks"] = len(all_tasks) + for split_type, duration_seconds in duration_by_split.items(): + total_stats[f"duration_{split_type}_seconds"] = duration_seconds + total_stats[f"duration_{split_type}_hours"] = duration_seconds / 3600 + total_stats["process_time"] = time.perf_counter() - start + self._log_metrics(total_stats) + duration_summary = ", ".join( + f"{split_type}={duration_by_split.get(split_type, 0.0) / 3600:.2f}h" + for split_type in ["train", "dev", "test"] + ) + logger.info( + f"[{self.name}] emitted {len(all_tasks)} AudioTasks " + f"(input_rows={total_stats['input_rows']}, skipped_missing_text={total_stats['skipped_missing_text']}, " + f"skipped_missing_audio={total_stats['skipped_missing_audio']}, " + f"skipped_audio_load={total_stats['skipped_audio_load']}, duration_by_split_hours=({duration_summary}))" + ) + return all_tasks diff --git a/nemo_curator/stages/audio/asr/io/__init__.py b/nemo_curator/stages/audio/asr/io/__init__.py new file mode 100644 index 0000000000..476b70c416 --- /dev/null +++ b/nemo_curator/stages/audio/asr/io/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.asr.io.split_manifest_writer import SplitAwareManifestWriter +from nemo_curator.stages.audio.asr.io.tarred_dataset_writer import TarredAudioDatasetWriterStage + +__all__ = ["SplitAwareManifestWriter", "TarredAudioDatasetWriterStage"] diff --git a/nemo_curator/stages/audio/asr/io/convert_to_tarred_audio_dataset.py b/nemo_curator/stages/audio/asr/io/convert_to_tarred_audio_dataset.py new file mode 100644 index 0000000000..bece9e6b80 --- /dev/null +++ b/nemo_curator/stages/audio/asr/io/convert_to_tarred_audio_dataset.py @@ -0,0 +1,1029 @@ +# ruff: noqa +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +# This script converts an existing audio dataset with a manifest to +# a tarred and sharded audio dataset that can be read by the +# TarredAudioToTextDataLayer. + +# Please make sure your audio_filepath DOES NOT CONTAIN '-sub'! +# Because we will use it to handle files which have duplicate filenames but with different offsets +# (see function create_shard for details) + + +# Bucketing can help to improve the training speed. You may use --buckets_num to specify the number of buckets. +# It creates multiple tarred datasets, one per bucket, based on the audio durations. +# The range of [min_duration, max_duration) is split into equal sized buckets. +# Recommend to use --sort_in_shards to speedup the training by reducing the paddings in the batches +# More info on how to use bucketing feature: https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/datasets.html + +# If valid NVIDIA DALI version is installed, will also generate the corresponding DALI index files that need to be +# supplied to the config in order to utilize webdataset for efficient large dataset handling. +# NOTE: DALI + Webdataset is NOT compatible with Bucketing support ! + +# Usage: +1) Creating a new tarfile dataset + +python convert_to_tarred_audio_dataset.py \ + --manifest_path= \ + --target_dir= \ + --num_shards= \ + --max_duration= \ + --min_duration= \ + --shuffle --shuffle_seed=1 \ + --sort_in_shards \ + --force_codec=flac \ + --workers=-1 + + +2) Concatenating more tarfiles to a pre-existing tarred dataset + +python convert_to_tarred_audio_dataset.py \ + --manifest_path= \ + --metadata_path= \ + --target_dir= \ + --max_duration= \ + --min_duration= \ + --shuffle --shuffle_seed=1 \ + --sort_in_shards \ + --workers=-1 \ + --concat_manifest_paths + + +3) Writing an empty metadata file + +python convert_to_tarred_audio_dataset.py \ + --target_dir= \ + # any other optional argument + --num_shards=8 \ + --max_duration=16.7 \ + --min_duration=0.01 \ + --shuffle \ + --workers=-1 \ + --sort_in_shards \ + --shuffle_seed=1 \ + --write_metadata + +""" + +import argparse +import copy +import json +import os +import random +import tarfile +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from io import BytesIO +from typing import Any, List, Optional, Union + +import soundfile +from joblib import Parallel, delayed +from omegaconf import DictConfig, OmegaConf, open_dict +from tqdm import tqdm + +try: + import create_dali_tarred_dataset_index as dali_index + + DALI_INDEX_SCRIPT_AVAILABLE = True +except (ImportError, ModuleNotFoundError, FileNotFoundError): + DALI_INDEX_SCRIPT_AVAILABLE = False + + +@dataclass +class ASRTarredDatasetConfig: + num_shards: int = -1 + shuffle: bool = False + max_duration: Optional[float] = None + min_duration: Optional[float] = None + shuffle_seed: Optional[int] = None + sort_in_shards: bool = True + slice_with_offset: bool = True + shard_manifests: bool = True + keep_files_together: bool = False + force_codec: Optional[str] = None + use_lhotse: bool = False + use_bucketing: bool = False + num_buckets: Optional[int] = None + bucket_duration_bins: Optional[list[float]] = None + + +@dataclass +class ASRTarredDatasetMetadata: + created_datetime: Optional[str] = None + version: int = 0 + num_samples_per_shard: Optional[int] = None + is_concatenated_manifest: bool = False + + dataset_config: Optional[ASRTarredDatasetConfig] = field(default_factory=lambda: ASRTarredDatasetConfig()) + history: Optional[List[Any]] = field(default_factory=lambda: []) + + def __post_init__(self): + self.created_datetime = self.get_current_datetime() + + def get_current_datetime(self): + return datetime.now().strftime("%m-%d-%Y %H-%M-%S") + + @classmethod + def from_config(cls, config: DictConfig): + obj = cls() + obj.__dict__.update(**config) + return obj + + @classmethod + def from_file(cls, filepath: str): + config = OmegaConf.load(filepath) + return ASRTarredDatasetMetadata.from_config(config=config) + + +class ASRTarredDatasetBuilder: + """ + Helper class that constructs a tarred dataset from scratch, or concatenates tarred datasets + together and constructs manifests for them. + """ + + def __init__(self): + self.config = None + + def configure(self, config: ASRTarredDatasetConfig): + """ + Sets the config generated from command line overrides. + + Args: + config: ASRTarredDatasetConfig dataclass object. + """ + self.config = config # type: ASRTarredDatasetConfig + + if self.config.num_shards < 0: + raise ValueError("`num_shards` must be > 0. Please fill in the metadata information correctly.") + + def create_new_dataset( + self, + manifest_path: str, + target_dir: str = "./tarred/", + num_workers: int = 0, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + only_manifests: bool = False, + dry_run: bool = False, + ): + """ + Creates a new tarred dataset from a given manifest file. + + Args: + manifest_path (str): Path to the original ASR manifest file. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred/". + num_workers (int, optional): Number of parallel worker processes for writing tar files. Defaults to 0 (sequential processing). + buckets_num (int, optional): Number of buckets for static bucketing. Defaults to 1 (no bucketing). + dynamic_buckets_num (int, optional): Number of buckets to estimate for dynamic bucketing. Defaults to 30. + only_manifests (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. + + Raises: + ValueError: If the configuration has not been set. + FileNotFoundError: If the manifest file does not exist. + + Output: + - Creates tar files and a tarred dataset compatible manifest file in the specified `target_dir`. + - Preserves a record of the metadata used to construct the tarred dataset in `metadata.yaml`. + - Optionally creates shard manifests if `config.shard_manifests` is enabled. + + Notes: + - The function reads the manifest, applies filtering and shuffling if specified, and creates shards of tar files. + - It generates shard manifests and the main tarred dataset manifest. + - Metadata is updated and saved based on the tarred dataset configuration. + """ + if self.config is None: + raise ValueError("Config has not been set. Please call `configure(config: ASRTarredDatasetConfig)`") + + if manifest_path is None: + raise FileNotFoundError("Manifest filepath cannot be None !") + + config = self.config # type: ASRTarredDatasetConfig + + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + # Read the existing manifest + entries, total_duration, filtered_entries, filtered_duration = self._read_manifest(manifest_path, config) + + print( + f"\n Min duration: {config.min_duration} s" + f"\n Max duration: {config.max_duration} s" + f"\n Entries after filtration: {len(entries)} / {len(entries) + len(filtered_entries)}" + f"\n Duration after filtration: {total_duration:.2f} / {total_duration + filtered_duration:.2f} s" + f"\n Shards: {config.num_shards}" + f"\n Entries per shard: {len(entries) // config.num_shards}" + f"\n Remainder entries: {len(entries) % config.num_shards}" + ) + if dry_run: + return + + if len(entries) == 0: + print("No tarred dataset was created as there were 0 valid samples after filtering!") + return + if config.shuffle: + random.seed(config.shuffle_seed) + print(f"Shuffling (seed: {config.shuffle_seed})...") + if config.keep_files_together: + filename_entries = defaultdict(list) + for ent in entries: + filename_entries[ent["audio_filepath"]].append(ent) + filenames = list(filename_entries.keys()) + random.shuffle(filenames) + shuffled_entries = [] + for filename in filenames: + shuffled_entries += filename_entries[filename] + entries = shuffled_entries + else: + random.shuffle(entries) + + start_indices = [] + end_indices = [] + # Build indices + for i in range(config.num_shards): + start_idx = (len(entries) // config.num_shards) * i + end_idx = start_idx + (len(entries) // config.num_shards) + print(f"Shard {i} has entries {start_idx} ~ {end_idx}") + files = set() + for ent_id in range(start_idx, end_idx): + files.add(entries[ent_id]["audio_filepath"]) + print(f"Shard {i} contains {len(files)} files") + if i == config.num_shards - 1: + # We discard in order to have the same number of entries per shard. + print(f"Have {len(entries) - end_idx} entries left over that will be discarded.") + + start_indices.append(start_idx) + end_indices.append(end_idx) + + manifest_folder, _ = os.path.split(manifest_path) + + with Parallel(n_jobs=num_workers, verbose=config.num_shards) as parallel: + # Call parallel tarfile construction + new_entries_list = parallel( + delayed(self._create_shard)(entries[start_idx:end_idx], target_dir, i, manifest_folder, only_manifests) + for i, (start_idx, end_idx) in enumerate(zip(start_indices, end_indices)) + ) + + if config.shard_manifests: + sharded_manifests_dir = target_dir + "/sharded_manifests" + if not os.path.exists(sharded_manifests_dir): + os.makedirs(sharded_manifests_dir) + + for manifest in new_entries_list: + shard_id = manifest[0]["shard_id"] + new_manifest_shard_path = os.path.join(sharded_manifests_dir, f"manifest_{shard_id}.json") + with open(new_manifest_shard_path, "w", encoding="utf-8") as m2: + for entry in manifest: + json.dump(entry, m2, ensure_ascii=False) + m2.write("\n") + + # Flatten the list of list of entries to a list of entries + new_entries = [sample for manifest in new_entries_list for sample in manifest] + del new_entries_list + + print("Total number of entries in manifest :", len(new_entries)) + + # Write manifest + new_manifest_path = os.path.join(target_dir, "tarred_audio_manifest.json") + with open(new_manifest_path, "w", encoding="utf-8") as m2: + for entry in new_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write("\n") + + # Write metadata (default metadata for new datasets) + new_metadata_path = os.path.join(target_dir, "metadata.yaml") + metadata = ASRTarredDatasetMetadata() + + # Update metadata + metadata.dataset_config = config + metadata.num_samples_per_shard = len(new_entries) // config.num_shards + + if buckets_num <= 1: + # Estimate and update dynamic bucketing args + bucketing_kwargs = self.estimate_dynamic_bucketing_duration_bins( + new_manifest_path, num_buckets=dynamic_buckets_num + ) + for k, v in bucketing_kwargs.items(): + setattr(metadata.dataset_config, k, v) + + # Write metadata + metadata_yaml = OmegaConf.structured(metadata) + OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + + def estimate_dynamic_bucketing_duration_bins(self, manifest_path: str, num_buckets: int = 30) -> dict: + from lhotse import CutSet + from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets + + from nemo.collections.common.data.lhotse.nemo_adapters import LazyNeMoIterator + + cuts = CutSet(LazyNeMoIterator(manifest_path, metadata_only=True)) + bins = estimate_duration_buckets(cuts, num_buckets=num_buckets) + print( + f"Note: we estimated the optimal bucketing duration bins for {num_buckets} buckets. " + "You can enable dynamic bucketing by setting the following options in your training script:\n" + " use_lhotse=true\n" + " use_bucketing=true\n" + f" num_buckets={num_buckets}\n" + f" bucket_duration_bins=[{','.join(map(str, bins))}]\n" + " batch_duration=\n" + "If you'd like to use a different number of buckets, re-estimate this option manually using " + "scripts/speech_recognition/estimate_duration_bins.py" + ) + return dict( + use_lhotse=True, + use_bucketing=True, + num_buckets=num_buckets, + bucket_duration_bins=list(map(float, bins)), # np.float -> float for YAML serialization + ) + + def create_concatenated_dataset( + self, + base_manifest_path: str, + manifest_paths: List[str], + metadata: ASRTarredDatasetMetadata, + target_dir: str = "./tarred_concatenated/", + num_workers: int = 1, + only_manifests: bool = False, + dry_run: bool = False, + ): + """ + Creates a concatenated tarred dataset from the base manifest and additional manifest files. + + Args: + base_manifest_path (str): Path to the base manifest file that contains information for the original + tarred dataset (with flattened paths). + manifest_paths (List[str]): List of paths to additional manifest files that will be concatenated with + the base tarred dataset. + metadata (ASRTarredDatasetMetadata): Metadata instance containing configuration and overrides. + target_dir (str, optional): Output directory where tarred files and manifests will be saved. Defaults to "./tarred_concatenated/". + num_workers (int, optional): Number of parallel worker processes for creating tar files. Defaults to 1. + only_manifests (bool, optional): If True, performs a dry run without creating actual tar files. Defaults to False. + + Raises: + FileNotFoundError: If the base manifest file or any of the additional manifest files does not exist. + + Output: + - Creates tar files and a concatenated tarred dataset compatible manifest file in the specified `target_dir`. + - Updates metadata to reflect the concatenated dataset, including the version and historical data. + + Notes: + - The function reads the base manifest and additional manifests, filters and shuffles entries as needed, + and creates new shards of tar files. + - It generates a new concatenated dataset manifest and updates metadata with versioning and historical context. + - If `metadata` is provided, the function updates its version and includes historical data in the new metadata. + """ + if not os.path.exists(target_dir): + os.makedirs(target_dir) + + if base_manifest_path is None: + raise FileNotFoundError("Base manifest filepath cannot be None !") + + if manifest_paths is None or len(manifest_paths) == 0: + raise FileNotFoundError("List of additional manifest filepaths cannot be None !") + + config = ASRTarredDatasetConfig(**(metadata.dataset_config)) + + # Read the existing manifest (no filtering here) + base_entries, _, _, _ = self._read_manifest(base_manifest_path, config) + print(f"Read base manifest containing {len(base_entries)} samples.") + + # Precompute number of samples per shard + if metadata.num_samples_per_shard is None: + num_samples_per_shard = len(base_entries) // config.num_shards + else: + num_samples_per_shard = metadata.num_samples_per_shard + + print("Number of samples per shard :", num_samples_per_shard) + + # Compute min and max duration and update config (if no metadata passed) + print(f"Selected max duration : {config.max_duration}") + print(f"Selected min duration : {config.min_duration}") + + entries = [] + for new_manifest_idx in range(len(manifest_paths)): + new_entries, total_duration, filtered_new_entries, filtered_duration = self._read_manifest( + manifest_paths[new_manifest_idx], config + ) + + if len(filtered_new_entries) > 0: + print( + f"Filtered {len(filtered_new_entries)} files which amounts to {filtered_duration:0.2f}" + f" seconds of audio from manifest {manifest_paths[new_manifest_idx]}." + ) + print( + f"After filtering, manifest has {len(entries)} files which amounts to {total_duration} seconds of audio." + ) + + entries.extend(new_entries) + + if len(entries) == 0: + print("No tarred dataset was created as there were 0 valid samples after filtering!") + return + + if config.shuffle: + random.seed(config.shuffle_seed) + print(f"Shuffling (seed: {config.shuffle_seed})...") + random.shuffle(entries) + + # Drop last section of samples that cannot be added onto a chunk + drop_count = len(entries) % num_samples_per_shard + total_new_entries = len(entries) + entries = entries[:-drop_count] + + print( + f"Dropping {drop_count} samples from total new samples {total_new_entries} since they cannot " + f"be added into a uniformly sized chunk." + ) + + # Create shards and updated manifest entries + num_added_shards = len(entries) // num_samples_per_shard + + print(f"Number of samples in base dataset : {len(base_entries)}") + print(f"Number of samples in additional datasets : {len(entries)}") + print(f"Number of added shards : {num_added_shards}") + print(f"Remainder: {len(entries) % num_samples_per_shard}") + + if dry_run: + return + + start_indices = [] + end_indices = [] + shard_indices = [] + for i in range(num_added_shards): + start_idx = (len(entries) // num_added_shards) * i + end_idx = start_idx + (len(entries) // num_added_shards) + shard_idx = i + config.num_shards + print(f"Shard {shard_idx} has entries {start_idx + len(base_entries)} ~ {end_idx + len(base_entries)}") + + start_indices.append(start_idx) + end_indices.append(end_idx) + shard_indices.append(shard_idx) + + manifest_folder, _ = os.path.split(base_manifest_path) + + with Parallel(n_jobs=num_workers, verbose=num_added_shards) as parallel: + # Call parallel tarfile construction + new_entries_list = parallel( + delayed(self._create_shard)( + entries[start_idx:end_idx], target_dir, shard_idx, manifest_folder, only_manifests + ) + for i, (start_idx, end_idx, shard_idx) in enumerate(zip(start_indices, end_indices, shard_indices)) + ) + + if config.shard_manifests: + sharded_manifests_dir = target_dir + "/sharded_manifests" + if not os.path.exists(sharded_manifests_dir): + os.makedirs(sharded_manifests_dir) + + for manifest in new_entries_list: + shard_id = manifest[0]["shard_id"] + new_manifest_shard_path = os.path.join(sharded_manifests_dir, f"manifest_{shard_id}.json") + with open(new_manifest_shard_path, "w", encoding="utf-8") as m2: + for entry in manifest: + json.dump(entry, m2, ensure_ascii=False) + m2.write("\n") + + # Flatten the list of list of entries to a list of entries + new_entries = [sample for manifest in new_entries_list for sample in manifest] + del new_entries_list + + # Write manifest + if metadata is None: + new_version = 1 # start with `1`, where `0` indicates the base manifest + dataset + else: + new_version = metadata.version + 1 + + print("Total number of entries in manifest :", len(base_entries) + len(new_entries)) + + new_manifest_path = os.path.join(target_dir, f"tarred_audio_manifest_version_{new_version}.json") + with open(new_manifest_path, "w", encoding="utf-8") as m2: + # First write all the entries of base manifest + for entry in base_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write("\n") + + # Finally write the new entries + for entry in new_entries: + json.dump(entry, m2, ensure_ascii=False) + m2.write("\n") + + # Preserve historical metadata + base_metadata = metadata + + # Write metadata (updated metadata for concatenated datasets) + new_metadata_path = os.path.join(target_dir, f"metadata_version_{new_version}.yaml") + metadata = ASRTarredDatasetMetadata() + + # Update config + config.num_shards = config.num_shards + num_added_shards + + # Update metadata + metadata.version = new_version + metadata.dataset_config = config + metadata.num_samples_per_shard = num_samples_per_shard + metadata.is_concatenated_manifest = True + metadata.created_datetime = metadata.get_current_datetime() + + # Attach history + current_metadata = OmegaConf.structured(base_metadata.history) + metadata.history = current_metadata + + # Write metadata + metadata_yaml = OmegaConf.structured(metadata) + OmegaConf.save(metadata_yaml, new_metadata_path, resolve=True) + + def _read_manifest(self, manifest_path: Union[str, List[str]], config: ASRTarredDatasetConfig): + """Read and filters data from the manifest""" + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + + if isinstance(manifest_path, str): + manifest_paths = manifest_path.split(",") + else: + manifest_paths = manifest_path + + print(f"Found {len(manifest_paths)} manifest files to be processed") + for manifest_file in manifest_paths: + entries_i, total_dur_i, filtered_ent_i, filtered_dur_i = self._read_single_manifest( + str(manifest_file), config + ) + entries.extend(entries_i) + total_duration += total_dur_i + filtered_entries.extend(filtered_ent_i) + filtered_duration += filtered_dur_i + + return entries, total_duration, filtered_entries, filtered_duration + + def _read_single_manifest(self, manifest_path: str, config: ASRTarredDatasetConfig): + # Read the existing manifest + entries = [] + total_duration = 0.0 + filtered_entries = [] + filtered_duration = 0.0 + print(f"Reading manifest: {manifest_path}") + with open(manifest_path, "r", encoding="utf-8") as m: + for line in m: + line = line.strip() + if not line: + continue + entry = json.loads(line) + audio_key = "audio_filepath" if "audio_filepath" in entry else "audio_file" + if config.slice_with_offset and "offset" not in entry: + raise KeyError( + f"Manifest entry does not contain 'offset' field, but '--slice_with_offset' is enabled: {entry}" + ) + if audio_key not in entry: + raise KeyError(f"Manifest entry does not contain 'audio_filepath' or 'audio_file' key: {entry}") + audio_filepath = entry[audio_key] + if not os.path.isfile(audio_filepath) and not os.path.isabs(audio_filepath): + audio_filepath_abs = os.path.join(os.path.dirname(manifest_path), audio_filepath) + if not os.path.isfile(audio_filepath_abs): + raise FileNotFoundError(f"Could not find {audio_filepath} or {audio_filepath_abs}!") + entry[audio_key] = audio_filepath_abs + if (config.max_duration is None or entry["duration"] < config.max_duration) and ( + config.min_duration is None or entry["duration"] >= config.min_duration + ): + entries.append(entry) + total_duration += entry["duration"] + else: + filtered_entries.append(entry) + filtered_duration += entry["duration"] + + return entries, total_duration, filtered_entries, filtered_duration + + def _write_to_tar( + self, tar, audio_filepath: str, squashed_filename: str, duration: float = None, offset: float = 0 + ) -> None: + codec = self.config.force_codec + to_transcode = not (codec is None or audio_filepath.endswith(f".{codec}")) + to_crop = not (duration is None and offset == 0) + + if not to_crop and not to_transcode: + # Add existing file without transcoding, trimming, or re-encoding. + tar.add(audio_filepath, arcname=squashed_filename) + return + + # Standard processing: read, trim, and transcode the audio file + with soundfile.SoundFile(audio_filepath) as f: + sampling_rate = f.samplerate + + # Trim audio based on offset and duration. + start_sample = int(offset * sampling_rate) + num_frames = int(duration * sampling_rate) if duration else -1 + audio, sampling_rate = soundfile.read(file_path, start=start_sample, frames=num_frames) + + # Determine codec parameters. + if codec is not None: + if codec == "opus": + kwargs = {"format": "ogg", "subtype": "opus"} + else: + kwargs = {"format": codec} + else: + codec = soundfile.info(audio_filepath).format.lower() + kwargs = {"format": codec} + + # Transcode and write audio to tar. + encoded_audio = BytesIO() + soundfile.write(encoded_audio, audio, sampling_rate, closefd=False, **kwargs) + + # Generate filename with the appropriate extension. + encoded_squashed_filename = f"{squashed_filename.split('.')[0]}.{codec}" + + # Add the in-memory audio file to the tar archive. + ti = tarfile.TarInfo(encoded_squashed_filename) + encoded_audio.seek(0) + ti.size = len(encoded_audio.getvalue()) + tar.addfile(ti, encoded_audio) + + def _create_shard(self, entries, target_dir, shard_id, manifest_folder: str = None, only_manifests: bool = False): + """Creates a tarball containing the audio files from `entries`.""" + if self.config.sort_in_shards: + entries.sort(key=lambda x: x["duration"], reverse=False) + + new_entries = [] + + tar_filepath = os.path.join(target_dir, f"audio_{shard_id}.tar") + if not only_manifests: + tar = tarfile.open(tar_filepath, mode="w", dereference=True) + + count = dict() + for entry in tqdm(entries, desc="Creating shard.."): + # We squash the filename since we do not preserve directory structure of audio files in the tarball. + if os.path.exists(entry["audio_filepath"]) or only_manifests: + audio_filepath = entry["audio_filepath"] + else: + if not manifest_folder: + raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") + + audio_filepath = os.path.join(manifest_folder, entry["audio_filepath"]) + if not os.path.exists(audio_filepath): + raise FileNotFoundError(f"Could not find {entry['audio_filepath']}!") + + base, ext = os.path.splitext(audio_filepath) + base = base.replace("/", "_") + # Need the following replacement as long as WebDataset splits on first period + base = base.replace(".", "_") + squashed_filename = f"{base}{ext}" + + if self.config.slice_with_offset: + if squashed_filename not in count: + count[squashed_filename] = 1 + + entry_offset = str(entry["offset"]).split(".") + if len(entry_offset) == 1: + # Example: offset = 12 -> becomes 12_0 + entry_offset.append("0") + elif len(entry_offset) == 2: + # Example: offset = 12.34 -> becomes 12_34 + pass + else: + raise ValueError( + f"The offset for the entry with audio_filepath '{entry['audio_filepath']}' is incorrectly provided ({entry['offset']}). " + "Expected a float-like value (e.g., 12 or 12.34)." + ) + entry_offset = "_".join(entry_offset) + + entry_duration = str(entry["duration"]).split(".") + if len(entry_duration) == 1: + entry_duration.append("0") + elif len(entry_duration) > 2: + raise ValueError( + f"The duration for the entry with audio_filepath '{entry['audio_filepath']}' is incorrectly provided ({entry['duration']})." + ) + entry_duration = "_".join(entry_duration) + + to_write = base + "_" + entry_offset + "_" + entry_duration + ext + if not only_manifests: + self._write_to_tar( + tar, audio_filepath, to_write, duration=entry["duration"], offset=entry["offset"] + ) + count[squashed_filename] += 1 + + entry["source_audio_offset"] = entry["offset"] + del entry["offset"] + else: + if squashed_filename not in count: + if not only_manifests: + self._write_to_tar(tar, audio_filepath, squashed_filename) + to_write = squashed_filename + count[squashed_filename] = 1 + else: + to_write = base + "-sub" + str(count[squashed_filename]) + ext + count[squashed_filename] += 1 + + if only_manifests: + entry["abs_audio_filepath"] = audio_filepath + + # Carry over every key in the entry, override audio_filepath and shard_id + new_entry = { + **entry, + "audio_filepath": to_write, + "shard_id": shard_id, # Keep shard ID for recordkeeping + } + new_entries.append(new_entry) + + if not only_manifests: + tar.close() + return new_entries + + @classmethod + def setup_history(cls, base_metadata: ASRTarredDatasetMetadata, history: List[Any]): + if "history" in base_metadata.keys(): + for history_val in base_metadata.history: + cls.setup_history(history_val, history) + + if base_metadata is not None: + metadata_copy = copy.deepcopy(base_metadata) + with open_dict(metadata_copy): + metadata_copy.pop("history", None) + history.append(metadata_copy) + + +def main(args): + if args.buckets_num > 1: + bucket_length = (args.max_duration - args.min_duration) / float(args.buckets_num) + for i_bucket in range(args.buckets_num): + bucket_config = copy.deepcopy(args) + bucket_config.min_duration = args.min_duration + i_bucket * bucket_length + bucket_config.max_duration = bucket_config.min_duration + bucket_length + if i_bucket == args.buckets_num - 1: + # add a small number to cover the samples with exactly duration of max_duration in the last bucket. + bucket_config.max_duration += 1e-5 + bucket_config.target_dir = os.path.join(args.target_dir, f"bucket{i_bucket + 1}") + print( + f"Creating bucket {i_bucket + 1} with min_duration={bucket_config.min_duration} and max_duration={bucket_config.max_duration} ..." + ) + print(f"Results are being saved at: {bucket_config.target_dir}.") + create_tar_datasets(**vars(bucket_config)) + if not args.dry_run: + print(f"Bucket {i_bucket + 1} is created.") + else: + create_tar_datasets(**vars(args)) + + +def create_tar_datasets( + manifest_path: str = None, + concat_manifest_paths: str = None, + target_dir: str = None, + metadata_path: str = None, + num_shards: int = -1, + max_duration: float = None, + min_duration: float = None, + shuffle: bool = False, + keep_files_together: bool = False, + sort_in_shards: bool = False, + buckets_num: int = 1, + dynamic_buckets_num: int = 30, + shuffle_seed: int = None, + write_metadata: bool = False, + no_shard_manifests: bool = False, + force_codec: str = None, + workers: int = 1, + slice_with_offset: bool = False, + only_manifests: bool = False, + dry_run: bool = False, +): + builder = ASRTarredDatasetBuilder() + + shard_manifests = False if no_shard_manifests else True + + if write_metadata: + metadata = ASRTarredDatasetMetadata() + dataset_cfg = ASRTarredDatasetConfig( + num_shards=num_shards, + shuffle=shuffle, + max_duration=max_duration, + min_duration=min_duration, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, + shard_manifests=shard_manifests, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, + ) + metadata.dataset_config = dataset_cfg + + output_path = os.path.join(target_dir, "default_metadata.yaml") + OmegaConf.save(metadata, output_path, resolve=True) + print(f"Default metadata written to {output_path}") + exit(0) + + if concat_manifest_paths is None or len(concat_manifest_paths) == 0: + # Create a tarred dataset from scratch + config = ASRTarredDatasetConfig( + num_shards=num_shards, + shuffle=shuffle, + max_duration=max_duration, + min_duration=min_duration, + shuffle_seed=shuffle_seed, + sort_in_shards=sort_in_shards, + shard_manifests=shard_manifests, + keep_files_together=keep_files_together, + force_codec=force_codec, + slice_with_offset=slice_with_offset, + ) + builder.configure(config) + builder.create_new_dataset( + manifest_path=manifest_path, + target_dir=target_dir, + num_workers=workers, + buckets_num=buckets_num, + dynamic_buckets_num=dynamic_buckets_num, + only_manifests=only_manifests, + dry_run=dry_run, + ) + + else: + if buckets_num > 1: + raise ValueError("Concatenation feature does not support buckets_num > 1.") + print("Concatenating multiple tarred datasets ...") + + # Implicitly update config from base details + if metadata_path is not None: + metadata = ASRTarredDatasetMetadata.from_file(metadata_path) + else: + raise ValueError("`metadata` yaml file path must be provided!") + + # Preserve history + history = [] + builder.setup_history(OmegaConf.structured(metadata), history) + metadata.history = history + + # Add command line overrides (everything other than num_shards) + metadata.dataset_config.max_duration = max_duration + metadata.dataset_config.min_duration = min_duration + metadata.dataset_config.shuffle = shuffle + metadata.dataset_config.shuffle_seed = shuffle_seed + metadata.dataset_config.sort_in_shards = sort_in_shards + metadata.dataset_config.shard_manifests = shard_manifests + + builder.configure(metadata.dataset_config) + + # Concatenate a tarred dataset onto a previous one + builder.create_concatenated_dataset( + base_manifest_path=manifest_path, + manifest_paths=concat_manifest_paths, + metadata=metadata, + target_dir=target_dir, + num_workers=workers, + slice_with_offset=slice_with_offset, + only_manifests=only_manifests, + dry_run=dry_run, + ) + + if not dry_run and (DALI_INDEX_SCRIPT_AVAILABLE and dali_index.INDEX_CREATOR_AVAILABLE): + print("Constructing DALI Tarfile Index - ", target_dir) + index_config = dali_index.DALITarredIndexConfig(tar_dir=target_dir, workers=workers) + dali_index.main(index_config) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert an existing ASR dataset to tarballs compatible with TarredAudioToTextDataLayer." + ) + parser.add_argument( + "--manifest_path", default=None, type=str, required=False, help="Path to the existing dataset's manifest." + ) + + parser.add_argument( + "--concat_manifest_paths", + nargs="+", + default=None, + type=str, + required=False, + help="Path to the additional dataset's manifests that will be concatenated with base dataset.", + ) + + # Optional arguments + parser.add_argument( + "--target_dir", + default="./tarred", + type=str, + help="Target directory for resulting tarballs and manifest. Defaults to `./tarred`. Creates the path if necessary.", + ) + + parser.add_argument( + "--metadata_path", + required=False, + default=None, + type=str, + help="Path to metadata file for the dataset.", + ) + + parser.add_argument( + "--num_shards", + default=-1, + type=int, + help="Number of shards (tarballs) to create. Used for partitioning data among workers.", + ) + parser.add_argument( + "--max_duration", + default=None, + required=True, + type=float, + help="Maximum duration of audio clip in the dataset. By default, it is None and is required to be set.", + ) + parser.add_argument( + "--min_duration", + default=None, + type=float, + help="Minimum duration of audio clip in the dataset. By default, it is None and will not filter files.", + ) + parser.add_argument( + "--shuffle", + action="store_true", + help="Whether or not to randomly shuffle the samples in the manifest before tarring/sharding.", + ) + + parser.add_argument( + "--keep_files_together", + action="store_true", + help="Whether or not to keep entries from the same file (but different offsets) together when sorting before tarring/sharding.", + ) + parser.add_argument( + "--slice_with_offset", + action="store_true", + help=( + "If set, the audio will be sliced based on `offset` and `duration` fields from the manifest. " + "This is useful for creating datasets from audio segments instead of full files. " + "When unset, the entire audio file is used without slicing, regardless of the offset/duration values in the manifest." + ), + ) + parser.add_argument( + "--sort_in_shards", + action="store_true", + help="Whether or not to sort samples inside the shards based on their duration.", + ) + + parser.add_argument( + "--buckets_num", + type=int, + default=1, + help="Number of buckets to create based on duration.", + ) + + parser.add_argument( + "--dynamic_buckets_num", + type=int, + default=30, + help="Intended for dynamic (on-the-fly) bucketing; this option will not bucket your dataset during tar conversion. " + "Estimates optimal bucket duration bins for a given number of buckets.", + ) + + parser.add_argument("--shuffle_seed", type=int, default=None, help="Random seed for use if shuffling is enabled.") + parser.add_argument( + "--write_metadata", + action="store_true", + help=( + "Flag to write a blank metadata with the current call config. " + "Note that the metadata will not contain the number of shards, " + "and it must be filled out by the user." + ), + ) + parser.add_argument( + "--no_shard_manifests", + action="store_true", + help="Do not write sharded manifests along with the aggregated manifest.", + ) + parser.add_argument( + "--force_codec", + type=str, + default=None, + help=( + "If specified, transcode the audio to the given format. " + "Supports libnsndfile formats (example values: 'opus', 'flac')." + ), + ) + parser.add_argument( + "--only_manifests", + action="store_true", + help=( + "If set, only creates manifests for each shard without creating the actual tar files. " + "This allows you to verify the output structure and content before committing to the full tarball creation process. " + "Each manifest entry will also include the field `abs_audio_filepath`, which stores the absolute path to the original audio file." + ), + ) + parser.add_argument( + "--dry_run", + action="store_true", + help=( + "Run in simulation mode: calculate and display the number of shards and estimated data per shard without reading audio files or writing any output." + ), + ) + parser.add_argument("--workers", type=int, default=1, help="Number of worker processes") + args = parser.parse_args() + main(args) diff --git a/nemo_curator/stages/audio/asr/io/split_manifest_writer.py b/nemo_curator/stages/audio/asr/io/split_manifest_writer.py new file mode 100644 index 0000000000..5be5403cc0 --- /dev/null +++ b/nemo_curator/stages/audio/asr/io/split_manifest_writer.py @@ -0,0 +1,125 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Split/language-aware JSONL manifest writer for ASR pipelines. + +Routes each ``AudioTask`` to ``{output_dir}/{lang}/{output_filename_pattern}`` +so a single pipeline run produces the per-language train/dev/test manifests +directly. Runs as a single worker so the per-split files are written without +cross-worker contention. +""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import WorkerMetadata + + +@dataclass +class SplitAwareManifestWriter(ProcessingStage[AudioTask, AudioTask]): + """Write each entry to a split-derived manifest filename. + + Args: + output_dir: Destination root for the manifests. + lang_key: Data key holding the language. + split_key: Data key holding the split type (``train``/``dev``/``test``). + output_filename_pattern: Filename pattern formatted with ``lang``, + ``split``, and ``split_type``. Defaults to ``"{split}.jsonl"``. + langs: Optional languages to pre-create empty manifests for. + splits: Optional split types to pre-create empty manifests for. + When both ``langs`` and ``splits`` are given, all combinations are + created (and truncated) up front, guaranteeing the files exist even + if a split receives no entries. + """ + + output_dir: str = "" + name: str = "split_manifest_writer" + lang_key: str = "lang" + split_key: str = "split_type" + output_filename_pattern: str = "{split}.jsonl" + langs: list[str] | None = None + splits: list[str] | None = None + is_sink_stage: bool = True + _handles: dict[tuple[str, str], Any] = field(default_factory=dict, init=False, repr=False) + _counts: dict[tuple[str, str], int] = field(default_factory=dict, init=False, repr=False) + + def __post_init__(self) -> None: + if not self.output_dir: + msg = "output_dir is required for SplitAwareManifestWriter" + raise ValueError(msg) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.lang_key, self.split_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def num_workers(self) -> int | None: + return 1 + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers": 1} + + def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: + self._handles = {} + self._counts = {} + # Pre-create (truncate) declared manifests so all expected files exist. + if self.langs and self.splits: + for lang in self.langs: + for split in self.splits: + self._open(lang, split) + + def _path(self, lang: str, split: str) -> str: + filename = self.output_filename_pattern.format(lang=lang, split=split, split_type=split) + return os.path.join(self.output_dir, lang, filename) + + def _open(self, lang: str, split: str) -> Any: # noqa: ANN401 + key = (lang, split) + handle = self._handles.get(key) + if handle is None: + path = self._path(lang, split) + os.makedirs(os.path.dirname(path), exist_ok=True) + handle = open(path, "w", encoding="utf-8") # noqa: SIM115 + self._handles[key] = handle + self._counts[key] = 0 + logger.info(f"[{self.name}] writing manifest -> {path}") + return handle + + def process(self, task: AudioTask) -> AudioTask: + lang = str(task.data.get(self.lang_key, "unknown")) + split = str(task.data.get(self.split_key, "unknown")) + handle = self._open(lang, split) + handle.write(json.dumps(task.data, ensure_ascii=False) + "\n") + # Flush every write: the executor may terminate the worker without + # invoking teardown(), so we cannot rely on close() to flush the buffer. + handle.flush() + self._counts[(lang, split)] += 1 + return task + + def teardown(self) -> None: + for (lang, split), handle in self._handles.items(): + handle.close() + filename = self.output_filename_pattern.format(lang=lang, split=split, split_type=split) + logger.info(f"[{self.name}] {lang}/{filename}: {self._counts.get((lang, split), 0)} entries") + self._handles = {} diff --git a/nemo_curator/stages/audio/asr/io/tarred_dataset_writer.py b/nemo_curator/stages/audio/asr/io/tarred_dataset_writer.py new file mode 100644 index 0000000000..7782f106da --- /dev/null +++ b/nemo_curator/stages/audio/asr/io/tarred_dataset_writer.py @@ -0,0 +1,126 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Curator stage for creating NeMo-compatible tarred ASR datasets.""" + +from __future__ import annotations + +import time +from dataclasses import dataclass, field +from typing import Any + +from loguru import logger + +from nemo_curator.stages.audio.asr.io.convert_to_tarred_audio_dataset import create_tar_datasets +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import _EmptyTask + + +@dataclass +class TarredAudioDatasetWriterStage(ProcessingStage[_EmptyTask, _EmptyTask]): + """Create one tarred ASR dataset per manifest/target directory pair. + + This stage wraps NeMo's ``convert_to_tarred_audio_dataset.py`` helper for + YAML-driven Curator pipelines. It is a terminal stage and does not emit + downstream tasks. + """ + + manifest_paths: str | list[str] + target_dirs: str | list[str] + num_shards: int + max_duration: float | None + name: str = "tarred_audio_dataset_writer" + min_duration: float | None = None + shuffle: bool = False + keep_files_together: bool = False + sort_in_shards: bool = False + buckets_num: int = 1 + dynamic_buckets_num: int = 30 + shuffle_seed: int | None = None + no_shard_manifests: bool = False + force_codec: str | None = None + workers: int = 1 + slice_with_offset: bool = False + only_manifests: bool = False + dry_run: bool = False + is_sink_stage: bool = True + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def __post_init__(self) -> None: + self.manifest_paths = _coerce_path_list(self.manifest_paths) + self.target_dirs = _coerce_path_list(self.target_dirs) + if not self.manifest_paths: + msg = "manifest_paths is required for TarredAudioDatasetWriterStage" + raise ValueError(msg) + if not self.target_dirs: + msg = "target_dirs is required for TarredAudioDatasetWriterStage" + raise ValueError(msg) + if len(self.manifest_paths) != len(self.target_dirs): + msg = "manifest_paths and target_dirs must have the same length" + raise ValueError(msg) + if self.num_shards < 1: + msg = "num_shards must be >= 1 for TarredAudioDatasetWriterStage" + raise ValueError(msg) + self.resources = Resources(cpus=float(max(self.workers, 1))) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def num_workers(self) -> int | None: + return 1 + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers": 1} + + def process(self, _: _EmptyTask) -> list[_EmptyTask]: + start = time.perf_counter() + for manifest_path, target_dir in zip(self.manifest_paths, self.target_dirs, strict=True): + logger.info(f"[{self.name}] creating tarred dataset from {manifest_path} -> {target_dir}") + create_tar_datasets( + manifest_path=manifest_path, + target_dir=target_dir, + num_shards=self.num_shards, + max_duration=self.max_duration, + min_duration=self.min_duration, + shuffle=self.shuffle, + keep_files_together=self.keep_files_together, + sort_in_shards=self.sort_in_shards, + buckets_num=self.buckets_num, + dynamic_buckets_num=self.dynamic_buckets_num, + shuffle_seed=self.shuffle_seed, + no_shard_manifests=self.no_shard_manifests, + force_codec=self.force_codec, + workers=self.workers, + slice_with_offset=self.slice_with_offset, + only_manifests=self.only_manifests, + dry_run=self.dry_run, + ) + self._log_metrics( + { + "process_time": time.perf_counter() - start, + "input_manifests": len(self.manifest_paths), + "emitted_tasks": 0, + } + ) + return [] + + +def _coerce_path_list(paths: str | list[str]) -> list[str]: + if isinstance(paths, str): + return [paths] + return [str(path) for path in paths] diff --git a/nemo_curator/stages/audio/asr/metadata.py b/nemo_curator/stages/audio/asr/metadata.py new file mode 100644 index 0000000000..417577c923 --- /dev/null +++ b/nemo_curator/stages/audio/asr/metadata.py @@ -0,0 +1,88 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Canonical metadata schema for ASR manifest entries. + +``ASRMetadata`` is the typed contract every ASR dataset handler produces. It is +intentionally flattened into the plain ``dict`` carried by ``AudioTask.data`` via +:meth:`ASRMetadata.to_dict` so that all existing dict-based audio stages +(``GetAudioDurationStage``, ``AudioToDocumentStage``, ``JsonlWriter`` ...) keep +working without modification. +""" + +from __future__ import annotations + +from dataclasses import asdict, dataclass, field, fields + + +@dataclass +class ASRMetadata: + """A single ASR manifest entry. + + Core fields map directly to the JSONL manifest used for ASR training. The + ``extra`` dict holds dataset-specific fields (e.g. ``snr``, + ``collection_source``) which are spread into the top level on serialization. + + Args: + audio_filepath: Path to the converted audio (WAV, 16 kHz, mono, PCM16). + text: Ground-truth transcript. + duration: Audio duration in seconds (after conversion). + lang: Language identifier (e.g. ``"hindi"``). + split_type: Dataset split this entry belongs to (``"train"``/``"dev"``/``"test"``). + source: Source dataset name (e.g. ``"IndicVoices"``). + sample_rate: Target sample rate of the converted audio. + num_channels: Target channel count of the converted audio. + orig_sample_rate: Source sample rate before conversion (if known). + orig_num_channels: Source channel count before conversion (if known). + gender: Speaker gender label, if provided by the source dataset. + speaker_id: Source dataset speaker identifier, if provided. + age: Speaker age or age bucket, if provided by the source dataset. + text_verbatim: Original unnormalized transcript, if provided. + extra: Dataset-specific extra fields, flattened on serialization. + """ + + audio_filepath: str + text: str + duration: float + lang: str + split_type: str + source: str + sample_rate: int = 16000 + num_channels: int = 1 + orig_sample_rate: int | None = None + orig_num_channels: int | None = None + gender: str | None = None + speaker_id: str | None = None + age: str | None = None + text_verbatim: str | None = None + extra: dict = field(default_factory=dict) + + def to_dict(self) -> dict: + """Flatten to a plain dict suitable for ``AudioTask.data`` / a JSONL line. + + Core fields are emitted at the top level and ``extra`` is spread in. + Core fields take precedence over any colliding key in ``extra``. + """ + out = {k: v for k, v in asdict(self).items() if k != "extra"} + for key, value in self.extra.items(): + out.setdefault(key, value) + return out + + @classmethod + def from_dict(cls, data: dict) -> ASRMetadata: + """Rebuild from a flattened dict; unknown keys are collected into ``extra``.""" + known = {f.name for f in fields(cls)} - {"extra"} + core = {k: data[k] for k in known if k in data} + extra = {k: v for k, v in data.items() if k not in known} + return cls(**core, extra=extra) diff --git a/nemo_curator/stages/audio/asr/normalization/__init__.py b/nemo_curator/stages/audio/asr/normalization/__init__.py new file mode 100644 index 0000000000..d3dc04d603 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.asr.normalization.stats import TranscriptStatsStage +from nemo_curator.stages.audio.asr.normalization.transcript import TranscriptNormalizationStage + +__all__ = ["TranscriptNormalizationStage", "TranscriptStatsStage"] diff --git a/nemo_curator/stages/audio/asr/normalization/langs/bn/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/bn/alphabet.txt new file mode 100644 index 0000000000..b495042657 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/bn/alphabet.txt @@ -0,0 +1,67 @@ + +অ +আ +ই +ঈ +উ +ঊ +ঋ +এ +ঐ +ও +ঔ +ক +খ +গ +ঘ +ঙ +চ +ছ +জ +ঝ +ঞ +ট +ঠ +ড +ঢ +ণ +ত +থ +দ +ধ +ন +প +ফ +ব +ভ +ম +য +র +ল +শ +ষ +স +হ +ড় +ঢ় +য় +ৎ +ড় +া +ি +ী +ু +ূ +ৃ +ৄ +ৢ +ৣ +ে +ৈ +ো +ৌ +্ +ং +ঁ +ঃ +় diff --git a/nemo_curator/stages/audio/asr/normalization/langs/bn/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/bn/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/en/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/en/alphabet.txt new file mode 100644 index 0000000000..0edb85647c --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/en/alphabet.txt @@ -0,0 +1,26 @@ +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z diff --git a/nemo_curator/stages/audio/asr/normalization/langs/en/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/en/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/gu/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/gu/alphabet.txt new file mode 100644 index 0000000000..f80efa1145 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/gu/alphabet.txt @@ -0,0 +1,64 @@ + +ઁ +ં +ઃ +અ +આ +ઇ +ઈ +ઉ +ઊ +ઋ +ઍ +એ +ઐ +ઑ +ઓ +ઔ +ક +ખ +ગ +ઘ +ઙ +ચ +છ +જ +ઝ +ઞ +ટ +ઠ +ડ +ઢ +ણ +ત +થ +દ +ધ +ન +પ +ફ +બ +ભ +મ +ય +ર +લ +વ +શ +ષ +સ +હ +ળ +ા +િ +ી +ુ +ૂ +ૃ +ૅ +ે +ૈ +ૉ +ો +ૌ +્ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/gu/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/gu/pretok.jsonl new file mode 100644 index 0000000000..edda67c2c7 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/gu/pretok.jsonl @@ -0,0 +1,7 @@ +{"pattern": "", "repl": ""} +{"pattern": "\u0007|\u200d|\u200c|\u200b|\u200f", "repl": ""} +{"pattern": "\u201a", "repl": ","} +{"pattern": "_", "repl": " "} +{"pattern": "-", "repl": " "} +{"pattern": "\u2014", "repl": " "} +{"pattern": "\u2013", "repl": " "} diff --git a/nemo_curator/stages/audio/asr/normalization/langs/hi/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/hi/alphabet.txt new file mode 100644 index 0000000000..1725a54984 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/hi/alphabet.txt @@ -0,0 +1,76 @@ + +, +? +। +ँ +ं +ः +अ +आ +इ +ई +उ +ऊ +ऋ +ऍ +ए +ऐ +ऑ +ओ +औ +क +ख +ग +घ +ङ +च +छ +ज +झ +ञ +ट +ठ +ड +ढ +ण +त +थ +द +ध +न +प +फ +ब +भ +म +य +र +ल +व +श +ष +स +ह +़ +ा +ि +ी +ु +ू +ृ +ॅ +े +ै +ॉ +ो +ौ +् +ॐ +ॠ +ड़ +ज़ +ढ़ +फ़ +ख़ +क़ +ग़ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/hi/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/hi/pretok.jsonl new file mode 100644 index 0000000000..78366400d1 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/hi/pretok.jsonl @@ -0,0 +1,19 @@ +{"pattern": "₹", "repl": " रुपये "} +{"pattern": "०", "repl": " "} +{"pattern": "ऎ", "repl": "ऐ"} +{"pattern": "ॊ", "repl": "ो"} +{"pattern": "य़", "repl": "य"} +{"pattern": "य़", "repl": "य"} +{"pattern": "ऱ", "repl": "र"} +{"pattern": "ऱ", "repl": "र"} +{"pattern": "ॆ", "repl": "े"} +{"pattern": "\u0007|\u200d|\u200c|\u200b|\u200f", "repl": ""} +{"pattern": "ळ", "repl": "ल"} +{"pattern": "वंâ", "repl": "क"} +{"pattern": "र्वâ", "repl": "र्क"} +{"pattern": "", "repl": ""} +{"pattern": "\u201a", "repl": ","} +{"pattern": "_", "repl": " "} +{"pattern": "-", "repl": " "} +{"pattern": "\u2014", "repl": " "} +{"pattern": "\u2013", "repl": " "} diff --git a/nemo_curator/stages/audio/asr/normalization/langs/kn/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/kn/alphabet.txt new file mode 100644 index 0000000000..ec20886158 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/kn/alphabet.txt @@ -0,0 +1,68 @@ + +ಂ +ಃ +ಅ +ಆ +ಇ +ಈ +ಉ +ಊ +ಋ +ೠ +ಎ +ಏ +ಐ +ಒ +ಓ +ಔ +ಕ +ಖ +ಗ +ಘ +ಙ +ಚ +ಛ +ಜ +ಝ +ಞ +ಟ +ಠ +ಡ +ಢ +ಣ +ತ +ಥ +ದ +ಧ +ನ +ಪ +ಫ +ಬ +ಭ +ಮ +ಯ +ರ +ಲ +ವ +ಶ +ಷ +ಸ +ಹ +ಳ +ಾ +ಿ +ೀ +ು +ೂ +ೃ +ೄ +ೆ +ೇ +ೈ +ೊ +ೋ +ೌ +್ +ಌ +ೡ +ಁ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/kn/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/kn/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ml/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/ml/alphabet.txt new file mode 100644 index 0000000000..5ad046829f --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/ml/alphabet.txt @@ -0,0 +1,75 @@ + +ഁ +ം +ഃ +അ +ആ +ഇ +ഈ +ഉ +ഊ +ഋ +ൠ +എ +ഏ +ഐ +ഒ +ഓ +ഔ +ക +ഖ +ഗ +ഘ +ങ +ച +ഛ +ജ +ഝ +ഞ +ട +ഠ +ഡ +ഢ +ണ +ത +ഥ +ദ +ധ +ന +പ +ഫ +ബ +ഭ +മ +യ +ര +റ +ല +ള +ഴ +വ +ശ +ഷ +സ +ഹ +ാ +ി +ീ +ു +ൂ +ൃ +ൄ +െ +േ +ൈ +ൊ +ോ +ൌ +ൗ +് +ൺ +ൻ +ർ +ൽ +ൾ +ൿ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ml/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/ml/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/mr/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/mr/alphabet.txt new file mode 100644 index 0000000000..2ccc9d3dbd --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/mr/alphabet.txt @@ -0,0 +1,73 @@ + +ँ +ं +ः +अ +आ +इ +ई +उ +ऊ +ऋ +ॠ +ऌ +ए +ऐ +ऑ +ओ +औ +ॲ +ऍ +क +ख +ग +घ +ङ +च +छ +ज +झ +ञ +ट +ठ +ड +ढ +ण +त +थ +द +ध +न +प +फ +ब +भ +म +य +र +ऱ +ल +ळ +व +श +ष +स +ह +़ +ा +ि +ी +ु +ू +ृ +ॄ +ॅ +े +ै +ॉ +ो +ौ +् +ॐ +ऽ +ऱ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/mr/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/mr/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/pa/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/pa/alphabet.txt new file mode 100644 index 0000000000..b089b42619 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/pa/alphabet.txt @@ -0,0 +1,66 @@ + +ਂ +ੰ +ਃ +ਅ +ਆ +ਇ +ਈ +ਉ +ਊ +ਏ +ਐ +ਓ +ਔ +ੳ +ੲ +ਸ +ਹ +ਕ +ਖ +ਗ +ਘ +ਙ +ਚ +ਛ +ਜ +ਝ +ਞ +ਟ +ਠ +ਡ +ਢ +ਣ +ਤ +ਥ +ਦ +ਧ +ਨ +ਪ +ਫ +ਬ +ਭ +ਮ +ਯ +ਰ +ਲ +ਵ +ੜ +ਸ਼ +ਖ਼ +ਗ਼ +ਜ਼ +ਫ਼ +ਲ਼ +ਾ +ਿ +ੀ +ੁ +ੂ +ੇ +ੈ +ੋ +ੌ +੍ +ੱ +ੴ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/pa/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/pa/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/remove_chars.txt b/nemo_curator/stages/audio/asr/normalization/langs/remove_chars.txt new file mode 100644 index 0000000000..75dfc4ed1e --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/remove_chars.txt @@ -0,0 +1 @@ +[](){}<>;:"!.॥/\#*+=&^|~·`@“”´»«$%‘’' diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ta/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/ta/alphabet.txt new file mode 100644 index 0000000000..f53b8bdefb --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/ta/alphabet.txt @@ -0,0 +1,50 @@ + +அ +ஆ +இ +ஈ +உ +ஊ +எ +ஏ +ஐ +ஒ +ஓ +ஔ +க +ங +ச +ஞ +ட +ண +த +ந +ப +ம +ய +ர +ல +வ +ழ +ள +ற +ன +ஜ +ஷ +ஸ +ஹ +க்ஷ +ஶ +ஃ +ா +ி +ீ +ு +ூ +ெ +ே +ை +ொ +ோ +ௌ +் diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ta/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/ta/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/te/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/te/alphabet.txt new file mode 100644 index 0000000000..2ccbac7302 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/te/alphabet.txt @@ -0,0 +1,67 @@ + +ఁ +ం +ః +అ +ఆ +ఇ +ఈ +ఉ +ఊ +ఋ +ౠ +ఎ +ఏ +ఐ +ఒ +ఓ +ఔ +క +ఖ +గ +ఘ +ఙ +చ +ఛ +జ +ఝ +ఞ +ట +ఠ +డ +ఢ +ణ +త +థ +ద +ధ +న +ప +ఫ +బ +భ +మ +య +ర +ఱ +ల +ళ +వ +శ +ష +స +హ +ా +ి +ీ +ు +ూ +ృ +ౄ +ె +ే +ై +ొ +ో +ౌ +్ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/te/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/te/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ur/alphabet.txt b/nemo_curator/stages/audio/asr/normalization/langs/ur/alphabet.txt new file mode 100644 index 0000000000..db113b3336 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/langs/ur/alphabet.txt @@ -0,0 +1,50 @@ +ء +آ +أ +ؤ +ئ +ا +ب +پ +ت +ٹ +ث +ج +چ +ح +خ +د +ڈ +ذ +ر +ڑ +ز +ژ +س +ش +ص +ض +ط +ظ +ع +غ +ف +ق +ک +گ +ل +م +ن +ں +و +ہ +ھ +ی +ے + +َ +ُ +ِ +ّ +ْ +ٰ diff --git a/nemo_curator/stages/audio/asr/normalization/langs/ur/pretok.jsonl b/nemo_curator/stages/audio/asr/normalization/langs/ur/pretok.jsonl new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nemo_curator/stages/audio/asr/normalization/stats.py b/nemo_curator/stages/audio/asr/normalization/stats.py new file mode 100644 index 0000000000..731168f720 --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/stats.py @@ -0,0 +1,518 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Streaming transcript-quality statistics for normalized ASR manifests.""" + +from __future__ import annotations + +import json +import os +import time +from collections import Counter, defaultdict +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.backends.utils import RayStageSpecKeys +from nemo_curator.stages.audio.asr.normalization.transcript import ( + _RESOURCE_ROOT, + _coerce_lang_list, + _load_alphabet, + resolve_lang, +) +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from nemo_curator.backends.base import NodeInfo, WorkerMetadata + + +@dataclass +class _StatsBucket: + total_transcripts: int = 0 + valid_transcripts: int = 0 + invalid_transcripts: int = 0 + dropped_invalid: int = 0 + total_duration_seconds: float = 0.0 + valid_duration_seconds: float = 0.0 + invalid_duration_seconds: float = 0.0 + total_chars: int = 0 + known_chars: Counter[str] = field(default_factory=Counter) + unknown_char_count: int = 0 + unknown_chars: Counter[str] = field(default_factory=Counter) + split_counts: dict[str, dict[str, int]] = field( + default_factory=lambda: defaultdict(lambda: {"total": 0, "valid": 0, "invalid": 0}) + ) + split_duration_seconds: dict[str, dict[str, float]] = field( + default_factory=lambda: defaultdict(lambda: {"total": 0.0, "valid": 0.0, "invalid": 0.0}) + ) + + +@dataclass(frozen=True) +class _BucketUpdate: + text: str + duration: float + known_chars: Counter[str] + unknown_chars: dict[str, int] + transcript_error: bool + split: str + dropped: int + + +@dataclass +class TranscriptStatsStage(ProcessingStage[AudioTask, AudioTask]): + """Collect transcript-validity stats while streaming ``AudioTask`` objects. + + The stage writes global aggregate statistics, full summaries for each + language/source pair under ``by_language``, and full language-level totals + under ``by_language_overall``. Each summary includes unknown-character + counts and rates so vocabulary gaps can be inspected after a run. + + Args: + code_switch_langs: Extra language resource folders whose alphabets + should be included when computing ``alpha_minus_known_chars``. + """ + + name: str = "transcript_stats" + text_key: str = "text" + lang_key: str = "lang" + source_key: str = "source" + duration_key: str = "duration" + split_key: str = "split_type" + unknown_chars_key: str = "unknown_chars" + transcript_error_key: str = "transcript_error" + drop_invalid: bool = False + log_top_n_unknown_chars: int = 50 + code_switch_langs: str | list[str] | None = field(default_factory=list) + output_summary_path: str | None = None + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def __post_init__(self) -> None: + self._summary_handle = None + self._total_transcripts = 0 + self._valid_transcripts = 0 + self._invalid_transcripts = 0 + self._dropped_invalid = 0 + self._total_duration_seconds = 0.0 + self._valid_duration_seconds = 0.0 + self._invalid_duration_seconds = 0.0 + self._total_chars = 0 + self._known_chars: Counter[str] = Counter() + self._languages: set[str] = set() + self._unknown_char_count = 0 + self._unknown_chars: Counter[str] = Counter() + self._split_counts: dict[str, dict[str, int]] = defaultdict(lambda: {"total": 0, "valid": 0, "invalid": 0}) + self._split_duration_seconds: dict[str, dict[str, float]] = defaultdict( + lambda: {"total": 0.0, "valid": 0.0, "invalid": 0.0} + ) + self._language_stats: dict[str, _StatsBucket] = defaultdict(_StatsBucket) + self._language_source_stats: dict[str, dict[str, _StatsBucket]] = defaultdict( + lambda: defaultdict(_StatsBucket) + ) + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key, self.lang_key, self.unknown_chars_key, self.transcript_error_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [] + + def num_workers(self) -> int | None: + return 1 + + def ray_stage_spec(self) -> dict[str, Any]: + return {RayStageSpecKeys.IS_FANOUT_STAGE: False} + + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers": 1} + + def setup_on_node( + self, + _node_info: NodeInfo | None = None, + _worker_metadata: WorkerMetadata | None = None, + ) -> None: + if self.output_summary_path: + parent = os.path.dirname(self.output_summary_path) + if parent: + os.makedirs(parent, exist_ok=True) + self._summary_handle = open(self.output_summary_path, "w", encoding="utf-8") # noqa: SIM115 + + def process(self, task: AudioTask) -> AudioTask | None: + start = time.perf_counter() + text = str(task.data.get(self.text_key, "")) + lang = resolve_lang(str(task.data[self.lang_key])) + self._languages.add(lang) + source = str(task.data.get(self.source_key, "unknown") or "unknown") + duration = float(task.data.get(self.duration_key, 0.0) or 0.0) + unknown_chars = _coerce_unknown_chars(task.data.get(self.unknown_chars_key, {})) + unknown_char_set = set(unknown_chars) + transcript_error = bool(task.data.get(self.transcript_error_key, bool(unknown_chars))) + split = str(task.data.get(self.split_key, "unknown")) + + self._total_transcripts += 1 + self._total_duration_seconds += duration + self._total_chars += len(text) + known_chars = Counter(char for char in text if char not in unknown_char_set) + self._known_chars.update(known_chars) + self._unknown_chars.update(unknown_chars) + unknown_count = sum(unknown_chars.values()) + self._unknown_char_count += unknown_count + + split_counts = self._split_counts[split] + split_durations = self._split_duration_seconds[split] + split_counts["total"] += 1 + split_durations["total"] += duration + if transcript_error: + self._invalid_transcripts += 1 + self._invalid_duration_seconds += duration + split_counts["invalid"] += 1 + split_durations["invalid"] += duration + else: + self._valid_transcripts += 1 + self._valid_duration_seconds += duration + split_counts["valid"] += 1 + split_durations["valid"] += duration + + dropped = int(transcript_error and self.drop_invalid) + self._dropped_invalid += dropped + update = _BucketUpdate( + text=text, + duration=duration, + known_chars=known_chars, + unknown_chars=unknown_chars, + transcript_error=transcript_error, + split=split, + dropped=dropped, + ) + _update_bucket(self._language_stats[lang], update) + _update_bucket(self._language_source_stats[lang][source], update) + self._log_metrics(self._metrics_snapshot(process_time=time.perf_counter() - start)) + self._write_summary() + if dropped: + return None + return task + + def summary(self) -> dict[str, Any]: + valid_rate = self._valid_transcripts / self._total_transcripts if self._total_transcripts else 0.0 + invalid_rate = self._invalid_transcripts / self._total_transcripts if self._total_transcripts else 0.0 + unique_known_chars = len(self._known_chars) + unique_unknown_chars = len(self._unknown_chars) + unique_known_char_rate = unique_known_chars / self._total_chars if self._total_chars else 0.0 + unique_unknown_char_rate = unique_unknown_chars / self._total_chars if self._total_chars else 0.0 + alpha_minus_known_chars = [] + if len(self._languages) == 1: + language = next(iter(self._languages)) + alpha_minus_known_chars = sorted( + _load_combined_alphabet(language, self.code_switch_langs) - set(self._known_chars) + ) + split_hours = { + split: {key: seconds / 3600 for key, seconds in durations.items()} + for split, durations in self._split_duration_seconds.items() + } + return { + "total_transcripts": self._total_transcripts, + "valid_transcripts": self._valid_transcripts, + "invalid_transcripts": self._invalid_transcripts, + "dropped_invalid": self._dropped_invalid, + "emitted_transcripts": self._total_transcripts - self._dropped_invalid, + "valid_transcript_rate": valid_rate, + "invalid_transcript_rate": invalid_rate, + "total_duration_hours": self._total_duration_seconds / 3600, + "valid_duration_hours": self._valid_duration_seconds / 3600, + "invalid_duration_hours": self._invalid_duration_seconds / 3600, + "total_chars": self._total_chars, + "unique_known_chars": unique_known_chars, + "unique_known_char_rate": unique_known_char_rate, + "unique_unknown_chars": unique_unknown_chars, + "unique_unknown_char_rate": unique_unknown_char_rate, + "unknown_char_details": _unknown_char_details(self._unknown_chars, self._total_chars), + "alpha_minus_known_chars": alpha_minus_known_chars, + "split_counts": dict(self._split_counts), + "split_hours": split_hours, + "by_language": { + lang: { + source: _bucket_summary(bucket, lang, self.code_switch_langs) + for source, bucket in sorted(source_buckets.items(), key=lambda item: item[0]) + } + for lang, source_buckets in sorted(self._language_source_stats.items(), key=lambda item: item[0]) + }, + "by_language_overall": { + lang: _bucket_summary(bucket, lang, self.code_switch_langs) + for lang, bucket in sorted(self._language_stats.items(), key=lambda item: item[0]) + }, + } + + def teardown(self) -> None: + logger.info(self.format_summary()) + if self._summary_handle is not None: + self._summary_handle.close() + self._summary_handle = None + + def _write_summary(self) -> None: + if not self.output_summary_path: + return + if self._summary_handle is None: + self.setup_on_node() + self._summary_handle.seek(0) + self._summary_handle.write(json.dumps(self.summary(), ensure_ascii=False, indent=2) + "\n") + self._summary_handle.truncate() + self._summary_handle.flush() + + def format_summary(self) -> str: + return _format_summary(self.summary(), name=self.name, top_n_unknown_chars=self.log_top_n_unknown_chars) + + @staticmethod + def load_summary(path: str | os.PathLike[str] | None) -> dict[str, Any] | None: + """Load a transcript stats summary JSON file, including legacy JSONL output.""" + if not path: + return None + summary_path = Path(path) + if not summary_path.exists(): + logger.warning(f"Stats summary path does not exist: {summary_path}") + return None + raw_summary = summary_path.read_text(encoding="utf-8") + try: + return json.loads(raw_summary) + except json.JSONDecodeError: + lines = [line for line in raw_summary.splitlines() if line.strip()] + for line in reversed(lines): + try: + return json.loads(line) + except json.JSONDecodeError: + continue + logger.warning(f"Could not parse stats summary JSON: {summary_path}") + return None + + @classmethod + def format_summary_from_path( + cls, + path: str | os.PathLike[str] | None, + *, + name: str = "transcript_stats", + top_n_unknown_chars: int = 50, + ) -> str | None: + """Load and format a persisted transcript stats summary.""" + summary = cls.load_summary(path) + if not summary: + return None + return _format_summary(summary, name=name, top_n_unknown_chars=top_n_unknown_chars) + + @classmethod + def log_summary_from_path( + cls, + path: str | os.PathLike[str] | None, + *, + name: str = "transcript_stats", + top_n_unknown_chars: int = 50, + ) -> str | None: + """Load, format, and log a persisted transcript stats summary.""" + formatted = cls.format_summary_from_path(path, name=name, top_n_unknown_chars=top_n_unknown_chars) + if formatted: + logger.info("\n" + formatted) + return formatted + + def _metrics_snapshot(self, process_time: float) -> dict[str, float | int]: + summary = self.summary() + return { + "input_tasks": summary["total_transcripts"], + "emitted_tasks": summary["emitted_transcripts"], + "dropped_invalid": summary["dropped_invalid"], + "valid_transcripts": summary["valid_transcripts"], + "invalid_transcripts": summary["invalid_transcripts"], + "total_duration_hours": summary["total_duration_hours"], + "valid_duration_hours": summary["valid_duration_hours"], + "invalid_duration_hours": summary["invalid_duration_hours"], + "total_chars": summary["total_chars"], + "unique_known_chars": summary["unique_known_chars"], + "unique_known_char_rate": summary["unique_known_char_rate"], + "unique_unknown_chars": summary["unique_unknown_chars"], + "unique_unknown_char_rate": summary["unique_unknown_char_rate"], + "process_time": process_time, + } + + +def _coerce_unknown_chars(value: Any) -> dict[str, int]: # noqa: ANN401 + if isinstance(value, dict): + return {str(char): int(count) for char, count in value.items()} + return {} + + +def _format_summary(summary: dict[str, Any], *, name: str, top_n_unknown_chars: int) -> str: + lines = [ + f"[{name}] Transcript normalization summary", + " per_language_source:", + ] + for lang, source_summaries in summary.get("by_language", {}).items(): + for source, source_summary in source_summaries.items(): + lines.extend( + _format_summary_block( + f"lang={lang} source={source}", + source_summary, + indent=" ", + top_n_unknown_chars=top_n_unknown_chars, + ) + ) + lines.append(" per_language_overall:") + for lang, language_summary in summary.get("by_language_overall", {}).items(): + lines.extend( + _format_summary_block( + f"lang={lang} overall", + language_summary, + indent=" ", + top_n_unknown_chars=top_n_unknown_chars, + ) + ) + lines.append(" global:") + lines.extend( + _format_summary_block( + "all languages/sources", + summary, + indent=" ", + top_n_unknown_chars=top_n_unknown_chars, + ) + ) + return "\n".join(lines) + + +def _format_summary_block( + label: str, + summary: dict[str, Any], + *, + indent: str, + top_n_unknown_chars: int, +) -> list[str]: + split_hours = _round_nested_floats(summary["split_hours"]) + return [ + f"{indent}{label}", + ( + f"{indent} transcripts: total={summary['total_transcripts']} " + f"valid={summary['valid_transcripts']} ({summary['valid_transcript_rate']:.2%}) " + f"invalid={summary['invalid_transcripts']} ({summary['invalid_transcript_rate']:.2%})" + ), + ( + f"{indent} hours: total={summary['total_duration_hours']:.2f} " + f"valid_after_filter={summary['valid_duration_hours']:.2f} " + f"invalid_removed={summary['invalid_duration_hours']:.2f}" + ), + f"{indent} split_hours: {split_hours}", + ( + f"{indent} chars: total={summary['total_chars']} " + f"unique_known={summary['unique_known_chars']} " + f"unique_known_rate={summary['unique_known_char_rate']:.2%} " + f"unique_unknown={summary['unique_unknown_chars']} " + f"unique_unknown_rate={summary['unique_unknown_char_rate']:.2%}" + ), + f"{indent} unknown_chars: {_format_unknown_char_details(summary.get('unknown_char_details', {}), top_n_unknown_chars)}", + f"{indent} alpha_minus_known_chars: {summary['alpha_minus_known_chars']}", + f"{indent} split_counts: {summary['split_counts']}", + ] + + +def _unknown_char_details(unknown_chars: Counter[str], total_chars: int) -> dict[str, dict[str, float | int]]: + return { + char: {"count": count, "rate": count / total_chars if total_chars else 0.0} + for char, count in sorted(unknown_chars.items(), key=lambda item: (-item[1], item[0])) + } + + +def _format_unknown_char_details(details: dict[str, dict[str, float | int]], top_n: int) -> str: + if not details: + return "{}" + items = list(details.items()) + if top_n > 0: + items = items[:top_n] + formatted = ", ".join( + f"{char}=count={int(stats['count'])} rate={float(stats['rate']):.2%}" for char, stats in items + ) + if 0 < top_n < len(details): + formatted = f"{formatted} (showing top {top_n} of {len(details)})" + return formatted + + +def _update_bucket(bucket: _StatsBucket, update: _BucketUpdate) -> None: + bucket.total_transcripts += 1 + bucket.dropped_invalid += update.dropped + bucket.total_duration_seconds += update.duration + bucket.total_chars += len(update.text) + bucket.known_chars.update(update.known_chars) + bucket.unknown_chars.update(update.unknown_chars) + bucket.unknown_char_count += sum(update.unknown_chars.values()) + + split_counts = bucket.split_counts[update.split] + split_durations = bucket.split_duration_seconds[update.split] + split_counts["total"] += 1 + split_durations["total"] += update.duration + if update.transcript_error: + bucket.invalid_transcripts += 1 + bucket.invalid_duration_seconds += update.duration + split_counts["invalid"] += 1 + split_durations["invalid"] += update.duration + else: + bucket.valid_transcripts += 1 + bucket.valid_duration_seconds += update.duration + split_counts["valid"] += 1 + split_durations["valid"] += update.duration + + +def _bucket_summary(bucket: _StatsBucket, lang: str, code_switch_langs: str | list[str] | None) -> dict[str, Any]: + valid_rate = bucket.valid_transcripts / bucket.total_transcripts if bucket.total_transcripts else 0.0 + invalid_rate = bucket.invalid_transcripts / bucket.total_transcripts if bucket.total_transcripts else 0.0 + unique_known_chars = len(bucket.known_chars) + unique_unknown_chars = len(bucket.unknown_chars) + unique_known_char_rate = unique_known_chars / bucket.total_chars if bucket.total_chars else 0.0 + unique_unknown_char_rate = unique_unknown_chars / bucket.total_chars if bucket.total_chars else 0.0 + split_hours = { + split: {key: seconds / 3600 for key, seconds in durations.items()} + for split, durations in bucket.split_duration_seconds.items() + } + return { + "total_transcripts": bucket.total_transcripts, + "valid_transcripts": bucket.valid_transcripts, + "invalid_transcripts": bucket.invalid_transcripts, + "dropped_invalid": bucket.dropped_invalid, + "emitted_transcripts": bucket.total_transcripts - bucket.dropped_invalid, + "valid_transcript_rate": valid_rate, + "invalid_transcript_rate": invalid_rate, + "total_duration_hours": bucket.total_duration_seconds / 3600, + "valid_duration_hours": bucket.valid_duration_seconds / 3600, + "invalid_duration_hours": bucket.invalid_duration_seconds / 3600, + "total_chars": bucket.total_chars, + "unique_known_chars": unique_known_chars, + "unique_known_char_rate": unique_known_char_rate, + "unique_unknown_chars": unique_unknown_chars, + "unique_unknown_char_rate": unique_unknown_char_rate, + "unknown_char_details": _unknown_char_details(bucket.unknown_chars, bucket.total_chars), + "alpha_minus_known_chars": sorted(_load_combined_alphabet(lang, code_switch_langs) - set(bucket.known_chars)), + "split_counts": {split: dict(counts) for split, counts in bucket.split_counts.items()}, + "split_hours": split_hours, + } + + +def _load_combined_alphabet(lang: str, code_switch_langs: str | list[str] | None) -> set[str]: + alphabet = set() + for resource_lang in dict.fromkeys( + [resolve_lang(lang), *(resolve_lang(item) for item in _coerce_lang_list(code_switch_langs))] + ): + alphabet.update(_load_alphabet(_RESOURCE_ROOT / resource_lang / "alphabet.txt")) + return alphabet + + +def _round_nested_floats(value: Any, ndigits: int = 2) -> Any: # noqa: ANN401 + if isinstance(value, dict): + return {key: _round_nested_floats(nested_value, ndigits) for key, nested_value in value.items()} + if isinstance(value, float): + return round(value, ndigits) + return value diff --git a/nemo_curator/stages/audio/asr/normalization/transcript.py b/nemo_curator/stages/audio/asr/normalization/transcript.py new file mode 100644 index 0000000000..45699685fd --- /dev/null +++ b/nemo_curator/stages/audio/asr/normalization/transcript.py @@ -0,0 +1,252 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ASR transcript normalization for flat audio manifests.""" + +from __future__ import annotations + +import json +import re +import time +import unicodedata +from collections import Counter +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING, Any + +from loguru import logger + +from nemo_curator.stages.base import ProcessingStage +from nemo_curator.stages.resources import Resources +from nemo_curator.tasks import AudioTask + +if TYPE_CHECKING: + from collections.abc import Iterable + +_RESOURCE_ROOT = Path(__file__).parent / "langs" +_PUNCTUATION_CHARS_BY_LANG = { + "bn": ",?\u0964", + "en": ".,?", + "gu": ".,?", + "hi": ".,?", + "kn": ".,?", + "ml": ".,?!\u0964\u0965", + "mr": ".,?", + "pa": ".,?\u0964", + "ta": ".,?", + "te": ".,?", + "ur": "\u060c\u061f\u06d4", +} + + +@dataclass(frozen=True) +class NormalizationResult: + text: str + unknown_chars: dict[str, int] + + +class ResourceTranscriptNormalizer: + """Resource-driven normalizer for one ASR transcript language. + + Code-switch languages extend the base language's alphabet, pretok rules, + and punctuation set so mixed-language transcripts can be validated and + cleaned by a single normalizer. + """ + + def __init__( + self, + lang: str, + *, + remove_pnc_chars: bool = True, + lowercase_text: bool = False, + code_switch_langs: str | Iterable[str] | None = None, + ) -> None: + self.lang = resolve_lang(lang) + self.lowercase_text = lowercase_text + resource_langs = _ordered_unique( + [self.lang, *(resolve_lang(lang) for lang in _coerce_lang_list(code_switch_langs))] + ) + self.alphabet = set() + self.pretok_rules = [] + pnc_chars = "" + for resource_lang in resource_langs: + lang_dir = _RESOURCE_ROOT / resource_lang + self.alphabet.update(_load_alphabet(lang_dir / "alphabet.txt")) + self.pretok_rules.extend(_load_jsonl(lang_dir / "pretok.jsonl")) + pnc_chars += _PUNCTUATION_CHARS_BY_LANG[resource_lang] + remove_chars = _load_chars(_RESOURCE_ROOT / "remove_chars.txt") + pnc_chars = _ordered_unique_chars(pnc_chars) + if remove_pnc_chars: + self.remove_chars = remove_chars + pnc_chars + else: + non_pnc_remove_chars = "".join(char for char in remove_chars if char not in pnc_chars) + self.remove_chars = non_pnc_remove_chars + + def normalize(self, text: str) -> NormalizationResult: + normalized = unicodedata.normalize("NFKC", text) + # Language resources own punctuation normalization; rules may include + # characters such as \u2014 (em dash) and \u2013 (en dash). + for rule in self.pretok_rules: + pattern = str(rule["pattern"]) + repl = str(rule.get("repl", "")) + normalized = re.sub(pattern, repl, normalized) + if rule.get("repeat"): + while re.search(pattern, normalized): + normalized = re.sub(pattern, repl, normalized) + if self.remove_chars: + normalized = re.sub("[" + re.escape(self.remove_chars) + "]", " ", normalized) + normalized = " ".join(normalized.split()) + if self.lowercase_text: + normalized = normalized.lower() + unknown_chars = Counter(char for char in normalized if char not in self.alphabet and not char.isspace()) + return NormalizationResult(text=normalized, unknown_chars=dict(unknown_chars)) + + +@dataclass +class TranscriptNormalizationStage(ProcessingStage[AudioTask, AudioTask]): + """Normalize ASR transcript text and record unknown characters. + + Args: + lowercase_text: If True, lowercase the normalized transcript before + unknown-character detection and output assignment. + code_switch_langs: Extra language resource folders whose alphabet, + pretok rules, and punctuation characters should be combined with + each task's primary ``lang``. + """ + + name: str = "transcript_normalization" + text_key: str = "text" + lang_key: str = "lang" + output_text_key: str = "text" + output_original_text_key: str = "text_original" + unknown_chars_key: str = "unknown_chars" + transcript_error_key: str = "transcript_error" + duration_key: str = "duration" + remove_pnc_chars: bool = True + lowercase_text: bool = False + code_switch_langs: str | list[str] | None = field(default_factory=list) + resources: Resources = field(default_factory=lambda: Resources(cpus=1.0)) + + def __post_init__(self) -> None: + self._normalizers: dict[str, ResourceTranscriptNormalizer] = {} + self._unknown_char_counts: Counter[str] = Counter() + + def inputs(self) -> tuple[list[str], list[str]]: + return [], [self.text_key, self.lang_key] + + def outputs(self) -> tuple[list[str], list[str]]: + return [], [ + self.output_text_key, + self.output_original_text_key, + self.unknown_chars_key, + self.transcript_error_key, + ] + + def process(self, task: AudioTask) -> AudioTask: + start = time.perf_counter() + original_text = str(task.data[self.text_key]) + lang = resolve_lang(str(task.data[self.lang_key])) + result = self._normalizer(lang).normalize(original_text) + unknown_chars = result.unknown_chars + transcript_error = bool(unknown_chars) + duration = float(task.data.get(self.duration_key, 0.0) or 0.0) + + self._unknown_char_counts.update(unknown_chars) + metrics: dict[str, float | int] = { + "input_tasks": 1, + "emitted_tasks": 1, + "unknown_duration_seconds": duration if transcript_error else 0.0, + "process_time": time.perf_counter() - start, + } + self._log_metrics(metrics) + if unknown_chars: + logger.info( + f"[{self.name}] unknown chars for lang={lang}: " + f"{dict(sorted(unknown_chars.items(), key=lambda item: item[1], reverse=True))}; " + f"running_total={dict(self._unknown_char_counts.most_common())}" + ) + + task.data[self.output_original_text_key] = original_text + task.data[self.output_text_key] = result.text + task.data[self.unknown_chars_key] = unknown_chars + task.data[self.transcript_error_key] = transcript_error + return task + + def _normalizer(self, lang: str) -> ResourceTranscriptNormalizer: + if lang not in self._normalizers: + self._normalizers[lang] = ResourceTranscriptNormalizer( + lang, + remove_pnc_chars=self.remove_pnc_chars, + lowercase_text=self.lowercase_text, + code_switch_langs=self.code_switch_langs, + ) + return self._normalizers[lang] + + +def resolve_lang(lang: str) -> str: + normalized = lang.strip().lower() + if (_RESOURCE_ROOT / normalized).is_dir(): + return normalized + else: + msg = f"Unsupported ASR normalization language: {lang!r}" + raise ValueError(msg) + + +def _load_alphabet(path: Path) -> set[str]: + chars = _load_chars(path) + alphabet = set(chars) + alphabet.update(char.upper() for char in list(alphabet)) + return alphabet + + +def _ordered_unique(values: Iterable[str]) -> list[str]: + return list(dict.fromkeys(values)) + + +def _coerce_lang_list(value: str | Iterable[str] | None) -> list[str]: + if value is None: + return [] + if isinstance(value, str): + return [value] + return list(value) + + +def _ordered_unique_chars(value: str) -> str: + return "".join(dict.fromkeys(value)) + + +def _load_chars(path: Path) -> str: + if not path.exists(): + msg = f"Missing ASR normalization resource: {path}" + raise FileNotFoundError(msg) + chars = [] + with path.open(encoding="utf-8") as f: + for raw_line in f: + line = raw_line.rstrip("\n") + if not line or line.startswith("#"): + continue + chars.append(line) + return "".join(chars) + + +def _load_jsonl(path: Path) -> list[dict[str, Any]]: + if not path.exists(): + msg = f"Missing ASR normalization resource: {path}" + raise FileNotFoundError(msg) + rules = [] + with path.open(encoding="utf-8") as f: + for line in f: + if line.strip(): + rules.append(json.loads(line)) + return rules diff --git a/nemo_curator/stages/audio/common.py b/nemo_curator/stages/audio/common.py index 6c7c8eb837..7dbdc88685 100644 --- a/nemo_curator/stages/audio/common.py +++ b/nemo_curator/stages/audio/common.py @@ -15,6 +15,7 @@ import json import os import time +from collections.abc import Sequence from dataclasses import dataclass, field from operator import eq, ge, gt, le, lt, ne from typing import Any @@ -208,9 +209,10 @@ def __post_init__(self) -> None: raise ValueError(msg) def decompose(self) -> list[ProcessingStage]: + manifest_path = _coerce_manifest_path(self.manifest_path) return [ FilePartitioningStage( - file_paths=self.manifest_path, + file_paths=manifest_path, files_per_partition=self.files_per_partition, blocksize=self.blocksize, file_extensions=self.file_extensions, @@ -228,6 +230,15 @@ def get_description(self) -> str: return ", ".join(parts) +def _coerce_manifest_path(manifest_path: Any) -> str | list[str]: # noqa: ANN401 + """Convert Hydra/OmegaConf sequences to plain Python paths for partitioning.""" + if isinstance(manifest_path, str): + return manifest_path + if isinstance(manifest_path, Sequence): + return [str(path) for path in manifest_path] + return manifest_path + + @dataclass class ManifestWriterStage(ProcessingStage[AudioTask, AudioTask]): """Append a single AudioTask to a JSONL manifest file. diff --git a/tests/stages/audio/asr/__init__.py b/tests/stages/audio/asr/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/stages/audio/asr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/stages/audio/asr/datasets/__init__.py b/tests/stages/audio/asr/datasets/__init__.py new file mode 100644 index 0000000000..4fc25d0d3c --- /dev/null +++ b/tests/stages/audio/asr/datasets/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/stages/audio/asr/datasets/conftest.py b/tests/stages/audio/asr/datasets/conftest.py new file mode 100644 index 0000000000..9878975112 --- /dev/null +++ b/tests/stages/audio/asr/datasets/conftest.py @@ -0,0 +1,61 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf +from datasets import Dataset + +NUM_REALISTIC_ROWS = 100 +INPUT_SAMPLE_RATE = 8000 +OUTPUT_SAMPLE_RATE = 16000 + + +@pytest.fixture +def indicvoices_raw_dataset(tmp_path: Path) -> tuple[Path, int]: + """Create a tiny on-disk HF dataset shaped like the raw IndicVoices split.""" + raw_root = tmp_path / "raw" + audio_dir = tmp_path / "source_audio" + valid_dir = raw_root / "valid" + audio_dir.mkdir(parents=True) + + filepaths = [] + for i in range(NUM_REALISTIC_ROWS): + samples = np.linspace(0.0, 1.0, INPUT_SAMPLE_RATE // 20, endpoint=False, dtype=np.float32) + tone = 0.1 * np.sin(2 * np.pi * (220 + i) * samples) + stereo = np.stack([tone, tone * 0.5], axis=1) + path = audio_dir / f"sample_{i}.wav" + sf.write(path, stereo, INPUT_SAMPLE_RATE, subtype="PCM_16") + filepaths.append(str(path)) + + dataset = Dataset.from_dict( + { + "audio_filepath": filepaths, + "text": ["ગુજરાતી વાક્ય" for _ in range(NUM_REALISTIC_ROWS)], + "duration": [0.05] * NUM_REALISTIC_ROWS, + "lang": ["gu"] * NUM_REALISTIC_ROWS, + "speaker_id": [f"speaker_{i % 3}" for i in range(NUM_REALISTIC_ROWS)], + "gender": ["Female" if i % 2 else "Male" for i in range(NUM_REALISTIC_ROWS)], + "age_group": ["30-45"] * NUM_REALISTIC_ROWS, + "scenario": ["Extempore"] * NUM_REALISTIC_ROWS, + "task_name": ["Unit Test"] * NUM_REALISTIC_ROWS, + "state": ["Gujarat"] * NUM_REALISTIC_ROWS, + "district": ["Ahmedabad"] * NUM_REALISTIC_ROWS, + "normalized": ["ગુજરાતી વાક્ય" for _ in range(NUM_REALISTIC_ROWS)], + } + ) + dataset.save_to_disk(str(valid_dir)) + return raw_root, NUM_REALISTIC_ROWS diff --git a/tests/stages/audio/asr/datasets/e2e/__init__.py b/tests/stages/audio/asr/datasets/e2e/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/stages/audio/asr/datasets/e2e/configs/indicvoices_pipeline.yaml b/tests/stages/audio/asr/datasets/e2e/configs/indicvoices_pipeline.yaml new file mode 100644 index 0000000000..453e3cb1f3 --- /dev/null +++ b/tests/stages/audio/asr/datasets/e2e/configs/indicvoices_pipeline.yaml @@ -0,0 +1,54 @@ +# IndicVoices ASR extraction pipeline - e2e test configuration + +raw_data_dir: ??? +output_dir: ??? +langs: + - gu +native_splits: + - valid +split_dir_pattern: "{split}" +extraction_workers: 4 +dev_fraction: 0.6 +write_manifest: false +output_filename_pattern: "{split}_normalized.jsonl" +remove_pnc_chars: true +stats_summary_path: ${output_dir}/transcript_stats_summary.json + +stages: + - _target_: nemo_curator.stages.audio.asr.datasets.huggingface.HuggingFaceASRDatasetHandler + raw_data_dir: ${raw_data_dir} + output_dir: ${output_dir} + langs: ${langs} + source_name: IndicVoices + native_splits: ${native_splits} + split_dir_pattern: ${split_dir_pattern} + valid_split_strategy: dev_test + extraction_workers: ${extraction_workers} + dev_fraction: ${dev_fraction} + write_manifest: ${write_manifest} + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptNormalizationStage + text_key: text + lang_key: lang + output_text_key: text + output_original_text_key: text_original + remove_pnc_chars: ${remove_pnc_chars} + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptStatsStage + text_key: text + lang_key: lang + duration_key: duration + split_key: split_type + unknown_chars_key: unknown_chars + transcript_error_key: transcript_error + drop_invalid: true + output_summary_path: ${stats_summary_path} + + - _target_: nemo_curator.stages.audio.asr.io.split_manifest_writer.SplitAwareManifestWriter + output_dir: ${output_dir} + output_filename_pattern: ${output_filename_pattern} + langs: ${langs} + splits: + - train + - dev + - test diff --git a/tests/stages/audio/asr/datasets/e2e/test_indicvoices_e2e.py b/tests/stages/audio/asr/datasets/e2e/test_indicvoices_e2e.py new file mode 100644 index 0000000000..c3b532272f --- /dev/null +++ b/tests/stages/audio/asr/datasets/e2e/test_indicvoices_e2e.py @@ -0,0 +1,114 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""End-to-end test for IndicVoices ASR dataset ingestion. + +Runs the YAML-configured pipeline: + HuggingFaceASRDatasetHandler -> SplitAwareManifestWriter +""" + +import json +from collections import Counter +from pathlib import Path + +import soundfile as sf +from omegaconf import OmegaConf + +from nemo_curator.backends.xenna import XennaExecutor +from nemo_curator.config.run import create_pipeline_from_yaml +from nemo_curator.stages.audio.asr.datasets.huggingface import HuggingFaceASRDatasetHandler +from tests.stages.audio.asr.datasets.conftest import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE + +CONFIGS_DIR = Path(__file__).parent / "configs" + + +def _assert_stats_summary(output_dir: Path, total_rows: int) -> None: + summary_path = output_dir / "transcript_stats_summary.json" + assert summary_path.exists() + with summary_path.open(encoding="utf-8") as f: + stats_summary = json.load(f) + assert stats_summary["total_transcripts"] == total_rows + assert stats_summary["valid_transcripts"] == total_rows + assert stats_summary["invalid_transcripts"] == 0 + + +def test_indicvoices_pipeline_e2e( + tmp_path: Path, + indicvoices_raw_dataset: tuple[Path, int], +) -> None: + raw_root, total_rows = indicvoices_raw_dataset + output_dir = tmp_path / "out" + cfg = OmegaConf.load(CONFIGS_DIR / "indicvoices_pipeline.yaml") + cfg.raw_data_dir = str(raw_root) + cfg.output_dir = str(output_dir) + + pipeline = create_pipeline_from_yaml(cfg) + tasks = pipeline.run(XennaExecutor(config={"execution_mode": "batch"})) + + assert len(tasks) == total_rows + split_helper = HuggingFaceASRDatasetHandler( + raw_data_dir=str(raw_root), + output_dir=str(output_dir), + langs=["gu"], + source_name="IndicVoices", + valid_split_strategy="dev_test", + ) + expected_counts = Counter(split_helper.assign_split("valid", f"gu_valid_{i}") for i in range(total_rows)) + actual_counts = Counter(task.data["split_type"] for task in tasks) + assert actual_counts == expected_counts + assert actual_counts["dev"] + actual_counts["test"] == total_rows + assert 0.55 <= actual_counts["dev"] / total_rows <= 0.65 + + for task in tasks: + audio_path = Path(task.data["audio_filepath"]) + assert audio_path.parent == output_dir / "gu" / task.data["split_type"] / "audio" + wav_info = sf.info(audio_path) + assert wav_info.samplerate == OUTPUT_SAMPLE_RATE + assert wav_info.channels == 1 + assert wav_info.subtype == "PCM_16" + assert task.data["orig_sample_rate"] == INPUT_SAMPLE_RATE + assert task.data["orig_num_channels"] == 2 + assert task.data["lang"] == "gu" + assert task.data["source"] == "IndicVoices" + assert task.data["speaker_id"].startswith("speaker_") + assert task.data["text_original"] == task.data["text"] + assert task.data["unknown_chars"] == {} + assert task.data["transcript_error"] is False + + stats_metrics = [ + perf.custom_metrics for task in tasks for perf in task._stage_perf if perf.stage_name == "transcript_stats" + ] + assert len(stats_metrics) == total_rows + assert max(metric["input_tasks"] for metric in stats_metrics) == total_rows + assert max(metric["emitted_tasks"] for metric in stats_metrics) == total_rows + assert max(metric["valid_transcripts"] for metric in stats_metrics) == total_rows + assert max(metric["invalid_transcripts"] for metric in stats_metrics) == 0 + assert max(metric["unique_unknown_chars"] for metric in stats_metrics) == 0 + assert max(metric["unique_unknown_char_rate"] for metric in stats_metrics) == 0 + + _assert_stats_summary(output_dir, total_rows) + + manifest_counts = {} + for split in ["train", "dev", "test"]: + path = output_dir / "gu" / f"{split}_normalized.jsonl" + assert path.exists() + with path.open(encoding="utf-8") as f: + rows = [json.loads(line) for line in f if line.strip()] + manifest_counts[split] = len(rows) + assert all(row["split_type"] == split for row in rows) + assert all(row["text_original"] == row["text"] for row in rows) + assert all(row["unknown_chars"] == {} for row in rows) + assert all(row["transcript_error"] is False for row in rows) + + assert manifest_counts == {"train": 0, "dev": actual_counts["dev"], "test": actual_counts["test"]} diff --git a/tests/stages/audio/asr/datasets/test_base.py b/tests/stages/audio/asr/datasets/test_base.py new file mode 100644 index 0000000000..852af5c7a8 --- /dev/null +++ b/tests/stages/audio/asr/datasets/test_base.py @@ -0,0 +1,60 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from dataclasses import dataclass, field +from pathlib import Path + +from nemo_curator.stages.audio.asr.datasets.base import BaseASRDatasetHandlerStage +from nemo_curator.stages.audio.asr.metadata import ASRMetadata +from nemo_curator.tasks import AudioTask, _EmptyTask + + +@dataclass +class _DummyASRDatasetHandler(BaseASRDatasetHandlerStage): + source_name: str = "dummy" + manifest_splits: list[str] | None = field(default_factory=lambda: ["dev"]) + + def process(self, _: _EmptyTask) -> list[AudioTask]: + return [] + + +def test_base_asr_dataset_handler_writes_split_manifests(tmp_path: Path) -> None: + stage = _DummyASRDatasetHandler( + raw_data_dir=str(tmp_path / "raw"), + output_dir=str(tmp_path / "out"), + langs=["gu"], + write_manifest=True, + ) + stage.setup_on_node() + meta = ASRMetadata( + audio_filepath=str(tmp_path / "out" / "gu" / "dev" / "audio" / "sample.wav"), + text="ગુજરાતી", + duration=1.0, + lang="gu", + split_type="dev", + source="dummy", + sample_rate=16000, + num_channels=1, + orig_sample_rate=16000, + orig_num_channels=1, + ) + + stage.write_manifest_entry(meta) + stage.teardown() + + manifest_path = tmp_path / "out" / "gu" / "dev.jsonl" + with manifest_path.open(encoding="utf-8") as f: + rows = [json.loads(line) for line in f if line.strip()] + assert rows == [meta.to_dict()] diff --git a/tests/stages/audio/asr/datasets/test_huggingface_handler.py b/tests/stages/audio/asr/datasets/test_huggingface_handler.py new file mode 100644 index 0000000000..b950cedb94 --- /dev/null +++ b/tests/stages/audio/asr/datasets/test_huggingface_handler.py @@ -0,0 +1,197 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf +from datasets import Dataset + +from nemo_curator.stages.audio.asr.datasets import HuggingFaceASRDatasetHandler +from nemo_curator.tasks import _EmptyTask +from tests.stages.audio.asr.datasets.conftest import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE + + +def _write_audio(path: Path) -> None: + samples = np.linspace(0.0, 1.0, INPUT_SAMPLE_RATE // 20, endpoint=False, dtype=np.float32) + tone = 0.1 * np.sin(2 * np.pi * 220 * samples) + stereo = np.stack([tone, tone * 0.5], axis=1) + sf.write(path, stereo, INPUT_SAMPLE_RATE, subtype="PCM_16") + + +def _save_hf_split(raw_root: Path, lang: str, split: str, rows: dict[str, list[object]]) -> Path: + split_dir = raw_root / lang / split + dataset = Dataset.from_dict(rows) + dataset.save_to_disk(str(split_dir)) + return split_dir + + +def test_huggingface_handler_writes_manifest_without_returning_tasks(tmp_path: Path) -> None: + raw_root = tmp_path / "raw" + audio_dir = tmp_path / "source_audio" + audio_dir.mkdir(parents=True) + audio_paths = [] + for i in range(3): + audio_path = audio_dir / f"sample_{i}.wav" + _write_audio(audio_path) + audio_paths.append(str(audio_path)) + _save_hf_split( + raw_root, + "gu", + "train", + { + "audio_filepath": audio_paths, + "text": ["ગુજરાતી વાક્ય"] * len(audio_paths), + "speaker_id": [f"speaker_{i}" for i in range(len(audio_paths))], + "gender": ["Female"] * len(audio_paths), + "fname": [f"utt_{i}.wav" for i in range(len(audio_paths))], + }, + ) + output_dir = tmp_path / "out" + stage = HuggingFaceASRDatasetHandler( + raw_data_dir=str(raw_root), + output_dir=str(output_dir), + langs=["gu"], + source_name="Kathbath", + native_splits=["train"], + split_dir_pattern="{lang}/{split}", + extraction_workers=2, + write_manifest=True, + manifest_splits=["train"], + ) + stage.setup_on_node() + stage.setup() + + tasks = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + stage.teardown() + + assert tasks == [] + assert metrics["input_rows"] == 3 + assert metrics["emitted_tasks"] == 0 + + with (output_dir / "gu" / "train.jsonl").open(encoding="utf-8") as f: + rows = [json.loads(line) for line in f if line.strip()] + assert len(rows) == 3 + for i, row in enumerate(rows): + assert row["audio_filepath"] == str(output_dir / "gu" / "train" / "audio" / f"utt_{i}.wav") + audio_info = sf.info(row["audio_filepath"]) + assert audio_info.samplerate == OUTPUT_SAMPLE_RATE + assert audio_info.channels == 1 + assert audio_info.subtype == "PCM_16" + assert row["text"] == "ગુજરાતી વાક્ય" + assert row["duration"] == pytest.approx(audio_info.frames / OUTPUT_SAMPLE_RATE) + assert row["lang"] == "gu" + assert row["split_type"] == "train" + assert row["source"] == "Kathbath" + assert row["orig_sample_rate"] == INPUT_SAMPLE_RATE + assert row["orig_num_channels"] == 2 + assert row["speaker_id"] == f"speaker_{i}" + assert row["gender"] == "Female" + assert row["fname"] == f"utt_{i}.wav" + + +def test_huggingface_handler_reports_skipped_rows_and_maps_splits(tmp_path: Path) -> None: + raw_root = tmp_path / "raw" + audio_dir = tmp_path / "source_audio" + valid_audio = audio_dir / "valid.wav" + missing_text_audio = audio_dir / "missing_text.wav" + audio_dir.mkdir(parents=True) + _write_audio(valid_audio) + _write_audio(missing_text_audio) + _save_hf_split( + raw_root, + "gu", + "valid", + { + "audio_filepath": [ + str(valid_audio), + str(missing_text_audio), + None, + str(audio_dir / "does_not_exist.wav"), + ], + "text": ["valid text", None, "missing audio", "bad audio"], + }, + ) + stage = HuggingFaceASRDatasetHandler( + raw_data_dir=str(raw_root), + output_dir=str(tmp_path / "out"), + langs=["gu"], + source_name="Shrutilipi", + native_splits=["valid"], + split_mapping={"valid": "dev"}, + extraction_workers=1, + ) + stage.setup() + + tasks = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + + assert len(tasks) == 1 + assert tasks[0].data["split_type"] == "dev" + assert metrics["input_rows"] == 4 + assert metrics["emitted_tasks"] == 1 + assert metrics["skipped_missing_text"] == 1 + assert metrics["skipped_missing_audio"] == 1 + assert metrics["skipped_audio_load"] == 1 + + +def test_huggingface_handler_accepts_dataset_source_names(tmp_path: Path) -> None: + kwargs = { + "raw_data_dir": str(tmp_path / "raw"), + "output_dir": str(tmp_path / "out"), + "langs": ["gu"], + "native_splits": ["train", "valid"], + "split_dir_pattern": "{lang}/{split}", + } + + kathbath = HuggingFaceASRDatasetHandler(source_name="Kathbath", **kwargs) + shrutilipi = HuggingFaceASRDatasetHandler(source_name="Shrutilipi", **kwargs) + + assert kathbath.source_name == "Kathbath" + assert kathbath.native_splits == ["train", "valid"] + assert kathbath.split_dir_pattern == "{lang}/{split}" + assert shrutilipi.source_name == "Shrutilipi" + assert shrutilipi.native_splits == ["train", "valid"] + assert shrutilipi.split_dir_pattern == "{lang}/{split}" + + +def test_huggingface_handler_rejects_unsupported_source_name(tmp_path: Path) -> None: + with pytest.raises(ValueError, match="Unsupported source_name"): + HuggingFaceASRDatasetHandler( + raw_data_dir=str(tmp_path / "raw"), + output_dir=str(tmp_path / "out"), + langs=["gu"], + source_name="UnitHF", + ) + + +def test_huggingface_handler_uses_known_source_field_mappings(tmp_path: Path) -> None: + kwargs = { + "raw_data_dir": str(tmp_path / "raw"), + "output_dir": str(tmp_path / "out"), + "langs": ["gu"], + } + + kathbath = HuggingFaceASRDatasetHandler(source_name="Kathbath", **kwargs) + shrutilipi = HuggingFaceASRDatasetHandler(source_name="Shrutilipi", **kwargs) + + assert kathbath._source_field_mapping() == { + "fname": "fname", + "gender": "gender", + "speaker_id": "speaker_id", + } + assert shrutilipi._source_field_mapping() == {} diff --git a/tests/stages/audio/asr/datasets/test_indicvoices_handler.py b/tests/stages/audio/asr/datasets/test_indicvoices_handler.py new file mode 100644 index 0000000000..fbf016c106 --- /dev/null +++ b/tests/stages/audio/asr/datasets/test_indicvoices_handler.py @@ -0,0 +1,180 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from collections import Counter +from pathlib import Path + +import numpy as np +import pytest +import soundfile as sf +from datasets import Dataset + +from nemo_curator.stages.audio.asr.datasets import huggingface as huggingface_module +from nemo_curator.stages.audio.asr.datasets.huggingface import HuggingFaceASRDatasetHandler +from nemo_curator.tasks import _EmptyTask +from tests.stages.audio.asr.datasets.conftest import INPUT_SAMPLE_RATE, OUTPUT_SAMPLE_RATE + + +def _indicvoices_stage(**kwargs: object) -> HuggingFaceASRDatasetHandler: + return HuggingFaceASRDatasetHandler( + source_name="IndicVoices", + valid_split_strategy="dev_test", + **kwargs, + ) + + +def test_indicvoices_handler_ingests_realistic_hf_dataset( + tmp_path: Path, + indicvoices_raw_dataset: tuple[Path, int], +) -> None: + raw_root, total_rows = indicvoices_raw_dataset + output_dir = tmp_path / "out" + stage = _indicvoices_stage( + raw_data_dir=str(raw_root), + output_dir=str(output_dir), + langs=["gu"], + native_splits=["valid"], + split_dir_pattern="{split}", + extraction_workers=2, + ) + stage.setup() + tasks = stage.process(_EmptyTask(dataset_name="test", data=None)) + + assert len(tasks) == total_rows + split_helper = _indicvoices_stage(raw_data_dir=str(raw_root), output_dir=str(output_dir), langs=["gu"]) + expected_counts = Counter(split_helper.assign_split("valid", f"gu_valid_{i}") for i in range(total_rows)) + actual_counts = Counter(task.data["split_type"] for task in tasks) + assert actual_counts == expected_counts + assert actual_counts["dev"] + actual_counts["test"] == total_rows + assert 0.55 <= actual_counts["dev"] / total_rows <= 0.65 + + for task in tasks: + audio_path = Path(task.data["audio_filepath"]) + assert audio_path.parent == output_dir / "gu" / task.data["split_type"] / "audio" + wav_info = sf.info(audio_path) + assert wav_info.samplerate == OUTPUT_SAMPLE_RATE + assert wav_info.channels == 1 + assert wav_info.subtype == "PCM_16" + assert task.data["orig_sample_rate"] == INPUT_SAMPLE_RATE + assert task.data["orig_num_channels"] == 2 + assert task.data["lang"] == "gu" + assert task.data["source"] == "IndicVoices" + assert task.data["speaker_id"].startswith("speaker_") + assert task.data["gender"] in {"Female", "Male"} + assert task.data["age"] == "30-45" + assert "age_group" not in task.data + assert task.data["scenario"] == "Extempore" + + +def test_indicvoices_handler_reports_skipped_rows_from_realistic_hf_dataset(tmp_path: Path) -> None: + raw_root = tmp_path / "raw" + audio_dir = tmp_path / "source_audio" + valid_dir = raw_root / "valid" + audio_dir.mkdir(parents=True) + + valid_audio = audio_dir / "valid.wav" + missing_text_audio = audio_dir / "missing_text.wav" + sf.write(valid_audio, np.zeros(INPUT_SAMPLE_RATE // 20, dtype=np.float32), INPUT_SAMPLE_RATE) + sf.write(missing_text_audio, np.zeros(INPUT_SAMPLE_RATE // 20, dtype=np.float32), INPUT_SAMPLE_RATE) + + dataset = Dataset.from_dict( + { + "audio_filepath": [ + str(valid_audio), + str(missing_text_audio), + None, + str(audio_dir / "does_not_exist.wav"), + ], + "text": ["valid text", None, "missing audio", "bad audio"], + "speaker_id": ["speaker_0", "speaker_1", "speaker_2", "speaker_3"], + } + ) + dataset.save_to_disk(str(valid_dir)) + + stage = _indicvoices_stage( + raw_data_dir=str(raw_root), + output_dir=str(tmp_path / "out"), + langs=["gu"], + native_splits=["valid"], + split_dir_pattern="{split}", + extraction_workers=1, + ) + stage.setup() + tasks = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + + assert len(tasks) == 1 + assert metrics["input_rows"] == 4 + assert metrics["emitted_tasks"] == 1 + assert metrics["skipped_missing_text"] == 1 + assert metrics["skipped_missing_audio"] == 1 + assert metrics["skipped_audio_load"] == 1 + + +def test_indicvoices_handler_writes_manifests_when_enabled( + tmp_path: Path, + indicvoices_raw_dataset: tuple[Path, int], + monkeypatch: pytest.MonkeyPatch, +) -> None: + log_messages = [] + monkeypatch.setattr(huggingface_module.logger, "info", log_messages.append) + raw_root, total_rows = indicvoices_raw_dataset + output_dir = tmp_path / "out" + stage = _indicvoices_stage( + raw_data_dir=str(raw_root), + output_dir=str(output_dir), + langs=["gu"], + native_splits=["valid"], + split_dir_pattern="{split}", + extraction_workers=2, + write_manifest=True, + ) + stage.setup_on_node() + stage.setup() + + tasks = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + stage.teardown() + + assert tasks == [] + rows = [] + for split in ["dev", "test"]: + manifest_path = output_dir / "gu" / f"{split}.jsonl" + assert manifest_path.exists() + with manifest_path.open(encoding="utf-8") as f: + rows.extend(json.loads(line) for line in f if line.strip()) + + assert len(rows) == total_rows + actual_counts = Counter(row["split_type"] for row in rows) + expected_durations = { + "train": 0.0, + "dev": sum(row["duration"] for row in rows if row["split_type"] == "dev"), + "test": sum(row["duration"] for row in rows if row["split_type"] == "test"), + } + assert metrics["emitted_tasks"] == 0 + assert metrics["duration_train_seconds"] == pytest.approx(expected_durations["train"]) + assert metrics["duration_dev_seconds"] == pytest.approx(expected_durations["dev"]) + assert metrics["duration_test_seconds"] == pytest.approx(expected_durations["test"]) + assert metrics["duration_train_hours"] == pytest.approx(expected_durations["train"] / 3600) + assert metrics["duration_dev_hours"] == pytest.approx(expected_durations["dev"] / 3600) + assert metrics["duration_test_hours"] == pytest.approx(expected_durations["test"] / 3600) + assert any("duration_by_split_hours=(train=0.00h, dev=0.00h, test=0.00h)" in msg for msg in log_messages) + assert (output_dir / "gu" / "dev" / "audio").is_dir() + assert (output_dir / "gu" / "test" / "audio").is_dir() + + for split in ["dev", "test"]: + split_rows = [row for row in rows if row["split_type"] == split] + assert len(split_rows) == actual_counts[split] + assert all(Path(row["audio_filepath"]).parent == output_dir / "gu" / split / "audio" for row in split_rows) diff --git a/tests/stages/audio/asr/io/__init__.py b/tests/stages/audio/asr/io/__init__.py new file mode 100644 index 0000000000..3e4afde2e2 --- /dev/null +++ b/tests/stages/audio/asr/io/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/tests/stages/audio/asr/io/test_split_manifest_writer.py b/tests/stages/audio/asr/io/test_split_manifest_writer.py new file mode 100644 index 0000000000..58567a742f --- /dev/null +++ b/tests/stages/audio/asr/io/test_split_manifest_writer.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +from nemo_curator.stages.audio.asr.io.split_manifest_writer import SplitAwareManifestWriter +from nemo_curator.tasks import AudioTask + + +def test_split_manifest_writer_defaults_to_split_jsonl(tmp_path: Path) -> None: + stage = SplitAwareManifestWriter(output_dir=str(tmp_path), langs=["gu"], splits=["dev"]) + stage.setup() + + task = AudioTask(data={"lang": "gu", "split_type": "dev", "text": "ગુજરાતી"}) + result = stage.process(task) + stage.teardown() + + assert result is task + manifest_path = tmp_path / "gu" / "dev.jsonl" + assert manifest_path.exists() + with manifest_path.open(encoding="utf-8") as f: + rows = [json.loads(line) for line in f if line.strip()] + assert rows == [task.data] + + +def test_split_manifest_writer_uses_output_filename_pattern(tmp_path: Path) -> None: + stage = SplitAwareManifestWriter( + output_dir=str(tmp_path), + langs=["gu"], + splits=["dev", "test"], + output_filename_pattern="{split}_normalized.jsonl", + ) + stage.setup() + + dev_task = AudioTask(data={"lang": "gu", "split_type": "dev", "text": "ગુજરાતી"}) + test_task = AudioTask(data={"lang": "gu", "split_type": "test", "text": "વાક્ય"}) + stage.process(dev_task) + stage.process(test_task) + stage.teardown() + + assert (tmp_path / "gu" / "dev_normalized.jsonl").exists() + assert (tmp_path / "gu" / "test_normalized.jsonl").exists() + assert not (tmp_path / "gu" / "dev.jsonl").exists() diff --git a/tests/stages/audio/asr/io/test_tarred_dataset_writer.py b/tests/stages/audio/asr/io/test_tarred_dataset_writer.py new file mode 100644 index 0000000000..4e66d3737d --- /dev/null +++ b/tests/stages/audio/asr/io/test_tarred_dataset_writer.py @@ -0,0 +1,124 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any + +import pytest + +from nemo_curator.stages.audio.asr.io import TarredAudioDatasetWriterStage as ExportedTarredAudioDatasetWriterStage +from nemo_curator.stages.audio.asr.io import tarred_dataset_writer as writer_module +from nemo_curator.stages.audio.asr.io.tarred_dataset_writer import TarredAudioDatasetWriterStage +from nemo_curator.tasks import _EmptyTask + + +def test_tarred_dataset_writer_is_exported() -> None: + assert ExportedTarredAudioDatasetWriterStage is TarredAudioDatasetWriterStage + + +def test_tarred_dataset_writer_rejects_mismatched_manifest_and_target_dirs() -> None: + with pytest.raises(ValueError, match="same length"): + TarredAudioDatasetWriterStage( + manifest_paths=["train.jsonl", "dev.jsonl"], + target_dirs=["tarred/train"], + num_shards=2, + max_duration=20.0, + ) + + +def test_tarred_dataset_writer_accepts_single_manifest_and_target_dir_strings( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, Any]] = [] + + def fake_create_tar_datasets(**kwargs: Any) -> None: # noqa: ANN401 + calls.append(kwargs) + + monkeypatch.setattr(writer_module, "create_tar_datasets", fake_create_tar_datasets) + manifest_path = str(tmp_path / "train.jsonl") + target_dir = str(tmp_path / "tarred_train") + stage = TarredAudioDatasetWriterStage( + manifest_paths=manifest_path, + target_dirs=target_dir, + num_shards=2, + max_duration=20.0, + dry_run=True, + ) + + output = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + + assert output == [] + assert [call["manifest_path"] for call in calls] == [manifest_path] + assert [call["target_dir"] for call in calls] == [target_dir] + assert metrics["input_manifests"] == 1 + assert metrics["emitted_tasks"] == 0 + + +def test_tarred_dataset_writer_runs_converter_once_per_manifest( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + calls: list[dict[str, Any]] = [] + + def fake_create_tar_datasets(**kwargs: Any) -> None: # noqa: ANN401 + calls.append(kwargs) + + monkeypatch.setattr(writer_module, "create_tar_datasets", fake_create_tar_datasets) + manifest_paths = [str(tmp_path / "train.jsonl"), str(tmp_path / "dev.jsonl")] + target_dirs = [str(tmp_path / "tarred_train"), str(tmp_path / "tarred_dev")] + stage = TarredAudioDatasetWriterStage( + manifest_paths=manifest_paths, + target_dirs=target_dirs, + num_shards=8, + max_duration=25.0, + min_duration=0.1, + shuffle=True, + keep_files_together=True, + sort_in_shards=True, + buckets_num=1, + dynamic_buckets_num=16, + shuffle_seed=123, + no_shard_manifests=True, + force_codec="flac", + workers=4, + slice_with_offset=True, + only_manifests=True, + dry_run=True, + ) + + output = stage.process(_EmptyTask(dataset_name="test", data=None)) + metrics = stage._consume_custom_metrics() + + assert output == [] + assert [call["manifest_path"] for call in calls] == manifest_paths + assert [call["target_dir"] for call in calls] == target_dirs + assert all(call["num_shards"] == 8 for call in calls) + assert all(call["max_duration"] == 25.0 for call in calls) + assert all(call["min_duration"] == 0.1 for call in calls) + assert all(call["shuffle"] is True for call in calls) + assert all(call["keep_files_together"] is True for call in calls) + assert all(call["sort_in_shards"] is True for call in calls) + assert all(call["buckets_num"] == 1 for call in calls) + assert all(call["dynamic_buckets_num"] == 16 for call in calls) + assert all(call["shuffle_seed"] == 123 for call in calls) + assert all(call["no_shard_manifests"] is True for call in calls) + assert all(call["force_codec"] == "flac" for call in calls) + assert all(call["workers"] == 4 for call in calls) + assert all(call["slice_with_offset"] is True for call in calls) + assert all(call["only_manifests"] is True for call in calls) + assert all(call["dry_run"] is True for call in calls) + assert metrics["input_manifests"] == 2 + assert metrics["emitted_tasks"] == 0 diff --git a/tests/stages/audio/asr/normalization/__init__.py b/tests/stages/audio/asr/normalization/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/stages/audio/asr/normalization/test_transcript_normalization.py b/tests/stages/audio/asr/normalization/test_transcript_normalization.py new file mode 100644 index 0000000000..d3a0cd2471 --- /dev/null +++ b/tests/stages/audio/asr/normalization/test_transcript_normalization.py @@ -0,0 +1,197 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import pytest + +from nemo_curator.stages.audio.asr.normalization import TranscriptNormalizationStage +from nemo_curator.stages.audio.asr.normalization.transcript import ( + _PUNCTUATION_CHARS_BY_LANG, + _RESOURCE_ROOT, + _load_alphabet, +) +from nemo_curator.tasks import AudioTask + + +def test_language_resources_are_flattened_without_data_subdirectory() -> None: + assert (_RESOURCE_ROOT / "remove_chars.txt").exists() + for lang in ["gu", "hi", "kn", "te", "ml", "mr", "pa", "ta", "bn", "ur", "en"]: + lang_dir = _RESOURCE_ROOT / lang + assert (lang_dir / "alphabet.txt").exists() + assert (lang_dir / "pretok.jsonl").exists() + assert not (lang_dir / "remove_chars.txt").exists() + assert not (lang_dir / "pnc_chars.txt").exists() + assert not (lang_dir / "data").exists() + assert lang in _PUNCTUATION_CHARS_BY_LANG + + +def test_standard_punctuation_dictionary_preserves_language_specific_chars() -> None: + assert _PUNCTUATION_CHARS_BY_LANG["bn"] == ",?\u0964" + assert _PUNCTUATION_CHARS_BY_LANG["ml"] == ".,?!\u0964\u0965" + assert _PUNCTUATION_CHARS_BY_LANG["pa"] == ".,?\u0964" + assert _PUNCTUATION_CHARS_BY_LANG["ur"] == "\u060c\u061f\u06d4" + + +@pytest.mark.parametrize( + ("lang", "text"), + [ + ("kn", "ಕನ್ನಡ ವಾಕ್ಯ"), + ("te", "తెలుగు వాక్యం"), + ("ml", "മലയാളം വാക്യം"), + ("mr", "मराठी वाक्य"), + ("pa", "ਪੰਜਾਬੀ ਵਾਕ"), + ("ta", "தமிழ் வாக்கியம்"), + ("bn", "বাংলা বাক্য"), + ("ur", "اردو جملہ"), + ], +) +def test_additional_indic_language_resources_are_loadable(lang: str, text: str) -> None: + stage = TranscriptNormalizationStage() + task = AudioTask(data={"text": text, "lang": lang}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == text + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + + +def test_alphabet_loader_includes_uppercase_variants_for_each_letter(tmp_path: Path) -> None: + alphabet_path = tmp_path / "alphabet.txt" + alphabet_path.write_text("a\nb\n", encoding="utf-8") + + assert _load_alphabet(alphabet_path) == {"a", "A", "b", "B"} + + +def test_gujarati_text_is_cleaned_in_place_and_marked_valid() -> None: + stage = TranscriptNormalizationStage() + task = AudioTask(data={"text": " ગુજરાતી—વાક્ય @@@ ", "lang": "gu", "duration": 1.5}) + + result = stage.process(task) + metrics = stage._consume_custom_metrics() + + assert result is task + assert task.data["text"] == "ગુજરાતી વાક્ય" + assert task.data["text_original"] == " ગુજરાતી—વાક્ય @@@ " + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + assert metrics["input_tasks"] == 1 + assert metrics["emitted_tasks"] == 1 + + +def test_lowercase_text_can_be_enabled() -> None: + stage = TranscriptNormalizationStage(lowercase_text=True) + task = AudioTask(data={"text": "Hello WORLD", "lang": "en"}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == "hello world" + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + + +def test_code_switch_languages_extend_known_alphabet() -> None: + stage = TranscriptNormalizationStage(code_switch_langs=["en"]) + task = AudioTask(data={"text": "ગુજરાતી Hello", "lang": "gu"}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == "ગુજરાતી Hello" + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + + +def test_code_switch_language_accepts_single_string() -> None: + stage = TranscriptNormalizationStage(code_switch_langs="en") + task = AudioTask(data={"text": "ગુજરાતી Hello", "lang": "gu"}) + + result = stage.process(task) + + assert result is task + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + + +def test_code_switch_languages_extend_punctuation_removal() -> None: + stage = TranscriptNormalizationStage(code_switch_langs=["en"]) + task = AudioTask(data={"text": "ગુજરાતી - Hello", "lang": "gu"}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == "ગુજરાતી Hello" + assert task.data["transcript_error"] is False + assert task.data["unknown_chars"] == {} + + +def test_hindi_pretok_replacements_are_applied() -> None: + stage = TranscriptNormalizationStage() + task = AudioTask(data={"text": "₹ ऎॊय़ऱॆ ळ शब्द", "lang": "hi"}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == "रुपये ऐोयरे ल शब्द" + assert task.data["transcript_error"] is False + + +def test_language_must_match_resource_directory_name() -> None: + stage = TranscriptNormalizationStage() + task = AudioTask(data={"text": "शब्द", "lang": "hindi"}) + + with pytest.raises(ValueError, match="Unsupported ASR normalization language"): + stage.process(task) + + +def test_unknown_chars_are_recorded_and_task_is_retained() -> None: + stage = TranscriptNormalizationStage() + task = AudioTask(data={"text": "ગુજરાતી xyz x ૧", "lang": "gu", "duration": 2.0}) + + result = stage.process(task) + metrics = stage._consume_custom_metrics() + + assert result is task + assert task.data["transcript_error"] is True + assert task.data["unknown_chars"] == {"x": 2, "y": 1, "z": 1, "૧": 1} + assert metrics["input_tasks"] == 1 + assert metrics["emitted_tasks"] == 1 + assert metrics["unknown_duration_seconds"] == pytest.approx(2.0) + + +def test_punctuation_removal_can_be_disabled() -> None: + stage = TranscriptNormalizationStage(remove_pnc_chars=False) + task = AudioTask(data={"text": "ગુજરાતી . વાક્ય", "lang": "gu", "duration": 2.0}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == "ગુજરાતી . વાક્ય" + assert task.data["transcript_error"] is True + assert task.data["unknown_chars"] == {".": 1} + + +def test_output_keys_keep_source_text_unchanged() -> None: + stage = TranscriptNormalizationStage(output_text_key="normalized_text", output_original_text_key="raw_text") + task = AudioTask(data={"text": " ગુજરાતી—વાક્ય ", "lang": "gu"}) + + result = stage.process(task) + + assert result is task + assert task.data["text"] == " ગુજરાતી—વાક્ય " + assert task.data["raw_text"] == " ગુજરાતી—વાક્ય " + assert task.data["normalized_text"] == "ગુજરાતી વાક્ય" diff --git a/tests/stages/audio/asr/normalization/test_transcript_stats.py b/tests/stages/audio/asr/normalization/test_transcript_stats.py new file mode 100644 index 0000000000..c1cd2ccb74 --- /dev/null +++ b/tests/stages/audio/asr/normalization/test_transcript_stats.py @@ -0,0 +1,289 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from pathlib import Path + +import pytest + +from nemo_curator.stages.audio.asr.normalization import TranscriptStatsStage +from nemo_curator.tasks import AudioTask + + +def _task( # noqa: PLR0913 + text: str, + duration: float, + split_type: str, + transcript_error: bool, + lang: str = "gu", + source: str = "IndicVoices", + extra: dict | None = None, +) -> AudioTask: + data = { + "text": text, + "lang": lang, + "source": source, + "duration": duration, + "split_type": split_type, + "transcript_error": transcript_error, + "unknown_chars": {}, + } + data.update(extra or {}) + return AudioTask(data=data) + + +def test_transcript_stats_aggregates_valid_invalid_unknowns_and_splits() -> None: + stage = TranscriptStatsStage() + tasks = [ + _task("abc", 1.0, "train", False), + _task("ગુજરાતી", 2.0, "dev", True, extra={"unknown_chars": {"x": 2, "y": 1}}), + _task("શબ્દ", 3.0, "test", False), + ] + + results = [stage.process(task) for task in tasks] + summary = stage.summary() + metrics = stage._consume_custom_metrics() + + assert results == tasks + assert summary["total_transcripts"] == 3 + assert summary["valid_transcripts"] == 2 + assert summary["invalid_transcripts"] == 1 + assert summary["valid_transcript_rate"] == pytest.approx(2 / 3) + assert summary["invalid_transcript_rate"] == pytest.approx(1 / 3) + assert summary["total_duration_hours"] == pytest.approx(6.0 / 3600) + assert summary["valid_duration_hours"] == pytest.approx(4.0 / 3600) + assert summary["invalid_duration_hours"] == pytest.approx(2.0 / 3600) + assert summary["total_chars"] == len("abcગુજરાતીશબ્દ") + assert summary["unique_known_chars"] == len(set("abcગુજરાતીશબ્દ")) + assert summary["unique_known_char_rate"] == pytest.approx(len(set("abcગુજરાતીશબ્દ")) / len("abcગુજરાતીશબ્દ")) + assert summary["unique_unknown_chars"] == 2 + assert summary["unique_unknown_char_rate"] == pytest.approx(2 / len("abcગુજરાતીશબ્દ")) + assert summary["unknown_char_details"] == { + "x": {"count": 2, "rate": pytest.approx(2 / len("abcગુજરાતીશબ્દ"))}, + "y": {"count": 1, "rate": pytest.approx(1 / len("abcગુજરાતીશબ્દ"))}, + } + assert "અ" in summary["alpha_minus_known_chars"] + assert "ગ" not in summary["alpha_minus_known_chars"] + assert summary["split_counts"] == { + "train": {"total": 1, "valid": 1, "invalid": 0}, + "dev": {"total": 1, "valid": 0, "invalid": 1}, + "test": {"total": 1, "valid": 1, "invalid": 0}, + } + assert summary["split_hours"] == { + "train": {"total": pytest.approx(1.0 / 3600), "valid": pytest.approx(1.0 / 3600), "invalid": 0.0}, + "dev": {"total": pytest.approx(2.0 / 3600), "valid": 0.0, "invalid": pytest.approx(2.0 / 3600)}, + "test": {"total": pytest.approx(3.0 / 3600), "valid": pytest.approx(3.0 / 3600), "invalid": 0.0}, + } + assert metrics["input_tasks"] == 3 + assert metrics["emitted_tasks"] == 3 + assert metrics["valid_transcripts"] == 2 + assert metrics["invalid_transcripts"] == 1 + assert metrics["total_duration_hours"] == pytest.approx(6.0 / 3600) + assert metrics["valid_duration_hours"] == pytest.approx(4.0 / 3600) + assert metrics["unique_known_chars"] == len(set("abcગુજરાતીશબ્દ")) + assert metrics["unique_known_char_rate"] == pytest.approx(len(set("abcગુજરાતીશબ્દ")) / len("abcગુજરાતીશબ્દ")) + assert metrics["unique_unknown_chars"] == 2 + assert metrics["unique_unknown_char_rate"] == pytest.approx(2 / len("abcગુજરાતીશબ્દ")) + + +def test_transcript_stats_formats_unknown_char_details() -> None: + stage = TranscriptStatsStage(log_top_n_unknown_chars=1) + text = "ગુજરાતીxxy" + stage.process(_task(text, 1.0, "train", True, extra={"unknown_chars": {"x": 2, "y": 1}})) + + summary = stage.summary() + assert summary["unknown_char_details"] == { + "x": {"count": 2, "rate": pytest.approx(2 / len(text))}, + "y": {"count": 1, "rate": pytest.approx(1 / len(text))}, + } + assert summary["by_language"]["gu"]["IndicVoices"]["unknown_char_details"] == summary["unknown_char_details"] + + formatted = stage.format_summary() + assert f"unknown_chars: x=count=2 rate={2 / len(text):.2%} (showing top 1 of 2)" in formatted + assert "y=count=1" not in formatted + + +def test_transcript_stats_groups_each_language_by_source() -> None: + stage = TranscriptStatsStage() + tasks = [ + _task("ગુજરાતી", 1.0, "train", False, lang="gu", source="IndicVoices"), + _task("શબ્દx", 2.0, "dev", True, lang="gu", source="IndicVoices", extra={"unknown_chars": {"x": 1}}), + _task("शब्द", 3.0, "test", False, lang="hi", source="OtherDataset"), + ] + + assert [stage.process(task) for task in tasks] == tasks + + summary = stage.summary() + grouped = summary["by_language"] + language_totals = summary["by_language_overall"] + assert set(grouped) == {"gu", "hi"} + assert set(grouped["gu"]) == {"IndicVoices"} + assert set(grouped["hi"]) == {"OtherDataset"} + assert grouped["gu"]["IndicVoices"]["total_transcripts"] == 2 + assert grouped["gu"]["IndicVoices"]["valid_transcripts"] == 1 + assert grouped["gu"]["IndicVoices"]["invalid_transcripts"] == 1 + assert grouped["gu"]["IndicVoices"]["split_counts"] == { + "train": {"total": 1, "valid": 1, "invalid": 0}, + "dev": {"total": 1, "valid": 0, "invalid": 1}, + } + assert grouped["hi"]["OtherDataset"]["total_transcripts"] == 1 + assert grouped["hi"]["OtherDataset"]["valid_duration_hours"] == pytest.approx(3.0 / 3600) + assert language_totals["gu"]["total_transcripts"] == 2 + assert language_totals["gu"]["valid_transcripts"] == 1 + assert language_totals["gu"]["invalid_transcripts"] == 1 + assert language_totals["hi"]["total_transcripts"] == 1 + + formatted = stage.format_summary() + assert "per_language_source:" in formatted + assert "lang=gu source=IndicVoices" in formatted + assert "transcripts: total=2 valid=1 (50.00%) invalid=1 (50.00%)" in formatted + assert ( + "split_counts: {'train': {'total': 1, 'valid': 1, 'invalid': 0}, 'dev': {'total': 1, 'valid': 0, 'invalid': 1}}" + in formatted + ) + assert "lang=hi source=OtherDataset" in formatted + assert "per_language_overall:" in formatted + assert "lang=gu overall" in formatted + assert "lang=hi overall" in formatted + + +def test_transcript_stats_alpha_minus_known_chars_includes_code_switch_alphabet() -> None: + stage = TranscriptStatsStage(code_switch_langs=["en"]) + stage.process(_task("ગુજરાતી Hello", 1.0, "train", False, lang="gu")) + + summary = stage.summary() + + assert "z" in summary["alpha_minus_known_chars"] + assert "Z" in summary["alpha_minus_known_chars"] + assert "H" not in summary["alpha_minus_known_chars"] + assert "z" in summary["by_language"]["gu"]["IndicVoices"]["alpha_minus_known_chars"] + assert "z" in summary["by_language_overall"]["gu"]["alpha_minus_known_chars"] + + +def test_transcript_stats_code_switch_language_accepts_single_string() -> None: + stage = TranscriptStatsStage(code_switch_langs="en") + stage.process(_task("ગુજરાતી Hello", 1.0, "train", False, lang="gu")) + + summary = stage.summary() + + assert "z" in summary["alpha_minus_known_chars"] + assert "H" not in summary["alpha_minus_known_chars"] + + +def test_transcript_stats_can_drop_invalid_after_counting() -> None: + stage = TranscriptStatsStage(drop_invalid=True) + valid = _task("abc", 1.0, "train", False) + invalid = _task("abcx", 2.0, "train", True, extra={"unknown_chars": {"x": 1}}) + + assert stage.process(valid) is valid + assert stage.process(invalid) is None + + summary = stage.summary() + metrics = stage._consume_custom_metrics() + assert summary["total_transcripts"] == 2 + assert summary["valid_transcripts"] == 1 + assert summary["invalid_transcripts"] == 1 + assert metrics["input_tasks"] == 2 + assert metrics["emitted_tasks"] == 1 + assert metrics["dropped_invalid"] == 1 + + +def test_transcript_stats_accepts_multiple_languages() -> None: + stage = TranscriptStatsStage() + assert stage.process(_task("abc", 1.0, "train", False, lang="gu")) is not None + assert stage.process(_task("शब्द", 1.0, "train", False, lang="hi")) is not None + + +def test_transcript_stats_runs_as_single_worker_for_exact_dataset_summary() -> None: + assert TranscriptStatsStage().num_workers() == 1 + + +def test_transcript_stats_format_summary_rounds_split_hours() -> None: + stage = TranscriptStatsStage() + stage.process(_task("ગુજરાતી", 3661.0, "dev", False)) + + formatted = stage.format_summary() + + assert "split_hours: {'dev': {'total': 1.02, 'valid': 1.02, 'invalid': 0.0}}" in formatted + + +def test_transcript_stats_writes_summary_during_processing(tmp_path: Path) -> None: + summary_path = tmp_path / "stats" / "summary.json" + stage = TranscriptStatsStage(output_summary_path=str(summary_path)) + stage.setup_on_node() + + assert summary_path.exists() + + stage.process(_task("ગુજરાતી", 1.0, "dev", False)) + with summary_path.open(encoding="utf-8") as f: + first_summary = json.load(f) + assert first_summary["total_transcripts"] == 1 + assert first_summary["valid_transcripts"] == 1 + + stage.process(_task("શબ્દ", 2.0, "test", False)) + + with summary_path.open(encoding="utf-8") as f: + final_summary = json.load(f) + assert final_summary["total_transcripts"] == 2 + assert final_summary["valid_transcripts"] == 2 + assert final_summary["split_counts"] == { + "dev": {"total": 1, "valid": 1, "invalid": 0}, + "test": {"total": 1, "valid": 1, "invalid": 0}, + } + stage.teardown() + + +def test_transcript_stats_formats_summary_from_path(tmp_path: Path) -> None: + summary_path = tmp_path / "summary.json" + stage = TranscriptStatsStage(log_top_n_unknown_chars=1) + stage.process(_task("ગુજરાતીxxy", 1.0, "train", True, extra={"unknown_chars": {"x": 2, "y": 1}})) + summary_path.write_text(json.dumps(stage.summary(), ensure_ascii=False), encoding="utf-8") + + formatted = TranscriptStatsStage.format_summary_from_path(str(summary_path), top_n_unknown_chars=1) + + assert "per_language_source:" in formatted + assert "lang=gu source=IndicVoices" in formatted + assert "unknown_chars: x=count=2" in formatted + assert "per_language_overall:" in formatted + assert "global:" in formatted + + +def test_transcript_stats_logs_summary_from_path(tmp_path: Path) -> None: + summary_path = tmp_path / "summary.json" + stage = TranscriptStatsStage(log_top_n_unknown_chars=1) + stage.process(_task("ગુજરાતીxxy", 1.0, "train", True, extra={"unknown_chars": {"x": 2, "y": 1}})) + summary_path.write_text(json.dumps(stage.summary(), ensure_ascii=False), encoding="utf-8") + + formatted = TranscriptStatsStage.log_summary_from_path(str(summary_path), top_n_unknown_chars=1) + + assert formatted is not None + assert "lang=gu source=IndicVoices" in formatted + assert "unknown_chars: x=count=2" in formatted + + +def test_transcript_stats_loads_summary_json(tmp_path: Path) -> None: + summary_path = tmp_path / "summary.json" + expected = {"total_transcripts": 2, "valid_transcripts": 1, "invalid_transcripts": 1} + summary_path.write_text(json.dumps(expected), encoding="utf-8") + + assert TranscriptStatsStage.load_summary(str(summary_path)) == expected + + +def test_transcript_stats_loads_last_jsonl_record(tmp_path: Path) -> None: + summary_path = tmp_path / "summary.json" + first = {"total_transcripts": 1} + second = {"total_transcripts": 2} + summary_path.write_text(f"{json.dumps(first)}\n{json.dumps(second)}\n", encoding="utf-8") + + assert TranscriptStatsStage.load_summary(str(summary_path)) == second diff --git a/tests/stages/audio/asr/test_metadata.py b/tests/stages/audio/asr/test_metadata.py new file mode 100644 index 0000000000..b90eb98c8d --- /dev/null +++ b/tests/stages/audio/asr/test_metadata.py @@ -0,0 +1,63 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nemo_curator.stages.audio.asr.metadata import ASRMetadata + + +def test_asr_metadata_includes_optional_speaker_and_verbatim_fields() -> None: + metadata = ASRMetadata( + audio_filepath="sample.wav", + text="normalized text", + duration=1.0, + lang="en", + split_type="train", + source="UnitTest", + extra={ + "gender": "extra_gender", + "speaker_id": "extra_speaker", + "age": "extra_age", + "text_verbatim": "extra verbatim", + "custom_key": "custom_value", + }, + ) + + row = metadata.to_dict() + + assert row["gender"] is None + assert row["speaker_id"] is None + assert row["age"] is None + assert row["text_verbatim"] is None + assert row["custom_key"] == "custom_value" + + +def test_asr_metadata_round_trips_optional_speaker_and_verbatim_fields() -> None: + metadata = ASRMetadata.from_dict( + { + "audio_filepath": "sample.wav", + "text": "normalized text", + "duration": 1.0, + "lang": "en", + "split_type": "train", + "source": "UnitTest", + "gender": "female", + "speaker_id": "speaker_1", + "age": "adult", + "text_verbatim": "Original Text", + } + ) + + assert metadata.gender == "female" + assert metadata.speaker_id == "speaker_1" + assert metadata.age == "adult" + assert metadata.text_verbatim == "Original Text" diff --git a/tests/stages/audio/test_common.py b/tests/stages/audio/test_common.py index 0e50ecc1cc..30620d4df0 100644 --- a/tests/stages/audio/test_common.py +++ b/tests/stages/audio/test_common.py @@ -22,6 +22,7 @@ import numpy as np import pytest import torch +from omegaconf import OmegaConf from nemo_curator.backends.xenna import XennaExecutor from nemo_curator.pipeline import Pipeline @@ -297,6 +298,16 @@ def test_composite_discovers_nested_directory(self) -> None: assert partitioner.file_paths == str(nested) assert partitioner.file_extensions == [".jsonl", ".json"] + def test_composite_accepts_hydra_list_config(self, tmp_path: Path) -> None: + manifests = [str(tmp_path / "dev.jsonl"), str(tmp_path / "test.jsonl")] + manifest_path = OmegaConf.create(manifests) + composite = ManifestReader(manifest_path=manifest_path) + stages = composite.decompose() + + partitioner = stages[0] + assert partitioner.file_paths == manifests + assert all(isinstance(path, str) for path in partitioner.file_paths) + def test_ignores_non_jsonl_files(self) -> None: nested = self._nested_dir() txt_files = list(nested.rglob("*.txt")) diff --git a/tutorials/audio/README.md b/tutorials/audio/README.md index 928f5d359d..dfd9765753 100644 --- a/tutorials/audio/README.md +++ b/tutorials/audio/README.md @@ -20,6 +20,7 @@ sudo apt-get install -y ffmpeg | Tutorial | Description | Files | |----------|-------------|-------| +| **[ASR Data Processing](asr_data_pipeline/)** | Standardize downloaded ASR datasets into normalized train/dev/test manifests | `main.py`, `README.md`, `configs/indicvoices.yaml` | | **[FLEURS Dataset](fleurs/)** | Complete pipeline for multilingual speech data | `pipeline.py`, `run.py`, `pipeline.yaml` | | **[Audio Tagging](tagging/)** | Label raw audio for TTS/ASR via diarization, alignment, and quality metrics | `main.py`, `tts_pipeline.yaml`, `asr_pipeline.yaml` | | **[ALM Data Pipeline](alm/)** | Create training windows for Audio Language Models | `main.py`, `pipeline.yaml` | diff --git a/tutorials/audio/asr_data_pipeline/README.md b/tutorials/audio/asr_data_pipeline/README.md new file mode 100644 index 0000000000..cc5b998861 --- /dev/null +++ b/tutorials/audio/asr_data_pipeline/README.md @@ -0,0 +1,487 @@ +# ASR Data Processing Pipeline + +This tutorial shows how to use NeMo Curator to turn already-downloaded ASR datasets into normalized, split-aware training manifests. + +The pipeline is designed for public ASR corpora where the user handles download access and storage, while Curator handles: + +- dataset-specific ingestion and split assignment +- audio conversion to WAV, 16 kHz, mono, PCM16 +- canonical ASR metadata creation +- language-resource-driven transcript normalization +- transcript quality statistics by language and source +- train/dev/test JSONL manifest writing + +## When to Use This Pipeline + +Use this pipeline when you already have a dataset on disk and want to convert it into a consistent ASR training format. For example, a downloaded HuggingFace Arrow dataset, a tarred audio corpus with transcripts, or a custom internal dataset can each be wrapped by an ingestion handler that emits the same `AudioTask` shape. + +This is different from the audio tagging tutorial. The tagging pipeline starts from raw unlabelled audio and creates labels using diarization and ASR models. This ASR data pipeline starts from labelled ASR data and standardizes it for training. + +## Pipeline Overview + +```text +Downloaded dataset + | + v +Dataset ingestion handler + - reads the source-specific layout + - decodes or extracts audio + - writes WAV/16 kHz/mono/PCM16 + - assigns train/dev/test split_type + - emits AudioTask with ASRMetadata fields + | + v +TranscriptNormalizationStage + - loads language resources + - applies pretok replacements + - removes configured characters + - records unknown characters + | + v +TranscriptStatsStage + - counts valid/invalid transcripts + - tracks hours and split distribution + - groups summary by language and source + - optionally drops invalid tasks + | + v +SplitAwareManifestWriter + - writes per-language, per-split JSONL manifests +``` + +The canonical task data comes from `ASRMetadata` and includes fields such as: + +```json +{ + "audio_filepath": "/data/curated/gu/dev/audio/gu_valid_0.wav", + "text": "ગુજરાતી વાક્ય", + "duration": 3.21, + "lang": "gu", + "split_type": "dev", + "source": "IndicVoices", + "sample_rate": 16000, + "num_channels": 1, + "orig_sample_rate": 48000, + "orig_num_channels": 1 +} +``` + +## Quick Start: IndicVoices + +The repository includes an IndicVoices ingestion handler and a YAML config at `configs/indicvoices.yaml`. + +Install the audio dependencies from the Curator repository root: + +```bash +uv sync --extra audio_cuda12 +source .venv/bin/activate +``` + +Run the YAML pipeline with the ASR data pipeline runner: + +```bash +python tutorials/audio/asr_data_pipeline/main.py \ + --config-path ../../../configs \ + --config-name indicvoices \ + raw_data_dir=/data/asr/IndicVoices/raw \ + output_dir=/data/asr/IndicVoices/curated \ + 'langs=[gu]' \ + 'stages.0.native_splits=[valid]' \ + 'stages.0.split_dir_pattern={lang}/{split}' \ + stages.0.extraction_workers=32 +``` + +For a single-language sample whose downloaded directory looks like `/data/asr/gu/indic_voices/valid`, use: + +```bash +python tutorials/audio/asr_data_pipeline/main.py \ + --config-path ../../../configs \ + --config-name indicvoices \ + raw_data_dir=/data/asr/gu/indic_voices \ + output_dir=/data/asr/gu/indic_voices_curated \ + 'langs=[gu]' \ + 'stages.0.native_splits=[valid]' \ + 'stages.0.split_dir_pattern={split}' \ + stages.0.extraction_workers=16 +``` + +`HuggingFaceASRDatasetHandler` passes native `train` through as `train`. For IndicVoices, set `valid_split_strategy: dev_test` to deterministically split native `valid` into `dev` and `test` using `dev_fraction`, which defaults to `0.6`. + +## Example YAML + +The key parts of `configs/indicvoices.yaml` are: + +```yaml +raw_data_dir: /data/asr/IndicVoices/raw +output_dir: /data/asr/IndicVoices/curated +langs: + - gu + +stages: + - _target_: nemo_curator.stages.audio.asr.datasets.huggingface.HuggingFaceASRDatasetHandler + raw_data_dir: ${raw_data_dir} + output_dir: ${output_dir}/uncleaned + langs: ${langs} + source_name: IndicVoices + native_splits: + - valid + split_dir_pattern: "{lang}/{split}" + valid_split_strategy: dev_test + extraction_workers: 70 + dev_fraction: 0.6 + write_manifest: true + manifest_splits: + - train + - dev + - test + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptNormalizationStage + text_key: text + lang_key: lang + output_text_key: text + output_original_text_key: text_original + remove_pnc_chars: false + lowercase_text: false + code_switch_langs: [] + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptStatsStage + text_key: text + duration_key: duration + split_key: split_type + unknown_chars_key: unknown_chars + transcript_error_key: transcript_error + drop_invalid: true + code_switch_langs: [] + output_summary_path: ${output_dir}/transcript_stats_summary.json + + - _target_: nemo_curator.stages.audio.asr.io.split_manifest_writer.SplitAwareManifestWriter + output_dir: ${output_dir} + output_filename_pattern: "{split}_normalized.jsonl" + langs: ${langs} + splits: + - train + - dev + - test +``` + +Set `write_manifest: true` on the ingestion handler only when you also want unnormalized manifests from the source stage. The downstream `SplitAwareManifestWriter` writes the normalized manifests after transcript cleanup and filtering. For known sources such as `IndicVoices`, `Kathbath`, and `Shrutilipi`, source-specific metadata keys are preserved by default based on `source_name`; use `extra_keys` only for a new or custom source. + +## Output Layout + +For Gujarati IndicVoices with native `valid`, the output structure looks like: + +```text +output_dir/ +├── transcript_stats_summary.json +├── gu/ +│ ├── train_normalized.jsonl +│ ├── dev_normalized.jsonl +│ └── test_normalized.jsonl +└── uncleaned/ + └── gu/ + ├── dev.jsonl + ├── test.jsonl + ├── dev/ + │ └── audio/ + │ └── gu_valid_0.wav + └── test/ + └── audio/ + └── gu_valid_3.wav +``` + +If `write_manifest` is disabled on the handler, the `uncleaned` JSONL files are not written, but converted audio is still produced. + +## Transcript Normalization Resources + +Normalization resources live under: + +```text +nemo_curator/stages/audio/asr/normalization/langs/ +├── remove_chars.txt +└── gu/ + ├── alphabet.txt + └── pretok.jsonl +``` + +Each supported language has its own folder: + +- `alphabet.txt`: characters considered valid after normalization +- `pretok.jsonl`: regex replacement rules applied before character removal + +The root-level `langs/remove_chars.txt` contains characters removed for all languages during cleanup. +Standard punctuation characters are maintained in `_PUNCTUATION_CHARS_BY_LANG` in `nemo_curator/stages/audio/asr/normalization/transcript.py`. + +If `remove_pnc_chars: true`, punctuation listed for the language in `_PUNCTUATION_CHARS_BY_LANG` is removed. If `remove_pnc_chars: false`, those punctuation characters are retained during the removal step. + +Set `lowercase_text: true` when you want the final normalized transcript lowercased. For code-switched data, set `code_switch_langs` to combine additional language alphabets, pretok rules, and punctuation with the task's primary `lang`; for example, `code_switch_langs: ["en"]` allows English text inside a Gujarati transcript. Use the same `code_switch_langs` on `TranscriptStatsStage` so `alpha_minus_known_chars` is computed against the combined vocabulary. + +Rows with characters outside `alphabet.txt` are marked with: + +```json +{ + "unknown_chars": {"x": 1}, + "transcript_error": true +} +``` + +`TranscriptNormalizationStage` always returns the task. Filtering is handled by `TranscriptStatsStage` using `drop_invalid`. + +## Statistics Summary + +`TranscriptStatsStage` writes an atomic JSON summary to `output_summary_path` and logs a readable summary when the pipeline finishes. The summary is grouped three ways: + +- `by_language`: full stats for every language and dataset source +- `by_language_overall`: full stats for each language across all sources +- top-level fields: global stats across every language and source + +Each block includes `unknown_chars`, which reports the most frequent unknown characters with both count and dataset-level character rate. Use this to decide whether a character should be added to `alphabet.txt`, normalized in `pretok.jsonl`, or removed with `remove_chars.txt`. + +Example display: + +```text +[transcript_stats] Transcript normalization summary + per_language_source: + lang=gu source=IndicVoices + transcripts: total=100 valid=95 (95.00%) invalid=5 (5.00%) + hours: total=2.40 valid_after_filter=2.31 invalid_removed=0.09 + split_hours: {'train': {'total': 1.6, 'valid': 1.55, 'invalid': 0.05}, 'dev': {'total': 0.5, 'valid': 0.48, 'invalid': 0.02}} + chars: total=8400 unique_known=52 unique_known_rate=0.62% unique_unknown=3 unique_unknown_rate=0.04% + unknown_chars: @=count=12 rate=0.14%, #=count=4 rate=0.05%, x=count=1 rate=0.01% + alpha_minus_known_chars: ['ઁ', 'ઋ'] + split_counts: {'train': {'total': 70, 'valid': 67, 'invalid': 3}, 'dev': {'total': 30, 'valid': 28, 'invalid': 2}} + + lang=hi source=IndicVoices + transcripts: total=80 valid=78 (97.50%) invalid=2 (2.50%) + hours: total=1.90 valid_after_filter=1.86 invalid_removed=0.04 + split_hours: {'train': {'total': 1.2, 'valid': 1.18, 'invalid': 0.02}, 'test': {'total': 0.7, 'valid': 0.68, 'invalid': 0.02}} + chars: total=6900 unique_known=58 unique_known_rate=0.84% unique_unknown=1 unique_unknown_rate=0.01% + unknown_chars: ॐ=count=1 rate=0.01% + alpha_minus_known_chars: ['ॠ'] + split_counts: {'train': {'total': 55, 'valid': 54, 'invalid': 1}, 'test': {'total': 25, 'valid': 24, 'invalid': 1}} + + per_language_overall: + lang=gu overall + transcripts: total=100 valid=95 (95.00%) invalid=5 (5.00%) + hours: total=2.40 valid_after_filter=2.31 invalid_removed=0.09 + split_hours: {'train': {'total': 1.6, 'valid': 1.55, 'invalid': 0.05}, 'dev': {'total': 0.5, 'valid': 0.48, 'invalid': 0.02}} + chars: total=8400 unique_known=52 unique_known_rate=0.62% unique_unknown=3 unique_unknown_rate=0.04% + unknown_chars: @=count=12 rate=0.14%, #=count=4 rate=0.05%, x=count=1 rate=0.01% + alpha_minus_known_chars: ['ઁ', 'ઋ'] + split_counts: {'train': {'total': 70, 'valid': 67, 'invalid': 3}, 'dev': {'total': 30, 'valid': 28, 'invalid': 2}} + + global: + all languages/sources + transcripts: total=180 valid=173 (96.11%) invalid=7 (3.89%) + hours: total=4.30 valid_after_filter=4.17 invalid_removed=0.13 + unknown_chars: @=count=12 rate=0.08%, #=count=4 rate=0.03%, x=count=1 rate=0.01%, ॐ=count=1 rate=0.01% +``` + +For multi-source data, make sure every emitted task has a meaningful `source` value. The statistics stage uses the `source` field by default, configurable with `source_key`. + +## Multiple Sources and Languages + +The normalizer and stats stages can process multiple languages in one stream as long as each task has: + +- `lang`: language code with a matching resource folder +- `source`: dataset/source name +- `split_type`: output split name +- `duration`: audio duration in seconds +- `text`: transcript text + +A single dataset handler can emit data from many languages and sources. For example, a combined ingestion handler may read: + +```text +raw_data_dir/ +├── IndicVoices/ +│ ├── gu/ +│ │ └── valid/ +│ └── hi/ +│ └── valid/ +└── InternalCorpus/ + ├── gu/ + │ └── train.tsv + └── hi/ + └── train.tsv +``` + +and emit `source="IndicVoices"` or `source="InternalCorpus"` per row. Avoid chaining multiple fan-out ingestion handlers in a single linear pipeline; instead, create a combined ingestion handler for one streaming run, or run source-specific pipelines and merge manifests afterwards. + +## Adding a New Ingestion Handler + +New datasets should subclass `BaseASRDatasetHandlerStage`. The handler owns source-specific details: how to find files, read transcripts, decode audio, preserve metadata, and map native splits into Curator `split_type` values. + +### Handler Responsibilities + +1. Discover raw files or dataset shards under `raw_data_dir`. +2. Decode or extract the source audio. +3. Call `convert_audio()` to write WAV, 16 kHz, mono, PCM16 output. +4. Assign `split_type`, such as `train`, `dev`, or `test`. +5. Create an `ASRMetadata` object for each utterance. +6. Call `write_manifest_entry(meta)` if handler-owned manifests are enabled. +7. Return `AudioTask` objects using `build_audio_task(meta)`. + +### Minimal Skeleton + +```python +from __future__ import annotations + +import os +from dataclasses import dataclass, field + +from nemo_curator.stages.audio.asr.datasets.base import BaseASRDatasetHandlerStage +from nemo_curator.stages.audio.asr.metadata import ASRMetadata + + +@dataclass +class MyDatasetHandler(BaseASRDatasetHandlerStage): + name: str = "my_dataset_handler" + source_name: str = "MyDataset" + native_splits: list[str] = field(default_factory=lambda: ["train", "valid"]) + + def _output_splits(self) -> list[str]: + return ["train", "dev", "test"] + + def assign_split(self, native_split: str, utterance_id: str) -> str: + if native_split == "valid": + return "dev" # or deterministic dev/test logic + return native_split + + def process(self, _) -> list: + tasks = [] + for lang in self.langs: + for native_split in self.native_splits: + for row in self._iter_rows(lang, native_split): + utterance_id = row["id"] + split_type = self.assign_split(native_split, utterance_id) + audio_dir = self.audio_output_dir(lang, split_type) + dst_path = os.path.join(audio_dir, f"{utterance_id}.wav") + + audio_info = self.convert_audio( + row["audio_array"], + row["sample_rate"], + row.get("num_channels", 1), + dst_path, + ) + + meta = ASRMetadata( + audio_filepath=dst_path, + text=row["text"], + duration=audio_info["duration"], + lang=lang, + split_type=split_type, + source=self.source_name, + sample_rate=self.target_sample_rate, + num_channels=self.target_channels, + orig_sample_rate=audio_info["orig_sample_rate"], + orig_num_channels=audio_info["orig_num_channels"], + extra={"speaker_id": row.get("speaker_id")}, + ) + self.write_manifest_entry(meta) + tasks.append(self.build_audio_task(meta)) + return tasks +``` + +Use `setup()` for heavy imports such as `datasets`, `soundfile`, or dataset-specific SDKs. This keeps driver-side imports light and matches Curator stage patterns. + +### YAML Entry + +After adding the class, point the YAML `_target_` at the new handler: + +```yaml +stages: + - _target_: nemo_curator.stages.audio.asr.datasets.my_dataset.MyDatasetHandler + raw_data_dir: /data/asr/MyDataset/raw + output_dir: /data/asr/MyDataset/curated/uncleaned + langs: + - gu + - hi + native_splits: + - train + - valid + extraction_workers: 32 + skip_untar: false + write_manifest: false + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptNormalizationStage + remove_pnc_chars: true + + - _target_: nemo_curator.stages.audio.asr.normalization.TranscriptStatsStage + drop_invalid: true + output_summary_path: /data/asr/MyDataset/curated/transcript_stats_summary.json + + - _target_: nemo_curator.stages.audio.asr.io.split_manifest_writer.SplitAwareManifestWriter + output_dir: /data/asr/MyDataset/curated + output_filename_pattern: "{split}_normalized.jsonl" + langs: + - gu + - hi + splits: + - train + - dev + - test +``` + +## Adding a New Language + +To normalize a new language, add a folder under `nemo_curator/stages/audio/asr/normalization/langs/`: + +```text +langs/ +├── remove_chars.txt +└── xx/ + ├── alphabet.txt + └── pretok.jsonl +``` + +Then set `lang: "xx"` in the emitted `ASRMetadata` and add the language's punctuation characters to `_PUNCTUATION_CHARS_BY_LANG` in `nemo_curator/stages/audio/asr/normalization/transcript.py`. + +Start with a conservative `alphabet.txt`. Unknown characters are surfaced in the stats summary, so you can inspect the first run and decide whether the character is valid, should be normalized by `pretok.jsonl`, or should be removed by `remove_chars.txt`. + +## Performance Notes + +- Dataset handlers run with `xenna_workers=1` by default and parallelize extraction internally using `extraction_workers`. +- Increase `extraction_workers` for CPU-heavy decoding and resampling, but keep it within available CPU and I/O limits. +- Use `skip_untar: true` to reuse converted audio files from a previous run. +- `TranscriptStatsStage` runs as one worker so the dataset-level summary is exact. +- `SplitAwareManifestWriter` runs as one worker to avoid concurrent writes to the same manifest file. + +## Troubleshooting + +### `ModuleNotFoundError: nemo_curator.stages.audio.asr` + +Make sure you are running from the Curator checkout and importing the local package: + +```bash +export PYTHONPATH=$PWD:$PYTHONPATH +``` + +or reinstall Curator in editable mode: + +```bash +pip install -e . +``` + +### No Output Rows + +- Check that `raw_data_dir` and `split_dir_pattern` match the downloaded dataset layout. +- Check handler logs for `skipped_missing_text`, `skipped_missing_audio`, and `skipped_audio_load`. +- If `TranscriptStatsStage(drop_invalid=true)` is enabled, inspect `transcript_stats_summary.json` for invalid transcript counts and unknown characters. + +### Missing Normalization Resources + +If normalization fails with `Missing ASR normalization resource`, add the language folder and the required files under `normalization/langs/`. + +### Unexpected Punctuation Removal + +Set `remove_pnc_chars: false` to retain punctuation listed in `_PUNCTUATION_CHARS_BY_LANG`. Characters in `remove_chars.txt` that are not punctuation will still be removed. + +## Related Files + +- `configs/indicvoices.yaml` +- `tutorials/audio/asr_data_pipeline/main.py` +- `tutorials/audio/indicvoices/pipeline.py` +- `nemo_curator/stages/audio/asr/datasets/base.py` +- `nemo_curator/stages/audio/asr/datasets/indicvoices.py` +- `nemo_curator/stages/audio/asr/normalization/transcript.py` +- `nemo_curator/stages/audio/asr/normalization/stats.py` +- `nemo_curator/stages/audio/asr/io/split_manifest_writer.py` diff --git a/tutorials/audio/asr_data_pipeline/main.py b/tutorials/audio/asr_data_pipeline/main.py new file mode 100644 index 0000000000..78a08e856e --- /dev/null +++ b/tutorials/audio/asr_data_pipeline/main.py @@ -0,0 +1,116 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +ASR data processing pipeline runner for NeMo Curator. + +This YAML-driven runner is intended for downloaded ASR datasets that already +include transcripts. It runs ingestion, transcript normalization, transcript +statistics, and split-aware manifest writing stages. + +Usage: + python tutorials/audio/asr_data_pipeline/main.py \\ + --config-path ../../../configs \\ + --config-name indicvoices \\ + raw_data_dir=/data/asr/IndicVoices/raw \\ + output_dir=/data/asr/IndicVoices/curated \\ + 'langs=[gu]' +""" + +import importlib +import traceback + +import hydra +from loguru import logger +from omegaconf import DictConfig, OmegaConf + +from nemo_curator.config.run import create_pipeline_from_yaml +from nemo_curator.stages.audio.asr.normalization.stats import TranscriptStatsStage +from nemo_curator.tasks.utils import TaskPerfUtils + +_EXECUTOR_FACTORIES = { + "xenna": "nemo_curator.backends.xenna:XennaExecutor", + "ray_data": "nemo_curator.backends.ray_data:RayDataExecutor", +} + + +def _create_executor(backend: str, config: dict) -> object: + module_path, class_name = _EXECUTOR_FACTORIES[backend].rsplit(":", 1) + mod = importlib.import_module(module_path) + return getattr(mod, class_name)(config=config) + + +def _validate_backend(backend: str) -> None: + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + + +def _resolve_stats_summary_path(cfg: DictConfig) -> str | None: + path = OmegaConf.select(cfg, "stats_summary_path") + if path: + return str(path) + for stage_cfg in cfg.get("stages", []): + path = stage_cfg.get("output_summary_path") + if path: + return str(path) + return None + + +@hydra.main(version_base=None) +def main(cfg: DictConfig) -> None: + """Run an ASR data processing pipeline using Hydra configuration.""" + try: + pipeline = create_pipeline_from_yaml(cfg) + + logger.info(pipeline.describe()) + logger.info("\n" + "=" * 50 + "\n") + + backend = cfg.get("backend", "xenna") + _validate_backend(backend) + logger.info(f"Using backend: {backend}") + mode = cfg.get("execution_mode", "streaming") + config = {"execution_mode": mode} + executor = _create_executor(backend, config=config) + + logger.info("Starting ASR data processing pipeline...") + results = pipeline.run(executor) + except Exception: + logger.error("ASR data pipeline failed with full chained traceback:\n{}", traceback.format_exc()) + raise + + num_tasks = len(results) if results else 0 + + logger.info("\n" + "=" * 50) + logger.info("PIPELINE COMPLETE") + logger.info("=" * 50) + logger.info(f" Tasks processed: {num_tasks}") + if "final_manifest" in cfg: + logger.info(f" Output manifest: {cfg.final_manifest}") + elif "output_dir" in cfg: + logger.info(f" Output directory: {cfg.output_dir}") + + TranscriptStatsStage.log_summary_from_path(_resolve_stats_summary_path(cfg)) + + stage_metrics = TaskPerfUtils.collect_stage_metrics(results) + for stage_name, metrics in stage_metrics.items(): + logger.info(f" [{stage_name}]") + logger.info( + f" process_time: mean={metrics['process_time'].mean():.4f}s, total={metrics['process_time'].sum():.2f}s" + ) + logger.info(f" items_processed: {metrics['num_items_processed'].sum():.0f}") + + +if __name__ == "__main__": + main() diff --git a/tutorials/audio/indicvoices/pipeline.py b/tutorials/audio/indicvoices/pipeline.py new file mode 100644 index 0000000000..30496c1c69 --- /dev/null +++ b/tutorials/audio/indicvoices/pipeline.py @@ -0,0 +1,102 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""IndicVoices ASR stage-1 pipeline: extract -> split-aware manifest writing. + +Example: + python tutorials/audio/indicvoices/pipeline.py \ + --raw_data_dir /data/asr/gu/indic_voices \ + --output_dir /data/asr/gu/indic_voices_curated \ + --langs gu --split_dir_pattern "{split}" --clean +""" + +import argparse +import shutil +import sys + +from loguru import logger + +from nemo_curator.backends.ray_data import RayDataExecutor +from nemo_curator.backends.xenna import XennaExecutor +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.asr.datasets.huggingface import HuggingFaceASRDatasetHandler +from nemo_curator.stages.audio.asr.io.split_manifest_writer import SplitAwareManifestWriter + + +def create_pipeline(args: argparse.Namespace) -> Pipeline: + pipeline = Pipeline(name="indicvoices_asr", description="IndicVoices extract + split-aware manifests") + pipeline.add_stage( + HuggingFaceASRDatasetHandler( + raw_data_dir=args.raw_data_dir, + output_dir=args.output_dir, + langs=args.langs, + source_name="IndicVoices", + native_splits=args.native_splits, + split_dir_pattern=args.split_dir_pattern, + valid_split_strategy="dev_test", + dev_fraction=args.dev_fraction, + extraction_workers=args.extraction_workers, + skip_untar=args.skip_untar, + ) + ) + pipeline.add_stage( + SplitAwareManifestWriter( + output_dir=args.output_dir, + langs=args.langs, + splits=["train", "dev", "test"], + ) + ) + return pipeline + + +def main(args: argparse.Namespace) -> None: + logger.remove() + logger.add(sys.stderr, level="DEBUG" if args.verbose else "INFO") + + if args.clean: + for lang in args.langs: + shutil.rmtree(f"{args.output_dir}/{lang}", ignore_errors=True) + + pipeline = create_pipeline(args) + logger.info(pipeline.describe()) + logger.info("\n" + "=" * 50 + "\n") + + executor = RayDataExecutor() if args.backend == "ray_data" else XennaExecutor() + logger.info("Starting pipeline execution...") + pipeline.run(executor) + logger.info("\nPipeline completed!") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--raw_data_dir", type=str, required=True, help="Root containing the raw arrow split dirs") + parser.add_argument("--output_dir", type=str, required=True, help="Destination root for audio + manifests") + parser.add_argument("--langs", type=str, nargs="+", default=["gu"], help="Languages to process") + parser.add_argument( + "--native_splits", type=str, nargs="+", default=["train", "valid"], help="Native splits to read" + ) + parser.add_argument( + "--split_dir_pattern", + type=str, + default="{lang}_{split}", + help="Per-split arrow dir name pattern under raw_data_dir (e.g. '{split}' or '{lang}_{split}')", + ) + parser.add_argument("--dev_fraction", type=float, default=0.6, help="Fraction of 'valid' routed to dev") + parser.add_argument("--extraction_workers", type=int, default=10, help="Internal joblib workers for extraction") + parser.add_argument("--skip_untar", action="store_true", help="Reuse already-extracted WAVs when present") + parser.add_argument("--clean", action="store_true", help="Remove existing per-language output before running") + parser.add_argument("--backend", type=str, choices=["xenna", "ray_data"], default="xenna") + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + main(args) diff --git a/tutorials/audio/tagging/main.py b/tutorials/audio/tagging/main.py index c5616be540..eff8bc87f5 100644 --- a/tutorials/audio/tagging/main.py +++ b/tutorials/audio/tagging/main.py @@ -52,12 +52,14 @@ """ import importlib +import traceback import hydra from loguru import logger -from omegaconf import DictConfig +from omegaconf import DictConfig, OmegaConf from nemo_curator.config.run import create_pipeline_from_yaml +from nemo_curator.stages.audio.asr.normalization.stats import TranscriptStatsStage from nemo_curator.tasks.utils import TaskPerfUtils _EXECUTOR_FACTORIES = { @@ -72,25 +74,44 @@ def _create_executor(backend: str, config: dict) -> object: return getattr(mod, class_name)(config=config) +def _validate_backend(backend: str) -> None: + if backend not in _EXECUTOR_FACTORIES: + msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" + raise ValueError(msg) + + +def _resolve_stats_summary_path(cfg: DictConfig) -> str | None: + path = OmegaConf.select(cfg, "stats_summary_path") + if path: + return str(path) + for stage_cfg in cfg.get("stages", []): + path = stage_cfg.get("output_summary_path") + if path: + return str(path) + return None + + @hydra.main(version_base=None) def main(cfg: DictConfig) -> None: """Run audio tagging pipeline using Hydra configuration.""" - pipeline = create_pipeline_from_yaml(cfg) + try: + pipeline = create_pipeline_from_yaml(cfg) - logger.info(pipeline.describe()) - logger.info("\n" + "=" * 50 + "\n") + logger.info(pipeline.describe()) + logger.info("\n" + "=" * 50 + "\n") - backend = cfg.get("backend", "xenna") - if backend not in _EXECUTOR_FACTORIES: - msg = f"Unknown backend '{backend}'. Choose from: {list(_EXECUTOR_FACTORIES)}" - raise ValueError(msg) - logger.info(f"Using backend: {backend}") - mode = cfg.get("execution_mode", "streaming") - config = {"execution_mode": mode} - executor = _create_executor(backend, config=config) + backend = cfg.get("backend", "xenna") + _validate_backend(backend) + logger.info(f"Using backend: {backend}") + mode = cfg.get("execution_mode", "streaming") + config = {"execution_mode": mode} + executor = _create_executor(backend, config=config) - logger.info("Starting audio tagging pipeline...") - results = pipeline.run(executor) + logger.info("Starting audio tagging pipeline...") + results = pipeline.run(executor) + except Exception: + logger.error("Audio pipeline failed with full chained traceback:\n{}", traceback.format_exc()) + raise num_tasks = len(results) if results else 0 @@ -98,7 +119,13 @@ def main(cfg: DictConfig) -> None: logger.info("PIPELINE COMPLETE") logger.info("=" * 50) logger.info(f" Tasks processed: {num_tasks}") - logger.info(f" Output manifest: {cfg.final_manifest}") + if "final_manifest" in cfg: + logger.info(f" Output manifest: {cfg.final_manifest}") + elif "output_dir" in cfg: + logger.info(f" Output directory: {cfg.output_dir}") + + stats_summary_path = _resolve_stats_summary_path(cfg) + TranscriptStatsStage.log_summary_from_path(stats_summary_path) stage_metrics = TaskPerfUtils.collect_stage_metrics(results) for stage_name, metrics in stage_metrics.items():