if name == "main": # 如果作为主程序运行
model = AlexNet() # 实例化AlexNet模型
print(model) # 打印模型结构
summary(
model, input_size=(1, 227, 227), device="cpu"
) # 打印模型摘要,输入尺寸为(1, 227, 227),单通道
训练代码(train.py)
import os # 导入os模块,用于与操作系统交互
import sys # 导入sys模块,用于操作Python运行时环境
sys.path.append(os.getcwd()) # 将当前工作目录添加到sys.path,方便模块导入
import time # 导入time模块,用于计时
from torchvision.datasets import FashionMNIST # 导入FashionMNIST数据集
from torchvision import transforms # 导入transforms用于数据预处理
from torch.utils.data import (
DataLoader, # 导入DataLoader用于批量加载数据
random_split, # 导入random_split用于划分数据集
)
import numpy as np # 导入numpy用于数值计算
import matplotlib.pyplot as plt # 导入matplotlib用于绘图
import torch # 导入PyTorch主库
from torch import nn, optim # 导入神经网络模块和优化器
import copy # 导入copy模块用于深拷贝
import pandas as pd # 导入pandas用于数据处理
from AlexNet_model.model import AlexNet # 从自定义模块导入AlexNet模型
def train_val_date_load(): # 定义函数用于加载训练集和验证集
train_dataset = FashionMNIST(
root="./data", # 数据存储路径
train=True, # 加载训练集
download=True, # 如果数据不存在则下载
transform=transforms.Compose(
[
transforms.Resize(size=227), # 将图片缩放到227x227
transforms.ToTensor(), # 转换为Tensor
]
),
)