【Pytorch框架】python绘图,指尖玩转TensorBoard基本用法
介绍TensorBoard的主要功能TensorBoard 是一组用于数据可视化的工具。它包含在流行的开源机器学习库 Tensorflow 中。可视化模型的网络架构跟踪模型指标,如损失和准确性等检查机器学习工作流程中权重、偏差和其他组件的直方图显示非表格数据,包括图像、文本和音频将高维嵌入投影到低维空间虽然TensorBoard包含在Tensorflow框架中,但是可以独立安装,并且服务于Pyto
本文主要介绍TensorBoard的基本使用,主要是add_scalar()和add_image()两个函数的使用。
一、什么是TensorBoard?
介绍TensorBoard的主要功能
TensorBoard 是一组用于 数据可视化 的工具。它包含在流行的开源机器学习库 Tensorflow 中。TensorBoard 的主要功能包括:
- 可视化模型的网络架构
- 跟踪模型指标,如损失和准确性等
- 检查机器学习工作流程中权重、偏差和其他组件的直方图
- 显示非表格数据,包括图像、文本和音频
- 将高维嵌入投影到低维空间
虽然TensorBoard包含在Tensorflow框架中,但是可以独立安装,并且服务于Pytorch等其他框架。
二、怎么激活TensorBoard?
介绍TensorBoard的激活方法
这里介绍使用anaconda激活TensorBoard。
1、 首先,需要下载tensorboard,使用conda命令:
conda install tensorboard
2、 初始化一个SummaryWriter对象:
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
writer.close()
SummaryWriter在当前文件目录下创建一个名为logs的文件夹,当然,这里的logs可以替换成任何你喜欢的名字,用于保存可视化数据。在程序运行完后,记得将文件关闭,使用writer.close命令。
3、 使用Anaconda Prompt,激活项目环境(上一篇提到我们使用的pytorch框架主要使用anaconda环境,所以项目的解释器也是anaconda环境,如何通过anaconda安装pytorch,并加载数据集),将anaconda环境切换到项目目录下:
cd /d [Path to Your Project]
使用命令激活Tensorboard:
tensorboard --logdir=logs
这里--logdir指定tensorboard加载的路径,logs即初始化时赋予的标签,也就是保存可视化数据的文件夹名称。
回车后:

这里说明tensorboard默认打开的端口是http://localhost:6006/,当然你也可以自定义打开的端口位置,使用命令:
tensorboard --logdir=logs --port=6007
指定端口为6007。

打开后为这个样子,因为还没有写入任何可视化数据,所以页面中没有数据。
三、add_scalar函数
add_scalar 是 PyTorch 的 SummaryWriter 类中的一个方法,常用于将标量数据记录到 TensorBoard 中。这可以帮助用户在训练期间可视化模型的性能指标,例如损失值、准确率等。
函数定义
def add_scalar(
self,
tag,
scalar_value,
global_step=None,
walltime=None,
new_style=False,
double_precision=False,
):
"""Add scalar data to summary.
Args:
tag (str): Data identifier
scalar_value (float or string/blobname): Value to save
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
with seconds after epoch of event
new_style (boolean): Whether to use new style (tensor field) or old
style (simple_value field). New style could lead to faster data loading.
"""
我们主要用到三个参数:
- tag:标识记录可视化数据标量的字符串标签,可以是任何你喜欢的字符串。例如
'Loss/train'、'Accuracy/test'等等。 - scalar_value:需要记录的值,直观的表示就是记录的图中的y轴。
- global_step:表示记录标量的全局步骤,直观的表示就是记录的图中的x轴。
示例
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter("logs")
for i in range(100):
writer.add_scalar("y=x^2",i * i,i)
writer.close()
这段代码描述了向tensorboard写入了一个x从0到100的 x 2 x^2 x2 函数,运行后刷新tensorboard即可看到记录的可视化数据标量。
tag就是图中的y=x^2,用于确定图像的唯一标识,在记录下一次数据时要切换不同的tag,否则会出现数据紊乱。
四、add_image函数
add_image 是 PyTorch 中 SummaryWriter 类的方法,用于将图像数据记录到 TensorBoard 中,以便于可视化模型的输入、输出或中间特征图等。这对检查和调试模型非常有帮助,尤其是在图像处理领域。
函数定义
def add_image(
self, tag, img_tensor, global_step=None, walltime=None, dataformats="CHW"
):
"""Add image data to summary.
Note that this requires the ``pillow`` package.
Args:
tag (str): Data identifier
img_tensor (torch.Tensor, numpy.ndarray, or string/blobname): Image data
global_step (int): Global step value to record
walltime (float): Optional override default walltime (time.time())
seconds after epoch of event
dataformats (str): Image data format specification of the form
CHW, HWC, HW, WH, etc.
Shape:
img_tensor: Default is :math:`(3, H, W)`. You can use ``torchvision.utils.make_grid()`` to
convert a batch of tensor into 3xHxW format or call ``add_images`` and let us do the job.
Tensor with :math:`(1, H, W)`, :math:`(H, W)`, :math:`(H, W, 3)` is also suitable as long as
corresponding ``dataformats`` argument is passed, e.g. ``CHW``, ``HWC``, ``HW``.
"""
这里主要用到四个参数,其中tag和global_step和add_scalar函数用法一样,而img_tensor要求接受的是一个torch.Tensor类型的张量、numpy.ndarray类型的图像或者string/blobname类型的文件名。
第四个参数为dataformats,这个需要根据传入的img_tensor的形状变化,可选项有'CHW'、'HWC'、'HW'、'WH'。
'C':通道'H':高度'W':宽度
这个参数使用在示例中有说明。
示例
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
import numpy as np
writer = SummaryWriter("logs")
img_path="mnist-20/test/test_0_7.jpg"
img_PIL = Image.open(img_path)
img_array = np.array(img_PIL)
print(img_array.shape)
writer.add_image("test",img_array,1,dataformats='HW')
writer.close()
img_path为需要保存到tensorboard中的图像路径,使用Image.open后img_PIL的格式为JepgImageFile,再使用np.array将格式转为numpy.ndarray。若不清楚当前变量是什么格式可以使用print(type())打印查看。- 这里我们查看
img_array.shape,发现是(28,28),表示图像的高度和宽度均为28,且没有通道,所以我们dataformats选择的是'HW'。
运行后刷新页面(每次写入新的数据,都需要刷新),选择Image,就可以看到打开的图像:

好了,今天关于tensorboard基本使用的分享就到这里啦!
更多推荐


所有评论(0)