The post Mastering data visualization in Python with Matplotlib appeared first on LogRocket Blog.

]]>Many courses and tutorials have recently drawn beginner data scientists’ attention to new, shiny, interactive libraries like Plotly, but Matplotlib remains the king of data visualization libraries and, I suspect, will likely continue to be for the foreseeable future.

Because of this, I highly recommend that you learn it and go beyond the basics because the power of Matplotlib becomes more evident when you tap into its more advanced features.

In this tutorial, we will cover some of them and give a solid introduction to the object-oriented (OO) interface of Matplotlib.

When you first learn Matplotlib, you probably start using the library through its PyPlot interface, which is specifically designed for beginners because it’s user-friendly and requires less code to create visuals.

However, its features fall short when you want to perform advanced customizations on your graphs. That’s where the object-oriented API comes into play.

Under the hood, Matplotlib consists of base classes called artists.

Having unique classes for each element in a visual gives Matplotlib users a ton of flexibility. Each circle-annotated component in the above graph is a separate class that inherits from the base artists. This means that you can tweak every little line, dot, text, or object visible on the plot.

In the following sections, we will learn about the most important of these classes, starting with figure and axes objects.

Let’s first import Matplotlib and its submodules:

import matplotlib as mpl # pip install matplotlib import matplotlib.pyplot as plt

Next, we create a figure and an axes object using the `subplots`

function:

>>> fig, ax = plt.subplots()

Now, let’s explain what these objects do.

`fig`

(figure) is the highest-level artist, an object that contains everything. Think of it as the canvas you can draw on. The axes object (`ax`

) represents a single set of XY coordinate systems. All Matplotlib plots require a coordinate system, so you must create at least one figure and one axes object to draw charts.

`plt.subplots`

is a shorthand for doing this — it creates a single figure and one or more axes objects in a single line of code. A more verbose version of this would be:

>>> fig = plt.figure() >>> ax1 = fig.add_axes() <Figure size 432x288 with 0 Axes>

Because this requires more code, people usually stick to using `subplots`

. Besides, you can pass extra arguments to it to create multiple axes objects simultaneously:

>>> fig, axes = plt.subplots(nrows=1, ncols=3)

By changing the `nrows`

and `ncols`

arguments, you create a set of subplots — multiple axes objects stored in `axes`

. You can access each one by using a loop or indexing operators.

Learn how to use the subplots function in-depth in its documentation.

When you switch from PyPlot to OOP API, the function names for plots do not change. You call them using the axes object:

import seaborn as sns tips = sns.load_dataset("tips") fig, ax = plt.subplots() ax.scatter(tips["tip"], tips["total_bill"]) ax.set( title="Tip vs. Total Bill amount in a restaurant", xlabel="Tip ($)", ylabel="Totalb bill ($)", );

Here, I introduce the `set`

function, which you can use on any Matplotlib object to tweak its properties.

The above plot is a bit bland and in no way compares to default scatterplots created by Seaborn:

>>> sns.scatterplot(tips["tip"], tips["total_bill"]);

For this reason, let’s discuss two extremely flexible functions you can use to customize your plots in the next section.

Remember how Matplotlib has separate classes for each plot component? In the next couple of sections, we’ll take advantage of this feature.

While customizing my plots, I generally use this workflow:

- Create the basic plot
- Identify weaknesses of the plot that need customizations
- Extract those weak objects
- Customize them using the
`setp`

function (more on this later)

Here, we will discuss the third step — how to extract different components of the plot.

First, let’s create a simple plot:

fig, ax = plt.subplots() # Create the data to plot X = np.linspace(0.5, 3.5, 100) Y1 = 3 + np.cos(X) Y2 = 1 + np.cos(1 + X / 0.75) / 2 Y3 = np.random.uniform(Y1, Y2, len(X)) ax.scatter(X, Y3) ax.plot(X, Y1) ax.plot(X, Y2);

We used the `subplots`

function to create the figure and axes objects, but let’s assume we don’t have the axes object. How do we find it?

Remember, the figure object is the highest-level artist that contains everything in the plot. So, we will call `dir`

on the `fig`

object to see what methods it has:

>>> dir(fig) [ ... 'gca', 'get_agg_filter', 'get_alpha', 'get_animated', 'get_axes', 'get_dpi', 'get_edgecolor', 'get_facecolor', 'get_figheight', 'get_figure', 'get_figwidth', 'get_frameon', 'get_gid', 'get_in_layout' ... ]

In the list, we see the `get_axes`

method, which is what we need:

axes = fig.get_axes() >>> type(axes) list >>> len(axes) 1

The result from `get_axes`

is a list containing a single axes object we created in the above plot.

The axes example serves as proof that everything in Matplotlib is just a class. A single plot contains several components implemented as separate classes, and each of those components can have one or more subclasses.

They all have one thing in common: you can extract those classes or subclasses using the relevant `get_*`

functions. You just have to know their names.

What do you do once you extract those objects? You tweak them!

`plt.getp`

and `plt.setp`

functionsTo tweak the properties of any component, you have to know what arguments it has and what values each argument receives. You will be working with many objects, so visiting the documentation every time can become tiresome.

Fortunately, Matplotlib creators thought of this issue. Once you extract the relevant object, you can see what parameters it accepts using the `plt.getp`

function. For example, let’s see the properties of the axes object:

fig, _ = plt.subplots() ax = fig.get_axes()[0] >>> plt.getp(ax) ... xlabel = xlim = (0.0, 1.0) xmajorticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T... xminorticklabels = [] xscale = linear xticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T... xticklines = <a list of 12 Line2D ticklines objects> xticks = [0. 0.2 0.4 0.6 0.8 1. ] yaxis = YAxis(54.0,36.0) yaxis_transform = BlendedGenericTransform( BboxTransformTo( ... ybound = (0.0, 1.0) ygridlines = <a list of 6 Line2D gridline objects> ylabel = ylim = (0.0, 1.0) ymajorticklabels = [Text(0, 0, ''), Text(0, 0, ''), Text(0, 0, ''), T... yminorticklabels = [] yscale = linear ...

As you can see, the `getp`

function lists all properties of the object it was called on, displaying their current or default values. We can do the same for the fig object:

>>> plt.getp(fig) ... constrained_layout_pads = (0.04167, 0.04167, 0.02, 0.02) contains = None default_bbox_extra_artists = [<AxesSubplot:>, <matplotlib.spines.Spine object a... dpi = 72.0 edgecolor = (1.0, 1.0, 1.0, 0.0) facecolor = (1.0, 1.0, 1.0, 0.0) figheight = 4.0 figure = Figure(432x288) figwidth = 6.0 frameon = True gid = None in_layout = True label = linewidth = 0.0 path_effects = [] ...

Once you identify which parameters you want to change, you must know what range of values they receive. For this, you can use the `plt.setp`

function.

Let’s say we want to change the `yscale`

parameter of the axis object. To see the possible values it accepts, we pass both the axes object and the name of the parameter to `plt.setp`

:

>>> plt.setp(ax, "yscale") yscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase`

As we see, yscale accepts five possible values. That’s much faster than digging through the large docs of Matplotlib.

The `setp`

function is very flexible. Passing just the object without any other parameters will list that object’s all parameters displaying their possible values:

>>> plt.setp(ax) ... xlabel: str xlim: (bottom: float, top: float) xmargin: float greater than -0.5 xscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase` xticklabels: unknown xticks: unknown ybound: unknown ylabel: str ylim: (bottom: float, top: float) ymargin: float greater than -0.5 yscale: {"linear", "log", "symlog", "logit", ...} or `.ScaleBase` yticklabels: unknown yticks: unknown zorder: float ...

Now that we know what parameters we want to change and what values we want to pass to them, we can use the `set`

or `plt.setp`

functions:

fig, ax = plt.subplots() # Using `set` ax.set(yscale="log", xlabel="X Axis", ylabel="Y Axis", title="Large Title") # Using setp plt.setp(ax, yscale="log", xlabel="X Axis", ylabel="Y Axis", title="Large Title") plt.setp(fig, size_inches=(10, 10));

The most common figures in any plot are lines and dots. Almost all plots, such as bars, box plots, histograms, scatterplots, etc., use rectangles, hence, lines.

Matplotlib implements a global base class for drawing lines, the Line2D class. You never use it directly in practice, but it gets called every time Matplotlib draws a line, either as a plot or as part of some geometric figure.

As many other classes inherit from this one, it’s beneficial to learn its properties:

from matplotlib.lines import Line2D xs = [1, 2, 3, 4] ys = [1, 2, 3, 4] >>> plt.setp(Line2D(xs, ys)) ... dash_capstyle: `.CapStyle` or {'butt', 'projecting', 'round'} dash_joinstyle: `.JoinStyle` or {'miter', 'round', 'bevel'} dashes: sequence of floats (on/off ink in points) or (None, None) data: (2, N) array or two 1D arrays drawstyle or ds: {'default', 'steps', 'steps-pre', 'steps-mid', 'steps-post'}, default: 'default' figure: `.Figure` fillstyle: {'full', 'left', 'right', 'bottom', 'top', 'none'} gid: str in_layout: bool label: object linestyle or ls: {'-', '--', '-.', ':', '', (offset, on-off-seq), ...} linewidth or lw: float ...

I recommend paying attention to the `linestyle`

, `width,`

and `color`

arguments, which are used the most.

One of the essential aspects of all Matplotlib plots is axis ticks. They don’t draw much attention but silently control how the data is displayed on the plot, making their effect on the plot substantial.

Fortunately, Matplotlib makes it a breeze to customize the axis ticks using the `tick_params`

method of the axis object. Let’s learn about its parameters:

Change the appearance of ticks, tick labels, and gridlines. Tick properties that are not explicitly set using the keyword arguments remain unchanged unless *reset* is True. Parameters ---------- axis : {'x', 'y', 'both'}, default: 'both' The axis to which the parameters are applied. which : {'major', 'minor', 'both'}, default: 'major' The group of ticks to which the parameters are applied. reset : bool, default: False Whether to reset the ticks to defaults before updating them. Other Parameters ---------------- direction : {'in', 'out', 'inout'} Puts ticks inside the axes, outside the axes, or both. length : float Tick length in points. width : float Tick width in points.

Above is a snippet from its documentation.

The first and most important argument is `axis`

. It accepts three possible values and represents which axis ticks you want to change. Most of the time, you choose both.

Next, you have `which`

that directs the tick changes to either minor or major ticks. If minor ticks are not visible on your plot, you can turn them on using `ax.minorticks_on()`

:

fig, ax = plt.subplots(figsize=(10, 10)) ax.minorticks_on()

The rest are fairly self-explanatory. Let’s put all the concepts together in an example:

fig, ax = plt.subplots(figsize=(6, 6)) ax.tick_params(axis="both", which="major", direction="out", width=4, size=10, color="r") ax.minorticks_on() ax.tick_params(axis="both", which="minor", direction="in", width=2, size=8, color="b")

While we are here, you can tweak the spines as well. For example, let’s play around with the top and right spines:

fig, ax = plt.subplots(figsize=(6, 6)) ax.tick_params(axis="both", which="major", direction="out", width=4, size=10, color="r") ax.minorticks_on() ax.tick_params(axis="both", which="minor", direction="in", width=2, size=8, color="b") for spine in ["top", "right"]: plt.setp(ax.spines[spine], ls="--", color="brown", hatch="x", lw=4)

You can access the spines using the `spines `

attribute of the axes object, and the rest is easy. Because a spine is a line, its properties are the same as that of a Line2D object.

The key to a great plot is in the details. Matplotlib defaults are rarely up to professional standards, so it falls on you to customize them. In this article, we have tapped into the core of Matplotlib to teach you the internals so that you have a better handle on more advanced concepts.

Once you start implementing the ideas in the tutorial, you will hopefully see a dramatic change in how you create your plots and customize them. Thanks for reading.

The post Mastering data visualization in Python with Matplotlib appeared first on LogRocket Blog.

]]>The post Data visualization in Python using Seaborn appeared first on LogRocket Blog.

]]>The majority of data visuals created by data scientists are created with Python and its twin visualization libraries: Matplotlib and Seaborn. Matplotlib and Seaborn are widely used to create graphs that enable individuals and companies to make sense of terabytes of data.

So, what are these two libraries, exactly?

Matplotlib is the king of Python data visualization libraries and makes it a breeze to explore tabular data visually.

Seaborn is another Python data visualization library built on top of Matplotlib that introduces some features that weren’t previously available, and, in this tutorial, we’ll use Seaborn.

To follow along with this project, you’ll also need to know about Pandas, a powerful library that manipulates and analyzes tabular data.

In this blog post, we’ll learn how to perform data analysis through visualizations created with Seaborn. You will be introduced to histograms, KDEs, bar charts, and more. By the end, you’ll have a solid understanding of how to visualize data.

We will start by installing the libraries and importing our data. Running the below command will install the Pandas, Matplotlib, and Seaborn libraries for data visualization:

pip install pandas matplotlib seaborn

Now, let’s import the libraries under their standard aliases:

import matplotlib.pyplot as plt import pandas as pd import seaborn as sns

Next, load in the data to be analyzed. The dataset contains physical measurements of 54,000 diamonds and their prices. You can download the original dataset as a CSV file from here on Kaggle, but we will be using a shortcut:

diamonds = sns.load_dataset("diamonds")

Because the dataset is already built into Seaborn, we can load it as `pandas.DataFrame`

using the `load_dataset`

function.

>>> type(diamonds) pandas.core.frame.DataFrame

Before we dive head-first into visuals, let’s ensure we have a high-level understanding of our dataset:

>>> diamonds.head()

We have used the handy `head`

function of Pandas that prints out the first five rows of the data frame. `head`

should be the first function you use when you load a dataset into your environment for the first time.

Notice the dataset has ten variables — three categorical and seven numeric.

**Carat**: weight of a diamond**Cut**: the cut quality with five possible values in increasing order: Fair, Good, Very Good, Premium, Ideal**Color**: the color of a diamond with color codes from D (the best) to J (the worst)**Clarity**: the clarity of a diamond with eight clarity codes**X**: length of a diamond (mm)**Y**: the height of a diamond (mm)**Z**: depth of a diamond (mm)**Depth**: total depth percentage calculated as Z / average(X, Y)**Table**: the ratio of the height of a diamond to its widest point**Price**: diamond price in dollars

Instead of counting all variables one by one, we can use the `shape`

attribute of the data frame:

>>> diamonds.shape (53940, 10)

There are 53,940 diamonds recorded, along with their ten different features. Now, let’s print a five-number summary of the dataset:

>>> diamonds.describe()

The `describe`

function displays some critical metrics of each numeric variable in a data frame. Here are some observations from the above output:

- The cheapest diamond in the dataset costs $326, while the most expensive costs almost 60 times more , $18,823
- The minimum weight of a diamond is 0.2 carats, while the max is 5.01. The average weight is ~0.8
- Looking at the mean of X and Y features, we see that diamonds, on average, have the same height and width

Now that we are comfortable with the features in our dataset, we can start plotting them to uncover more insights.

In the previous section, we started something called “Exploratory Data Analysis” (EDA), which is the basis for any data-related project.

The goal of EDA is simple — get to know your dataset at the deepest level possible. Becoming intimate with the data and learning its relationships between its variables is an absolute must.

Completing a successful and thorough EDA lays the groundwork for future stages of your data project.

We have already performed the first stage of EDA, which was a simple “get acquainted” step. Now, let’s go deeper, starting with univariate analysis.

As the name suggests, we’ll explore variables one at a time, not the relationships between them just yet. Before we start plotting, we take a small dataset sample because 54,000 is more than we need, and we can learn about the data set pretty well with just 3,000 and to prevent overplotting.

sample = diamonds.sample(3000)

To take a sample, we use the `sample`

function of pandas, passing in the number of random data points to include in a sample.

Now, we create our first plot, which is a histogram:

sns.histplot(x=sample["price"])

Histograms only work on numeric variables. They divide the data into an arbitrary number of equal-sized bins and display how many diamonds go into each bin. Here, we can approximate that nearly 800 diamonds are priced between 0 and 1000.

Each bin contains the count of diamonds. Instead, we might want to see what percentage of the diamonds falls into each bin. For that, we will set the `stat `

argument of the `histplot `

function to `percent`

:

>>> sns.histplot(sample["price"], stat="percent")

Now, the height of each bar/bin shows the percentage of the diamonds. Let’s do the same for the carat of the diamonds:

sns.histplot(sample["carat"], stat="percent")

Looking at the first few bars, we can conclude that the majority of the diamonds weigh less than 0.5 carats. Histograms aim to take a numeric variable and show what its shape generally looks like. Statisticians look at the distribution of a variable.

However, histograms aren’t the only plots that do the job. There is also a plot called KDE Plot (Kernel Density Estimate), which uses some fancy math under the hood to draw curves like this:

sns.kdeplot(sample["table"])

Creating the KDE plot of the table variable shows us that the majority of diamonds measure between 55.0 and 60.0. At this point, I will leave it to you to plot the KDEs and histograms of other numeric variables because we have to move on to categorical features.

The most common plot for categorical features is a countplot. Passing the name of a categorical feature in our dataset to Seaborn’s `countplot`

draws a bar chart, with each bar height representing the number of diamonds in each category. Below is a countplot of diamond cuts:

sns.countplot(sample["cut"])

We can see that our dataset consists of much more ideal diamonds than premium or very good diamonds. Here is a countplot of colors for the interested:

sns.countplot(sample["color"])

This concludes the univariate analysis section of the EDA.

Now, let’s look at the relationships between two variables at a time. Let’s start with the connection between diamond carats and price.

We already know that diamonds with higher carats cost more. Let’s see if we can visually capture this trend:

sns.scatterplot(x=sample["carat"], y=sample["price"])

Here, we are using another Seaborn function that plots a scatter plot. Scatterplots are one of the most widely-used charts because they accurately show the relationships between two variables by using a cloud of dots.

Above, each dot represents a single diamond. The dots’ positions are determined by their carat and price measurements, which we passed to X and Y parameters of the scatterplot function.

The plot confirms our assumptions — heavier diamonds tend to be more expensive. We are drawing this conclusion based on the curvy upward trend of the dots.

sns.scatterplot(x=sample["depth"], y=sample["table"])

Let’s try plotting depth against the table. Frankly, this scatterplot is disappointing because we can’t draw a tangible conclusion as we did with the previous one.

Another typical bivariate plot is a boxplot, which plots the distribution of a variable against another based on their five-number summary:

sns.boxplot(x=sample["color"], y=sample["price"])

The boxplot above shows the relationship between each color category and their respective prices. The horizontal vertices at the bottom and top of each vertical line of a box represent that category’s minimum and maximum values. The edges of the boxes, specifically the bottom and top edges, represent the 25th and 75th percentiles.

In other words, the bottom edge of the first box tells us that 25% of D-colored diamonds cost less than about $1,250, while the top edge says that 75% of diamonds cost less than about $4,500. The little horizontal line in the middle denotes the median , the 50% mark.

The dark dots above are outliers. Let’s plot a boxplot of diamond clarities and their relationship with carat:

sns.boxplot(diamonds["clarity"], diamonds["carat"])

Here we see an interesting trend. The diamond clarities are displayed from best to worst, and we can see that lower clarity diamonds weigh more in the dataset. The last box shows that the lowest clarity (l1) diamonds weigh a carat on average.

Finally, it’s time to look at multiple variables at the same time.

The most common multivariate plot you will encounter is a pair plot of Seaborn. Pair plots take several numeric variables and plot every single combination of them against each other. Below, we are creating a pair plot of price, carat, table, and depth features to keep things manageable:

sns.pairplot(sample[["price", "carat", "table", "depth"]])

Every variable is plotted against others, resulting in plot doubles across the diagonal. The diagonal itself contains histograms because each one is a variable plotted against itself.

A pair plot is a compact and single-line version of creating multiple scatter plots and histograms simultaneously.

So far, we have solely relied on our visual intuition to decipher the relationships between different features. However, many analysts and statisticians require mathematical or statistical methods that quantify these relationships to back our “eyeball estimates.” One of these statistical methods is computing a correlation coefficient between features.

The correlation coefficient, often denoted as R, measures how strongly a numeric variable is linearly connected to another. It ranges from -1 to 1, and values close to the range limits denote strong relationships.

In other words, if the absolute value of the coefficient is between 0 and 0.3, it is considered a weak (or no) relationship. If it is between 0.3-0.7, the strength of the relationship is considered moderate, while greater than 0.7 correlation represents a strong connection.

Pandas makes it easy to compute the correlation coefficient between every single feature pair. By calling the `corr `

method on our data frame, we get a correlation matrix:

correlation_matrix = diamonds.corr() >>> correlation_matrix

>>> correlation_matrix.shape (7, 7)

Looking closely, we see a diagonal of 1s. These are perfect relationships because the diagonal contains the correlation between a feature and itself.

However, looking at the raw correlation matrix doesn’t reveal much. Once again, we will use another Seaborn plot called a heatmap to solve this:

>>> sns.heatmap(correlation_matrix)

Passing our correlation matrix to the heatmap function displays a plot that colors each cell of the matrix based on its magnitude. The color bar at the right serves as a legend of what shades of color denote which magnitudes.

But we can do much better. Instead of leaving the viewer to guess the numbers, we can annotate the heatmap so that each cell contains its magnitude:

sns.heatmap(correlation_matrix, square=True, annot=True, linewidths=3)

For this, we set the `annot`

parameter to `True`

, which displays the original correlation on the plot. We also set `square`

to `True`

to make the heatmap square-shaped and, thus, more visually appealing. We also increased the line widths so that each cell in the heatmap is more distinct.

Interpreting this heatmap, we can learn that the strongest relations are among the X, Y, and Z features. They all have >0.8 correlation. We also see that the table and depth are negatively correlated but weakly. We can also confirm our assumptions from the scatterplots — the correlation between carat and price is relatively high at 0.92.

Another approach we can use to explore multivariate relationships is to use scatter plots with more variables. Take a look at the one below:

sns.scatterplot(sample["carat"], sample["price"], hue=sample["cut"])

Now, each dot is colored based on its cut category. We achieved this by passing the `cut`

column to the `hue`

parameter of the `scatterplot`

function. We can pass numeric variables to `hue`

as well:

sns.scatterplot(sample["carat"], sample["price"], hue=sample["x"])

In the above example, we plot carat against price and color each diamond based on its width.

Here we can make two observations:

- Heavier diamonds cost more
- Heavier diamonds are also wider

Instead of encoding the third variable with color, we could have increased the dot size:

sns.scatterplot(sample["carat"], sample["price"], size=sample["y"])

This time, we passed the Y variable to the `size`

argument, which scales the size of the dots based on the magnitude of Y for each diamond. Finally, we can plot four variables at the same time by passing separate columns to both `hue`

and `size`

:

sns.scatterplot(sample["carat"], sample["price"], hue=sample["cut"], size=sample["z"])

Now the plot encodes the diamond cut categories as color and their depth as the size of the dots.

Let’s see a few more complex visuals you can create with Seaborn, such as a subplot. We have already seen an example of subplots when we used the `pairplot`

function:

g = sns.pairplot(sample[["price", "carat", "depth"]])

>>> type(g) seaborn.axisgrid.PairGrid

The `pairplot`

function is shorthand to create a set of subplots called a `PairGrid`

. Fortunately, we are not just limited to the `pairplot`

function. We can create custom `PairGrids`

:

g = sns.PairGrid(sample[["price", "carat", "depth"]])

Passing a dataframe to the `PairGrid `

class returns a set of empty subplots like above. Now, we will use the `map`

function to populate each:

g = sns.PairGrid(sample[["price", "carat", "depth"]]) g.map(sns.scatterplot)

`map`

accepts a name of a Seaborn plotting function and applies it to all subplots. Here, we don’t need scatterplots in the diagonal, so we can populate it with histograms:

g = sns.PairGrid(sample[["price", "carat", "depth"]]) g.map_offdiag(sns.scatterplot) g.map_diag(sns.histplot);

Using the `map_offdiag `

and `map_diag `

functions, we ended up with the same result of `pairplot`

. But we can improve the above chart even further. For example, we can plot different charts in the upper and lower triangles using `map_lower `

and `map_upper`

:

g = sns.PairGrid(sample[["price", "carat", "depth"]]) g.map_lower(sns.scatterplot) g.map_upper(sns.kdeplot) g.map_diag(sns.histplot);

The upper triangle KDE Plots turn into contours because of their 2D nature.

Finally, we can also use the `hue `

parameter to encode a third variable in every subplot:

g = sns.PairGrid(sample[["price", "carat", "depth", "cut"]], hue="cut") g.map_diag(sns.histplot) g.map_offdiag(sns.scatterplot) g.add_legend();

The `hue`

parameter is specified while calling the `PairGrid`

class. We also call the `add_legend`

function on the grid to make the legend visible.

But, there is a problem with the above subplots. The dots are completely overplotted, so we can’t reasonably distinguish any patterns between each diamond cut.

To solve this, we can use a different set of subplots called `FacetGrid`

. A `FacetGrid `

can be created just like a `PairGrid`

but with different parameters:

g = sns.FacetGrid(sample, col="cut")

Passing the cut column to `col `

parameter creates a `FacetGrid`

with five subplots for each diamond cut category. Let’s populate them with `map`

:

g = sns.FacetGrid(sample, col="cut") g.map(sns.scatterplot, "price", "carat");

This time, we have separate scatter plots in separate subplots for each diamond cut category. As you can see, FacetGrid is smart enough to put the relevant axis labels as well.

We can also introduce another categorical variable as a row by passing a column name to the `row `

parameter:

g = sns.FacetGrid(sample, col="cut", row="color") g.map(sns.scatterplot, "price", "carat");

The resulting plot is humongous because there is a subplot for every diamond cut/color combination. There are many other ways you can customize these FacetGrids and PairGrids, so review the docs to learn more.

We have used Seaborn exclusively, but you could consider using Matplotlib.

We used Seaborn because of its simplicity, and, because Seaborn was built on top of Matplotlib, it was designed to complement the weaknesses of Matplotlib, making it more user-friendly.

Another primary reason is the default styles of plots. By default, Seaborn creates more easy-on-the-eye plots. On the other hand, the default styles of Matplotlib plots, well, suck. For example, here is the same histogram of diamond prices:

fig, ax = plt.subplots() ax.hist(sample["price"])

It is vastly different. While Seaborn automatically finds the optimal number of bins, Matplotlib always uses ten bins (although you can change it manually). Another example is the carat vs. price scatterplot:

fig, ax = plt.subplots() ax.scatter(sample["carat"], sample["price"])

Generally, Seaborn suits developers looking to create beautiful charts using less code.

However, the key to a masterpiece visual is in the customization, and that’s where Matplotlib really shines. Though it has a steeper learning curve, once you master it, you can create jaw-dropping visuals such as these.

This tutorial only served as a glimpse of what a real-world EDA might look like. Even though we learned about many different types of plots, there are still more you can create.

From here, you can learn each introduced plot function in-depth. Each one has many parameters, and reading the documentation and trying out the examples should be enough to satisfy your needs to plot finer charts.

I also recommend reading the Matplotlib documentation to learn about more advanced methods in data visualization. Thank you for reading!

The post Data visualization in Python using Seaborn appeared first on LogRocket Blog.

]]>