Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement GT.pipe() and GT.pipes() #363

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ quartodoc:
contents:
- GT.save
- GT.as_raw_html
- title: Pipeline
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about putting pipe in the helper functions section. @rich-iannone WDYT?

desc: >
Sometimes, you might want to programmatically manipulate the table while still benefiting
from the chained API that **Great Tables** offers. `pipe()` is designed to tackle this
issue.
contents:
- GT.pipe
- title: Value formatting functions
desc: >
If you have single values (or lists of them) in need of formatting, we have a set of
Expand Down
96 changes: 96 additions & 0 deletions great_tables/_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Any

if TYPE_CHECKING:
from .gt import GT

Check warning on line 6 in great_tables/_pipe.py

View check run for this annotation

Codecov / codecov/patch

great_tables/_pipe.py#L6

Added line #L6 was not covered by tests


def pipe(self: "GT", func: Callable[..., "GT"], *args: Any, **kwargs: Any) -> "GT":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we could use ParamSpec here? Something like...

from typing import ParamSpec, Callable

P = ParamSpec("P")
def punt(f: Callable[P, int], *args: P.args, **kwargs: P.kwargs) -> int:
    return f(*args, **kwargs)


def add(x: int, y: int) -> int:
    return x + y

# flagged in static check as z is not an arg to add
punt(add, 1, y = 2, z = 3)

https://mypy-play.net/?mypy=latest&python=3.12&gist=23dd2d435071c6e9f5639cfcadf7dd16

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Of course! This is somewhat new to me, so please feel free to correct me if I've misunderstood anything.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@machow and @rich-iannone, this PR has been pending for a while. Could we schedule it for merging, or are there any remaining concerns?

"""
Provide a structured way to chain a function for a GT object.

This function accepts a function that receives a GT object along with optional positional and
keyword arguments, returning a GT object. This allows users to easily integrate a function
into the chained API offered by **Great Tables**.

Parameters
----------
func
A function that receives a GT object along with optional positional and keyword arguments,
returning a GT object.

*args
Optional positional arguments to be passed to the function.

**kwargs
Optional keyword arguments to be passed to the function.

Returns
-------
gt
A GT object.

Examples:
------
Let's use the `name`, `land_area_km2`, and `density_2021` columns of the `towny` dataset to
create a table. First, we'll demonstrate using two consecutive calls to the `.tab_style()`
method to highlight the maximum value of the `land_area_km2` column with `"lightgray"` and the
maximum value of the `density_2021` column with `"lightblue"`.

```{python}
import polars as pl
from great_tables import GT, loc, style
from great_tables.data import towny


towny_mini = pl.from_pandas(towny).head(10)

(
GT(
towny_mini[["name", "land_area_km2", "density_2021"]],
rowname_col="name",
)
.tab_style(
style=style.fill(color="lightgray"),
locations=loc.body(
columns="land_area_km2",
rows=pl.col("land_area_km2").eq(pl.col("land_area_km2").max()),
),
)
.tab_style(
style=style.fill(color="lightblue"),
locations=loc.body(
columns="density_2021",
rows=pl.col("density_2021").eq(pl.col("density_2021").max()),
),
)
)
```

Next, we'll demonstrate how to achieve the same result using the `.pipe()` method to
programmatically style each column.

```{python}
columns = ["land_area_km2", "density_2021"]
colors = ["lightgray", "lightblue"]


def tbl_style(gtbl: GT, columns: list[str], colors: list[str]) -> GT:
for column, color in zip(columns, colors):
gtbl = gtbl.tab_style(
style=style.fill(color=color),
locations=loc.body(columns=column, rows=pl.col(column).eq(pl.col(column).max())),
)
return gtbl


(
GT(
towny_mini[["name", "land_area_km2", "density_2021"]],
rowname_col="name",
).pipe(tbl_style, columns, colors)
)
```
"""
return func(self, *args, **kwargs)
3 changes: 3 additions & 0 deletions great_tables/gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from great_tables._boxhead import cols_align, cols_label
from great_tables._data_color import data_color
from great_tables._export import as_raw_html, save
from great_tables._pipe import pipe
from great_tables._formats import (
fmt,
fmt_bytes,
Expand Down Expand Up @@ -255,6 +256,8 @@ def __init__(
save = save
as_raw_html = as_raw_html

pipe = pipe

# -----

def _repr_html_(self):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_pipe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import polars as pl
from great_tables import GT, loc, style


def test_pipe():
columns = ["x", "y"]
colors = ["lightgray", "lightblue"]
df = pl.DataFrame(dict(zip(columns, [[1, 2, 3], [3, 2, 1]])))

gt1 = (
GT(df)
.tab_style(
style=style.fill(color=colors[0]),
locations=loc.body(
columns=columns[0], rows=pl.col(columns[0]).eq(pl.col(columns[0]).max())
),
)
.tab_style(
style=style.fill(color=colors[1]),
locations=loc.body(
columns=columns[1], rows=pl.col(columns[1]).eq(pl.col(columns[1]).max())
),
)
)

def tbl_style(gtbl: GT, columns: list[str], colors: list[str]) -> GT:
for column, color in zip(columns, colors):
gtbl = gtbl.tab_style(
style=style.fill(color=color),
locations=loc.body(columns=column, rows=pl.col(column).eq(pl.col(column).max())),
)
return gtbl

gt2 = GT(df).pipe(tbl_style, columns, colors) # check *args
gt3 = GT(df).pipe(tbl_style, columns=columns, colors=colors) # check **kwargs

assert gt1._styles == gt2._styles == gt3._styles