V2EX = way to explore
V2EX 是一个关于分享和探索的地方
现在注册
已注册用户请  登录
推荐学习书目
Learn Python the Hard Way
Python Sites
PyPI - Python Package Index
http://diveintopython.org/toc/index.html
Pocoo
值得关注的项目
PyPy
Celery
Jinja2
Read the Docs
gevent
pyenv
virtualenv
Stackless Python
Beautiful Soup
结巴中文分词
Green Unicorn
Sentry
Shovel
Pyflakes
pytest
Python 编程
pep8 Checker
Styles
PEP 8
Google Python Style Guide
Code Style from The Hitchhiker's Guide
canxun
V2EX  ›  Python

求助 Python 大佬 怎样给他修改成直接输入图片

  •  
  •   canxun · 2021-11-26 11:29:33 +08:00 · 1561 次点击
    这是一个创建于 1090 天前的主题,其中的信息可能已经有所发展或是发生改变。
    现在是随机调用库的图片 能不能改成指定图片
    比如说("2.jpg")
    这样



    """

    ****************** 实现 MNIST 手写数字识别 ************************


    ****************************************************************

    """

    # -*- coding: utf-8 -*-

    import cv2
    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torchvision
    from torchvision import datasets, transforms



    # 默认预测四张含有数字的图片

    BATCH_SIZE = 4
    # 默认使用 cpu 加速
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")



    # 构建数据转换列表

    tsfrm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1037,), (0.3081,))
    ])

    # 测试集

    test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root = 'data', train = False, download = True,
    transform = tsfrm),
    batch_size = BATCH_SIZE, shuffle = True)



    # 定义图片可视化函数

    def imshow(images):
    img = torchvision.utils.make_grid(images)
    img = img.numpy().transpose(1, 2, 0)
    std = [0.5, 0.5, 0.5]
    mean = [0.5, 0.5, 0.5]
    img = img * std + mean
    # 将图片高和宽分别赋值给 x1,y1
    x1, y1 = img.shape[0:2]
    # 图片放大到原来的 5 倍,输出尺寸格式为(宽,高)
    enlarge_img = cv2.resize(img, (int(y1*5), int(x1*5)))
    cv2.imshow('image', enlarge_img)
    cv2.waitKey(0)



    # 定义一个 LeNet-5 网络,包含两个卷积层 conv1 和 conv2 ,两个线性层作为输出,最后输出 10 个维度

    # 这 10 个维度作为 0-9 的标识来确定识别出的是哪个数字。

    class ConvNet(nn.Module):
    def __init__(self):
    super().__init__()
    # 1*1*28*28
    # 1 个输入图片通道,10 个输出通道,5x5 卷积核
    self.conv1 = nn.Conv2d(1, 10, 5)
    self.conv2 = nn.Conv2d(10, 20, 3)
    # 全连接层、输出层 softmax,10 个维度
    self.fc1 = nn.Linear(20 * 10 * 10, 500)
    self.fc2 = nn.Linear(500, 10)


    # 正向传播
    def forward(self, x):
    in_size = x.size(0)
    out = self.conv1(x) # 1* 10 * 24 *24
    out = F.relu(out)
    out = F.max_pool2d(out, 2, 2) # 1* 10 * 12 * 12
    out = self.conv2(out) # 1* 20 * 10 * 10
    out = F.relu(out)
    out = out.view(in_size, -1) # 1 * 2000
    out = self.fc1(out) # 1 * 500
    out = F.relu(out)
    out = self.fc2(out) # 1 * 10
    out = F.log_softmax(out, dim=1)
    return out



    # 主程序入口
    if __name__ == "__main__":
    model_eval = ConvNet()
    # 加载训练模型
    model_eval.load_state_dict(torch.load('./MNISTModel.pkl', map_location=DEVICE))
    model_eval.eval()
    # 从测试集里面拿出几张图片
    images,labels = next(iter(test_loader))
    # 显示图片
    imshow(images)
    # 输入
    inputs = images.to(DEVICE)
    # 输出
    outputs = model_eval(inputs)
    # 找到概率最大的下标
    _, preds = torch.max(outputs, 1)
    # 打印预测结果
    numlist = []
    for i in range(len(preds)):
    label = preds.numpy()[i]
    numlist.append(label)
    List = ' '.join(repr(s) for s in numlist)

    print('当前预测的数字为: ',List)
    5 条回复    2021-12-05 13:40:27 +08:00
    canxun
        1
    canxun  
    OP
       2021-11-26 11:42:00 +08:00
    wuhu
    coderluan
        2
    coderluan  
       2021-11-26 11:47:42 +08:00
    # 从测试集里面拿出几张图片
    images,labels = next(iter(test_loader))

    改成

    images=cv2.imread("2.jpg")
    canxun
        3
    canxun  
    OP
       2021-11-26 11:53:05 +08:00
    @coderluan File "d:/pycode/.vscode/Untitled-1.py", line 107, in <module>
    imshow(images)
    File "d:/pycode/.vscode/Untitled-1.py", line 51, in imshow
    img = torchvision.utils.make_grid(images)
    File "C:\ruanjian\python\lib\site-packages\torch\autograd\grad_mode.py", line 28, in decorate_context
    return func(*args, **kwargs)
    File "C:\ruanjian\python\lib\site-packages\torchvision\utils.py", line 46, in make_grid
    raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
    TypeError: tensor or list of tensors expected, got <class 'NoneType'>
    canxun
        4
    canxun  
    OP
       2021-11-28 23:43:23 +08:00
    11
    imn1
        5
    imn1  
       2021-12-05 13:40:27 +08:00
    你要将
    test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(root = 'data', train = False, download = True,
    transform = tsfrm),
    batch_size = BATCH_SIZE, shuffle = True)
    这句改成函数,参数就是里面的'data',并返回 test_loader

    这句 是从 data 目录获取文件的,如果你要改成单文件,就要看手册 torch.utils.data.DataLoader 是否提供这个功能
    不提供的话,你就需要把指定文件扔进某个目录,并把目录路径传给 root 这个位置
    关于   ·   帮助文档   ·   博客   ·   API   ·   FAQ   ·   实用小工具   ·   3300 人在线   最高记录 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 28ms · UTC 11:48 · PVG 19:48 · LAX 03:48 · JFK 06:48
    Developed with CodeLauncher
    ♥ Do have faith in what you're doing.