Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fmt cannot accept multiple fns #327

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 12 additions & 10 deletions great_tables/_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,23 +99,25 @@ def fmt(
The GT object is returned. This is the same object that the method is called on so that we
can facilitate method chaining.
"""

# If a single function is supplied to `fns` then
# repackage that into a list as the `default` function
if isinstance(fns, Callable):
fns = FormatFns(default=fns)

row_res = resolve_rows_i(self, rows)
row_pos = [name_pos[1] for name_pos in row_res]

col_res = resolve_cols_c(self, columns)

formatter = FormatInfo(fns, col_res, row_pos)
# If a single function is supplied to `fns` then
# repackage that into a list as the `default` function
formatter: list[FormatInfo] = []
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think unfortunately GT.fmt never expects a list of functions as inputs, but uses the maybe trickily named FormatFns to mean, a dataclass that holds functions for different renderers (e.g. html, in the future latex, etc..).

It's a good point though that GT.fmt is an important entry point, so it's worth thinking carefully about its behavior. I think if it comes down to it, people could always pass lambda x: formatter2(formatter1(x))? Or use some implementation of a compose() function?

These docs:

    Parameters
    ----------
    fns
        Either a single formatting function or a named list of functions.

Unfortunately are copied over from R, and so named list is referring to something more like a dictionary (which we're currently using the FormatFns dataclass for)

Copy link
Collaborator Author

@jrycw jrycw May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about the possibility of implementing a compose() function. In my imagination, the code might look something like this:

from functools import partial
from typing import Any, Callable
from great_tables import GT, exibble


def compose(funcs: list[Callable[Any, Any]]) -> Callable:
    def _compose(x: Any, funcs: list[Callable[Any, Any]]) -> Any:
        for f in funcs:
            x = f(x)
        return x
    return partial(_compose, funcs=funcs)


(
    GT(exibble[["fctr"]]).fmt(
        compose([lambda x: f"_{x}_", lambda x: f"*{x}*", lambda x: f"<{x}>"]),
        columns="fctr",
    )
)

image

if isinstance(fns, Callable):
_fmtfs = FormatFns(default=fns)
formatter.append(FormatInfo(_fmtfs, col_res, row_pos))
elif isinstance(fns, list):
for _fns in fns:
_fmtfs = FormatFns(default=_fns)
formatter.append(FormatInfo(_fmtfs, col_res, row_pos))

if is_substitution:
return self._replace(_substitutions=[*self._substitutions, formatter])
return self._replace(_substitutions=[*self._substitutions, *formatter])

return self._replace(_formats=[*self._formats, formatter])
return self._replace(_formats=[*self._formats, *formatter])


def fmt_number(
Expand Down
12 changes: 7 additions & 5 deletions great_tables/_gt_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,17 @@ def render_formats(self, data_tbl: TblData, formats: list[FormatInfo], context:
eval_func = getattr(fmt.func, context, fmt.func.default)
if eval_func is None:
raise Exception("Internal Error")
cell_info = []
for col, row in fmt.cells.resolve():
result = eval_func(_get_cell(data_tbl, row, col))
if isinstance(result, FormatterSkipElement):
continue

# TODO: I think that this is very inefficient with polars, so
# we could either accumulate results and set them per column, or
# could always use a pandas DataFrame inside Body?
_set_cell(self.body, row, col, result)
cell_info.append((row, col, result))
# TODO: I think that this is very inefficient with polars, so
# we could either accumulate results and set them per column, or
# could always use a pandas DataFrame inside Body?
for _cell_info in cell_info:
data_tbl = _set_cell(self.body, *_cell_info)

return self

Expand Down
3 changes: 3 additions & 0 deletions great_tables/_tbl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,13 @@ def _(data, row: int, column: str, value: Any) -> None:
# if this is violated, get_loc will return a mask
col_indx = data.columns.get_loc(column)
data.iloc[row, col_indx] = value
return data
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since _set_cell() mutates data, it seems useful to continue not returning data (sometimes called Command query separation).



@_set_cell.register(PlDataFrame)
def _(data, row: int, column: str, value: Any) -> None:
data[row, column] = value
return data


# _get_column_dtype ----
Expand Down Expand Up @@ -297,6 +299,7 @@ def _(
elif callable(expr):
# TODO: currently, we call on each string, but we could be calling on
# pd.DataFrame.columns instead (which would let us use pandas .str methods)

col_pos = {k: ii for ii, k in enumerate(list(data.columns))}
return [(col, col_pos[col]) for col in data.columns if expr(col)]

Expand Down
16 changes: 16 additions & 0 deletions tests/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,22 @@ def test_format_fns():
assert res == ["2", "3"]


def test_format_multi_fns():
df = pd.DataFrame({"x": [1, 2]})
gt = GT(df)
new_gt = fmt(gt, fns=[lambda x: str(x + 1), lambda x: str(x + 2)], columns=["x"])

formats_fn = new_gt._formats[0]

res = list(map(formats_fn.func.default, df["x"]))
assert res == ["2", "3"]

formats_fn = new_gt._formats[1]

res = list(map(formats_fn.func.default, df["x"]))
assert res == ["3", "4"]


def test_format_snap(snapshot):
new_gt = (
GT(exibble)
Expand Down
Loading