From 6540e74ee709b850280d0d5163b574e271d7b9c4 Mon Sep 17 00:00:00 2001 From: "Bjarne C. Hiller" Date: Mon, 11 Nov 2024 14:15:02 -0800 Subject: [PATCH] Fix issues with `captum/insights/attr_vis/example.py` (#1432) Summary: ## Don't set environment variable `WERKZEUG_RUN_MAIN` **Solution proposed by:** jeremyfix With Werkzeug 2.1.0, setting the environment variable `WERKZEUG_RUN_MAIN` results in `KeyError: 'WERKZEUG_SERVER_FD'`. `WERKZEUG_RUN_MAIN` is used by Werkzeug internally and is not supposed to be set by external libraries. Therefore, I removed `os.environ["WERKZEUG_RUN_MAIN"] = "true"`. As a **side effect**, the startup message is shown, but this shouldn't be a problem compared to the `KeyError.` Related Issues: - https://github.com/pallets/werkzeug/issues/2361 - https://github.com/pytorch/captum/issues/1127 ## Fix import and typing errors In `example.py`, there were some import and typing errors. I also updated the example path in the `README.md`. ## Testenvironment OS: Debian 12 Python: Tested in both 3.11 and 3.12 ```console $ pip list Package Version ----------------------------- -------------- alabaster 1.0.0 annoy 1.17.3 anyio 4.6.2.post1 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 async-lru 2.0.4 attrs 24.2.0 babel 2.16.0 beautifulsoup4 4.12.3 black 24.10.0 bleach 6.2.0 blinker 1.8.2 Brotli 1.1.0 captum 0.7.0 certifi 2024.8.30 cffi 1.17.1 charset-normalizer 3.4.0 click 8.1.7 comm 0.2.2 contourpy 1.3.0 coverage 7.6.4 cycler 0.12.1 debugpy 1.8.7 decorator 5.1.1 defusedxml 0.7.1 docutils 0.21.2 executing 2.1.0 fastjsonschema 2.20.0 filelock 3.16.1 flake8 7.1.1 Flask 3.0.3 Flask-Compress 1.17 fonttools 4.54.1 fqdn 1.5.1 fsspec 2024.10.0 h11 0.14.0 httpcore 1.0.6 httpx 0.27.2 idna 3.10 imagesize 1.4.1 iniconfig 2.0.0 ipykernel 6.29.5 ipython 8.29.0 ipywidgets 8.1.5 isoduration 20.11.0 itsdangerous 2.2.0 jedi 0.19.1 Jinja2 3.1.4 joblib 1.4.2 json5 0.9.25 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2024.10.1 jupyter 1.1.1 jupyter_client 8.6.3 jupyter-console 6.6.3 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-lsp 2.2.5 jupyter_server 2.14.2 jupyter_server_terminals 0.5.3 jupyterlab 4.2.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.13 kiwisolver 1.4.7 libcst 1.5.0 MarkupSafe 3.0.2 matplotlib 3.9.2 matplotlib-inline 0.1.7 mccabe 0.7.0 mistune 3.0.2 moreorless 0.4.0 mpmath 1.3.0 mypy 1.13.0 mypy-extensions 1.0.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.4.2 notebook 7.2.2 notebook_shim 0.2.4 numpy 1.26.4 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 overrides 7.7.0 packaging 24.1 pandocfilters 1.5.1 parameterized 0.9.0 parso 0.8.4 pathspec 0.12.1 pexpect 4.9.0 pillow 11.0.0 pip 24.2 platformdirs 4.3.6 pluggy 1.5.0 prometheus_client 0.21.0 prompt_toolkit 3.0.48 psutil 6.1.0 ptyprocess 0.7.0 pure_eval 0.2.3 pycodestyle 2.12.1 pycparser 2.22 pyflakes 3.2.0 Pygments 2.18.0 pyparsing 3.2.0 pytest 8.3.3 pytest-cov 6.0.0 python-dateutil 2.9.0.post0 python-json-logger 2.0.7 PyYAML 6.0.2 pyzmq 26.2.0 referencing 0.35.1 requests 2.32.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rpds-py 0.20.1 scikit-learn 1.5.2 scipy 1.14.1 Send2Trash 1.8.3 setuptools 75.1.0 six 1.16.0 sniffio 1.3.1 snowballstemmer 2.2.0 soupsieve 2.6 Sphinx 8.1.3 sphinx-autodoc-typehints 2.5.0 sphinxcontrib-applehelp 2.0.0 sphinxcontrib-devhelp 2.0.0 sphinxcontrib-htmlhelp 2.1.0 sphinxcontrib-jsmath 1.0.1 sphinxcontrib-katex 0.9.10 sphinxcontrib-qthelp 2.0.0 sphinxcontrib-serializinghtml 2.0.0 stack-data 0.6.3 stdlibs 2024.10.25 sympy 1.13.1 terminado 0.18.1 threadpoolctl 3.5.0 tinycss2 1.4.0 toml 0.10.2 tomlkit 0.13.2 torch 2.5.1 torchvision 0.20.1 tornado 6.4.1 tqdm 4.66.6 trailrunner 1.4.0 traitlets 5.14.3 triton 3.1.0 types-python-dateutil 2.9.0.20241003 typing_extensions 4.12.2 ufmt 2.8.0 uri-template 1.3.0 urllib3 2.2.3 usort 1.0.2 wcwidth 0.2.13 webcolors 24.8.0 webencodings 0.5.1 websocket-client 1.8.0 Werkzeug 3.1.2 wheel 0.44.0 widgetsnbextension 4.0.13 zstandard 0.23.0 ``` Pull Request resolved: https://github.com/pytorch/captum/pull/1432 Reviewed By: craymichael Differential Revision: D65665700 Pulled By: cyrjano fbshipit-source-id: 08028809a4ef04ba560d35490104cf11e6af1faa --- README.md | 2 +- captum/insights/attr_vis/example.py | 52 ++++++++++++++--------------- captum/insights/attr_vis/server.py | 2 -- 3 files changed, 26 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index a78f99656..44d25088e 100644 --- a/README.md +++ b/README.md @@ -394,7 +394,7 @@ access to a number of our interpretability algorithms. To analyze a sample model on CIFAR10 via Captum Insights run ``` -python -m captum.insights.example +python -m captum.insights.attr_vis.example ``` and navigate to the URL specified in the output. diff --git a/captum/insights/attr_vis/example.py b/captum/insights/attr_vis/example.py index be20e44c4..cb7c071b7 100644 --- a/captum/insights/attr_vis/example.py +++ b/captum/insights/attr_vis/example.py @@ -10,9 +10,6 @@ import torchvision.transforms as transforms from captum.insights import AttributionVisualizer, Batch -# pyre-fixme[21]: Could not find module -# `captum.insights.attr_vis.example.get_pretrained_model`. -from captum.insights.attr_vis.example.get_pretrained_model import Net from captum.insights.attr_vis.features import ImageFeature @@ -32,31 +29,32 @@ def get_classes() -> List[str]: return classes -def get_pretrained_model() -> Net: - class Net(nn.Module): - def __init__(self) -> None: - super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) - self.pool1 = nn.MaxPool2d(2, 2) - self.pool2 = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) - self.fc1 = nn.Linear(16 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - self.relu1 = nn.ReLU() - self.relu2 = nn.ReLU() - self.relu3 = nn.ReLU() - self.relu4 = nn.ReLU() - - def forward(self, x): - x = self.pool1(self.relu1(self.conv1(x))) - x = self.pool2(self.relu2(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) - x = self.relu3(self.fc1(x)) - x = self.relu4(self.fc2(x)) - x = self.fc3(x) - return x +class Net(nn.Module): + def __init__(self) -> None: + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 6, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(6, 16, 5) + self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + self.relu1 = nn.ReLU() + self.relu2 = nn.ReLU() + self.relu3 = nn.ReLU() + self.relu4 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.pool1(self.relu1(self.conv1(x))) + x = self.pool2(self.relu2(self.conv2(x))) + x = x.view(-1, 16 * 5 * 5) + x = self.relu3(self.fc1(x)) + x = self.relu4(self.fc2(x)) + x = self.fc3(x) + return x + +def get_pretrained_model() -> Net: net = Net() pt_path = os.path.abspath( os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt") diff --git a/captum/insights/attr_vis/server.py b/captum/insights/attr_vis/server.py index 6d13d8183..5edbd0eb2 100644 --- a/captum/insights/attr_vis/server.py +++ b/captum/insights/attr_vis/server.py @@ -2,7 +2,6 @@ # pyre-strict import logging -import os import socket import threading from time import sleep @@ -108,7 +107,6 @@ def start_server( global port if port is None: - os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message if not debug: log = logging.getLogger("werkzeug") log.disabled = True