A Parallel loop in Python with Joblib.Parallel

The goal of this post is to perform an embarrassingly parallel loop in Python, with the same code running on different platforms [Linux and Windows]. From wikipedia, here is a definition of embarassingly parallel:

In parallel computing, an embarrassingly parallel workload or problem [...] is one where little or no effort is needed to separate the problem into a number of parallel tasks. This is often the case where there is little or no dependency or need for communication between those parallel tasks, or for results between them.

In the following, we are going to parallelize a loop with independent iterations. More specifically, we have list of natural numbers and want to check each number for primality.

Imports

from joblib import Parallel, delayed
from numba import jit
import numpy as np
import pandas as pd
import perfplot
import primesieve

Computations are performed on a laptop with an 8 cores Intel i7-7700HQ CPU @ 2.80GHz running Linux. Package versions:

Python implementation: CPython
Python version       : 3.9.10
IPython version      : 8.0.1
perfplot  : 0.10.1
pandas    : 1.4.0
numpy     : 1.21.5
primesieve: 2.3.0


The Parallel loop

Let's say that we have a list of k natural numbers $\left[ n_1, n_2, ..., n_k \right]$. We want to send roughly equal amounts of work to n_jobs processes and gather a list with all the k boolean results. For example, with k=16 and n_jobs=4:

A Simple primality test

A primality test is an algorithm for determining whether an input number is prime. We are going to proceed with an old-school trial division method. For a given positive integer $n$, we check whether it can be divided by a smaller integer $p$. Also, we are going to perform two simple optimizations:

1. If $n$ is not prime, it can be written as a product $n=p , q$, and we obviously cannot have both $p$ and $q$ larger than $\sqrt{n}$. So we only test if $n$ can be divided by $p$ smaller or equal to $\sqrt{n}$. For example, when testing 25 for primality, we would check if 2, 3 or 5 are divisors. Numbers between 9 and 24 are either prime or multiple of 2 or 3. Numbers between 4 and 8 are either prime or even.
2. Any integer $p \geq 5$ can be written as a multiple of 6 plus an integer $i$ between -1 and 4: $p=6k+i$, with $k \geq 1$ and $i \in \lbrace -1, 0, 1 , 2, 3, 4 \rbrace$. If $i \in \lbrace 0, 3 \rbrace$ then $p$ is a multiple of 3, and if $i \in \lbrace 2, 4\rbrace$ then $p$ is a multiple of 2. So if we already made sure that $n$ is not a multiple of 2 or 3, we only need to check if $n$ can be divided by $p=6 , k \pm 1$.

To summarize, we need to:

1. deal first with $n \leq 3$
2. check if $n&gt;3$ is a multiple of 2 or 3.
3. check if $p$ divides $n$ for $p = 6k \pm 1$ with $k \geq 1$ and $p \leq \sqrt{n}$. Note that we start here with $p=5$.

Here is a Python implementation of this test [here is the reference]:

def is_prime(n: int) -> bool:
if n <= 3:
return n > 1
if (np.mod(n, 2) == 0) or (np.mod(n, 3) == 0):
return False
sqrt_n = int(np.floor(np.sqrt(n)))
p = 5
while p <= sqrt_n:
if (np.mod(n, p) == 0) or (np.mod(n, p + 2) == 0):
return False
p += 6
return True

Let's check that it is correct with the list of prime numbers up to N=1000. We use the pyprimesieve package to computer a reference array of primes not larger than N:

N = 1_000
prime_indices_ref = np.array(primesieve.primes(N))
prime_indices_ref[-10:]
array([937, 941, 947, 953, 967, 971, 977, 983, 991, 997], dtype=uint64)


Then we also compute this list of primes using is_prime() and check it is equal to the above reference array:

is_prime_vec = list(map(is_prime, range(N + 1)))
prime_indices = np.where(is_prime_vec)[0]
np.testing.assert_array_equal(prime_indices, prime_indices_ref)

The is_prime function seems to work fine, but let's try to accelerate it with Numba!

Numba

We just add the magic jit decorator to the function and perform the compilation by calling it once:

@jit(nopython=True, fastmath=True)
def is_prime_numba(n: int) -> bool:
if n <= 3:
return n > 1
if (np.mod(n, 2) == 0) or (np.mod(n, 3) == 0):
return False
sqrt_n = int(np.floor(np.sqrt(n)))
p = 5
while p <= sqrt_n:
if (np.mod(n, p) == 0) or (np.mod(n, p + 2) == 0):
return False
p += 6
return True

is_prime_numba(1234)
False


In order to measure the performance amelioration, we create a random number generator and generate an array of 1000 random integers. We then compare the is_prime and is_prime_numba functions by applying them to a range of integer arrays.

SD = 124
rng = np.random.default_rng(seed=SD)
SIZE = 1_000

out = perfplot.bench(
setup=lambda n: rng.integers(
np.power(10, n),
np.power(10, n + 1),
SIZE,
dtype=int,
endpoint=True,
),
kernels=[
lambda numbers: list(map(is_prime, numbers)),
lambda numbers: list(map(is_prime_numba, numbers)),
],
labels=["is_prime", "is_prime_numba"],
n_range=range(3, 13),
)
df = pd.DataFrame(data=out.timings_s, columns=out.n_range, index=out.labels).T
ax = df.plot(figsize=(12, 12), logy=True)
_ = ax.set(
title="Sequential acceleration with Numba",
xlabel="log10(n)",
ylabel=f"Runtime[s] (SIZE = {SIZE})",
)

OK this is much faster with Numba. We are now ready to parallelize the loop on the integer array.

Parallel loop with Joblib

We are going to use joblib with the default loky backend. Loky is a cross-platform and cross-version implementation of the ProcessPoolExecutor class of concurrent.futures. One of its main features is [from Loky's github repository]:

No need for if __name__ == "__main__": in scripts: thanks to the use of cloudpickle to call functions defined in the main module, it is not required to protect the code calling parallel functions under Windows.

This might be useful when writing a cross-platform library. We won't go much into details about how new processes are created with multiprocessing in Linux and Windows systems, but here is for example a post describing this issue, written by Aquiles Carattino on the pythonforthelab website.

The current code is working with loky and multiprocessing on both Linux and Windows, but this is because it is running in a JupyterLab notebook. It would get trickier for a Python module on Windows with multiprocessing to avoid recursive spawning of subprocesses.

Also, note that is possible to use the dask backend on a distributed cluster.

Let's start by generating an array of rather large integers:

LOW = int(1e12)
HIGH = int(1e13)
numbers = rng.integers(LOW, HIGH, SIZE, dtype=int, endpoint=True)
numbers[:5]
array([4229485319711, 6682422302023, 6589525053365, 7499770743267,
3694428446242])


We define a sequential function that applies is_prime_numba to all the array, returning a Pandas dataframe:

def is_prime_array_seq(numbers):
results = list(map(is_prime_numba, numbers))
res_df_seq = pd.DataFrame(
list(zip(numbers, results)), columns=["number", "is_prime"]
).set_index("number")
res_df_seq.sort_index(inplace=True)
return res_df_seq

res_df_seq = is_prime_array_seq(numbers)
res_df_seq.head(3)
is_prime
number
1010118466699 True
1010316874298 False
1011838077604 False

Now the parallel version is_prime_array_par:

def is_prime_array_par(numbers, n_jobs=8, batch_size=100, backend="loky"):
results = Parallel(
n_jobs=n_jobs, batch_size=batch_size, backend=backend, verbose=0
)(delayed(is_prime_numba)(n) for n in numbers)
res_df_par = pd.DataFrame(
list(zip(numbers, results)), columns=["number", "is_prime"]
).set_index("number")
res_df_par.sort_index(inplace=True)
return res_df_par

res_df_par = is_prime_array_par(numbers)
res_df_par.head(3)
is_prime
number
1010118466699 True
1010316874298 False
1011838077604 False
pd.testing.assert_frame_equal(res_df_seq, res_df_par)

Let's compare the execution time of the sequential with the parallel version.

out = perfplot.bench(
setup=lambda n: rng.integers(
np.power(10, n),
np.power(10, n + 1),
SIZE,
dtype=int,
endpoint=True,
),
kernels=[
lambda numbers: is_prime_array_seq(numbers),
lambda numbers: is_prime_array_par(numbers, n_jobs=1),
lambda numbers: is_prime_array_par(numbers, n_jobs=2),
lambda numbers: is_prime_array_par(numbers, n_jobs=4),
lambda numbers: is_prime_array_par(numbers, n_jobs=8),
],
labels=["seq", "par_1", "par_2", "par_4", "par_8"],
n_range=range(12, 18),
)
df = pd.DataFrame(data=out.timings_s, columns=out.n_range, index=out.labels).T
ax = df.plot(figsize=(12, 12), logy=True)
_ = ax.set(
title="Parallel acceleration with Joblib",
xlabel="log10(n)",
ylabel=f"Runtime[s] (SIZE = {SIZE})",
)

We can observe that the overhead of the parallelization is really significant: in the present case, it is only when $n \geq 10^{15}$ that the parallel version is faster than the sequential one, due to the heavier computational burden per worker. Also, n_jobs=4 is faster than n_jobs=8, probably because some other jobs were running on the laptop.

Conclusion

Thanks to Joblib with the loky backend, it is fairly easy to run an efficient embarrassingly parallel loop in Python. Furthermore, the same code is going to work on both Linux and Windows systems. However, one should make sure that the task distributed to each worker is large enough, so that the overhead induced by the pickling operation [happening when creating child processes] is negligible with regards to the worker computational load.