Bug description
After PR #809 (commit 8bf122f5), transformers4rec.torch.model.base.Model.load() uses a parameterized generic as the second argument to isinstance:
|
if isinstance(state_dict, Dict[str, torch.Tensor]): |
if isinstance(state_dict, Dict[str, torch.Tensor]):
model.load_state_dict(state_dict, strict=strict)
else:
raise ValueError("`state_dict` must be a dictionary of parameter (torch) tensors.")
On Python 3.9+ this raises TypeError: Subscripted generics cannot be used with class and instance checks before any of the load logic runs. Net effect: Model.load() is completely unusable on current main.
The regression is not caught by the existing test tests/unit/torch/model/test_model.py::test_save_next_item_prediction_model, which is the only place in the test suite that reaches this line. CI does not appear to exercise it on a clean env (likely the copy-pr-bot runners are currently blocked, see #798).
Steps/Code to reproduce bug
Pure-Python repro (no T4Rec install needed):
from typing import Dict
import torch
isinstance({"a": torch.zeros(1)}, Dict[str, torch.Tensor])
# TypeError: Subscripted generics cannot be used with class and instance checks
Library-level repro:
import torch
from transformers4rec.torch.model.base import Model
Model.load({"a": torch.zeros(1)}, heads=[]) # raises TypeError before anything else
Expected behavior
Model.load() should accept a plain dict (as returned by state_dict() or torch.load) and load the weights into the provided heads.
Environment details
- Transformers4Rec:
main @ 8bf122f5
- Python: any 3.9+ (documented semantics; all modern installs affected)
- PyTorch: any
Additional context
Minimal fix — replace the parameterized generic with a plain dict:
if isinstance(state_dict, dict):
model.load_state_dict(state_dict, strict=strict)
else:
raise TypeError("`state_dict` must be a dict of torch.Tensor.")
The tighter "dict of str → torch.Tensor" check cannot be done in a single isinstance call; if such validation is desired it has to be a loop over items. In practice load_state_dict itself will surface incompatible entries, so the bare dict check is sufficient.
Happy to send a one-line PR.
Bug description
After PR #809 (commit
8bf122f5),transformers4rec.torch.model.base.Model.load()uses a parameterized generic as the second argument toisinstance:Transformers4Rec/transformers4rec/torch/model/base.py
Line 917 in 8bf122f
On Python 3.9+ this raises
TypeError: Subscripted generics cannot be used with class and instance checksbefore any of the load logic runs. Net effect:Model.load()is completely unusable on currentmain.The regression is not caught by the existing test
tests/unit/torch/model/test_model.py::test_save_next_item_prediction_model, which is the only place in the test suite that reaches this line. CI does not appear to exercise it on a clean env (likely thecopy-pr-botrunners are currently blocked, see #798).Steps/Code to reproduce bug
Pure-Python repro (no T4Rec install needed):
Library-level repro:
Expected behavior
Model.load()should accept a plaindict(as returned bystate_dict()ortorch.load) and load the weights into the provided heads.Environment details
main@8bf122f5Additional context
Minimal fix — replace the parameterized generic with a plain
dict:The tighter "dict of
str→torch.Tensor" check cannot be done in a singleisinstancecall; if such validation is desired it has to be a loop over items. In practiceload_state_dictitself will surface incompatible entries, so the baredictcheck is sufficient.Happy to send a one-line PR.