sklearn 源码分析系列:neighbors(2)

by DemonSonggithub(https://github.com/demonSong/DML)

我起初一直在纠结是否需要把kd_tree的实现也放在这一篇中讲,如果讲算法实现,就违背了源码分析的初衷,过早钻入细节,是阅读源码的大忌。算法和框架的分析应属两部分内容,所以最终决定,所有sklearn源码分析系列不涉及具体算法,而是保证每个方法调用的连通性,重点关注架构,以及一些必要的python实现细节。

Note:

这篇文章主要分析Neighbors包中的Unsupervised Nearest Neighbors相关接口,对应于官方文档1.6.1章节,详见文档

Finding the Nearest Neighbors实操

详细实操代码可参考Github kaggle项目,详见链接

在实现最近邻算法时,常用的算法有”kd_tree”,”ball_tree”,”brute”三种,它们对应于不同的应用场景,这里不再赘述。

数据生成与可视化

# 1.6.1 Unsupervised Nearest Neighbors

from sklearn.neighbors import NearestNeighbors
import numpy as np
import matplotlib.pyplot as plt


# 1.6.1.1 Finding the Nearest Neighbors
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])

plt.figure()
plt.scatter(X[:,0],X[:,1])
plt.xlim(X[:,0].min()-1,X[:,0].max()+1)
plt.ylim(X[:,1].min()-1,X[:,1].max()+1)
plt.title("Unsupervised nearest neighbors")
plt.show()

# k个最近的点中包含自己
nbrs = NearestNeighbors(n_neighbors=3, algorithm='ball_tree').fit(X)

distances,indices = nbrs.kneighbors(X)

# k个最近点的下标,按升序排列
indices

alt text

输出:

array([[0, 1, 2],
       [1, 0, 2],
       [2, 1, 0],
       [3, 4, 5],
       [4, 3, 5],
       [5, 4, 3]], dtype=int64)
# k个最近点的最短距离,按升序排列
distances

Out[2]:
array([[ 0.        ,  1.        ,  2.23606798],
       [ 0.        ,  1.        ,  1.41421356],
       [ 0.        ,  1.41421356,  2.23606798],
       [ 0.        ,  1.        ,  2.23606798],
       [ 0.        ,  1.        ,  1.41421356],
       [ 0.        ,  1.41421356,  2.23606798]])

kneighbors(X)默认返回两个参数,其中k个最近邻中还包含了自己,距离和下标均按照升序排列。

# k个最近点生成的邻接矩阵
nbrs.kneighbors_graph(X).toarray()

Out [3]:
array([[ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 1.,  1.,  1.,  0.,  0.,  0.],
       [ 0.,  0.,  0.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.],
       [ 0.,  0.,  0.,  1.,  1.,  1.]])

# 1.6.1.2 KD Tree and Ball Tree Classes
from sklearn.neighbors import KDTree
import numpy as np

# 可直接用KDtree实现最近邻查找
X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
kdt = KDTree(X, leaf_size=30, metric='euclidean')
kdt.query(X,k = 3,return_distance = False)

Out [4]:
array([[0, 1, 2],
       [1, 0, 2],
       [2, 1, 0],
       [3, 4, 5],
       [4, 3, 5],
       [5, 4, 3]], dtype=int64)

源码剖析

我们先从整体上来看看,实现NearestNeighbors所需关联到的python文件及对应的文件结构是什么样子的。

alt text

相比于Neighbors(1)中的内容,它多了unsupervised.py文件而已。所以,我们直接顺藤摸瓜开始分析。

unsupervised.py

class NearestNeighbors(NeighborsBase, KNeighborsMixin,
                       RadiusNeighborsMixin, UnsupervisedMixin):
    def __init__(self, n_neighbors=5, radius=1.0,
                     algorithm='auto', leaf_size=30, metric='minkowski',
                     p=2, metric_params=None, n_jobs=1, **kwargs):
            self._init_params(n_neighbors=n_neighbors,
                              radius=radius,
                              algorithm=algorithm,
                              leaf_size=leaf_size, metric=metric, p=p,
                              metric_params=metric_params, n_jobs=n_jobs, **kwargs)

这是一个明显的子类继承多个父类的情况,其中KNeighborsMixinRadiusNeighborsMixin属于功能相同,但具体实现细节有所差异,只单独分析一例。

先来看看它的构造方法吧,构造方法中传入了,9个参数,都是带默认值的。但令人奇怪的是,它同样是空有型而无内容的【初始化类】,该类只与客户端打交道,而真正的参数初始化都交给了其中的某个父类的__init__params()方法。为什么要这么做?不急,先看看到底是哪个父类完成了参数初始化。

所有父类集中在neighbors包下的base.py文件中。
alt text

经过一番寻找总算找到了初始化参数方法,在类neighborsBase

class NeighborsBase(six.with_metaclass(ABCMeta, BaseEstimator)):
    """Base class for nearest neighbors estimators."""

    @abstractmethod
    def __init__(self):
        pass

    def _init_params(self, n_neighbors=None, radius=None,
                     algorithm='auto', leaf_size=30, metric='minkowski',
                     p=2, metric_params=None, n_jobs=1):

        self.n_neighbors = n_neighbors
        self.radius = radius
        self.algorithm = algorithm
        self.leaf_size = leaf_size
        self.metric = metric
        self.metric_params = metric_params
        self.p = p
        self.n_jobs = n_jobs

        if algorithm not in ['auto', 'brute',
                             'kd_tree', 'ball_tree']:
            raise ValueError("unrecognized algorithm: '%s'" % algorithm)

        if algorithm == 'auto':
            if metric == 'precomputed':
                alg_check = 'brute'
            else:
                alg_check = 'ball_tree'
        else:
            alg_check = algorithm

        if callable(metric):
            if algorithm == 'kd_tree':
                # callable metric is only valid for brute force and ball_tree
                raise ValueError(
                    "kd_tree algorithm does not support callable metric '%s'"
                    % metric)
        elif metric not in VALID_METRICS[alg_check]:
            raise ValueError("Metric '%s' not valid for algorithm '%s'"
                             % (metric, algorithm))

        if self.metric_params is not None and 'p' in self.metric_params:
            warnings.warn("Parameter p is found in metric_params. "
                          "The corresponding parameter from __init__ "
                          "is ignored.", SyntaxWarning, stacklevel=3)
            effective_p = metric_params['p']
        else:
            effective_p = self.p

        if self.metric in ['wminkowski', 'minkowski'] and effective_p < 1:
            raise ValueError("p must be greater than one for minkowski metric")

        # 重点关注
        self._fit_X = None
        self._tree = None
        self._fit_method = None

喔,原来NeighborsBase是要作为整个Neighbors最具领导力的类?起码这家伙拿到了全局信息吧,我的一个猜测是,除了unsupervised需要用到这些参数之外,其他类也同样需要用这些参数做些有趣的事吧?所以既然大家都要复用这些参数!那就放在一个基类中吧,此处就叫NeighborsBase吧。(待检验)

我们关注下方法本身中的参数:
1. self.n_neighbors = n_neighbors ## k近邻中的k
2. self.radius = radius ## 不知
3. self.algorithm = algorithm ## 使用何种k近邻算法,如’kd_tree’
4. self.leaf_size = leaf_size ## 生成’kd_tree’树需要传入的参数
5. self.metric = metric ## 计算其他各种形式的两点间距离
6. self.metric_params = metric_params ## 不知
7. self.p = p ## 不知
8. self.n_jobs = n_jobs ## 并发创建的线程数

除此之外,在初始化最后,还占了三个位:
1. self._fit_X = None ## fit_X 和传入的X之间有何关系?
2. self._tree = None ## _tree表示返回的树结构
3. self._fit_method = None ## fit传入的算法

NeighborsBase就这些内容,它还有一个_fit()方法,稍后分析。总的来说,当客户端调用诸如nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30)的构造方法时,NearestNeighbors什么都没做,把参数初始化任务交给了它的父类NeighborsBase(该小组的老大!),而这老大具体也没做什么具体的事,把该初始化的参数初始化,并做一些参数合法性的检查,完工。

模型参数初始完毕之后,自然到了fit步骤,正如,客户端调用那样nbrs = NearestNeighbors(n_neighbors=3, algorithm='kd_tree',leaf_size=30).fit(X)我把数据X,传给了谁?谁来拟合这些数据呢?

记得NearestNeighbors中的几个父类吧,完成fit操作的是UnsupervisedMixin类,接着来看看它的代码。

class UnsupervisedMixin(object):
    def fit(self, X, y=None):
        """Fit the model using X as training data

        Parameters
        ----------
        X : {array-like, sparse matrix, BallTree, KDTree}
            Training data. If array or matrix, shape [n_samples, n_features],
            or [n_samples, n_samples] if metric='precomputed'.
        """
        return self._fit(X)

非常简短,针对非监督的数据,全部交给了自己的self._fit(X)方法,所以它又是个代理类?这个代理类更狠,什么都没做,直接转交给NearestNeighbors中的某个父类来完成。调用_fit()方法后,就又回到了NeighborsBase中去了,所以当客户端要调用fit方法时,先交给了NeighborsBase的手下UnsupervisedMixin做一些前期的处理操作,但这手下学会了偷懒,什么都没做直接交给了领导,直接让领导来处理咯,真坏。那领导真的有功夫,有能力处理这个fit任务?领导也不傻,我们看看领导怎么做的。

def _fit(self, X):

        ......
        # 做些必要的检查
        X = check_array(X, accept_sparse='csr')

        # 还是在做检查
        n_samples = X.shape[0]
        if n_samples == 0:
            raise ValueError("n_samples must be greater than 0")
        ......

        #前面占的位子给补上
        self._fit_method = self.algorithm
        self._fit_X = X

        ......

        # 嘿,领导开始派发任务了
        if self._fit_method == 'ball_tree':
            self._tree = BallTree(X, self.leaf_size,
                                  metric=self.effective_metric_,
                                  **self.effective_metric_params_)
        # 看到了熟悉的kd_tree了                       
        elif self._fit_method == 'kd_tree':
            self._tree = KDTree(X, self.leaf_size,
                                metric=self.effective_metric_,
                                **self.effective_metric_params_)
        elif self._fit_method == 'brute':
            self._tree = None
        else:
            raise ValueError("algorithm = '%s' not recognized"
                             % self.algorithm)

        # 检查,为什么不放在一开始做?
        if self.n_neighbors is not None:
            if self.n_neighbors <= 0:
                raise ValueError(
                    "Expected n_neighbors > 0. Got %d" %
                    self.n_neighbors
                )

        return self

唉,领导也没有干活啊,做了一些检查,根据来的参数,交给对应的具体执行者去做!但返回的还是自己,因为我要和客户端打交道。我们来分析下具体的执行者做了些什么操作。看如下代码,

elif self._fit_method == 'kd_tree':
            self._tree = KDTree(X, self.leaf_size,                               metric=self.effective_metric_,                          **self.effective_metric_params_)

NeighborsBasefit()方法中,并不是返回某个模型对象,而是把模型对象内嵌到了NeighborsBase中的self._tree中去,这是为什么?kd_tree模型本身有查询最近邻的方法,为什么不直接暴露给客户端呢?在这里我并不理解它这样做的用意是什么。(待解决)

所以对于数据真正的fit()是交给具体算法来完成的,咱们接下来就看看kd_tree.py吧。关于kd_tree的算法细节,可以参考之前我的一篇博文【K近邻法学习笔记】。关于sklearn中kd_tree的具体分析,不作为本文内容,日后单独开辟一章来讲解。本文重点关注各接口的实现与内在联系。

alt text

所以当NeighborsBase构造了kd_tree时,就调用了它的构造方法,走。

def __init__(self, data, leafsize=10):
        self.data = np.asarray(data)
        self.n, self.m = np.shape(self.data)
        self.leafsize = int(leafsize)
        if self.leafsize < 1:
            raise ValueError("leafsize must be at least 1")
        self.maxes = np.amax(self.data,axis=0)
        self.mins = np.amin(self.data,axis=0)

        # 关键步骤
        self.tree = self.__build(np.arange(self.n), self.maxes, self.mins)

前面也是做了一些初始化操作,接着开始构建kd_tree的数据结构了。调用__build()方法,由传入的数据的生成了对应的数据结构。到这里,数据到结构的映射完成了。

Created with Raphaël 2.1.0 数据X到结构的映射 Client Client NearestNeighbors NearestNeighbors NeighborsBase NeighborsBase UnsupervisedMixin UnsupervisedMixin KDTree KDTree __init__() _init_params() fit() _fit() __init__() _build()

总结下,NearsetNeighbors和客户端打交到,而NeighborsBase统筹规划所有调度。

既然有了数据X到结构的映射,那自然要做真正的查询操作了(k近邻查询),我们继续来看看,客户端调用如下distances,indices = nbrs.kneighbors(X),在NearestNeighbors中只要初始化方法,并没有kneighbors(X)方法,该方法在它的另外一个父类KNeighborsMixin中。

class KNeighborsMixin(object):
    def kneighbors(self, X=None, n_neighbors=None, return_distance=True):
    ......
     n_samples, _ = X.shape
        sample_range = np.arange(n_samples)[:, None]
    ......
    elif self._fit_method in ['ball_tree', 'kd_tree']:
            if issparse(X):
                raise ValueError(
                    "%s does not work with sparse matrices. Densify the data, "
                    "or set algorithm='brute'" % self._fit_method)
            result = Parallel(n_jobs, backend='threading')(
                delayed(self._tree.query, check_pickle=False)(
                    X[s], n_neighbors, return_distance)
                for s in gen_even_slices(X.shape[0], n_jobs)
            )
            if return_distance:
                dist, neigh_ind = tuple(zip(*result))
                result = np.vstack(dist), np.vstack(neigh_ind)
            else:
                result = np.vstack(result)

很多东西都可以忽略不看,只需要关注一行代码就可以了。

result = Parallel(n_jobs, backend='threading')(
                delayed(self._tree.query, check_pickle=False)(
                    X[s], n_neighbors, return_distance)
                for s in gen_even_slices(X.shape[0], n_jobs)
            )

前面它包了一个并发的类,咱们不去研究,在delay方法中,传入了self._tree.query这是一个方法名,在之前KDTree类的接口中,有相应的实现,也就是说KNeighborsMixin类也不做任何查询操作,同样把查询交给了KDTree来完成,的确如此,只有KDTree中存放了相应的数据结构,不是它做查询谁来做查询,KNeighborsMixin只是简单的把KDTree返回的查询结果交给客户端就可以了,别无其他。

Created with Raphaël 2.1.0 查询结果的返回过程 Client Client NearestNeighbors NearestNeighbors KNeighborsMixin KNeighborsMixin KDTree KDTree k近邻查询 kneighbors(X) query(X) 查询结果 查询结果

综上,整个关于数据X到kd_tree的结构映射调用就完成了,也没有太多东西,理清各个类之间的关系就可以了。同样的,当要进行k近邻查询时,交给了NearestNeighbors中的父类KNeighborsMixin来代理查询,真正的查询操作还是kd_tree来完成,前期都是些琐碎的调用流程,而算法的核心在于kd_tree,起码数据在到kd_tree之前,能够做很多前期处理,保证了算法对数据的要求。看来是时候研究下kd_tree的核心算法了。

Logo

旨在为数千万中国开发者提供一个无缝且高效的云端环境,以支持学习、使用和贡献开源项目。

更多推荐