使用 VideoDecoder 解码视频#

本示例演示如何使用 torchcodec.decoders.VideoDecoder 类解码视频。

准备工作#

首先从网络下载一个示例视频,并定义一个绘图工具函数。如果你只关心解码步骤,可直接跳到下方的“创建解码器”。

from typing import Optional
import torch
import httpx

# 视频来源: https://www.pexels.com/video/dog-eating-854132/
# 许可: CC0. 作者: Coverr.
url = "https://videos.pexels.com/video-files/854132/854132-sd_640_360_25fps.mp4"
response = httpx.get(url, headers={"User-Agent": ""})
if response.status_code != 200:
    raise RuntimeError(f"Failed to download video. {response.status_code = }.")

raw_video_bytes = response.content
def plot(frames: torch.Tensor, title: Optional[str] = None):
    try:
        from torchvision.utils import make_grid
        from torchvision.transforms.v2.functional import to_pil_image
        import matplotlib.pyplot as plt
    except ImportError:
        print("Cannot plot, please run `pip install torchvision matplotlib`")
        return

    plt.rcParams["savefig.bbox"] = 'tight'
    fig, ax = plt.subplots()
    ax.imshow(to_pil_image(make_grid(frames)))
    ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
    if title is not None:
        ax.set_title(title)
    plt.tight_layout()
import matplotlib.pyplot as plt
from matplotlib import font_manager
import os

def _set_chinese_font():
    candidates = [
        "Microsoft YaHei",
        "Microsoft JhengHei",
        "SimHei",
        "SimSun",
        "Noto Sans CJK SC",
        "Noto Sans CJK JP",
        "Source Han Sans SC",
        "PingFang SC",
        "Arial Unicode MS",
    ]
    available = {f.name for f in font_manager.fontManager.ttflist}
    for name in candidates:
        if name in available:
            plt.rcParams['font.family'] = 'sans-serif'
            plt.rcParams['font.sans-serif'] = [name]
            plt.rcParams['axes.unicode_minus'] = False
            return name
    win_fonts = [
        r"C:\Windows\Fonts\msyh.ttc",
        r"C:\Windows\Fonts\simhei.ttf",
        r"C:\Windows\Fonts\simsun.ttc",
    ]
    for path in win_fonts:
        if os.path.exists(path):
            font_manager.fontManager.addfont(path)
            font_name = font_manager.FontProperties(fname=path).get_name()
            plt.rcParams['font.family'] = 'sans-serif'
            plt.rcParams['font.sans-serif'] = [font_name]
            plt.rcParams['axes.unicode_minus'] = False
            return font_name
    print('Warning: 未找到中文字体,可能出现缺字警告。')
    return None

_chinese_font = _set_chinese_font()
if _chinese_font:
    print(f'已使用中文字体: {_chinese_font}')
已使用中文字体: Microsoft YaHei

创建解码器#

我们可以直接从原始(编码)视频字节创建解码器。你也可以使用本地视频文件路径作为输入。

from torchcodec.decoders import VideoDecoder

decoder = VideoDecoder(raw_video_bytes)

视频尚未被真正解码,但我们可以通过 metadata 属性(类型为 torchcodec.decoders.VideoStreamMetadata)读取一些元数据。

print(decoder.metadata)
VideoStreamMetadata:
  duration_seconds_from_header: 13.8
  begin_stream_seconds_from_header: 0.0
  bit_rate: 505790.0
  codec: h264
  stream_index: 0
  begin_stream_seconds_from_content: 0.0
  end_stream_seconds_from_content: 13.8
  width: 640
  height: 360
  num_frames_from_header: 345
  num_frames_from_content: 345
  average_fps_from_header: 25.0
  pixel_aspect_ratio: 1
  duration_seconds: 13.8
  begin_stream_seconds: 0.0
  end_stream_seconds: 13.8
  num_frames: 345
  average_fps: 25.0

通过下标解码帧#

索引解码器即可得到帧的 torch.Tensor:默认返回形状为 N, C, H, W(批大小、通道数、高、宽)。当解码多个帧时才会出现批维 N。可以通过在创建解码器时设置 dimension_order 参数改为 N, H, W, C。帧的 dtype 始终为 torch.uint8

如果需要一次解码多个帧,推荐使用批量 API(更快):VideoDecoder.get_frames_atVideoDecoder.get_frames_in_rangeVideoDecoder.get_frames_played_atVideoDecoder.get_frames_played_in_range

first_frame = decoder[0]
every_twenty_frame = decoder[0:-1:20]

print(f"{first_frame.shape = }")
print(f"{first_frame.dtype = }")
print(f"{every_twenty_frame.shape = }")
print(f"{every_twenty_frame.dtype = }")
first_frame.shape = torch.Size([3, 360, 640])
first_frame.dtype = torch.uint8
every_twenty_frame.shape = torch.Size([18, 3, 360, 640])
every_twenty_frame.dtype = torch.uint8
plot(first_frame, "第一帧")
plot(every_twenty_frame, "每隔 20 帧")
../../_images/0ec3e3a699f3cd8686a22b0a0a5e01abf965abd270ac5dd7a7edee607dd3e33b.png ../../_images/a26cc9c811f683580560992f758447209ff68f2f4a6475308a7aae7fa9049354.png

迭代帧#

解码器是一个可迭代对象,可以直接遍历:

for frame in decoder:
    assert (
        isinstance(frame, torch.Tensor)
        and frame.shape == (3, decoder.metadata.height, decoder.metadata.width)
    )

获取帧的 pts 和时长#

索引返回的是纯 torch.Tensor。如果还需要获取帧的展示时间戳(pts)和时长,可使用 VideoDecoder.get_frame_atVideoDecoder.get_frames_at,它们分别返回 torchcodec.Frametorchcodec.FrameBatch

last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
print(last_frame)
type(last_frame) = <class 'torchcodec._frame.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 13.76
  duration_seconds: 0.04
other_frames = decoder.get_frames_at([10, 0, 50])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([3, 3, 360, 640])
  pts_seconds: tensor([0.4000, 0.0000, 2.0000], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(last_frame.data, "最后一帧")
plot(other_frames.data, "其它帧")
../../_images/961815783b71bd9a3dda0480bac5711bf9eaacb18d7e29b8bae2216f745c5e5f.png ../../_images/a08d2f2468300a9a382b4fcc9acbb8b79f17c478745e68a8a4a1721809ec4b90.png

FrameFrameBatch 都包含 data 字段(解码后的张量数据),以及 pts_secondsduration_seconds 字段:对于 Frame 是单个数值;对于 FrameBatch 是一维 torch.Tensor(与批中帧一一对应)。

基于播放时间检索帧#

除了按索引检索之外,还可以根据帧的播放时间检索:VideoDecoder.get_frame_played_atVideoDecoder.get_frames_played_at,同样分别返回 FrameFrameBatch

frame_at_2_seconds = decoder.get_frame_played_at(seconds=2)
print(f"{type(frame_at_2_seconds) = }")
print(frame_at_2_seconds)
type(frame_at_2_seconds) = <class 'torchcodec._frame.Frame'>
Frame:
  data (shape): torch.Size([3, 360, 640])
  pts_seconds: 2.0
  duration_seconds: 0.04
other_frames = decoder.get_frames_played_at(seconds=[10.1, 0.3, 5])
print(f"{type(other_frames) = }")
print(other_frames)
type(other_frames) = <class 'torchcodec._frame.FrameBatch'>
FrameBatch:
  data (shape): torch.Size([3, 3, 360, 640])
  pts_seconds: tensor([10.0800,  0.2800,  5.0000], dtype=torch.float64)
  duration_seconds: tensor([0.0400, 0.0400, 0.0400], dtype=torch.float64)
plot(frame_at_2_seconds.data, "在 2 秒播放的帧")
plot(other_frames.data, "其它时间点的帧")
../../_images/08ecc013075151e662ca902eb397f4de0fd84877b86036bee9ce61243cbcd6fa.png ../../_images/87ae077c8f37f027ed3f9fa9f3e69f481ac4d72cb9f4cc96cdd3a428c5759a18.png