torch_book.data.cv.plot 源代码

from dataclasses import dataclass
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec

[文档] @dataclass class GridFrame: """用于存储网格布局的图像和轴""" num_rows: int # 网格的行数 num_cols: int # 网格的列数 scale: float = 1.5 # 每个图像的大小比例
[文档] def __post_init__(self): """初始化网格布局的图像和轴""" figsize = (self.num_cols * self.scale, self.num_rows * self.scale) self.figure = plt.figure(figsize=figsize, layout="constrained")
[文档] def update_axes(self, axes, frames): """更新轴以显示图像""" for ax, img in zip(axes, frames): ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) return axes
[文档] def __call__(self, frames, **kwargs): """以网格布局绘制一系列图像的列表""" gs = GridSpec(self.num_rows, self.num_cols, self.figure, **kwargs) axes = [self.figure.add_subplot(g) for g in gs] axes = self.update_axes(axes, frames) return gs, axes
[文档] @dataclass class CompareGridFrame: """对比显示两个图像列表的网格布局""" num_rows: int # 网格的行数 num_cols: int # 网格的列数 scale: float = 1.5 # 每个图像的大小比例 layout: str ="col" # 布局方式,'row' 或 'col'
[文档] def __post_init__(self): """初始化网格布局的图像和轴""" if self.layout == "col": figsize = (self.num_cols * 2 * self.scale, self.num_rows * self.scale) elif self.layout == "row": figsize = (self.num_cols * self.scale, self.num_rows * 2 * self.scale) else: raise ValueError("layout must be 'col' or 'row'") self.figure = plt.figure(figsize=figsize, layout="constrained")
[文档] def update_axes(self, axes, frames): """更新轴以显示图像""" for ax, img in zip(axes, frames): ax.imshow(img) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) return axes
[文档] def __call__(self, frames1, frames2, **kwargs): """横向堆叠两个图像的列表""" if self.layout == "col": gs_main = GridSpec(1, 2, figure=self.figure, width_ratios=[1, 1], hspace=0.01) # 宽度比例为 2:1 elif self.layout == "row": gs_main = GridSpec(2, 1, figure=self.figure, height_ratios=[1, 1], wspace=0.01) # 高度比例为 2:1 else: raise ValueError("layout must be 'col' or 'row'") gs_left = GridSpecFromSubplotSpec(self.num_rows, self.num_cols, subplot_spec=gs_main[0], **kwargs) axes_left = [self.figure.add_subplot(g) for g in gs_left] left_axes = self.update_axes(axes_left, frames1) gs_right = GridSpecFromSubplotSpec(self.num_rows, self.num_cols, subplot_spec=gs_main[1], **kwargs) axes_right = [self.figure.add_subplot(g) for g in gs_right] right_axes = self.update_axes(axes_right, frames2) return left_axes, right_axes