Visualizing a 3D plot in Matplotlib
In real-world data science problems we always have datasets with multiple dimensions or features.
In such cases, we can perform univariate visualization or multivariate visualization, where we may plot 2 or 3 attributes together.
We will use the mplot3d toolkit provided by matplotlib.
The mplot3d toolkit adds capabilities for creating a simple 3D plot by supplying an axes object that can create a 2D projection of a 3D scene.
The resulting graph will have the same look and feel as regular 2D plots.
Let us take an example to understand this. We will create a dataset of 3 features to create a 3D plot.
Importing the libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
Creating the dataset
We will create a dataset of 3 feature vectors, each containing 100 random values, using numpy.random module.
x1 = np.random.rand(100)
x2 = np.random.rand(100)
x3 = np.random.rand(100)
Creating the plot
Finally, for creating the 3D plot we will use the Axes3D.scatter() python function. This function takes 3 attributes as arguments.
We first prepare the XYZ plane and name the respective axes as “x1”, “x2” & “x3”.
Any coordinate on the XYZ plane is actually a point (x1, x2, x3) on the coordinate axes.
fig = plt.figure()
axes = Axes3D(fig)
axes.set_title('3D plot in MatplotLib', size=14)
axes.set_xlabel('x1')
axes.set_ylabel('x2')
axes.set_zlabel('x3')
axes.w_xaxis.set_ticklabels(())
axes.w_yaxis.set_ticklabels(())
axes.w_zaxis.set_ticklabels(())
axes.scatter(x1, x2, x3)
