Bekhruz Tuychiev I am a data science content writer, spilling every bit of knowledge I have through a series of blog posts, articles, and tutorials. Trying to fulfill my never-satisfied desire of teaching AI and data science to as many people as possible.

Mastering data visualization in Python with Matplotlib

7 min read 2106

Matplotlib Logo Over a Bar Graph

According to the 2020 Kaggle Machine Learning and Data Science survey, Matplotlib is the number one data visualization library among Kagglers, leading by a significant margin.

Kagglers Visualization Libraries

Many courses and tutorials have recently drawn beginner data scientists’ attention to new, shiny, interactive libraries like Plotly, but Matplotlib remains the king of data visualization libraries and, I suspect, will likely continue to be for the foreseeable future.

Because of this, I highly recommend that you learn it and go beyond the basics because the power of Matplotlib becomes more evident when you tap into its more advanced features.

In this tutorial, we will cover some of them and give a solid introduction to the object-oriented (OO) interface of Matplotlib.

What is the object-oriented interface in Matplotlib?

When you first learn Matplotlib, you probably start using the library through its PyPlot interface, which is specifically designed for beginners because it’s user-friendly and requires less code to create visuals.

However, its features fall short when you want to perform advanced customizations on your graphs. That’s where the object-oriented API comes into play.

Under the hood, Matplotlib consists of base classes called artists.

Anatomy Figure

Having unique classes for each element in a visual gives Matplotlib users a ton of flexibility. Each circle-annotated component in the above graph is a separate class that inherits from the base artists. This means that you can tweak every little line, dot, text, or object visible on the plot.

In the following sections, we will learn about the most important of these classes, starting with figure and axes objects.

Figure and axes objects in Matplotlib

Let’s first import Matplotlib and its submodules:

import matplotlib as mpl  # pip install matplotlib
import matplotlib.pyplot as plt

Next, we create a figure and an axes object using the subplots function:

>>> fig, ax = plt.subplots()

Subplots Function

Now, let’s explain what these objects do.

fig (figure) is the highest-level artist, an object that contains everything. Think of it as the canvas you can draw on. The axes object (ax) represents a single set of XY coordinate systems. All Matplotlib plots require a coordinate system, so you must create at least one figure and one axes object to draw charts.

plt.subplots is a shorthand for doing this — it creates a single figure and one or more axes objects in a single line of code. A more verbose version of this would be:

>>> fig = plt.figure()
>>> ax1 = fig.add_axes()

<Figure size 432x288 with 0 Axes>

Because this requires more code, people usually stick to using subplots. Besides, you can pass extra arguments to it to create multiple axes objects simultaneously:

>>> fig, axes = plt.subplots(nrows=1, ncols=3)

Multiple Axes

By changing the nrows and ncols arguments, you create a set of subplots — multiple axes objects stored in axes. You can access each one by using a loop or indexing operators.



Learn how to use the subplots function in-depth in its documentation.

Plotting functions in Matplotlib

When you switch from PyPlot to OOP API, the function names for plots do not change. You call them using the axes object:

import seaborn as sns

tips = sns.load_dataset("tips")

fig, ax = plt.subplots()

ax.scatter(tips["tip"], tips["total_bill"])
ax.set(
   title="Tip vs. Total Bill amount in a restaurant",
   xlabel="Tip ($)",
   ylabel="Totalb bill ($)",
);

Tip vs Total Bill

Here, I introduce the set function, which you can use on any Matplotlib object to tweak its properties.

The above plot is a bit bland and in no way compares to default scatterplots created by Seaborn:

>>> sns.scatterplot(tips["tip"], tips["total_bill"]);

Seaborn Scatterplot

For this reason, let’s discuss two extremely flexible functions you can use to customize your plots in the next section.

get_* functions in Matplotlib

Remember how Matplotlib has separate classes for each plot component? In the next couple of sections, we’ll take advantage of this feature.

While customizing my plots, I generally use this workflow:

  1. Create the basic plot
  2. Identify weaknesses of the plot that need customizations
  3. Extract those weak objects
  4. Customize them using the setp function (more on this later)

Here, we will discuss the third step  —  how to extract different components of the plot.

First, let’s create a simple plot:

fig, ax = plt.subplots()

# Create the data to plot
X = np.linspace(0.5, 3.5, 100)
Y1 = 3 + np.cos(X)
Y2 = 1 + np.cos(1 + X / 0.75) / 2
Y3 = np.random.uniform(Y1, Y2, len(X))

ax.scatter(X, Y3)
ax.plot(X, Y1)
ax.plot(X, Y2);

Simple Scatterplot

We used the subplots function to create the figure and axes objects, but let’s assume we don’t have the axes object. How do we find it?

Remember, the figure object is the highest-level artist that contains everything in the plot. So, we will call dir on the fig object to see what methods it has:

>>> dir(fig)

[
...
'gca',
'get_agg_filter',
'get_alpha',
'get_animated',
'get_axes',
'get_dpi',
'get_edgecolor',
'get_facecolor',
'get_figheight',
'get_figure',
'get_figwidth',
'get_frameon',
'get_gid',
'get_in_layout'
...
]

In the list, we see the get_axes method, which is what we need:

axes = fig.get_axes()

>>> type(axes)
list

>>> len(axes)
1

The result from get_axes is a list containing a single axes object we created in the above plot.

The axes example serves as proof that everything in Matplotlib is just a class. A single plot contains several components implemented as separate classes, and each of those components can have one or more subclasses.


More great articles from LogRocket:


They all have one thing in common: you can extract those classes or subclasses using the relevant get_* functions. You just have to know their names.

What do you do once you extract those objects? You tweak them!

plt.getp and plt.setp functions

To tweak the properties of any component, you have to know what arguments it has and what values each argument receives. You will be working with many objects, so visiting the documentation every time can become tiresome.

Fortunately, Matplotlib creators thought of this issue. Once you extract the relevant object, you can see what parameters it accepts using the plt.getp function. For example, let’s see the properties of the axes object:

fig, _ = plt.subplots()

ax = fig.get_axes()[0]

>>> plt.getp(ax)
   ...
   xlabel =
   xlim = (0.0, 1.0)
   xmajorticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T...
   xminorticklabels = []
   xscale = linear
   xticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T...
   xticklines = <a list of 12 Line2D ticklines objects>
   xticks = [0.  0.2 0.4 0.6 0.8 1. ]
   yaxis = YAxis(54.0,36.0)
   yaxis_transform = BlendedGenericTransform(     BboxTransformTo(     ...
   ybound = (0.0, 1.0)
   ygridlines = <a list of 6 Line2D gridline objects>
   ylabel =
   ylim = (0.0, 1.0)
   ymajorticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T...
   yminorticklabels = []
   yscale = linear
   ...

As you can see, the getp function lists all properties of the object it was called on, displaying their current or default values. We can do the same for the fig object:

>>> plt.getp(fig)
   ...
   constrained_layout_pads = (0.04167, 0.04167, 0.02, 0.02)
   contains = None
   default_bbox_extra_artists = [<AxesSubplot:>, <matplotlib.spines.Spine object a...
   dpi = 72.0
   edgecolor = (1.0, 1.0, 1.0, 0.0)
   facecolor = (1.0, 1.0, 1.0, 0.0)
   figheight = 4.0
   figure = Figure(432x288)
   figwidth = 6.0
   frameon = True
   gid = None
   in_layout = True
   label =
   linewidth = 0.0
   path_effects = []
   ...

Once you identify which parameters you want to change, you must know what range of values they receive. For this, you can use the plt.setp function.

Let’s say we want to change the yscale parameter of the axis object. To see the possible values it accepts, we pass both the axes object and the name of the parameter to plt.setp:

>>> plt.setp(ax, "yscale")
yscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase`

As we see, yscale accepts five possible values. That’s much faster than digging through the large docs of Matplotlib.

The setp function is very flexible. Passing just the object without any other parameters will list that object’s all parameters displaying their possible values:

>>> plt.setp(ax)
 ...
 xlabel: str
 xlim: (bottom: float, top: float)
 xmargin: float greater than -0.5
 xscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase`
 xticklabels: unknown
 xticks: unknown
 ybound: unknown
 ylabel: str
 ylim: (bottom: float, top: float)
 ymargin: float greater than -0.5
 yscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase`
 yticklabels: unknown
 yticks: unknown
 zorder: float
 ...

Now that we know what parameters we want to change and what values we want to pass to them, we can use the set or plt.setp functions:

fig, ax = plt.subplots()

# Using `set`
ax.set(yscale="log", xlabel="X Axis", ylabel="Y Axis", title="Large Title")

# Using setp
plt.setp(ax, yscale="log", xlabel="X Axis", ylabel="Y Axis", title="Large Title")

plt.setp(fig, size_inches=(10, 10));

Large Scatterplot Axes

Using Line2D objects

The most common figures in any plot are lines and dots. Almost all plots, such as bars, box plots, histograms, scatterplots, etc., use rectangles, hence, lines.

Matplotlib implements a global base class for drawing lines, the Line2D class. You never use it directly in practice, but it gets called every time Matplotlib draws a line, either as a plot or as part of some geometric figure.

As many other classes inherit from this one, it’s beneficial to learn its properties:

from matplotlib.lines import Line2D
xs = [1, 2, 3, 4]
ys = [1, 2, 3, 4]

>>> plt.setp(Line2D(xs, ys))
 ...
 dash_capstyle: `.CapStyle` or {'butt', 'projecting', 'round'}
 dash_joinstyle: `.JoinStyle` or {'miter', 'round', 'bevel'}
 dashes: sequence of floats (on/off ink in points) or (None, None)
 data: (2, N) array or two 1D arrays
 drawstyle or ds: {'default', 'steps', 'steps-pre', 'steps-mid', 'steps-post'}, default: 'default'
 figure: `.Figure`
 fillstyle: {'full', 'left', 'right', 'bottom', 'top', 'none'}
 gid: str
 in_layout: bool
 label: object
 linestyle or ls: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...}
 linewidth or lw: float
 ...

I recommend paying attention to the linestyle, width, and color arguments, which are used the most.

Customizing axis ticks in Matplotlib

One of the essential aspects of all Matplotlib plots is axis ticks. They don’t draw much attention but silently control how the data is displayed on the plot, making their effect on the plot substantial.

Fortunately, Matplotlib makes it a breeze to customize the axis ticks using the tick_params method of the axis object. Let’s learn about its parameters:

Change the appearance of ticks, tick labels, and gridlines.

Tick properties that are not explicitly set using the keyword
arguments remain unchanged unless *reset* is True.

Parameters
----------
axis : {'x', 'y', 'both'}, default: 'both'
   The axis to which the parameters are applied.
which : {'major', 'minor', 'both'}, default: 'major'
   The group of ticks to which the parameters are applied.
reset : bool, default: False
   Whether to reset the ticks to defaults before updating them.

Other Parameters
----------------
direction : {'in', 'out', 'inout'}
   Puts ticks inside the axes, outside the axes, or both.
length : float
   Tick length in points.
width : float
   Tick width in points.

Above is a snippet from its documentation.

The first and most important argument is axis. It accepts three possible values and represents which axis ticks you want to change. Most of the time, you choose both.

Next, you have which that directs the tick changes to either minor or major ticks. If minor ticks are not visible on your plot, you can turn them on using ax.minorticks_on():

fig, ax = plt.subplots(figsize=(10, 10))

ax.minorticks_on()

Large Scatterplot

The rest are fairly self-explanatory. Let’s put all the concepts together in an example:

fig, ax = plt.subplots(figsize=(6, 6))

ax.tick_params(axis="both", which="major", direction="out", width=4, size=10, color="r")
ax.minorticks_on()
ax.tick_params(axis="both", which="minor", direction="in", width=2, size=8, color="b")

Example Plot

While we are here, you can tweak the spines as well. For example, let’s play around with the top and right spines:

fig, ax = plt.subplots(figsize=(6, 6))

ax.tick_params(axis="both", which="major", direction="out", width=4, size=10, color="r")
ax.minorticks_on()
ax.tick_params(axis="both", which="minor", direction="in", width=2, size=8, color="b")

for spine in ["top", "right"]:
   plt.setp(ax.spines[spine], ls="--", color="brown", hatch="x", lw=4)

Example Plot With Red Dotted Line

You can access the spines using the spines attribute of the axes object, and the rest is easy. Because a spine is a line, its properties are the same as that of a Line2D object.

Conclusion

The key to a great plot is in the details. Matplotlib defaults are rarely up to professional standards, so it falls on you to customize them. In this article, we have tapped into the core of Matplotlib to teach you the internals so that you have a better handle on more advanced concepts.

Once you start implementing the ideas in the tutorial, you will hopefully see a dramatic change in how you create your plots and customize them. Thanks for reading.

Get setup with LogRocket's modern error tracking in minutes:

  1. Visit https://logrocket.com/signup/ to get an app ID.
  2. Install LogRocket via NPM or script tag. LogRocket.init() must be called client-side, not server-side.
  3. $ npm i --save logrocket 

    // Code:

    import LogRocket from 'logrocket';
    LogRocket.init('app/id');
    Add to your HTML:

    <script src="https://cdn.lr-ingest.com/LogRocket.min.js"></script>
    <script>window.LogRocket && window.LogRocket.init('app/id');</script>
  4. (Optional) Install plugins for deeper integrations with your stack:
    • Redux middleware
    • ngrx middleware
    • Vuex plugin
Get started now
Bekhruz Tuychiev I am a data science content writer, spilling every bit of knowledge I have through a series of blog posts, articles, and tutorials. Trying to fulfill my never-satisfied desire of teaching AI and data science to as many people as possible.

Leave a Reply