diff --git a/Dockerfile b/Dockerfile
index 97a9077..ed04f0b 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -1,9 +1,9 @@
-FROM python:3.8-buster
-
-ENV DEBIAN_FRONTEND noninteractive
-
-COPY packages_setup.bash /
-RUN chmod +x /packages_setup.bash && /packages_setup.bash
-
-COPY base_requirements.txt /
+FROM python:3.8-buster
+
+ENV DEBIAN_FRONTEND noninteractive
+
+COPY packages_setup.bash /
+RUN chmod +x /packages_setup.bash && /packages_setup.bash
+
+COPY base_requirements.txt /
RUN pip install -r /base_requirements.txt
\ No newline at end of file
diff --git a/__init__.py b/__init__.py
index 0d45ab9..e2bf39a 100644
--- a/__init__.py
+++ b/__init__.py
@@ -4,6 +4,8 @@
import builtins as __builtin__
import datetime
+import re
+import unicodedata
from dataclasses import dataclass
from functools import partial
from pathlib import Path
@@ -14,9 +16,15 @@
import pandas as pd
from IPython.display import Markdown as md
-from .charting import (Chart, ChartEncoding, altair_sw_theme, altair_theme,
- enable_sw_charts)
-from .df_extensions import space, viz, common
+from .charting import (
+ Chart,
+ ChartEncoding,
+ altair_sw_theme,
+ altair_theme,
+ enable_sw_charts,
+ ChartTitle
+)
+from .df_extensions import common, space, viz
from .helpers.pipe import Pipe, Pipeline, iter_format
from .management.exporters import render_to_html, render_to_markdown
from .management.settings import settings
@@ -32,7 +40,8 @@
def page_break():
- return md("""
+ return md(
+ """
```{=openxml}
@@ -40,7 +49,8 @@ def page_break():
```
-""")
+"""
+ )
def notebook_setup():
@@ -51,7 +61,20 @@ def Date(x):
return datetime.datetime.fromisoformat(x).date()
-comma_thousands = '{:,}'.format
-percentage_0dp = '{:,.0%}'.format
-percentage_1dp = '{:,.1%}'.format
-percentage_2dp = '{:,.2%}'.format
+def slugify(value):
+ """
+ Converts to lowercase, removes non-word characters (alphanumerics and
+ underscores) and converts spaces to hyphens. Also strips leading and
+ trailing whitespace.
+ """
+ value = (
+ unicodedata.normalize("NFKD", value).encode("ascii", "ignore").decode("ascii")
+ )
+ value = re.sub("[^\w\s-]", "", value).strip().lower()
+ return re.sub("[-\s]+", "-", value)
+
+
+comma_thousands = "{:,}".format
+percentage_0dp = "{:,.0%}".format
+percentage_1dp = "{:,.1%}".format
+percentage_2dp = "{:,.2%}".format
diff --git a/apis/google_api.py b/apis/google_api.py
index 786c13b..14667b5 100644
--- a/apis/google_api.py
+++ b/apis/google_api.py
@@ -1,4 +1,3 @@
-
import socket
import sys
from pathlib import Path
@@ -8,49 +7,58 @@
from googleapiclient.discovery import build
from googleapiclient.http import MediaFileUpload
-SCOPES = ['https://www.googleapis.com/auth/drive', "https://www.googleapis.com/auth/documents",
- 'https://www.googleapis.com/auth/script.projects']
+SCOPES = [
+ "https://www.googleapis.com/auth/drive",
+ "https://www.googleapis.com/auth/documents",
+ "https://www.googleapis.com/auth/script.projects",
+]
url_template = "https://docs.google.com/document/d/{0}/edit"
class DriveIntegration:
-
def __init__(self, data):
self.creds = Credentials.from_authorized_user_info(data, SCOPES)
- self.api = build('drive', 'v3', credentials=self.creds)
+ self.api = build("drive", "v3", credentials=self.creds)
def upload_file(self, file_name, file_path, folder_id, drive_id):
- body = {'name': file_name, 'driveID': drive_id, "parents": [
- folder_id], 'mimeType': 'application/vnd.google-apps.document'}
+ body = {
+ "name": file_name,
+ "driveID": drive_id,
+ "parents": [folder_id],
+ "mimeType": "application/vnd.google-apps.document",
+ }
# Now create the media file upload object and tell it what file to upload,
# in this case 'test.html'
media = MediaFileUpload(
file_path,
- mimetype='application/vnd.openxmlformats-officedocument.wordprocessingml.document')
+ mimetype="application/vnd.openxmlformats-officedocument.wordprocessingml.document",
+ )
# Now we're doing the actual post, creating a new file of the uploaded type
- uploaded = self.api.files().create(body=body, media_body=media,
- supportsTeamDrives=True).execute()
+ uploaded = (
+ self.api.files()
+ .create(body=body, media_body=media, supportsTeamDrives=True)
+ .execute()
+ )
url = url_template.format(uploaded["id"])
return url
class ScriptIntergration:
-
def __init__(self, data):
self.creds = Credentials.from_authorized_user_info(data, SCOPES)
socket.setdefaulttimeout(600) # set timeout to 10 minutes
- self.api = build('script', 'v1', credentials=self.creds)
+ self.api = build("script", "v1", credentials=self.creds)
def get_function(self, script_id, function_name):
-
def inner(*args):
request = {"function": function_name, "parameters": list(args)}
- response = self.api.scripts().run(body=request,
- scriptId=script_id).execute()
+ response = (
+ self.api.scripts().run(body=request, scriptId=script_id).execute()
+ )
return response
@@ -66,7 +74,7 @@ def trigger_log_in_flow(settings):
json_creds = creds.to_json()
print(f"GOOGLE_CLIENT_JSON={json_creds}")
raise ValueError("Add the following last line printed to the .env")
-
+
def test_settings(settings):
"""
@@ -74,7 +82,9 @@ def test_settings(settings):
"""
if "GOOGLE_APP_JSON" not in settings or settings["GOOGLE_APP_JSON"] == "":
- raise ValueError("Missing GOOGLE_APP_JSON settings. See the notebook setup page in the wiki for the correct settings.")
+ raise ValueError(
+ "Missing GOOGLE_APP_JSON settings. See the notebook setup page in the wiki for the correct settings."
+ )
if "GOOGLE_CLIENT_JSON" not in settings or settings["GOOGLE_CLIENT_JSON"] == "":
trigger_log_in_flow(settings)
diff --git a/base_requirements.txt b/base_requirements.txt
index d809100..a5c4353 100644
--- a/base_requirements.txt
+++ b/base_requirements.txt
@@ -12,8 +12,6 @@ altair==4.1.0
altair_saver==0.5.0
jupyter==1.0.0
pylint==2.9.1
-pycodestyle==2.7.0
-autopep8==1.5.7
ruamel.yaml==0.17.10
pypandoc==1.5
flake8==3.9.2
@@ -24,4 +22,6 @@ papermill==2.3.3
google-api-python-client==2.21.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
-ptitprince==0.2.5
\ No newline at end of file
+ptitprince==0.2.5
+pylint==2.12.2
+black[jupyter]==21.12b0
\ No newline at end of file
diff --git a/charting/__init__.py b/charting/__init__.py
index c9bb0d1..a3e9a85 100644
--- a/charting/__init__.py
+++ b/charting/__init__.py
@@ -1,22 +1,21 @@
from . import theme as altair_theme
from . import sw_theme as altair_sw_theme
-from .chart import Chart, Renderer, ChartEncoding
+from .chart import Chart, Renderer, ChartEncoding, ChartTitle
from .saver import MSSaver, SWSaver, render, sw_render
+
import altair as alt
-alt.themes.register('mysoc_theme', lambda: altair_theme.mysoc_theme)
-alt.themes.enable('mysoc_theme')
+alt.themes.register("mysoc_theme", lambda: altair_theme.mysoc_theme)
+alt.themes.enable("mysoc_theme")
-alt.renderers.register('mysoc_saver', render)
-alt.renderers.enable('mysoc_saver')
+alt.renderers.register("mysoc_saver", render)
+alt.renderers.enable("mysoc_saver")
def enable_sw_charts():
- alt.themes.register('societyworks_theme', lambda: altair_sw_theme.sw_theme)
- alt.themes.enable('societyworks_theme')
-
- alt.renderers.register('sw_saver', sw_render)
- alt.renderers.enable('sw_saver')
- Renderer.default_renderer = SWSaver
-
+ alt.themes.register("societyworks_theme", lambda: altair_sw_theme.sw_theme)
+ alt.themes.enable("societyworks_theme")
+ alt.renderers.register("sw_saver", sw_render)
+ alt.renderers.enable("sw_saver")
+ Renderer.default_renderer = SWSaver
\ No newline at end of file
diff --git a/charting/chart.py b/charting/chart.py
index 15251e3..9054527 100644
--- a/charting/chart.py
+++ b/charting/chart.py
@@ -1,6 +1,6 @@
from functools import wraps
from pathlib import Path
-from typing import List, Optional
+from typing import List, Optional, Union
import altair as alt
import pandas as pd
@@ -15,8 +15,8 @@ class Renderer:
def save_chart(chart, filename, scale_factor=1, **kwargs):
"""
- dumbed down version of altair save function that just assumes
- we're sending extra properties to the embed options
+ dumbed down version of altair save function that just assumes
+ we're sending extra properties to the embed options
"""
if isinstance(filename, Path):
# altair doesn't process paths right
@@ -24,11 +24,58 @@ def save_chart(chart, filename, scale_factor=1, **kwargs):
filename.parent.mkdir()
filename = str(filename)
- altair_save_chart(chart,
- filename,
- scale_factor=scale_factor,
- embed_options=kwargs,
- method=Renderer.default_renderer)
+ altair_save_chart(
+ chart,
+ filename,
+ scale_factor=scale_factor,
+ embed_options=kwargs,
+ method=Renderer.default_renderer,
+ )
+
+
+def split_text_to_line(text: str, cut_off: int = 60) -> List[str]:
+ """
+ Split a string to meet line limit
+ """
+ bits = text.split(" ")
+ rows = []
+ current_item = []
+ for b in bits:
+ if len(" ".join(current_item + [b])) > cut_off:
+ rows.append(" ".join(current_item))
+ current_item = []
+ current_item.append(b)
+ rows.append(" ".join(current_item))
+ return rows
+
+
+class ChartTitle(alt.TitleParams):
+ """
+ Helper function for chart title
+ Includes better line wrapping
+ """
+
+ def __init__(
+ self,
+ title: Union[str, List[str]],
+ subtitle: Optional[Union[str, List[str]]] = None,
+ line_limit: int = 60,
+ **kwargs
+ ):
+
+ if isinstance(title, str):
+ title_bits = split_text_to_line(title, line_limit)
+ else:
+ title_bits = title
+
+ if isinstance(subtitle, str):
+ subtitle = [subtitle]
+
+ kwargs["text"] = title_bits
+ if subtitle:
+ kwargs["subtitle"] = subtitle
+
+ super().__init__(**kwargs)
class MSDisplayMixIn:
@@ -38,10 +85,11 @@ class MSDisplayMixIn:
"""
ignore_properties = ["_display_options"]
+ scale_factor = 1
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
- self._display_options = {}
+ self._display_options = {"scale_factor": self.__class__.scale_factor}
def display_options(self, **kwargs):
"""
@@ -81,6 +129,60 @@ def __or__(self, other):
raise ValueError("Only Chart objects can be concatenated.")
return hconcat(self, other)
+ @wraps(alt.Chart.properties)
+ def raw_properties(self, *args, **kwargs):
+ return super().properties(*args, **kwargs)
+
+ def properties(
+ self,
+ title: Optional[Union[str, list, alt.TitleParams, ChartTitle]] = "",
+ title_line_limit: Optional[int] = 60,
+ subtitle: Optional[Union[str, list]] = None,
+ width: Optional[int] = None,
+ height: Optional[int] = None,
+ aspect: Optional[tuple] = (16, 9),
+ logo: bool = False,
+ caption: Optional[str] = "",
+ scale_factor: Optional[str] = None,
+ **kwargs
+ ) -> "Chart":
+
+ args = {}
+ display_args = {"logo": logo, "caption": caption}
+ if scale_factor:
+ display_args["scale_factor"] = scale_factor
+
+ if isinstance(title, str) or isinstance(title, list) or subtitle is not None:
+ args["title"] = ChartTitle(
+ title=title, subtitle=subtitle, line_limit=title_line_limit
+ )
+
+ if width and not height:
+ args["width"] = width
+ args["height"] = (width / aspect[0]) * aspect[1]
+
+ if height and not width:
+ args["height"] = height
+ args["width"] = (height / aspect[1]) * aspect[0]
+
+ if width and height:
+ args["height"] = height
+ args["width"] = width
+
+ width_offset = 0
+ height_offset = 0
+
+ if logo or caption:
+ height_offset += 100
+
+ if "width" in args:
+ args["width"] -= width_offset
+ args["height"] -= height_offset
+ args["autosize"] = alt.AutoSizeParams(type="fit", contains="padding")
+
+ kwargs.update(args)
+ return super().properties(**kwargs).display_options(**display_args)
+
class MSDataManagementMixIn:
"""
@@ -93,6 +195,7 @@ class MSDataManagementMixIn:
@classmethod
def from_url(cls, url, n=0):
from .download import get_chart_from_url
+
return get_chart_from_url(url, n)
def _get_df(self) -> pd.DataFrame:
@@ -102,7 +205,7 @@ def update_df(self, df: pd.DataFrame):
"""
take a new df and update the chart
"""
- self.datasets[self.data["name"]] = df.to_dict('records')
+ self.datasets[self.data["name"]] = df.to_dict("records")
return self
@property
@@ -119,19 +222,23 @@ def __setattribute__(self, key, value):
super().__setattribute__(key, value)
-class Chart(MSDisplayMixIn, MSDataManagementMixIn, alt.Chart):
+class MSAltair(MSDisplayMixIn, MSDataManagementMixIn):
+ pass
+
+
+class Chart(MSAltair, alt.Chart):
pass
-class LayerChart(MSDisplayMixIn, MSDataManagementMixIn, alt.LayerChart):
+class LayerChart(MSAltair, alt.LayerChart):
pass
-class HConcatChart(MSDisplayMixIn, MSDataManagementMixIn, alt.HConcatChart):
+class HConcatChart(MSAltair, alt.HConcatChart):
pass
-class VConcatChart(MSDisplayMixIn, MSDataManagementMixIn, alt.VConcatChart):
+class VConcatChart(MSAltair, alt.VConcatChart):
pass
@@ -153,6 +260,6 @@ def vconcat(*charts, **kwargs):
@wraps(Chart.encode)
def ChartEncoding(**kwargs):
"""
- Thin wrapper to specify properites we want to use multiple times
+ Thin wrapper to specify properties we want to use multiple times
"""
return kwargs
diff --git a/charting/download.py b/charting/download.py
index 65f216e..f77dc2d 100644
--- a/charting/download.py
+++ b/charting/download.py
@@ -1,10 +1,10 @@
-
from typing import List
from urllib.request import urlopen
import altair as alt
from .chart import Chart, LayerChart
import json
+
def get_chart_spec_from_url(url: str) -> List[str]:
"""
For extracting chart specs produced by the research sites framework
@@ -26,7 +26,8 @@ def json_to_chart(json_spec: str) -> alt.Chart:
del di["layer"]
del di["width"]
chart = LayerChart.from_dict(
- {"config": di["config"], "layer": [], "datasets": di["datasets"]})
+ {"config": di["config"], "layer": [], "datasets": di["datasets"]}
+ )
for n, l in enumerate(layers):
di_copy = di.copy()
di_copy.update(l)
@@ -43,11 +44,9 @@ def json_to_chart(json_spec: str) -> alt.Chart:
return chart
-def get_chart_from_url(url: str,
- n: int = 0,
- include_df: bool = False) -> alt.Chart:
+def get_chart_from_url(url: str, n: int = 0, include_df: bool = False) -> alt.Chart:
"""
- given url, a number (0 indexed), get the spec,
+ given url, a number (0 indexed), get the spec,
and reduce an altair chart instance.
if `include_df` will try and reduce the original df as well.
"""
diff --git a/charting/saver.py b/charting/saver.py
index 1f5e5a8..3f460f5 100644
--- a/charting/saver.py
+++ b/charting/saver.py
@@ -5,10 +5,15 @@
from altair_saver.types import JSONDict, Mimebundle
from altair_saver._utils import extract_format, infer_mode_from_spec
from functools import partial
-from altair_saver.savers._selenium import (CDN_URL, EXTRACT_CODE,
- HTML_TEMPLATE, JavascriptError,
- MimebundleContent, SeleniumSaver,
- get_bundled_script)
+from altair_saver.savers._selenium import (
+ CDN_URL,
+ EXTRACT_CODE,
+ HTML_TEMPLATE,
+ JavascriptError,
+ MimebundleContent,
+ SeleniumSaver,
+ get_bundled_script,
+)
import altair as alt
@@ -144,7 +149,9 @@ def get_as_base64(url):
class MSSaver(SeleniumSaver):
- logo_url = "https://research.mysociety.org/sites/foi-monitor/static/img/mysociety-logo.jpg"
+ logo_url = (
+ "https://research.mysociety.org/sites/foi-monitor/static/img/mysociety-logo.jpg"
+ )
font = "Source Sans Pro"
def __init__(self, *args, **kwargs):
@@ -156,8 +163,10 @@ def _get_font(self):
def _get_logo(self):
if self._logo is None:
- self._logo = "data:image/jpg;base64," + \
- get_as_base64(self.__class__.logo_url).decode()
+ self._logo = (
+ "data:image/jpg;base64,"
+ + get_as_base64(self.__class__.logo_url).decode()
+ )
return self._logo
def _extract(self, fmt: str) -> MimebundleContent:
@@ -208,12 +217,9 @@ def _extract(self, fmt: str) -> MimebundleContent:
opt = self._embed_options.copy()
opt["mode"] = self._mode
- extract_code = EXTRACT_CODE.replace(
- "$$BASE64LOGO$$", str(self._get_logo()))
- extract_code = extract_code.replace(
- "$$FONT$$", str(self._get_font()))
- result = driver.execute_async_script(
- extract_code, self._spec, opt, fmt)
+ extract_code = EXTRACT_CODE.replace("$$BASE64LOGO$$", str(self._get_logo()))
+ extract_code = extract_code.replace("$$FONT$$", str(self._get_font()))
+ result = driver.execute_async_script(extract_code, self._spec, opt, fmt)
if "error" in result:
raise JavascriptError(result["error"])
return result["result"]
@@ -254,8 +260,13 @@ def render(
scale_factor = embed_options["scale_factor"]
for fmt in fmts:
- saver = Saver(spec, mode=mode, embed_options=embed_options,
- scale_factor=scale_factor, **kwargs)
+ saver = Saver(
+ spec,
+ mode=mode,
+ embed_options=embed_options,
+ scale_factor=scale_factor,
+ **kwargs,
+ )
mimebundle.update(saver.mimebundle(fmt))
return mimebundle
diff --git a/charting/sw_theme.py b/charting/sw_theme.py
index 1e2554f..bc6d292 100644
--- a/charting/sw_theme.py
+++ b/charting/sw_theme.py
@@ -8,59 +8,58 @@
from typing import List, Any, Optional
# brand colours
-colours = {'colour_orange': '#f79421',
- 'colour_off_white': '#f3f1eb',
- 'colour_light_grey': '#e2dfd9',
- 'colour_mid_grey': '#959287',
- 'colour_dark_grey': '#6c6b68',
- 'colour_black': '#333333',
- 'colour_red': '#dd4e4d',
- 'colour_yellow': '#fff066',
- 'colour_violet': '#a94ca6',
- 'colour_green': '#61b252',
- 'colour_green_dark': '#53a044',
- 'colour_green_dark_2': '#388924',
- 'colour_blue': '#54b1e4',
- 'colour_blue_dark': '#2b8cdb',
- 'colour_blue_dark_2': '#207cba'}
+colours = {
+ "colour_orange": "#f79421",
+ "colour_off_white": "#f3f1eb",
+ "colour_light_grey": "#e2dfd9",
+ "colour_mid_grey": "#959287",
+ "colour_dark_grey": "#6c6b68",
+ "colour_black": "#333333",
+ "colour_red": "#dd4e4d",
+ "colour_yellow": "#fff066",
+ "colour_violet": "#a94ca6",
+ "colour_green": "#61b252",
+ "colour_green_dark": "#53a044",
+ "colour_green_dark_2": "#388924",
+ "colour_blue": "#54b1e4",
+ "colour_blue_dark": "#2b8cdb",
+ "colour_blue_dark_2": "#207cba",
+}
# based on data visualisation colour palette
-adjusted_colours = {"sw_yellow": "#fed876",
- "sw_berry": "#e02653",
- "sw_blue": "#0ba7d1",
- "sw_dark_blue": "#065a70"
- }
+adjusted_colours = {
+ "sw_yellow": "#fed876",
+ "sw_berry": "#e02653",
+ "sw_blue": "#0ba7d1",
+ "sw_dark_blue": "#065a70",
+}
-monochrome_colours = {"colour_blue_light_20": "#7ddef8",
- "colour_blue": "#0ba7d1",
- "colour_blue_dark_20": "#076d88",
- "colour_blue_dark_30": "#033340"
- }
+monochrome_colours = {
+ "colour_blue_light_20": "#7ddef8",
+ "colour_blue": "#0ba7d1",
+ "colour_blue_dark_20": "#076d88",
+ "colour_blue_dark_30": "#033340",
+}
all_colours = colours.copy()
all_colours.update(adjusted_colours)
all_colours.update(monochrome_colours)
-palette = ["sw_yellow",
- "sw_berry",
- "sw_blue",
- "sw_dark_blue"]
+palette = ["sw_yellow", "sw_berry", "sw_blue", "sw_dark_blue"]
-contrast_palette = ["sw_dark_blue",
- "sw_yellow",
- "sw_berry",
- "sw_blue"]
+contrast_palette = ["sw_dark_blue", "sw_yellow", "sw_berry", "sw_blue"]
palette = contrast_palette
-monochrome_palette = ["colour_blue_light_20",
- "colour_blue",
- "colour_blue_dark_20",
- "colour_blue_dark_30"
- ]
+monochrome_palette = [
+ "colour_blue_light_20",
+ "colour_blue",
+ "colour_blue_dark_20",
+ "colour_blue_dark_30",
+]
palette_colors = [adjusted_colours[x] for x in palette]
@@ -72,22 +71,40 @@
# set default of colours
original_palette = [
# Start with category10 color cycle:
- "#1f77b4", '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
- '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
+ "#1f77b4",
+ "#ff7f0e",
+ "#2ca02c",
+ "#d62728",
+ "#9467bd",
+ "#8c564b",
+ "#e377c2",
+ "#7f7f7f",
+ "#bcbd22",
+ "#17becf",
# Then continue with the paired lighter colors from category20:
- '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
- '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5']
+ "#aec7e8",
+ "#ffbb78",
+ "#98df8a",
+ "#ff9896",
+ "#c5b0d5",
+ "#c49c94",
+ "#f7b6d2",
+ "#c7c7c7",
+ "#dbdb8d",
+ "#9edae5",
+]
# use new palette for as long as possible
-sw_palette_colors = palette_colors + original_palette[len(palette_colors):]
+sw_palette_colors = palette_colors + original_palette[len(palette_colors) :]
-def color_scale(domain: List[Any],
- monochrome: bool = False,
- reverse: bool = False,
- palette: Optional[List[Any]] = None,
- named_palette: Optional[List[Any]] = None
- ) -> alt.Scale:
+def color_scale(
+ domain: List[Any],
+ monochrome: bool = False,
+ reverse: bool = False,
+ palette: Optional[List[Any]] = None,
+ named_palette: Optional[List[Any]] = None,
+) -> alt.Scale:
if palette is None:
if monochrome:
palette = monochrome_palette_colors
@@ -98,7 +115,7 @@ def color_scale(domain: List[Any],
palette = [monochrome_colours[x] for x in named_palette]
else:
palette = [all_colours[x] for x in named_palette]
- use_palette = palette[:len(domain)]
+ use_palette = palette[: len(domain)]
if reverse:
use_palette = use_palette[::-1]
return alt.Scale(domain=domain, range=use_palette)
@@ -107,70 +124,65 @@ def color_scale(domain: List[Any],
font = "Lato"
sw_theme = {
-
- 'config': {
+ "config": {
"padding": {"left": 5, "top": 5, "right": 20, "bottom": 5},
- "title": {'font': font,
- 'fontSize': 30,
- 'anchor': "start"
- },
- 'axis': {
+ "title": {"font": font, "fontSize": 30, "anchor": "start"},
+ "axis": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
- 'offset': 0
+ "titleFontSize": 16,
+ "offset": 0,
},
- 'axisX': {
+ "axisX": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
- 'domain': True,
- 'grid': True,
+ "titleFontSize": 16,
+ "domain": True,
+ "grid": True,
"ticks": False,
"gridWidth": 0.4,
- 'labelPadding': 10,
-
+ "labelPadding": 10,
},
- 'axisY': {
+ "axisY": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
+ "titleFontSize": 16,
"titleAlign": "left",
- 'labelPadding': 10,
- 'domain': True,
+ "labelPadding": 10,
+ "domain": True,
"ticks": False,
"titleAngle": 0,
"titleY": -10,
"titleX": -50,
"gridWidth": 0.4,
},
- 'view': {
+ "view": {
"stroke": "transparent",
- 'continuousWidth': 700,
- 'continuousHeight': 400
+ "continuousWidth": 700,
+ "continuousHeight": 400,
},
"line": {
"strokeWidth": 3,
},
"bar": {"color": palette_colors[0]},
- 'mark': {"shape": "cross"},
- 'legend': {
- "orient": 'bottom',
+ "mark": {"shape": "cross"},
+ "legend": {
+ "orient": "bottom",
"labelFont": font,
"labelFontSize": 12,
"titleFont": font,
"titleFontSize": 12,
"title": "",
"offset": 18,
- "symbolType": 'square',
- }
+ "symbolType": "square",
+ },
}
}
-sw_theme.setdefault('encoding', {}).setdefault('color', {})['scale'] = {
- 'range': sw_palette_colors,
+sw_theme.setdefault("encoding", {}).setdefault("color", {})["scale"] = {
+ "range": sw_palette_colors,
}
diff --git a/charting/theme.py b/charting/theme.py
index 0e95e4f..3b2935a 100644
--- a/charting/theme.py
+++ b/charting/theme.py
@@ -8,59 +8,71 @@
from typing import List, Any, Optional
# brand colours
-colours = {'colour_orange': '#f79421',
- 'colour_off_white': '#f3f1eb',
- 'colour_light_grey': '#e2dfd9',
- 'colour_mid_grey': '#959287',
- 'colour_dark_grey': '#6c6b68',
- 'colour_black': '#333333',
- 'colour_red': '#dd4e4d',
- 'colour_yellow': '#fff066',
- 'colour_violet': '#a94ca6',
- 'colour_green': '#61b252',
- 'colour_green_dark': '#53a044',
- 'colour_green_dark_2': '#388924',
- 'colour_blue': '#54b1e4',
- 'colour_blue_dark': '#2b8cdb',
- 'colour_blue_dark_2': '#207cba'}
+colours = {
+ "colour_orange": "#f79421",
+ "colour_off_white": "#f3f1eb",
+ "colour_light_grey": "#e2dfd9",
+ "colour_mid_grey": "#959287",
+ "colour_dark_grey": "#6c6b68",
+ "colour_black": "#333333",
+ "colour_red": "#dd4e4d",
+ "colour_yellow": "#fff066",
+ "colour_violet": "#a94ca6",
+ "colour_green": "#61b252",
+ "colour_green_dark": "#53a044",
+ "colour_green_dark_2": "#388924",
+ "colour_blue": "#54b1e4",
+ "colour_blue_dark": "#2b8cdb",
+ "colour_blue_dark_2": "#207cba",
+}
# based on data visualisation colour palette
-adjusted_colours = {"colour_yellow": "#ffe269",
- "colour_orange": "#f4a140",
- "colour_berry": "#e02653",
- "colour_purple": "#a94ca6",
- "colour_blue": "#4faded",
- "colour_dark_blue": "#0a4166"}
-
-monochrome_colours = {"colour_blue_light_20": "#acd8f6",
- "colour_blue": "#4faded",
- "colour_blue_dark_20": "#147cc2",
- "colour_blue_dark_30": "#0f5e94",
- "colour_blue_dark_40": "#0a4166",
- "colour_blue_dark_50": "#062337",
- }
-
-palette = ["colour_dark_blue",
- "colour_berry",
- "colour_orange",
- "colour_blue",
- "colour_purple",
- "colour_yellow"]
-
-contrast_palette = ["colour_dark_blue",
- "colour_yellow",
- "colour_berry",
- "colour_orange",
- "colour_blue",
- "colour_purple"]
-
-monochrome_palette = ["colour_blue_light_20",
- "colour_blue",
- "colour_blue_dark_20",
- "colour_blue_dark_30",
- "colour_blue_dark_40",
- "colour_blue_dark_50",
- ]
+adjusted_colours = {
+ "colour_yellow": "#ffe269",
+ "colour_orange": "#f4a140",
+ "colour_berry": "#e02653",
+ "colour_purple": "#a94ca6",
+ "colour_blue": "#4faded",
+ "colour_dark_blue": "#0a4166",
+ "colour_mid_grey": "#959287",
+ "colour_dark_grey": "#6c6b68",
+}
+
+monochrome_colours = {
+ "colour_blue_light_20": "#acd8f6",
+ "colour_blue": "#4faded",
+ "colour_blue_dark_20": "#147cc2",
+ "colour_blue_dark_30": "#0f5e94",
+ "colour_blue_dark_40": "#0a4166",
+ "colour_blue_dark_50": "#062337",
+}
+
+palette = [
+ "colour_dark_blue",
+ "colour_berry",
+ "colour_orange",
+ "colour_blue",
+ "colour_purple",
+ "colour_yellow",
+]
+
+contrast_palette = [
+ "colour_dark_blue",
+ "colour_yellow",
+ "colour_berry",
+ "colour_orange",
+ "colour_blue",
+ "colour_purple",
+]
+
+monochrome_palette = [
+ "colour_blue_light_20",
+ "colour_blue",
+ "colour_blue_dark_20",
+ "colour_blue_dark_30",
+ "colour_blue_dark_40",
+ "colour_blue_dark_50",
+]
palette_colors = [adjusted_colours[x] for x in palette]
@@ -72,22 +84,40 @@
# set default of colours
original_palette = [
# Start with category10 color cycle:
- "#1f77b4", '#ff7f0e', '#2ca02c', '#d62728', '#9467bd',
- '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf',
+ "#1f77b4",
+ "#ff7f0e",
+ "#2ca02c",
+ "#d62728",
+ "#9467bd",
+ "#8c564b",
+ "#e377c2",
+ "#7f7f7f",
+ "#bcbd22",
+ "#17becf",
# Then continue with the paired lighter colors from category20:
- '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5',
- '#c49c94', '#f7b6d2', '#c7c7c7', '#dbdb8d', '#9edae5']
+ "#aec7e8",
+ "#ffbb78",
+ "#98df8a",
+ "#ff9896",
+ "#c5b0d5",
+ "#c49c94",
+ "#f7b6d2",
+ "#c7c7c7",
+ "#dbdb8d",
+ "#9edae5",
+]
# use new palette for as long as possible
-mysoc_palette_colors = palette_colors + original_palette[len(palette_colors):]
+mysoc_palette_colors = palette_colors + original_palette[len(palette_colors) :]
-def color_scale(domain: List[Any],
- monochrome: bool = False,
- reverse: bool = False,
- palette: Optional[List[Any]] = None,
- named_palette: Optional[List[Any]] = None
- ) -> alt.Scale:
+def color_scale(
+ domain: List[Any],
+ monochrome: bool = False,
+ reverse: bool = False,
+ palette: Optional[List[Any]] = None,
+ named_palette: Optional[List[Any]] = None,
+) -> alt.Scale:
if palette is None:
if monochrome:
palette = monochrome_palette_colors
@@ -98,7 +128,7 @@ def color_scale(domain: List[Any],
palette = [monochrome_colours[x] for x in named_palette]
else:
palette = [adjusted_colours[x] for x in named_palette]
- use_palette = palette[:len(domain)]
+ use_palette = palette[: len(domain)]
if reverse:
use_palette = use_palette[::-1]
return alt.Scale(domain=domain, range=use_palette)
@@ -107,70 +137,71 @@ def color_scale(domain: List[Any],
font = "Source Sans Pro"
mysoc_theme = {
-
- 'config': {
+ "config": {
"padding": {"left": 5, "top": 5, "right": 20, "bottom": 5},
- "title": {'font': font,
- 'fontSize': 30,
- 'anchor': "start"
- },
- 'axis': {
+ "title": {
+ "font": font,
+ "fontSize": 30,
+ "anchor": "start",
+ "subtitleFontSize": 20,
+ "subtitleFont": "Source Sans Pro",
+ },
+ "axis": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
- 'offset': 0
+ "titleFontSize": 16,
+ "offset": 0,
},
- 'axisX': {
+ "axisX": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
- 'domain': True,
- 'grid': True,
+ "titleFontSize": 16,
+ "domain": True,
+ "grid": True,
"ticks": False,
"gridWidth": 0.4,
- 'labelPadding': 10,
-
+ "labelPadding": 10,
},
- 'axisY': {
+ "axisY": {
"labelFont": font,
"labelFontSize": 14,
"titleFont": font,
- 'titleFontSize': 16,
+ "titleFontSize": 16,
"titleAlign": "left",
- 'labelPadding': 10,
- 'domain': True,
+ "labelPadding": 10,
+ "domain": True,
"ticks": False,
"titleAngle": 0,
"titleY": -10,
"titleX": -50,
"gridWidth": 0.4,
},
- 'view': {
+ "view": {
"stroke": "transparent",
- 'continuousWidth': 700,
- 'continuousHeight': 400
+ "continuousWidth": 700,
+ "continuousHeight": 400,
},
"line": {
"strokeWidth": 3,
},
"bar": {"color": palette_colors[0]},
- 'mark': {"shape": "cross"},
- 'legend': {
- "orient": 'bottom',
+ "mark": {"shape": "cross"},
+ "legend": {
+ "orient": "bottom",
"labelFont": font,
"labelFontSize": 12,
"titleFont": font,
"titleFontSize": 12,
"title": "",
"offset": 18,
- "symbolType": 'square',
- }
+ "symbolType": "square",
+ },
}
}
-mysoc_theme.setdefault('encoding', {}).setdefault('color', {})['scale'] = {
- 'range': mysoc_palette_colors,
+mysoc_theme.setdefault("encoding", {}).setdefault("color", {})["scale"] = {
+ "range": mysoc_palette_colors,
}
diff --git a/df_extensions/common.py b/df_extensions/common.py
index 2e11e43..663fff9 100644
--- a/df_extensions/common.py
+++ b/df_extensions/common.py
@@ -1,5 +1,6 @@
import pandas as pd
+
@pd.api.extensions.register_series_accessor("common")
class CommonAccessor(object):
"""
diff --git a/df_extensions/space.py b/df_extensions/space.py
index 2331686..60ab533 100644
--- a/df_extensions/space.py
+++ b/df_extensions/space.py
@@ -22,18 +22,17 @@
def hex_to_rgb(value):
- value = value.lstrip('#')
+ value = value.lstrip("#")
lv = len(value)
- t = tuple(int(value[i:i + lv // 3], 16) for i in range(0, lv, lv // 3))
+ t = tuple(int(value[i : i + lv // 3], 16) for i in range(0, lv, lv // 3))
return t + (0,)
def fnormalize(s):
- return (s - s.mean())/s.std()
+ return (s - s.mean()) / s.std()
class mySocMap(Colormap):
-
def __call__(self, X, alpha=None, bytes=False):
return mysoc_palette_colors[int(X)]
@@ -43,22 +42,24 @@ class Cluster:
Helper class for finding kgram clusters.
"""
- def __init__(self,
- source_df: pd.DataFrame,
- id_col: Optional[str] = None,
- cols: Optional[List[str]] = None,
- label_cols: Optional[List[str]] = None,
- normalize: bool = True,
- transform: List[Callable] = None,
- k: Optional[int] = 2):
- """
- Initalised with a dataframe, an id column in that dataframe,
- the columns that are the dimensions in question.
- A 'normalize' paramater on if those columns should be
- normalised before use.
- and 'label_cols' which are columns that contain categories
- for items.
- These can be used to help understand clusters.
+ def __init__(
+ self,
+ source_df: pd.DataFrame,
+ id_col: Optional[str] = None,
+ cols: Optional[List[str]] = None,
+ label_cols: Optional[List[str]] = None,
+ normalize: bool = True,
+ transform: List[Callable] = None,
+ k: Optional[int] = 2,
+ ):
+ """
+ Initalised with a dataframe, an id column in that dataframe,
+ the columns that are the dimensions in question.
+ A 'normalize' paramater on if those columns should be
+ normalised before use.
+ and 'label_cols' which are columns that contain categories
+ for items.
+ These can be used to help understand clusters.
"""
self.default_seed = 1221
@@ -86,7 +87,10 @@ def __init__(self,
if cols:
df = df[cols]
else:
- def t(x): return x != id_col and x not in label_cols
+
+ def t(x):
+ return x != id_col and x not in label_cols
+
cols = list(filter(t, source_df.columns))
if normalize:
@@ -102,7 +106,8 @@ def t(x): return x != id_col and x not in label_cols
not_allowed = cols + label_cols
label_df = label_df.drop(
- columns=[x for x in label_df.columns if x not in not_allowed])
+ columns=[x for x in label_df.columns if x not in not_allowed]
+ )
for c in cols:
try:
labels = ["Low", "Medium", "High"]
@@ -149,8 +154,8 @@ def get_cluster_labels(self, include_short=True) -> np.array:
labels = pd.Series(self.get_clusters(self.k).labels_)
def f(x):
- return self.get_label_name(n=x,
- include_short=include_short)
+ return self.get_label_name(n=x, include_short=include_short)
+
labels = labels.apply(f)
return np.array(labels)
@@ -179,21 +184,20 @@ def add_labels(self, labels: Dict[int, Union[str, Tuple[str, str]]]):
return new
- def assign_name(self,
- n: int,
- name: str,
- desc: Optional[str] = ""):
+ def assign_name(self, n: int, name: str, desc: Optional[str] = ""):
k = self.k
if k not in self.label_names:
self.label_names[k] = {}
self.label_descs[k] = {}
- self.label_names[k][n-1] = name
- self.label_descs[k][n-1] = desc
+ self.label_names[k][n - 1] = name
+ self.label_descs[k][n - 1] = desc
- def plot(self,
- limit_columns: Optional[List[str]] = None,
- only_one: Optional[Any] = None,
- show_legend: bool = True):
+ def plot(
+ self,
+ limit_columns: Optional[List[str]] = None,
+ only_one: Optional[Any] = None,
+ show_legend: bool = True,
+ ):
"""
Plot either all possible x, y graphs for k clusters
or just the subset with the named x_var and y_var.
@@ -207,16 +211,15 @@ def plot(self,
if limit_columns:
vars = [x for x in vars if x in limit_columns]
combos = list(combinations(vars, 2))
- rows = math.ceil(len(combos)/num_rows)
+ rows = math.ceil(len(combos) / num_rows)
- plt.rcParams["figure.figsize"] = (15, 5*rows)
+ plt.rcParams["figure.figsize"] = (15, 5 * rows)
df["labels"] = self.get_cluster_labels()
if only_one:
df["labels"] = df["labels"] == only_one
- df["labels"] = df["labels"].map(
- {True: only_one, False: "Other clusters"})
+ df["labels"] = df["labels"].map({True: only_one, False: "Other clusters"})
chart_no = 0
rgb_values = sns.color_palette("Set2", len(df["labels"].unique()))
@@ -227,8 +230,7 @@ def plot(self,
chart_no += 1
ax = fig.add_subplot(rows, num_rows, chart_no)
for c, d in df.groupby("labels"):
- scatter = ax.scatter(d[x_var], d[y_var],
- color=color_map[c], label=c)
+ scatter = ax.scatter(d[x_var], d[y_var], color=color_map[c], label=c)
ax.set_xlabel(self._axis_label(x_var))
ax.set_ylabel(self._axis_label(y_var))
@@ -238,21 +240,26 @@ def plot(self,
plt.show()
def plot_tool(self):
-
def func(cluster, show_legend, **kwargs):
if cluster == "All":
cluster = None
limit_columns = [x for x, y in kwargs.items() if y is True]
- self.plot(only_one=cluster, show_legend=show_legend,
- limit_columns=limit_columns)
+ self.plot(
+ only_one=cluster, show_legend=show_legend, limit_columns=limit_columns
+ )
cluster_options = ["All"] + self.get_label_options()
- analysis_options = {x: True if n <
- 2 else False for n, x in enumerate(self.cols)}
+ analysis_options = {
+ x: True if n < 2 else False for n, x in enumerate(self.cols)
+ }
- tool = interactive(func, cluster=cluster_options, **
- analysis_options, show_legend=False,)
+ tool = interactive(
+ func,
+ cluster=cluster_options,
+ **analysis_options,
+ show_legend=False,
+ )
display(tool)
def _get_clusters(self, k: int):
@@ -270,10 +277,7 @@ def get_clusters(self, k: int):
self.cluster_results[k] = self._get_clusters(k)
return self.cluster_results[k]
- def find_k(self,
- start: int = 15,
- stop: Optional[int] = None,
- step: int = 1):
+ def find_k(self, start: int = 15, stop: Optional[int] = None, step: int = 1):
"""
Graph the elbow and Silhouette method for finding the optimal k.
High silhouette value good.
@@ -284,9 +288,7 @@ def find_k(self,
start = 2
def s_score(kmeans):
- return silhouette_score(self.df,
- kmeans.labels_,
- metric='euclidean')
+ return silhouette_score(self.df, kmeans.labels_, metric="euclidean")
df = pd.DataFrame({"n": range(start, stop, step)})
df["k_means"] = df["n"].apply(self.get_clusters)
@@ -295,21 +297,19 @@ def s_score(kmeans):
plt.rcParams["figure.figsize"] = (10, 5)
plt.subplot(1, 2, 1)
- plt.plot(df["n"], df["sum_squares"], 'bx-')
- plt.xlabel('k')
- plt.ylabel('Sum of squared distances')
- plt.title('Elbow Method For Optimal k')
+ plt.plot(df["n"], df["sum_squares"], "bx-")
+ plt.xlabel("k")
+ plt.ylabel("Sum of squared distances")
+ plt.title("Elbow Method For Optimal k")
plt.subplot(1, 2, 2)
- plt.plot(df["n"], df["silhouette"], 'bx-')
- plt.xlabel('k')
- plt.ylabel('Silhouette score')
- plt.title('Silhouette Method For Optimal k')
+ plt.plot(df["n"], df["silhouette"], "bx-")
+ plt.xlabel("k")
+ plt.ylabel("Silhouette score")
+ plt.title("Silhouette Method For Optimal k")
plt.show()
- def stats(self,
- label_lookup: Optional[dict] = None,
- all_members: bool = False):
+ def stats(self, label_lookup: Optional[dict] = None, all_members: bool = False):
"""
Simple description of sample size
"""
@@ -321,8 +321,7 @@ def stats(self,
df.index = self.df.index
df = df.reset_index()
- pt = df.pivot_table(self.id_col,
- index="labels", aggfunc="count")
+ pt = df.pivot_table(self.id_col, index="labels", aggfunc="count")
pt = pt.rename(columns={self.id_col: "count"})
pt["%"] = (pt["count"] / len(df)).round(3) * 100
@@ -346,11 +345,13 @@ def random_set(s: List[str]) -> List[str]:
pt = pt.rename(columns={self.id_col: "random members"})
return pt
- def raincloud(self,
- column: str,
- one_value: Optional[str] = None,
- groups: Optional[str] = "Cluster",
- use_source: bool = True):
+ def raincloud(
+ self,
+ column: str,
+ one_value: Optional[str] = None,
+ groups: Optional[str] = "Cluster",
+ use_source: bool = True,
+ ):
"""
raincloud plot of a variable, grouped by different clusters
@@ -361,8 +362,12 @@ def raincloud(self,
else:
df = self.df
df["Cluster"] = self.get_cluster_labels()
- df.viz.raincloud(values=column, groups=groups, one_value=one_value,
- title=f"Raincloud plot for {column} variable.")
+ df.viz.raincloud(
+ values=column,
+ groups=groups,
+ one_value=one_value,
+ title=f"Raincloud plot for {column} variable.",
+ )
def reverse_raincloud(self, cluster_label: str):
"""
@@ -371,21 +376,24 @@ def reverse_raincloud(self, cluster_label: str):
"""
df = self.df.copy()
df["Cluster"] = self.get_cluster_labels()
- df = df.melt("Cluster")[lambda df:~(df["variable"] == " ")]
+ df = df.melt("Cluster")[lambda df: ~(df["variable"] == " ")]
df["value"] = df["value"].astype(float)
- df = df[lambda df:(df["Cluster"] == cluster_label)]
+ df = df[lambda df: (df["Cluster"] == cluster_label)]
- df.viz.raincloud(values="value",
- groups="variable",
- title=f"Raincloud plot for Cluster: {cluster_label}")
+ df.viz.raincloud(
+ values="value",
+ groups="variable",
+ title=f"Raincloud plot for Cluster: {cluster_label}",
+ )
def reverse_raincloud_tool(self):
"""
Raincloud tool to examine clusters showing the
distribution of different variables
"""
- tool = interactive(self.reverse_raincloud,
- cluster_label=self.get_label_options())
+ tool = interactive(
+ self.reverse_raincloud, cluster_label=self.get_label_options()
+ )
display(tool)
def raincloud_tool(self, reverse: bool = False):
@@ -404,14 +412,20 @@ def func(variable, comparison, use_source_values):
if comparison == "none":
groups = None
comparison = None
- self.raincloud(variable, one_value=comparison,
- groups=groups, use_source=use_source_values)
+ self.raincloud(
+ variable,
+ one_value=comparison,
+ groups=groups,
+ use_source=use_source_values,
+ )
comparison_options = ["all", "none"] + self.get_label_options()
- tool = interactive(func,
- variable=self.cols,
- use_source_values=True,
- comparison=comparison_options)
+ tool = interactive(
+ func,
+ variable=self.cols,
+ use_source_values=True,
+ comparison=comparison_options,
+ )
display(tool)
@@ -425,23 +439,27 @@ def label_tool(self):
def func(cluster, sort, include_data_labels):
if sort == "Index":
sort = None
- df = self.label_review(label=cluster,
- sort=sort,
- include_data=include_data_labels)
+ df = self.label_review(
+ label=cluster, sort=sort, include_data=include_data_labels
+ )
display(df)
return df
sort_options = ["Index", "% of cluster", "% of label"]
- tool = interactive(func,
- cluster=self.get_label_options(),
- sort=sort_options,
- include_data_labels=True)
+ tool = interactive(
+ func,
+ cluster=self.get_label_options(),
+ sort=sort_options,
+ include_data_labels=True,
+ )
display(tool)
- def label_review(self,
- label: Optional[int] = 1,
- sort: Optional[str] = None,
- include_data: bool = True):
+ def label_review(
+ self,
+ label: Optional[int] = 1,
+ sort: Optional[str] = None,
+ include_data: bool = True,
+ ):
"""
Review labeled data for a cluster
"""
@@ -451,9 +469,9 @@ def label_review(self,
def to_count_pivot(df):
mdf = df.drop(columns=["label"]).melt()
mdf["Count"] = mdf["variable"] + mdf["value"]
- return mdf.pivot_table("Count",
- index=["variable", "value"],
- aggfunc="count")
+ return mdf.pivot_table(
+ "Count", index=["variable", "value"], aggfunc="count"
+ )
df = self.label_df
if include_data is False:
@@ -464,8 +482,7 @@ def to_count_pivot(df):
pt = to_count_pivot(df).join(opt)
pt = pt.rename(columns={"Count": "cluster_count"})
pt["% of cluster"] = (pt["cluster_count"] / len(df)).round(3) * 100
- pt["% of label"] = (pt["cluster_count"] /
- pt["overall_count"]).round(3) * 100
+ pt["% of label"] = (pt["cluster_count"] / pt["overall_count"]).round(3) * 100
if sort:
pt = pt.sort_values(sort, ascending=False)
return pt
@@ -490,10 +507,12 @@ def df_with_labels(self) -> pd.DataFrame:
df["label_desc"] = self.get_cluster_descs()
return df
- def plot3d(self,
- x_var: Optional[str] = None,
- y_var: Optional[str] = None,
- z_var: Optional[str] = None):
+ def plot3d(
+ self,
+ x_var: Optional[str] = None,
+ y_var: Optional[str] = None,
+ z_var: Optional[str] = None,
+ ):
k = self.k
"""
Plot either all possible x, y, z graphs for k clusters
@@ -509,9 +528,9 @@ def plot3d(self,
combos = [x for x in combos if x[1] == y_var]
if z_var:
combos = [x for x in combos if x[1] == y_var]
- rows = math.ceil(len(combos)/2)
+ rows = math.ceil(len(combos) / 2)
- plt.rcParams["figure.figsize"] = (20, 10*rows)
+ plt.rcParams["figure.figsize"] = (20, 10 * rows)
chart_no = 0
fig = plt.figure()
@@ -522,7 +541,7 @@ def plot3d(self,
ax.set_xlabel(self._axis_label(x_var))
ax.set_ylabel(self._axis_label(y_var))
ax.set_zlabel(self._axis_label(z_var))
- plt.title(f'Data with {k} clusters')
+ plt.title(f"Data with {k} clusters")
plt.show()
@@ -537,13 +556,14 @@ def join_distance(df_label_dict: Dict[str, pd.DataFrame]) -> pd.DataFrame:
def prepare(df, label):
- return (df
- .set_index(list(df.columns[:2]))
- .rename(columns={"distance": label})
- .drop(columns=["match", "position"], errors="ignore"))
+ return (
+ df.set_index(list(df.columns[:2]))
+ .rename(columns={"distance": label})
+ .drop(columns=["match", "position"], errors="ignore")
+ )
to_join = [prepare(df, label) for label, df in df_label_dict.items()]
- df = reduce(lambda x, y: x.join(y), to_join)
+ df = reduce(lambda x, y: x.join(y), to_join)
df = df.reset_index()
return df
@@ -557,29 +577,35 @@ class SpacePDAccessor(object):
def __init__(self, pandas_obj):
self._obj = pandas_obj
- def cluster(self,
- id_col: Optional[str] = None,
- cols: Optional[List[str]] = None,
- label_cols: Optional[List[str]] = None,
- normalize: bool = True,
- transform: List[Callable] = None,
- k: Optional[int] = None) -> Cluster:
+ def cluster(
+ self,
+ id_col: Optional[str] = None,
+ cols: Optional[List[str]] = None,
+ label_cols: Optional[List[str]] = None,
+ normalize: bool = True,
+ transform: List[Callable] = None,
+ k: Optional[int] = None,
+ ) -> Cluster:
"""
returns a Cluster helper object for this dataframe
"""
- return Cluster(self._obj,
- id_col=id_col,
- cols=cols,
- label_cols=label_cols,
- normalize=normalize,
- transform=transform,
- k=k)
-
- def self_distance(self,
- id_col: Optional[str] = None,
- cols: Optional[List] = None,
- normalize: bool = False,
- transform: List[callable] = None):
+ return Cluster(
+ self._obj,
+ id_col=id_col,
+ cols=cols,
+ label_cols=label_cols,
+ normalize=normalize,
+ transform=transform,
+ k=k,
+ )
+
+ def self_distance(
+ self,
+ id_col: Optional[str] = None,
+ cols: Optional[List] = None,
+ normalize: bool = False,
+ transform: List[callable] = None,
+ ):
"""
Calculate the distance between all objects in a dataframe
in an n-dimensional space.
@@ -606,8 +632,7 @@ def self_distance(self,
a_col = id_col + "_A"
b_col = id_col + "_B"
- _ = list(product(source_df[id_col],
- source_df[id_col]))
+ _ = list(product(source_df[id_col], source_df[id_col]))
df = pd.DataFrame(_, columns=[a_col, b_col])
@@ -626,10 +651,12 @@ def self_distance(self,
df = df.loc[~(df[a_col] == df[b_col])]
return df
- def join_distance(self,
- other: Union[Dict[str, pd.DataFrame], pd.DataFrame],
- our_label: Optional[str] = "A",
- their_label: Optional[str] = "B"):
+ def join_distance(
+ self,
+ other: Union[Dict[str, pd.DataFrame], pd.DataFrame],
+ our_label: Optional[str] = "A",
+ their_label: Optional[str] = "B",
+ ):
"""
Either merges self and other
(both of whichs hould be the result of
@@ -639,8 +666,7 @@ def join_distance(self,
"""
if not isinstance(other, dict):
- df_label_dict = {our_label: self._obj,
- their_label: other}
+ df_label_dict = {our_label: self._obj, their_label: other}
else:
df_label_dict = other
@@ -656,18 +682,18 @@ def match_distance(self):
def standardise_distance(df):
df = df.copy()
# use tenth from last because the last point might be an extreme outlier (in this case london)
- tenth_from_last_score = df["distance"].sort_values().tail(
- 10).iloc[0]
+ tenth_from_last_score = df["distance"].sort_values().tail(10).iloc[0]
df["match"] = 1 - (df["distance"] / tenth_from_last_score)
df["match"] = df["match"].round(3) * 100
df["match"] = df["match"].apply(lambda x: x if x > 0 else 0)
df = df.sort_values("match", ascending=False)
return df
- return (df
- .groupby(df.columns[0], as_index=False)
- .apply(standardise_distance)
- .reset_index(drop=True))
+ return (
+ df.groupby(df.columns[0], as_index=False)
+ .apply(standardise_distance)
+ .reset_index(drop=True)
+ )
def local_rankings(self):
"""
@@ -679,10 +705,11 @@ def get_position(df):
df["position"] = df["distance"].rank(method="first")
return df
- return (df
- .groupby(df.columns[0], as_index=False)
- .apply(get_position)
- .reset_index(drop=True))
+ return (
+ df.groupby(df.columns[0], as_index=False)
+ .apply(get_position)
+ .reset_index(drop=True)
+ )
@pd.api.extensions.register_dataframe_accessor("joint_space")
@@ -735,20 +762,14 @@ def same_nearest_k(self, k: int = 5):
df = self._obj
def top_k(df, k=5):
- df = (df
- .set_index(list(df.columns[:2]))
- .rank())
+ df = df.set_index(list(df.columns[:2])).rank()
df = df <= k
- same_rank = df.sum(axis=1).reset_index(
- drop=True) == len(list(df.columns))
+ same_rank = df.sum(axis=1).reset_index(drop=True) == len(list(df.columns))
data = [[same_rank.sum() / k]]
d = pd.DataFrame(data, columns=[f"same_top_{k}"])
return d.iloc[0]
- return (df
- .groupby(df.columns[0])
- .apply(top_k, k=k)
- .reset_index())
+ return df.groupby(df.columns[0]).apply(top_k, k=k).reset_index()
def agreement(self, ks: List[int] = [1, 2, 3, 5, 10, 25]):
"""
@@ -759,10 +780,7 @@ def agreement(self, ks: List[int] = [1, 2, 3, 5, 10, 25]):
df = self._obj
def get_average(k):
- return (df
- .joint_space.same_nearest_k(k=k)
- .mean()
- .round(2)[0])
+ return df.joint_space.same_nearest_k(k=k).mean().round(2)[0]
r = pd.DataFrame({"top_k": ks})
r["agreement"] = r["top_k"].apply(get_average)
@@ -776,5 +794,4 @@ def plot(self, sample=None, kind="scatter", title="", **kwargs):
if sample:
df = df.sample(sample)
plt.rcParams["figure.figsize"] = (10, 5)
- df.plot(x=df.columns[2], y=df.columns[3],
- kind=kind, title=title, **kwargs)
+ df.plot(x=df.columns[2], y=df.columns[3], kind=kind, title=title, **kwargs)
diff --git a/df_extensions/viz.py b/df_extensions/viz.py
index 99189be..3ef44f1 100644
--- a/df_extensions/viz.py
+++ b/df_extensions/viz.py
@@ -12,20 +12,20 @@
@pd.api.extensions.register_series_accessor("viz")
class VIZSeriesAccessor:
-
def __init__(self, pandas_obj):
self._obj = pandas_obj
- def raincloud(self,
- groups: Optional[pd.Series] = None,
- ort: Optional[str] = "h",
- pal: Optional[str] = "Set2",
- sigma: Optional[float] = .2,
- title: str = "",
- all_data_label: str = "All data",
- x_label: Optional[str] = None,
- y_label: Optional[str] = None,
- ):
+ def raincloud(
+ self,
+ groups: Optional[pd.Series] = None,
+ ort: Optional[str] = "h",
+ pal: Optional[str] = "Set2",
+ sigma: Optional[float] = 0.2,
+ title: str = "",
+ all_data_label: str = "All data",
+ x_label: Optional[str] = None,
+ y_label: Optional[str] = None,
+ ):
"""
show a raincloud plot of the values of a series
Optional split by a second series (group)
@@ -42,9 +42,17 @@ def raincloud(self,
df[" "] = all_data_label
x_col = " "
- f, ax = plt.subplots(figsize=(14, 2*df[x_col].nunique()))
- pt.RainCloud(x=df[x_col], y=df[s.name], palette=pal, bw=sigma,
- width_viol=.6, ax=ax, orient=ort, move=.3)
+ f, ax = plt.subplots(figsize=(14, 2 * df[x_col].nunique()))
+ pt.RainCloud(
+ x=df[x_col],
+ y=df[s.name],
+ palette=pal,
+ bw=sigma,
+ width_viol=0.6,
+ ax=ax,
+ orient=ort,
+ move=0.3,
+ )
if title:
plt.title(title, loc="center", fontdict={"fontsize": 30})
if x_label is not None:
@@ -56,22 +64,23 @@ def raincloud(self,
@pd.api.extensions.register_dataframe_accessor("viz")
class VIZDFAccessor:
-
def __init__(self, pandas_obj):
self._obj = pandas_obj
- def raincloud(self,
- values: str,
- groups: Optional[str] = None,
- one_value: Optional[str] = None,
- limit: Optional[List[str]] = None,
- ort: Optional[str] = "h",
- pal: Optional[str] = "Set2",
- sigma: Optional[float] = .2,
- title: Optional[str] = "",
- all_data_label: str = "All data",
- x_label: Optional[str] = None,
- y_label: Optional[str] = None,):
+ def raincloud(
+ self,
+ values: str,
+ groups: Optional[str] = None,
+ one_value: Optional[str] = None,
+ limit: Optional[List[str]] = None,
+ ort: Optional[str] = "h",
+ pal: Optional[str] = "Set2",
+ sigma: Optional[float] = 0.2,
+ title: Optional[str] = "",
+ all_data_label: str = "All data",
+ x_label: Optional[str] = None,
+ y_label: Optional[str] = None,
+ ):
"""
helper function for visualising one column against
another with raincloud plots.
@@ -88,11 +97,20 @@ def raincloud(self,
if one_value:
df[groups] = (df[groups] == one_value).map(
- {False: "Other clusters", True: one_value})
-
- f, ax = plt.subplots(figsize=(14, 2*df[groups].nunique()))
- pt.RainCloud(x=df[groups], y=df[values], palette=pal, bw=sigma,
- width_viol=.6, ax=ax, orient=ort, move=.3)
+ {False: "Other clusters", True: one_value}
+ )
+
+ f, ax = plt.subplots(figsize=(14, 2 * df[groups].nunique()))
+ pt.RainCloud(
+ x=df[groups],
+ y=df[values],
+ palette=pal,
+ bw=sigma,
+ width_viol=0.6,
+ ax=ax,
+ orient=ort,
+ move=0.3,
+ )
if title:
plt.title(title, loc="center", fontdict={"fontsize": 30})
if x_label is not None:
diff --git a/helpers/pipe.py b/helpers/pipe.py
index 2b7b629..58a8fbd 100644
--- a/helpers/pipe.py
+++ b/helpers/pipe.py
@@ -12,7 +12,6 @@
class PartialLibrary:
-
def __init__(self, library):
self._library = library
@@ -44,14 +43,12 @@ def amend_partial(pfunc: Callable, value: Any) -> Tuple[Callable, str]:
return pfunc, self_contained
args = [x if x != PipeValue else value for x in pfunc.args]
- kwargs = {k: (v if v != PipeValue else value)
- for k, v in pfunc.keywords.items()}
+ kwargs = {k: (v if v != PipeValue else value) for k, v in pfunc.keywords.items()}
return partial(pfunc.func, *args, **kwargs), self_contained
class PipeStart:
-
def __init__(self, value):
self.value = value
self.operations = []
@@ -91,6 +88,7 @@ class Pipe:
Starting value, then a list of functions to pass the result through.
Current value can be referred as Pipe.value if it can't be passed through one value in a partial.
"""
+
start = PipeStart
value = PipeValue
end = PipeEnd()
diff --git a/management/cli.py b/management/cli.py
index 198deec..0b332cc 100644
--- a/management/cli.py
+++ b/management/cli.py
@@ -3,7 +3,6 @@
class DocCollection:
-
def __init__(self):
self.collection = None
@@ -21,27 +20,48 @@ def cli():
@cli.command()
-@click.argument('slug', default="")
+@click.argument("slug", default="")
@click.option("-p", "--param", nargs=2, multiple=True)
-def render(slug="", param=[]):
- doc_collection = dc.collection
+@click.option("-g", "--group", nargs=1)
+@click.option("--all/--not-all", "render_all", default=False)
+@click.option("--publish/--no-publish", default=False)
+def render(slug="", param=[], group="", render_all=False, publish=False):
params = {x: y for x, y in param}
if slug:
- doc = doc_collection.get(slug)
+ docs = [dc.collection.get(slug)]
+ elif render_all:
+ docs = dc.collection.all()
+ elif group:
+ docs = dc.collection.get_group(group)
else:
- doc = doc_collection.first()
+ docs = [dc.collection.first()]
+
if params:
print("using custom params")
print(params)
- doc.render(context=params)
+
+ for doc in docs:
+ doc.render(context=params)
+ if publish:
+ print("starting publication flow")
+ doc.upload()
@cli.command()
-@click.argument('slug', default="")
-def upload(slug=""):
- doc_collection = dc.collection
+@click.argument("slug", default="")
+@click.option("--all/--not-all", "render_all", default=False)
+def upload(slug="", param=[], render_all=False):
+ params = {x: y for x, y in param}
if slug:
- doc = doc_collection.get(slug)
+ docs = [dc.collection.get(slug)]
+ elif render_all:
+ docs = dc.collection.all()
else:
- doc = doc_collection.first()
- doc.upload()
\ No newline at end of file
+ docs = [dc.collection.first()]
+
+ if params:
+ print("using custom params")
+ print(params)
+
+ for doc in docs:
+ doc.upload()
diff --git a/management/exporters.py b/management/exporters.py
index 86e4d8e..4e4ed76 100644
--- a/management/exporters.py
+++ b/management/exporters.py
@@ -12,11 +12,13 @@
import nbformat
from ipython_genutils.text import indent as normal_indent
from nbconvert import MarkdownExporter, HTMLExporter
-from nbconvert.preprocessors import (ClearMetadataPreprocessor,
- ClearOutputPreprocessor,
- ExecutePreprocessor,
- ExtractOutputPreprocessor,
- Preprocessor)
+from nbconvert.preprocessors import (
+ ClearMetadataPreprocessor,
+ ClearOutputPreprocessor,
+ ExecutePreprocessor,
+ ExtractOutputPreprocessor,
+ Preprocessor,
+)
from traitlets.config import Config
notebook_render_dir = "_notebook_resources"
@@ -30,8 +32,10 @@ class RemoveOnContent(Preprocessor):
def preprocess(self, nb, resources):
# Filter out cells that meet the conditions
- nb.cells = [self.preprocess_cell(cell, resources, index)[0]
- for index, cell in enumerate(nb.cells)]
+ nb.cells = [
+ self.preprocess_cell(cell, resources, index)[0]
+ for index, cell in enumerate(nb.cells)
+ ]
return nb, resources
@@ -42,9 +46,7 @@ def preprocess_cell(self, cell, resources, cell_index):
if cell["source"]:
if "#HIDE" == cell["source"][:5]:
- cell.transient = {
- 'remove_source': True
- }
+ cell.transient = {"remove_source": True}
return cell, resources
@@ -89,11 +91,13 @@ class MarkdownRenderer(object):
include_input = False
markdown_tables = True
- def __init__(self,
- input_name="readme.ipynb",
- output_name=None,
- include_input=None,
- clear_and_execute=None):
+ def __init__(
+ self,
+ input_name="readme.ipynb",
+ output_name=None,
+ include_input=None,
+ clear_and_execute=None,
+ ):
if include_input is None:
include_input = self.__class__.include_input
if clear_and_execute is None:
@@ -106,8 +110,7 @@ def __init__(self,
def check_for_self_reference(self, cell):
# scope out the cell that called this function
# prevent circular call
- contains_str = check_string_in_source(
- self.__class__.self_reference, cell)
+ contains_str = check_string_in_source(self.__class__.self_reference, cell)
is_code = cell["cell_type"] == "code"
return contains_str and is_code
@@ -115,9 +118,11 @@ def get_contents(self, input_file):
with open(input_file) as f:
nb = json.load(f)
- nb["cells"] = [x for x in nb["cells"]
- if x["source"] and
- self.check_for_self_reference(x) is False]
+ nb["cells"] = [
+ x
+ for x in nb["cells"]
+ if x["source"] and self.check_for_self_reference(x) is False
+ ]
str_notebook = json.dumps(nb)
nb = nbformat.reads(str_notebook, as_version=4)
@@ -129,13 +134,13 @@ def get_config(self):
pre_processors = []
if self.clear_and_execute:
- pre_processors += [ClearMetadataPreprocessor,
- ClearOutputPreprocessor,
- ExecutePreprocessor]
+ pre_processors += [
+ ClearMetadataPreprocessor,
+ ClearOutputPreprocessor,
+ ExecutePreprocessor,
+ ]
- pre_processors += [CustomExtractOutputPreprocessor,
- RemoveOnContent
- ]
+ pre_processors += [CustomExtractOutputPreprocessor, RemoveOnContent]
c.MarkdownExporter.preprocessors = pre_processors
c.MarkdownExporter.filters = {"indent": indent}
@@ -157,8 +162,9 @@ def process(self, input_file=None, output_file=None):
output_file = self.output_name
if output_file is None:
- output_file = Path(os.path.splitext(
- input_file)[0] + self.__class__.default_ext)
+ output_file = Path(
+ os.path.splitext(input_file)[0] + self.__class__.default_ext
+ )
output_base_path = Path(output_file).parent
@@ -174,8 +180,7 @@ def process(self, input_file=None, output_file=None):
exporter = self.__class__.exporter_class(config=c)
- resources = {"output_files_dir": notebook_render_dir,
- "unique_key": base_root}
+ resources = {"output_files_dir": notebook_render_dir, "unique_key": base_root}
body, resources = exporter.from_notebook_node(nb, resources)
@@ -190,8 +195,9 @@ def process(self, input_file=None, output_file=None):
if self.__class__.markdown_tables:
body = body.replace(
- '
\n | ', "
")
- soup = BeautifulSoup(body, 'html.parser')
+ '
\n | ', "
"
+ )
+ soup = BeautifulSoup(body, "html.parser")
for div in soup.find_all("pagebreak"):
div.replaceWith('')
@@ -209,14 +215,13 @@ def process(self, input_file=None, output_file=None):
body = body.replace("<br/>", "
")
body = body.replace("![png]", "![]")
body = body.replace('', "")
+ body = body.replace("", "")
# write main file
with open(output_file, "w") as f:
f.write(body)
- print("Written to {0} at {1}".format(
- output_file, datetime.now(tz=None)))
+ print("Written to {0} at {1}".format(output_file, datetime.now(tz=None)))
class HTML_Renderer(MarkdownRenderer):
@@ -232,13 +237,13 @@ def get_config(self):
pre_processors = []
if self.clear_and_execute:
- pre_processors += [ClearMetadataPreprocessor,
- ClearOutputPreprocessor,
- ExecutePreprocessor]
+ pre_processors += [
+ ClearMetadataPreprocessor,
+ ClearOutputPreprocessor,
+ ExecutePreprocessor,
+ ]
- pre_processors += [CustomExtractOutputPreprocessor,
- RemoveOnContent
- ]
+ pre_processors += [CustomExtractOutputPreprocessor, RemoveOnContent]
c.HTMLExporter.preprocessors = pre_processors
diff --git a/management/render_processing.py b/management/render_processing.py
index e569c38..29b8224 100644
--- a/management/render_processing.py
+++ b/management/render_processing.py
@@ -1,9 +1,11 @@
import json
+import shutil
+from copy import deepcopy
from dataclasses import dataclass
from importlib import import_module
from pathlib import Path
from typing import Iterable, Optional
-import shutil
+
import papermill as pm
import pypandoc
from jinja2 import Template
@@ -38,14 +40,15 @@ def add_tag_based_on_content(input_file: Path, tag: str, content: str):
def render(txt: str, context: dict):
- return Template(txt).render(**context)
+ t = Template(str(txt))
+ return t.render(**context)
def combine_outputs(parts, output_path):
- text_parts = [open(p, 'r').read() for p in parts]
+ text_parts = [open(p, "r").read() for p in parts]
result = "\n".join(text_parts)
result = result.replace("Notebook", "")
- with open(output_path, 'w') as f:
+ with open(output_path, "w") as f:
f.write(result)
@@ -54,6 +57,7 @@ class Notebook:
"""
Handle talking to and rendering one file
"""
+
name: str
_parent: "Document"
@@ -81,12 +85,9 @@ def papermill(self, slug, params, rerun: bool = True):
print("Not papermilling, just copying current file")
shutil.copy(self.raw_path(), self.papermill_path(slug))
else:
- add_tag_based_on_content(
- actual_path, "parameters", "#default-params")
+ add_tag_based_on_content(actual_path, "parameters", "#default-params")
pm.execute_notebook(
- actual_path,
- self.papermill_path(slug),
- parameters=params
+ actual_path, self.papermill_path(slug), parameters=params
)
def rendered_filename(self, slug: str, ext: str = ".md"):
@@ -106,13 +107,17 @@ def render(self, slug: str, hide_input: bool = True):
include_input = not hide_input
input_path = self.papermill_path(slug)
exporters.render_to_markdown(
- input_path, self.rendered_filename(slug, ".md"),
+ input_path,
+ self.rendered_filename(slug, ".md"),
clear_and_execute=False,
- include_input=include_input)
+ include_input=include_input,
+ )
exporters.render_to_html(
- input_path, self.rendered_filename(slug, ".html"),
+ input_path,
+ self.rendered_filename(slug, ".html"),
clear_and_execute=False,
- include_input=include_input)
+ include_input=include_input,
+ )
class Document:
@@ -127,8 +132,7 @@ def __init__(self, name: str, data: dict, context: Optional[dict] = None):
self._data = data.copy()
self.options = {"rerun": True, "hide_input": True}
self.options.update(self._data.get("options", {}))
- self.notebooks = [Notebook(x, _parent=self)
- for x in self._data["notebooks"]]
+ self.notebooks = [Notebook(x, _parent=self) for x in self._data["notebooks"]]
self.init_rendered_values(context)
def init_rendered_values(self, context):
@@ -157,7 +161,7 @@ def get_rendered_parameters(self, context) -> dict:
"""
render properties using jinga
"""
- raw_params = self._data["parameters"]
+ raw_params = self._data.get("parameters", {})
final_params = {}
for k, v in raw_params.items():
nv = context.get(k, render(v, context))
@@ -193,8 +197,7 @@ def render(self, context: Optional[dict] = None):
# combine for both md and html
for ext in [".md", ".html"]:
dest = self.rendered_filename(ext)
- files = [x.rendered_filename(self.slug, ext)
- for x in self.notebooks]
+ files = [x.rendered_filename(self.slug, ext) for x in self.notebooks]
combine_outputs(files, dest)
resources_dir = files[0].parent / "_notebook_resources"
dest_resources = dest.parent / "_notebook_resources"
@@ -208,9 +211,15 @@ def render(self, context: Optional[dict] = None):
if template.exists() is False:
raise ValueError("Missing Template")
reference_doc = str(template)
- pypandoc.convert_file(str(input_path_html), 'docx', outputfile=str(
- output_path_doc), extra_args=[f"--resource-path={str(render_dir)}",
- f"--reference-doc={reference_doc}"])
+ pypandoc.convert_file(
+ str(input_path_html),
+ "docx",
+ outputfile=str(output_path_doc),
+ extra_args=[
+ f"--resource-path={str(render_dir)}",
+ f"--reference-doc={reference_doc}",
+ ],
+ )
def upload(self):
"""
@@ -222,8 +231,7 @@ def upload(self):
file_path = self.rendered_filename(".docx")
g_folder_id = v["g_folder_id"]
g_drive_id = v["g_drive_id"]
- g_drive_upload_and_format(
- file_name, file_path, g_folder_id, g_drive_id)
+ g_drive_upload_and_format(file_name, file_path, g_folder_id, g_drive_id)
class DocumentCollection:
@@ -239,11 +247,33 @@ def from_yaml(cls, yaml_file: Path):
return cls(data)
def __init__(self, data: dict):
+
+ for k, v in data.items():
+ if "meta" not in v:
+ data[k]["meta"] = False
+ if "extends" in v:
+ base = deepcopy(data[v["extends"]])
+ if "meta" in base:
+ base.pop("meta")
+ base.update(v)
+ base.pop("extends")
+ data[k] = base
+
+ for k, v in data.items():
+ if "group" not in v:
+ data[k]["group"] = None
+
self.docs = {name: Document(name, data) for name, data in data.items()}
def all(self) -> Iterable:
for d in self.docs.values():
- yield d
+ if d._data["meta"] is False:
+ yield d
+
+ def get_group(self, group: str) -> Iterable:
+ for d in self.all():
+ if d._data["group"] == group:
+ yield d
def first(self) -> Document:
return list(self.docs.values())[0]
diff --git a/management/settings.py b/management/settings.py
index 3bf7378..ebca2b4 100644
--- a/management/settings.py
+++ b/management/settings.py
@@ -20,7 +20,7 @@ def get_settings(yaml_file: str = "settings.yaml", env_file: str = ".env"):
return {}
settings_file = Path(*top_level, yaml_file)
- with open(settings_file, 'r') as fp:
+ with open(settings_file, "r") as fp:
data = yaml.load(fp, Loader=yaml.Loader)
env_data = {}
diff --git a/management/upload.py b/management/upload.py
index 798f202..ba647b0 100644
--- a/management/upload.py
+++ b/management/upload.py
@@ -1,5 +1,8 @@
from notebook_helper.apis.google_api import (
- DriveIntegration, ScriptIntergration, test_settings)
+ DriveIntegration,
+ ScriptIntergration,
+ test_settings,
+)
from .settings import settings
@@ -20,7 +23,9 @@ def format_document(url):
Apply google sheets formatter to URL
"""
api = ScriptIntergration(settings["GOOGLE_CLIENT_JSON"])
- script_id = "AKfycbwjKpOgzKaDHahyn-7If0LzMhaNfMTTsiHf6nvgL2gaaVsgI_VvuZjHJWAzRaehENLX"
+ script_id = (
+ "AKfycbwjKpOgzKaDHahyn-7If0LzMhaNfMTTsiHf6nvgL2gaaVsgI_VvuZjHJWAzRaehENLX"
+ )
func = api.get_function(script_id, "formatWordURL")
print("formatting document, this may take a few minutes")
v = func(url)
diff --git a/progress.py b/progress.py
index 9a50d99..1effc8f 100644
--- a/progress.py
+++ b/progress.py
@@ -5,12 +5,15 @@
console = get_console()
-def track_progress(iterable: Iterable,
- name: Optional[str] = None,
- total: Optional[int] = None,
- update_label: bool = False,
- label_func: Optional[Callable] = lambda x: x,
- clear: Optional[bool] = True):
+
+def track_progress(
+ iterable: Iterable,
+ name: Optional[str] = None,
+ total: Optional[int] = None,
+ update_label: bool = False,
+ label_func: Optional[Callable] = lambda x: x,
+ clear: Optional[bool] = True,
+):
"""
simple tracking loop using rich progress
"""
@@ -27,4 +30,4 @@ def track_progress(iterable: Iterable,
description = f"{name}: {label_func(i)}"
else:
description = name
- progress.update(task, advance=1, description=description)
\ No newline at end of file
+ progress.update(task, advance=1, description=description)