Introduction to Matplotlib
Matplotlib is Python's primary plotting library and is essential for visualizing data in machine learning. Being able to visualize your data helps you understand patterns, identify outliers, and communicate results effectively.
In machine learning, visualization is crucial for: exploring data before modeling, understanding model performance, and presenting results to others.
Basic Line Plot
The simplest plot is a line plot, created with plt.plot():
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 2, 3, 4, 5])
y = np.array([2, 4, 6, 8, 10])
plt.plot(x, y)
plt.xlabel('X values')
plt.ylabel('Y values')
plt.title('Simple Line Plot')
plt.show()
print("Line plot created!")
Scatter Plots
Scatter plots are great for showing relationships between two variables:
import matplotlib.pyplot as plt
import numpy as np
x = np.array([1, 2, 3, 4, 5, 6, 7, 8])
y = np.array([2, 4, 5, 7, 8, 9, 11, 12])
plt.scatter(x, y, color='blue', marker='o')
plt.xlabel('Feature X')
plt.ylabel('Target Y')
plt.title('Scatter Plot: X vs Y')
plt.grid(True)
plt.show()
print("Scatter plot shows the relationship between X and Y")
Histograms
Histograms show the distribution of data—essential for understanding your dataset:
import matplotlib.pyplot as plt
import numpy as np
data = np.random.normal(100, 15, 1000)
plt.hist(data, bins=30, color='skyblue', edgecolor='black')
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Data Distribution Histogram')
plt.show()
print("Histogram shows how data is distributed")
Multiple Plots
You can create multiple subplots in one figure:
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(2, 2, figsize=(10, 8))
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
axes[0, 0].plot(x, y1)
axes[0, 0].set_title('Sine Wave')
x2 = np.random.rand(50)
y2 = np.random.rand(50)
axes[0, 1].scatter(x2, y2)
axes[0, 1].set_title('Random Scatter')
data = np.random.normal(0, 1, 1000)
axes[1, 0].hist(data, bins=30)
axes[1, 0].set_title('Normal Distribution')
categories = ['A', 'B', 'C', 'D']
values = [23, 45, 56, 78]
axes[1, 1].bar(categories, values)
axes[1, 1].set_title('Bar Chart')
plt.tight_layout()
plt.show()
print("Multiple plots in one figure!")
Customizing Plots
Matplotlib offers extensive customization options:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 10, 100)
y = np.sin(x)
plt.figure(figsize=(10, 6))
plt.plot(x, y, linewidth=2, color='red', linestyle='--', label='Sine')
plt.xlabel('X Axis', fontsize=12)
plt.ylabel('Y Axis', fontsize=12)
plt.title('Customized Plot', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()
print("Customized plot with labels, colors, and styling!")
💡 Visualization Best Practices
Always label your axes, add titles, and choose appropriate plot types. Good visualizations make your data insights clear and easy to understand. In ML, visualization is often the first step in understanding your problem!