新冠疫情对全球造成了深远的影响,其中对呼吸系统的影响尤为严重。因此,准确检测新冠感染变得至关重要。本文将探讨如何利用深度学习技术来识别新冠感染。将从数据的导入开始,逐步深入到模型的训练和评估。
在进行新冠检测之前,需要准备相应的工具和数据。以下是需要导入的Python库和数据集。
import pandas as pd
import numpy as np
from keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import ResNet50
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Dense, Flatten, GlobalAveragePooling2D
import tensorflow as tf
import matplotlib.pyplot as plt
首先读取训练数据集,该数据集包含了图像文件名和对应的标签。
train_data = pd.read_csv("Training_set_covid.csv")
print(train_data)
为了能够加载图像,需要为数据集添加一个包含文件路径的列。
train_data["filepath"] = "train/" + train_data["filename"]
print(train_data)
由于数据量较小,通过数据增强来扩充训练集,并创建训练和验证图像的存储变量。
train_datagen = ImageDataGenerator(validation_split=0.2, zoom_range=0.2, rescale=1./255., horizontal_flip=True)
train_data["label"] = train_data["label"].astype(str)
接下来,加载图像,并为训练和验证创建两个数据流。
train_images = train_datagen.flow_from_dataframe(train_data, x_col="filepath", batch_size=8, target_size=(255,255), class_mode="binary", shuffle=True, subset='training', y_col="label")
valid_images = train_datagen.flow_from_dataframe(train_data, x_col="filepath", batch_size=8, target_size=(255,255), class_mode="binary", shuffle=True, subset='validation', y_col="label")
将使用ResNet50模型作为基础模型,并添加自定义的输出层。
base_model = ResNet50(input_shape=(225, 225, 3), include_top=False, weights="imagenet")
for layer in base_model.layers:
layer.trainable = False
base_model = Sequential()
base_model.add(ResNet50(include_top=False, weights='imagenet', pooling='max'))
base_model.add(Dense(1, activation='sigmoid'))
base_model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.0001), loss='binary_crossentropy', metrics=['acc'])
base_model.summary()
模型构建完成后,开始训练模型,并使用验证集来评估模型的性能。
resnet_history = base_model.fit(train_images, validation_data=valid_images, steps_per_epoch=int(train_images.n/8), epochs=20)
plt.plot(resnet_history.history["acc"], label="train")
plt.plot(resnet_history.history["val_acc"], label="val")
plt.title("训练准确率和验证准确率")
plt.legend()
plt.plot(resnet_history.history["loss"], label="train")
plt.plot(resnet_history.history["val_loss"], label="val")
plt.title("训练损失和验证损失")
plt.legend()