news 2026/6/24 0:27:32

tensorflow 零基础吃透:tf.data 中 RaggedTensor 的核心用法(数据集流水线)

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
tensorflow 零基础吃透:tf.data 中 RaggedTensor 的核心用法(数据集流水线)

零基础吃透:tf.data中RaggedTensor的核心用法(数据集流水线)

这份内容会拆解tf.data.Dataset与 RaggedTensor 结合的四大核心场景——构建数据集、批处理/取消批处理、非规则张量转Ragged批处理、数据集转换,全程用「通俗解释+代码拆解+原理+结果解读」,帮你理解“可变长度数据”在TF输入流水线中的最优处理方式。

核心背景(先理清)

tf.data.Dataset是TensorFlow的输入流水线核心工具,负责数据加载、预处理、批处理、迭代等全流程;RaggedTensor 则是处理“可变长度数据”的原生类型,两者结合能完美解决“可变长度数据”在流水线中的处理问题(无需补0,保留原生结构)。

先补全示例依赖(确保代码可运行):

importtensorflowastfimportgoogle.protobuf.text_formataspbtext# 先重建之前的feature_tensors(tf.Example解析后的RaggedTensor字典)defbuild_tf_example(s):returnpbtext.Merge(s,tf.train.Example()).SerializeToString()example_batch=[build_tf_example(r'''features {feature {key: "colors" value {bytes_list {value: ["red", "blue"]} } } feature {key: "lengths" value {int64_list {value: [7]} } } }'''),build_tf_example(r'''features {feature {key: "colors" value {bytes_list {value: ["orange"]} } } feature {key: "lengths" value {int64_list {value: []} } } }'''),build_tf_example(r'''features {feature {key: "colors" value {bytes_list {value: ["black", "yellow"]} } } feature {key: "lengths" value {int64_list {value: [1, 3]} } } }'''),build_tf_example(r'''features {feature {key: "colors" value {bytes_list {value: ["green"]} } } feature {key: "lengths" value {int64_list {value: [3, 5, 2]} } } }''')]feature_specification={'colors':tf.io.RaggedFeature(tf.string),'lengths':tf.io.RaggedFeature(tf.int64),}feature_tensors=tf.io.parse_example(example_batch,feature_specification)# 文档中的辅助打印函数defprint_dictionary_dataset(dataset):fori,elementinenumerate(dataset):print("Element {}:".format(i))for(feature_name,feature_value)inelement.items():print('{:>14} = {}'.format(feature_name,feature_value))

场景1:使用RaggedTensor构建数据集

核心逻辑

tf.data.Dataset.from_tensor_slices是构建数据集的核心方法,对RaggedTensor的支持和普通张量完全一致——按“第一个维度(样本维度)”切分,每个元素对应一个样本的RaggedTensor(保留原始可变长度)。

代码+解析

# 从RaggedTensor字典构建数据集(feature_tensors是{colors: RaggedTensor, lengths: RaggedTensor})dataset=tf.data.Dataset.from_tensor_slices(feature_tensors)# 打印数据集元素print("=== 构建的RaggedTensor数据集 ===")print_dictionary_dataset(dataset)

运行结果+解读

Element 0: colors = [b'red' b'blue'] lengths = [7] Element 1: colors = [b'orange'] lengths = [] Element 2: colors = [b'black' b'yellow'] lengths = [1 3] Element 3: colors = [b'green'] lengths = [3 5 2]
  • 每个Element对应一个样本,colors/lengths保留该样本的原始长度(比如样本1的lengths为空列表,样本3的lengths有3个元素);
  • 对比普通张量:如果是补0的密集张量,样本1的lengths会是[0,0,0](补到最长长度),而RaggedTensor无冗余。

关键原理

from_tensor_slices对RaggedTensor的切分规则:

  • 只切分最外层的均匀维度(样本维度),内层的不规则维度保持不变;
  • 比如feature_tensors['lengths']是形状[4, None]的RaggedTensor,切分后每个元素是形状[None]的RaggedTensor(单个样本的长度列表)。

场景2:批处理/取消批处理RaggedTensor数据集

2.1 批处理(Dataset.batch)

核心逻辑

Dataset.batch(n)n个连续样本合并成一个批次,批次内的RaggedTensor会自动合并为更高维的RaggedTensor(批次维度是均匀的,内部维度仍不规则)。

代码+解析
# 按2个样本为一批进行批处理batched_dataset=dataset.batch(2)print("\n=== 批处理后的RaggedTensor数据集(batch=2) ===")print_dictionary_dataset(batched_dataset)
运行结果+解读
Element 0: colors = <tf.RaggedTensor [[b'red', b'blue'], [b'orange']]> lengths = <tf.RaggedTensor [[7], []]> Element 1: colors = <tf.RaggedTensor [[b'black', b'yellow'], [b'green']]> lengths = <tf.RaggedTensor [[1, 3], [3, 5, 2]]>
  • 每个Element是一个批次(2个样本),colors/lengths变成二维RaggedTensor(第一维是批次内的样本索引,第二维是样本内的元素);
  • 对比密集张量批处理:无需补0到“批次内最长长度”,比如批次0的colors中,第一个样本2个元素、第二个样本1个元素,直接保留原始长度。

2.2 取消批处理(Dataset.unbatch)

核心逻辑

Dataset.unbatch()把批处理后的数据集拆回“单个样本”的形式,完全恢复批处理前的结构。

代码+解析
# 取消批处理unbatched_dataset=batched_dataset.unbatch()print("\n=== 取消批处理后的数据集 ===")print_dictionary_dataset(unbatched_dataset)
运行结果

和场景1的原始数据集完全一致(4个单个样本,保留原始长度)。

关键对比(Ragged批处理 vs 密集张量批处理)

方式特点冗余性
RaggedTensor.batch合并为高维RaggedTensor,保留原始长度无冗余
密集张量.batch补0到批次内最长长度,生成固定形状张量有冗余

场景3:非Ragged张量(可变长度)转Ragged批处理

核心场景

如果数据集的元素是长度不同的密集张量(不是RaggedTensor),直接用batch会报错(长度不匹配),此时用dense_to_ragged_batch把每个批次转成RaggedTensor,避免补0。

代码+解析

# 步骤1:构建“长度不同的密集张量”数据集# 原始数据:[1,5,3,2,8] → 每个元素用tf.range生成长度不同的密集张量non_ragged_dataset=tf.data.Dataset.from_tensor_slices([1,5,3,2,8])non_ragged_dataset=non_ragged_dataset.map(tf.range)# 映射后:[0], [0,1,2,3,4], [0,1,2], [0,1], [0-7]# 步骤2:用dense_to_ragged_batch批处理(每2个样本为一批,转成RaggedTensor)batched_non_ragged_dataset=non_ragged_dataset.apply(tf.data.experimental.dense_to_ragged_batch(2))# 打印结果print("\n=== 非Ragged张量转Ragged批处理 ===")forelementinbatched_non_ragged_dataset:print(element)

运行结果+解读

<tf.RaggedTensor [[0], [0, 1, 2, 3, 4]]> <tf.RaggedTensor [[0, 1, 2], [0, 1]]> <tf.RaggedTensor [[0, 1, 2, 3, 4, 5, 6, 7]]>
  • 第一批:2个样本[0][0,1,2,3,4]→ 合并为二维RaggedTensor;
  • 第二批:2个样本[0,1,2][0,1]→ 合并为二维RaggedTensor;
  • 第三批:只剩1个样本[0-7]→ 一维RaggedTensor;
  • 核心价值:不用补0,直接按原始长度合并为RaggedTensor,解决“长度不同的密集张量无法直接batch”的问题。

关键原理

tf.data.experimental.dense_to_ragged_batch(n)

  • 每次取n个长度不同的密集张量;
  • 自动将其转换为一个n行的RaggedTensor(每行对应一个样本的原始长度);
  • 替代方案:如果不用这个方法,需要先把每个元素转成RaggedTensor,再batch,步骤更繁琐。

场景4:转换RaggedTensor数据集(Dataset.map)

核心逻辑

Dataset.map可以对数据集中的每个元素(RaggedTensor)进行任意转换(比如计算均值、生成新的RaggedTensor),TF原生支持RaggedTensor的运算。

代码+解析

# 定义转换函数:处理每个样本的features字典deftransform_lengths(features):return{# 计算lengths的均值(空列表的均值为0)'mean_length':tf.math.reduce_mean(features['lengths']),# 对lengths中的每个值,生成0到该值-1的序列(返回RaggedTensor)'length_ranges':tf.ragged.range(features['lengths'])}# 应用转换transformed_dataset=dataset.map(transform_lengths)# 打印结果print("\n=== 转换后的RaggedTensor数据集 ===")print_dictionary_dataset(transformed_dataset)

运行结果+解读

Element 0: mean_length = 7 length_ranges = <tf.RaggedTensor [[0, 1, 2, 3, 4, 5, 6]]> Element 1: mean_length = 0 length_ranges = <tf.RaggedTensor []> Element 2: mean_length = 2 length_ranges = <tf.RaggedTensor [[0], [0, 1, 2]]> Element 3: mean_length = 3 length_ranges = <tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [0, 1]]>
关键转换逻辑解读
  1. tf.math.reduce_mean(features['lengths'])

    • 样本0的lengths=[7] → 均值=7;
    • 样本1的lengths=[] → 均值=0(TF对空RaggedTensor的reduce_mean默认返回0);
    • 样本2的lengths=[1,3] → 均值=(1+3)/2=2;
    • 样本3的lengths=[3,5,2] → 均值=(3+5+2)/3=3。
  2. tf.ragged.range(features['lengths'])

    • 对lengths中的每个数值L,生成[0,1,...,L-1]的序列;
    • 样本0的lengths=[7] → 生成[0-6]→ RaggedTensor[[0,1,2,3,4,5,6]]
    • 样本2的lengths=[1,3] → 生成[0][0,1,2]→ RaggedTensor[[0], [0,1,2]]

关键优势

Dataset.map处理RaggedTensor时:

  • 无需转换为密集张量,直接运算;
  • 所有TF内置运算(reduce_mean/range/concat等)都原生支持RaggedTensor;
  • 转换后的结果仍可保留RaggedTensor结构,无缝接入后续流水线。

核心总结(tf.data+RaggedTensor关键要点)

操作核心价值关键API
构建数据集直接切分RaggedTensor,保留样本原始长度tf.data.Dataset.from_tensor_slices
批处理合并为高维RaggedTensor,无冗余补0Dataset.batch(n)
取消批处理恢复单个样本的RaggedTensor结构Dataset.unbatch()
非Ragged批处理解决长度不同的密集张量无法batch的问题tf.data.experimental.dense_to_ragged_batch
数据集转换原生支持RaggedTensor运算,无需转密集张量Dataset.map(转换函数)

避坑关键

  1. Dataset.from_generator暂不支持RaggedTensor(文档提示后续会支持),如需生成器构建,需先将RaggedTensor转成密集张量+Mask;
  2. 批处理后的RaggedTensor可直接传入Keras模型(需Input层设置ragged=True);
  3. 所有RaggedTensor的运算都遵循“只处理有效元素”的规则,无冗余计算。

这套组合是TF处理“可变长度数据”(文本、序列特征等)的最优流水线方案,既保证数据结构的原生性,又兼顾流水线的高效性。

版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/23 19:06:09

为啥网站跳转重定向是307 而不是 301 呢?

文章目录为啥网站跳转重定向是307 而不是 301 呢&#xff1f;为什么出现307 状态码呢&#xff1f;一 HSTS 是什么&#xff1f;二 HSTS 如何生效&#xff1f;三、Chrome 浏览器如何支持 HSTS&#xff1f;四、注意事项五 总结六 Chrome 博客 default for navigation https七 解释…

作者头像 李华
网站建设 2026/6/23 19:06:19

Zabbix监控模板实战指南:从零构建企业级监控体系

在当今数字化时代&#xff0c;企业IT系统的稳定运行至关重要。Zabbix作为一款功能强大的开源监控解决方案&#xff0c;其丰富的社区模板库为各类设备和应用提供了即插即用的监控能力。无论你是刚接触Zabbix的新手&#xff0c;还是希望优化现有监控体系的管理员&#xff0c;本文…

作者头像 李华
网站建设 2026/6/23 11:52:36

RulersGuides.js:网页设计中的Photoshop式标尺与辅助线终极指南

RulersGuides.js&#xff1a;网页设计中的Photoshop式标尺与辅助线终极指南 【免费下载链接】RulersGuides.js Creates Photoshop-like guides and rulers interface on a web page 项目地址: https://gitcode.com/gh_mirrors/ru/RulersGuides.js 你是否曾经在网页设计时…

作者头像 李华
网站建设 2026/6/23 19:15:54

如何快速掌握MagicEdit:高保真视频编辑的终极指南

如何快速掌握MagicEdit&#xff1a;高保真视频编辑的终极指南 【免费下载链接】magic-edit MagicEdit - 一个高保真和时间连贯的视频编辑工具&#xff0c;支持视频风格化、局部编辑、视频混合和视频外绘等应用。 项目地址: https://gitcode.com/gh_mirrors/ma/magic-edit …

作者头像 李华
网站建设 2026/6/23 10:37:41

基于STM32的辅助病床智慧监护系统设计(有完整资料)

资料查找方式&#xff1a;特纳斯电子&#xff08;电子校园网&#xff09;&#xff1a;搜索下面编号即可编号&#xff1a;T4102310M设计简介&#xff1a;以STM32单片机为核心&#xff0c;结合体温、血氧、心率等生理特征参数的监测&#xff0c;并可按需设定点滴时间定时参数&…

作者头像 李华
网站建设 2026/6/23 19:06:38

AI音频分离技术深度解析:Ultimate Vocal Remover的多轨处理革命

AI音频分离技术深度解析&#xff1a;Ultimate Vocal Remover的多轨处理革命 【免费下载链接】ultimatevocalremovergui 使用深度神经网络的声音消除器的图形用户界面。 项目地址: https://gitcode.com/GitHub_Trending/ul/ultimatevocalremovergui 在数字音频处理领域&a…

作者头像 李华