Friday, August 14, 2015

Lists, Comprehensions, and Generators

In A Student’s Guide to Python for Physical Modeling, we emphasized NumPy arrays and paid less attention to Python lists. The reason is simple: In most scientific computing applications, NumPy arrays store data more efficiently and speed up mathematical calculations, sometimes a thousandfold.

However, there are some applications where a Python list is the better choice. There are also times when the choice between a list and an array has little or no effect on performance. In such cases a list can make your code easier to read and understand, and that is always a good thing.

In this post, I will describe Python lists and explain a special Python construct for creating lists called a list comprehension. I will also describe a similar construct called a generator expression.


A list is an ordered collection of items. You may have made a “To Do” list this morning or a grocery list for a recent trip to the store. In computer science, a list is a data structure that supports a few basic methods like insert, remove, append, and find. You probably used several of these operations with your own list. Perhaps you penciled in a new task later in the day (append), then crossed tasks off the list as you completed them (remove).

An array is a rigid data structure that stores a fixed number of identical elements (integers, floats, eight-character strings, etc.). If operations like insert, remove, or append are important parts of a computational task, a more flexible data structure like a list may be appropriate. The type of data may also suggest that a Python list is a better choice than a NumPy array. For instance, how would you initialize an array to store a grocery list? Furthermore, if the number of items to be stored is not known at the outset, it may be easier to store the data in a list and convert it to an array later. (Perhaps you are taking data at regular intervals but do not know how long an experiment will run.) Finally, if you are not worried about performance and scaling, a Python list might be a simpler option than a NumPy array. If you just need the first 20 perfect cubes, do you really want to import all of NumPy?

Let’s use the example of a grocery list to create and transform a simple list. A more useful example in scientific computing might be managing a collection of data files to process into stunning figures for your latest report, but the principles are the same.

To create a Python list, just enclose a comma-separated list of elements inside square brackets.

groceries = ['milk', 'eggs', 'orange juice']

Some functions also return a list. I can create the same list from a string using the split method:

groceries = "milk, eggs, apples, orange juice".split(',')

To find an item in a list, use the index method. It will return the index of the first occurrence of the item you request if it is in the list and a ValueError if it is not.


It looks like I forgot the bread! I can add it to the list using the append method:


Later, I see some orange juice in the back of the refrigerator (not yet expired …), so I will delete that item from the list using the remove method:

groceries.remove('orange juice')

Since I am headed to ABC Grocery, where everything is organized alphabetically, I will sort the list and review it:


One more useful operation is joining lists, or concatenation. In Python, the addition operator for lists is defined to join two lists. The extend method of a list will also join two lists. Calling append with a list as its argument will not add each element to the original list. It will make the argument list the final element of the calling list. I.e., you will have a list within a list, not a concatenation of the two lists.

If I had two separate grocery lists, I could join them into a single list in a variety of ways:

old_list = ['bananas', 'coffee']
new_list = groceries + old_list         # Addition creates a new list.
groceries += old_list                   # In-place addition extends original list.
old_list.extend(groceries)              # extend also extends original list.

bad_list = ['bananas', 'coffee']
bad_list.append(groceries)              # append does NOT concatenate lists.

After this set of commands, groceries and new_list contain the same elements. old_list contains 'bananas' and 'coffee' twice, since the commands append old_list to groceries first, and then append the new groceries to the original old_list. As you can see, bad_list did not merge the two lists properly.

In case you are skeptical of the usefulness of Python lists in scientific endeavors, here is a function that creates a list of the first N Fibonacci numbers.

def fibonacci(N):
    if N == 0: return [1]       # Handle unusual request for 0th number.
    fib = [1, 1]                # Create a list for all other values of N.
    for k in range(1,N):
        # Next Fibonacci number is the sum of the previous two.
        fib.append(fib[-1] + fib[-2])
    return fib

If you are still skeptical of the possible utility of a Python list, try writing the same function using NumPy arrays and using it to compute fibonacci(100). I’ve included a solution at the end of the post.

The command inside the loop could also have been written using list addition. Either of the following commands will work:

fib += [fib[-1] + fib[-2]]
fib = fib + [fib[-1] + fib[-2]]

The second approach is less efficient because it makes a copy of the list every time a new element is added, but it is syntactically correct. Be aware that you can only use list addition to join two lists — not a list and some other object you would like to put inside it. The following alternatives would result in errors:

fib += fib[-1] + fib[-2]
fib += (fib[-1] + fib[-2])
fib = fib + (fib[-1] + fib[-2])

List Comprehensions

Sometimes a list is not a collection of random elements; it has a logical structure instead. For instance, suppose you want to find the sum of the first 20 perfect cubes. You could create an array or a list of cubes, then add them up. The familiar procedure using a NumPy array is

import numpy as np
cubes = np.arange(1,21)**3

This does not require too much typing, and the purpose of the code is fairly clear. However, compare it with the following code:

cubes = [n**3 for n in range(1,21)]

This is an example of a list comprehension: a Python expression inside of the square brackets that denote a list. The statement is similar to the notation used in mathematics to define a set. It also clearly describes what the list contains. Note the similarity of the list comprehension to the following loop, which creates the same list:

cubes = []                  # Initialize empty list.
for n in range(1,21):
    cubes.append(n**3)      # Append perfect cubes.

The list comprehension effectively compresses all of this into a single Python statement.

A list comprehension defines a new list from another collection of objects via an expression like “for n in ...” Rather than using range, you can build one list from another:

poly = [(x+1)*(x-1) for x in cubes]

You can also apply conditional statements in a list comprehension:

even_squares = [n**2 for n in range(1,51) if n%2 == 0]
odd_squares  = [n**2 for n in range(1,51) if n%2 == 1]

You can cram quite a lot of code into a list comprehension, but it is not always advisable:

pythagoras = [(a,b,c)   for a in range(1,31) for b in range(a,31) \
                        for c in range(1,31) if a**2 + b**2 == c**2]

Despite the length and complexity of this single expression, its meaning is still fairly clear.

Returning to our original task, we can do even better than adding up the first 20 perfect cubes. Using a nested list comprehension, we can make a list of sums of cubes!

sums_of_cubes = [sum([n**3 for n in range(1,N+1)]) for N in range(1,21)]


A list comprehension creates a Python list that stores all of the elements in a single data structure. Sometimes this is exactly what you need. Other times, you simply want to iterate over all of the items in a list. If you never need all of the items in the list at once, you can use a generator expression instead. A generator expression looks like a list comprehension, except that you enclose the expression in round parentheses instead of square brackets — (...) instead of [...]. Despite the round parentheses, a generator expression does not create a tuple, and there is no such thing as a “tuple comprehension”.

cube_list = [n**3 for n in range(1,101)]
cube_generator = (n**3 for n in range(1,101))

A generator is simpler than a list. You cannot insert, remove, or append items, nor can you search or sort a generator. A generator knows how to produce the next item in a sequence, and little else. Once it has reached the end of its sequence, it does even less.

for x in cube_list: print(x)            # Prints numbers stored in list.
for x in cube_list: print(x)            # Prints numbers stored in list again.

for x in cube_generator: print(x)       # Prints numbers provided by generator.
for x in cube_generator: print(x)       # Prints nothing.  Generator is finished.

The advantages of a generator over a list are size and speed. Compare the output of the __sizeof__() method for the following lists and generators. This method returns the size of the object in bytes.

cube_list = [n**3 for n in range(1,10)]
cube_generator = (n**3 for n in range(1,10))

cube_list = [n**3 for n in range(1,10**3)]
cube_generator = (n**3 for n in range(1,10**3))

cube_list = [n**3 for n in range(1,10**6)]
cube_generator = (n**3 for n in range(1,10**6))

The list grows from 168 bytes to 9 kB to 8.7 MB, while the generator remains a constant 48 bytes. Also, you may have noticed a delay while Python created the large list during the last set of commands.

I generally prefer a generator when I iterate over a large sequence of items once — especially if the program might exit the loop before reaching the end of the sequence.


NumPy arrays are often the most efficient data structure for numerical work in Python. However, there are some tasks for which a Python list is a better choice — often when organizing data rather than processing data. Python offers a compact syntax for creating lists called a list comprehension. A generator expression is similar, but creates an object that can produce a sequence without storing all of its elements. A generator is often a better choice than a list or an array when iterating over a large sequence of items.

NumPy version of fibonacci(N)

Here is a version of the fibonacci(N) function above that uses NumPy arrays.

import numpy as np

def Fibonacci(N):
    if N == 0: return np.array(1)   # Handle unusual request for 0th number.
    fib = np.zeros(N+1, dtype=int)  # Initialize list for all other values of N.
    fib[0], fib[1] = 1, 1
    for k in range(2,N+1):
        # Next Fibonacci number is the sum of the previous two.
        fib[k] = fib[k-1] + fib[k-2]
    return fib

Perhaps you came up with a more elegant solution. I find this version more difficult to code and more confusing to read. Plus, using a NumPy array forces a compromise: Either use floating point numbers and lose significant digits for N > 78, or use integers and encounter an overflow error for N > 91. In either case, you cannot generate the 100th Fibonacci number!

Wednesday, August 5, 2015

Function Arguments: *args and **kwargs

In Python, functions can have positional arguments and named arguments. In this post, I will describe both types and explain how to use special syntax to simplify repetitive function calls with nearly the same arguments. This extends the discussion in section 5.1.3 of A Student’s Guide to Python for Physical Modeling.

First, let’s look at np.savetxt, which has a straightforward declaration:

$ import matplotlib.pyplot as plt
$ from mpl_toolkits.mplot3d import Axes3D

$ np.savetxt?
Signature: np.savetxt(  fname, X, fmt='%.18e', delimiter=' ', newline='\n',
                        header='', footer='', comments='# ')
Docstring: Save an array to a text file.

We see the function has two required arguments, followed by several optional arguments with default values. Next, let’s look at something more exotic:

$ Axes3D.plot_surface?
Signature: Axes3D.plot_surface(X, Y, Z, *args, **kwargs)
Docstring: Create a surface plot.

The first three arguments seem obvious enough: These are the arrays that specify the points on the surface. The last two — *args and **kwargs — look strange. Let’s examine one more function:

$ plt.plot?
Signature: plt.plot(*args, **kwargs)
Docstring: Plot lines and/or markers to the :class:`~matplotlib.axes.Axes`.

*args and **kwargs are the only arguments for the familiar plotting function! They are the focus of this post.

Positional Arguments

In a Python function, positional arguments are Python expressions assigned to function variables based on their position in the function call.

Suppose I create a surface plot of topographic data with the command

ax = Axes3D(plt.figure())
ax.plot_surface(latitude, longitude, elevation)

Python evaluates ax.plot_surface with the substitutions

X = latitude
Y = longitude
Z = elevation

I.e., the substitutions are based on the positions of the arguments. I would get a different (meaningless) surface if I shuffled the order around:

ax.plot_surface(elevation, longitude, latitude)


The *args argument in a function definition allows the function to process an unspecified number of positional arguments. Let’s look at a simple example:

def get_stats(*args):
    from numpy import mean, std
    return mean(args), std(args)

This function will compute the descriptive statistics (mean and standard deviation) of any sequence of values passed to it. Try the following commands:

get_stats(1, 2, 3, 4, 5)

You can type any number of arguments when calling the function, or you can pass the function any sequence of values — an array, a tuple, a list.

This ability to process any number of arguments is what makes it possible to call plt.plot in a variety of ways. All of these commands are valid:

t = np.linspace(-1, 1, 101)
plt.plot(t, t**2 - 1)
plt.plot(t, t**3 - t, 'k--')
plt.plot(t, t**4 - t**2, t, t**5 - t**3 + t)

How can one function process so many different kinds of input, including mixtures of variable names, expressions, and strings? The plot function has several subroutines that determine exactly what is in the series of arguments you supply and what to do with those objects. This can make a function very flexible, but it is also likely to be complex — both to write and to interpret.

You can use the *args notation to “unpack” a sequence into a series of positional arguments for any function. For example, suppose the three topographic data arrays mentioned earlier had been packaged as a single tuple:

data = (latitude, longitude, elevation)

The surface plot function does not know what to do with this tuple, but I can use the *args notation to assign the three arrays to X, Y, and Z.

ax.surface_plot(data)       # Raises an exception.
ax.surface_plot(*data)      # Creates surface plot.

The *data command instructs Python to assign the items in data to the positional arguments of ax.surface_plot.

This method of passing positional arguments to functions can be convenient when you wish to automate calculations using various combinations of input parameters or to ensure that several functions use the same data:

data = (x, y, z)

If I want to perform the same analysis on a different set of data later, I only need to change the data variable.

Named Arguments

In a Python function, named arguments are Python expressions whose value in the function is specified by a keyword-value pair.

For example, this function call from the Illuminating Surface Plots post uses named arguments to specify several options:

ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

Function like plt.plot and Axes3D.plot_surface whose definitions include **kwargs can accept any number of keyword arguments.


Similar to the *args notation, you can use the **kwargs notation to pass a collection of named arguments to a function. To do this, you must package the keyword-value pairs in a dictionary.

A dictionary is a Python data structure that associates an immutable object called a key with a value, which can be mutable or immutable. A future post will discuss dictionaries in more detail. For this post, only the syntax for creating a dictionary is important. Enclose the contents of a dictionary between curly braces: { ... }. Each entry of the dictionary is a key, followed by a colon, followed by a value:

definitions = { 'cat':"n. A feline.", 'dog':"n. A canine."}


The first command creates a dictionary. The second accesses one of its members.

In Illuminating Surface Plots, I used the same set of plotting options many times. This led to a lot of typing, and for the commands in the script, a lot of retyping every time I decided to change one of these options. A more efficient method is to … Define once, reuse often. I can put all of the data arrays into a tuple and most of the plotting options into a dictionary:

data = (X, Y, Z)
plot_options = {'rstride':1,

Most surface plot commands were identical except for the value of the facecolors option. I could create two surface plots with different values of this argument as follows:

ax1.plot_surface(*data, facecolors=green_surface, **plot_options)
ax2.plot_surface(*data, facecolors=illuminatedn_surface, **plot_options)

This is easier to type and ensures that both plots use the same data set and plotting options.


Python function accept positional and named arguments. Functions whose definitions include the arguments *args and **kwargs will accept an unspecified number of either type. You can use this notation to unpack a sequence (list, tuple, or array) into a series of positional arguments or a dictionary into a series of named arguments. This provides a convenient method for calling the same function with slightly different inputs and options, or calling different functions with the same inputs and options.

The rules for supplying arguments to functions are as follows:

  1. Positional arguments (if any) must come first.
  2. Named arguments (if any) and/or positional arguments in the form of *args (if any) must come next.
  3. Named arguments in the form of **kwargs (if any) must come last.

Saturday, August 1, 2015

Illuminating Surface Plots

Matplotlib provides functions for visualizing three-dimensional data sets. One useful tool is a surface plot. A surface plot is a two-dimensional projection of a three-dimensional object. Much like a sketch artist, Python uses techniques like perspective and shading to give the illusion of a three-dimensional object in space. In this post, I describe how you can control the lighting of a surface plot.

Surface Plots

First, let’s look at some of the options available with the default three-dimensional plotting tools. This script will create a surface plot of a Bessel function. Its ripples will emphasize the effects of lighting later.

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D     # Import 3D plotting tools.
from scipy.special import jn                # Import Bessel function.

# Define grid of points.
points = np.linspace(-10, 10, 51)
X, Y = np.meshgrid(points, points)
R = np.sqrt(X**2 + Y**2)
Z = jn(0,R)

# Create 3D surface plot.
ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1)

The default surface plot is a single color with grid lines and some shading. To get rid of the grid lines use the following plot commands instead:

ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

We can use a colormap to assign colors to the figure based on the height of the surface

ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

Assigning colors according to height may not be what you want, and there is no shading when using a color map. Furthermore, when you specify a single color, you cannot adjust the lighting angle to produce different shading effects. Sometimes, you may want to control these lighting effects, and Matplotlib provides a way.

Turn on the Lights

Tucked away in Matplotlib is an object called LightSource. It allows you to simulate illuminating a surface using a virtual light source placed at a location of your choosing. (LightSource creates an “illuminated intensity map.” You can find the details of the model in the source code.) It does not provide the same control or as many features as the lighting tools of commercial packages like MATLAB or Mathematica, but it is sufficient to produce some nice plots.

To use the object, import it from matplotlib.colors and then apply its shade method to the data set:

# Get lighting object for shading surface plots.
from matplotlib.colors import LightSource

# Get colormaps to use with lighting object.
from matplotlib import cm

# Create an instance of a LightSource and use it to illuminate the surface.
light = LightSource(90, 45)
illuminated_surface = light.shade(Z, cmap=cm.coolwarm)

The two arguments to LightSource are the azimuth and elevation angles of the light source. (0,0) corresponds to a light placed along the x-axis. As the name implies, the elevation is the angle above the xy-plane in degrees. The virtual light source position is then rotated about the vertical axis by the azimuth angle (also in degrees). (Don’t confuse these parameters with the similarly named parameters specifying the observer’s position!)

The function requires a single argument: a two-dimensional array — here, Z. LightSource interprets each data point as the height of a surface above a point in the xy-plane. It also assumes these points have the same spacing in the x and y directions. If you are not using Cartesian coordinates and uniform spacing, you may be surprised by the result.

The object returned by light.shade is a NumPy array of RGBA values for each data point. (For each point in the input array, light.shade returns a 4-element array of the Red, Green, Blue, and Alpha value for that point. Alpha controls the transparency of the point.) Other plotting tools can use this data to draw a shaded surface. To use this array instead of the color or cmap options of the surface plotting command, pass the array with a keyword argument:

ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

If you prefer to shade a surface of uniform color instead of using a color map, or if you have a colored surface that you wish to shade, a LightSource object offers a second method called shade_rgb. You have to pass the function two arguments: an array of RGB values and a data set giving the height of each point.

As an example, let’s transform a white surface so we can see the shading effects independent of any coloring. The RGB values for white are [1.0, 1.0, 1.0]. (Red, green, and blue values are all at maximum.) To create a uniform white surface, we need to create an array with three elements for every element of the data set Z, with each entry set to 1.0. The following code will create the RGB array, shade it, and plot it:

rgb = np.ones((Z.shape[0], Z.shape[1], 3))
illuminated_surface = light.shade_rgb(rgb, Z)
ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

To change the color of the shaded surface, we can use NumPy array math. Just make a three-element array with the RGB values of your target color and multiply rgb by this array before shading.

# Create a shaded green surface.
green = np.array([0,1.0,0])
green_surface = light.shade_rgb(rgb * green, Z)
ax = Axes3D(plt.figure())
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0, antialiased=False,

The figure below illustrates the techniques described here. The same surface is shown with four different color and lighting configurations. The code that produced the figures is also included below.

Shading of 3D surfaces.
Shading of 3D surfaces.

# =========================================================================
# Author:   Jesse M. Kinder
# Created:  2015 Jul 27
# Modified: 2015 Jul 31
# -------------------------------------------------------------------------
# Demonstrate shading of surface plots using Matplotlib's LightSource.
# ------------------------------------------------------------------------- 
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

# Import Bessel function.
from scipy.special import jn

# Import colormaps.
from matplotlib import cm

# Import lighting object for shading surface plots.
from matplotlib.colors import LightSource

# Define grid of points.
points = np.linspace(-10, 10, 101)
X, Y = np.meshgrid(points, points)
R = np.sqrt(X**2 + Y**2)
Z = jn(0,R)

# Create an rgb array for single-color surfaces.
white = np.ones((Z.shape[0], Z.shape[1], 3))
red = white * np.array([1,0,0])
green = white * np.array([0,1,0])
blue = white * np.array([0,0,1])

# Set view parameters for all subplots.
azimuth = 45
altitude = 60

# Create empty figure.
fig = plt.figure(figsize=(18,12))

# -------------------------------------------------------------------------
# Generate first subplot.
# ------------------------------------------------------------------------- 
# Create a light source object for light from
# 0 degrees azimuth, 0 degrees elevation.
light = LightSource(0, 0)

# Generate face colors for a shaded surface using either
# a color map or the uniform rgb color specified above.

illuminated_surface = light.shade_rgb(red, Z)

# Create a subplot with 3d plotting capabilities.
# This command will fail if Axes3D was not imported.
ax = fig.add_subplot(2,2,1, projection='3d')
ax.view_init(altitude, azimuth)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                antialiased=False, facecolors=illuminated_surface)

# -------------------------------------------------------------------------
# Repeat the commands above for the other three subplots, but use different
# illumination angles and colors.
# ------------------------------------------------------------------------- 
light = LightSource(90, 0)
illuminated_surface = light.shade_rgb(green, Z)

ax = fig.add_subplot(2,2,2, projection='3d')
ax.view_init(altitude, azimuth)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                antialiased=False, facecolors=illuminated_surface)

# ------------------------------------------------------------------------- 
light = LightSource(90, 45)
illuminated_surface = light.shade_rgb(blue, Z)

ax = fig.add_subplot(2,2,3, projection='3d')
ax.view_init(altitude, azimuth)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                antialiased=False, facecolors=illuminated_surface)

# ------------------------------------------------------------------------- 
light = LightSource(180, 45)
illuminated_surface = light.shade(Z, cmap=cm.coolwarm)

ax = fig.add_subplot(2,2,4, projection='3d')
ax.view_init(altitude, azimuth)
ax.plot_surface(X, Y, Z, rstride=1, cstride=1, linewidth=0,
                antialiased=False, facecolors=illuminated_surface)

# -------------------------------------------------------------------------