Computer Vision
One of the most significant applications of CNNs is in the field of computer vision. CNNs are widely used for tasks such as image classification, object detection, and image segmentation. For example, in image classification, a CNN can be trained to recognize different classes of objects, such as cats and dogs, from a dataset of labeled images.
Example:
Let's build a simple CNN for image classification using the MNIST dataset of handwritten digits:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import Flatten, Dense
# Load the MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28, 28, 1).astype('float32') / 255
X_test = X_test.reshape(-1, 28, 28, 1).astype('float32') / 255
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
# Define the model
model = Sequential([
Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
MaxPooling2D(pool_size=(2, 2)),
Flatten(),
Dense(128, activation='relu'),
Dense(10, activation='softmax')
])
# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# Train the model
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test), verbose=2)
# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f'Test accuracy: {accuracy}')
This code snippet defines and trains a simple CNN for image classification on the MNIST dataset.
Image Recognition
Image recognition is another area where CNNs excel. Beyond simple classification, CNNs can be used for more complex tasks such as identifying and localizing multiple objects within an image. Advanced CNN architectures like ResNet, Inception, and YOLO are used in applications ranging from autonomous driving to medical image analysis.
Example:
Let's visualize the performance of our CNN on the MNIST dataset:
import matplotlib.pyplot as plt
# Plot the training and validation accuracy
plt.plot(history.history['accuracy'], label='train_accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)
plt.show()
# Plot the training and validation loss
plt.plot(history.history['loss'], label='train_loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.title('Model Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
This code snippet plots the training and validation accuracy and loss of the CNN model.