diff --git a/CHANGES b/CHANGES index 3517f1a7..cf9f6de4 100644 --- a/CHANGES +++ b/CHANGES @@ -118,6 +118,8 @@ Major release, unreleased - The dev server now uses threads by default. (`#2529`_) - Loading config files with ``silent=True`` will ignore ``ENOTDIR`` errors. (`#2581`_) +- Pass ``--cert`` and ``--key`` options to ``flask run`` to run the + development server over HTTPS. (`#2606`_) .. _pallets/meta#24: https://github.com/pallets/meta/issues/24 .. _#1421: https://github.com/pallets/flask/issues/1421 @@ -154,6 +156,7 @@ Major release, unreleased .. _#2450: https://github.com/pallets/flask/pull/2450 .. _#2529: https://github.com/pallets/flask/pull/2529 .. _#2581: https://github.com/pallets/flask/pull/2581 +.. _#2606: https://github.com/pallets/flask/pull/2606 Version 0.12.3 diff --git a/flask/cli.py b/flask/cli.py index 52a4f596..43e6aa5a 100644 --- a/flask/cli.py +++ b/flask/cli.py @@ -14,6 +14,7 @@ import ast import inspect import os import re +import ssl import sys import traceback from functools import update_wrapper @@ -21,9 +22,10 @@ from operator import attrgetter from threading import Lock, Thread import click +from werkzeug.utils import import_string from . import __version__ -from ._compat import getargspec, iteritems, reraise +from ._compat import getargspec, iteritems, reraise, text_type from .globals import current_app from .helpers import get_debug_flag, get_env @@ -599,25 +601,110 @@ def show_server_banner(env, debug, app_import_path): print(' * Debug mode: {0}'.format('on' if debug else 'off')) +class CertParamType(click.ParamType): + """Click option type for the ``--cert`` option. Allows either an + existing file, the string ``'adhoc'``, or an import for a + :class:`~ssl.SSLContext` object. + """ + + name = 'path' + + def __init__(self): + self.path_type = click.Path( + exists=True, dir_okay=False, resolve_path=True) + + def convert(self, value, param, ctx): + try: + return self.path_type(value, param, ctx) + except click.BadParameter: + value = click.STRING(value, param, ctx).lower() + + if value == 'adhoc': + try: + import OpenSSL + except ImportError: + raise click.BadParameter( + 'Using ad-hoc certificates requires pyOpenSSL.', + ctx, param) + + return value + + obj = import_string(value, silent=True) + + if sys.version_info < (2, 7): + if obj: + return obj + else: + if isinstance(obj, ssl.SSLContext): + return obj + + raise + + +def _validate_key(ctx, param, value): + """The ``--key`` option must be specified when ``--cert`` is a file. + Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed. + """ + cert = ctx.params.get('cert') + is_adhoc = cert == 'adhoc' + + if sys.version_info < (2, 7): + is_context = cert and not isinstance(cert, (text_type, bytes)) + else: + is_context = isinstance(cert, ssl.SSLContext) + + if value is not None: + if is_adhoc: + raise click.BadParameter( + 'When "--cert" is "adhoc", "--key" is not used.', + ctx, param) + + if is_context: + raise click.BadParameter( + 'When "--cert" is an SSLContext object, "--key is not used.', + ctx, param) + + if not cert: + raise click.BadParameter( + '"--cert" must also be specified.', + ctx, param) + + ctx.params['cert'] = cert, value + + else: + if cert and not (is_adhoc or is_context): + raise click.BadParameter( + 'Required when using "--cert".', + ctx, param) + + return value + + @click.command('run', short_help='Runs a development server.') @click.option('--host', '-h', default='127.0.0.1', help='The interface to bind to.') @click.option('--port', '-p', default=5000, help='The port to bind to.') +@click.option('--cert', type=CertParamType(), + help='Specify a certificate file to use HTTPS.') +@click.option('--key', + type=click.Path(exists=True, dir_okay=False, resolve_path=True), + callback=_validate_key, expose_value=False, + help='The key file to use when specifying a certificate.') @click.option('--reload/--no-reload', default=None, - help='Enable or disable the reloader. By default the reloader ' + help='Enable or disable the reloader. By default the reloader ' 'is active if debug is enabled.') @click.option('--debugger/--no-debugger', default=None, - help='Enable or disable the debugger. By default the debugger ' + help='Enable or disable the debugger. By default the debugger ' 'is active if debug is enabled.') @click.option('--eager-loading/--lazy-loader', default=None, - help='Enable or disable eager loading. By default eager ' + help='Enable or disable eager loading. By default eager ' 'loading is enabled if the reloader is disabled.') @click.option('--with-threads/--without-threads', default=True, help='Enable or disable multithreading.') @pass_script_info def run_command(info, host, port, reload, debugger, eager_loading, - with_threads): + with_threads, cert): """Run a local development server. This server is for development purposes only. It does not provide @@ -642,7 +729,7 @@ def run_command(info, host, port, reload, debugger, eager_loading, from werkzeug.serving import run_simple run_simple(host, port, app, use_reloader=reload, use_debugger=debugger, - threaded=with_threads) + threaded=with_threads, ssl_context=cert) @click.command('shell', short_help='Runs a shell in the app context.') diff --git a/tests/test_cli.py b/tests/test_cli.py index d9216f3d..f1e5eba7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -14,7 +14,9 @@ from __future__ import absolute_import import os +import ssl import sys +import types from functools import partial import click @@ -24,8 +26,8 @@ from click.testing import CliRunner from flask import Flask, current_app from flask.cli import ( - AppGroup, FlaskGroup, NoAppException, ScriptInfo, dotenv, - find_best_app, get_version, load_dotenv, locate_app, prepare_import, + AppGroup, FlaskGroup, NoAppException, ScriptInfo, dotenv, find_best_app, + get_version, load_dotenv, locate_app, prepare_import, run_command, with_appcontext ) @@ -464,3 +466,62 @@ def test_dotenv_optional(monkeypatch): monkeypatch.chdir(test_path) load_dotenv() assert 'FOO' not in os.environ + + +def test_run_cert_path(): + # no key + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--cert', __file__]) + + # no cert + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--key', __file__]) + + ctx = run_command.make_context( + 'run', ['--cert', __file__, '--key', __file__]) + assert ctx.params['cert'] == (__file__, __file__) + + +def test_run_cert_adhoc(monkeypatch): + monkeypatch.setitem(sys.modules, 'OpenSSL', None) + + # pyOpenSSL not installed + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--cert', 'adhoc']) + + # pyOpenSSL installed + monkeypatch.setitem(sys.modules, 'OpenSSL', types.ModuleType('OpenSSL')) + ctx = run_command.make_context('run', ['--cert', 'adhoc']) + assert ctx.params['cert'] == 'adhoc' + + # no key with adhoc + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--cert', 'adhoc', '--key', __file__]) + + +def test_run_cert_import(monkeypatch): + monkeypatch.setitem(sys.modules, 'not_here', None) + + # ImportError + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--cert', 'not_here']) + + # not an SSLContext + if sys.version_info >= (2, 7): + with pytest.raises(click.BadParameter): + run_command.make_context('run', ['--cert', 'flask']) + + # SSLContext + if sys.version_info < (2, 7): + ssl_context = object() + else: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + + monkeypatch.setitem(sys.modules, 'ssl_context', ssl_context) + ctx = run_command.make_context('run', ['--cert', 'ssl_context']) + assert ctx.params['cert'] is ssl_context + + # no --key with SSLContext + with pytest.raises(click.BadParameter): + run_command.make_context( + 'run', ['--cert', 'ssl_context', '--key', __file__])