codecamp

讯飞星辰 MaaS RL 奖励函数 - 强化学习奖励计算指南

1. 字符串相等匹配奖励函数

核心逻辑:完全匹配给满分,关键词重叠达标给半额奖励,无匹配给0分,适用于标准答案固定的生成任务打分。

## 导入正则表达式库,用于文本分词和关键词提取
import re


def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    """
    强化学习奖励得分计算主函数(字符串相等匹配逻辑)
    :param data_source: 数据来源标识,用于区分不同任务/数据集,本逻辑暂未使用该参数
    :param solution: 模型生成的回答文本(待打分内容)
    :param ground_truth: 标准答案文本(基准参考)
    :param extra_info: 额外扩展信息字典,可用于传入自定义阈值、规则等配置,本逻辑暂未使用该参数
    :return: 0.0-1.0之间的奖励得分,分值越高匹配度越好
    """
    # 规则1:模型输出与标准答案完全一致,给满分1.0
    if solution == ground_truth:
        reward = 1.0  # Exact match
    # 规则2:未完全匹配,但关键词重叠率达标,给半额奖励0.5
    elif has_keyword_overlap(solution, ground_truth):
        reward = 0.5  # Partial match based on keyword overlap
    # 规则3:无有效匹配,给0分
    else:
        reward = 0.0  # No meaningful match
    # 返回最终奖励得分
    return reward




def has_keyword_overlap(text1: str, text2: str) -> bool:
    """
    计算两个文本的关键词重叠率,判断是否达到有效匹配阈值
    :param text1: 待比对文本1(模型输出)
    :param text2: 待比对文本2(标准答案)
    :return: 重叠率是否超过预设阈值,返回布尔值
    """
    # 正则提取文本中的单词/关键词,转为小写,去重后生成集合
    # \b\w+\b 正则规则:匹配完整的单词边界,避免拆分单词
    keywords1 = set(re.findall(r'\b\w+\b', text1.lower()))
    keywords2 = set(re.findall(r'\b\w+\b', text2.lower()))

    
    # 计算两个关键词集合的交集(共同出现的关键词)
    overlap = keywords1.intersection(keywords2)
    # 计算重叠率:交集数量 / 两个集合中更大的关键词数量,避免短文本虚高
    overlap_ratio = len(overlap) / max(len(keywords1), len(keywords2))

    
    # 判断重叠率是否超过30%的阈值,阈值可根据业务场景调整
    return overlap_ratio > 0.3  # Threshold can be adjusted

2. 字符串包含匹配奖励函数

核心逻辑:基于Jaccard(杰卡德)相似度计算关键词交集与并集的占比,输出0-1的连续分值,适用于关键词命中类的生成任务打分。

## 导入正则表达式库,用于文本分词
import re
## 补全原文缺失的类型注解导入,用于标注列表返回值类型
from typing import List




def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    """
    强化学习奖励得分计算主函数(字符串包含匹配逻辑)
    :param data_source: 数据来源标识,本逻辑暂未使用该参数
    :param solution: 模型生成的回答文本(待打分内容)
    :param ground_truth: 标准答案文本(基准参考)
    :param extra_info: 额外扩展信息字典,本逻辑暂未使用该参数
    :return: 0.0-1.0之间的杰卡德相似度得分,分值越高关键词重合度越高
    """
    # 对模型输出和标准答案分别分词,转为去重的集合
    set_solution = set(tokenize(solution))
    set_ground_truth = set(tokenize(ground_truth))


    # 边界处理:两个文本都无有效关键词,视为完全匹配,相似度1.0
    if not set_solution and not set_ground_truth:
        similarity = 1.0
    else:
        # 【修正原文bug】原文直接传入ground_truth字符串,改为传入分词后的标准答案集合
        # 计算两个集合的交集:同时出现在模型输出和标准答案中的关键词
        intersection = set_solution.intersection(set_ground_truth)
        # 计算两个集合的并集:模型输出和标准答案中所有不重复的关键词
        union = set_solution.union(set_ground_truth)
        # 杰卡德相似度核心公式:交集数量 / 并集数量
        similarity = len(intersection) / len(union)

    
    # 数值截断:确保相似度最终落在0.0-1.0的合法区间,避免异常值
    similarity = max(0.0, min(similarity, 1.0))
    # 返回最终相似度(奖励得分)
    return similarity




def tokenize(text: str) -> List[str]:
    """
    文本分词函数:提取文本中的有效关键词,统一转为小写
    :param text: 待分词的原始文本
    :return: 分词后的关键词列表
    """
    # 正则匹配完整单词,转为小写,返回分词列表
    return re.findall(r'\b\w+\b', text.lower())

3. 字符串相似度比较奖励函数

核心逻辑:基于Levenshtein(莱文斯坦/编辑距离)计算文本差异,归一化后输出0-1的连续分值,适用于文本整体语义相似度、句式匹配类的生成任务打分。

def compute_score(data_source: str, solution: str, ground_truth: str, extra_info: dict) -> float:
    """
    强化学习奖励得分计算主函数(编辑距离相似度逻辑)
    :param data_source: 数据来源标识,本逻辑暂未使用该参数
    :param solution: 模型生成的回答文本(待打分内容)
    :param ground_truth: 标准答案文本(基准参考)
    :param extra_info: 额外扩展信息字典,本逻辑暂未使用该参数
    :return: 0.0-1.0之间的文本相似度得分,分值越高文本差异越小
    """
    # 计算两个文本的编辑距离:把一个文本转为另一个文本所需的最少单字符操作次数
    distance = levenshtein_distance(solution, ground_truth)
    # 取两个文本的最大长度,用于归一化处理
    max_len = max(len(solution), len(ground_truth))


    # 边界处理:两个文本都是空字符串,视为完全匹配,相似度1.0
    if max_len == 0:
        similarity = 1.0
    else:
        # 核心公式:归一化相似度 = 1 - (编辑距离/文本最大长度)
        # 编辑距离越小,相似度越接近1.0
        similarity = 1 - (distance / max_len)

    
    # 数值截断:确保相似度最终落在0.0-1.0的合法区间
    similarity = max(0.0, min(similarity, 1.0))
    # 返回最终相似度(奖励得分)
    return similarity




def levenshtein_distance(s1: str, s2: str) -> int:
    """
    莱文斯坦距离(编辑距离)计算函数,基于动态规划实现
    编辑距离定义:将字符串s1转为s2所需的最少操作次数,操作包含「插入、删除、替换」单个字符
    :param s1: 待比对字符串1(模型输出)
    :param s2: 待比对字符串2(标准答案)
    :return: 两个字符串的编辑距离,数值越大差异越大
    """
    # 优化:始终让长字符串在前,短字符串在后,减少循环次数,提升计算效率
    if len(s1) < len(s2):
        return levenshtein_distance(s2, s1)


    # 初始化动态规划上一行:代表s1为空时,转为s2前j个字符所需的插入次数
    previous_row = list(range(len(s2) + 1))


    # 遍历s1的每一个字符,逐行计算动态规划表
    for i, c1 in enumerate(s1):
        # 初始化当前行的第一个元素:代表s2为空时,转为s1前i个字符所需的删除次数
        current_row = [i + 1]
        # 遍历s2的每一个字符,计算当前位置的最小操作次数
        for j, c2 in enumerate(s2):
            # 插入操作:上一行当前列的数值+1(给s1插入一个字符匹配s2的当前字符)
            insertions = previous_row[j + 1] + 1  # insertion
            # 删除操作:当前行上一列的数值+1(删除s1的当前字符)
            deletions = current_row[j] + 1        # deletion
            # 替换操作:上一行上一列的数值 + 字符是否不同(相同则0成本,不同则1成本)
            substitutions = previous_row[j] + (c1 != c2)  # substitution
            # 取三种操作的最小成本,加入当前行
            current_row.append(min(insertions, deletions, substitutions))
        # 更新上一行为当前行,进入下一轮循环
        previous_row = current_row


    # 动态规划表的最后一个元素,就是两个字符串的最小编辑距离
    return previous_row[-1]
讯飞星辰 MaaS 用户使用须知 - 平台使用规则与注意事项指南
讯飞星辰 MaaS Linux 推送模型示例 - 模型部署与推送指南
温馨提示
下载编程狮App,免费阅读超1000+编程语言教程
取消
确定
目录

关闭

MIP.setData({ 'pageTheme' : getCookie('pageTheme') || {'day':true, 'night':false}, 'pageFontSize' : getCookie('pageFontSize') || 20 }); MIP.watch('pageTheme', function(newValue){ setCookie('pageTheme', JSON.stringify(newValue)) }); MIP.watch('pageFontSize', function(newValue){ setCookie('pageFontSize', newValue) }); function setCookie(name, value){ var days = 1; var exp = new Date(); exp.setTime(exp.getTime() + days*24*60*60*1000); document.cookie = name + '=' + value + ';expires=' + exp.toUTCString(); } function getCookie(name){ var reg = new RegExp('(^| )' + name + '=([^;]*)(;|$)'); return document.cookie.match(reg) ? JSON.parse(document.cookie.match(reg)[2]) : null; }