pdf

Categorical Reparameterization with Gumbel-Softmax

这篇文章的内容已经固化为了pytorch的一个函数

其作用是允许 Stochastic, Differentiable, Probabilistic Weighted, Indexing.

先用代码解释 gumbel采样可以如何用均匀随机采样表达:

def gumbel(*shape):
    u = np.random.rand(*shape)
    return -np.log(-np.log(u))

def gumbelsoftmax(weights, lmbda=1, N=10000):
    d = len(weights)
    logits = np.log(weights.reshape(d, 1))
    gumbel_noise = gumbel(d*N).reshape(d, N)
    return softmax(( logits + gumbel_noise)/ lmbda, axis=0)

通过采样 gumbelsoftmax,得到的分布近似于

其中变量理解为温度超参,其作用在于控制采样系统的随机性.

pytorch functional 的代码如下(不必复制使用,这内置于pytorch中):

def gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1):
    # type: (Tensor, float, bool, float, int) -> Tensor
    r"""
    Examples::
        >>> logits = torch.randn(20, 32)
        >>> # Sample soft categorical using reparametrization trick:
        >>> F.gumbel_softmax(logits, tau=1, hard=False)
        >>> # Sample hard categorical using "Straight-through" trick:
        >>> F.gumbel_softmax(logits, tau=1, hard=True)
    """

    if eps != 1e-10:
        warnings.warn("`eps` parameter is deprecated and has no effect.")

    gumbels = -torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_().log()  # ~Gumbel(0,1)
    gumbels = (logits + gumbels) / tau  # ~Gumbel(logits,tau)
    y_soft = gumbels.softmax(dim)

    if hard:
        # Straight through.
        index = y_soft.max(dim, keepdim=True)[1]
        y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
        ret = y_hard - y_soft.detach() + y_soft
    else:
        # Reparametrization trick.
        ret = y_soft
    return ret

值得注意的是这里使用了

ret = y_hard - y_soft.detach() + y_soft

这一个trick使得one-hot的在forward时是indexing量,但是backward的时候用的是的梯度。

这篇文章与很多其他内容相关,比如在强化学习中作为一个可以exploit又可以explore的hard indexing. 在网络剪枝中可以作为一个可以学习使用的参量。

比较奇妙的是这篇paper在ICLR发布的时候只是marginally accepted,也只是poster,主要是可能原作者以及Reviewer当时只是在考虑使用在Generative Model上,提升没有那么显著,而没有预知这么多后来的应用。不过后来大家对它的引用以及发挥是很巨大的.