hell0kitty 2020-06-17
纸上得来终觉浅,绝知此事要躬行。
之前在【Django】DRF源码分析之五大模块文章中没有讲到认证模块,本章就主要来谈谈认证模块中的三大认证,首先我们先回顾一下DRF请求的流程:
urls.py
中的url匹配,执行对应类视图调用as_view()
方法from django.conf.urls import url from . import views urlpatterns = [ url(r‘^v1/users/$‘, views.User.as_view()) ]
APIView
中调用父类as_view()
,并且在闭包中调用了dispatch()
方法,该方法调用的APIView
类中的(该类重写了父类)def dispatch(self, request, *args, **kwargs): ...... # 请求模块和解析模块 request = self.initialize_request(request, *args, **kwargs) ...... try: # 三大认证模块 self.initial(request, *args, **kwargs) ...... # 响应模块 response = handler(request, *args, **kwargs) except Exception as exc: # 异常模块 response = self.handle_exception(exc) # 渲染模块 self.response = self.finalize_response(request, response, *args, **kwargs) return self.response
self.initialize_request(request, *args, **kwargs)
请求模块,此步骤是rest_framework
对request进行了扩展封装和兼容def initialize_request(self, request, *args, **kwargs): """ Returns the initial request object. """ parser_context = self.get_parser_context(request) return Request( request, parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), parser_context=parser_context )
self.initial(request, *args, **kwargs)
,点击源码进入。def initial(self, request, *args, **kwargs): ...... # 认证组件 self.perform_authentication(request) # 权限组件 self.check_permissions(request) # 限流组件 self.check_throttles(request)
目前请求走到这里,就是我们今天需要讨论的认证模块,分为三个部分,以此来从源码进行剖析。
首先执行的就是self.perform_authentication(request)
(认证组件),点击源码查看:
def perform_authentication(self, request): request.user
我们发现该方法只有一行代码,没有返回值,也没有赋值,也不能继续点击进入(可能会出现一堆的东西),但是我们的目的就是找认证组件认证方法,所以猜想这句话就是调用方法,是不是很可能被@property
装饰了,还是通过request
对象调用的,所以我们就rest_framework/request.py
下面找request
类(因为之前的请求模块对原生request进行了扩展就是使用的该类)中的user
方法,发现源码如下:
@property def user(self): """ Returns the user associated with the current request, as authenticated by the authentication classes provided to the request. """ if not hasattr(self, ‘_user‘): with wrap_attributeerrors(): # 没用户,认证用户 self._authenticate() # 有用户,直接返回 return self._user
发现对于认证的函数只调用了self._authenticate()
,我们继续点击进入,源码分析如下图:
结合上图我们需要分析出下面几个问题:
先解决第一个问题(self.authenticators是啥?
)我们直接点击进去发现跑到了request
类的__init__
方法肯定不对,我们往回找,他是request的属性,记得之前请求模块对request进行了扩展,回去发现在APIView
类的下面有self.initialize_request(request, *args, **kwargs)
方法中有下面的代码:
Request( request, parsers=self.get_parsers(), authenticators=self.get_authenticators(), negotiator=self.get_content_negotiator(), parser_context=parser_context )
发现传入了authenticators
,他等于self.get_authenticators()
的返回值,所以我们去查找self.get_authenticators()
的源码:
def get_authenticators(self): """ Instantiates and returns the list of authenticators that this view can use. """ return [auth() for auth in self.authentication_classes]
发现结果是一个列表推导式,所以上图中可以进行遍历,而且列表中装的也都是对象,我们就去看看到底是什么类的对象,点击查看authentication_classes
,发现是通过authentication_classes = api_settings.DEFAULT_AUTHENTICATION_CLASSES
配置的,所以我们可以去api_settings
查找,结果代码如下:
‘DEFAULT_AUTHENTICATION_CLASSES‘: [ ‘rest_framework.authentication.SessionAuthentication‘, ‘rest_framework.authentication.BasicAuthentication‘ ],
默认写了两个类,ok,目前我们已经知道第一个问题(self.authenticators是啥?
),他默认就是这两个类的对象列表,下面就是解决第二个问题(authenticate(self)方法执行了什么玩意?
),该方法是这两个类的方法,所以我们去这两个默认的类去查看,源码查看文件rest_framework/authentication.py
,下面是UML图以及类继承图:
由此我们发现BaseAuthentication
是其他类的父类,而且每个类都有authenticate
方法,我们首先查看一下BasicAuthentication
类中的authenticate
方法实现:
def get_authorization_header(request): """ Return request‘s ‘Authorization:‘ header, as a bytestring. Hide some test client ickyness where the header can be unicode. """ auth = request.META.get(‘HTTP_AUTHORIZATION‘, b‘‘) if isinstance(auth, str): # Work around django test client oddness auth = auth.encode(HTTP_HEADER_ENCODING) return auth class BasicAuthentication(BaseAuthentication): """ HTTP Basic authentication against username/password. """ www_authenticate_realm = ‘api‘ def authenticate(self, request): """ Returns a `User` if a correct username and password have been supplied using HTTP Basic authentication. Otherwise returns `None`. """ auth = get_authorization_header(request).split() # 第一步:从请求头获取token信息按照空格分割 if not auth or auth[0].lower() != b‘basic‘: # 第二步:判断我们的值格式:“basic xxxxxxxx”,就是有两段,中间空格隔开 return None # 校验分割长度是不是等于2 if len(auth) == 1: msg = _(‘Invalid basic header. No credentials provided.‘) raise exceptions.AuthenticationFailed(msg) elif len(auth) > 2: msg = _(‘Invalid basic header. Credentials string should not contain spaces.‘) raise exceptions.AuthenticationFailed(msg) # 把token值按照一定规则解密 try: auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(‘:‘) except (TypeError, UnicodeDecodeError, binascii.Error): msg = _(‘Invalid basic header. Credentials not correctly base64 encoded.‘) raise exceptions.AuthenticationFailed(msg) userid, password = auth_parts[0], auth_parts[2] return self.authenticate_credentials(userid, password, request) def authenticate_credentials(self, userid, password, request=None): """ Authenticate the userid and password against username and password with optional request for context. """ credentials = { get_user_model().USERNAME_FIELD: userid, ‘password‘: password } user = authenticate(request=request, **credentials) if user is None: raise exceptions.AuthenticationFailed(_(‘Invalid username/password.‘)) if not user.is_active: raise exceptions.AuthenticationFailed(_(‘User inactive or deleted.‘)) return (user, None) def authenticate_header(self, request): return ‘Basic realm="%s"‘ % self.www_authenticate_realm
分析过程:
1. 调用get_authorization_header
从请求头获取,Authorization 的值,一般就是token信息,并且按照空格分割
2. 分割完成,判断我们的第一部分是不是basic
3. 校验分割长度是不是等于2
4. 把token值按照一定规则解密,分配
5. 调用self.authenticate_credentials(userid, password, request)
,可以看成通过解密的信息查询用户,最终返回元祖类型的数据
目第二个问题也已经解决,得知返回的结果是一个(user,None)的元祖,然后把元祖信息拆分给request.user
和request.auth
。如果其中任意一个地方发生异常都会调用self._not_authenticated()
,下面我们就来看看第三个问题(self._not_authenticated()
方法做了啥?)
def _not_authenticated(self): """ Set authenticator, user & authtoken representing an unauthenticated request. Defaults are None, AnonymousUser & None. """ self._authenticator = None if api_settings.UNAUTHENTICATED_USER: self.user = api_settings.UNAUTHENTICATED_USER() else: self.user = None if api_settings.UNAUTHENTICATED_TOKEN: self.auth = api_settings.UNAUTHENTICATED_TOKEN() else: self.auth = None
源码其实很简单,就是给self.user
和self.auth
赋值,其实就相当于给request.user
和request.auth
赋值。其中api_settings.UNAUTHENTICATED_USER()
表示的是一个匿名用户也可以理解为游客,而api_settings.UNAUTHENTICATED_TOKEN()
,默认值为None。可以在api_settings
中查看‘UNAUTHENTICATED_USER‘: ‘django.contrib.auth.models.AnonymousUser‘
,‘UNAUTHENTICATED_TOKEN‘: None,
。
整个认证的过程分析完成,我们可以知道大致流程就是:
_not_authenticated
方法,赋值为request.user和request.auth_not_authenticated
方法,赋值给request.user和request.auth也即是说我们可以通过在类视图的request对象直接获取当前访问的用户,判断他是登录用户还是游客。
通过源码的分析,我们可以知道实现自定义认证类必要条件,继承BaseAuthentication
,然后实现authenticate
方法,至于验证的逻辑可以结合业务编写,最终返回(user,auth)的元祖
#继承BaseAuthentication class MyAuthentication(BaseAuthentication): def authenticate(self, request): #重写authenticate方法 # 1. 从请求的META获取token信息 # 2. 判断信息是否合法或这不存在 # 2.1 不存在:表示游客,返回None # 2.2 存在但是错误:非法用户,抛出异常 # 2.3 存在且正确:返回 (用户, 认证信息) return (user,None)
REST_FRAMEWORK = { # 认证类配置 ‘DEFAULT_AUTHENTICATION_CLASSES‘: [ ‘rest_framework.authentication.SessionAuthentication‘, ‘rest_framework.authentication.BasicAuthentication‘, ‘xxxx.xxxxxx.MyAuthentication‘ # eg:‘utils.authentications.MyAuthentication‘ ] }
def xxxx(APIView): authentication_classes = (MyAuthentication,SessionAuthentication,BasicAuthentication) ...... def get(): ......
经过认证组件之后我们知道request对象中保存这当前请求的用户,下面执行self.check_permissions(request)
方法,点击进入源码
def check_permissions(self, request): """ Check if the request should be permitted. Raises an appropriate exception if the request is not permitted. """ for permission in self.get_permissions(): if not permission.has_permission(request, self): self.permission_denied( request, message=getattr(permission, ‘message‘, None) )
看到这个过程简直是似曾相识,对,他和认证组件一个设计模式,通过self.get_permissions()
获取权限类的对象列表,然后遍历,源码:
def get_permissions(self): """ Instantiates and returns the list of permissions that this view requires. """ return [permission() for permission in self.permission_classes]
发现同样是一个列表推导式,查看源码发现他和认证组件就是放在一起,接着点击self.permission_classes
,同样发现permission_classes = api_settings.DEFAULT_PERMISSION_CLASSES
也是由api_settings
配置得到,从rest_framework/settings.py
文件找到得到默认配置。
‘DEFAULT_PERMISSION_CLASSES‘: [ ‘rest_framework.permissions.AllowAny‘, ],
默认配置了一个AllowAny
,此时程序开始遍历权限类的对象,执行has_permission
方法的到返回值,如果为False表示他没有权限,继续执行self.permission_denied()
方法直接抛出异常,为True则遍历下一个,直到全部为True,遍历结束,什么也不做就表示拥有配置的所有权限。
所以首先,让我们去了解has_permission
到底做了什么?以及系统默认包含了那些认证类,通过rest_framework/permissions.py
可以查看所有的权限类
发现大致分为以下几个类,BasePermission
类是其他类的父类,而且每个类都实现了has_permission
方法。
class BasePermission(metaclass=BasePermissionMetaclass): """ A base class from which all permission classes should inherit. """ def has_permission(self, request, view): """ Return `True` if permission is granted, `False` otherwise. """ return True def has_object_permission(self, request, view, obj): """ Return `True` if permission is granted, `False` otherwise. """ return True class AllowAny(BasePermission): """ Allow any access. This isn‘t strictly required, since you could use an empty permission_classes list, but it‘s useful because it makes the intention more explicit. """ def has_permission(self, request, view): return True class IsAuthenticated(BasePermission): """ Allows access only to authenticated users. """ def has_permission(self, request, view): return bool(request.user and request.user.is_authenticated) class IsAdminUser(BasePermission): """ Allows access only to admin users. """ def has_permission(self, request, view): return bool(request.user and request.user.is_staff) class IsAuthenticatedOrReadOnly(BasePermission): """ The request is authenticated as a user, or is a read-only request. """ def has_permission(self, request, view): return bool( request.method in SAFE_METHODS or request.user and request.user.is_authenticated )
接下来分别解释一个每个类:
AllowAny
:直接返回True,任何用户拥有权限IsAuthenticated
:必须是认证信息通过的用户IsAdminUser
:必须是认证信息通过的用户且is_staff
为True的用户,数据库保存的结果为1IsAuthenticatedOrReadOnly
:表示通过认证用户拥有权限或者游客以及认证失败的用户只能有SAFE_METHODS
属性内定义的请求方法,默认为SAFE_METHODS = (‘GET‘, ‘HEAD‘, ‘OPTIONS‘)
权限组件相对过程比较简单,因为他是建立在认证组件基础之上,下面就让我们自定义权限组件。
通过源码的分析,我们可以同样也知道实现自定义权限类必要条件,继承BasePermission
,然后实现has_permission
方法,最终通过判断返回True或False。
from rest_framework.permissions import BasePermission class MyPermission(BasePermission): def has_permission(self, request, view): # 判断逻辑xxxxxxx # 返回True或False return True or Flase
REST_FRAMEWORK = { # 权限类配置 ‘DEFAULT_PERMISSION_CLASSES‘: [ ‘utils.permissions.MyPermission‘, ], }
def xxxx(APIView): permission_classes = (MyPermission,) ..... def get(): ....
前面的认证和权限组件处理完成之后接下来就是限流组件,代码运行到self.check_throttles(request)
,点击查看源码:
def check_throttles(self, request): """ Check if request should be throttled. Raises an appropriate exception if the request is throttled. """ throttle_durations = [] for throttle in self.get_throttles(): if not throttle.allow_request(request, self): throttle_durations.append(throttle.wait()) if throttle_durations: # Filter out `None` values which may happen in case of config / rate # changes, see #1438 durations = [ duration for duration in throttle_durations if duration is not None ] duration = max(durations, default=None) self.throttled(request, duration)
首先定义了一个throttle_durations
空列表,之后又是循环遍历self.get_throttles()
,可以想象他和认证组件、权限组件应该是一个样子,返回限流类对象列表,源码如下:
def get_throttles(self): """ Instantiates and returns the list of throttles that this view uses. """ return [throttle() for throttle in self.throttle_classes]
果然是一个列表推导式,保存的是限流类对象,同样我们也会想到它应该也是通过api_settings
配置,点击self.throttle_classes
查看throttle_classes = api_settings.DEFAULT_THROTTLE_CLASSES
,继续查看默认配置信息。
‘DEFAULT_THROTTLE_CLASSES‘: [],
发现结果是一个空列表,意思也就是,默认并没有采用任何一个类来限制用户请求频率。通过认证的类定义在rest_frameworks/authentication.py
和权限类定义在rest_framework/permission.py
,我们应该在rest_framework
下面查找类似throttle
类,即rest_frameworks/throttling.py
,并且通过源码应该不难发现他们应该都实现了allow_request
方法。
通过类的继承关系我们发现