news 2026/4/16 0:03:09

如何用机器学习解决简单问题

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
如何用机器学习解决简单问题

原文:towardsdatascience.com/how-to-solve-a-simple-problem-with-machine-learning-9efd03d0fe69

管理者和工程师的机器学习课程

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/944d3832d1e8cf7fb909a60c0e517e27.png

作者创建的图片

欢迎回到我的系列课程第二课,管理者和工程师的机器学习课程。今天,应大家的要求,我将向您展示如何在第一课中提到的解决方案进行实现。

这比最初为这个系列设计的课程更技术性,但我相信大多数专业人士都从更好地理解机器学习技术中受益。

为了尽可能保持相关性,我将主要关注底层推理,因为那里有宝贵的教训。如果您想详细研究代码,页面底部有一个 GitHub 链接。

哦,别忘了查看我的其他机器学习课程! 👇

管理者和工程师的机器学习课程


第一课的回顾

在第一课中,我解释了机器学习是简单问题的有效解决方案,即使传统方法也能解决它们。我的观点是,机器学习通常提供最直接、易于维护和健壮的替代方案,这与普遍认为它是仅在所有其他方法都失败的情况下才使用的技术的观点相矛盾。

为了证明我的观点,我提出一个用例,我想在轨道图像中检测铁轨头部。大多数工程师会创建一个基于传统计算机视觉的解决方案,根据像素值的强度和变化创建规则。

虽然这是一个有效的方法,但我使用机器学习解决了这个问题。编写代码、标注图像和训练算法花了我一个小时。由于很多人问我是否能分享一个 GitHub 仓库,我想用这个第二个教训来详细解释实现过程。

让我们继续进行演练(您可以在页面底部找到 GitHub 仓库的链接)。


铁轨检测演练

我的目标是创建一个机器学习解决方案,我可以在 MacBook 的 CPU 上训练和运行它,这限制了算法和数据分辨率的大小。我还想使用一些简单、基于常识和直接的技术。

简化问题

在处理机器学习问题时,您需要考虑如何以算法可以学习关键模式的方式表达它。根据下面的图像,我的思考过程如下。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/ab2ae81884ac83f3273f4f0cc51bf72c.png

作者创建的图片

  1. 我想创建一个算法,该算法可以检测两条铁轨的中心。

  2. 我大致知道铁轨的位置,但位置变化在 5-10%之间。

  3. 从算法的角度来看,两条轨道之间没有区别。

  4. 如果我创建作物,我可以制作出两个总是包括轨道的图像。

  5. 这样简化了问题,因为算法只需要找到一条轨道。

  6. 它还通过删除大量冗余信息使任务更容易。

  7. 在较小的作物上训练比使用整个图像快。

由于这个原因,我决定在像这两个这样的作物上训练算法,而不是在上面的整个图像上,这样既简单又快。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/92232493012c60d1cbe3eaddd02542d4.png

作者创建的图片

图像标注

当人们使用 YOLO 等工具工作时,他们通常会使用标注工具,然后直接使用输出训练模型。这可以工作,但很容易忘记你可以用相同的标注以不同的方式表达问题。在我的经验中,你如何表达训练问题是创建一个好的算法的决定性因素。

对于这个简单的任务,我想在我的 Jupyter Notebook 中做标注。我想出的最快方法是写下每个作物轨道中心的像素位置。我的“标注工具”显示图像的轨道区域,我输入两个用空格分隔的整数,描述我认为轨道中心的位置。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/b9336b5b869d100251e206a56440434a.png

作者创建的图片

创建这样一个简单的标注工具是直接的。以下三个功能构成了我为此任务设计的标注工具,并且直接在我的 Jupyter Notebook 中运行。

defload_data():withopen("rail_locations.json")asf:returnjson.load(f)defshow_rails(image,start_indexes,number_of_pixels):fig,axes=plt.subplots(nrows=1,ncols=2,figsize=(20,5))foriinrange(2):axes[i].imshow(image.crop([start_indexes[i],0,start_indexes[i]+number_of_pixels,image.size[1]]),cmap="gray")axes[i].set_xticks([iforiinrange(0,number_of_pixels,5)])axes[i].set_yticks([])plt.tight_layout()plt.show()defannotate_rail_images(start_indexes,number_of_pixels,split):training_data=load_data()np.random.shuffle(training_data)forindex,tinenumerate(training_data):if"locations"int.keys():continueclear_output()image=Image.open(t["image"])show_rails(image,start_indexes,number_of_pixels)locations=input("Rail centers:").split()locations=[int(l)+start_indexes[i]fori,linenumerate(locations)]training_data[index]["locations"]=locations training_data[index]["split"]=splitwithopen("rail_locations.json","w")asf:json.dump(training_data,f)clear_output()

表达训练问题

我最初的想法是训练一个算法给我一个介于 0 和 1 之间的值,描述每个图像中轨道的位置。那种方法的缺点是它给算法提供了一个弱训练信号,使得任务更难学习。

相反,我决定算法应该给我每个像素列的概率,告诉我该列是否包含轨道的中心。由于我将每个作物调整到 64×128 像素,这意味着我的算法输出一个包含 128 个值的向量。

以那种方式表达问题给算法提供了更好的训练信号,因为它被迫做出 128 个预测(每个列一个),而不是只有一个。我将标签稀释到一行有五个 1,以简化任务。

defcreate_label(x,size,padding):label=torch.zeros(size)start=max(0,int(round(size*x))-padding)end=min(size-1,int(round(size*x))+padding+1)label[start:end]=1returnlabel

我想强调,表达这个问题有很多人正确的办法,其中一些比我的方法更好。然而,在机器学习中,你很少需要找到最佳方法来创建一个可行的解决方案,理解什么是足够好的是至关重要的。

创建额外的训练数据

现在我已经标注了数据,我想尽可能多地从每个数据中创建训练示例。大多数人认为答案是数据增强,但您通常可以做得更好。我仍然可以访问整个图像,并且通过使用标注和翻转,我可以将一个裁剪图变成 400 多个独特的训练示例。

这里是九个示例,其中我用一个裁剪图创建了多个训练图像。线条标记了图像的列,我期望算法在这里给出高概率。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/18a57c91fd8850064a49981027346ead.png

作者创建的图像

我加载数据的代码如下。data_point是一个字典,其中image指向原始未裁剪的图像,locations包含两条轨道的中心像素。

defflip_image(crop,x):ifnp.random.rand()<0.5:crop=crop.transpose(Image.FLIP_LEFT_RIGHT)x=1-xifnp.random.rand()<0.5:crop=crop.transpose(Image.FLIP_TOP_BOTTOM)defcreate_crop(image,x,start_index,train,number_of_pixels,size,padding):x_start=x-(np.random.uniform(10,number_of_pixels-10))iftrainelsestart_index crop=image.crop([x_start,0,x_start+number_of_pixels,image.size[1]])crop=crop.resize((size,64),Image.LANCZOS)x=(x-x_start)/number_of_pixelsiftrain:flip_image(crop,x)crop=TF.adjust_brightness(crop,np.random.uniform(0.5,1.5))crop=TF.adjust_contrast(crop,np.random.uniform(0.8,1.2))label=create_label(x,size,padding)returnTF.to_tensor(crop),labeldefget_crops(data_point,train=False):left_image,x_left=create_crop(image=data_point["image"],x=data_point["locations"][0],start_index=START_INDEX_LEFT,train=train,number_of_pixels=NUMBER_OF_PIXELS,size=IMAGE_WIDTH,padding=LABEL_PADDING)right_image,x_right=create_crop(image=data_point["image"],x=data_point["locations"][1],start_index=START_INDEX_RIGHT,train=train,number_of_pixels=NUMBER_OF_PIXELS,size=IMAGE_WIDTH,padding=LABEL_PADDING)returntorch.cat([left_image,right_image]),torch.cat([x_left,x_right])

当我添加额外的数据增强,例如调整亮度和对比度时,我得到了更多的训练数据,而且由于这些技术,我可以通过标注不超过十张图片来解决这个问题。

算法

在解决像这样的简单问题时,如何设计算法是最不关键的部分。有无限多的架构可以完成这项任务,而且在这些架构之间切换几乎不会产生差异。我决定使用一个只有 14,000 个参数的简单卷积神经网络(CNN)。

classRailDetector(torch.nn.Module):def__init__(self,size=1):super(RailDetector,self).__init__()self.conv1=torch.nn.Conv2d(1,size,kernel_size=3,stride=2,padding=1)self.conv2=torch.nn.Conv2d(size,2*size,kernel_size=3,stride=2,padding=1)self.conv3=torch.nn.Conv2d(2*size,4*size,kernel_size=3,stride=2,padding=1)self.embedding=torch.nn.Linear(4*size*2,IMAGE_WIDTH)defforward(self,x):x=F.relu(F.max_pool2d(self.conv1(x),(2,2)))x=F.relu(F.max_pool2d(self.conv2(x),(2,2)))x=F.relu(F.max_pool2d(self.conv3(x),(2,2)))returnself.embedding(x.view(x.shape[0],-1))Trainingandevaluation

我使用 PyTorch 训练算法,就像您训练任何算法一样。我使用二元交叉熵损失函数(BCE loss function)和 Adam 优化器。为了评估我的验证数据上的模型,我计算它平均偏离中心像素有多远。每当这个数字提高时,我就保存模型权重。

在训练我的算法十分钟之后,这个数字达到大约 0.6 像素,这对于我的预期用途来说已经足够好了。在这里,您可以看到我的验证数据裁剪图,以及标签(红色线条)和预测(蓝色线条)。

https://github.com/OpenDocCN/towardsdatascience-blog-zh-2024/raw/master/docs/img/460dc4d7eca9c77c0b63f3f7a7878fda.png

作者创建的图像


结论

在本课中,我想通过回顾我在第一课中描述的解决方案的思考过程和方法来提高您对机器学习的理解。我希望您在工作中有一些顿悟的时刻。

链接

ML Lessons for Managers and Engineers | Oscar Leo | Substack

ml-lessons-for-managers-and-engineers/rail-head-detection at main ·…

不要害怕为简单任务使用机器学习

感谢您的阅读。别忘了分享和订阅!

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/4/16 12:50:57

STM32工业阀门控制项目:Keil5操作指南

STM32工业阀门控制实战&#xff1a;从Keil5环境搭建到系统实现 你有没有遇到过这样的场景&#xff1f; 现场的阀门响应迟钝、动作不精准&#xff0c;故障了还得派人爬高去手动排查&#xff1b;上位机发个指令&#xff0c;等半天才看到执行结果&#xff0c;还无法确认是否到位…

作者头像 李华
网站建设 2026/4/4 3:02:29

大模型推理服务灰度策略管理系统

大模型推理服务灰度策略管理系统中的 TensorRT 实践 在当前大语言模型&#xff08;LLM&#xff09;加速落地的背景下&#xff0c;推理服务的性能与稳定性直接决定了产品的用户体验和上线节奏。尤其是在需要频繁迭代、多版本并行验证的“灰度发布”场景中&#xff0c;如何在保证…

作者头像 李华
网站建设 2026/4/16 14:02:54

NVIDIA官方技术咨询预约:TensorRT专家坐诊

NVIDIA官方技术咨询预约&#xff1a;TensorRT专家坐诊 在当今AI应用爆发式增长的时代&#xff0c;一个训练完成的深度学习模型从实验室走向生产环境&#xff0c;往往面临“落地难”的困境——明明在开发阶段表现优异&#xff0c;部署后却出现延迟高、吞吐低、资源消耗大的问题。…

作者头像 李华
网站建设 2026/4/16 14:24:53

Keil5添加文件手把手教程:图文详解每一步骤

Keil5添加文件实战指南&#xff1a;从零开始搞懂工程结构与编译逻辑你有没有遇到过这样的情况&#xff1f;写好了led_driver.c和led_driver.h&#xff0c;在main.c里#include "led_driver.h"&#xff0c;结果一编译——Error: Cannot open source file ‘led_driver.…

作者头像 李华
网站建设 2026/4/16 10:56:51

NVIDIA官方技术大会演讲回放:TensorRT专场

NVIDIA TensorRT&#xff1a;从模型到生产的推理加速引擎 在当今AI应用爆发式增长的时代&#xff0c;一个训练好的深度学习模型是否真正“有用”&#xff0c;早已不再只看准确率。真正的考验在于——它能不能在真实场景中快速、稳定、低成本地跑起来。 想象这样一个画面&#x…

作者头像 李华
网站建设 2026/4/16 10:57:15

生产消费者模型

生产消费者模型概念与作用概念&#xff1a;它通过一个容器&#xff08;缓冲区&#xff09;来解决生产者和消费者之间的强耦合问题。解耦&#xff1a;生产者只管生产&#xff0c;消费者只管消费&#xff0c;它们互不认识&#xff0c;只通过缓冲区交互。支持并发&#xff1a;生产…

作者头像 李华