diff --git a/README.md b/README.md index 59735ce..b64e492 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,8 @@ plt.style.use(["cosmoplots.default"]) ``` ### Muliple subfigures -To make a figure with multiple rows or columns, use `cosmoplots.figure_multiple_rows_columns`: +To make a figure with multiple rows or columns, use `cosmoplots.figure_multiple_rows_columns`. +By default, the labels are $\mathrm{(a)}$, $\mathrm{(b)}$, $\mathrm{(c)}$, ..., but they may be replaced using the `labels` argument. ```python import matplotlib.pyplot as plt import cosmoplots diff --git a/assets/multifig.png b/assets/multifig.png index eeb3bad..0d495d7 100644 Binary files a/assets/multifig.png and b/assets/multifig.png differ diff --git a/cosmoplots/axes.py b/cosmoplots/axes.py index a32c81a..a830496 100644 --- a/cosmoplots/axes.py +++ b/cosmoplots/axes.py @@ -91,7 +91,10 @@ def change_log_axis_base( ) return axes -def figure_multiple_rows_columns(rows: int, columns: int) -> Tuple[Figure, List[Axes]]: +def figure_multiple_rows_columns(rows: int, columns: int, + labels: Union[List[str], None] = None, + label_x: float = -0.2, label_y: float = 0.95, + **kwargs) -> Tuple[Figure, List[Axes]]: """Returns a figure with axes which is appropriate for (rows, columns) subfigures. Parameters @@ -100,16 +103,23 @@ def figure_multiple_rows_columns(rows: int, columns: int) -> Tuple[Figure, List[ The number of rows in the figure columns : int The number of columns in the figure - + labels : List[str] | None + The labels to be applied to each subfigure. Defaults to (a), (b), (c), ... + label_x and label_y: float + x- and y- positions of the labels relative to each Axes object. + **kwargs: + Additional keyword arguments to be passed to Axes.text. + Returns ------- plt.Figure The figure object - plt.Axes + List[plt.Axes] A list of all the axes objects owned by the figure """ fig = plt.figure(figsize = (columns*3.37, rows*2.08277)) axes = [] + labels = labels or [r"$\mathrm{{({})}}$".format(chr(97+l)) for l in range(rows*columns)] for c in range(columns): for r in range(rows): left = (0.2)/columns + c/columns @@ -117,6 +127,7 @@ def figure_multiple_rows_columns(rows: int, columns: int) -> Tuple[Figure, List[ width = 0.75/columns height = 0.75/rows axes.append(fig.add_axes((left, bottom, width, height))) + axes[-1].text(label_x, label_y, labels[columns*r+c], transform=axes[-1].transAxes, **kwargs) return fig, axes