Skip to main content

The Mathematics Behind: Rejection Sampling

End result.

Suppose that we have a probability density function (PDF) f(x) that is impossible to analyze analytically. How can we ever draw samples from this PDF? Luckily, there are many techniques out there and this time I will highlight rejection sampling. A simple to implement (but not always effective) sampling method.

The goal

Suppose we have a function that is not computationally tracktable. How can we ever draw samples from it? One (of the many) methods that is out there, is rejection sampling. In this article, we will try to draw samples from the following distribution:

Hard function.

The secret function in this plot is f(x) = \frac{3}{5} \mathcal{N}(x|\frac{7}{20}, \frac{1}{20}) + \frac{2}{5} \mathcal{N}(x|\frac{13}{20}, \frac{2}{25}), which is a multimodal Gaussian. (Note: this function is mathematically tractable, but it is used for demonstration purposes.)


Now we have to come up with a probability density function g(x), such that M g(x) > f(x) for all x. In other words, we need a PDF that is, when multiplied by a constant M, higher than f(x). On our example, one such a function is g(x)=\mathcal{N}(x|\frac{9}{20}, \frac{1}{5}) and M=3. How did I found this function? First of all, Gaussians are easy to compute. Therefore, I choose a Gaussian which has its mean around the “middle” of f(x). I then tuned its variance and the constant M by just trying out some values. The result is the following M \cdot g(x), which is indeed always more than f(x):

The approximation.

The mathematics

So, what to do next? We sample pairs of (x, u) where x are samples from the approximation Gaussian and u is sampled from a uniform probability distribution in the interval [0, 1]. Then we look if u < \frac{f(x)}{M \cdot g(x)}, since if that is the case, then (x, u) lives below the function we want to approximate! Note that this pair (x, u) is also a sample of the function we wished to approximate.


The theory is summarized in the following Python script:

from scipy.stats import norm
import matplotlib.pyplot as plt
import numpy as np

# The multiplication constant to make our probability estimation fit
M = 3
# Number of samples to draw from the probability estimation function
N = 1000

# The target probability density function
f = lambda x: 0.6 * norm.pdf(x, 0.35, 0.05) + 0.4 * norm.pdf(x, 0.65, 0.08)

# The approximated probability density function
g = lambda x: norm.pdf(x, 0.45, 0.2)

# A number of samples, drawn from the approximated probability density function
x_samples = M * np.random.normal(0.45, 0.2, (N,))

# A number of samples in the interval [0, 1]
u = np.random.uniform(0, 1, (N, ))

# Now examine all the samples and only use the samples found by rejection sampling
samples = [(x_samples[i], u[i] * M * g(x_samples[i])) for i in range(N) if u[i] < f(x_samples[i]) / (M * g(x_samples[i]))]

# Make the plots
fig, ax = plt.subplots(1, 1)

# The x coordinates
x = np.linspace(0, 1, 500)

# The target probability function
ax.plot(x, f(x), 'r-', label='$f(x)$')

# The approximated probability density function
ax.plot(x, M * g(x), 'b-', label='$M \cdot g(x)$')

# The samples found by rejection sampling
ax.plot([sample[0] for sample in samples], [sample[1] for sample in samples], 'g.', label='Samples')

# Set the window size
axes = plt.gca()
axes.set_xlim([0, 1])
axes.set_ylim([0, 6])

# Show the legend

# Set the title
plt.title('Rejection sampling')

# Show the plots

The resulted samples are the following:


These samples clearly lie below the function which we wished to sample from:

End result.

So this method works. However, note that this method is not efficient. We throw away many points and therefore waste a large amount of time. A better method is for example Markov Chain Monte Carlo (MCMC). The rejection sampling method is efficient in easy probability density functions, but you will have many problems (read: you throw a way a lot of points) if you have functions with many peeks.


Try to draw samples from the following PDF:

f(x) = \frac{1}{5} \mathcal{N}(x|\frac{5}{20}, \frac{4}{20}) + \frac{2}{5} \mathcal{N}(x|\frac{19}{20}, \frac{6}{25}) + \frac{2}{5} \mathcal{N}(x|\frac{24}{20}, \frac{1}{25}).

Kevin Jacobs

Kevin Jacobs

Kevin Jacobs is a certified Data Scientist and blog writer for Data Blogger. He is passionate about any project that involves large amounts of data and statistical data analysis. Kevin can be reached using Twitter (@kmjjacobs), LinkedIn or via e-mail: