Multiple Axis in Matplotlib with Different Scales

multiple axis in matplotlib with different scales

If I understand the question, you may interested in this example in the Matplotlib gallery.

Sample Image

Yann's comment above provides a similar example.


Edit - Link above fixed. Corresponding code copied from the Matplotlib gallery:

from mpl_toolkits.axes_grid1 import host_subplot
import mpl_toolkits.axisartist as AA
import matplotlib.pyplot as plt

host = host_subplot(111, axes_class=AA.Axes)
plt.subplots_adjust(right=0.75)

par1 = host.twinx()
par2 = host.twinx()

offset = 60
new_fixed_axis = par2.get_grid_helper().new_fixed_axis
par2.axis["right"] = new_fixed_axis(loc="right", axes=par2,
offset=(offset, 0))

par2.axis["right"].toggle(all=True)

host.set_xlim(0, 2)
host.set_ylim(0, 2)

host.set_xlabel("Distance")
host.set_ylabel("Density")
par1.set_ylabel("Temperature")
par2.set_ylabel("Velocity")

p1, = host.plot([0, 1, 2], [0, 1, 2], label="Density")
p2, = par1.plot([0, 1, 2], [0, 3, 2], label="Temperature")
p3, = par2.plot([0, 1, 2], [50, 30, 15], label="Velocity")

par1.set_ylim(0, 4)
par2.set_ylim(1, 65)

host.legend()

host.axis["left"].label.set_color(p1.get_color())
par1.axis["right"].label.set_color(p2.get_color())
par2.axis["right"].label.set_color(p3.get_color())

plt.draw()
plt.show()

#plt.savefig("Test")

How to add additional x-axes but with different scale and color (matplotlib)

You should use 3 different axes, one for each line you need to plot.

The first one can be:

fig, ax_full = plt.subplots()

full = ax_full.plot(x_full, y_full, color = 'red', label = 'full')

Then you can generate the others with:

ax_enn = ax_full.twiny()

And plot each line on the respective axis:

enn = ax_enn.plot(x_enn, y_enn, color = 'blue', label = 'enn')

Then you can move the axis to the bottom with:

ax_enn.xaxis.set_ticks_position('bottom')
ax_enn.xaxis.set_label_position('bottom')
ax_enn.spines['bottom'].set_position(('axes', -0.15))

And finally customize the colors:

ax_enn.spines['bottom'].set_color('blue')
ax_enn.tick_params(axis='x', colors='blue')
ax_enn.xaxis.label.set_color('blue')

Complete Code

import numpy as np
import matplotlib.pyplot as plt

x_full = np.linspace(0.001, 0.02, 20)
x_enn = np.linspace(0.05, 1.95, 20)
x_knn = np.linspace(2, 40, 20)

y_full = np.random.rand(len(x_full))
y_enn = np.random.rand(len(x_enn))
y_knn = np.random.rand(len(x_knn))

fig, ax_full = plt.subplots()

full = ax_full.plot(x_full, y_full, color = 'red', label = 'full')
ax_full.spines['bottom'].set_color('red')
ax_full.tick_params(axis='x', colors='red')
ax_full.xaxis.label.set_color('red')

ax_enn = ax_full.twiny()
enn = ax_enn.plot(x_enn, y_enn, color = 'blue', label = 'enn')
ax_enn.xaxis.set_ticks_position('bottom')
ax_enn.xaxis.set_label_position('bottom')
ax_enn.spines['bottom'].set_position(('axes', -0.15))
ax_enn.spines['bottom'].set_color('blue')
ax_enn.tick_params(axis='x', colors='blue')
ax_enn.xaxis.label.set_color('blue')

ax_knn = ax_full.twiny()
knn = ax_knn.plot(x_knn, y_knn, color = 'green', label = 'knn')
ax_knn.xaxis.set_ticks_position('bottom')
ax_knn.xaxis.set_label_position('bottom')
ax_knn.spines['bottom'].set_position(('axes', -0.3))
ax_knn.spines['bottom'].set_color('green')
ax_knn.tick_params(axis='x', colors='green')
ax_knn.xaxis.label.set_color('green')

lines = full + enn + knn
labels = [l.get_label() for l in lines]
ax_full.legend(lines, labels)

plt.tight_layout()

plt.show()

Sample Image

matplotlib: overlay plots with different scales?

It sounds like what you're wanting is subplots... What you're doing now doesn't make much sense (Or I'm very confused by your code snippet, at any rate...).

Try something more like this:

import matplotlib.pyplot as plt
import numpy as np

fig, axes = plt.subplots(nrows=3)

colors = ('k', 'r', 'b')
for ax, color in zip(axes, colors):
data = np.random.random(1) * np.random.random(10)
ax.plot(data, marker='o', linestyle='none', color=color)

plt.show()

Sample Image

Edit:

If you don't want subplots, your code snippet makes a lot more sense.

You're trying to add three axes right on top of each other. Matplotlib is recognizing that there's already a subplot in that exactly size and location on the figure, and so it's returning the same axes object each time. In other words, if you look at your list ax, you'll see that they're all the same object.

If you really want to do that, you'll need to reset fig._seen to an empty dict each time you add an axes. You probably don't really want to do that, however.

Instead of putting three independent plots over each other, have a look at using twinx instead.

E.g.

import matplotlib.pyplot as plt
import numpy as np
# To make things reproducible...
np.random.seed(1977)

fig, ax = plt.subplots()

# Twin the x-axis twice to make independent y-axes.
axes = [ax, ax.twinx(), ax.twinx()]

# Make some space on the right side for the extra y-axis.
fig.subplots_adjust(right=0.75)

# Move the last y-axis spine over to the right by 20% of the width of the axes
axes[-1].spines['right'].set_position(('axes', 1.2))

# To make the border of the right-most axis visible, we need to turn the frame
# on. This hides the other plots, however, so we need to turn its fill off.
axes[-1].set_frame_on(True)
axes[-1].patch.set_visible(False)

# And finally we get to plot things...
colors = ('Green', 'Red', 'Blue')
for ax, color in zip(axes, colors):
data = np.random.random(1) * np.random.random(10)
ax.plot(data, marker='o', linestyle='none', color=color)
ax.set_ylabel('%s Thing' % color, color=color)
ax.tick_params(axis='y', colors=color)
axes[0].set_xlabel('X-axis')

plt.show()

Sample Image

Dual x-axis in python: same data, different scale

In your code example, you plot the same data twice (albeit transformed using E=h*c/wl). I think it would be sufficient to only plot the data once, but create two x-axes: one displaying the wavelength in nm and one displaying the corresponding energy in eV.

Consider the adjusted code below:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import FormatStrFormatter
import scipy.constants as constants
from sys import float_info

# Function to prevent zero values in an array
def preventDivisionByZero(some_array):
corrected_array = some_array.copy()
for i, entry in enumerate(some_array):
# If element is zero, set to some small value
if abs(entry) < float_info.epsilon:
corrected_array[i] = float_info.epsilon

return corrected_array

# Converting wavelength (nm) to energy (eV)
def WLtoE(wl):
# Prevent division by zero error
wl = preventDivisionByZero(wl)

# E = h*c/wl
h = constants.h # Planck constant
c = constants.c # Speed of light
J_eV = constants.e # Joule-electronvolt relationship

wl_nm = wl * 10**(-9) # convert wl from nm to m
E_J = (h*c) / wl_nm # energy in units of J
E_eV = E_J / J_eV # energy in units of eV

return E_eV


# Converting energy (eV) to wavelength (nm)
def EtoWL(E):
# Prevent division by zero error
E = preventDivisionByZero(E)

# Calculates the wavelength in nm
return constants.h * constants.c / (constants.e * E) * 10**9

x = np.arange(200,2001,5)
y = 2*x + 3

fig, ax1 = plt.subplots()

ax1.plot(x, y, color='black')

ax1.set_xlabel('Wavelength (nm)', fontsize = 'large')
ax1.set_ylabel('Absorbance (a.u.)', fontsize = 'large')

# Invert the wavelength axis
ax1.invert_xaxis()

# Create the second x-axis on which the energy in eV will be displayed
ax2 = ax1.secondary_xaxis('top', functions=(WLtoE, EtoWL))
ax2.set_xlabel('Energy (eV)', fontsize='large')

# Get ticks from ax1 (wavelengths)
wl_ticks = ax1.get_xticks()
wl_ticks = preventDivisionByZero(wl_ticks)

# Based on the ticks from ax1 (wavelengths), calculate the corresponding
# energies in eV
E_ticks = WLtoE(wl_ticks)

# Set the ticks for ax2 (Energy)
ax2.set_xticks(E_ticks)

# Allow for two decimal places on ax2 (Energy)
ax2.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))

plt.tight_layout()
plt.show()

First of all, I define the preventDivisionByZero utility function. This function takes an array as input and checks for values that are (approximately) equal to zero. Subsequently, it will replace these values with a small number (sys.float_info.epsilon) that is not equal to zero. This function will be used in a few places to prevent division by zero. I will come back to why this is important later.

After this function, your WLtoE function is defined. Note that I added the preventDivisionByZero function at the top of your function. In addition, I defined a EtoWL function, which does the opposite compared to your WLtoE function.

Then, you generate your dummy data and plot it on ax1, which is the x-axis for the wavelength. After setting some labels, ax1 is inverted (as was requested in your original post).

Now, we create the second axis for the energy using ax2 = ax1.secondary_xaxis('top', functions=(WLtoE, EtoWL)). The first argument indicates that the axis should be placed at the top of the figure. The second (keyword) argument is given a tuple containing two functions: the first function is the forward transform, while the second function is the backward transform. See Axes.secondary_axis for more information. Note that matplotlib will pass values to these two functions whenever necessary. As these values can be equal to zero, it is important to handle those cases. Hence, the preventDivisionByZero function! After creating the second axis, the label is set.

Now we have two x-axes, but the ticks on both axis are at different locations. To 'solve' this, we store the tick locations of the wavelength x-axis in wl_ticks. After ensuring there are no zero elements using the preventDivisionByZero function, we calculate the corresponding energy values using the WLtoE function. These corresponding energy values are stored in E_ticks. Now we simply set the tick locations of the second x-axis equal to the values in E_ticks using ax2.set_xticks(E_ticks).

To allow for two decimal places on the second x-axis (energy), we use ax2.xaxis.set_major_formatter(FormatStrFormatter('%.2f')). Of course, you can choose the desired number of decimal places yourself.

The code given above produces the following graph:
output of python code given above

How to have 2 different scales on same Y axis in Python using Matplotlib

On your 4th axes, set tick_left and move the left spine to the right-hand side:

twin4.yaxis.tick_left()
twin4.spines['left'].set_position(('axes', 1.0))

Sample Image

Adding a y-axis label to secondary y-axis in matplotlib

The best way is to interact with the axes object directly

import numpy as np
import matplotlib.pyplot as plt
x = np.arange(0, 10, 0.1)
y1 = 0.05 * x**2
y2 = -1 *y1

fig, ax1 = plt.subplots()

ax2 = ax1.twinx()
ax1.plot(x, y1, 'g-')
ax2.plot(x, y2, 'b-')

ax1.set_xlabel('X data')
ax1.set_ylabel('Y1 data', color='g')
ax2.set_ylabel('Y2 data', color='b')

plt.show()

example graph

two (or more) graphs in one plot with different x-axis AND y-axis scales in python

The idea would be to create three subplots at the same position. In order to make sure, they will be recognized as different plots, their properties need to differ - and the easiest way to achieve this is simply to provide a different label, ax=fig.add_subplot(111, label="1").

The rest is simply adjusting all the axes parameters, such that the resulting plot looks appealing.
It's a little bit of work to set all the parameters, but the following should do what you need.

Sample Image

import matplotlib.pyplot as plt

x_values1=[1,2,3,4,5]
y_values1=[1,2,2,4,1]

x_values2=[-1000,-800,-600,-400,-200]
y_values2=[10,20,39,40,50]

x_values3=[150,200,250,300,350]
y_values3=[10,20,30,40,50]

fig=plt.figure()
ax=fig.add_subplot(111, label="1")
ax2=fig.add_subplot(111, label="2", frame_on=False)
ax3=fig.add_subplot(111, label="3", frame_on=False)

ax.plot(x_values1, y_values1, color="C0")
ax.set_xlabel("x label 1", color="C0")
ax.set_ylabel("y label 1", color="C0")
ax.tick_params(axis='x', colors="C0")
ax.tick_params(axis='y', colors="C0")

ax2.scatter(x_values2, y_values2, color="C1")
ax2.xaxis.tick_top()
ax2.yaxis.tick_right()
ax2.set_xlabel('x label 2', color="C1")
ax2.set_ylabel('y label 2', color="C1")
ax2.xaxis.set_label_position('top')
ax2.yaxis.set_label_position('right')
ax2.tick_params(axis='x', colors="C1")
ax2.tick_params(axis='y', colors="C1")

ax3.plot(x_values3, y_values3, color="C3")
ax3.set_xticks([])
ax3.set_yticks([])

plt.show()


Related Topics



Leave a reply



Submit