Skip to content

[BUG] Model.load raises TypeError on Python 3.9+ (parameterized generic in isinstance) #810

@shaun0927

Description

@shaun0927

Bug description

After PR #809 (commit 8bf122f5), transformers4rec.torch.model.base.Model.load() uses a parameterized generic as the second argument to isinstance:

if isinstance(state_dict, Dict[str, torch.Tensor]):

if isinstance(state_dict, Dict[str, torch.Tensor]):
    model.load_state_dict(state_dict, strict=strict)
else:
    raise ValueError("`state_dict` must be a dictionary of parameter (torch) tensors.")

On Python 3.9+ this raises TypeError: Subscripted generics cannot be used with class and instance checks before any of the load logic runs. Net effect: Model.load() is completely unusable on current main.

The regression is not caught by the existing test tests/unit/torch/model/test_model.py::test_save_next_item_prediction_model, which is the only place in the test suite that reaches this line. CI does not appear to exercise it on a clean env (likely the copy-pr-bot runners are currently blocked, see #798).

Steps/Code to reproduce bug

Pure-Python repro (no T4Rec install needed):

from typing import Dict
import torch

isinstance({"a": torch.zeros(1)}, Dict[str, torch.Tensor])
# TypeError: Subscripted generics cannot be used with class and instance checks

Library-level repro:

import torch
from transformers4rec.torch.model.base import Model

Model.load({"a": torch.zeros(1)}, heads=[])   # raises TypeError before anything else

Expected behavior

Model.load() should accept a plain dict (as returned by state_dict() or torch.load) and load the weights into the provided heads.

Environment details

  • Transformers4Rec: main @ 8bf122f5
  • Python: any 3.9+ (documented semantics; all modern installs affected)
  • PyTorch: any

Additional context

Minimal fix — replace the parameterized generic with a plain dict:

if isinstance(state_dict, dict):
    model.load_state_dict(state_dict, strict=strict)
else:
    raise TypeError("`state_dict` must be a dict of torch.Tensor.")

The tighter "dict of strtorch.Tensor" check cannot be done in a single isinstance call; if such validation is desired it has to be a loop over items. In practice load_state_dict itself will surface incompatible entries, so the bare dict check is sufficient.

Happy to send a one-line PR.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions