@浙大疏锦行
importtorchimporttorch.nnasnnimporttorch.nn.functionalasFimporttorchvisionimporttorchvision.transformsastransformsimportnumpyasnpimportmatplotlib.pyplotaspltfromPILimportImage torch.manual_seed(42)np.random.seed(42)transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])testset=torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)classes=('飞机','汽车','鸟','猫','鹿','狗','青蛙','马','船','卡车')classSimpleCNN(nn.Module):def__init__(self):super(SimpleCNN,self).__init__()self.conv1=nn.Conv2d(3,32,kernel_size=3,padding=1)self.conv2=nn.Conv2d(32,64,kernel_size=3,padding=1)self.conv3=nn.Conv2d(64,128,kernel_size=3,padding=1)self.pool=nn.MaxPool2d(2,2)self.fc1=nn.Linear(128*4*4,512)self.fc2=nn.Linear(512,10)defforward(self,x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=self.pool(F.relu(self.conv3(x)))x=x.view(-1,128*4*4)x=F.relu(self.fc1(x))x=self.fc2(x)returnx# 初始化模型model=SimpleCNN()print("模型已创建")device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")model=model.to(device)deftrain_model(model,epochs=1):trainset=torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True,num_workers=2)criterion=nn.CrossEntropyLoss()optimizer=torch.optim.Adam(model.parameters(),lr=0.001)forepochinrange(epochs):running_loss=0.0fori,datainenumerate(trainloader,0):inputs,labels=data inputs,labels=inputs.to(device),labels.to(device)optimizer.zero_grad()outputs=model(inputs)loss=criterion(outputs,labels)loss.backward()optimizer.step()running_loss+=loss.item()ifi%100==99:print(f'[{epoch+1},{i+1}] 损失:{running_loss/100:.3f}')running_loss=0.0print("训练完成")try:model.load_state_dict(torch.load('cifar10_cnn.pth'))print("已加载预训练模型")except:print("无法加载预训练模型,使用未训练模型或训练新模型")train_model(model,epochs=1)torch.save(model.state_dict(),'cifar10_cnn.pth')model.eval()classGradCAM:def__init__(self,model,target_layer):self.model=model self.target_layer=target_layer self.gradients=Noneself.activations=Noneself.register_hooks()defregister_hooks(self):defforward_hook(module,input,output):self.activations=output.detach()defbackward_hook(module,grad_input,grad_output):self.gradients=grad_output[0].detach()self.target_layer.register_forward_hook(forward_hook)self.target_layer.register_backward_hook(backward_hook)defgenerate_cam(self,input_image,target_class=None):model_output=self.model(input_image)iftarget_classisNone:target_class=torch.argmax(model_output,dim=1).item()self.model.zero_grad()one_hot=torch.zeros_like(model_output)one_hot[0,target_class]=1model_output.backward(gradient=one_hot)gradients=self.gradients activations=self.activations weights=torch.mean(gradients,dim=(2,3),keepdim=True)cam=torch.sum(weights*activations,dim=1,keepdim=True)cam=F.relu(cam)cam=F.interpolate(cam,size=(32,32),mode='bilinear',align_corners=False)cam=cam-cam.min()cam=cam/cam.max()ifcam.max()>0elsecamreturncam.cpu().squeeze().numpy(),target_classimportwarnings warnings.filterwarnings("ignore")importmatplotlib.pyplotasplt plt.rcParams["font.family"]=["SimHei"]plt.rcParams['axes.unicode_minus']=Falseidx=102image,label=testset[idx]print(f"选择的图像类别:{classes[label]}")deftensor_to_np(tensor):img=tensor.cpu().numpy().transpose(1,2,0)mean=np.array([0.5,0.5,0.5])std=np.array([0.5,0.5,0.5])img=std*img+mean img=np.clip(img,0,1)returnimg input_tensor=image.unsqueeze(0).to(device)grad_cam=GradCAM(model,model.conv3)heatmap,pred_class=grad_cam.generate_cam(input_tensor)plt.figure(figsize=(12,4))plt.subplot(1,3,1)plt.imshow(tensor_to_np(image))plt.title(f"原始图像:{classes[label]}")plt.axis('off')plt.subplot(1,3,2)plt.imshow(heatmap,cmap='jet')plt.title(f"Grad-CAM热力图:{classes[pred_class]}")plt.axis('off')plt.subplot(1,3,3)img=tensor_to_np(image)heatmap_resized=np.uint8(255*heatmap)heatmap_colored=plt.cm.jet(heatmap_resized)[:,:,:3]superimposed_img=heatmap_colored*0.4+img*0.6plt.imshow(superimposed_img)plt.title("叠加热力图")plt.axis('off')plt.tight_layout()plt.savefig('grad_cam_result.png')plt.show()