Rejection Sampling: A Simple Mathematical Inspection

Written by Kevin Jacobs

I'm Kevin, a Data Scientist, PhD student in NLP and Law and blog writer for Data Blogger.

What is rejection sampling and why would you need it? Suppose that we have a probability density function (PDF) f(x) that is impossible to analyze analytically. This is where rejection sampling can be used for: a simple to implement (yet not always effective) sampling method.

Introduction

Suppose we have a function that is not computationally tracktable. How can we draw samples from it? One (of the many) methods that is out there, is rejection sampling. In this article, we will try to draw samples from the following distribution:

Hard function.

The secret function used for this plot is f(x) = \frac{3}{5} \mathcal{N}(x|\frac{7}{20}, \frac{1}{20}) + \frac{2}{5} \mathcal{N}(x|\frac{13}{20}, \frac{2}{25}), which is a multimodal Gaussian. (Note: this function is mathematically tractable, but it is used for demonstration purposes.)

Approximation

Now we have to come up with a probability density function g(x), such that M g(x) > f(x) for all x. In other words, we need a PDF that is, when multiplied by a constant M, higher than f(x). On our example, one such a function is g(x)=\mathcal{N}(x|\frac{9}{20}, \frac{1}{5}) and M=3. How did I found this function? First of all, Gaussians are easy to compute. Therefore, I choose a Gaussian which has its mean around the “middle” of f(x). I then tuned its variance and the constant M by just trying out some values. The result is the following M \cdot g(x), which is indeed always more than f(x):

The approximation.

The mathematics

So, what to do next? We sample pairs of (x, u) where x are samples from the approximation Gaussian and u is sampled from a uniform probability distribution in the interval [0, 1]. Then we look if u < \frac{f(x)}{M \cdot g(x)}, since if that is the case, then (x, u) lives below the function we want to approximate! Note that this pair (x, u) is also a sample of the function we wished to approximate.

Results

The theory is summarized in the following Python script:

from scipy.stats import norm
import matplotlib.pyplot as plt
import numpy as np

# The multiplication constant to make our probability estimation fit
M = 3
# Number of samples to draw from the probability estimation function
N = 1000

# The target probability density function
f = lambda x: 0.6 * norm.pdf(x, 0.35, 0.05) + 0.4 * norm.pdf(x, 0.65, 0.08)

# The approximated probability density function
g = lambda x: norm.pdf(x, 0.45, 0.2)

# A number of samples, drawn from the approximated probability density function
x_samples = M * np.random.normal(0.45, 0.2, (N,))

# A number of samples in the interval [0, 1]
u = np.random.uniform(0, 1, (N, ))

# Now examine all the samples and only use the samples found by rejection sampling
samples = [(x_samples[i], u[i] * M * g(x_samples[i])) for i in range(N) if u[i] < f(x_samples[i]) / (M * g(x_samples[i]))]

# Make the plots
fig, ax = plt.subplots(1, 1)

# The x coordinates
x = np.linspace(0, 1, 500)

# The target probability function
ax.plot(x, f(x), 'r-', label='$f(x)
Rejection sampling

The resulted samples are the following:

Samples.

These samples clearly lie below the function which we wished to sample from:

End result.

So this method works. However, note that this method is not efficient. We throw away many points and therefore waste a large amount of time. A better method is for example Markov Chain Monte Carlo (MCMC). The rejection sampling method is efficient in easy probability density functions, but you will have many problems (read: you throw a way a lot of points) if you have functions with many peeks.

Share this post on:

Get updates in your inbox

Join over 8,000 data science learners.

Get updates in your inbox

Join over 8,000 data science learners.

Share this post on