news 2026/6/23 18:27:58

tensorflow 零基础吃透:tf.function 与 RaggedTensor 的结合使用

作者头像

张小明

前端开发工程师

1.2k 24
文章封面图
tensorflow 零基础吃透:tf.function 与 RaggedTensor 的结合使用

零基础吃透:tf.function与RaggedTensor的结合使用

核心背景(先理清)

  • tf.function:TensorFlow的核心装饰器,能把Python函数编译成TensorFlow计算图(而非逐行执行的Eager模式),大幅提升代码执行效率(尤其是重复调用/部署场景);
  • 关键特性:RaggedTensor可透明兼容tf.function——无需修改函数逻辑,同时支持密集张量(普通tf.Tensor)和RaggedTensor输入,TF会自动适配计算图。

先准备基础运行环境:

importtensorflowastfprint(f"TensorFlow版本:{tf.__version__}")# 建议2.3+,具体函数需此版本支持

场景1:tf.function对RaggedTensor的“透明支持”(无需改代码)

核心逻辑

@tf.function装饰的函数,对密集张量和RaggedTensor的处理逻辑完全一致——TF会自动识别输入类型,调用适配RaggedTensor的算子(如tf.concat有专门的Ragged处理逻辑),无需额外修改代码。

代码+逐行解析

# 1. 定义编译成计算图的函数(生成回文序列)@tf.function# 核心装饰器:转计算图defmake_palindrome(x,axis):# 逻辑:拼接原张量 + 反转后的张量(生成回文)reversed_x=tf.reverse(x,[axis])# 反转张量(支持Ragged)returntf.concat([x,reversed_x],axis)# 拼接(支持Ragged)# 2. 测试1:传入密集张量(普通tf.Tensor)dense_x=tf.constant([[1,2],[3,4],[5,6]])dense_result=make_palindrome(dense_x,axis=1)print("=== 密集张量执行结果 ===")print(dense_result)# 3. 测试2:传入RaggedTensor(无需改函数)ragged_x=tf.ragged.constant([[1,2],[3],[4,5,6]])ragged_result=make_palindrome(ragged_x,axis=1)print("\n=== RaggedTensor执行结果 ===")print(ragged_result)

运行结果+解读

=== 密集张量执行结果 === tf.Tensor( [[1 2 2 1] [3 4 4 3] [5 6 6 5]], shape=(3, 4), dtype=int32) === RaggedTensor执行结果 === 2022-12-14 22:26:12.602591: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:907] Skipping loop optimization for Merge node with control input: RaggedConcat/assert_equal_1/Assert/AssertGuard/branch_executed/_9 <tf.RaggedTensor [[1, 2, 2, 1], [3, 3], [4, 5, 6, 6, 5, 4]]>
关键解读
  1. 函数逻辑通用:

    • 密集张量:每行[1,2]反转后[2,1],拼接成[1,2,2,1]
    • RaggedTensor:每行[3]反转后[3],拼接成[3,3][4,5,6]反转后[6,5,4],拼接成[4,5,6,6,5,4]——完全符合回文逻辑,无需改代码。
  2. 警告说明(非错误!):

    • 警告内容:Skipping loop optimization for Merge node...
    • 原因:TF的Grappler优化器(计算图优化工具)对RaggedTensor的复杂节点跳过了循环优化(Ragged的行长度不规则,部分优化不适用);
    • 影响:仅跳过优化,不影响计算结果和功能,可直接忽略。

核心原理

tf.function对RaggedTensor的“透明支持”:

  • TF会自动识别输入是RaggedTensor,调用Ragged版本的算子(如tf.concat内部会判断输入类型,选择密集/Ragged拼接逻辑);
  • 计算图会保留RaggedTensor的“行分区规则”(记录每行长度),保证运算结果符合可变长度的逻辑。

场景2:为tf.function指定input_signature(RaggedTensorSpec)

核心背景

input_signaturetf.function的参数,作用是限定输入的类型/形状

  • 提升性能:避免tf.function为不同输入类型/形状重复生成计算图;
  • 部署安全:明确输入规范,防止传入不兼容的输入;
  • 针对RaggedTensor:需用tf.RaggedTensorSpec替代普通的tf.TensorSpec

代码+解析

# 装饰器:指定input_signature为RaggedTensorSpec(限定输入规范)@tf.function(# 输入签名:二维RaggedTensor,shape=[None, None](两个维度都可变),dtype=int32input_signature=[tf.RaggedTensorSpec(shape=[None,None],dtype=tf.int32)])defmax_and_min(rt):# 计算最后一维的最大值/最小值(原生支持Ragged)max_vals=tf.math.reduce_max(rt,axis=-1)min_vals=tf.math.reduce_min(rt,axis=-1)return(max_vals,min_vals)# 测试:传入符合签名的RaggedTensorragged_x=tf.ragged.constant([[1,2],[3],[4,5,6]])max_vals,min_vals=max_and_min(ragged_x)print("\n=== 指定input_signature后的执行结果 ===")print("每行最大值:",max_vals)print("每行最小值:",min_vals)

运行结果+解读

=== 指定input_signature后的执行结果 === 每行最大值: tf.Tensor([2 3 6], shape=(3,), dtype=int32) 每行最小值: tf.Tensor([1 3 4], shape=(3,), dtype=int32)
  • 计算逻辑:对每行(最后一维)求最大/最小值,完全适配Ragged的可变长度:
    • 第一行[1,2]→ 最大2、最小1;
    • 第二行[3]→ 最大3、最小3;
    • 第三行[4,5,6]→ 最大6、最小4。

关键API:tf.RaggedTensorSpec

tf.RaggedTensorSpec是描述RaggedTensor的“输入签名类”,核心参数如下:

参数含义
shapeRaggedTensor的形状,None表示可变维度(如[None, None]=二维,两个维度都可变);
均匀维度可指定具体值(如[3, None]=固定3行,每行元素数可变)
dtypeRaggedTensor的元素类型(如tf.int32/tf.string
ragged_rank可选,不规则维度的数量(如ragged_rank=1表示只有最后1个维度是不规则的)
示例:不同的RaggedTensorSpec
# 三维RaggedTensor:固定2个样本,后两维可变,且后两维都是不规则的spec=tf.RaggedTensorSpec(shape=[2,None,None],dtype=tf.int32,ragged_rank=2)print("自定义RaggedTensorSpec:",spec)

场景3:具体函数(Concrete Function)与RaggedTensor

核心背景

  • 具体函数(Concrete Function):tf.function编译后生成的具体计算图实例(绑定了特定输入类型/形状),比普通tf.function更快(无需动态跟踪),是部署的首选;
  • 版本要求:TF 2.3+ 开始原生支持RaggedTensor与具体函数结合,低版本会报错。

代码+解析

# 1. 定义编译成计算图的函数(元素+1)@tf.functiondefincrement(x):returnx+1# 对RaggedTensor的每个元素+1,保留原始结构# 2. 构建RaggedTensorrt=tf.ragged.constant([[1,2],[3],[4,5,6]])# 3. 获取具体函数(绑定RaggedTensor的输入类型/形状)cf=increment.get_concrete_function(rt)# 4. 执行具体函数(性能更高)cf_result=cf(rt)print("\n=== 具体函数执行结果 ===")print(cf_result)

运行结果+解读

=== 具体函数执行结果 === <tf.RaggedTensor [[2, 3], [4], [5, 6, 7]]>
  • 逻辑:对RaggedTensor的每个元素+1,完全保留原始可变长度结构
    • [1,2][2,3][3][4][4,5,6][5,6,7]
  • 优势:具体函数只需编译一次,后续调用直接执行计算图,性能比普通tf.function更高。

版本兼容写法(可选)

若需兼容低版本TF,可加异常捕获:

try:cf=increment.get_concrete_function(rt)print(cf(rt))exceptExceptionase:print(f"TF版本过低不支持:{type(e).__name__}:{e}")

核心总结(tf.function+RaggedTensor关键要点)

场景核心用法关键API/参数
透明支持直接传入RaggedTensor,无需改函数逻辑@tf.function+ 普通TF算子(concat/reduce_max等)
指定输入签名RaggedTensorSpec限定RaggedTensor的形状/类型tf.RaggedTensorSpec(shape, dtype)
具体函数TF2.3+直接调用get_concrete_function(rt)tf.function.get_concrete_function

避坑关键

  1. 警告不是错误:Grappler优化器的跳过警告不影响结果,可忽略;
  2. 版本兼容:具体函数对RaggedTensor的支持从TF2.3开始,低版本需升级;
  3. 算子兼容:所有TF内置算子(reduce_max/concat/range等)都原生支持RaggedTensor,可直接在tf.function中使用;
  4. input_signatureshape:RaggedTensor的不规则维度必须用None表示,均匀维度可指定具体值(如[5, None]=固定5行)。

性能优化建议

  1. 若函数需重复调用同一类型的RaggedTensor,建议指定input_signature,避免重复编译计算图;
  2. 部署时优先使用具体函数(Concrete Function),性能更高;
  3. 避免在tf.function内动态创建RaggedTensor(如tf.ragged.constant),尽量把数据预处理放在函数外。
版权声明: 本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若内容造成侵权/违法违规/事实不符,请联系邮箱:809451989@qq.com进行投诉反馈,一经查实,立即删除!
网站建设 2026/6/22 21:19:28

57、外设总线概述

外设总线概述 即插即用规范 一些新的 ISA 设备板遵循特殊的设计规则,需要特殊的初始化序列,旨在简化附加接口板的安装和配置。这种板卡设计规范称为即插即用(PnP),它包含了一套用于构建和配置无跳线 ISA 设备的繁琐规则集。PnP 设备实现了可重定位的 I/O 区域,PC 的 BIO…

作者头像 李华
网站建设 2026/6/23 7:56:18

60、Linux内核源代码物理布局解析

Linux内核源代码物理布局解析 1. 内核源码目录结构概述 对大量的内核源代码进行结构化组织并非易事,开发者们也未遵循严格的规则。最初 drivers/char 和 drivers/block 的划分如今已效率低下,为满足不同需求,创建了更多的目录。不过,最通用的字符和块设备驱动仍位于 …

作者头像 李华
网站建设 2026/6/23 15:56:50

Google Apps Script OAuth2 库完整指南:轻松实现第三方服务集成

Google Apps Script OAuth2 库完整指南&#xff1a;轻松实现第三方服务集成 【免费下载链接】apps-script-oauth2 An OAuth2 library for Google Apps Script. 项目地址: https://gitcode.com/gh_mirrors/ap/apps-script-oauth2 Google Apps Script OAuth2 库是一个专门…

作者头像 李华
网站建设 2026/6/23 17:37:40

PySceneDetect完整指南:零基础掌握视频智能分割技术

PySceneDetect完整指南&#xff1a;零基础掌握视频智能分割技术 【免费下载链接】PySceneDetect :movie_camera: Python and OpenCV-based scene cut/transition detection program & library. 项目地址: https://gitcode.com/gh_mirrors/py/PySceneDetect PySceneD…

作者头像 李华
网站建设 2026/6/22 23:48:22

24、结合psad和fwsnort保障网络安全

结合psad和fwsnort保障网络安全 1. 网络攻击与响应机制 在网络环境中,我们经常会面临各种攻击。通过 tcpdump 工具可以捕获网络数据包,例如: [iptablesfw]# tcpdump -i eth0 -l -nn port 80 13:32:24.839585 IP 144.202.X.X.59651 > 71.157.X.X.80: S 653660994:65…

作者头像 李华
网站建设 2026/6/23 17:36:17

32、网络攻击欺骗与 fwsnort 脚本详解

网络攻击欺骗与 fwsnort 脚本详解 1. 攻击欺骗技术 在网络安全领域,攻击欺骗是一种重要的测试和攻击手段。通过 snortspoof.pl 脚本,我们可以利用 exploit.rules 文件中描述的规则来发送攻击。例如,使用以下命令进行数据包捕获: [spoofer]# tcpdump -i eth1 -l -n…

作者头像 李华