谷歌提出协同机器学习:通过分散的手机更新同一个云端模型

yanglin 2017-04-07

选自Google Blog

作者:Brendan McMahan、Daniel Ramage

机器之心编译

参与:微胖、Ellen Han


标准的机器学习方法要求在一个机器或者数据中心集中训练数据。谷歌已经打造出用于数据处理的最安全、最强健的云基础架构之一。现在,为了通过用户与移动设备交互来训练模型,我们推出了另一种办法:联合学习(federated learning)。

联合学习可以让移动手机协同(collaboratively)学习一个共享的预测模型,与此同时所有训练数据仍保留在设备上,将机器学习与数据储存在云端的需求脱钩。通过在设备上进行模型训练,这一方法超越了对在移动设备端进行预测的本地模型的使用方式(比如移动视觉 API 以及设备智能回复)。

工作原理是这样的:你的设备会下载一个当前模型,通过从手机数据中学习不断改善模型,然后将这些变化总结为一个小的重点更新。只有这一重点更新才以加密方式会被传到云端,在云端,这一更新会迅速被其他用户针对共享模型的更新平均化(averaged)。所有训练数据仍然留在你的设备上,而且个别更新不会存储到云端。

谷歌提出协同机器学习:通过分散的手机更新同一个云端模型

手机在本地个性化模型,以你的手机使用方式为基础(A)。许多用户的更新会集中(B)起来,形成针对共享模型的协同一致的变更(C),然后重复这个过程。

联合学习考虑到了让模型更聪明、延迟更低、更节能而不让隐私受到威胁。而且,这一方法还有一个间接好处:除了提供共享模型的更新,你还能立刻使用改善后的模型,根据你使用手机的方式不同,体验也会不同。

我们目前正在安卓的 Gboard(谷歌键盘)上测试联合学习。当键盘提出一个建议问询时,手机就会在本地存储相关信息,比如当前文本,以及你是否点击了相关建议。联合学习在设备上处理这一过程,并对键盘问询建议的迭代提出改善建议。

谷歌提出协同机器学习:通过分散的手机更新同一个云端模型

为了实现联合学习,我们不得不克服许多算法和技术难题。在一个典型的机器学习系统中,一个优化算法,比如随机梯度下降(SGD)通常运行在一个大型数据组上,这个大型数据组通常在跨服务器均质分区。这类高度迭代算法需要低延迟,高通量的数据连接。但是,在联合学习中,数据分布跨越数百万设备,而这些设备的分布高度不均匀。另外,这些设备还存在明显的更高延迟、更低通量的连接情况,而且间歇适合于训练。

这些带宽和延迟局限性,激发我们研究出联合平均算法(Federated Averaging algorithm (https://arxiv.org/abs/1602.05629)),这一办法仅用少于朴素 SGD 联合版 10 到 100 倍的连接,就能训练深度网络。关键思想就是:利用现代移动设备上的强大处理器来计算更高质量的更新,而不是简单的梯度步。既然更少的高质量更新迭代就能生成一个好的模型,那么,训练所用的连接也要少的多。由于上传速度通常比下载速度慢很多,因此,我们也找到一个新办法来减少上传连接成本,通过使用随机轮转及量化(random rotations and quantization)来压缩更新,成本压缩了 100 倍。虽然这些方法聚焦的是训练深度神经网络,但是,我们也设计了用于高维解析凸模型的算法,更擅长解决诸如预测点击到达率之类的问题。

将这一技术部署到数以百万计、使用了谷歌键盘的异构手机上,需要成熟的技术堆栈。设备训练使用了一个迷你版的 TF。仔细安排日程能确保训练仅仅发生在设备闲置、处在插电状态时,并且使用的是无线网路,因此,对手机用户体验没啥影响。

谷歌提出协同机器学习:通过分散的手机更新同一个云端模型

仅当手机不会对你的体验产生负面影响时,它才会参与到联合学习中去。

然后,系统需要以一种安全、高效、可扩展以及容错的方式来联通并聚合模型更新。只有将研究与这一基础架构结合起来才能让联合学习发挥作用。

不需要将用户数据存入云端就能进行联合学习,但这还不够。我们已经开发了一个使用密码技术的安全聚合协议(Secure Aggregation protocol (http://eprint.iacr.org/2017/281)),只有当几万或几十万用户参与进来,一个协同服务器才能解密被平均的更新——在平均化前,个体手机更新是无法被监测到的。在用来解决深度网络大小以及真实世界连接限制问题的这类协议中,这是首例。我们也设计了联合平均(Federated Averaging),这样,协同服务器仅需要平均后的更新,就能使用安全聚合协议;不过,协议是通用的,它还能应用到其他问题上。我们正致力于将这一协议推行到产品中,也期待在不久的将来将其部署到联合学习的应用中。

我们的工作仅仅触及了可能性的表层。联合学习无法解决所有机器学习难题(例如,通过在仔细标注过的样本上训练,学会识别不同种类的狗),而且对于许多其他模型来说,必要的训练数据已经存储于云端(例如,为 Gmail 训练垃圾邮件过滤器)。所以,谷歌将继续推进最新的基于云的机器学习研究,但是,我们也承诺继续研究扩大联合学习解决问题的范围。比如,除了谷歌键盘问询建议,我们希望根据你在手机上的真实输入,改善驱动键盘的语言模型,以及根据人们查看、分享以及删除的图片内容来改善照片排列。

应用联合学习需要机器学习实践者采用新工具和新的思维方式:在无法直接接触或标记初始数据,通信成本有限(communication cost)的情况下,进行模型研发、训练以及评估。我们相信,联合学习为用户带来的好处让解决技术挑战是有价值的,我们也满怀与机器学习研究社区进行广泛对话的希望发表自己的研究成果。

谷歌提出协同机器学习:通过分散的手机更新同一个云端模型

相关推荐