Why Do Many Examples Use 'Fig, Ax = Plt.Subplots()' in Matplotlib/Pyplot/Python

Why do many examples use `fig, ax = plt.subplots()` in Matplotlib/pyplot/python

plt.subplots() is a function that returns a tuple containing a figure and axes object(s). Thus when using fig, ax = plt.subplots() you unpack this tuple into the variables fig and ax. Having fig is useful if you want to change figure-level attributes or save the figure as an image file later (e.g. with fig.savefig('yourfilename.png')). You certainly don't have to use the returned figure object but many people do use it later so it's common to see. Also, all axes objects (the objects that have plotting methods), have a parent figure object anyway, thus:

fig, ax = plt.subplots()

is more concise than this:

fig = plt.figure()
ax = fig.add_subplot(111)

fig, ax = plt.subplots() meaning

plt.subplots() is basically a (very nice) shortcut for initializing a figure and subplot axes. See the docs here. In particular,

>>> fig, ax = plt.subplots(1, 1)

is essentially equivalent to

>>> fig = plt.figure()
>>> ax = fig.add_subplot(1, 1)

But plt.subplots() is most useful for constructing several axes at once, for example,

>>> fig, axes = plt.subplots(2, 3)

makes a figure with 2 rows and 3 columns of subplots, essentially equivalent to

>>> fig = plt.figure()
>>> axes = np.empty((2,3))
>>> for i in range(2):
... for j in range(3):
... axes[i,j] = fig.add_subplot(2, 3, (i*j)+j+1)

I say "essentially" because plt.subplots() also has some nice features, like sharex=True forces each of the subplots to share the same x axis (i.e., same axis limits / scales, etc.). This is my favorite way to initialize a figure because it gives you the figure and all of the axes handles in one smooth line.

Understanding of fig, ax, and plt when combining Matplotlib and Pandas

plt.xticks() only works on the "current" ax. You should use ax.set_xticks(), ax.set_xticklabels() and ax.tick_params() instead.

plt.xticks() is a rather old function that is still supported, mimicking similar matlab code, born in a time when people were only plotting onto a single plot. The newer functions are more general with more options.

In short: you don't need to call plt directly, you are invited to use the ax functions instead. When calling plt.xticks(), it gets rerouted to the currently active ax (often the last one created).

basic questions about fig and axes in matplotlib

It is possible to plot without setting any variables. Example:

plt.figure()
plt.plot([1, 2], [5, 8])
plt.show()

You must initialize your figure somewhere, hence plt.figure(). As you pointed out, you can also use plt.subplots():

fig, ax = plt.subplots()
ax.plot([1, 2], [5, 8])
plt.show()

Note that we did not set the ncols and nrows keyword arguments. As the default value for both is 1, we get a single axis (which is why I chose the variable name ax). For n × 1 or 1 × n subplots, the second variable returned by plt.subplots() is a one-dimensional array.

fig, axes = plt.subplots(ncols=1, nrows=2)
axes[0].plot([1, 2], [5, 8])
axes[1].plot([2, 3], [9, 3])

In the case of m × n subplots (m, n > 1), axes is a two-dimensional array.

fig, axes = plt.subplots(ncols=2, nrows=2)
axes[0][0].plot([1, 2], [5, 8])
axes[1][0].plot([2, 3], [9, 3])
axes[0][1].plot([3, 4], [5, 8])
axes[1][1].plot([4, 5], [9, 3])

Whether you use ax or axes as the name for the second variable is your own choice, but axes suggests that there are multiple axes and ax that there is only one. Of course, there are also other ways to construct subplots, as shown in your linked post.

Axes from plt.subplots() is a numpy.ndarray object and has no attribute plot

If you debug your program by simply printing ax, you'll quickly find out that ax is a two-dimensional array: one dimension for the rows, one for the columns.

Thus, you need two indices to index ax to retrieve the actual AxesSubplot instance, like:

ax[1,1].plot(...)

If you want to iterate through the subplots in the way you do it now, by flattening ax first:

ax = ax.flatten()

and now ax is a one dimensional array. I don't know if rows or columns are stepped through first, but if it's the wrong around, use the transpose:

ax = ax.T.flatten()

Of course, by now it makes more sense to simply create each subplot on the fly, because that already has an index, and the other two numbers are fixed:

for x < plots_tot:
ax = plt.subplot(nrows, ncols, x+1)

Note: you have x <= plots_tot, but with x starting at 0, you'll get an IndexError next with your current code (after flattening your array). Matplotlib is (unfortunately) 1-indexed for subplots. I prefer using a 0-indexed variable (Python style), and just add +1 for the subplot index (like above).

Understanding how to use fig, ax = plt to put into existing code?

you have to create an axes object, then plot the dataframe on it. This will then allow you to access the axes object before/after, and do whatever you want:

import pandas as pd
import matplotlib.pyplot as plt

#some arbitrary dataframe
df = pd.DataFrame({'x':[100,200], 'y':[200,300]})
fig, ax = plt.subplots()
df.plot(ax=ax)

#now edit the tick labels
labels = [item.get_text() for item in ax.get_xticklabels()]
labels[0] = 'Testing'
ax.set_xticklabels(labels)

Sample Image

What is the relevance of the first line fig,ax=plt.subplots(), when creating a matplotlib plot?

I tried to test if it was a way to assign 2 variables to the same thing by writing x,y=1, I ended up getting an error "int object is not iterable".

You are almost right. That is syntax to assign multiple variables at the same time, but what you are missing is that plt.subplots() returns a tuple - of two values paired together.

If you want to better understand it you can run:

a, b = (1, 4)

or

a,b = 1, 4

(it's the same as far as python is concerns, it packs/unpacks values to a tuple if multiple values are used or returned)

I tried to delete it and run the code, but I got this error:

'tuple' object has no attribute 'scatter'

This is also related to why you got this error. The figure is indeed not in use in your code snippet, but you need it for python to understand you want to use part of the tuple and not the tuple itself.
For example: a=(1,2) will result in a holding a tuple, but in a, b = 1, 2 each of the created variables will hold an integer.

In your case, the axis object has a method scatter, which the tuple object does not have, hence your error.

return values of subplot

In the documentation it says that matplotlib.pyplot.subplots return an instance of Figure and an array of (or a single) Axes (array or not depends on the number of subplots).

Common use is:

import matplotlib.pyplot as plt
import numpy as np
f, axes = plt.subplots(1,2) # 1 row containing 2 subplots.

# Plot random points on one subplots.
axes[0].scatter(np.random.randn(10), np.random.randn(10))

# Plot histogram on the other one.
axes[1].hist(np.random.randn(100))

# Adjust the size and layout through the Figure-object.
f.set_size_inches(10, 5)
f.tight_layout()

Turn off axes in subplots

You can turn the axes off by following the advice in Veedrac's comment (linking to here) with one small modification.

Rather than using plt.axis('off') you should use ax.axis('off') where ax is a matplotlib.axes object. To do this for your code you simple need to add axarr[0,0].axis('off') and so on for each of your subplots.

The code below shows the result (I've removed the prune_matrix part because I don't have access to that function, in the future please submit fully working code.)

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.cm as cm

img = mpimg.imread("stewie.jpg")

f, axarr = plt.subplots(2, 2)
axarr[0,0].imshow(img, cmap = cm.Greys_r)
axarr[0,0].set_title("Rank = 512")
axarr[0,0].axis('off')

axarr[0,1].imshow(img, cmap = cm.Greys_r)
axarr[0,1].set_title("Rank = %s" % 128)
axarr[0,1].axis('off')

axarr[1,0].imshow(img, cmap = cm.Greys_r)
axarr[1,0].set_title("Rank = %s" % 32)
axarr[1,0].axis('off')

axarr[1,1].imshow(img, cmap = cm.Greys_r)
axarr[1,1].set_title("Rank = %s" % 16)
axarr[1,1].axis('off')

plt.show()

Stewie example

Note: To turn off only the x or y axis you can use set_visible() e.g.:

axarr[0,0].xaxis.set_visible(False) # Hide only x axis


Related Topics



Leave a reply



Submit