Skip to content

Commit

Permalink
add filter method to LinkGraph (#269)
Browse files Browse the repository at this point in the history
* add examples to docstring

* add filter method to LinkGraph

* use python3.10 for code format check

* add stub for networkx to github action

* fix bug of accessing non-existent objects in link graph
  • Loading branch information
CunliangGeng authored Jul 15, 2024
1 parent 5d1ba6f commit 1e23e57
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 3 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/format-typing-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.10
uses: actions/setup-python@v3
with:
python-version: '3.10'
- name: Install ruff and mypy
run: |
pip install ruff mypy typing_extensions types-Deprecated types-beautifulsoup4 types-jsonschema pandas-stubs
pip install ruff mypy typing_extensions \
types-Deprecated types-beautifulsoup4 types-jsonschema types-networkx pandas-stubs
- name: Get all changed python files
id: changed-python-files
uses: tj-actions/changed-files@v44
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ dev = [
# static typing
"mypy",
"typing_extensions",
# stub packages
# stub packages. Update the `format-typing-check.yml` too if you add more.
"types-Deprecated",
"types-beautifulsoup4",
"types-jsonschema",
Expand Down
80 changes: 79 additions & 1 deletion src/nplinker/scoring/link_graph.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections.abc import Sequence
from functools import wraps
from typing import Union
from networkx import Graph
Expand Down Expand Up @@ -79,7 +80,7 @@ def __init__(self) -> None:
>>> lg[gcf]
{spectrum: {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})}}
Get all links:
Get all links in the LinkGraph:
>>> lg.links
[(gcf, spectrum, {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})})]
Expand Down Expand Up @@ -129,6 +130,10 @@ def links(
Returns:
A list of tuples containing the links between objects.
Examples:
>>> lg.links
[(gcf, spectrum, {"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})})]
"""
return list(self._g.edges(data=True))

Expand All @@ -150,6 +155,9 @@ def add_link(
data: keyword arguments. At least one scoring method and its data must be provided.
The key must be the name of the scoring method defined in `ScoringMethod`, and the
value is a `Score` object, e.g. `metcalf=Score("metcalf", 1.0, {"cutoff": 0.5})`.
Examples:
>>> lg.add_link(gcf, spectrum, metcalf=Score("metcalf", 1.0, {"cutoff": 0.5}))
"""
# validate the data
if not data:
Expand All @@ -174,6 +182,10 @@ def has_link(self, u: Entity, v: Entity) -> bool:
Returns:
True if there is a link between the two objects, False otherwise
Examples:
>>> lg.has_link(gcf, spectrum)
True
"""
return self._g.has_edge(u, v)

Expand All @@ -192,5 +204,71 @@ def get_link_data(
Returns:
A dictionary of scoring methods and their data for the link between the two objects, or
None if there is no link between the two objects.
Examples:
>>> lg.get_link_data(gcf, spectrum)
{"metcalf": Score("metcalf", 1.0, {"cutoff": 0.5})}
"""
return self._g.get_edge_data(u, v) # type: ignore

def filter(self, u_nodes: Sequence[Entity], v_nodes: Sequence[Entity] = [], /) -> LinkGraph:
"""Return a new LinkGraph object with the filtered links between the given objects.
The new LinkGraph object will only contain the links between `u_nodes` and `v_nodes`.
If `u_nodes` or `v_nodes` is empty, the new LinkGraph object will contain the links for
the given objects in `v_nodes` or `u_nodes`, respectively. If both are empty, return an
empty LinkGraph object.
Note that not all objects in `u_nodes` and `v_nodes` need to be present in the original
LinkGraph.
Args:
u_nodes: a sequence of objects used as the first object in the links
v_nodes: a sequence of objects used as the second object in the links
Returns:
A new LinkGraph object with the filtered links between the given objects.
Examples:
Filter the links for `gcf1` and `gcf2`:
>>> new_lg = lg.filter([gcf1, gcf2])
Filter the links for `spectrum1` and `spectrum2`:
>>> new_lg = lg.filter([spectrum1, spectrum2])
Filter the links between two lists of objects:
>>> new_lg = lg.filter([gcf1, gcf2], [spectrum1, spectrum2])
"""
lg = LinkGraph()

# exchange u_nodes and v_nodes if u_nodes is empty but v_nodes not
if len(u_nodes) == 0 and len(v_nodes) != 0:
u_nodes = v_nodes
v_nodes = []

if len(v_nodes) == 0:
for u in u_nodes:
self._filter_one_node(u, lg)

for u in u_nodes:
for v in v_nodes:
self._filter_two_nodes(u, v, lg)

return lg

@validate_u
def _filter_one_node(self, u: Entity, lg: LinkGraph) -> None:
"""Filter the links for a given object and add them to the new LinkGraph object."""
try:
links = self[u]
except KeyError:
pass
else:
for node2, value in links.items():
lg.add_link(u, node2, **value)

@validate_uv
def _filter_two_nodes(self, u: Entity, v: Entity, lg: LinkGraph) -> None:
"""Filter the links between two objects and add them to the new LinkGraph object."""
link_data = self.get_link_data(u, v)
if link_data is not None:
lg.add_link(u, v, **link_data)
29 changes: 29 additions & 0 deletions tests/unit/scoring/test_link_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,32 @@ def test_has_link(lg, gcfs, spectra):
def test_get_link_data(lg, gcfs, spectra, score):
assert lg.get_link_data(gcfs[0], spectra[0]) == {"metcalf": score}
assert lg.get_link_data(gcfs[0], spectra[1]) is None


def test_filter(gcfs, spectra, score):
lg = LinkGraph()
lg.add_link(gcfs[0], spectra[0], metcalf=score)
lg.add_link(gcfs[1], spectra[1], metcalf=score)

u_nodes = [gcfs[0], gcfs[1], gcfs[2]]
v_nodes = [spectra[0], spectra[1], spectra[2]]

# test filtering with GCFs
lg_filtered = lg.filter(u_nodes)
assert len(lg_filtered) == 4 # number of nodes

# test filtering with Spectra
lg_filtered = lg.filter(v_nodes)
assert len(lg_filtered) == 4

# test empty `u_nodes` argument
lg_filtered = lg.filter([], v_nodes)
assert len(lg_filtered) == 4

# test empty `u_nodes` and `v_nodes` arguments
lg_filtered = lg.filter([], [])
assert len(lg_filtered) == 0

# test filtering with GCFs and Spectra
lg_filtered = lg.filter(u_nodes, v_nodes)
assert len(lg_filtered) == 4

0 comments on commit 1e23e57

Please sign in to comment.