From 1dc6708a92c52545ddc2db3949ba13b3c29f5d47 Mon Sep 17 00:00:00 2001 From: Florian Mounier Date: Thu, 13 Sep 2012 12:13:33 +0200 Subject: [PATCH] Protect the config --- pygal/config.py | 12 +++- pygal/ghost.py | 12 +++- pygal/test/test_config.py | 112 +++++++++++++++++++++++++++++++++++++- 3 files changed, 132 insertions(+), 4 deletions(-) diff --git a/pygal/config.py b/pygal/config.py index b6852f3..083393c 100644 --- a/pygal/config.py +++ b/pygal/config.py @@ -20,7 +20,7 @@ """ Config module with all options """ - +from copy import deepcopy from pygal.style import DefaultStyle @@ -117,6 +117,13 @@ class Config(object): def __init__(self, **kwargs): """Can be instanciated with config kwargs""" + for k in dir(self): + v = getattr(self, k) + if (k not in self.__dict__ and not + k.startswith('_') and not + hasattr(v, '__call__')): + setattr(self, k, v) + self.css = list(self.css) self.js = list(self.js) self._update(kwargs) @@ -152,3 +159,6 @@ class Config(object): elif not hasattr(value, '__call__'): config[attr] = value return config + + def copy(self): + return deepcopy(self) diff --git a/pygal/ghost.py b/pygal/ghost.py index 1e00b84..ea746d0 100644 --- a/pygal/ghost.py +++ b/pygal/ghost.py @@ -40,8 +40,16 @@ class Ghost(object): def __init__(self, config=None, **kwargs): """Init config""" - self.config = config or Config() - self.config(**kwargs) + if config and type(config) == type: + config = config() + + if config: + config = config.copy() + else: + config = Config() + + config(**kwargs) + self.config = config self.series = [] def add(self, title, values): diff --git a/pygal/test/test_config.py b/pygal/test/test_config.py index ff9e04b..7d5073d 100644 --- a/pygal/test/test_config.py +++ b/pygal/test/test_config.py @@ -16,11 +16,121 @@ # # You should have received a copy of the GNU Lesser General Public License # along with pygal. If not, see . -from pygal import Line, Dot, Pie, Radar +from pygal import Line, Dot, Pie, Radar, Config from pygal.test.utils import texts from pygal.test import pytest_generate_tests, make_data +def test_config_behaviours(): + line1 = Line() + line1.show_legend = False + line1.fill = True + line1.pretty_print = True + line1.x_labels = ['a', 'b', 'c'] + line1.add('_', [1, 2, 3]) + l1 = line1.render() + + line2 = Line( + show_legend=False, + fill=True, + pretty_print=True, + x_labels=['a', 'b', 'c']) + line2.add('_', [1, 2, 3]) + l2 = line2.render() + assert l1 == l2 + + class LineConfig(Config): + show_legend = False + fill = True + pretty_print = True + x_labels = ['a', 'b', 'c'] + + line3 = Line(LineConfig) + line3.add('_', [1, 2, 3]) + l3 = line3.render() + assert l1 == l3 + + line4 = Line(LineConfig()) + line4.add('_', [1, 2, 3]) + l4 = line4.render() + assert l1 == l4 + + +def test_config_alterations_class(): + class LineConfig(Config): + show_legend = False + fill = True + pretty_print = True + x_labels = ['a', 'b', 'c'] + + line1 = Line(LineConfig) + line1.add('_', [1, 2, 3]) + l1 = line1.render() + + LineConfig.stroke = False + line2 = Line(LineConfig) + line2.add('_', [1, 2, 3]) + l2 = line2.render() + assert l1 != l2 + + l1bis = line1.render() + assert l1 == l1bis + + +def test_config_alterations_instance(): + class LineConfig(Config): + show_legend = False + fill = True + pretty_print = True + x_labels = ['a', 'b', 'c'] + + config = LineConfig() + line1 = Line(config) + line1.add('_', [1, 2, 3]) + l1 = line1.render() + + config.stroke = False + line2 = Line(config) + line2.add('_', [1, 2, 3]) + l2 = line2.render() + assert l1 != l2 + + l1bis = line1.render() + assert l1 == l1bis + + +def test_config_alterations_kwargs(): + class LineConfig(Config): + show_legend = False + fill = True + pretty_print = True + x_labels = ['a', 'b', 'c'] + + config = LineConfig() + + line1 = Line(config) + line1.add('_', [1, 2, 3]) + l1 = line1.render() + + line1.stroke = False + l1bis = line1.render() + assert l1 != l1bis + + line2 = Line(config) + line2.add('_', [1, 2, 3]) + l2 = line2.render() + assert l1 == l2 + assert l1bis != l2 + + line3 = Line(config, title='Title') + line3.add('_', [1, 2, 3]) + l3 = line3.render() + assert l3 != l2 + + l2bis = line2.render() + assert l2 == l2bis + + def test_logarithmic(): line = Line(logarithmic=True) line.add('_', [1, 10 ** 10, 1])