# -*- coding: utf-8 -*-
"""
    tests.regression
    ~~~~~~~~~~~~~~~~~~~~~~~~~~

    Tests regressions.

    :copyright: (c) 2015 by Armin Ronacher.
    :license: BSD, see LICENSE for more details.
"""

import pytest

import os
import gc
import sys
import flask
import threading
from werkzeug.exceptions import NotFound


_gc_lock = threading.Lock()


class assert_no_leak(object):

    def __enter__(self):
        gc.disable()
        _gc_lock.acquire()
        loc = flask._request_ctx_stack._local

        # Force Python to track this dictionary at all times.
        # This is necessary since Python only starts tracking
        # dicts if they contain mutable objects.  It's a horrible,
        # horrible hack but makes this kinda testable.
        loc.__storage__['FOOO'] = [1, 2, 3]

        gc.collect()
        self.old_objects = len(gc.get_objects())

    def __exit__(self, exc_type, exc_value, tb):
        if not hasattr(sys, 'getrefcount'):
            gc.collect()
        new_objects = len(gc.get_objects())
        if new_objects > self.old_objects:
            pytest.fail('Example code leaked')
        _gc_lock.release()
        gc.enable()


# XXX: untitaker: These tests need to be revised. They broke around the time we
# ported Flask to Python 3.
@pytest.mark.skipif(os.environ.get('RUN_FLASK_MEMORY_TESTS') != '1',
                    reason='Turned off due to envvar.')
def test_memory_consumption():
    app = flask.Flask(__name__)

    @app.route('/')
    def index():
        return flask.render_template('simple_template.html', whiskey=42)

    def fire():
        with app.test_client() as c:
            rv = c.get('/')
            assert rv.status_code == 200
            assert rv.data == b'<h1>42</h1>'

    # Trigger caches
    fire()

    # This test only works on CPython 2.7.
    if sys.version_info >= (2, 7) and \
            not hasattr(sys, 'pypy_translation_info'):
        with assert_no_leak():
            for x in range(10):
                fire()


def test_safe_join_toplevel_pardir():
    from flask.helpers import safe_join
    with pytest.raises(NotFound):
        safe_join('/foo', '..')


def test_aborting():
    class Foo(Exception):
        whatever = 42
    app = flask.Flask(__name__)
    app.testing = True

    @app.errorhandler(Foo)
    def handle_foo(e):
        return str(e.whatever)

    @app.route('/')
    def index():
        raise flask.abort(flask.redirect(flask.url_for('test')))

    @app.route('/test')
    def test():
        raise Foo()

    with app.test_client() as c:
        rv = c.get('/')
        assert rv.headers['Location'] == 'http://localhost/test'
        rv = c.get('/test')
        assert rv.data == b'42'