Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion modelscan/scanners/pickle/scan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import logging
from typing import Optional
from typing import Any, Dict, Optional

from modelscan.error import DependencyError
from modelscan.scanners.scan import ScanBase, ScanResults
from modelscan.tools.utils import _is_zipfile
from modelscan.tools.picklescanner import (
numpy_installed,
scan_numpy,
scan_pickle_bytes,
scan_pytorch,
Expand Down Expand Up @@ -53,6 +55,20 @@ def scan(
]:
return None

dep_error = self.handle_binary_dependencies()
if dep_error:
return ScanResults(
[],
[
DependencyError(
self.name(),
f"To use {self.full_name()}, please install modelscan with numpy extras. `pip install 'modelscan[ numpy ]'` if you are using pip.",
model,
)
],
[],
)

results = scan_numpy(
model=model,
settings=self._settings,
Expand All @@ -68,6 +84,13 @@ def name() -> str:
def full_name() -> str:
return "modelscan.scanners.NumpyUnsafeOpScan"

def handle_binary_dependencies(
self, settings: Optional[Dict[str, Any]] = None
) -> Optional[str]:
if not numpy_installed:
return DependencyError.name()
return None


class PickleUnsafeOpScan(ScanBase):
def scan(
Expand Down
11 changes: 10 additions & 1 deletion modelscan/tools/picklescanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@
from tarfile import TarError
from typing import IO, Any, Dict, List, Set, Tuple, Union, Optional

import numpy as np
try:
import numpy as np

numpy_installed = True
except ImportError:
np = None # type: ignore[assignment]
numpy_installed = False

from modelscan.error import PickleGenopsError
from modelscan.skip import ModelScanSkipped, SkipCategories
Expand Down Expand Up @@ -200,6 +206,9 @@ def _build_scan_result_from_raw_globals(


def scan_numpy(model: Model, settings: Dict[str, Any]) -> ScanResults:
if np is None:
raise ImportError("NumPy is required to scan NumPy model files.")

scan_name = "numpy"
# Code to distinguish from NumPy binary files and pickles.
_ZIP_PREFIX = b"PK\x03\x04"
Expand Down
4 changes: 3 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ modelscan = "modelscan.cli:main"
[tool.poetry.dependencies]
python = ">=3.10,<3.13"
click = "^8.1.3"
numpy = ">=1.24.3"
numpy = { version = ">=1.24.3", optional = true }
rich = ">=13.4.2,<15.0.0"
tomlkit = ">=0.12.3,<0.14.0"
h5py = { version = "^3.9.0", optional = true }
Expand All @@ -25,6 +25,7 @@ tensorflow = { version = "^2.17", optional = true }
[tool.poetry.extras]
tensorflow = ["tensorflow"]
h5py = ["h5py"]
numpy = ["numpy"]

[tool.poetry.group.test.dependencies]
pytest = ">=7.4,<9.0"
Expand All @@ -36,6 +37,7 @@ dill = ">=0.3.7,<0.5.0"
types-requests = ">1.26"
torch = "^2.7.0"
tf-keras = "^2.20.1"
numpy = ">=1.24.3"


[tool.poetry.group.dev.dependencies]
Expand Down
20 changes: 19 additions & 1 deletion tests/test_modelscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)

from modelscan.skip import SkipCategories
from modelscan.settings import DEFAULT_SETTINGS
from modelscan.settings import DEFAULT_SETTINGS, SupportedModelFormats
from modelscan.model import Model

settings: Dict[str, Any] = DEFAULT_SETTINGS
Expand Down Expand Up @@ -606,6 +606,24 @@ def test_scan_numpy(numpy_file_path: Any) -> None:
assert results["errors"] == []


def test_scan_numpy_reports_dependency_error_when_numpy_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
from modelscan.scanners.pickle import scan as pickle_scan

monkeypatch.setattr(pickle_scan, "numpy_installed", False)

model = Model("missing_numpy.npy", io.BytesIO(b""))
model.set_context("formats", [SupportedModelFormats.NUMPY])
scanner = pickle_scan.NumpyUnsafeOpScan(settings)

results = scanner.scan(model)

assert results is not None
assert [error.to_dict()["category"] for error in results.errors] == ["DEPENDENCY"]
assert "modelscan[ numpy ]" in results.errors[0].message


def test_scan_file_path(file_path: Any) -> None:
benign_pickle = ModelScan()
results = benign_pickle.scan(Path(f"{file_path}/data/benign0_v3.pkl"))
Expand Down