diff --git a/nemo_curator/stages/text/deduplication/removal.py b/nemo_curator/stages/text/deduplication/removal.py index 45d48d9a4a..9e11c28cb6 100644 --- a/nemo_curator/stages/text/deduplication/removal.py +++ b/nemo_curator/stages/text/deduplication/removal.py @@ -43,6 +43,7 @@ class TextDuplicatesRemovalStage(ProcessingStage[DocumentBatch, DocumentBatch]): id_field: Field to use for deduplication within the input dataframe. Defaults to CURATOR_DEDUP_ID_STR. duplicate_id_field: Field to use for deduplication within the removal dataframe. Defaults to "id". read_kwargs: Additional arguments for reading parquet files + drop_id_field: Whether to drop the deduplication ID field from the output batch. """ ids_to_remove_path: str @@ -51,6 +52,7 @@ class TextDuplicatesRemovalStage(ProcessingStage[DocumentBatch, DocumentBatch]): # Optional parameters read_kwargs: dict[str, Any] | None = None + drop_id_field: bool = False def __post_init__(self): """Initialize parent class after dataclass initialization.""" @@ -84,6 +86,8 @@ def process(self, task: DocumentBatch) -> DocumentBatch: time_to_remove_t0 = time.perf_counter() removal_ids = set(removal_df[self.duplicate_id_field].tolist()) df = df[~df[self.id_field].isin(removal_ids)] + if self.drop_id_field: + df = df.drop(columns=[self.id_field]) removal_ids_time = time.perf_counter() - time_to_remove_t0 self._log_metrics( { diff --git a/nemo_curator/stages/text/deduplication/removal_workflow.py b/nemo_curator/stages/text/deduplication/removal_workflow.py index 7684bafcea..b928c64730 100644 --- a/nemo_curator/stages/text/deduplication/removal_workflow.py +++ b/nemo_curator/stages/text/deduplication/removal_workflow.py @@ -62,6 +62,7 @@ class TextDuplicatesRemovalWorkflow(WorkflowBase): output_kwargs: dict[str, Any] | None = None output_fields: list[str] | None = None output_mode: Literal["ignore", "overwrite", "append", "error"] | None = None + drop_id_field: bool = False def __post_init__(self): """Initialize parent class after dataclass initialization.""" @@ -69,6 +70,9 @@ def __post_init__(self): logger.warning( f"Using {CURATOR_DEDUP_ID_STR} as id_field for removal stage, even though we are not using id generator." ) + if self.drop_id_field and self.output_fields and self.id_field in self.output_fields: + msg = f"Cannot drop id_field {self.id_field!r} when it is included in output_fields." + raise ValueError(msg) def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> list[ProcessingStage]: stages = [] @@ -125,6 +129,7 @@ def _generate_stages(self, initial_tasks: list[FileGroupTask] | None = None) -> id_field=self.id_field, duplicate_id_field=self.duplicate_id_field, read_kwargs=self.duplicate_id_read_kwargs, + drop_id_field=self.drop_id_field, ) ) diff --git a/nemo_curator/stages/text/deduplication/semantic.py b/nemo_curator/stages/text/deduplication/semantic.py index acb333da2f..92c16b0175 100644 --- a/nemo_curator/stages/text/deduplication/semantic.py +++ b/nemo_curator/stages/text/deduplication/semantic.py @@ -374,6 +374,7 @@ def _run_duplicate_removal(self, executor: BaseExecutor) -> WorkflowRunResult | output_kwargs=self.write_kwargs, output_fields=self.output_fields, output_mode="ignore", + drop_id_field=self.use_id_generator and self.output_fields is None, ) return workflow.run(executor=executor) diff --git a/tests/stages/text/deduplication/test_removal_workflow.py b/tests/stages/text/deduplication/test_removal_workflow.py index c45abcee61..b034f10d13 100644 --- a/tests/stages/text/deduplication/test_removal_workflow.py +++ b/tests/stages/text/deduplication/test_removal_workflow.py @@ -281,7 +281,38 @@ def test_initial_tasks_partitioning(self, test_config: "TestTextDuplicateRemoval assert workflow_output.get_metadata("num_duplicates_removed") == expected_removed + +def test_removal_stage_can_drop_id_field(tmp_path: Path): + ids_to_remove_path = tmp_path / "ids_to_remove.parquet" + pd.DataFrame({"id": [1]}).to_parquet(ids_to_remove_path, index=False) + task = DocumentBatch( + dataset_name="dataset", + data=pd.DataFrame({CURATOR_DEDUP_ID_STR: [1, 2], "text": ["drop", "keep"]}), + ) + + stage = TextDuplicatesRemovalStage( + ids_to_remove_path=str(ids_to_remove_path), + id_field=CURATOR_DEDUP_ID_STR, + drop_id_field=True, + ) + + result = stage.process(task).to_pandas() + + assert result.to_dict(orient="list") == {"text": ["keep"]} + assert CURATOR_DEDUP_ID_STR not in result.columns + + class TestTextDuplicatesRemovalWorkflowGenerateStages: + def test_drop_id_field_conflicts_with_output_fields(self): + with pytest.raises(ValueError, match="Cannot drop id_field"): + TextDuplicatesRemovalWorkflow( + input_path="input_path", + ids_to_remove_path="ids_to_remove_path", + output_path="output_path", + output_fields=["text", CURATOR_DEDUP_ID_STR], + drop_id_field=True, + ) + def test_invalid_filetypes(self): read_invalid_file_type_workflow = TextDuplicatesRemovalWorkflow( input_path="input_path", @@ -347,6 +378,7 @@ def test_reader_stage( assert stages[2].id_field == CURATOR_DEDUP_ID_STR assert stages[2].duplicate_id_field == "id" assert stages[2].read_kwargs == {} + assert not stages[2].drop_id_field # test for writer stage (stages[3]) - default output_filetype is parquet assert isinstance(stages[3], ParquetWriter) @@ -373,6 +405,7 @@ def test_writer_stage(self, output_filetype: str): output_path="output_path", output_filetype=output_filetype, id_generator_path=None, + drop_id_field=True, ) stages = workflow._generate_stages(initial_tasks=None) assert len(stages) == 4 @@ -380,6 +413,7 @@ def test_writer_stage(self, output_filetype: str): # reader stage assert isinstance(stages[1], ParquetReaderStage) # Default input_filetype is parquet assert isinstance(stages[2], TextDuplicatesRemovalStage) + assert stages[2].drop_id_field expected_write_stage = ParquetWriter if output_filetype == "parquet" else JsonlWriter assert isinstance(stages[3], expected_write_stage) diff --git a/tutorials/text/deduplication/semantic/semantic_e2e.ipynb b/tutorials/text/deduplication/semantic/semantic_e2e.ipynb index 73f8ce5165..1947ee686c 100644 --- a/tutorials/text/deduplication/semantic/semantic_e2e.ipynb +++ b/tutorials/text/deduplication/semantic/semantic_e2e.ipynb @@ -785,7 +785,7 @@ "\n", "We can control the schema of this by specifying the `output_fields` argument in the workflow definition.\n", "\n", - "If you had set `use_id_generator=True` then you'd see `_curator_dedup_id` here as well." + "If you had set `use_id_generator=True` with the default output schema, the deduplicated output now drops the generated `_curator_dedup_id` before writing. If you want it preserved, include it explicitly in `output_fields`." ] }, {