缘起

  在使用flask的时候一直比较纳闷request是什么原理,他是如何保证多线程情况下对不同请求参数的隔离的。

准备知识

  在讲request之前首先需要先理解一下werkzeug.local中的几个类,因为request就是基于这几个类来搞事情的。

深入flask中的request-LMLPHP深入flask中的request-LMLPHP
# -*- coding: utf-8 -*-
import copy
from werkzeug._compat import PY2, implements_bool

# since each thread has its own greenlet we can just use those as identifiers
# for the context.  If greenlets are not available we fall back to the
# current thread ident depending on where it is.
try:
    from greenlet import getcurrent as get_ident
except ImportError:
    try:
        from thread import get_ident
    except ImportError:
        from _thread import get_ident


def release_local(local):
    """Releases the contents of the local for the current context.
    This makes it possible to use locals without a manager.

    Example::

        >>> loc = Local()
        >>> loc.foo = 42
        >>> release_local(loc)
        >>> hasattr(loc, 'foo')
        False

    With this function one can release :class:`Local` objects as well
    as :class:`LocalStack` objects.  However it is not possible to
    release data held by proxies that way, one always has to retain
    a reference to the underlying local object in order to be able
    to release it.

    .. versionadded:: 0.6.1
    """
    local.__release_local__()


class Local(object):

    """
    用一个大字典实现局部上下文
        不同的线程或者greenlet调用该local对象时获取的值都是本线程或者greenlet独享的
        实际上就是为每一个线程或者协程在字典里单独开辟出了一个空间(实际上就是一个键值对,键就是线程或者greenlet的唯一标识),
        这空间用来存储单个线程(或者greenlet)的私有变量

    """
    __slots__ = ('__storage__', '__ident_func__')

    def __init__(self):
        object.__setattr__(self, '__storage__', {})
        object.__setattr__(self, '__ident_func__', get_ident)

    def __iter__(self):
        return iter(self.__storage__.items())

    def __call__(self, proxy):
        """Create a proxy for a name."""
        return LocalProxy(self, proxy)

    def __release_local__(self):
        self.__storage__.pop(self.__ident_func__(), None)

    def __getattr__(self, name):
        try:
            return self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        ident = self.__ident_func__()
        storage = self.__storage__
        try:
            storage[ident][name] = value
        except KeyError:
            storage[ident] = {name: value}

    def __delattr__(self, name):
        try:
            del self.__storage__[self.__ident_func__()][name]
        except KeyError:
            raise AttributeError(name)


class LocalStack(object):

    """
    LocalStack也是一个栈相关的局部上下文,底层实现是基于Local类。
        注意一下他的pop方法,如果当前栈的长度为1,pop时会清空当前线程(greenlet)在底层的Local中所对应的"键值对"的
    """

    def __init__(self):
        self._local = Local()

    def __release_local__(self):
        self._local.__release_local__()

    def _get__ident_func__(self):
        return self._local.__ident_func__

    def _set__ident_func__(self, value):
        object.__setattr__(self._local, '__ident_func__', value)
    __ident_func__ = property(_get__ident_func__, _set__ident_func__)
    del _get__ident_func__, _set__ident_func__

    def __call__(self):
        def _lookup():
            rv = self.top
            if rv is None:
                raise RuntimeError('object unbound')
            return rv
        return LocalProxy(_lookup)

    def push(self, obj):
        """Pushes a new item to the stack"""
        rv = getattr(self._local, 'stack', None)
        if rv is None:
            self._local.stack = rv = []
        rv.append(obj)
        return rv

    def pop(self):
        """Removes the topmost item from the stack, will return the
        old value or `None` if the stack was already empty.
        """
        stack = getattr(self._local, 'stack', None)
        if stack is None:
            return None
        elif len(stack) == 1:
            release_local(self._local)
            return stack[-1]
        else:
            return stack.pop()

    @property
    def top(self):
        """The topmost item on the stack.  If the stack is empty,
        `None` is returned.
        """
        try:
            return self._local.stack[-1]
        except (AttributeError, IndexError):
            return None


@implements_bool
class LocalProxy(object):
    """
    代理模式: 给目标对象提供一个代理对象,并由代理对象控制对目标对象的引用
    """

    """Acts as a proxy for a werkzeug local.  Forwards all operations to
    a proxied object.  The only operations not supported for forwarding
    are right handed operands and any kind of assignment.

    Example usage::

        from werkzeug.local import Local
        l = Local()

        # these are proxies
        request = l('request')
        user = l('user')


        from werkzeug.local import LocalStack
        _response_local = LocalStack()

        # this is a proxy
        response = _response_local()

    Whenever something is bound to l.user / l.request the proxy objects
    will forward all operations.  If no object is bound a :exc:`RuntimeError`
    will be raised.

    To create proxies to :class:`Local` or :class:`LocalStack` objects,
    call the object as shown above.  If you want to have a proxy to an
    object looked up by a function, you can (as of Werkzeug 0.6.1) pass
    a function to the :class:`LocalProxy` constructor::

        session = LocalProxy(lambda: get_current_request().session)

    .. versionchanged:: 0.6.1
       The class can be instantiated with a callable as well now.
    """
    __slots__ = ('__local', '__dict__', '__name__', '__wrapped__')

    def __init__(self, local, name=None):
        # __local 会被重命名为 _LocalProxy__local
        object.__setattr__(self, '_LocalProxy__local', local)
        object.__setattr__(self, '__name__', name)
        if callable(local) and not hasattr(local, '__release_local__'):
            # "local" is a callable that is not an instance of Local or
            # LocalManager: mark it as a wrapped function.
            object.__setattr__(self, '__wrapped__', local)

    def _get_current_object(self):
        """Return the current object.  This is useful if you want the real
        object behind the proxy at a time for performance reasons or because
        you want to pass the object into a different context.
        """
        if not hasattr(self.__local, '__release_local__'):
            return self.__local()
        try:
            return getattr(self.__local, self.__name__)
        except AttributeError:
            raise RuntimeError('no object bound to %s' % self.__name__)

    @property
    def __dict__(self):
        try:
            return self._get_current_object().__dict__
        except RuntimeError:
            raise AttributeError('__dict__')

    def __repr__(self):
        try:
            obj = self._get_current_object()
        except RuntimeError:
            return '<%s unbound>' % self.__class__.__name__
        return repr(obj)

    def __bool__(self):
        try:
            return bool(self._get_current_object())
        except RuntimeError:
            return False

    def __unicode__(self):
        try:
            return unicode(self._get_current_object())  # noqa
        except RuntimeError:
            return repr(self)

    def __dir__(self):
        try:
            return dir(self._get_current_object())
        except RuntimeError:
            return []

    def __getattr__(self, name):
        if name == '__members__':
            return dir(self._get_current_object())
        return getattr(self._get_current_object(), name)

    def __setitem__(self, key, value):
        self._get_current_object()[key] = value

    def __delitem__(self, key):
        del self._get_current_object()[key]

    if PY2:
        __getslice__ = lambda x, i, j: x._get_current_object()[i:j]

        def __setslice__(self, i, j, seq):
            self._get_current_object()[i:j] = seq

        def __delslice__(self, i, j):
            del self._get_current_object()[i:j]

    __setattr__ = lambda x, n, v: setattr(x._get_current_object(), n, v)
    __delattr__ = lambda x, n: delattr(x._get_current_object(), n)
    __str__ = lambda x: str(x._get_current_object())
    __lt__ = lambda x, o: x._get_current_object() < o
    __le__ = lambda x, o: x._get_current_object() <= o
    __eq__ = lambda x, o: x._get_current_object() == o
    __ne__ = lambda x, o: x._get_current_object() != o
    __gt__ = lambda x, o: x._get_current_object() > o
    __ge__ = lambda x, o: x._get_current_object() >= o
    __cmp__ = lambda x, o: cmp(x._get_current_object(), o)  # noqa
    __hash__ = lambda x: hash(x._get_current_object())
    __call__ = lambda x, *a, **kw: x._get_current_object()(*a, **kw)
    __len__ = lambda x: len(x._get_current_object())
    __getitem__ = lambda x, i: x._get_current_object()[i]
    __iter__ = lambda x: iter(x._get_current_object())
    __contains__ = lambda x, i: i in x._get_current_object()
    __add__ = lambda x, o: x._get_current_object() + o
    __sub__ = lambda x, o: x._get_current_object() - o
    __mul__ = lambda x, o: x._get_current_object() * o
    __floordiv__ = lambda x, o: x._get_current_object() // o
    __mod__ = lambda x, o: x._get_current_object() % o
    __divmod__ = lambda x, o: x._get_current_object().__divmod__(o)
    __pow__ = lambda x, o: x._get_current_object() ** o
    __lshift__ = lambda x, o: x._get_current_object() << o
    __rshift__ = lambda x, o: x._get_current_object() >> o
    __and__ = lambda x, o: x._get_current_object() & o
    __xor__ = lambda x, o: x._get_current_object() ^ o
    __or__ = lambda x, o: x._get_current_object() | o
    __div__ = lambda x, o: x._get_current_object().__div__(o)
    __truediv__ = lambda x, o: x._get_current_object().__truediv__(o)
    __neg__ = lambda x: -(x._get_current_object())
    __pos__ = lambda x: +(x._get_current_object())
    __abs__ = lambda x: abs(x._get_current_object())
    __invert__ = lambda x: ~(x._get_current_object())
    __complex__ = lambda x: complex(x._get_current_object())
    __int__ = lambda x: int(x._get_current_object())
    __long__ = lambda x: long(x._get_current_object())  # noqa
    __float__ = lambda x: float(x._get_current_object())
    __oct__ = lambda x: oct(x._get_current_object())
    __hex__ = lambda x: hex(x._get_current_object())
    __index__ = lambda x: x._get_current_object().__index__()
    __coerce__ = lambda x, o: x._get_current_object().__coerce__(x, o)
    __enter__ = lambda x: x._get_current_object().__enter__()
    __exit__ = lambda x, *a, **kw: x._get_current_object().__exit__(*a, **kw)
    __radd__ = lambda x, o: o + x._get_current_object()
    __rsub__ = lambda x, o: o - x._get_current_object()
    __rmul__ = lambda x, o: o * x._get_current_object()
    __rdiv__ = lambda x, o: o / x._get_current_object()
    if PY2:
        __rtruediv__ = lambda x, o: x._get_current_object().__rtruediv__(o)
    else:
        __rtruediv__ = __rdiv__
    __rfloordiv__ = lambda x, o: o // x._get_current_object()
    __rmod__ = lambda x, o: o % x._get_current_object()
    __rdivmod__ = lambda x, o: x._get_current_object().__rdivmod__(o)
    __copy__ = lambda x: copy.copy(x._get_current_object())
    __deepcopy__ = lambda x, memo: copy.deepcopy(x._get_current_object(), memo)
werkzeug.local部分源码

先来讲Local对象

1 创建一个Local对象

local = Local()

刚创建后, 这个local_context中负责存储局部上下文变量的storage是一个空字典

local.__storage__ = {}

我们用iden1, inde2 .... indeN 来表示n个同属于一个进程的线程(或者greenlet), 假如当前的线程(或者greenlet)的id为iden1, 我们来操作一下local

local.name = "iden1_name"

实际执行的代码是:

local.__storage__.setdefault("iden1", {})["name"] = "iden1_name"

这个local中负责存储局部上下文变量的storage就变成了这样:

local.__storage__ = {
    "iden1": {
        "name": "iden1_name"
    }
}

当我们在不同的线程(或者greenlet)中操作后,local就可能会变成这样

local.__storage__ = {
    "iden1": {...},
    "iden2": {...},
    ...
    "idenN": {...}
}

local对象有一个__release_local__方法, 执行该方法会清理掉当前线程(或者greenlet)对应的存储空间, 假如当前的线程(或者greenlet)的id为iden1,
当我们执行完__release_local__方法后, local的存储空间就会变成这样:

# 已经没有iden1了  
local.__storage__ = {
    "iden2": {...},
    ...
    "idenN": {...}
}

local还定义了__call__方法, 当我们执行local()后会返回一个LocalStack对象

LocalStack对象

LocalStack底层使用的是Local,然后在Local实例中实现了一个栈

创建一个LocalStack对象

local_stack = LocalStack()

该对象的local属性就是一个Local实例

isinstance(local_stack.local, Local) is True

local_stack的栈存储在他的local属性中, 当我们调用local_stack.push(some_obj)的时候, 实际上是执行了

local_stack.local.stack.append(some_obj) if hasattr(local_stack.local, "stack") else local_stack.local.stack = [some_obj]

假如当前的线程(或者greenlet)的id为iden1, 我们push一个对象request_ctx_obj, 然后又push一个对象request_ctx_obj2, 那么local_stack.local就会是这样:

local_stack.local.__storage__ = {
    "iden1": {
        "stack": [request_ctx_obj, request_ctx_obj2]
    }
}

假如当前的线程(或者greenlet)的id为iden1,我们在调用local_stack.top()方法时,实际上执行的是:

return local_stack.local.stack[-1]

需要注意的是:

  如果我们当前所处的线程(或者greenlet)中之前并没有进行过push的话,那么我们调用local_stack.top()方法返回的结果是None

当我们执行local_stack.pop()时, 实际上执行的是

local_stack.local.stack.pop()

需要注意两点:
  1 如果当前线程(或者greenlet)中之前没有push过, 那么pop()方法会返回None
  2 如果当前线程(或者greenlet)中的栈中只有一个对象, 那么本次pop()还会清理掉stack(实际上执行了local_stack.local.__release_local__方法),
  假如当前的线程(或者greenlet)的id为iden2的话,没有pop()之前是这样的:

local_stack.local.__storage__ = {
    "iden1": {
        "stack": [request_ctx_obj, request_ctx_obj2]
    }
    "iden2": {
        "stack": [request_ctx_obj3]
    }
}

执行pop()则会将当前线程(或者greenlet)的局部上下文存储空间清理掉, 变为这样:

local_stack.local.__storage__ = {
    "iden1": {
        "stack": [request_ctx_obj, request_ctx_obj2]
    }
}

LocalStack也提供了__call__方法, 执行该方法会生成一个LocalProxy对象

LocalProxy

LocalProxy实现了代理模式, 给目标对象提供一个代理对象,并由代理对象控制对目标对象的引用。
根据传入参数的类型以及数量的不同他会有两种表现形式.
第一种,第一个参数传入一个Local实例, 然后第二个参数传入想要代理的对象的名称:
  我们执行下面的语句:

local_1 = Local()
local_proxy1 = LocalProxy(local_1, "age")

  local_proxy1所代理的实际上就是local_1实例中的age属性了。
  假如当前的线程(或者greenlet)的id为iden1,那么我们执行local_proxy1 = 12, 实际执行的就是local_1.age = 12

第二种,只传入一个函数,通过该函数可以获取到想要代理的对象:
  我们执行下面的语句:

local_2 = Local()

def _find_raw_obj():
    return local_2.name

local_proxy2 = LocalProxy(_find_raw_obj)

  local_proxy2所代理的实际上就是local_2实例中的name属性

flask源码剖析

request源码

def _lookup_req_object(name):
    top = _request_ctx_stack.top
    if top is None:
        raise RuntimeError(_request_ctx_err_msg)
    return getattr(top, name)


# context locals
_request_ctx_stack = LocalStack()
request = LocalProxy(partial(_lookup_req_object, 'request'))

只要看懂了文章上半部分的local,这里实际上很简单,

  _request_ctx_stack是一个LocalStack实例,而我们每次调用request.some_attr 的时候实际上是执行_request_ctx_stack.top.some_attr

再来看一下当请求过来的时候,flask是如何处理的:

# flask.app
class Flask(_PackageBoundObject):

    request_class = Request

    def request_context(self, environ):
        return RequestContext(self, environ)

    def wsgi_app(self, environ, start_response):
        # self.request_context是一个RequestContext实例
        ctx = self.request_context(environ)
        error = None
        try:
            try:
# 2 执行了request_ctx_stack.push(ctx) ctx.push()
# 3 处理请求得到响应 response
= self.full_dispatch_request() except Exception as e: error = e response = self.handle_exception(e) except: error = sys.exc_info()[1] raise return response(environ, start_response) finally: if self.should_ignore_error(error): error = None
# 4 request_ctx_stack.pop(ctx) ctx.auto_pop(error)
# 1 当请求来的时候会执行app.__call__()方法
def __call__(self, environ, start_response): """The WSGI server calls the Flask application object as the WSGI application. This calls :meth:`wsgi_app` which can be wrapped to applying middleware.""" return self.wsgi_app(environ, start_response) # flask.ctx class RequestContext(object): def __init__(self, app, environ, request=None): self.app = app if request is None:
       # request是一个Request对象 request
= app.request_class(environ) self.request = request self._implicit_app_ctx_stack = [] def push(self): top = _request_ctx_stack.top if top is not None and top.preserved: top.pop(top._preserved_exc) # Before we push the request context we have to ensure that there # is an application context. app_ctx = _app_ctx_stack.top if app_ctx is None or app_ctx.app != self.app: app_ctx = self.app.app_context() app_ctx.push() self._implicit_app_ctx_stack.append(app_ctx) else: self._implicit_app_ctx_stack.append(None) if hasattr(sys, 'exc_clear'): sys.exc_clear()
# 2.1 这里是重点 _request_ctx_stack.push(self)
def pop(self, exc=_sentinel): app_ctx = self._implicit_app_ctx_stack.pop() try: if not self._implicit_app_ctx_stack: self.preserved = False self._preserved_exc = None if exc is _sentinel: exc = sys.exc_info()[1] self.app.do_teardown_request(exc) request_close = getattr(self.request, 'close', None) if request_close is not None: request_close() finally: rv = _request_ctx_stack.pop()# Get rid of the app as well if necessary. if app_ctx is not None: app_ctx.pop(exc) assert rv is self, 'Popped wrong request context. ' \ '(%r instead of %r)' % (rv, self)

 当请求过来时:

  1 将请求封装为一个RequestContext实例

  2 然后将请求的environ封装成Request对象

  3 执行_request_ctx_stack.push(RequestContext实例)

  4 处理请求得到响应

  5 执行_request_ctx_stack.pop()

  6 返回结果

看到这里,大体原理我们也就懂了。

来点深入的高级用法

需求

工作中使用到flask flask-restful,有这样的场景:

  1 首先是遵循restful

  2 我希望所有接口有统一的参数传递格式,类似于这样:    

    timestamp: int                     # 以秒为单位 
    token:  str                           # 这个就不用说了
    data: str                          # base64.encode(json.dumps(原始参数数据))
    signature: md5(data + string(timestamp) + key)        # 签名,注:key为加密密钥

假如是登陆接口,客户端需要给我传递手机号以及验证码,我希望格式是json,所以原始参数数据大概是这样:

{"mobile": "12377776666", "code": "1234"}

  3 我希望所有接口有统一的响应格式,类似于这样:

{
    "code": 0, # 这成功。 1 -9999 错误
    "message": "",
    "data": {
     "mobile": "sdfasdf",
     "code": ""
    }
}

  4 我希望请求参数数据经过统一的参数检测之后, request的args(如果请求方式为get) 或者form(如果请求方式为post) 或者values属性变为原始参数数据.这样就可以正常使用RequestParser()

实现

以下是本人的实现源码(生产环境测试无问题,但有几处需要改进的地方):

深入flask中的request-LMLPHP深入flask中的request-LMLPHP
import time
import json
import base64
import traceback
import logging
from hashlib import md5
from flask import request
from functools import wraps
from flask_restful import reqparse
from flask.globals import _request_ctx_stack
from werkzeug.exceptions import HTTPException
from werkzeug.datastructures import ImmutableMultiDict, CombinedMultiDict, MultiDict


logger = logging.getLogger("api")


def get_user_info(token):
    """根据token获取用户信息,这个自己实现吧"""
    pass


class ServerRequestParser(reqparse.RequestParser):

    def parse_args(self, req=None, strict=False):
        try:
            # print(request.values)
            return super().parse_args(req, strict)
        except HTTPException as e:
            raise ServerException(3, str(e.data))


class ServerException(Exception):
    code_msg = {
        -1: "未知错误",
        -2: "非法请求",
        1: "缺少token",
        2: "非法token",
        3: "参数错误: %s"
    }

    def __init__(self, code, *args):
        self.code = code
        self.args = args

    def to_json(self):
        return {
            "code": self.code,
            "message": self.code_msg[self.code] % self.args,
            "data": {}
        }


class ServerRequest(object):
    """
    该类主要是配合uniform_verification_wechat方法,对flask.wrappers.Request对象做封装后进行替换
    """

    def __init__(self, json_data):
        self._raw_request = request._get_current_object()
        self.args = ImmutableMultiDict()
        self.form = ImmutableMultiDict()
        self.json = {}
        if request.method.lower() == "get":
            self.args = json_data
        elif request.content_type == "application/json":
            self.json = json_data
        else:
            self.form = json_data
        setattr(_request_ctx_stack._local.stack[-1], "request", self)

    def __getattr__(self, item):
        return getattr(self._raw_request, item)

    @property
    def values(self):
        args = []
        for d in self.args, self.form:
            if not isinstance(d, MultiDict):
                d = MultiDict(d)
            args.append(d)
        return CombinedMultiDict(args)


uniform_wechat_parser = reqparse.RequestParser()
uniform_wechat_parser.add_argument("timestamp", type=int, required=True)
uniform_wechat_parser.add_argument("signature", required=True)
uniform_wechat_parser.add_argument("data", required=True)
uniform_wechat_parser.add_argument("token", required=True)


def uniform_verification_wechat(check_token=True):

    def wrapper(func):
        current_time = time.time()

        @wraps(func)
        def inner(*args, **kwargs):
            request_data_dict = uniform_wechat_parser.parse_args()
            request_log_info = "url: %s, method: %s, request_data: %s" % (
                request.url, request.method, json.dumps(request_data_dict).encode())
            signature = request_data_dict["signature"]
            timestamp = request_data_dict["timestamp"]
            data = request_data_dict["data"]
            token = request_data_dict.get("token", None)
            try:
                if current_time - timestamp >= 120:
                    raise ServerException(-2)
                _strings = "%s%s%s" % (data, timestamp, "密钥key")
                if signature != md5(_strings.encode()).hexdigest():
                    raise ServerException(-2)
                try:
                    data = json.loads(base64.b64decode(data.encode()).decode())
                except Exception:
                    raise ServerException(-2)
                request_log_info = "url: %s, method: %s, request_data: %s" % (
                    request.url, request.method, json.dumps(data).encode())
                user_info = {}
                if check_token:
                    if not token:
                        raise ServerException(1)
                    user_info = get_user_info(token)
                    if user_info is None:
                        raise ServerException(2)
                    logger.info("checking token... %s" % user_info)

                request.user_info = user_info
                ServerRequest(data)
                _result = func(*args, **kwargs)
                result = {
                    "code": 0,
                    "message": "",
                    "data": _result
                }
            except ServerException as e:
                result = e.to_json()
            except Exception as e:
                if hasattr(e, "to_json"):
                    result = e.to_json()
                else:
                    logger.info(traceback.print_exc())
                    result = ServerException(-1).to_json()
            response_log_info = result
            logger.info("%s, %s" % (request_log_info, json.dumps(response_log_info).encode()))
            return result

        return inner

    return wrapper


from flask import Flask
from flask_restful import Resource, Api

app = Flask(__name__)
api = Api(app)


class Hello(Resource):
    hello_parser = ServerRequestParser()
    hello_parser.add_argument("name", required=True)

    @uniform_verification_wechat(False)
    def get(self):
        args = self.hello_parser.parse_args()
        return {'hello': args["name"]}


api.add_resource(Hello, '/')

if __name__ == "__main__":
    app.run(debug=True)
统一参数解析处理

参考:

  https://blog.tonyseek.com/post/the-context-mechanism-of-flask/

  flask源码

  flask-restful源码

12-10 00:16