From 5996d0569cc98a468af62cb608795642773993ac Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Wed, 6 Jan 2016 10:28:38 +0100 Subject: [PATCH] Confidence intervals (thanks @chartique). Fix #292. Use metadata to specify CI. Missing auto viewport. (Computing ymin ymax from ci) --- demo/moulinrouge/tests.py | 54 +++++++++++++++++++++++++++++ docs/changelog.rst | 2 ++ pygal/css/style.css | 6 ++-- pygal/graph/bar.py | 11 ++++-- pygal/graph/graph.py | 19 ++++++++++- pygal/graph/line.py | 5 +++ pygal/graph/stackedbar.py | 1 - pygal/stats.py | 72 +++++++++++++++++++++++++++++++++++++++ pygal/style.py | 13 ++++++- pygal/svg.py | 24 ++++++++++++- pygal/util.py | 1 - 11 files changed, 199 insertions(+), 9 deletions(-) create mode 100644 pygal/stats.py diff --git a/demo/moulinrouge/tests.py b/demo/moulinrouge/tests.py index 3ef2c7a..1036bf4 100644 --- a/demo/moulinrouge/tests.py +++ b/demo/moulinrouge/tests.py @@ -24,6 +24,7 @@ except ImportError: from flask import abort from pygal.style import styles, Style, RotateStyle from pygal.colors import rotate +from pygal import stats from pygal.graph.horizontal import HorizontalGraph from random import randint, choice from datetime import datetime, date @@ -1023,6 +1024,59 @@ def get_test_routes(app): chart.interpolate = 'cubic' return chart.render_response() + @app.route('/test/erfinv/approx') + def test_erfinv(): + from scipy import stats as sstats + chart = Line(show_dots=False) + chart.add('scipy', [ + sstats.norm.ppf(x/1000) for x in range(1, 999)]) + chart.add('approx', [stats.ppf(x/1000) for x in range(1, 999)]) + chart.add('scipy', [ + sstats.norm.ppf(x/1000) for x in range(1, 999)]) + + # chart.add('approx', [special.erfinv(x/1000) - erfinv(x/1000) for x in range(-999, 1000)]) + return chart.render_response() + + @app.route('/test/ci/') + def test_ci_for(chart): + chart = CHARTS_BY_NAME[chart]( + confidence_interval_proportion=True, + style=styles['default']( + value_font_family='googlefont:Raleway', + value_colors=(None, None, 'blue', 'red', 'green'), + ci_colors=(None, 'magenta') + )) + chart.add('Series 1', [ + {'value': 127.3, 'ci': { + 'type': 'continuous', 'sample_size': 3534, 'stddev': 19, + 'confidence': .99}}, + {'value': 127.3, 'ci': { + 'type': 'continuous', 'sample_size': 3534, 'stddev': 19}}, + {'value': 127.3, 'ci': { + 'type': 'continuous', 'sample_size': 3534, 'stddev': 19, + 'confidence': .90}}, + {'value': 127.3, 'ci': { + 'type': 'continuous', 'sample_size': 3534, 'stddev': 19, + 'confidence': .75}}, + # {'value': 73, 'ci': {'sample_size': 200}}, + # {'value': 54, 'ci': {'type': 'dichotomous', 'sample_size': 250}}, + # {'value': 67, 'ci': {'sample_size': 100}}, + # {'value': 61, 'ci': {'sample_size': 750}} + ]) + chart.add('Series 2', [ + {'value': 34.5, 'ci': { + 'type': 'dichotomous', 'sample_size': 3532}}, + ]) + chart.add('Series 3', [ + {'value': 100, 'ci': {'low': 50, 'high': 150}}, + {'value': 100, 'ci': {'low': 75, 'high': 175}}, + {'value': 50, 'ci': {'low': 50, 'high': 100}}, + {'value': 125, 'ci': {'low': 120, 'high': 130}}, + ]) + chart.range = (30, 200) + # chart.range = (32, 37) + return chart.render_response() + return list(sorted(filter( lambda x: x.startswith('test') and not x.endswith('_for'), locals())) ) + list(sorted(filter( diff --git a/docs/changelog.rst b/docs/changelog.rst index 1a5bd02..3225b38 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -6,6 +6,8 @@ Changelog ====== * Bar print value positioning with `print_values_position`. Can be `top`, `center` or `bottom` (thanks @chartique #291) +* Confidence intervals (thanks @chartique #292) + 2.0.12 ====== diff --git a/pygal/css/style.css b/pygal/css/style.css index f7e9c69..d2016a3 100644 --- a/pygal/css/style.css +++ b/pygal/css/style.css @@ -94,6 +94,10 @@ fill-opacity: {{ style.opacity }}; } +{{ id }}.ci { + stroke: {{ style.foreground }}; +} + {{ id }}.reactive.active, {{ id }}.active .reactive { fill-opacity: {{ style.opacity_hover }}; @@ -151,5 +155,3 @@ {{ colors }} {{ strokes }} - - diff --git a/pygal/graph/bar.py b/pygal/graph/bar.py index 83e42d3..3ba0a08 100644 --- a/pygal/graph/bar.py +++ b/pygal/graph/bar.py @@ -116,11 +116,18 @@ class Bar(Graph): self.svg.node(bars, class_='bar'), metadata) - bounds = self._bar( + x_, y_, width, height = self._bar( serie, bar, x, y, i, self.zero, secondary=rescale) + print(y_) + + self._confidence_interval( + serie_node['overlay'], x_ + width / 2, y_, serie.values[i], + metadata) + self._tooltip_and_print_values( - serie_node, serie, bar, i, val, metadata, *bounds) + serie_node, serie, bar, i, val, metadata, + x_, y_, width, height) def _compute(self): """Compute y min and max and y scale and set labels""" diff --git a/pygal/graph/graph.py b/pygal/graph/graph.py index e40b5f9..38096fa 100644 --- a/pygal/graph/graph.py +++ b/pygal/graph/graph.py @@ -27,6 +27,7 @@ from math import ceil, cos, sin, sqrt from pygal._compat import is_list_like, is_str, to_str from pygal.graph.public import PublicApi from pygal.interpolate import INTERPOLATIONS +from pygal import stats from pygal.util import ( cached_property, compute_scale, cut, decorate, get_text_box, get_texts_box, humanize, majorize, rad, reverse_text_len, @@ -684,7 +685,6 @@ class Graph(PublicApi): # Inner margin if self.print_values_position == 'top': - gw = self.width - self.margin_box.x gh = self.height - self.margin_box.y alpha = 1.1 * (self.style.value_font_size / gh) * self._box.height if self._max > 0: @@ -692,6 +692,23 @@ class Graph(PublicApi): if self._min < 0: self._box.ymin -= alpha + def _confidence_interval(self, node, x, y, value, metadata): + if not metadata or 'ci' not in metadata: + return + ci = metadata['ci'] + ci['point_estimate'] = value + + low, high = getattr( + stats, + 'confidence_interval_%s' % ci.get('type', 'manual') + )(**ci) + + self.svg.confidence_interval( + node, x, + # Respect some charts y modifications (pyramid, stackbar) + y + (self.view.y(low) - self.view.y(value)), + y + (self.view.y(high) - self.view.y(value))) + @cached_property def _legends(self): """Getter for series title""" diff --git a/pygal/graph/line.py b/pygal/graph/line.py index 2fe7527..baf17d8 100644 --- a/pygal/graph/line.py +++ b/pygal/graph/line.py @@ -111,10 +111,15 @@ class Line(Graph): if y > self.view.height / 2: classes.append('top') classes = ' '.join(classes) + + self._confidence_interval( + serie_node['overlay'], x, y, serie.values[i], metadata) + dots = decorate( self.svg, self.svg.node(serie_node['overlay'], class_="dots"), metadata) + val = self._get_value(serie.points, i) alter(self.svg.transposable_node( dots, 'circle', cx=x, cy=y, r=serie.dots_size, diff --git a/pygal/graph/stackedbar.py b/pygal/graph/stackedbar.py index 0e29ba7..13848b0 100644 --- a/pygal/graph/stackedbar.py +++ b/pygal/graph/stackedbar.py @@ -25,7 +25,6 @@ from __future__ import division from pygal.adapters import none_to_zero from pygal.graph.bar import Bar -from pygal.util import ident, swap class StackedBar(Bar): diff --git a/pygal/stats.py b/pygal/stats.py new file mode 100644 index 0000000..77f6390 --- /dev/null +++ b/pygal/stats.py @@ -0,0 +1,72 @@ +from math import log, sqrt, pi +try: + from scipy import stats +except ImportError: + stats = None + + +def erfinv(x, a=.147): + """Approximation of the inverse error function + https://en.wikipedia.org/wiki/Error_function + #Approximation_with_elementary_functions + """ + lnx = log(1 - x * x) + part1 = (2 / (a * pi) + lnx / 2) + part2 = lnx / a + sgn = 1 if x > 0 else -1 + return sgn * sqrt(sqrt(part1 * part1 - part2) - part1) + + +def norm_ppf(x): + if not 0 < x < 1: + raise ValueError("Can't compute the percentage point for value %d" % x) + return sqrt(2) * erfinv(2 * x - 1) + + +def ppf(x, n): + if stats: + if n < 30: + return stats.t.ppf(x, n) + return stats.norm.ppf(x) + else: + if n < 30: + # TODO: implement power series: + # http://eprints.maths.ox.ac.uk/184/1/tdist.pdf + raise ImportError( + 'You must have scipy installed to use t-student ' + 'when sample_size is below 30') + return norm_ppf(x) + +# According to http://sphweb.bumc.bu.edu/otlt/MPH-Modules/BS/ +# BS704_Confidence_Intervals/BS704_Confidence_Intervals_print.html + + +def confidence_interval_continuous( + point_estimate, stddev, sample_size, confidence=.95, **kwargs): + """Continuous confidence interval from sample size and standard error""" + alpha = ppf((confidence + 1) / 2, sample_size - 1) + + margin = stddev / sqrt(sample_size) + return (point_estimate - alpha * margin, point_estimate + alpha * margin) + + +def confidence_interval_dichotomous( + point_estimate, sample_size, confidence=.95, bias=False, + percentage=True, **kwargs): + """Dichotomous confidence interval from sample size and maybe a bias""" + alpha = ppf((confidence + 1) / 2, sample_size - 1) + p = point_estimate + if percentage: + p /= 100 + + margin = sqrt(p * (1 - p) / sample_size) + if bias: + margin += .5 / sample_size + if percentage: + margin *= 100 + + return (point_estimate - alpha * margin, point_estimate + alpha * margin) + + +def confidence_interval_manual(point_estimate, low, high): + return (low, high) diff --git a/pygal/style.py b/pygal/style.py index 6eb4a86..f09c55e 100644 --- a/pygal/style.py +++ b/pygal/style.py @@ -88,6 +88,7 @@ class Style(object): ) value_colors = () + ci_colors = () def __init__(self, **kwargs): """Create the style""" @@ -124,6 +125,15 @@ class Style(object): ' fill: {1};\n' '}}\n') % (prefix,)).format(*tupl) + def ci_color(tupl): + """Make a value color css""" + if not tupl[1]: + return '' + return (( + '%s .color-{0} .ci {{\n' + ' stroke: {1};\n' + '}}\n') % (prefix,)).format(*tupl) + if len(self.colors) < len_: missing = len_ - len(self.colors) cycles = 1 + missing // len(self.colors) @@ -150,7 +160,8 @@ class Style(object): return '\n'.join(chain( map(color, enumerate(colors)), - map(value_color, enumerate(value_colors)))) + map(value_color, enumerate(value_colors)), + map(ci_color, enumerate(self.ci_colors)))) def to_dict(self): """Convert instance to a serializable mapping.""" diff --git a/pygal/svg.py b/pygal/svg.py index c485545..f567912 100644 --- a/pygal/svg.py +++ b/pygal/svg.py @@ -232,7 +232,7 @@ class Svg(object): def line(self, node, coords, close=False, **kwargs): """Draw a svg line""" line_len = len(coords) - if line_len < 2: + if len([c for c in coords if c[1] is not None]) < 2: return root = 'M%s L%s Z' if close else 'M%s L%s' origin_index = 0 @@ -296,6 +296,28 @@ class Svg(object): self.graph._static_value(serie_node, val, x, y, metadata) return rv + def confidence_interval(self, node, x, low, high, width=7): + if self.graph.horizontal: + coord_format = lambda xy: '%f %f' % (xy[1], xy[0]) + else: + coord_format = lambda xy: '%f %f' % xy + + shr = lambda xy: (xy[0] + width, xy[1]) + shl = lambda xy: (xy[0] - width, xy[1]) + + top = (x, high) + bottom = (x, low) + + ci = self.node(node, class_="ci") + self.node( + ci, 'path', d="M%s L%s M%s L%s M%s L%s L%s M%s L%s" % tuple( + map(coord_format, ( + top, shr(top), top, shl(top), top, + bottom, shr(bottom), bottom, shl(bottom) + )) + ), class_='nofill' + ) + def pre_render(self): """Last things to do before rendering""" self.add_styles() diff --git a/pygal/util.py b/pygal/util.py index 0c83ba6..65f41a0 100644 --- a/pygal/util.py +++ b/pygal/util.py @@ -250,7 +250,6 @@ def decorate(svg, node, metadata): if 'label' in metadata: svg.node(node, 'desc', class_='label').text = to_unicode( metadata['label']) - return node