机器之心分析师网络
作者:周宇
编辑:H4O
本文重点探讨分布式学习框架中针对随机梯度下降(SGD)算法的拜占庭问题。
分布式学习(Distributed Learning)是一种广泛应用的大规模模型训练框架。在分布式学习框架中,服务器通过聚合在分布式设备中训练的本地模型(local model)来利用各个设备的计算能力。分布式机器学习的典型架构——参数服务器架构中,包括一个服务器(称为参数服务器 - Parameter Server,PS)和多个计算节点(workers,也称为节点 nodes)[1]。其中,随机梯度下降(Stochastic Gradient Descent,SGD)是一种广泛使用的、效果较好的分布式优化算法。在每一轮中,每个计算节点根据不同的本地数据集在它的设备上训练一个本地模型,并与服务器共享最终的参数。然后,服务器聚合不同计算节点的参数,并通过与计算节点共享得到的组合参数来启动下一轮训练。关于基于 SGD 优化的分布式框架的网络结构(包括:层数、类型、大小等)在训练开始之前由所有计算节点共同商定确认。
近年来,分布式学习的安全性越来越受到人们的关注,其中,最重要的就是拜占庭威胁模型。在拜占庭威胁模型中,计算节点可以任意和恶意地行事。机器之心在前期的文章中也探讨过分布式学习中的拜占庭问题,主要针对联邦学习中的拜占庭问题。在这篇文章中,我们重点探讨的是分布式学习框架中针对随机梯度下降(SGD)算法的拜占庭问题。如图 1 所示,在 SGD 学习框架中,一些恶意节点(Malicious worker)向服务器发送拜占庭梯度(Byzantine Gradient),而不是计算得到的真实梯度,而拜占庭梯度可以是任意值。恶意节点可以控制计算节点设备本身,也可以控制节点和服务器之间的通信。以 Algorithm 1 中提出的同步 SGD(sync-SGD)协议为例 [4]。攻击者(恶意节点)在使其效果最大化的时间内(即在 Algorithm 1 的第 6 行和第 7 行之间)干扰进程。在此期间,攻击者可以将节点 i 中的参数(p_i)^(t+1) 替换为任意值,然后将此任意值发送到服务器中。攻击方法在设置参数值的方式上有所不同,而防御方法则试图识别损坏的参数并丢弃它们。Algorithm 1 使用平均值(第 8 行中的 AggregationRule( ))聚合计算节点参数。
图 1. SGD 学习框架工作流程 [3]
本文所讨论的分布式学习的核心是这样一个假设:经过训练的网络参数是独立同分布的(Independent and identically distributed,i.i.d.)