diff --git a/Dockerfile b/Dockerfile index 97a9077..ed04f0b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,9 +1,9 @@ -FROM python:3.8-buster - -ENV DEBIAN_FRONTEND noninteractive - -COPY packages_setup.bash / -RUN chmod +x /packages_setup.bash && /packages_setup.bash - -COPY base_requirements.txt / +FROM python:3.8-buster + +ENV DEBIAN_FRONTEND noninteractive + +COPY packages_setup.bash / +RUN chmod +x /packages_setup.bash && /packages_setup.bash + +COPY base_requirements.txt / RUN pip install -r /base_requirements.txt \ No newline at end of file diff --git a/__init__.py b/__init__.py index 0d45ab9..e2bf39a 100644 --- a/__init__.py +++ b/__init__.py @@ -4,6 +4,8 @@ import builtins as __builtin__ import datetime +import re +import unicodedata from dataclasses import dataclass from functools import partial from pathlib import Path @@ -14,9 +16,15 @@ import pandas as pd from IPython.display import Markdown as md -from .charting import (Chart, ChartEncoding, altair_sw_theme, altair_theme, - enable_sw_charts) -from .df_extensions import space, viz, common +from .charting import ( + Chart, + ChartEncoding, + altair_sw_theme, + altair_theme, + enable_sw_charts, + ChartTitle +) +from .df_extensions import common, space, viz from .helpers.pipe import Pipe, Pipeline, iter_format from .management.exporters import render_to_html, render_to_markdown from .management.settings import settings @@ -32,7 +40,8 @@ def page_break(): - return md(""" + return md( + """ ```{=openxml} @@ -40,7 +49,8 @@ def page_break(): ``` -""") +""" + ) def notebook_setup(): @@ -51,7 +61,20 @@ def Date(x): return datetime.datetime.fromisoformat(x).date() -comma_thousands = '{:,}'.format -percentage_0dp = '{:,.0%}'.format -percentage_1dp = '{:,.1%}'.format -percentage_2dp = '{:,.2%}'.format +def slugify(value): + """ + Converts to lowercase, removes non-word characters (alphanumerics and + underscores) and converts spaces to hyphens. Also strips leading and + trailing whitespace. + """ + value = ( + unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii") + ) + value = re.sub("[^\w\s-]", "", value).strip().lower() + return re.sub("[-\s]+", "-", value) + + +comma_thousands = "{:,}".format +percentage_0dp = "{:,.0%}".format +percentage_1dp = "{:,.1%}".format +percentage_2dp = "{:,.2%}".format diff --git a/apis/google_api.py b/apis/google_api.py index 786c13b..14667b5 100644 --- a/apis/google_api.py +++ b/apis/google_api.py @@ -1,4 +1,3 @@ - import socket import sys from pathlib import Path @@ -8,49 +7,58 @@ from googleapiclient.discovery import build from googleapiclient.http import MediaFileUpload -SCOPES = ['https://www.googleapis.com/auth/drive', "https://www.googleapis.com/auth/documents", - 'https://www.googleapis.com/auth/script.projects'] +SCOPES = [ + "https://www.googleapis.com/auth/drive", + "https://www.googleapis.com/auth/documents", + "https://www.googleapis.com/auth/script.projects", +] url_template = "https://docs.google.com/document/d/{0}/edit" class DriveIntegration: - def __init__(self, data): self.creds = Credentials.from_authorized_user_info(data, SCOPES) - self.api = build('drive', 'v3', credentials=self.creds) + self.api = build("drive", "v3", credentials=self.creds) def upload_file(self, file_name, file_path, folder_id, drive_id): - body = {'name': file_name, 'driveID': drive_id, "parents": [ - folder_id], 'mimeType': 'application/vnd.google-apps.document'} + body = { + "name": file_name, + "driveID": drive_id, + "parents": [folder_id], + "mimeType": "application/vnd.google-apps.document", + } # Now create the media file upload object and tell it what file to upload, # in this case 'test.html' media = MediaFileUpload( file_path, - mimetype='application/vnd.openxmlformats-officedocument.wordprocessingml.document') + mimetype="application/vnd.openxmlformats-officedocument.wordprocessingml.document", + ) # Now we're doing the actual post, creating a new file of the uploaded type - uploaded = self.api.files().create(body=body, media_body=media, - supportsTeamDrives=True).execute() + uploaded = ( + self.api.files() + .create(body=body, media_body=media, supportsTeamDrives=True) + .execute() + ) url = url_template.format(uploaded["id"]) return url class ScriptIntergration: - def __init__(self, data): self.creds = Credentials.from_authorized_user_info(data, SCOPES) socket.setdefaulttimeout(600) # set timeout to 10 minutes - self.api = build('script', 'v1', credentials=self.creds) + self.api = build("script", "v1", credentials=self.creds) def get_function(self, script_id, function_name): - def inner(*args): request = {"function": function_name, "parameters": list(args)} - response = self.api.scripts().run(body=request, - scriptId=script_id).execute() + response = ( + self.api.scripts().run(body=request, scriptId=script_id).execute() + ) return response @@ -66,7 +74,7 @@ def trigger_log_in_flow(settings): json_creds = creds.to_json() print(f"GOOGLE_CLIENT_JSON={json_creds}") raise ValueError("Add the following last line printed to the .env") - + def test_settings(settings): """ @@ -74,7 +82,9 @@ def test_settings(settings): """ if "GOOGLE_APP_JSON" not in settings or settings["GOOGLE_APP_JSON"] == "": - raise ValueError("Missing GOOGLE_APP_JSON settings. See the notebook setup page in the wiki for the correct settings.") + raise ValueError( + "Missing GOOGLE_APP_JSON settings. See the notebook setup page in the wiki for the correct settings." + ) if "GOOGLE_CLIENT_JSON" not in settings or settings["GOOGLE_CLIENT_JSON"] == "": trigger_log_in_flow(settings) diff --git a/base_requirements.txt b/base_requirements.txt index d809100..a5c4353 100644 --- a/base_requirements.txt +++ b/base_requirements.txt @@ -12,8 +12,6 @@ altair==4.1.0 altair_saver==0.5.0 jupyter==1.0.0 pylint==2.9.1 -pycodestyle==2.7.0 -autopep8==1.5.7 ruamel.yaml==0.17.10 pypandoc==1.5 flake8==3.9.2 @@ -24,4 +22,6 @@ papermill==2.3.3 google-api-python-client==2.21.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==0.4.6 -ptitprince==0.2.5 \ No newline at end of file +ptitprince==0.2.5 +pylint==2.12.2 +black[jupyter]==21.12b0 \ No newline at end of file diff --git a/charting/__init__.py b/charting/__init__.py index c9bb0d1..a3e9a85 100644 --- a/charting/__init__.py +++ b/charting/__init__.py @@ -1,22 +1,21 @@ from . import theme as altair_theme from . import sw_theme as altair_sw_theme -from .chart import Chart, Renderer, ChartEncoding +from .chart import Chart, Renderer, ChartEncoding, ChartTitle from .saver import MSSaver, SWSaver, render, sw_render + import altair as alt -alt.themes.register('mysoc_theme', lambda: altair_theme.mysoc_theme) -alt.themes.enable('mysoc_theme') +alt.themes.register("mysoc_theme", lambda: altair_theme.mysoc_theme) +alt.themes.enable("mysoc_theme") -alt.renderers.register('mysoc_saver', render) -alt.renderers.enable('mysoc_saver') +alt.renderers.register("mysoc_saver", render) +alt.renderers.enable("mysoc_saver") def enable_sw_charts(): - alt.themes.register('societyworks_theme', lambda: altair_sw_theme.sw_theme) - alt.themes.enable('societyworks_theme') - - alt.renderers.register('sw_saver', sw_render) - alt.renderers.enable('sw_saver') - Renderer.default_renderer = SWSaver - + alt.themes.register("societyworks_theme", lambda: altair_sw_theme.sw_theme) + alt.themes.enable("societyworks_theme") + alt.renderers.register("sw_saver", sw_render) + alt.renderers.enable("sw_saver") + Renderer.default_renderer = SWSaver \ No newline at end of file diff --git a/charting/chart.py b/charting/chart.py index 15251e3..9054527 100644 --- a/charting/chart.py +++ b/charting/chart.py @@ -1,6 +1,6 @@ from functools import wraps from pathlib import Path -from typing import List, Optional +from typing import List, Optional, Union import altair as alt import pandas as pd @@ -15,8 +15,8 @@ class Renderer: def save_chart(chart, filename, scale_factor=1, **kwargs): """ - dumbed down version of altair save function that just assumes - we're sending extra properties to the embed options + dumbed down version of altair save function that just assumes + we're sending extra properties to the embed options """ if isinstance(filename, Path): # altair doesn't process paths right @@ -24,11 +24,58 @@ def save_chart(chart, filename, scale_factor=1, **kwargs): filename.parent.mkdir() filename = str(filename) - altair_save_chart(chart, - filename, - scale_factor=scale_factor, - embed_options=kwargs, - method=Renderer.default_renderer) + altair_save_chart( + chart, + filename, + scale_factor=scale_factor, + embed_options=kwargs, + method=Renderer.default_renderer, + ) + + +def split_text_to_line(text: str, cut_off: int = 60) -> List[str]: + """ + Split a string to meet line limit + """ + bits = text.split(" ") + rows = [] + current_item = [] + for b in bits: + if len(" ".join(current_item + [b])) > cut_off: + rows.append(" ".join(current_item)) + current_item = [] + current_item.append(b) + rows.append(" ".join(current_item)) + return rows + + +class ChartTitle(alt.TitleParams): + """ + Helper function for chart title + Includes better line wrapping + """ + + def __init__( + self, + title: Union[str, List[str]], + subtitle: Optional[Union[str, List[str]]] = None, + line_limit: int = 60, + **kwargs + ): + + if isinstance(title, str): + title_bits = split_text_to_line(title, line_limit) + else: + title_bits = title + + if isinstance(subtitle, str): + subtitle = [subtitle] + + kwargs["text"] = title_bits + if subtitle: + kwargs["subtitle"] = subtitle + + super().__init__(**kwargs) class MSDisplayMixIn: @@ -38,10 +85,11 @@ class MSDisplayMixIn: """ ignore_properties = ["_display_options"] + scale_factor = 1 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._display_options = {} + self._display_options = {"scale_factor": self.__class__.scale_factor} def display_options(self, **kwargs): """ @@ -81,6 +129,60 @@ def __or__(self, other): raise ValueError("Only Chart objects can be concatenated.") return hconcat(self, other) + @wraps(alt.Chart.properties) + def raw_properties(self, *args, **kwargs): + return super().properties(*args, **kwargs) + + def properties( + self, + title: Optional[Union[str, list, alt.TitleParams, ChartTitle]] = "", + title_line_limit: Optional[int] = 60, + subtitle: Optional[Union[str, list]] = None, + width: Optional[int] = None, + height: Optional[int] = None, + aspect: Optional[tuple] = (16, 9), + logo: bool = False, + caption: Optional[str] = "", + scale_factor: Optional[str] = None, + **kwargs + ) -> "Chart": + + args = {} + display_args = {"logo": logo, "caption": caption} + if scale_factor: + display_args["scale_factor"] = scale_factor + + if isinstance(title, str) or isinstance(title, list) or subtitle is not None: + args["title"] = ChartTitle( + title=title, subtitle=subtitle, line_limit=title_line_limit + ) + + if width and not height: + args["width"] = width + args["height"] = (width / aspect[0]) * aspect[1] + + if height and not width: + args["height"] = height + args["width"] = (height / aspect[1]) * aspect[0] + + if width and height: + args["height"] = height + args["width"] = width + + width_offset = 0 + height_offset = 0 + + if logo or caption: + height_offset += 100 + + if "width" in args: + args["width"] -= width_offset + args["height"] -= height_offset + args["autosize"] = alt.AutoSizeParams(type="fit", contains="padding") + + kwargs.update(args) + return super().properties(**kwargs).display_options(**display_args) + class MSDataManagementMixIn: """ @@ -93,6 +195,7 @@ class MSDataManagementMixIn: @classmethod def from_url(cls, url, n=0): from .download import get_chart_from_url + return get_chart_from_url(url, n) def _get_df(self) -> pd.DataFrame: @@ -102,7 +205,7 @@ def update_df(self, df: pd.DataFrame): """ take a new df and update the chart """ - self.datasets[self.data["name"]] = df.to_dict('records') + self.datasets[self.data["name"]] = df.to_dict("records") return self @property @@ -119,19 +222,23 @@ def __setattribute__(self, key, value): super().__setattribute__(key, value) -class Chart(MSDisplayMixIn, MSDataManagementMixIn, alt.Chart): +class MSAltair(MSDisplayMixIn, MSDataManagementMixIn): + pass + + +class Chart(MSAltair, alt.Chart): pass -class LayerChart(MSDisplayMixIn, MSDataManagementMixIn, alt.LayerChart): +class LayerChart(MSAltair, alt.LayerChart): pass -class HConcatChart(MSDisplayMixIn, MSDataManagementMixIn, alt.HConcatChart): +class HConcatChart(MSAltair, alt.HConcatChart): pass -class VConcatChart(MSDisplayMixIn, MSDataManagementMixIn, alt.VConcatChart): +class VConcatChart(MSAltair, alt.VConcatChart): pass @@ -153,6 +260,6 @@ def vconcat(*charts, **kwargs): @wraps(Chart.encode) def ChartEncoding(**kwargs): """ - Thin wrapper to specify properites we want to use multiple times + Thin wrapper to specify properties we want to use multiple times """ return kwargs diff --git a/charting/download.py b/charting/download.py index 65f216e..f77dc2d 100644 --- a/charting/download.py +++ b/charting/download.py @@ -1,10 +1,10 @@ - from typing import List from urllib.request import urlopen import altair as alt from .chart import Chart, LayerChart import json + def get_chart_spec_from_url(url: str) -> List[str]: """ For extracting chart specs produced by the research sites framework @@ -26,7 +26,8 @@ def json_to_chart(json_spec: str) -> alt.Chart: del di["layer"] del di["width"] chart = LayerChart.from_dict( - {"config": di["config"], "layer": [], "datasets": di["datasets"]}) + {"config": di["config"], "layer": [], "datasets": di["datasets"]} + ) for n, l in enumerate(layers): di_copy = di.copy() di_copy.update(l) @@ -43,11 +44,9 @@ def json_to_chart(json_spec: str) -> alt.Chart: return chart -def get_chart_from_url(url: str, - n: int = 0, - include_df: bool = False) -> alt.Chart: +def get_chart_from_url(url: str, n: int = 0, include_df: bool = False) -> alt.Chart: """ - given url, a number (0 indexed), get the spec, + given url, a number (0 indexed), get the spec, and reduce an altair chart instance. if `include_df` will try and reduce the original df as well. """ diff --git a/charting/saver.py b/charting/saver.py index 1f5e5a8..3f460f5 100644 --- a/charting/saver.py +++ b/charting/saver.py @@ -5,10 +5,15 @@ from altair_saver.types import JSONDict, Mimebundle from altair_saver._utils import extract_format, infer_mode_from_spec from functools import partial -from altair_saver.savers._selenium import (CDN_URL, EXTRACT_CODE, - HTML_TEMPLATE, JavascriptError, - MimebundleContent, SeleniumSaver, - get_bundled_script) +from altair_saver.savers._selenium import ( + CDN_URL, + EXTRACT_CODE, + HTML_TEMPLATE, + JavascriptError, + MimebundleContent, + SeleniumSaver, + get_bundled_script, +) import altair as alt @@ -144,7 +149,9 @@ def get_as_base64(url): class MSSaver(SeleniumSaver): - logo_url = "https://research.mysociety.org/sites/foi-monitor/static/img/mysociety-logo.jpg" + logo_url = ( + "https://research.mysociety.org/sites/foi-monitor/static/img/mysociety-logo.jpg" + ) font = "Source Sans Pro" def __init__(self, *args, **kwargs): @@ -156,8 +163,10 @@ def _get_font(self): def _get_logo(self): if self._logo is None: - self._logo = "data:image/jpg;base64," + \ - get_as_base64(self.__class__.logo_url).decode() + self._logo = ( + "data:image/jpg;base64," + + get_as_base64(self.__class__.logo_url).decode() + ) return self._logo def _extract(self, fmt: str) -> MimebundleContent: @@ -208,12 +217,9 @@ def _extract(self, fmt: str) -> MimebundleContent: opt = self._embed_options.copy() opt["mode"] = self._mode - extract_code = EXTRACT_CODE.replace( - "$$BASE64LOGO$$", str(self._get_logo())) - extract_code = extract_code.replace( - "$$FONT$$", str(self._get_font())) - result = driver.execute_async_script( - extract_code, self._spec, opt, fmt) + extract_code = EXTRACT_CODE.replace("$$BASE64LOGO$$", str(self._get_logo())) + extract_code = extract_code.replace("$$FONT$$", str(self._get_font())) + result = driver.execute_async_script(extract_code, self._spec, opt, fmt) if "error" in result: raise JavascriptError(result["error"]) return result["result"] @@ -254,8 +260,13 @@ def render( scale_factor = embed_options["scale_factor"] for fmt in fmts: - saver = Saver(spec, mode=mode, embed_options=embed_options, - scale_factor=scale_factor, **kwargs) + saver = Saver( + spec, + mode=mode, + embed_options=embed_options, + scale_factor=scale_factor, + **kwargs, + ) mimebundle.update(saver.mimebundle(fmt)) return mimebundle diff --git a/charting/sw_theme.py b/charting/sw_theme.py index 1e2554f..bc6d292 100644 --- a/charting/sw_theme.py +++ b/charting/sw_theme.py @@ -8,59 +8,58 @@ from typing import List, Any, Optional # brand colours -colours = {'colour_orange': '#f79421', - 'colour_off_white': '#f3f1eb', - 'colour_light_grey': '#e2dfd9', - 'colour_mid_grey': '#959287', - 'colour_dark_grey': '#6c6b68', - 'colour_black': '#333333', - 'colour_red': '#dd4e4d', - 'colour_yellow': '#fff066', - 'colour_violet': '#a94ca6', - 'colour_green': '#61b252', - 'colour_green_dark': '#53a044', - 'colour_green_dark_2': '#388924', - 'colour_blue': '#54b1e4', - 'colour_blue_dark': '#2b8cdb', - 'colour_blue_dark_2': '#207cba'} +colours = { + "colour_orange": "#f79421", + "colour_off_white": "#f3f1eb", + "colour_light_grey": "#e2dfd9", + "colour_mid_grey": "#959287", + "colour_dark_grey": "#6c6b68", + "colour_black": "#333333", + "colour_red": "#dd4e4d", + "colour_yellow": "#fff066", + "colour_violet": "#a94ca6", + "colour_green": "#61b252", + "colour_green_dark": "#53a044", + "colour_green_dark_2": "#388924", + "colour_blue": "#54b1e4", + "colour_blue_dark": "#2b8cdb", + "colour_blue_dark_2": "#207cba", +} # based on data visualisation colour palette -adjusted_colours = {"sw_yellow": "#fed876", - "sw_berry": "#e02653", - "sw_blue": "#0ba7d1", - "sw_dark_blue": "#065a70" - } +adjusted_colours = { + "sw_yellow": "#fed876", + "sw_berry": "#e02653", + "sw_blue": "#0ba7d1", + "sw_dark_blue": "#065a70", +} -monochrome_colours = {"colour_blue_light_20": "#7ddef8", - "colour_blue": "#0ba7d1", - "colour_blue_dark_20": "#076d88", - "colour_blue_dark_30": "#033340" - } +monochrome_colours = { + "colour_blue_light_20": "#7ddef8", + "colour_blue": "#0ba7d1", + "colour_blue_dark_20": "#076d88", + "colour_blue_dark_30": "#033340", +} all_colours = colours.copy() all_colours.update(adjusted_colours) all_colours.update(monochrome_colours) -palette = ["sw_yellow", - "sw_berry", - "sw_blue", - "sw_dark_blue"] +palette = ["sw_yellow", "sw_berry", "sw_blue", "sw_dark_blue"] -contrast_palette = ["sw_dark_blue", - "sw_yellow", - "sw_berry", - "sw_blue"] +contrast_palette = ["sw_dark_blue", "sw_yellow", "sw_berry", "sw_blue"] palette = contrast_palette -monochrome_palette = ["colour_blue_light_20", - "colour_blue", - "colour_blue_dark_20", - "colour_blue_dark_30" - ] +monochrome_palette = [ + "colour_blue_light_20", + "colour_blue", + "colour_blue_dark_20", + "colour_blue_dark_30", +] palette_colors = [adjusted_colours[x] for x in palette] @@ -72,22 +71,40 @@ # set default of colours original_palette = [ # Start with category10 color cycle: - "#1f77b4", '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', - '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", # Then continue with the paired lighter colors from category20: - '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', - '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5'] + "#aec7e8", + "#ffbb78", + "#98df8a", + "#ff9896", + "#c5b0d5", + "#c49c94", + "#f7b6d2", + "#c7c7c7", + "#dbdb8d", + "#9edae5", +] # use new palette for as long as possible -sw_palette_colors = palette_colors + original_palette[len(palette_colors):] +sw_palette_colors = palette_colors + original_palette[len(palette_colors) :] -def color_scale(domain: List[Any], - monochrome: bool = False, - reverse: bool = False, - palette: Optional[List[Any]] = None, - named_palette: Optional[List[Any]] = None - ) -> alt.Scale: +def color_scale( + domain: List[Any], + monochrome: bool = False, + reverse: bool = False, + palette: Optional[List[Any]] = None, + named_palette: Optional[List[Any]] = None, +) -> alt.Scale: if palette is None: if monochrome: palette = monochrome_palette_colors @@ -98,7 +115,7 @@ def color_scale(domain: List[Any], palette = [monochrome_colours[x] for x in named_palette] else: palette = [all_colours[x] for x in named_palette] - use_palette = palette[:len(domain)] + use_palette = palette[: len(domain)] if reverse: use_palette = use_palette[::-1] return alt.Scale(domain=domain, range=use_palette) @@ -107,70 +124,65 @@ def color_scale(domain: List[Any], font = "Lato" sw_theme = { - - 'config': { + "config": { "padding": {"left": 5, "top": 5, "right": 20, "bottom": 5}, - "title": {'font': font, - 'fontSize': 30, - 'anchor': "start" - }, - 'axis': { + "title": {"font": font, "fontSize": 30, "anchor": "start"}, + "axis": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, - 'offset': 0 + "titleFontSize": 16, + "offset": 0, }, - 'axisX': { + "axisX": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, - 'domain': True, - 'grid': True, + "titleFontSize": 16, + "domain": True, + "grid": True, "ticks": False, "gridWidth": 0.4, - 'labelPadding': 10, - + "labelPadding": 10, }, - 'axisY': { + "axisY": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, + "titleFontSize": 16, "titleAlign": "left", - 'labelPadding': 10, - 'domain': True, + "labelPadding": 10, + "domain": True, "ticks": False, "titleAngle": 0, "titleY": -10, "titleX": -50, "gridWidth": 0.4, }, - 'view': { + "view": { "stroke": "transparent", - 'continuousWidth': 700, - 'continuousHeight': 400 + "continuousWidth": 700, + "continuousHeight": 400, }, "line": { "strokeWidth": 3, }, "bar": {"color": palette_colors[0]}, - 'mark': {"shape": "cross"}, - 'legend': { - "orient": 'bottom', + "mark": {"shape": "cross"}, + "legend": { + "orient": "bottom", "labelFont": font, "labelFontSize": 12, "titleFont": font, "titleFontSize": 12, "title": "", "offset": 18, - "symbolType": 'square', - } + "symbolType": "square", + }, } } -sw_theme.setdefault('encoding', {}).setdefault('color', {})['scale'] = { - 'range': sw_palette_colors, +sw_theme.setdefault("encoding", {}).setdefault("color", {})["scale"] = { + "range": sw_palette_colors, } diff --git a/charting/theme.py b/charting/theme.py index 0e95e4f..3b2935a 100644 --- a/charting/theme.py +++ b/charting/theme.py @@ -8,59 +8,71 @@ from typing import List, Any, Optional # brand colours -colours = {'colour_orange': '#f79421', - 'colour_off_white': '#f3f1eb', - 'colour_light_grey': '#e2dfd9', - 'colour_mid_grey': '#959287', - 'colour_dark_grey': '#6c6b68', - 'colour_black': '#333333', - 'colour_red': '#dd4e4d', - 'colour_yellow': '#fff066', - 'colour_violet': '#a94ca6', - 'colour_green': '#61b252', - 'colour_green_dark': '#53a044', - 'colour_green_dark_2': '#388924', - 'colour_blue': '#54b1e4', - 'colour_blue_dark': '#2b8cdb', - 'colour_blue_dark_2': '#207cba'} +colours = { + "colour_orange": "#f79421", + "colour_off_white": "#f3f1eb", + "colour_light_grey": "#e2dfd9", + "colour_mid_grey": "#959287", + "colour_dark_grey": "#6c6b68", + "colour_black": "#333333", + "colour_red": "#dd4e4d", + "colour_yellow": "#fff066", + "colour_violet": "#a94ca6", + "colour_green": "#61b252", + "colour_green_dark": "#53a044", + "colour_green_dark_2": "#388924", + "colour_blue": "#54b1e4", + "colour_blue_dark": "#2b8cdb", + "colour_blue_dark_2": "#207cba", +} # based on data visualisation colour palette -adjusted_colours = {"colour_yellow": "#ffe269", - "colour_orange": "#f4a140", - "colour_berry": "#e02653", - "colour_purple": "#a94ca6", - "colour_blue": "#4faded", - "colour_dark_blue": "#0a4166"} - -monochrome_colours = {"colour_blue_light_20": "#acd8f6", - "colour_blue": "#4faded", - "colour_blue_dark_20": "#147cc2", - "colour_blue_dark_30": "#0f5e94", - "colour_blue_dark_40": "#0a4166", - "colour_blue_dark_50": "#062337", - } - -palette = ["colour_dark_blue", - "colour_berry", - "colour_orange", - "colour_blue", - "colour_purple", - "colour_yellow"] - -contrast_palette = ["colour_dark_blue", - "colour_yellow", - "colour_berry", - "colour_orange", - "colour_blue", - "colour_purple"] - -monochrome_palette = ["colour_blue_light_20", - "colour_blue", - "colour_blue_dark_20", - "colour_blue_dark_30", - "colour_blue_dark_40", - "colour_blue_dark_50", - ] +adjusted_colours = { + "colour_yellow": "#ffe269", + "colour_orange": "#f4a140", + "colour_berry": "#e02653", + "colour_purple": "#a94ca6", + "colour_blue": "#4faded", + "colour_dark_blue": "#0a4166", + "colour_mid_grey": "#959287", + "colour_dark_grey": "#6c6b68", +} + +monochrome_colours = { + "colour_blue_light_20": "#acd8f6", + "colour_blue": "#4faded", + "colour_blue_dark_20": "#147cc2", + "colour_blue_dark_30": "#0f5e94", + "colour_blue_dark_40": "#0a4166", + "colour_blue_dark_50": "#062337", +} + +palette = [ + "colour_dark_blue", + "colour_berry", + "colour_orange", + "colour_blue", + "colour_purple", + "colour_yellow", +] + +contrast_palette = [ + "colour_dark_blue", + "colour_yellow", + "colour_berry", + "colour_orange", + "colour_blue", + "colour_purple", +] + +monochrome_palette = [ + "colour_blue_light_20", + "colour_blue", + "colour_blue_dark_20", + "colour_blue_dark_30", + "colour_blue_dark_40", + "colour_blue_dark_50", +] palette_colors = [adjusted_colours[x] for x in palette] @@ -72,22 +84,40 @@ # set default of colours original_palette = [ # Start with category10 color cycle: - "#1f77b4", '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', - '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf', + "#1f77b4", + "#ff7f0e", + "#2ca02c", + "#d62728", + "#9467bd", + "#8c564b", + "#e377c2", + "#7f7f7f", + "#bcbd22", + "#17becf", # Then continue with the paired lighter colors from category20: - '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', - '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5'] + "#aec7e8", + "#ffbb78", + "#98df8a", + "#ff9896", + "#c5b0d5", + "#c49c94", + "#f7b6d2", + "#c7c7c7", + "#dbdb8d", + "#9edae5", +] # use new palette for as long as possible -mysoc_palette_colors = palette_colors + original_palette[len(palette_colors):] +mysoc_palette_colors = palette_colors + original_palette[len(palette_colors) :] -def color_scale(domain: List[Any], - monochrome: bool = False, - reverse: bool = False, - palette: Optional[List[Any]] = None, - named_palette: Optional[List[Any]] = None - ) -> alt.Scale: +def color_scale( + domain: List[Any], + monochrome: bool = False, + reverse: bool = False, + palette: Optional[List[Any]] = None, + named_palette: Optional[List[Any]] = None, +) -> alt.Scale: if palette is None: if monochrome: palette = monochrome_palette_colors @@ -98,7 +128,7 @@ def color_scale(domain: List[Any], palette = [monochrome_colours[x] for x in named_palette] else: palette = [adjusted_colours[x] for x in named_palette] - use_palette = palette[:len(domain)] + use_palette = palette[: len(domain)] if reverse: use_palette = use_palette[::-1] return alt.Scale(domain=domain, range=use_palette) @@ -107,70 +137,71 @@ def color_scale(domain: List[Any], font = "Source Sans Pro" mysoc_theme = { - - 'config': { + "config": { "padding": {"left": 5, "top": 5, "right": 20, "bottom": 5}, - "title": {'font': font, - 'fontSize': 30, - 'anchor': "start" - }, - 'axis': { + "title": { + "font": font, + "fontSize": 30, + "anchor": "start", + "subtitleFontSize": 20, + "subtitleFont": "Source Sans Pro", + }, + "axis": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, - 'offset': 0 + "titleFontSize": 16, + "offset": 0, }, - 'axisX': { + "axisX": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, - 'domain': True, - 'grid': True, + "titleFontSize": 16, + "domain": True, + "grid": True, "ticks": False, "gridWidth": 0.4, - 'labelPadding': 10, - + "labelPadding": 10, }, - 'axisY': { + "axisY": { "labelFont": font, "labelFontSize": 14, "titleFont": font, - 'titleFontSize': 16, + "titleFontSize": 16, "titleAlign": "left", - 'labelPadding': 10, - 'domain': True, + "labelPadding": 10, + "domain": True, "ticks": False, "titleAngle": 0, "titleY": -10, "titleX": -50, "gridWidth": 0.4, }, - 'view': { + "view": { "stroke": "transparent", - 'continuousWidth': 700, - 'continuousHeight': 400 + "continuousWidth": 700, + "continuousHeight": 400, }, "line": { "strokeWidth": 3, }, "bar": {"color": palette_colors[0]}, - 'mark': {"shape": "cross"}, - 'legend': { - "orient": 'bottom', + "mark": {"shape": "cross"}, + "legend": { + "orient": "bottom", "labelFont": font, "labelFontSize": 12, "titleFont": font, "titleFontSize": 12, "title": "", "offset": 18, - "symbolType": 'square', - } + "symbolType": "square", + }, } } -mysoc_theme.setdefault('encoding', {}).setdefault('color', {})['scale'] = { - 'range': mysoc_palette_colors, +mysoc_theme.setdefault("encoding", {}).setdefault("color", {})["scale"] = { + "range": mysoc_palette_colors, } diff --git a/df_extensions/common.py b/df_extensions/common.py index 2e11e43..663fff9 100644 --- a/df_extensions/common.py +++ b/df_extensions/common.py @@ -1,5 +1,6 @@ import pandas as pd + @pd.api.extensions.register_series_accessor("common") class CommonAccessor(object): """ diff --git a/df_extensions/space.py b/df_extensions/space.py index 2331686..60ab533 100644 --- a/df_extensions/space.py +++ b/df_extensions/space.py @@ -22,18 +22,17 @@ def hex_to_rgb(value): - value = value.lstrip('#') + value = value.lstrip("#") lv = len(value) - t = tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3)) + t = tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3)) return t + (0,) def fnormalize(s): - return (s - s.mean())/s.std() + return (s - s.mean()) / s.std() class mySocMap(Colormap): - def __call__(self, X, alpha=None, bytes=False): return mysoc_palette_colors[int(X)] @@ -43,22 +42,24 @@ class Cluster: Helper class for finding kgram clusters. """ - def __init__(self, - source_df: pd.DataFrame, - id_col: Optional[str] = None, - cols: Optional[List[str]] = None, - label_cols: Optional[List[str]] = None, - normalize: bool = True, - transform: List[Callable] = None, - k: Optional[int] = 2): - """ - Initalised with a dataframe, an id column in that dataframe, - the columns that are the dimensions in question. - A 'normalize' paramater on if those columns should be - normalised before use. - and 'label_cols' which are columns that contain categories - for items. - These can be used to help understand clusters. + def __init__( + self, + source_df: pd.DataFrame, + id_col: Optional[str] = None, + cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + normalize: bool = True, + transform: List[Callable] = None, + k: Optional[int] = 2, + ): + """ + Initalised with a dataframe, an id column in that dataframe, + the columns that are the dimensions in question. + A 'normalize' paramater on if those columns should be + normalised before use. + and 'label_cols' which are columns that contain categories + for items. + These can be used to help understand clusters. """ self.default_seed = 1221 @@ -86,7 +87,10 @@ def __init__(self, if cols: df = df[cols] else: - def t(x): return x != id_col and x not in label_cols + + def t(x): + return x != id_col and x not in label_cols + cols = list(filter(t, source_df.columns)) if normalize: @@ -102,7 +106,8 @@ def t(x): return x != id_col and x not in label_cols not_allowed = cols + label_cols label_df = label_df.drop( - columns=[x for x in label_df.columns if x not in not_allowed]) + columns=[x for x in label_df.columns if x not in not_allowed] + ) for c in cols: try: labels = ["Low", "Medium", "High"] @@ -149,8 +154,8 @@ def get_cluster_labels(self, include_short=True) -> np.array: labels = pd.Series(self.get_clusters(self.k).labels_) def f(x): - return self.get_label_name(n=x, - include_short=include_short) + return self.get_label_name(n=x, include_short=include_short) + labels = labels.apply(f) return np.array(labels) @@ -179,21 +184,20 @@ def add_labels(self, labels: Dict[int, Union[str, Tuple[str, str]]]): return new - def assign_name(self, - n: int, - name: str, - desc: Optional[str] = ""): + def assign_name(self, n: int, name: str, desc: Optional[str] = ""): k = self.k if k not in self.label_names: self.label_names[k] = {} self.label_descs[k] = {} - self.label_names[k][n-1] = name - self.label_descs[k][n-1] = desc + self.label_names[k][n - 1] = name + self.label_descs[k][n - 1] = desc - def plot(self, - limit_columns: Optional[List[str]] = None, - only_one: Optional[Any] = None, - show_legend: bool = True): + def plot( + self, + limit_columns: Optional[List[str]] = None, + only_one: Optional[Any] = None, + show_legend: bool = True, + ): """ Plot either all possible x, y graphs for k clusters or just the subset with the named x_var and y_var. @@ -207,16 +211,15 @@ def plot(self, if limit_columns: vars = [x for x in vars if x in limit_columns] combos = list(combinations(vars, 2)) - rows = math.ceil(len(combos)/num_rows) + rows = math.ceil(len(combos) / num_rows) - plt.rcParams["figure.figsize"] = (15, 5*rows) + plt.rcParams["figure.figsize"] = (15, 5 * rows) df["labels"] = self.get_cluster_labels() if only_one: df["labels"] = df["labels"] == only_one - df["labels"] = df["labels"].map( - {True: only_one, False: "Other clusters"}) + df["labels"] = df["labels"].map({True: only_one, False: "Other clusters"}) chart_no = 0 rgb_values = sns.color_palette("Set2", len(df["labels"].unique())) @@ -227,8 +230,7 @@ def plot(self, chart_no += 1 ax = fig.add_subplot(rows, num_rows, chart_no) for c, d in df.groupby("labels"): - scatter = ax.scatter(d[x_var], d[y_var], - color=color_map[c], label=c) + scatter = ax.scatter(d[x_var], d[y_var], color=color_map[c], label=c) ax.set_xlabel(self._axis_label(x_var)) ax.set_ylabel(self._axis_label(y_var)) @@ -238,21 +240,26 @@ def plot(self, plt.show() def plot_tool(self): - def func(cluster, show_legend, **kwargs): if cluster == "All": cluster = None limit_columns = [x for x, y in kwargs.items() if y is True] - self.plot(only_one=cluster, show_legend=show_legend, - limit_columns=limit_columns) + self.plot( + only_one=cluster, show_legend=show_legend, limit_columns=limit_columns + ) cluster_options = ["All"] + self.get_label_options() - analysis_options = {x: True if n < - 2 else False for n, x in enumerate(self.cols)} + analysis_options = { + x: True if n < 2 else False for n, x in enumerate(self.cols) + } - tool = interactive(func, cluster=cluster_options, ** - analysis_options, show_legend=False,) + tool = interactive( + func, + cluster=cluster_options, + **analysis_options, + show_legend=False, + ) display(tool) def _get_clusters(self, k: int): @@ -270,10 +277,7 @@ def get_clusters(self, k: int): self.cluster_results[k] = self._get_clusters(k) return self.cluster_results[k] - def find_k(self, - start: int = 15, - stop: Optional[int] = None, - step: int = 1): + def find_k(self, start: int = 15, stop: Optional[int] = None, step: int = 1): """ Graph the elbow and Silhouette method for finding the optimal k. High silhouette value good. @@ -284,9 +288,7 @@ def find_k(self, start = 2 def s_score(kmeans): - return silhouette_score(self.df, - kmeans.labels_, - metric='euclidean') + return silhouette_score(self.df, kmeans.labels_, metric="euclidean") df = pd.DataFrame({"n": range(start, stop, step)}) df["k_means"] = df["n"].apply(self.get_clusters) @@ -295,21 +297,19 @@ def s_score(kmeans): plt.rcParams["figure.figsize"] = (10, 5) plt.subplot(1, 2, 1) - plt.plot(df["n"], df["sum_squares"], 'bx-') - plt.xlabel('k') - plt.ylabel('Sum of squared distances') - plt.title('Elbow Method For Optimal k') + plt.plot(df["n"], df["sum_squares"], "bx-") + plt.xlabel("k") + plt.ylabel("Sum of squared distances") + plt.title("Elbow Method For Optimal k") plt.subplot(1, 2, 2) - plt.plot(df["n"], df["silhouette"], 'bx-') - plt.xlabel('k') - plt.ylabel('Silhouette score') - plt.title('Silhouette Method For Optimal k') + plt.plot(df["n"], df["silhouette"], "bx-") + plt.xlabel("k") + plt.ylabel("Silhouette score") + plt.title("Silhouette Method For Optimal k") plt.show() - def stats(self, - label_lookup: Optional[dict] = None, - all_members: bool = False): + def stats(self, label_lookup: Optional[dict] = None, all_members: bool = False): """ Simple description of sample size """ @@ -321,8 +321,7 @@ def stats(self, df.index = self.df.index df = df.reset_index() - pt = df.pivot_table(self.id_col, - index="labels", aggfunc="count") + pt = df.pivot_table(self.id_col, index="labels", aggfunc="count") pt = pt.rename(columns={self.id_col: "count"}) pt["%"] = (pt["count"] / len(df)).round(3) * 100 @@ -346,11 +345,13 @@ def random_set(s: List[str]) -> List[str]: pt = pt.rename(columns={self.id_col: "random members"}) return pt - def raincloud(self, - column: str, - one_value: Optional[str] = None, - groups: Optional[str] = "Cluster", - use_source: bool = True): + def raincloud( + self, + column: str, + one_value: Optional[str] = None, + groups: Optional[str] = "Cluster", + use_source: bool = True, + ): """ raincloud plot of a variable, grouped by different clusters @@ -361,8 +362,12 @@ def raincloud(self, else: df = self.df df["Cluster"] = self.get_cluster_labels() - df.viz.raincloud(values=column, groups=groups, one_value=one_value, - title=f"Raincloud plot for {column} variable.") + df.viz.raincloud( + values=column, + groups=groups, + one_value=one_value, + title=f"Raincloud plot for {column} variable.", + ) def reverse_raincloud(self, cluster_label: str): """ @@ -371,21 +376,24 @@ def reverse_raincloud(self, cluster_label: str): """ df = self.df.copy() df["Cluster"] = self.get_cluster_labels() - df = df.melt("Cluster")[lambda df:~(df["variable"] == " ")] + df = df.melt("Cluster")[lambda df: ~(df["variable"] == " ")] df["value"] = df["value"].astype(float) - df = df[lambda df:(df["Cluster"] == cluster_label)] + df = df[lambda df: (df["Cluster"] == cluster_label)] - df.viz.raincloud(values="value", - groups="variable", - title=f"Raincloud plot for Cluster: {cluster_label}") + df.viz.raincloud( + values="value", + groups="variable", + title=f"Raincloud plot for Cluster: {cluster_label}", + ) def reverse_raincloud_tool(self): """ Raincloud tool to examine clusters showing the distribution of different variables """ - tool = interactive(self.reverse_raincloud, - cluster_label=self.get_label_options()) + tool = interactive( + self.reverse_raincloud, cluster_label=self.get_label_options() + ) display(tool) def raincloud_tool(self, reverse: bool = False): @@ -404,14 +412,20 @@ def func(variable, comparison, use_source_values): if comparison == "none": groups = None comparison = None - self.raincloud(variable, one_value=comparison, - groups=groups, use_source=use_source_values) + self.raincloud( + variable, + one_value=comparison, + groups=groups, + use_source=use_source_values, + ) comparison_options = ["all", "none"] + self.get_label_options() - tool = interactive(func, - variable=self.cols, - use_source_values=True, - comparison=comparison_options) + tool = interactive( + func, + variable=self.cols, + use_source_values=True, + comparison=comparison_options, + ) display(tool) @@ -425,23 +439,27 @@ def label_tool(self): def func(cluster, sort, include_data_labels): if sort == "Index": sort = None - df = self.label_review(label=cluster, - sort=sort, - include_data=include_data_labels) + df = self.label_review( + label=cluster, sort=sort, include_data=include_data_labels + ) display(df) return df sort_options = ["Index", "% of cluster", "% of label"] - tool = interactive(func, - cluster=self.get_label_options(), - sort=sort_options, - include_data_labels=True) + tool = interactive( + func, + cluster=self.get_label_options(), + sort=sort_options, + include_data_labels=True, + ) display(tool) - def label_review(self, - label: Optional[int] = 1, - sort: Optional[str] = None, - include_data: bool = True): + def label_review( + self, + label: Optional[int] = 1, + sort: Optional[str] = None, + include_data: bool = True, + ): """ Review labeled data for a cluster """ @@ -451,9 +469,9 @@ def label_review(self, def to_count_pivot(df): mdf = df.drop(columns=["label"]).melt() mdf["Count"] = mdf["variable"] + mdf["value"] - return mdf.pivot_table("Count", - index=["variable", "value"], - aggfunc="count") + return mdf.pivot_table( + "Count", index=["variable", "value"], aggfunc="count" + ) df = self.label_df if include_data is False: @@ -464,8 +482,7 @@ def to_count_pivot(df): pt = to_count_pivot(df).join(opt) pt = pt.rename(columns={"Count": "cluster_count"}) pt["% of cluster"] = (pt["cluster_count"] / len(df)).round(3) * 100 - pt["% of label"] = (pt["cluster_count"] / - pt["overall_count"]).round(3) * 100 + pt["% of label"] = (pt["cluster_count"] / pt["overall_count"]).round(3) * 100 if sort: pt = pt.sort_values(sort, ascending=False) return pt @@ -490,10 +507,12 @@ def df_with_labels(self) -> pd.DataFrame: df["label_desc"] = self.get_cluster_descs() return df - def plot3d(self, - x_var: Optional[str] = None, - y_var: Optional[str] = None, - z_var: Optional[str] = None): + def plot3d( + self, + x_var: Optional[str] = None, + y_var: Optional[str] = None, + z_var: Optional[str] = None, + ): k = self.k """ Plot either all possible x, y, z graphs for k clusters @@ -509,9 +528,9 @@ def plot3d(self, combos = [x for x in combos if x[1] == y_var] if z_var: combos = [x for x in combos if x[1] == y_var] - rows = math.ceil(len(combos)/2) + rows = math.ceil(len(combos) / 2) - plt.rcParams["figure.figsize"] = (20, 10*rows) + plt.rcParams["figure.figsize"] = (20, 10 * rows) chart_no = 0 fig = plt.figure() @@ -522,7 +541,7 @@ def plot3d(self, ax.set_xlabel(self._axis_label(x_var)) ax.set_ylabel(self._axis_label(y_var)) ax.set_zlabel(self._axis_label(z_var)) - plt.title(f'Data with {k} clusters') + plt.title(f"Data with {k} clusters") plt.show() @@ -537,13 +556,14 @@ def join_distance(df_label_dict: Dict[str, pd.DataFrame]) -> pd.DataFrame: def prepare(df, label): - return (df - .set_index(list(df.columns[:2])) - .rename(columns={"distance": label}) - .drop(columns=["match", "position"], errors="ignore")) + return ( + df.set_index(list(df.columns[:2])) + .rename(columns={"distance": label}) + .drop(columns=["match", "position"], errors="ignore") + ) to_join = [prepare(df, label) for label, df in df_label_dict.items()] - df = reduce(lambda x, y: x.join(y), to_join) + df = reduce(lambda x, y: x.join(y), to_join) df = df.reset_index() return df @@ -557,29 +577,35 @@ class SpacePDAccessor(object): def __init__(self, pandas_obj): self._obj = pandas_obj - def cluster(self, - id_col: Optional[str] = None, - cols: Optional[List[str]] = None, - label_cols: Optional[List[str]] = None, - normalize: bool = True, - transform: List[Callable] = None, - k: Optional[int] = None) -> Cluster: + def cluster( + self, + id_col: Optional[str] = None, + cols: Optional[List[str]] = None, + label_cols: Optional[List[str]] = None, + normalize: bool = True, + transform: List[Callable] = None, + k: Optional[int] = None, + ) -> Cluster: """ returns a Cluster helper object for this dataframe """ - return Cluster(self._obj, - id_col=id_col, - cols=cols, - label_cols=label_cols, - normalize=normalize, - transform=transform, - k=k) - - def self_distance(self, - id_col: Optional[str] = None, - cols: Optional[List] = None, - normalize: bool = False, - transform: List[callable] = None): + return Cluster( + self._obj, + id_col=id_col, + cols=cols, + label_cols=label_cols, + normalize=normalize, + transform=transform, + k=k, + ) + + def self_distance( + self, + id_col: Optional[str] = None, + cols: Optional[List] = None, + normalize: bool = False, + transform: List[callable] = None, + ): """ Calculate the distance between all objects in a dataframe in an n-dimensional space. @@ -606,8 +632,7 @@ def self_distance(self, a_col = id_col + "_A" b_col = id_col + "_B" - _ = list(product(source_df[id_col], - source_df[id_col])) + _ = list(product(source_df[id_col], source_df[id_col])) df = pd.DataFrame(_, columns=[a_col, b_col]) @@ -626,10 +651,12 @@ def self_distance(self, df = df.loc[~(df[a_col] == df[b_col])] return df - def join_distance(self, - other: Union[Dict[str, pd.DataFrame], pd.DataFrame], - our_label: Optional[str] = "A", - their_label: Optional[str] = "B"): + def join_distance( + self, + other: Union[Dict[str, pd.DataFrame], pd.DataFrame], + our_label: Optional[str] = "A", + their_label: Optional[str] = "B", + ): """ Either merges self and other (both of whichs hould be the result of @@ -639,8 +666,7 @@ def join_distance(self, """ if not isinstance(other, dict): - df_label_dict = {our_label: self._obj, - their_label: other} + df_label_dict = {our_label: self._obj, their_label: other} else: df_label_dict = other @@ -656,18 +682,18 @@ def match_distance(self): def standardise_distance(df): df = df.copy() # use tenth from last because the last point might be an extreme outlier (in this case london) - tenth_from_last_score = df["distance"].sort_values().tail( - 10).iloc[0] + tenth_from_last_score = df["distance"].sort_values().tail(10).iloc[0] df["match"] = 1 - (df["distance"] / tenth_from_last_score) df["match"] = df["match"].round(3) * 100 df["match"] = df["match"].apply(lambda x: x if x > 0 else 0) df = df.sort_values("match", ascending=False) return df - return (df - .groupby(df.columns[0], as_index=False) - .apply(standardise_distance) - .reset_index(drop=True)) + return ( + df.groupby(df.columns[0], as_index=False) + .apply(standardise_distance) + .reset_index(drop=True) + ) def local_rankings(self): """ @@ -679,10 +705,11 @@ def get_position(df): df["position"] = df["distance"].rank(method="first") return df - return (df - .groupby(df.columns[0], as_index=False) - .apply(get_position) - .reset_index(drop=True)) + return ( + df.groupby(df.columns[0], as_index=False) + .apply(get_position) + .reset_index(drop=True) + ) @pd.api.extensions.register_dataframe_accessor("joint_space") @@ -735,20 +762,14 @@ def same_nearest_k(self, k: int = 5): df = self._obj def top_k(df, k=5): - df = (df - .set_index(list(df.columns[:2])) - .rank()) + df = df.set_index(list(df.columns[:2])).rank() df = df <= k - same_rank = df.sum(axis=1).reset_index( - drop=True) == len(list(df.columns)) + same_rank = df.sum(axis=1).reset_index(drop=True) == len(list(df.columns)) data = [[same_rank.sum() / k]] d = pd.DataFrame(data, columns=[f"same_top_{k}"]) return d.iloc[0] - return (df - .groupby(df.columns[0]) - .apply(top_k, k=k) - .reset_index()) + return df.groupby(df.columns[0]).apply(top_k, k=k).reset_index() def agreement(self, ks: List[int] = [1, 2, 3, 5, 10, 25]): """ @@ -759,10 +780,7 @@ def agreement(self, ks: List[int] = [1, 2, 3, 5, 10, 25]): df = self._obj def get_average(k): - return (df - .joint_space.same_nearest_k(k=k) - .mean() - .round(2)[0]) + return df.joint_space.same_nearest_k(k=k).mean().round(2)[0] r = pd.DataFrame({"top_k": ks}) r["agreement"] = r["top_k"].apply(get_average) @@ -776,5 +794,4 @@ def plot(self, sample=None, kind="scatter", title="", **kwargs): if sample: df = df.sample(sample) plt.rcParams["figure.figsize"] = (10, 5) - df.plot(x=df.columns[2], y=df.columns[3], - kind=kind, title=title, **kwargs) + df.plot(x=df.columns[2], y=df.columns[3], kind=kind, title=title, **kwargs) diff --git a/df_extensions/viz.py b/df_extensions/viz.py index 99189be..3ef44f1 100644 --- a/df_extensions/viz.py +++ b/df_extensions/viz.py @@ -12,20 +12,20 @@ @pd.api.extensions.register_series_accessor("viz") class VIZSeriesAccessor: - def __init__(self, pandas_obj): self._obj = pandas_obj - def raincloud(self, - groups: Optional[pd.Series] = None, - ort: Optional[str] = "h", - pal: Optional[str] = "Set2", - sigma: Optional[float] = .2, - title: str = "", - all_data_label: str = "All data", - x_label: Optional[str] = None, - y_label: Optional[str] = None, - ): + def raincloud( + self, + groups: Optional[pd.Series] = None, + ort: Optional[str] = "h", + pal: Optional[str] = "Set2", + sigma: Optional[float] = 0.2, + title: str = "", + all_data_label: str = "All data", + x_label: Optional[str] = None, + y_label: Optional[str] = None, + ): """ show a raincloud plot of the values of a series Optional split by a second series (group) @@ -42,9 +42,17 @@ def raincloud(self, df[" "] = all_data_label x_col = " " - f, ax = plt.subplots(figsize=(14, 2*df[x_col].nunique())) - pt.RainCloud(x=df[x_col], y=df[s.name], palette=pal, bw=sigma, - width_viol=.6, ax=ax, orient=ort, move=.3) + f, ax = plt.subplots(figsize=(14, 2 * df[x_col].nunique())) + pt.RainCloud( + x=df[x_col], + y=df[s.name], + palette=pal, + bw=sigma, + width_viol=0.6, + ax=ax, + orient=ort, + move=0.3, + ) if title: plt.title(title, loc="center", fontdict={"fontsize": 30}) if x_label is not None: @@ -56,22 +64,23 @@ def raincloud(self, @pd.api.extensions.register_dataframe_accessor("viz") class VIZDFAccessor: - def __init__(self, pandas_obj): self._obj = pandas_obj - def raincloud(self, - values: str, - groups: Optional[str] = None, - one_value: Optional[str] = None, - limit: Optional[List[str]] = None, - ort: Optional[str] = "h", - pal: Optional[str] = "Set2", - sigma: Optional[float] = .2, - title: Optional[str] = "", - all_data_label: str = "All data", - x_label: Optional[str] = None, - y_label: Optional[str] = None,): + def raincloud( + self, + values: str, + groups: Optional[str] = None, + one_value: Optional[str] = None, + limit: Optional[List[str]] = None, + ort: Optional[str] = "h", + pal: Optional[str] = "Set2", + sigma: Optional[float] = 0.2, + title: Optional[str] = "", + all_data_label: str = "All data", + x_label: Optional[str] = None, + y_label: Optional[str] = None, + ): """ helper function for visualising one column against another with raincloud plots. @@ -88,11 +97,20 @@ def raincloud(self, if one_value: df[groups] = (df[groups] == one_value).map( - {False: "Other clusters", True: one_value}) - - f, ax = plt.subplots(figsize=(14, 2*df[groups].nunique())) - pt.RainCloud(x=df[groups], y=df[values], palette=pal, bw=sigma, - width_viol=.6, ax=ax, orient=ort, move=.3) + {False: "Other clusters", True: one_value} + ) + + f, ax = plt.subplots(figsize=(14, 2 * df[groups].nunique())) + pt.RainCloud( + x=df[groups], + y=df[values], + palette=pal, + bw=sigma, + width_viol=0.6, + ax=ax, + orient=ort, + move=0.3, + ) if title: plt.title(title, loc="center", fontdict={"fontsize": 30}) if x_label is not None: diff --git a/helpers/pipe.py b/helpers/pipe.py index 2b7b629..58a8fbd 100644 --- a/helpers/pipe.py +++ b/helpers/pipe.py @@ -12,7 +12,6 @@ class PartialLibrary: - def __init__(self, library): self._library = library @@ -44,14 +43,12 @@ def amend_partial(pfunc: Callable, value: Any) -> Tuple[Callable, str]: return pfunc, self_contained args = [x if x != PipeValue else value for x in pfunc.args] - kwargs = {k: (v if v != PipeValue else value) - for k, v in pfunc.keywords.items()} + kwargs = {k: (v if v != PipeValue else value) for k, v in pfunc.keywords.items()} return partial(pfunc.func, *args, **kwargs), self_contained class PipeStart: - def __init__(self, value): self.value = value self.operations = [] @@ -91,6 +88,7 @@ class Pipe: Starting value, then a list of functions to pass the result through. Current value can be referred as Pipe.value if it can't be passed through one value in a partial. """ + start = PipeStart value = PipeValue end = PipeEnd() diff --git a/management/cli.py b/management/cli.py index 198deec..0b332cc 100644 --- a/management/cli.py +++ b/management/cli.py @@ -3,7 +3,6 @@ class DocCollection: - def __init__(self): self.collection = None @@ -21,27 +20,48 @@ def cli(): @cli.command() -@click.argument('slug', default="") +@click.argument("slug", default="") @click.option("-p", "--param", nargs=2, multiple=True) -def render(slug="", param=[]): - doc_collection = dc.collection +@click.option("-g", "--group", nargs=1) +@click.option("--all/--not-all", "render_all", default=False) +@click.option("--publish/--no-publish", default=False) +def render(slug="", param=[], group="", render_all=False, publish=False): params = {x: y for x, y in param} if slug: - doc = doc_collection.get(slug) + docs = [dc.collection.get(slug)] + elif render_all: + docs = dc.collection.all() + elif group: + docs = dc.collection.get_group(group) else: - doc = doc_collection.first() + docs = [dc.collection.first()] + if params: print("using custom params") print(params) - doc.render(context=params) + + for doc in docs: + doc.render(context=params) + if publish: + print("starting publication flow") + doc.upload() @cli.command() -@click.argument('slug', default="") -def upload(slug=""): - doc_collection = dc.collection +@click.argument("slug", default="") +@click.option("--all/--not-all", "render_all", default=False) +def upload(slug="", param=[], render_all=False): + params = {x: y for x, y in param} if slug: - doc = doc_collection.get(slug) + docs = [dc.collection.get(slug)] + elif render_all: + docs = dc.collection.all() else: - doc = doc_collection.first() - doc.upload() \ No newline at end of file + docs = [dc.collection.first()] + + if params: + print("using custom params") + print(params) + + for doc in docs: + doc.upload() diff --git a/management/exporters.py b/management/exporters.py index 86e4d8e..4e4ed76 100644 --- a/management/exporters.py +++ b/management/exporters.py @@ -12,11 +12,13 @@ import nbformat from ipython_genutils.text import indent as normal_indent from nbconvert import MarkdownExporter, HTMLExporter -from nbconvert.preprocessors import (ClearMetadataPreprocessor, - ClearOutputPreprocessor, - ExecutePreprocessor, - ExtractOutputPreprocessor, - Preprocessor) +from nbconvert.preprocessors import ( + ClearMetadataPreprocessor, + ClearOutputPreprocessor, + ExecutePreprocessor, + ExtractOutputPreprocessor, + Preprocessor, +) from traitlets.config import Config notebook_render_dir = "_notebook_resources" @@ -30,8 +32,10 @@ class RemoveOnContent(Preprocessor): def preprocess(self, nb, resources): # Filter out cells that meet the conditions - nb.cells = [self.preprocess_cell(cell, resources, index)[0] - for index, cell in enumerate(nb.cells)] + nb.cells = [ + self.preprocess_cell(cell, resources, index)[0] + for index, cell in enumerate(nb.cells) + ] return nb, resources @@ -42,9 +46,7 @@ def preprocess_cell(self, cell, resources, cell_index): if cell["source"]: if "#HIDE" == cell["source"][:5]: - cell.transient = { - 'remove_source': True - } + cell.transient = {"remove_source": True} return cell, resources @@ -89,11 +91,13 @@ class MarkdownRenderer(object): include_input = False markdown_tables = True - def __init__(self, - input_name="readme.ipynb", - output_name=None, - include_input=None, - clear_and_execute=None): + def __init__( + self, + input_name="readme.ipynb", + output_name=None, + include_input=None, + clear_and_execute=None, + ): if include_input is None: include_input = self.__class__.include_input if clear_and_execute is None: @@ -106,8 +110,7 @@ def __init__(self, def check_for_self_reference(self, cell): # scope out the cell that called this function # prevent circular call - contains_str = check_string_in_source( - self.__class__.self_reference, cell) + contains_str = check_string_in_source(self.__class__.self_reference, cell) is_code = cell["cell_type"] == "code" return contains_str and is_code @@ -115,9 +118,11 @@ def get_contents(self, input_file): with open(input_file) as f: nb = json.load(f) - nb["cells"] = [x for x in nb["cells"] - if x["source"] and - self.check_for_self_reference(x) is False] + nb["cells"] = [ + x + for x in nb["cells"] + if x["source"] and self.check_for_self_reference(x) is False + ] str_notebook = json.dumps(nb) nb = nbformat.reads(str_notebook, as_version=4) @@ -129,13 +134,13 @@ def get_config(self): pre_processors = [] if self.clear_and_execute: - pre_processors += [ClearMetadataPreprocessor, - ClearOutputPreprocessor, - ExecutePreprocessor] + pre_processors += [ + ClearMetadataPreprocessor, + ClearOutputPreprocessor, + ExecutePreprocessor, + ] - pre_processors += [CustomExtractOutputPreprocessor, - RemoveOnContent - ] + pre_processors += [CustomExtractOutputPreprocessor, RemoveOnContent] c.MarkdownExporter.preprocessors = pre_processors c.MarkdownExporter.filters = {"indent": indent} @@ -157,8 +162,9 @@ def process(self, input_file=None, output_file=None): output_file = self.output_name if output_file is None: - output_file = Path(os.path.splitext( - input_file)[0] + self.__class__.default_ext) + output_file = Path( + os.path.splitext(input_file)[0] + self.__class__.default_ext + ) output_base_path = Path(output_file).parent @@ -174,8 +180,7 @@ def process(self, input_file=None, output_file=None): exporter = self.__class__.exporter_class(config=c) - resources = {"output_files_dir": notebook_render_dir, - "unique_key": base_root} + resources = {"output_files_dir": notebook_render_dir, "unique_key": base_root} body, resources = exporter.from_notebook_node(nb, resources) @@ -190,8 +195,9 @@ def process(self, input_file=None, output_file=None): if self.__class__.markdown_tables: body = body.replace( - '\n ', "") - soup = BeautifulSoup(body, 'html.parser') + '\n ', "" + ) + soup = BeautifulSoup(body, "html.parser") for div in soup.find_all("pagebreak"): div.replaceWith('
') @@ -209,14 +215,13 @@ def process(self, input_file=None, output_file=None): body = body.replace("<br/>", "
") body = body.replace("![png]", "![]") body = body.replace('', "") + body = body.replace("", "") # write main file with open(output_file, "w") as f: f.write(body) - print("Written to {0} at {1}".format( - output_file, datetime.now(tz=None))) + print("Written to {0} at {1}".format(output_file, datetime.now(tz=None))) class HTML_Renderer(MarkdownRenderer): @@ -232,13 +237,13 @@ def get_config(self): pre_processors = [] if self.clear_and_execute: - pre_processors += [ClearMetadataPreprocessor, - ClearOutputPreprocessor, - ExecutePreprocessor] + pre_processors += [ + ClearMetadataPreprocessor, + ClearOutputPreprocessor, + ExecutePreprocessor, + ] - pre_processors += [CustomExtractOutputPreprocessor, - RemoveOnContent - ] + pre_processors += [CustomExtractOutputPreprocessor, RemoveOnContent] c.HTMLExporter.preprocessors = pre_processors diff --git a/management/render_processing.py b/management/render_processing.py index e569c38..29b8224 100644 --- a/management/render_processing.py +++ b/management/render_processing.py @@ -1,9 +1,11 @@ import json +import shutil +from copy import deepcopy from dataclasses import dataclass from importlib import import_module from pathlib import Path from typing import Iterable, Optional -import shutil + import papermill as pm import pypandoc from jinja2 import Template @@ -38,14 +40,15 @@ def add_tag_based_on_content(input_file: Path, tag: str, content: str): def render(txt: str, context: dict): - return Template(txt).render(**context) + t = Template(str(txt)) + return t.render(**context) def combine_outputs(parts, output_path): - text_parts = [open(p, 'r').read() for p in parts] + text_parts = [open(p, "r").read() for p in parts] result = "\n".join(text_parts) result = result.replace("Notebook", "") - with open(output_path, 'w') as f: + with open(output_path, "w") as f: f.write(result) @@ -54,6 +57,7 @@ class Notebook: """ Handle talking to and rendering one file """ + name: str _parent: "Document" @@ -81,12 +85,9 @@ def papermill(self, slug, params, rerun: bool = True): print("Not papermilling, just copying current file") shutil.copy(self.raw_path(), self.papermill_path(slug)) else: - add_tag_based_on_content( - actual_path, "parameters", "#default-params") + add_tag_based_on_content(actual_path, "parameters", "#default-params") pm.execute_notebook( - actual_path, - self.papermill_path(slug), - parameters=params + actual_path, self.papermill_path(slug), parameters=params ) def rendered_filename(self, slug: str, ext: str = ".md"): @@ -106,13 +107,17 @@ def render(self, slug: str, hide_input: bool = True): include_input = not hide_input input_path = self.papermill_path(slug) exporters.render_to_markdown( - input_path, self.rendered_filename(slug, ".md"), + input_path, + self.rendered_filename(slug, ".md"), clear_and_execute=False, - include_input=include_input) + include_input=include_input, + ) exporters.render_to_html( - input_path, self.rendered_filename(slug, ".html"), + input_path, + self.rendered_filename(slug, ".html"), clear_and_execute=False, - include_input=include_input) + include_input=include_input, + ) class Document: @@ -127,8 +132,7 @@ def __init__(self, name: str, data: dict, context: Optional[dict] = None): self._data = data.copy() self.options = {"rerun": True, "hide_input": True} self.options.update(self._data.get("options", {})) - self.notebooks = [Notebook(x, _parent=self) - for x in self._data["notebooks"]] + self.notebooks = [Notebook(x, _parent=self) for x in self._data["notebooks"]] self.init_rendered_values(context) def init_rendered_values(self, context): @@ -157,7 +161,7 @@ def get_rendered_parameters(self, context) -> dict: """ render properties using jinga """ - raw_params = self._data["parameters"] + raw_params = self._data.get("parameters", {}) final_params = {} for k, v in raw_params.items(): nv = context.get(k, render(v, context)) @@ -193,8 +197,7 @@ def render(self, context: Optional[dict] = None): # combine for both md and html for ext in [".md", ".html"]: dest = self.rendered_filename(ext) - files = [x.rendered_filename(self.slug, ext) - for x in self.notebooks] + files = [x.rendered_filename(self.slug, ext) for x in self.notebooks] combine_outputs(files, dest) resources_dir = files[0].parent / "_notebook_resources" dest_resources = dest.parent / "_notebook_resources" @@ -208,9 +211,15 @@ def render(self, context: Optional[dict] = None): if template.exists() is False: raise ValueError("Missing Template") reference_doc = str(template) - pypandoc.convert_file(str(input_path_html), 'docx', outputfile=str( - output_path_doc), extra_args=[f"--resource-path={str(render_dir)}", - f"--reference-doc={reference_doc}"]) + pypandoc.convert_file( + str(input_path_html), + "docx", + outputfile=str(output_path_doc), + extra_args=[ + f"--resource-path={str(render_dir)}", + f"--reference-doc={reference_doc}", + ], + ) def upload(self): """ @@ -222,8 +231,7 @@ def upload(self): file_path = self.rendered_filename(".docx") g_folder_id = v["g_folder_id"] g_drive_id = v["g_drive_id"] - g_drive_upload_and_format( - file_name, file_path, g_folder_id, g_drive_id) + g_drive_upload_and_format(file_name, file_path, g_folder_id, g_drive_id) class DocumentCollection: @@ -239,11 +247,33 @@ def from_yaml(cls, yaml_file: Path): return cls(data) def __init__(self, data: dict): + + for k, v in data.items(): + if "meta" not in v: + data[k]["meta"] = False + if "extends" in v: + base = deepcopy(data[v["extends"]]) + if "meta" in base: + base.pop("meta") + base.update(v) + base.pop("extends") + data[k] = base + + for k, v in data.items(): + if "group" not in v: + data[k]["group"] = None + self.docs = {name: Document(name, data) for name, data in data.items()} def all(self) -> Iterable: for d in self.docs.values(): - yield d + if d._data["meta"] is False: + yield d + + def get_group(self, group: str) -> Iterable: + for d in self.all(): + if d._data["group"] == group: + yield d def first(self) -> Document: return list(self.docs.values())[0] diff --git a/management/settings.py b/management/settings.py index 3bf7378..ebca2b4 100644 --- a/management/settings.py +++ b/management/settings.py @@ -20,7 +20,7 @@ def get_settings(yaml_file: str = "settings.yaml", env_file: str = ".env"): return {} settings_file = Path(*top_level, yaml_file) - with open(settings_file, 'r') as fp: + with open(settings_file, "r") as fp: data = yaml.load(fp, Loader=yaml.Loader) env_data = {} diff --git a/management/upload.py b/management/upload.py index 798f202..ba647b0 100644 --- a/management/upload.py +++ b/management/upload.py @@ -1,5 +1,8 @@ from notebook_helper.apis.google_api import ( - DriveIntegration, ScriptIntergration, test_settings) + DriveIntegration, + ScriptIntergration, + test_settings, +) from .settings import settings @@ -20,7 +23,9 @@ def format_document(url): Apply google sheets formatter to URL """ api = ScriptIntergration(settings["GOOGLE_CLIENT_JSON"]) - script_id = "AKfycbwjKpOgzKaDHahyn-7If0LzMhaNfMTTsiHf6nvgL2gaaVsgI_VvuZjHJWAzRaehENLX" + script_id = ( + "AKfycbwjKpOgzKaDHahyn-7If0LzMhaNfMTTsiHf6nvgL2gaaVsgI_VvuZjHJWAzRaehENLX" + ) func = api.get_function(script_id, "formatWordURL") print("formatting document, this may take a few minutes") v = func(url) diff --git a/progress.py b/progress.py index 9a50d99..1effc8f 100644 --- a/progress.py +++ b/progress.py @@ -5,12 +5,15 @@ console = get_console() -def track_progress(iterable: Iterable, - name: Optional[str] = None, - total: Optional[int] = None, - update_label: bool = False, - label_func: Optional[Callable] = lambda x: x, - clear: Optional[bool] = True): + +def track_progress( + iterable: Iterable, + name: Optional[str] = None, + total: Optional[int] = None, + update_label: bool = False, + label_func: Optional[Callable] = lambda x: x, + clear: Optional[bool] = True, +): """ simple tracking loop using rich progress """ @@ -27,4 +30,4 @@ def track_progress(iterable: Iterable, description = f"{name}: {label_func(i)}" else: description = name - progress.update(task, advance=1, description=description) \ No newline at end of file + progress.update(task, advance=1, description=description)