John Schulman's Homepage

Sending Samples Without Bits-Back

Posted on 2020/03/08
← back to blog index

\gdef{\logpinv}{\log p^{-1}} \gdef{\klqp}{\mathrm{KL}[q,p]} \gdef{\paccz}{p(\mathrm{accept \ }z)} \gdef{\pacczn}{p(\mathrm{accept }z^n)} \gdef{\logfracn}{\log \frac{q^n(z^n)}{p^n(z^n)}} \gdef{\qtil}{\tilde{q}} \gdef{\minv}{\frac{1}{M}} \gdef{\Pr}{\operatorname{Pr}}

Intro

I’ll describe a fun little problem in information theory, and a solution–a compression algorithm–based on rejection sampling. This problem is motivated by the coding interpretation of the variational bound, which seems to be a valuable source of intuition. The compression algorithm I describe is distinct from the well-known idea of bits-back coding and gives a more direct interpretation of the variational bound objective. It’s terribly computationally inefficient, but I think it’s interesting as a proof of principle.


The Problem. Alice and Bob initially agree on a “prior” distribution p(z), and they have a shared random number generator (RNG). Later, Alice is given a different distribution q(z). How long of a message does Alice need to send to Bob so that by combining the message with the RNG, he can produce a sample z \sim q(z)?

More precisely, Alice and Bob agree on a deterministic function f(\omega, m) of the RNG state \omega and the message m. Alice computes the message m as a function of q, and the distribution of f(\omega, m) must equal q(z).


Let’s analyze the problem for a couple of simple cases:

  1. If q(z)=p(z), then the message length is zero: Alice can just tell Bob to take his first sample from p.
  2. If q(z)=I[z=z_0], i.e. an indicator on one value z_0, then the best Alice can do is to send z_0 with message length \logpinv(z_0).

Note that this problem is subtlely different from the problem where Alice samples z \sim q (using a non-shared RNG) and then must send it to Bob. Sending arbitrary z requires expected code-length E_{q}[\logpinv(z)]. If Alice samples z\sim q and sends it to Bob using the code from p(z), it would require more bits than necessary. In particular, it would take S[p] bits in the case that q(z)=p(z) rather than zero.

You might guess that the general answer is \klqp, which gives the correct answer in examples (1) and (2) above. That guess is correct! I’ll prove it below (after adding some pesky details). But first, I’ll explain the motivation for this communication problem.

Variational Upper Bound

Many concepts of probability have a corresponding interpretation in terms of codes and compression. A key idea, needed for models with latent variables (like the variational autoencoder (VAE)) is variational upper bound (VUB). It’s usually called the variational lower bound, but we’ll flip the sign so it’ll correspond to code-length.

The VUB is the objective used for fitting probabilistic models with latent variables. Given a model p(x,z)=p(z)p(x|z), we typically want to maximize \log p(x), but there’s an intractable sum over z. The VUB introduces a sample distribution q(z), giving an upper bound on the log-loss. \logpinv(x) \le \underbrace{\klqp}_{(*)} + \underbrace{E_{z \sim q}[\logpinv(x|z)]}_{(**)} Equality occurs at q(z)=p(z|x), and training (e.g., for the VAE) involves jointly minimizing the RHS with respect to p and q

The LHS of the inequality reads “the number of bits Alice needs to send to Bob to transmit x, given that they previously agreed on distribution p(x)”. Can the RHS be interpreted as a concrete compression scheme for x, involving a code z that partially encodes x?

We’d like to say something like this:

The second point, interpreting E_{z \sim q}[\logpinv(x|z)] as the code-length of x given z, is clearly true. The first point is non-obvious, and it’s precisely the problem stated above.

The variational upper bound is indeed the code-length under a well-known compression scheme, called bits-back coding. However, bits-back coding doesn’t quite match the simple two-part-code interpretation given above. In bits-back coding, Alice samples z using some auxiliary data as the entropy source, then sends the whole sample to Bob at an expected cost of E_q[\logpinv(z)]. Then, she sends x at a cost of \logpinv(x | z). Finally, using x to infer the distribution q, Bob recovers the E_q[\log q^{-1}(z)] bits of auxiliary data, giving a net code-length of E_q[\logpinv(z)]+E_q[\logpinv(x | z)]-E_q[\log q^{-1}(z)]=\klqp + E_q[\logpinv(x | z)].

Bits-back is a pretty slick idea, but I’ve always wondered if the interpretation as a two-part code can be directly implemented. In particular, can the code-word z be communicated using cost \klqp, without using x at all?

Naive Rejection Sampling (Suboptimal)

Let’s return to the problem, where Alice needs to send Bob a message that lets him sample z \sim q. One natural idea is to use rejection sampling. Rejection sampling allows you to sample from q(z) by (stochastically) filtering samples from a different distribution p(z). Alice uses her RNG to generate a sequence of IID samples from p(z), but applies the rejection criterion so that the first accepted sample is a sample from q(z). Then she sends to Bob the index n of the first accepted sample. Bob runs the same process on his end and takes the nth sample, which is the same as Alice’s nth sample due to the shared RNG.

Now let’s look at the expected code-length of this protocol. In rejection sampling, we sample z \sim p(z) and accept with probability \paccz = \minv\frac{q(z)}{p(z)}, where M=\max_z \frac{q(z)}{p(z)}.

The probability of accepting a sample is E_z[\paccz]=\minv. Given that an event has probability \epsilon, the expected number of samples until it occurs is 1/\epsilon. Hence, the expected number of trials of the rejection sampling process is just M. The code-length of this integer is \log M=\log \max_z \frac{q(z)}{p(z)} = \max_z \log \frac{q(z)}{p(z)}.

Hence, this rejection sampling procedure attains a code-length \max_z \log \frac{q(z)}{p(z)}. But we’ve claimed above that the optimal code-length is \klqp=E_{z\sim q}[\log \frac{q(z)}{p(z)}]. So rejection sampling is suboptimal in general, replacing the expectation E_q by a max \max_q.

Modified Rejection Sampling

The rejection sampling approach almost works, but it gives suboptimal code-length. Like many ideas in coding theory, we can fix the problem by grouping together a bunch of messages and use the law of large numbers. We’ll group together n samples and do rejection sampling with \log M \approx n\klqp.

Let’s modify the communication problem to have Alice send Bob n samples at a time. In the modified problem, Alice and Bob agree on (p_1, p_2, \dots, p_n); Alice needs to send Bob a sample (z_1, z_2, \dots, z_n) from (q_1, q_2, \dots, q_n). To simplify the argument, let all of the distributions be the equal; p_1 = p_2 = \dots = p_n; q_1 = q_2 = \dots = q_n. The argument can easily be modified for the case where these distributions are not equal.

As for notation, let z^n denote an n-tuple of samples, and let p^n(z^n) and q^n(z^n) denote the joint distributions over n-tuples of samples.

We’ll define the communication protocol as follows. Alice repeatedly samples z^n \sim p^n, accepting with probability \min\left(1, \minv\frac{q^n(z^n)}{p^n(z^n)}\right), where \log M=n(\klqp+\epsilon), and \epsilon is a small number that \rightarrow 0 as n \rightarrow \infty. Then she sends Bob an integer k–the number of trials until acceptance, and he takes the kth sample from p^n.

To show that this protocol works, we will prove the following two statements:

  1. The expected message length per sample z_i approaches \klqp as n \rightarrow \infty.
  2. The total variation divergence between each decoded z_i and q(z) approaches zero as n \rightarrow \infty.

The proof will be based on the idea of typical sets introduced by Shannon. We’ll also explain how to slightly modify the protocol to send exactly q instead of an approximation (at the cost of some extra bits).

Consider the log ratio \log \frac{q^n(z^n)}{p^n(z^n)} = \sum_{i=1}^n \log \frac{q(z_i)}{p(z_i)} For z sampled from q, the expectation of each of these terms is E_q[\log \frac{q(z)}{p(z)}]=\klqp. Informally speaking, this sum will probably be around n\klqp \pm O(\sqrt{n}). Let’s state this more formally.

Let \epsilon>0 be a small number. As n \rightarrow \infty, the sample average of \log \frac{q(z)}{p(z)} approaches its mean value, \klqp, so we get \Pr \left(\frac{1}{n}\logfracn \le \klqp+\epsilon\right) \ge 1-\epsilon for z^n \sim q^n. Let S denote the set of z^n satisfying \frac{1}{n}\logfracn \le \klqp - \epsilon. S satisfies \Pr(z\sim q \in S) \ge 1-\epsilon.

Alice does rejection sampling by sampling z^n \sim p^n and then accepting with probability \Pr(\text{accept \ } z^n) = \min(1, \minv\frac{q^n(z^n)}{p^n(z^n)}), where \log M=n(\klqp+\epsilon). \Pr(\text{sample } z^n \sim p^n \text{ and accept }) = \begin{cases} \frac{q^n(z^n)}{M} \qquad z \in S \\ <\frac{q^n(z^n)}{M} \qquad z \notin S \\ \end{cases} Now let’s compute the probability of acceptance: \begin{aligned} \Pr(\text{accept})&=\sum_{z^n}\Pr(\text{sample } z^n \sim p^n \text{ and accept })\\ &=\sum_{z^n \in S} \frac{q^n(z^n)}{M} + \sum_{z \notin S} \text{[positive value]}\\ &\ge \sum_{z^n \in S} \frac{q^n(z^n)}{M}\\ &\ge (1 - \epsilon) / M \end{aligned} The message length is \log(\frac{1}{\Pr(\text{accept})})=\log M + O(\epsilon) = n\klqp + O(\epsilon), proving the first part of the proposition.

For the second part of the proposition, let’s define \qtil^N to be the decoded distribution over z^n when following the rejection sampling protocol. \qtil^n(z^n) \propto p^n(z^n)P(\text{accept } z^n) Define q_S to be the distribution of samples from q^n, conditioned on membership in S. For z^n \in S, q^n are proportional, as follows: q^n(z^n) = q_S(z^n) P_{S|q} \quad\text{where}\quad P_{S|q}=\Pr(z^n \sim q^n \in S)\\ \qtil^n(z^n) = q_S(z^n) P_{S|\qtil} \quad\text{where}\quad P_{S|\qtil} =\Pr(z^n \sim \qtil^n \in S) Furthermore, 1 \ge P_{S|\qtil} \ge P_{S|q} \ge 1- \epsilon. A routine calculation shows that the total variation divergence is O(\epsilon). This proves the second part of the proposition.

Finally, it’s a bit unsatisfying that Alice doesn’t send exactly q^n, she sends an approximation \qtil^n. We can easily fix this issue and have Alice send exactly q^n at the cost of some extra bits. Here’s a sketch. With probability P_{S|q}, we perform the protocol above. With probability 1-P_{S|q}, Alice directly sends z^n, sampled the compliment of S, at a cost of -\log p^n(z^n). Overall, the extra cost is O(\epsilon).

Discussion

This procedure is computationally intractable, since it requires Alice to generate a sequence of samples from an exponentially large set of tuples (z_1, z_2, \dots, z_n). This contrasts with bits-back coding, which can be implemented efficiently. In fact, a recent paper showed how to implement bits-back coding with VAEs, cleverly using ANS (a relative of arithmetic coding).

It’s possible that there’s a procedure like arithmetic coding that solves our problem, giving an efficient algorithm in the case that z lives in a small discrete set. If z is high-dimensional, then it seems unlikely that we can solve the transmission problem efficiently without additional assumptions–all we can do is enumerate samples from p and index into them.

Finally, I wouldn’t be surprised if this problem is well-known–it seems like a natural way of formalizing the idea of lossy data transmission. If so, please send me a pointer.

Thanks to Nik Tezak and Beth Barnes for helpful feedback.