讯飞星辰 MaaS RL 奖励函数 - 强化学习奖励计算指南
1. 字符串相等匹配奖励函数
核心逻辑:完全匹配给满分,关键词重叠达标给半额奖励,无匹配给0分,适用于标准答案固定的生成任务打分。
## 导入正则表达式库,用于文本分词和关键词提取
import re
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
"""
强化学习奖励得分计算主函数(字符串相等匹配逻辑)
:param data_source: 数据来源标识,用于区分不同任务/数据集,本逻辑暂未使用该参数
:param solution: 模型生成的回答文本(待打分内容)
:param ground_truth: 标准答案文本(基准参考)
:param extra_info: 额外扩展信息字典,可用于传入自定义阈值、规则等配置,本逻辑暂未使用该参数
:return: 0.0-1.0之间的奖励得分,分值越高匹配度越好
"""
# 规则1:模型输出与标准答案完全一致,给满分1.0
if solution == ground_truth:
reward = 1.0 # Exact match
# 规则2:未完全匹配,但关键词重叠率达标,给半额奖励0.5
elif has_keyword_overlap(solution, ground_truth):
reward = 0.5 # Partial match based on keyword overlap
# 规则3:无有效匹配,给0分
else:
reward = 0.0 # No meaningful match
# 返回最终奖励得分
return reward
def has_keyword_overlap(text1: str, text2: str) -> bool:
"""
计算两个文本的关键词重叠率,判断是否达到有效匹配阈值
:param text1: 待比对文本1(模型输出)
:param text2: 待比对文本2(标准答案)
:return: 重叠率是否超过预设阈值,返回布尔值
"""
# 正则提取文本中的单词/关键词,转为小写,去重后生成集合
# \b\w+\b 正则规则:匹配完整的单词边界,避免拆分单词
keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))
# 计算两个关键词集合的交集(共同出现的关键词)
overlap = keywords1.intersection(keywords2)
# 计算重叠率:交集数量 / 两个集合中更大的关键词数量,避免短文本虚高
overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))
# 判断重叠率是否超过30%的阈值,阈值可根据业务场景调整
return overlap_ratio > 0.3 # Threshold can be adjusted
2. 字符串包含匹配奖励函数
核心逻辑:基于Jaccard(杰卡德)相似度计算关键词交集与并集的占比,输出0-1的连续分值,适用于关键词命中类的生成任务打分。
## 导入正则表达式库,用于文本分词
import re
## 补全原文缺失的类型注解导入,用于标注列表返回值类型
from typing import List
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
"""
强化学习奖励得分计算主函数(字符串包含匹配逻辑)
:param data_source: 数据来源标识,本逻辑暂未使用该参数
:param solution: 模型生成的回答文本(待打分内容)
:param ground_truth: 标准答案文本(基准参考)
:param extra_info: 额外扩展信息字典,本逻辑暂未使用该参数
:return: 0.0-1.0之间的杰卡德相似度得分,分值越高关键词重合度越高
"""
# 对模型输出和标准答案分别分词,转为去重的集合
set_solution = set(tokenize(solution))
set_ground_truth = set(tokenize(ground_truth))
# 边界处理:两个文本都无有效关键词,视为完全匹配,相似度1.0
if not set_solution and not set_ground_truth:
similarity = 1.0
else:
# 【修正原文bug】原文直接传入ground_truth字符串,改为传入分词后的标准答案集合
# 计算两个集合的交集:同时出现在模型输出和标准答案中的关键词
intersection = set_solution.intersection(set_ground_truth)
# 计算两个集合的并集:模型输出和标准答案中所有不重复的关键词
union = set_solution.union(set_ground_truth)
# 杰卡德相似度核心公式:交集数量 / 并集数量
similarity = len(intersection) / len(union)
# 数值截断:确保相似度最终落在0.0-1.0的合法区间,避免异常值
similarity = max(0.0, min(similarity, 1.0))
# 返回最终相似度(奖励得分)
return similarity
def tokenize(text: str) -> List[str]:
"""
文本分词函数:提取文本中的有效关键词,统一转为小写
:param text: 待分词的原始文本
:return: 分词后的关键词列表
"""
# 正则匹配完整单词,转为小写,返回分词列表
return re.findall(r'\b\w+\b', text.lower())
3. 字符串相似度比较奖励函数
核心逻辑:基于Levenshtein(莱文斯坦/编辑距离)计算文本差异,归一化后输出0-1的连续分值,适用于文本整体语义相似度、句式匹配类的生成任务打分。
def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
"""
强化学习奖励得分计算主函数(编辑距离相似度逻辑)
:param data_source: 数据来源标识,本逻辑暂未使用该参数
:param solution: 模型生成的回答文本(待打分内容)
:param ground_truth: 标准答案文本(基准参考)
:param extra_info: 额外扩展信息字典,本逻辑暂未使用该参数
:return: 0.0-1.0之间的文本相似度得分,分值越高文本差异越小
"""
# 计算两个文本的编辑距离:把一个文本转为另一个文本所需的最少单字符操作次数
distance = levenshtein_distance(solution, ground_truth)
# 取两个文本的最大长度,用于归一化处理
max_len = max(len(solution), len(ground_truth))
# 边界处理:两个文本都是空字符串,视为完全匹配,相似度1.0
if max_len == 0:
similarity = 1.0
else:
# 核心公式:归一化相似度 = 1 - (编辑距离/文本最大长度)
# 编辑距离越小,相似度越接近1.0
similarity = 1 - (distance / max_len)
# 数值截断:确保相似度最终落在0.0-1.0的合法区间
similarity = max(0.0, min(similarity, 1.0))
# 返回最终相似度(奖励得分)
return similarity
def levenshtein_distance(s1: str, s2: str) -> int:
"""
莱文斯坦距离(编辑距离)计算函数,基于动态规划实现
编辑距离定义:将字符串s1转为s2所需的最少操作次数,操作包含「插入、删除、替换」单个字符
:param s1: 待比对字符串1(模型输出)
:param s2: 待比对字符串2(标准答案)
:return: 两个字符串的编辑距离,数值越大差异越大
"""
# 优化:始终让长字符串在前,短字符串在后,减少循环次数,提升计算效率
if len(s1) < len(s2):
return levenshtein_distance(s2, s1)
# 初始化动态规划上一行:代表s1为空时,转为s2前j个字符所需的插入次数
previous_row = list(range(len(s2) + 1))
# 遍历s1的每一个字符,逐行计算动态规划表
for i, c1 in enumerate(s1):
# 初始化当前行的第一个元素:代表s2为空时,转为s1前i个字符所需的删除次数
current_row = [i + 1]
# 遍历s2的每一个字符,计算当前位置的最小操作次数
for j, c2 in enumerate(s2):
# 插入操作:上一行当前列的数值+1(给s1插入一个字符匹配s2的当前字符)
insertions = previous_row[j + 1] + 1 # insertion
# 删除操作:当前行上一列的数值+1(删除s1的当前字符)
deletions = current_row[j] + 1 # deletion
# 替换操作:上一行上一列的数值 + 字符是否不同(相同则0成本,不同则1成本)
substitutions = previous_row[j] + (c1 != c2) # substitution
# 取三种操作的最小成本,加入当前行
current_row.append(min(insertions, deletions, substitutions))
# 更新上一行为当前行,进入下一轮循环
previous_row = current_row
# 动态规划表的最后一个元素,就是两个字符串的最小编辑距离
return previous_row[-1]