diff --git a/torchTextClassifiers/torchTextClassifiers.py b/torchTextClassifiers/torchTextClassifiers.py index 1d99987..cd0ce05 100644 --- a/torchTextClassifiers/torchTextClassifiers.py +++ b/torchTextClassifiers/torchTextClassifiers.py @@ -784,7 +784,7 @@ def predict( for k in range(top_k): attributions = lig.attribute( (encoded_text, attention_mask, categorical_vars), - target=integer_predictions[:, k], + target=integer_predictions[:, k].to(device), ) # (batch_size, seq_len) attributions = attributions.sum(dim=-1) captum_attributions.append(attributions.detach().cpu())