Chapter 4: Data Visualization / Lesson 16

Matplotlib Basics

Introduction to Matplotlib

Matplotlib is Python's primary plotting library and is essential for visualizing data in machine learning. Being able to visualize your data helps you understand patterns, identify outliers, and communicate results effectively.

In machine learning, visualization is crucial for: exploring data before modeling, understanding model performance, and presenting results to others.

Why Visualize Data?

Visualization helps you:

  • Understand your data: See distributions, patterns, and relationships
  • Identify problems: Spot outliers, missing data patterns, or data quality issues
  • Communicate findings: Share insights with others visually
  • Debug models: Visualize predictions vs actuals to see where models fail

Basic Line Plot

The simplest plot is a line plot, created with plt.plot():

line_plot.py
import matplotlib.pyplot as plt import numpy as np # Create data x = np.array([1, 2, 3, 4, 5]) y = np.array([2, 4, 6, 8, 10]) # Create plot plt.plot(x, y) plt.xlabel('X values') plt.ylabel('Y values') plt.title('Simple Line Plot') plt.show() print("Line plot created!")

Scatter Plots

Scatter plots are great for showing relationships between two variables:

scatter_plot.py
import matplotlib.pyplot as plt import numpy as np # Generate sample data x = np.array([1, 2, 3, 4, 5, 6, 7, 8]) y = np.array([2, 4, 5, 7, 8, 9, 11, 12]) # Create scatter plot plt.scatter(x, y, color='blue', marker='o') plt.xlabel('Feature X') plt.ylabel('Target Y') plt.title('Scatter Plot: X vs Y') plt.grid(True) plt.show() print("Scatter plot shows the relationship between X and Y")

Histograms

Histograms show the distribution of data—essential for understanding your dataset:

histogram.py
import matplotlib.pyplot as plt import numpy as np # Generate sample data data = np.random.normal(100, 15, 1000) # Mean=100, Std=15, 1000 points # Create histogram plt.hist(data, bins=30, color='skyblue', edgecolor='black') plt.xlabel('Value') plt.ylabel('Frequency') plt.title('Data Distribution Histogram') plt.show() print("Histogram shows how data is distributed")

Multiple Plots

You can create multiple subplots in one figure:

subplots.py
import matplotlib.pyplot as plt import numpy as np # Create figure with subplots fig, axes = plt.subplots(2, 2, figsize=(10, 8)) # Plot 1: Line plot x = np.linspace(0, 10, 100) y1 = np.sin(x) axes[0, 0].plot(x, y1) axes[0, 0].set_title('Sine Wave') # Plot 2: Scatter x2 = np.random.rand(50) y2 = np.random.rand(50) axes[0, 1].scatter(x2, y2) axes[0, 1].set_title('Random Scatter') # Plot 3: Histogram data = np.random.normal(0, 1, 1000) axes[1, 0].hist(data, bins=30) axes[1, 0].set_title('Normal Distribution') # Plot 4: Bar chart categories = ['A', 'B', 'C', 'D'] values = [23, 45, 56, 78] axes[1, 1].bar(categories, values) axes[1, 1].set_title('Bar Chart') plt.tight_layout() plt.show() print("Multiple plots in one figure!")

Customizing Plots

Matplotlib offers extensive customization options:

customization.py
import matplotlib.pyplot as plt import numpy as np x = np.linspace(0, 10, 100) y = np.sin(x) # Create customized plot plt.figure(figsize=(10, 6)) plt.plot(x, y, linewidth=2, color='red', linestyle='--', label='Sine') plt.xlabel('X Axis', fontsize=12) plt.ylabel('Y Axis', fontsize=12) plt.title('Customized Plot', fontsize=14, fontweight='bold') plt.grid(True, alpha=0.3) plt.legend() plt.show() print("Customized plot with labels, colors, and styling!")

💡 Visualization Best Practices

Always label your axes, add titles, and choose appropriate plot types. Good visualizations make your data insights clear and easy to understand. In ML, visualization is often the first step in understanding your problem!

🎉

Lesson Complete!

Great work! Continue to the next lesson.

main.py
📤 Output
Click "Run" to execute...