Skip to content

Commit

Permalink
Fix issues with captum/insights/attr_vis/example.py (#1432)
Browse files Browse the repository at this point in the history
Summary:
## Don't set environment variable `WERKZEUG_RUN_MAIN`

**Solution proposed by:** jeremyfix

With Werkzeug 2.1.0, setting the environment variable `WERKZEUG_RUN_MAIN` results in `KeyError: 'WERKZEUG_SERVER_FD'`. `WERKZEUG_RUN_MAIN` is used by Werkzeug internally and is not supposed to be set by external libraries. Therefore, I removed `os.environ["WERKZEUG_RUN_MAIN"] = "true"`. As a **side effect**, the startup message is shown, but this shouldn't be a problem compared to the `KeyError.`

Related Issues:
- pallets/werkzeug#2361
- #1127

## Fix import and typing errors

In `example.py`, there were some import and typing errors. I also updated the example path in the `README.md`.

## Testenvironment

OS: Debian 12
Python: Tested in both 3.11 and 3.12

```console
$ pip list
Package                       Version
----------------------------- --------------
alabaster                     1.0.0
annoy                         1.17.3
anyio                         4.6.2.post1
argon2-cffi                   23.1.0
argon2-cffi-bindings          21.2.0
arrow                         1.3.0
asttokens                     2.4.1
async-lru                     2.0.4
attrs                         24.2.0
babel                         2.16.0
beautifulsoup4                4.12.3
black                         24.10.0
bleach                        6.2.0
blinker                       1.8.2
Brotli                        1.1.0
captum                        0.7.0
certifi                       2024.8.30
cffi                          1.17.1
charset-normalizer            3.4.0
click                         8.1.7
comm                          0.2.2
contourpy                     1.3.0
coverage                      7.6.4
cycler                        0.12.1
debugpy                       1.8.7
decorator                     5.1.1
defusedxml                    0.7.1
docutils                      0.21.2
executing                     2.1.0
fastjsonschema                2.20.0
filelock                      3.16.1
flake8                        7.1.1
Flask                         3.0.3
Flask-Compress                1.17
fonttools                     4.54.1
fqdn                          1.5.1
fsspec                        2024.10.0
h11                           0.14.0
httpcore                      1.0.6
httpx                         0.27.2
idna                          3.10
imagesize                     1.4.1
iniconfig                     2.0.0
ipykernel                     6.29.5
ipython                       8.29.0
ipywidgets                    8.1.5
isoduration                   20.11.0
itsdangerous                  2.2.0
jedi                          0.19.1
Jinja2                        3.1.4
joblib                        1.4.2
json5                         0.9.25
jsonpointer                   3.0.0
jsonschema                    4.23.0
jsonschema-specifications     2024.10.1
jupyter                       1.1.1
jupyter_client                8.6.3
jupyter-console               6.6.3
jupyter_core                  5.7.2
jupyter-events                0.10.0
jupyter-lsp                   2.2.5
jupyter_server                2.14.2
jupyter_server_terminals      0.5.3
jupyterlab                    4.2.5
jupyterlab_pygments           0.3.0
jupyterlab_server             2.27.3
jupyterlab_widgets            3.0.13
kiwisolver                    1.4.7
libcst                        1.5.0
MarkupSafe                    3.0.2
matplotlib                    3.9.2
matplotlib-inline             0.1.7
mccabe                        0.7.0
mistune                       3.0.2
moreorless                    0.4.0
mpmath                        1.3.0
mypy                          1.13.0
mypy-extensions               1.0.0
nbclient                      0.10.0
nbconvert                     7.16.4
nbformat                      5.10.4
nest-asyncio                  1.6.0
networkx                      3.4.2
notebook                      7.2.2
notebook_shim                 0.2.4
numpy                         1.26.4
nvidia-cublas-cu12            12.4.5.8
nvidia-cuda-cupti-cu12        12.4.127
nvidia-cuda-nvrtc-cu12        12.4.127
nvidia-cuda-runtime-cu12      12.4.127
nvidia-cudnn-cu12             9.1.0.70
nvidia-cufft-cu12             11.2.1.3
nvidia-curand-cu12            10.3.5.147
nvidia-cusolver-cu12          11.6.1.9
nvidia-cusparse-cu12          12.3.1.170
nvidia-nccl-cu12              2.21.5
nvidia-nvjitlink-cu12         12.4.127
nvidia-nvtx-cu12              12.4.127
overrides                     7.7.0
packaging                     24.1
pandocfilters                 1.5.1
parameterized                 0.9.0
parso                         0.8.4
pathspec                      0.12.1
pexpect                       4.9.0
pillow                        11.0.0
pip                           24.2
platformdirs                  4.3.6
pluggy                        1.5.0
prometheus_client             0.21.0
prompt_toolkit                3.0.48
psutil                        6.1.0
ptyprocess                    0.7.0
pure_eval                     0.2.3
pycodestyle                   2.12.1
pycparser                     2.22
pyflakes                      3.2.0
Pygments                      2.18.0
pyparsing                     3.2.0
pytest                        8.3.3
pytest-cov                    6.0.0
python-dateutil               2.9.0.post0
python-json-logger            2.0.7
PyYAML                        6.0.2
pyzmq                         26.2.0
referencing                   0.35.1
requests                      2.32.3
rfc3339-validator             0.1.4
rfc3986-validator             0.1.1
rpds-py                       0.20.1
scikit-learn                  1.5.2
scipy                         1.14.1
Send2Trash                    1.8.3
setuptools                    75.1.0
six                           1.16.0
sniffio                       1.3.1
snowballstemmer               2.2.0
soupsieve                     2.6
Sphinx                        8.1.3
sphinx-autodoc-typehints      2.5.0
sphinxcontrib-applehelp       2.0.0
sphinxcontrib-devhelp         2.0.0
sphinxcontrib-htmlhelp        2.1.0
sphinxcontrib-jsmath          1.0.1
sphinxcontrib-katex           0.9.10
sphinxcontrib-qthelp          2.0.0
sphinxcontrib-serializinghtml 2.0.0
stack-data                    0.6.3
stdlibs                       2024.10.25
sympy                         1.13.1
terminado                     0.18.1
threadpoolctl                 3.5.0
tinycss2                      1.4.0
toml                          0.10.2
tomlkit                       0.13.2
torch                         2.5.1
torchvision                   0.20.1
tornado                       6.4.1
tqdm                          4.66.6
trailrunner                   1.4.0
traitlets                     5.14.3
triton                        3.1.0
types-python-dateutil         2.9.0.20241003
typing_extensions             4.12.2
ufmt                          2.8.0
uri-template                  1.3.0
urllib3                       2.2.3
usort                         1.0.2
wcwidth                       0.2.13
webcolors                     24.8.0
webencodings                  0.5.1
websocket-client              1.8.0
Werkzeug                      3.1.2
wheel                         0.44.0
widgetsnbextension            4.0.13
zstandard                     0.23.0
```

Pull Request resolved: #1432

Reviewed By: craymichael

Differential Revision: D65665700

Pulled By: cyrjano

fbshipit-source-id: 08028809a4ef04ba560d35490104cf11e6af1faa
  • Loading branch information
chillerb authored and facebook-github-bot committed Nov 11, 2024
1 parent 69d6939 commit 6540e74
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 30 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ access to a number of our interpretability algorithms.
To analyze a sample model on CIFAR10 via Captum Insights run

```
python -m captum.insights.example
python -m captum.insights.attr_vis.example
```

and navigate to the URL specified in the output.
Expand Down
52 changes: 25 additions & 27 deletions captum/insights/attr_vis/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
import torchvision.transforms as transforms
from captum.insights import AttributionVisualizer, Batch

# pyre-fixme[21]: Could not find module
# `captum.insights.attr_vis.example.get_pretrained_model`.
from captum.insights.attr_vis.example.get_pretrained_model import Net
from captum.insights.attr_vis.features import ImageFeature


Expand All @@ -32,31 +29,32 @@ def get_classes() -> List[str]:
return classes


def get_pretrained_model() -> Net:
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()

def forward(self, x):
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x
class Net(nn.Module):
def __init__(self) -> None:
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.relu1 = nn.ReLU()
self.relu2 = nn.ReLU()
self.relu3 = nn.ReLU()
self.relu4 = nn.ReLU()

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.pool1(self.relu1(self.conv1(x)))
x = self.pool2(self.relu2(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = self.relu3(self.fc1(x))
x = self.relu4(self.fc2(x))
x = self.fc3(x)
return x


def get_pretrained_model() -> Net:
net = Net()
pt_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "models/cifar_torchvision.pt")
Expand Down
2 changes: 0 additions & 2 deletions captum/insights/attr_vis/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# pyre-strict
import logging
import os
import socket
import threading
from time import sleep
Expand Down Expand Up @@ -108,7 +107,6 @@ def start_server(

global port
if port is None:
os.environ["WERKZEUG_RUN_MAIN"] = "true" # hides starting message
if not debug:
log = logging.getLogger("werkzeug")
log.disabled = True
Expand Down

0 comments on commit 6540e74

Please sign in to comment.