Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix false positive when iterating over Enum #312

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ vulture.egg-info/
.pytest_cache/
.tox/
.venv/
.vscode/
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
* Add `UnicodeEncodeError` exception handling to `core.py` (milanbalazs, #299).
* Add whitelist for `Enum` attributes `_name_` and `_value_` (Eugene Toder, #305).
* Fix false positive when iterating over `Enum` (anudaweerasinghe, pm3512, addykan, #304)

# 2.7 (2023-01-08)

Expand Down Expand Up @@ -301,4 +302,4 @@

# 0.1 (2012-03-17)

* First release.
* First release.
35 changes: 35 additions & 0 deletions tests/test_scavenging.py
Original file line number Diff line number Diff line change
Expand Up @@ -899,3 +899,38 @@ class Color(Enum):

check(v.unused_classes, [])
check(v.unused_vars, ["BLUE"])


def test_enum_list(v):
v.scan(
"""\
import enum
class E(enum.Enum):
A = 1
B = 2

print(list(E))
"""
)

check(v.defined_classes, ["E"])
check(v.defined_vars, ["A", "B"])
check(v.unused_vars, [])


def test_enum_for(v):
v.scan(
"""\
import enum
class E(enum.Enum):
A = 1
B = 2

for e in E:
print(e)
"""
)

check(v.defined_classes, ["E"])
check(v.defined_vars, ["A", "B", "e"])
check(v.unused_vars, [])
47 changes: 47 additions & 0 deletions vulture/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@ def get_list(typ):
self.code = []
self.found_dead_code_or_error = False

self.enum_class_vars = (
dict()
) # stores variables defined in enum classes

def scan(self, code, filename=""):
filename = Path(filename)
self.code = code.splitlines()
Expand Down Expand Up @@ -551,6 +555,18 @@ def visit_Call(self, node):
):
self._handle_new_format_string(node.func.value.s)

# handle enum.Enum members
iter_functions = ["list", "tuple", "set"]
if (
isinstance(node.func, ast.Name)
and node.func.id in iter_functions
and len(node.args) > 0
and isinstance(node.args[0], ast.Name)
):
arg = node.args[0].id
if arg in self.enum_class_vars:
self.used_names.update(self.enum_class_vars[arg])

def _handle_new_format_string(self, s):
def is_identifier(name):
return bool(re.match(r"[a-zA-Z_][a-zA-Z0-9_]*", name))
Expand Down Expand Up @@ -581,6 +597,20 @@ def _is_locals_call(node):
and not node.keywords
)

@staticmethod
def _is_subclass(node, class_name):
"""Return True if the node is a subclass of the given class."""
assert isinstance(node, ast.ClassDef)
for superclass in node.bases:
if (
isinstance(superclass, ast.Name)
and superclass.id == class_name
or isinstance(superclass, ast.Attribute)
and superclass.attr == class_name
):
return True
return False

def visit_ClassDef(self, node):
for decorator in node.decorator_list:
if _match(
Expand All @@ -594,6 +624,15 @@ def visit_ClassDef(self, node):
self._define(
self.defined_classes, node.name, node, ignore=_ignore_class
)
# if subclasses enum add class variables to enum_class_vars
if self._is_subclass(node, "Enum"):
newKey = node.name
classVariables = []
for stmt in node.body:
if isinstance(stmt, ast.Assign):
for target in stmt.targets:
classVariables.append(target.id)
self.enum_class_vars[newKey] = classVariables

def visit_FunctionDef(self, node):
decorator_names = [
Expand Down Expand Up @@ -661,6 +700,14 @@ def visit_Assign(self, node):
def visit_While(self, node):
self._handle_conditional_node(node, "while")

def visit_For(self, node):
# Handle iterating over Enum
if (
isinstance(node.iter, ast.Name)
and node.iter.id in self.enum_class_vars
):
self.used_names.update(self.enum_class_vars[node.iter.id])

def visit_MatchClass(self, node):
for kwd_attr in node.kwd_attrs:
self.used_names.add(kwd_attr)
Expand Down