From 5c61cd55f52a84f20bde1a0e88b1d3d638df6a6a Mon Sep 17 00:00:00 2001 From: julber95 Date: Wed, 10 Jun 2026 14:30:32 +0000 Subject: [PATCH] Fix device mismatch in captum attribution target --- torchTextClassifiers/torchTextClassifiers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 1d99987..cd0ce05 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -784,7 +784,7 @@ def predict( for k in range(top_k): attributions = lig.attribute( (encoded_text, attention_mask, categorical_vars), - target=integer_predictions[:, k], + target=integer_predictions[:, k].to(device), ) # (batch_size, seq_len) attributions = attributions.sum(dim=-1) captum_attributions.append(attributions.detach().cpu())