Pytorch 计算深度模型的大小

计算模型大小的方法

卷积 时间复杂度 与 空间复杂度 的计算方式:
在这里插入图片描述

C 通道的个数,K卷积核大小,M特征图大小,C_l-1是输入通道的个数,C_l是输出通道的个数

1 模型大小 MB

计算模型的大小的原理就是计算保存模型所需要的存储空间的大小,一般以字节为单位,由于模型常常较大,通常使用 MB (million byte)为单位,在算法层面是就是空间复杂度。

NOTE: 有的地方算 参数量 or 模型大小 会x4,因为模型参数一般都是FP32存储的,FP32是单精度,占4个字节

计算方式:

# 计算了
total_params = sum(p.numel() for p in model.parameters())
total_params += sum(p.numel() for p in model.buffers())

print(f'{total_params:,} total parameters.')
print(f'{total_params/(1024*1024):.2f}M total parameters.')
total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
print(f'{total_trainable_params:,} training parameters.')
print(f'{total_trainable_params/(1024*1024):.2f}M training parameters.')

2 计算量 FLOPs

算法的时间复杂度,每秒浮点数运算的次数

  • 计算量的要求是在于芯片的floaps(指的是gpu的运算能力)
  • 参数量取决于显存大小

3 相关第三方库

3.1 torchstat

安装

pip install torchstat

用法

from torchstat import stat

model = CNN()
stat(model, (3, 224, 224))

版本报错解决:https://blog.csdn.net/u013963578/article/details/133672751

输出:

  • params: 网络的参数量
  • memory: 节点推理时候所需的内存
  • Flops: 网络完成的浮点运算
  • MAdd网络完成的乘加操作的数量。一次乘加=一次乘法+一次加法,所以可以粗略的认为Flops ≈2*MAdd
  • MemRead: 网络运行时,从内存中读取的大小
  • MemWrite: 网络运行时,写入到内存中的大小
  • MemR+W: MemR+W = MemRead + MemWrite

torchstat存在的问题:

1.torchstat bug

版本问题输入为None

解决:修改源码 torchstat/reporter.py

df = df._append(total_df)

2.不能计算含有Transformer结构的模型大小

3.2 torchsummary

安装

pip install torchsummary

使用

model.to(torch.device("cuda:0"))
torchsummary.summary(model, input_size, batch_size=-1, device="cuda")

BUG修复

修改torchsummary.py源码

# 注销源码
# summary[m_key]["input_shape"] = list(input[0].size())
# summary[m_key]["input_shape"][0] = batch_size

# input 为 None的时候等于input
if len(input) != 0:
    summary[m_key]["input_shape"] = list(input[0].size())
    summary[m_key]["input_shape"][0] = batch_size
else:
    summary[m_key]["input_shape"] = input

torchsummary支持对Transformer模型大小的计算

3.3 thop

安装:

pip install thop

使用:

from thop import profile
input_size = (1, 3, 512, 512)
a = torch.randn(input_size)
flops, params = profile(model=model, inputs=(a, ))  # 注意 逗号

print(f"flops: {flops / 1e9} GFlops")
print(f"params: {params / 1e6} MB")

参考:

计算原理:https://blog.csdn.net/hxxjxw/article/details/119043464
计算Param 和 GFlops https://blog.csdn.net/qq_41573860/article/details/116767639
torchstat 输出参数解析 https://blog.csdn.net/m0_56192771/article/details/124672273
torchsummary bug 解决 tuple out of index https://blog.csdn.net/onermb/article/details/116149599
thop:https://blog.csdn.net/qq_21539375/article/details/113936308
参数大小与计算量:https://blog.csdn.net/qq_40507857/article/details/118764782
总结:https://blog.csdn.net/qq_41573860/article/details/116767639

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/578603.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

嵌入式全栈开发学习笔记---Linux基本命令2

目录 cp 源路径 目的路径 cp -r 源路径 目的路径 mv 源路径 目的路径 mv oldname newname 接下来我们继续介绍两个常用的命令 一个是拷贝文件,一个是剪切文件 ,或者也可以用来改名字。 cp 源路径 目的路径 “cp”用来拷贝文件或者目录,…

jpa分页插件对象Pageable出现了错误异常如何解决?

jpa分页插件对象Pageable出现了错误异常如何解决?! 一般来说,遇到这种的错误异常情况,通常情况 下,都是因为程序员把传递的分页页码数字写错了。 正常情况下,分页页码起始数字应该是0;而不是1…

【研发管理】产品经理知识体系-产品创新中的市场调研

导读:在产品创新过程中,市场调研的重要性不言而喻。它不仅是产品创新的起点,也是确保产品成功推向市场的关键步骤。对于产品经理系统学习和掌握产品创新中的市场调研相关知识体系十分重要。 目录 概述:市场调研重要性 1、相关概…

Python脚本:论文必备!给PDF指定页面后添加空白页

一、代码 直接上代码 import PyPDF2,os from datetime import datetimedef add_blank_pages(pdf_path, page_numbers):pdf_writer PyPDF2.PdfWriter()timestamp datetime.now().strftime("%Y%m%d%H%M%S")# 读取PDF文件with open(pdf_path, rb) as pdf_file:pdf_r…

【Java EE】Spring核心思想(一)——IOC

文章目录 🎍Spring 是什么?🎄什么是IoC呢?🌸传统程序开发🌸传统程序开发的缺陷🌸如何解决传统程序的缺陷?🌸控制反转式程序开发🌸对比总结 🌲理解…

Kafka学习笔记01【2024最新版】

一、Kafka-课程介绍 官网地址:Apache KafkaApache Kafka: A Distributed Streaming Platform.https://kafka.apache.org/ kafka 3.6.1版本,作为经典分布式订阅、发布的消息传输中间件,kafka在实时数据处理、消息队列、流处理等领域具有广泛…

检测水箱水位传感器有哪些?

生活中很多家电中都内含一个水箱,例如电蒸锅、饮水机、蒸汽熨斗、咖啡机等等,这些内部都有水箱,或大或小。当然水箱也有很多种类型,例如生活水箱、生产水箱、消防水箱等等。 把水储存在水箱中也会遇到这些问题,水箱没…

JavaScript云LIS系统源码 前端框架JQuery+EasyUI+后端框架MVC+SQLSuga大型医院云LIS检验系统源码 可直接上项目

JavaScript云LIS系统源码 前端框架JQueryEasyUI后端框架MVCSQLSuga大型医院云LIS检验系统源码 可直接上项目 云LIS系统概述: 云LIS是为区域医疗提供临床实验室信息服务的计算机应用程序,可协助区域内所有临床实验室相互协调并完成日常检验工作&#xff…

pytest-asyncio:协程异步测试案例

简介:pytest-asyncio是一个pytest插件。它便于测试使用异步库的代码。具体来说,pytest-asyncio提供了对作为测试函数的协同程序的支持。这允许用户在测试中等待代码。 历史攻略: asyncio并发访问websocket Python:协程 - 快速创…

SVGDreamer: 文本引导矢量图形合成

现有的 Text-to-SVG 方法还存在两个限制:1.生成的矢量图缺少编辑性;2. 难以生成高质量和多样性的结果。为了解决这些限制,作者提出了一种新的文本引导矢量图形合成方法:SVGDreamer。 论文题目: SVGDreamer: Text Guid…

《Git---Windows Powershell提交信息中文乱码解决方案》

解释: Windows PowerShell中的Git乱码通常是因为字符编码不正确或Git配置不支持Windows系统的默认编码导致的。Git在处理文件时可能使用UTF-8编码,而Windows系统的命令行工具(如PowerShell)默认使用的是Windows-1252或GBK编码。 …

【库函数】Linux下动态库.so和静态库.a的生成和使用

目录 🌞1. Linux下静态库和动态库的基本概念 🌞2. 动态库 🌊2.1 动态库如何生成 🌍2.1.1 文件详情 🌍2.1.2 编译生成动态库 🌊2.2 动态库如何使用 🌍2.2.1 案例 🌍2.2.2 动态…

TCP详解

2.1TCP 由IETF的RFC793定义的传输控制协议(Transmission Control Protocol,TCP)是一种基于字节流的传输层通信协议。在传输数据前需要在发送与接收者之间建立连接,通过相应机制保证其建立连接的可靠性。 TCP协议具备以下特性&am…

云备份项目->配置环境

升级gcc到7.3版本 sudo yum install centos-release-scl-rh centos-release-scl sudo yum install devtoolset-7-gcc devtoolset-7-gcc-c source /opt/rh/devtoolset-7/enable echo "source /opt/rh/devtoolset-7/enable" >> ~/.bashrc 安装Jsoncpp库 sud…

十、多模态大语言模型(MLLM)

1 多模态大语言模型(Multimodal Large Language Models) 模态的定义 模态(modal)是事情经历和发生的方式,我们生活在一个由多种模态(Multimodal)信息构成的世界,包括视觉信息、听觉信息、文本信息、嗅觉信…

STM32与Proteus的串口仿真详细教程与源程序

资料下载地址:STM32与Proteus的串口仿真详细教程与源程序 资料内容 包含LCD1602显示,串口发送接收,完美实现。 文档内容齐全,包含使用说明,相关驱动等。 解决了STM32的Proteus串口收发问题。 注意:每输…

IP-guard getdatarecord 存在任意文件读取

声明 本文仅用于技术交流,请勿用于非法用途 由于传播、利用此文所提供的信息而造成的任何直接或者间接的后果及损失,均由使用者本人负责,文章作者不为此承担任何责任。 一、产品介绍 IP-guard是由溢信科技股份有限公司开发的一款终端安全管…

揭秘被忽视的商业模式:全民拼购助力客户实现日销千万的惊人业绩

今天,我想和大家分享一个颇具潜力的模式与玩法,尽管它在外界看来可能略显陈旧。这个模式曾被忽视,但我的一位客户却巧妙运用,实现了惊人的业绩——日销售额接近五千万,日订单量高达300万单。 值得注意的是,…

一键搞定COX回归亚组森林图!快速生成顶级SCI论文的高清图!

现在亚组分析好像越来越流行,无论是观察性研究还是RCT研究,亚组分析一般配备森林图。 比如NEJM这张图: 比如Lancet这张图: 但是在使用R语言绘制时,简单的代码画不出好看的图,好看的图又需要许多代码参数来进…

[HUBUCTF 2022 新生赛]最简单的misc

有点简单, 要用到工具lsb,qr扫码 一般杂项先binwalk,不行的话在lsb 因为头是png所以save bin出二维码,用QR扫码 即可得出flag
最新文章