Understanding Matplotlib.Subplots Python

understanding matplotlib.subplots python

The different return types are due to the squeeze keyword argument to plt.subplots() which is set to True by default.
Let's enhance the documentation with the respective unpackings:

squeeze : bool, optional, default: True

  • If True, extra dimensions are squeezed out from the returned Axes object:

    • if only one subplot is constructed (nrows=ncols=1), the resulting single Axes object is returned as a scalar.

      fig, ax = plt.subplots()
    • for Nx1 or 1xN subplots, the returned object is a 1D numpy object array of Axes objects are returned as numpy 1D arrays.

      fig, (ax1, ..., axN) = plt.subplots(nrows=N, ncols=1) (for Nx1)

      fig, (ax1, ..., axN) = plt.subplots(nrows=1, ncols=N) (for 1xN)
    • for NxM, subplots with N>1 and M>1 are returned as a 2D arrays.

      fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)
  • If False, no squeezing at all is done: the returned Axes object is always a 2D array containing Axes instances, even if it ends up being 1x1.

    fig, ((ax,),) = plt.subplots(nrows=1, ncols=1, squeeze=False)
    fig, ((ax,), .. ,(axN,)) = plt.subplots(nrows=N, ncols=1, squeeze=False) for Nx1

    fig, ((ax, .. ,axN),) = plt.subplots(nrows=1, ncols=N, squeeze=False) for 1xN

    fig, ((ax11, .., ax1M),..,(axN1, .., axNM)) = plt.subplots(nrows=N, ncols=M)

Alternatively you may always use the unpacked version

fig, ax_arr = plt.subplots(nrows=N, ncols=M, squeeze=False)

and index the array to obtain the axes, ax_arr[1,2].plot(..).

So for a 2 x 3 grid it wouldn't actually matter if you set squeeze to False. The result will always be a 2D array. You may unpack it as

fig, ((ax1, ax2, ax3),(ax4, ax5, ax6)) = plt.subplots(nrows=2, ncols=3)

to have ax{i} as the matplotlib axes objects, or you may use the packed version

fig, ax_arr = plt.subplots(nrows=2, ncols=3)
ax_arr[0,0].plot(..) # plot to first top left axes
ax_arr[1,2].plot(..) # plot to last bottom right axes

How to plot in multiple subplots

There are several ways to do it. The subplots method creates the figure along with the subplots that are then stored in the ax array. For example:

import matplotlib.pyplot as plt

x = range(10)
y = range(10)

fig, ax = plt.subplots(nrows=2, ncols=2)

for row in ax:
for col in row:
col.plot(x, y)

plt.show()

Sample Image

However, something like this will also work, it's not so "clean" though since you are creating a figure with subplots and then add on top of them:

fig = plt.figure()

plt.subplot(2, 2, 1)
plt.plot(x, y)

plt.subplot(2, 2, 2)
plt.plot(x, y)

plt.subplot(2, 2, 3)
plt.plot(x, y)

plt.subplot(2, 2, 4)
plt.plot(x, y)

plt.show()

Sample Image

Why do many examples use `fig, ax = plt.subplots()` in Matplotlib/pyplot/python

plt.subplots() is a function that returns a tuple containing a figure and axes object(s). Thus when using fig, ax = plt.subplots() you unpack this tuple into the variables fig and ax. Having fig is useful if you want to change figure-level attributes or save the figure as an image file later (e.g. with fig.savefig('yourfilename.png')). You certainly don't have to use the returned figure object but many people do use it later so it's common to see. Also, all axes objects (the objects that have plotting methods), have a parent figure object anyway, thus:

fig, ax = plt.subplots()

is more concise than this:

fig = plt.figure()
ax = fig.add_subplot(111)

Subplotting subplots

  • As mentioned, in your case you only need one level of subplots, e.g., nrows=1, ncols=2.
  • However, in matplotlib 3.4+ there is such a thing as "subplotting subplots" called subfigures, which makes it easier to implement nested layouts, e.g.:
    • How to create row titles for subplots
    • How to share colorbars within some subplots
    • How to share xlabels within some subplots

Subplots

For your simpler use case, create 1x2 subplots with ax1 on the left and ax2 on the right:

subplots

# create 1x2 subplots
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(16, 4))

# plot saturdays on the left
dfSat.plot(ax=ax1, x='date', y='temp_min')
dfSat.plot(ax=ax1, x='date', y='temp_max')
ax1.set_ylim(-20, 50)
ax1.set_title('Saturdays')

# plot sundays on the right
dfSun.plot(ax=ax2, x='date', y='temp_min')
dfSun.plot(ax=ax2, x='date', y='temp_max')
ax2.set_ylim(-20, 50)
ax2.set_title('Sundays')

Subfigures

Say you want something more complicated like having the left side show 2012 with its own suptitle and right side to show 2015 with its own suptitle.

Create 1x2 subfigures (left subfig_l and right subfig_r) with 2x1 subplots on the left (top ax_lt and bottom ax_lb) and 2x1 subplots on the right (top ax_rt and bottom ax_rb):

subfigures

# create 1x2 subfigures
fig = plt.figure(constrained_layout=True, figsize=(12, 5))
(subfig_l, subfig_r) = fig.subfigures(nrows=1, ncols=2, wspace=0.07)

# create top/box axes in left subfig
(ax_lt, ax_lb) = subfig_l.subplots(nrows=2, ncols=1)

# plot 2012 saturdays on left-top axes
dfSat12 = dfSat.loc[dfSat['date'].dt.year.eq(2012)]
dfSat12.plot(ax=ax_lt, x='date', y='temp_min')
dfSat12.plot(ax=ax_lt, x='date', y='temp_max')
ax_lt.set_ylim(-20, 50)
ax_lt.set_ylabel('Saturdays')

# plot 2012 sundays on left-top axes
dfSun12 = dfSun.loc[dfSun['date'].dt.year.eq(2012)]
dfSun12.plot(ax=ax_lb, x='date', y='temp_min')
dfSun12.plot(ax=ax_lb, x='date', y='temp_max')
ax_lb.set_ylim(-20, 50)
ax_lb.set_ylabel('Sundays')

# set suptitle for left subfig
subfig_l.suptitle('2012', size='x-large', weight='bold')

# create top/box axes in right subfig
(ax_rt, ax_rb) = subfig_r.subplots(nrows=2, ncols=1)

# plot 2015 saturdays on left-top axes
dfSat15 = dfSat.loc[dfSat['date'].dt.year.eq(2015)]
dfSat15.plot(ax=ax_rt, x='date', y='temp_min')
dfSat15.plot(ax=ax_rt, x='date', y='temp_max')
ax_rt.set_ylim(-20, 50)
ax_rt.set_ylabel('Saturdays')

# plot 2015 sundays on left-top axes
dfSun15 = dfSun.loc[dfSun['date'].dt.year.eq(2015)]
dfSun15.plot(ax=ax_rb, x='date', y='temp_min')
dfSun15.plot(ax=ax_rb, x='date', y='temp_max')
ax_rb.set_ylim(-20, 50)
ax_rb.set_ylabel('Sundays')

# set suptitle for right subfig
subfig_r.suptitle('2015', size='x-large', weight='bold')

Sample data for reference:

import pandas as pd
from vega_datasets import data

df = data.seattle_weather()
df['date'] = pd.to_datetime(df['date'])
dfSat = df.loc[df['date'].dt.weekday.eq(5)]
dfSun = df.loc[df['date'].dt.weekday.eq(6)]

Python - matplotlib - differences between subplot() and subplots()

1. matplotlib.pyplot.subplots()

From the documentation page on matplotlib.pyplot.subplots():

This utility wrapper makes it convenient to create common layouts of subplots, including the enclosing figure object, in a single call.

That means you can use this single function to create a figure with several subplots with only one line of code. For example, the code below will return both fig which is the figure object, and axes which is a 2x3 array of axes objects which allows you to easily access each subplot:

fig, axes = plt.subplots(nrows=2, ncols=3)

2. matplotlib.pyplot.subplot()

In contrast, matplotlib.pyplot.subplot() creates only a single subplot axes at a specified grid position. This means it will require several lines of code to achieve the same result as matplot.pyplot.subplots() did in a single line of code above:

# first you have to make the figure
fig = plt.figure(1)

# now you have to create each subplot individually
ax1 = plt.subplot(231)
ax2 = plt.subplot(232)
ax3 = plt.subplot(233)
ax4 = plt.subplot(234)
ax5 = plt.subplot(235)
ax6 = plt.subplot(236)

or you can also use built-in method of fig:

ax1 = fig.add_subplot(231)
ax2 = fig.add_subplot(232)
ax3 = fig.add_subplot(233)
ax4 = fig.add_subplot(234)
ax5 = fig.add_subplot(235)
ax6 = fig.add_subplot(236)

Conclusion

The code above can be condensed with a loop, but it is still considerably more tedious to use. I'd therefore recommend you use matplotlib.pyplot.subplots() since it is more concise and easy to use.

Understanding matplotlib: plt, figure, ax(arr)?

pyplot is the 'scripting' level API in matplotlib (its highest level API to do a lot with matplotlib). It allows you to use matplotlib using a procedural interface in a similar way as you can do it with Matlab. pyplot has a notion of 'current figure' and 'current axes' that all the functions delegate to (@tacaswell dixit). So, when you use the functions available on the module pyplot you are plotting to the 'current figure' and 'current axes'.

If you want 'fine-grain' control of where/what your are plotting then you should use an object oriented API using instances of Figure and Axes.

Functions available in pyplot have an equivalent method in the Axes.

From the repo anatomy of matplotlib:

  • The Figure is the top-level container in this hierarchy. It is the overall window/page that everything is drawn on. You can have multiple independent figures and Figures can contain multiple Axes.

But...

  • Most plotting occurs on an Axes. The axes is effectively the area that we plot data on and any ticks/labels/etc associated with it. Usually we'll set up an Axes with a call to subplot (which places Axes on a regular grid), so in most cases, Axes and Subplot are synonymous.

  • Each Axes has an XAxis and a YAxis. These contain the ticks, tick locations, labels, etc.

If you want to know the anatomy of a plot you can visit this link.

Plotting matplotlib subplots with functions

Remove the plt.show() from your methods normal_dist_hist(...) and normal_test_QQplots(...). Add plt.show() at the end of your normality_report(...).

def normal_dist_hist(data, ax):
...
plt.show() # Remove this

#Probability plot
def normal_test_QQplots(data, ax):
...
plt.show() # Remove this

def normality_report(df):
...
for col in df.columns[1:]:
ax_x = 0
normal_dist_hist(df[col], ax=axes[ax_y, ax_x])
ax_x = 1
normal_test_QQplots(df[col], ax=axes[ax_y, ax_x])
ax_y += 1
plt.show() # Add it here.

matplotlib combining subplots into a single plot with no axis and no gaps

You have to specify the grid parameters:

  • 2 rows
  • 3 columns
  • 0 width space
  • 0 height space

with matplotlib.pyplot.subplots:

fig, axes = plt.subplots(nrows = 2, ncols = 3, gridspec_kw = {'wspace': 0, 'hspace': 0})

Then you can loop over created axes and, for each one of them, you have to show the image and set axis to 'tight' firtsly and 'off' secondly:

for ax in axes.flatten():
ax.imshow(img)
ax.axis('tight')
ax.axis('off')

Your code would be slighlty different, since you are plotting different images for each ax.

Complete Code

import matplotlib.pyplot as plt

img = plt.imread('img.jpeg')

fig, axes = plt.subplots(nrows = 2, ncols = 3, gridspec_kw = {'wspace': 0, 'hspace': 0})

for ax in axes.flatten():
ax.imshow(img)
ax.axis('tight')
ax.axis('off')

plt.show()

Sample Image



Related Topics



Leave a reply



Submit