Seaborn is a python library for creating plots. It is based on matplotlib and provides a high-level interface for drawing statistical graphics.

Seaborn integrates nicely with pandas: It operates on DataFrames and arrays and does aggregations and semantic mapping automatically, which makes it a quick, convenient option for data visualization in your data projects. One you understand the basic concepts, you can create plots really easily without using stack overflow too much.

The basic steps for creating plots with seaborn are:

  1. Import libraries
  2. Prepare the data
  3. [optional] Control figure aesthetics (Grids, Color etc.)
  4. Create plot
  5. [optional] Further customization of your plot, which also can overwrite the figure aesthetics defined in step 2 (title, labels, tick marks etc.)
  6. Show or save the plot

Steps

Let’s walk through the steps of creating plots.

0 Import libraries

As mentioned above, searborn is built on top of matplotlib, but does not provide the full functionality. Therefore, we should always import both, matplotlib (with the pyplot module) and seaborn. It is likely that you will need matplotlib at some point (for example for changing the plot title or the tick labels and to show the plots).

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline # make sure plots are shown in a jupyter notebook. otherwise you will need plt.show()

1. Prepare Data

You can work with a pandas DataFrame. In general, you do not have to pre-aggregate your data. For example for count plots, seaborn will automatically calculate the counts for the count plot.

2. Figure Aesthetics

This is an optional step. You can just go with the default configuration here, or you can define aesthetics such as grid, axis, tick marks, color and font sizes.

Axis and Grid Style: per default, you will have a dark grid. You can change between grids (darkgrid, whitegrid), no grids (dark, white) and ticks (ticks).

sns.set_style("white")

remove the top and right axis with

sns.despine()

Color: Choose a color palette and set it as the default.

sns.set_palette(palette="rocket", n_colors=3)

Scaling: You can choose between four contexts notebook (default), paper, talk, and poster which scales font and lines according to the use case. But you can also override some.

sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})

Or set multiple parameters in one step with set() e.g. sns.set(context='notebook', style='white', font='Helvetica', font_scale=1, color_codes=True, rc=None))

3. Plot

In this step, you define what kind of plot you would like. There are two important distinctions you need to understand: figure and axes

A Figure refers to the whole figure that you see. A figure may consist of a single plot, but it is possible to have multiple subplots in the same figure. A specific sub-plot is called an Axes. There are functions that have an effect on the entire figure (figure-level functions) and some functions have an effect just on the specified subplot (axes-level-function).

For each type of plot (distributions, categorical, relational), there is a figure-level function, which takes the kind of plot as a parameter. For example, displot() is the figure-level function for plotting distributions. By passing the parameter kind="hist", it draws a histogram.

sns.displot(data, x, kind = "hist")

Alternatively, you can choose the axes-level function histplot() to plot a histogram.

sns.histplot(data, x)

The table gives an overview of the figure-level functions and their axes-level functions.

type of plot figure-level function axes-level functions
distributions displot e.g. histplot, rugplot
categorical catplot e.g. barplot, boxplot
relational relplot e.g. scatterplot, lineplot

You can use both, figure-level and axes level functions for creating a single plot figure.

4. Customize

For customization, you will use some functions from matplotlib.pyplot. Some often used customizations are adding labels to the axis (plt.xlabel and plt.ylabel) or determining the start and end of an axis (plt.ylim(0,100),plt.xlim(0,10))

plt.title("A Title") # add a plot title
plt.ylabel("Survived") # add a y-axis label
plt.xlabel("Sex") # add an x-axis label
plt.ylim(0,100) # start and end of y-axis
plt.xlim(0,10) # start and end of x-axis

5. Save or Show

If you use %matplotlib inline in your jupyter notebook, you will not need to call any functions for your plot to show up in the notebook. Otherwise, you will have to use plt.show() to show a plot and plt.savefig("foo.png") to save your figure.

plt.show()
plt.savefig("foo.png")

Examples: One Plot per figure

Let’s look at some examples for plots that are useful for understanding features of your dataset. First, we’ll tackle descriptives (count plots and distributions) and then relationships between features. We’ll use the built-in titanic dataset:

# (0) import packages and datasets
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

titanic = sns.load_dataset('titanic')

#(2) Figure aesthetics
sns.set(context='notebook', style='white', palette='rocket', font='Helvetica', font_scale=1, color_codes=True)

titanic.head(5)
survived pclass sex age sibsp parch fare embarked class who adult_male deck embark_town alive alone
0 0 3 male 22 1 0 7.25 S Third man True nan Southampton no False
1 1 1 female 38 1 0 71.2833 C First woman False C Cherbourg yes False
2 1 3 female 26 0 0 7.925 S Third woman False nan Southampton yes True
3 1 1 female 35 1 0 53.1 S First woman False C Southampton yes False
4 0 3 male 35 0 0 8.05 S Third man True nan Southampton no True

Plot distribution of categorical features with a countplot

Count plots automatically count the number of occurrences of each categorical feature value and displays them in a bar plot. You can just pass the dataset and the variable you want to plot as parameters. So no need to aggregate the data up front.

sns.countplot(x = 'class', data = titanic)

plt.title("Count passengers by class")
plt.xlabel("Class")
plt.ylabel("Count nr passengers")

You can annotate the bars with the counts with plt.annotate. It takes as arguments a text (the counts, which equals the height of the bar), xy (coordinates for the annotation) and data. We loop over each bar (plt.patches) in the plot and add the annotation:

c = sns.countplot(x = 'class', data = titanic)

for p in c.patches:
    c.annotate(p.get_height(), 
                (p.get_x() + p.get_width() / 2.0, 
                 p.get_height()), 
                 ha = 'center', 
                 va = 'center', 
                 xytext = (0, 5),
                 textcoords = 'offset points')

Plot distribution of numeric features in a histogram

Plotting the distribution of one feature is simple: You can use histplot() and pass the dataset and the feature you want to plot as a parameter.

sns.histplot(data=titanic, x="age")
plt.title("Age distribution")
plt.xlabel("Age customer")
plt.ylabel("Value")

You can also plot the distributions and compare across multiple groups with hue.

sns.histplot(data=titanic, x="age", hue="survived", multiple="stack")
plt.title("Age distribution")
plt.xlabel("Age customer")
plt.ylabel("Value")

Plot relationships between features with a scatterplot

You can create a simple scatter plot with 2 variables like this:

sns.scatterplot(data = titanic,
               x = 'age', 
               y = 'fare')
plt.title("Age vs. Fare")
plt.xlabel("Age")
plt.ylabel("Fare")

You can compare more than 2 variables with multiple groups using the parameters hue and style. Each will add a new group. If you specify hue and style for the same group it makes them more distinguishable.

sns.scatterplot(data = titanic,
               x = 'age', 
               y = 'fare',
               hue = 'sex',
               style = 'sex')
plt.title("Age vs. fare with sex")
plt.xlabel("Age")
plt.ylabel("Fare")

Multiple Plots in one figure

If you want multiple, maybe even different types of plots in one figure, you can use subplots() from the matplotlib.pyplot library. plt.subplots(nrows, ncols) creates a single figure instance, with nrows * ncols axes instances and returns the created figure and axes instances. The grid of the axes is returned as an nrows * ncols shaped numpy array. If we passed for example the parameter nrows=2 ncols=2, we create a subplot with 2 rows and 2 columns. With the ax parameter in the axes-level functions, we specify the position of the plot. Let’s assume, we wanted to plot 2 plots side by side, we have to create a git with 1 row and 2 columns and put the plots at the correct positions.

# instanciate figure and axes instances
fig,axes = plt.subplots(1,2)

# first plot
sns.histplot(data=titanic, x="age", hue="survived", multiple="stack",ax=axes[0]) #first position
plt.xlabel("Age")
plt.ylabel("Counts")

# second plot 
sns.scatterplot(data = titanic,
               x = 'age', 
               y = 'fare',
               hue = 'survived',
               style = 'survived',ax=axes[1]) #second position
plt.xlabel("Age")
plt.ylabel("Fare")

You can plot as many axes as you want into one figure. You can pass the position of the plot as ax = 'axes[row, column]. If you add more plots, you may need to increase the size of the figure so the plots are still readable

fig,axes = plt.subplots(2,2) 

# increase figure size
fig.set_figheight(10) 
fig.set_figwidth(10)

sns.histplot(data=titanic, x="age", ax=axes[0,0])
sns.histplot(data=titanic, x="fare", ax=axes[0,1])
sns.histplot(data=titanic, x="pclass", ax=axes[1,0])
sns.histplot(data=titanic, x="survived", ax=axes[1,1])

Facetgrids

The figure-level functions relplot, catplot() and displot are built on a FacetGrid. This means that it is easy to add faceting variables to visualize higher-dimensional relationships:

sns.displot(data=titanic, x="age", hue="survived", col="survived")

The FacetGrid class is useful when you want to visualize the distribution of a feature or the relationship between multiple features separately within subsets of your dataset. You may use three parameters to specify the dimensions row, col, and hue. row and col correspond to the array of axes. You can use hue to differentiate between different levels as you did in the scatter plot examples above.

with sns.axes_style("white"):
    g = sns.FacetGrid(titanic, row="sex", col="pclass", margin_titles=True, height=2.5)
g.map(sns.barplot, "survived", "fare")
g.fig.subplots_adjust(wspace=.02, hspace=.02)

Quickly explore numerical features with a pairplot

One quick and easy way to start exploring a dataset is with sns.pairplot(). It automatically draws a scatter plot for the correlations between all numeric variables. The diagonal shows the distribution of each numeric feature.

For example for the first 4 columns in the titanic dataset, we will get a 3 x 3 grid, because there are 3 numeric variables in the dataset in the first 4 columns.

sns.pairplot(titanic.iloc[:,:4])