Visualizing a 3D plot in Matplotlib

Less than 500 views Posted On Aug. 14, 2020

In real-world data science problems we always have datasets with multiple dimensions or features.

In such cases, we can perform univariate visualization or multivariate visualization, where we may plot 2 or 3 attributes together.

We will use the mplot3d toolkit provided by matplotlib.

The mplot3d toolkit adds capabilities for creating a simple 3D plot by supplying an axes object that can create a 2D projection of a 3D scene.

The resulting graph will have the same look and feel as regular 2D plots.

Let us take an example to understand this. We will create a dataset of 3 features to create a 3D plot.

Importing the libraries

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


Creating the dataset

We will create a dataset of 3 feature vectors, each containing 100 random values, using numpy.random module.

x1 = np.random.rand(100)
x2 = np.random.rand(100)
x3 = np.random.rand(100)


Creating the plot

Finally, for creating the 3D plot we will use the Axes3D.scatter() python function. This function takes 3 attributes as arguments.

We first prepare the XYZ plane and name the respective axes as “x1”, “x2” & “x3”.

Any coordinate on the XYZ plane is actually a point (x1, x2, x3) on the coordinate axes.

fig = plt.figure()
axes = Axes3D(fig)
axes.set_title('3D plot in MatplotLib', size=14)
axes.scatter(x1, x2, x3)
Share this tutorial with someone who needs it

What are your thoughts?