Tensorflow:作为占位符的元组列表

tquod 发布于 2019-10-09 tensorflow 最后更新 2019-10-09 22:59 10 浏览

我想使用compute_gradients并生成本地渐变。这些梯度将与来自其他机器的多个本地梯度平均,之后会调用apply_gradients。我在第二个接受渐变的session.runs中使用了PLACEHOLDER_FOR_CODE_2。由于apply_gradients需要元组列表,因此我正在寻找一种有效的方法来完成此操作。 这就是我生成元组占位符列表的方法:

grads  = cifar10.train_part1(loss, global_step)
xx = [tf.placeholder(tf.float32, shape=grads[0][0].shape) for i in range(10)]
yy = [tf.placeholder(tf.float32, shape=grads[0][0].shape) for i in range(10)]
xyz = zip(xx,yy)
train_op = cifar10.train_part2(loss,global_step, xyz)
我收到以下错误:
NotImplementedError: ('Trying to optimize unsupported type ', tf.Tensor 'Placeholder_10:0' shape=(5, 5, 3, 64) dtype=float32)
已邀请: