作者 lixiang

智能助手增加多轮对话流式返回接口

@@ -131,7 +131,8 @@ public class ShiroConfig { @@ -131,7 +131,8 @@ public class ShiroConfig {
131 filterChainDefinitionMap.put("/**/*.wasm", "anon"); 131 filterChainDefinitionMap.put("/**/*.wasm", "anon");
132 132
133 // 在ShiroConfig.java的filterChainDefinitionMap中添加: 133 // 在ShiroConfig.java的filterChainDefinitionMap中添加:
134 - filterChainDefinitionMap.put("/**/airag/zdyRag/send", "anon"); // 精确匹配接口路径 134 + filterChainDefinitionMap.put("/**/airag/zdyRag/sendStream", "anon"); // 精确匹配接口路径
  135 + filterChainDefinitionMap.put("/**/airag/zdyRag/multiStageStream", "anon"); // 精确匹配接口路径
135 filterChainDefinitionMap.put("/public/**", "anon"); // 通配符放行所有公开路径 136 filterChainDefinitionMap.put("/public/**", "anon"); // 通配符放行所有公开路径
136 // 放行按钮接口(如果确实需要公开) 137 // 放行按钮接口(如果确实需要公开)
137 filterChainDefinitionMap.put("/**/airagbutton/airagButton/buttonList", "anon"); 138 filterChainDefinitionMap.put("/**/airagbutton/airagButton/buttonList", "anon");
@@ -110,7 +110,7 @@ public class ZdyRagController { @@ -110,7 +110,7 @@ public class ZdyRagController {
110 } 110 }
111 111
112 // 从知识库搜索 112 // 从知识库搜索
113 - List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 3, 0.75); 113 + List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 2, 0.78);
114 if (CollectionUtil.isEmpty(maps)) { 114 if (CollectionUtil.isEmpty(maps)) {
115 Map<String, String> data = new HashMap<>(); 115 Map<String, String> data = new HashMap<>();
116 data.put("token", "该问题未记录在知识库中"); 116 data.put("token", "该问题未记录在知识库中");
@@ -134,7 +134,7 @@ public class ZdyRagController { @@ -134,7 +134,7 @@ public class ZdyRagController {
134 // 构建知识库内容 134 // 构建知识库内容
135 StringBuilder content = new StringBuilder(); 135 StringBuilder content = new StringBuilder();
136 for (Map<String, Object> map : maps) { 136 for (Map<String, Object> map : maps) {
137 - if (Double.parseDouble(map.get("score").toString()) > 0.75) { 137 + if (Double.parseDouble(map.get("score").toString()) > 0.78) {
138 content.append(map.get("content").toString()).append("\n"); 138 content.append(map.get("content").toString()).append("\n");
139 } 139 }
140 } 140 }
@@ -148,11 +148,14 @@ public class ZdyRagController { @@ -148,11 +148,14 @@ public class ZdyRagController {
148 String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" + 148 String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +
149 "处理步骤和要求:\n" + 149 "处理步骤和要求:\n" +
150 "1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" + 150 "1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
151 - "2. 回答结构:\n" + 151 + "2. 严格基于参考内容回答,禁止使用参考内容中与问题无关的内容\n" +
  152 + "3. 回答结构:\n" +
152 " - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" + 153 " - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
153 " - 然后列出支持该答案的具体证据(可直接引用参考内容)\n" + 154 " - 然后列出支持该答案的具体证据(可直接引用参考内容)\n" +
154 - "3. 禁止以下行为:\n" + 155 + "4. 禁止以下行为:\n" +
155 " - 添加参考内容中不存在的信息\n" + 156 " - 添加参考内容中不存在的信息\n" +
  157 + " - 在回答中提及‘参考内容’等字样\n" +
  158 + " - 在回答中提及其他产品的功能\n" +
156 " - 进行任何推测性陈述\n" + 159 " - 进行任何推测性陈述\n" +
157 " - 使用模糊或不确定的表达\n" + 160 " - 使用模糊或不确定的表达\n" +
158 " - 参考内容为空时应该拒绝回答\n" + 161 " - 参考内容为空时应该拒绝回答\n" +
@@ -187,6 +190,8 @@ public class ZdyRagController { @@ -187,6 +190,8 @@ public class ZdyRagController {
187 // 记录日志 - 从知识库生成回答 190 // 记录日志 - 从知识库生成回答
188 logRecord.setAnswer(answerBuilder.toString()) 191 logRecord.setAnswer(answerBuilder.toString())
189 .setAnswerType(2); 192 .setAnswerType(2);
  193 +
  194 + System.out.println("回答内容 = " + answerBuilder.toString());
190 airagLogService.save(logRecord); 195 airagLogService.save(logRecord);
191 196
192 emitter.complete(); 197 emitter.complete();
@@ -292,19 +297,43 @@ public class ZdyRagController { @@ -292,19 +297,43 @@ public class ZdyRagController {
292 297
293 298
294 List<ChatMessage> messages = new ArrayList<>(); 299 List<ChatMessage> messages = new ArrayList<>();
295 - String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +  
296 - "处理步骤和要求:\n" +  
297 - "1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +  
298 - "2. 回答结构:\n" +  
299 - " - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +  
300 - " - 然后列出支持该答案的具体证据(可直接引用参考内容)\n" +  
301 - "3. 禁止以下行为:\n" +  
302 - " - 添加参考内容中不存在的信息\n" +  
303 - " - 进行任何推测性陈述\n" +  
304 - " - 使用模糊或不确定的表达\n" +  
305 - " - 参考内容为空时应该拒绝回答\n" +  
306 - "参考内容(请严格限制回答范围于此):\n" + content;  
307 - 300 +// String questin = "你是一个严谨的信息处理助手,请严格按照以下要求回答用户问题:" + questionText + "\n\n" +
  301 +// "处理步骤和要求:\n" +
  302 +// "1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
  303 +// "2. 回答结构:\n" +
  304 +// " - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
  305 +// " - 然后列出支持该答案的说明,以点的方式将这些说明列出(可直接引用参考内容)\n" +
  306 +// "3. 禁止以下行为:\n" +
  307 +// " - 添加参考内容中不存在的信息\n" +
  308 +// " - 进行任何推测性陈述\n" +
  309 +// " - 使用模糊或不确定的表达\n" +
  310 +// " - 参考内容为空时应该拒绝回答\n" +
  311 +// "参考内容(请严格限制回答范围于此):\n" + content;
  312 + String questin = "你是一个严格遵循指令的信息处理助手,请按照以下规范回答用户问题:\n\n" +
  313 + "# 处理规范\n" +
  314 + "1. 回答范围:\n" +
  315 + " - 仅使用提供的参考内容进行回答\n" +
  316 + " - 禁止任何超出参考内容的推断、想象或补充\n" +
  317 + " - 当参考内容为空或不相关时,必须拒绝回答\n\n" +
  318 + "2. 回答结构要求:\n" +
  319 + " - 首行必须用「回答:」开头,给出最直接的事实性回答\n" +
  320 + " - 后续每行以「•」开头列出支持证据,每条证据必须:\n" +
  321 + " * 直接引用参考内容\n" +
  322 + " * 标注具体出处位置(如段落编号/行号)\n" +
  323 + " * 保持原句完整性,不得改写\n\n" +
  324 + "3. 禁止事项:\n" +
  325 + " - 任何形式的推测(包括\"可能\"、\"应该\"等不确定表述)\n" +
  326 + " - 回答内容不得提出\"参考内容\"、\"证据\"等字样\n" +
  327 + " - 参考内容中未明确出现的数字、事实或结论\n" +
  328 + " - 总结性陈述或观点性表达\n" +
  329 + " - 多个信息点的合并表述\n\n" +
  330 + "4. 特殊情形处理:\n" +
  331 + " - 专业术语必须保持原文表述\n" +
  332 + " - 数据必须包含原始单位和精度\n\n" +
  333 + "# 当前任务\n" +
  334 + "问题:「" + questionText + "」\n\n" +
  335 + "参考内容(严格限制回答范围):\n" +
  336 + content;
308 337
309 messages.add(new UserMessage("user", questin)); 338 messages.add(new UserMessage("user", questin));
310 String chat = aiChatHandler.completions(modelId, messages, null); 339 String chat = aiChatHandler.completions(modelId, messages, null);
  1 +package org.jeecg.modules.airag.zdyrag.controller;
  2 +
  3 +import cn.hutool.core.collection.CollectionUtil;
  4 +import com.fasterxml.jackson.databind.ObjectMapper;
  5 +import dev.langchain4j.data.message.ChatMessage;
  6 +import dev.langchain4j.data.message.UserMessage;
  7 +import dev.langchain4j.service.TokenStream;
  8 +import io.swagger.v3.oas.annotations.Operation;
  9 +import lombok.extern.slf4j.Slf4j;
  10 +import org.apache.commons.lang3.StringUtils;
  11 +import org.jeecg.ai.handler.LLMHandler;
  12 +import org.jeecg.common.api.vo.Result;
  13 +import org.jeecg.modules.airag.app.entity.AiragLog;
  14 +import org.jeecg.modules.airag.app.service.IAiragLogService;
  15 +import org.jeecg.modules.airag.common.handler.IAIChatHandler;
  16 +import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
  17 +import org.jeecg.modules.airag.app.utils.FileToBase64Util;
  18 +import org.springframework.beans.factory.annotation.Autowired;
  19 +import org.springframework.beans.factory.annotation.Value;
  20 +import org.springframework.data.redis.core.RedisTemplate;
  21 +import org.springframework.web.bind.annotation.GetMapping;
  22 +import org.springframework.web.bind.annotation.RequestMapping;
  23 +import org.springframework.web.bind.annotation.RequestParam;
  24 +import org.springframework.web.bind.annotation.RestController;
  25 +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
  26 +
  27 +import java.util.*;
  28 +import java.util.concurrent.*;
  29 +
  30 +@RestController
  31 +@RequestMapping("/airag/zdyRag")
  32 +@Slf4j
  33 +public class ZdyRagMultiStageController {
  34 +
  35 + @Autowired
  36 + private EmbeddingHandler embeddingHandler;
  37 +
  38 + @Autowired
  39 + private IAIChatHandler aiChatHandler;
  40 +
  41 + @Autowired
  42 + private IAiragLogService airagLogService;
  43 +
  44 + @Autowired
  45 + private RedisTemplate<String, Object> redisTemplate;
  46 +
  47 + @Value("${jeecg.upload.path}")
  48 + private String uploadPath;
  49 +
  50 + private final ExecutorService executor = Executors.newCachedThreadPool();
  51 + private final ExecutorService asyncLLMExecutor = Executors.newFixedThreadPool(5);
  52 +
  53 + private static final int MAX_CONTEXT_SIZE = 10;
  54 + private static final long CONTEXT_TTL_MILLIS = 30 * 60 * 1000; // 30分钟过期
  55 +
  56 + private String redisKey(String sessionId) {
  57 + return "chat:context:" + sessionId;
  58 + }
  59 +
  60 + @Operation(summary = "multiStageStream with Redis context")
  61 + @GetMapping("multiStageStream")
  62 + public SseEmitter multiStageStream(@RequestParam String questionText,
  63 + @RequestParam(required = false) String sessionId) throws Exception {
  64 + SseEmitter emitter = new SseEmitter(300000L);
  65 + String modelId = "1926875898187878401";
  66 + String knowId = "1926872137990148098";
  67 +
  68 + AiragLog logRecord = new AiragLog()
  69 + .setQuestion(questionText)
  70 + .setModelId(modelId)
  71 + .setCreateTime(new Date());
  72 +
  73 + executor.execute(() -> {
  74 + try {
  75 + List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 5, 0.75);
  76 +
  77 + if (CollectionUtil.isEmpty(maps)) {
  78 + sendSimpleMessage(emitter, "该问题未记录在知识库中");
  79 + logRecord.setAnswer("该问题未记录在知识库中").setAnswerType(3).setIsStorage(0);
  80 + airagLogService.save(logRecord);
  81 + emitter.complete();
  82 + return;
  83 + }
  84 +
  85 + // 多线程摘要
  86 + List<Future<String>> summaryFutures = new ArrayList<>();
  87 + for (Map<String, Object> map : maps) {
  88 + String content = map.get("content").toString();
  89 + String summaryPrompt = buildSummaryPrompt(questionText, content);
  90 + summaryFutures.add(asyncLLMExecutor.submit(() ->
  91 + aiChatHandler.completions(modelId, List.of(new UserMessage("user", summaryPrompt)), null)
  92 + ));
  93 + }
  94 +
  95 + List<String> summaries = new ArrayList<>();
  96 + for (Future<String> future : summaryFutures) {
  97 + try {
  98 + String summary = future.get(15, TimeUnit.SECONDS);
  99 + if (StringUtils.isNotBlank(summary)) summaries.add(summary.trim());
  100 + } catch (Exception e) {
  101 + log.warn("摘要生成失败", e);
  102 + }
  103 + }
  104 +
  105 + // 多线程候选答案
  106 + List<Future<String>> answerFutures = new ArrayList<>();
  107 + for (String summary : summaries) {
  108 + String answerPrompt = buildAnswerPrompt(questionText, summary);
  109 + answerFutures.add(asyncLLMExecutor.submit(() ->
  110 + aiChatHandler.completions(modelId, List.of(new UserMessage("user", answerPrompt)), null)
  111 + ));
  112 + }
  113 +
  114 + List<String> candidateAnswers = new ArrayList<>();
  115 + for (Future<String> future : answerFutures) {
  116 + try {
  117 + String answer = future.get(15, TimeUnit.SECONDS);
  118 + if (StringUtils.isNotBlank(answer)) candidateAnswers.add(answer);
  119 + } catch (Exception e) {
  120 + log.warn("候选答案生成失败", e);
  121 + }
  122 + }
  123 +
  124 + String mergePrompt = buildMergePrompt(questionText, candidateAnswers);
  125 + List<ChatMessage> mergeMessages = new ArrayList<>();
  126 +
  127 + // 从 Redis 读取历史上下文
  128 + if (StringUtils.isNotBlank(sessionId)) {
  129 + Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
  130 + if (cached instanceof List) {
  131 + //noinspection unchecked
  132 + mergeMessages.addAll((List<ChatMessage>) cached);
  133 + }
  134 + }
  135 + mergeMessages.add(new UserMessage("user", mergePrompt));
  136 +
  137 + StringBuilder answerBuilder = new StringBuilder();
  138 +
  139 + Map<String, Object> firstMatch = maps.get(0);
  140 + String storedFileName = extractFieldFromMetadata(firstMatch.get("metadata"), "storedFileName");
  141 + String docName = extractFieldFromMetadata(firstMatch.get("metadata"), "docName");
  142 + String similarityScore = String.valueOf(firstMatch.get("score"));
  143 +
  144 + TokenStream tokenStream = aiChatHandler.chat(modelId, mergeMessages);
  145 +
  146 + tokenStream.onNext(token -> {
  147 + try {
  148 + answerBuilder.append(token);
  149 + Map<String, String> data = new HashMap<>();
  150 + data.put("token", token);
  151 + emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
  152 + } catch (Exception e) {
  153 + log.error("发送 token 失败", e);
  154 + }
  155 + });
  156 +
  157 + tokenStream.onComplete(response -> {
  158 + try {
  159 + Map<String, String> endData = new HashMap<>();
  160 + endData.put("event", "END");
  161 + endData.put("similarity", similarityScore);
  162 + endData.put("fileName", docName);
  163 + if (StringUtils.isNotBlank(storedFileName)) {
  164 + endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
  165 + }
  166 + emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
  167 +
  168 + logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
  169 + airagLogService.save(logRecord);
  170 +
  171 + // 保存更新上下文到 Redis,截断最近10条
  172 + if (StringUtils.isNotBlank(sessionId)) {
  173 + Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
  174 + List<ChatMessage> context;
  175 + if (cached instanceof List) {
  176 + //noinspection unchecked
  177 + context = new ArrayList<>((List<ChatMessage>) cached);
  178 + } else {
  179 + context = new ArrayList<>();
  180 + }
  181 + context.add(new UserMessage("user", questionText));
  182 + context.add(new UserMessage("assistant", answerBuilder.toString()));
  183 + if (context.size() > MAX_CONTEXT_SIZE) {
  184 + context = context.subList(context.size() - MAX_CONTEXT_SIZE, context.size());
  185 + }
  186 + redisTemplate.opsForValue().set(redisKey(sessionId), context, CONTEXT_TTL_MILLIS, TimeUnit.MILLISECONDS);
  187 + }
  188 +
  189 + emitter.complete();
  190 + } catch (Exception e) {
  191 + emitter.completeWithError(e);
  192 + }
  193 + });
  194 +
  195 + tokenStream.onError(error -> {
  196 + log.error("生成答案失败", error);
  197 + logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
  198 + airagLogService.save(logRecord);
  199 + emitter.completeWithError(error);
  200 + });
  201 +
  202 + tokenStream.start();
  203 +
  204 + } catch (Exception e) {
  205 + log.error("多阶段处理异常", e);
  206 + logRecord.setAnswer("处理异常: " + e.getMessage()).setAnswerType(4);
  207 + airagLogService.save(logRecord);
  208 + emitter.completeWithError(e);
  209 + }
  210 + });
  211 +
  212 + return emitter;
  213 + }
  214 +
  215 + private void sendSimpleMessage(SseEmitter emitter, String message) throws Exception {
  216 + Map<String, String> data = new HashMap<>();
  217 + data.put("token", message);
  218 + emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
  219 + }
  220 +
  221 + private String extractFieldFromMetadata(Object metadataObj, String key) throws Exception {
  222 + if (metadataObj == null) return "";
  223 + ObjectMapper objectMapper = new ObjectMapper();
  224 + Map<String, String> metadata = objectMapper.readValue(metadataObj.toString(), Map.class);
  225 + if (metadata.containsKey(key)) {
  226 + return metadata.get(key);
  227 + }
  228 + return "";
  229 + }
  230 +
  231 + private String buildSummaryPrompt(String question, String content) {
  232 + return "你是一个信息摘要助手,请只针对以下内容进行摘要,严格不添加其他产品信息或无关内容:\n\n" +
  233 + "用户问题:" + question + "\n" +
  234 + "内容段落:\n" + content + "\n\n" +
  235 + "请提取与问题直接相关且仅限于该内容的关键信息,控制在200字以内。";
  236 + }
  237 +
  238 + private String buildAnswerPrompt(String question, String summary) {
  239 + return "你是一个信息回答助手,请严格根据以下摘要内容回答用户问题。\n\n" +
  240 + "用户问题:" + question + "\n" +
  241 + "摘要内容:\n" + summary + "\n\n" +
  242 + "回答要求:\n- 回答必须以‘回答:’开头\n- 严格禁止添加摘要外的信息\n- 只能使用摘要中提及的内容\n- 禁止合并其他摘要的内容。";
  243 + }
  244 + private String buildMergePrompt(String question, List<String> answers) {
  245 + StringBuilder sb = new StringBuilder("你收到多个候选答案,请从中选择最准确且不交叉混淆产品信息的答案作为最终回答。\n\n");
  246 + sb.append("用户问题:").append(question).append("\n");
  247 + for (int i = 0; i < answers.size(); i++) {
  248 + sb.append("候选答案").append(i + 1).append(":\n").append(answers.get(i)).append("\n\n");
  249 + }
  250 + sb.append("请直接输出最佳答案,**禁止添加新信息**或跨产品混合。");
  251 + return sb.toString();
  252 + }
  253 +
  254 +}