作者 lixiang

智能助手增加多轮对话流式返回接口

... ... @@ -131,7 +131,8 @@ public class ShiroConfig {
filterChainDefinitionMap.put("/**/*.wasm", "anon");
// 在ShiroConfig.java的filterChainDefinitionMap中添加:
filterChainDefinitionMap.put("/**/airag/zdyRag/send", "anon"); // 精确匹配接口路径
filterChainDefinitionMap.put("/**/airag/zdyRag/sendStream", "anon"); // 精确匹配接口路径
filterChainDefinitionMap.put("/**/airag/zdyRag/multiStageStream", "anon"); // 精确匹配接口路径
filterChainDefinitionMap.put("/public/**", "anon"); // 通配符放行所有公开路径
// 放行按钮接口(如果确实需要公开)
filterChainDefinitionMap.put("/**/airagbutton/airagButton/buttonList", "anon");
... ...
... ... @@ -110,7 +110,7 @@ public class ZdyRagController {
}
// 从知识库搜索
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 3, 0.75);
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 2, 0.78);
if (CollectionUtil.isEmpty(maps)) {
Map<String, String> data = new HashMap<>();
data.put("token", "该问题未记录在知识库中");
... ... @@ -134,7 +134,7 @@ public class ZdyRagController {
// 构建知识库内容
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > 0.75) {
if (Double.parseDouble(map.get("score").toString()) > 0.78) {
content.append(map.get("content").toString()).append("\n");
}
}
... ... @@ -148,11 +148,14 @@ public class ZdyRagController {
String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +
"处理步骤和要求:\n" +
"1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
"2. 回答结构:\n" +
"2. 严格基于参考内容回答,禁止使用参考内容中与问题无关的内容\n" +
"3. 回答结构:\n" +
" - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
" - 然后列出支持该答案的具体证据(可直接引用参考内容)\n" +
"3. 禁止以下行为:\n" +
"4. 禁止以下行为:\n" +
" - 添加参考内容中不存在的信息\n" +
" - 在回答中提及‘参考内容’等字样\n" +
" - 在回答中提及其他产品的功能\n" +
" - 进行任何推测性陈述\n" +
" - 使用模糊或不确定的表达\n" +
" - 参考内容为空时应该拒绝回答\n" +
... ... @@ -187,6 +190,8 @@ public class ZdyRagController {
// 记录日志 - 从知识库生成回答
logRecord.setAnswer(answerBuilder.toString())
.setAnswerType(2);
System.out.println("回答内容 = " + answerBuilder.toString());
airagLogService.save(logRecord);
emitter.complete();
... ... @@ -292,19 +297,43 @@ public class ZdyRagController {
List<ChatMessage> messages = new ArrayList<>();
String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +
"处理步骤和要求:\n" +
"1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
"2. 回答结构:\n" +
" - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
" - 然后列出支持该答案的具体证据(可直接引用参考内容)\n" +
"3. 禁止以下行为:\n" +
" - 添加参考内容中不存在的信息\n" +
" - 进行任何推测性陈述\n" +
" - 使用模糊或不确定的表达\n" +
" - 参考内容为空时应该拒绝回答\n" +
"参考内容(请严格限制回答范围于此):\n" + content;
// String questin = "你是一个严谨的信息处理助手,请严格按照以下要求回答用户问题:" + questionText + "\n\n" +
// "处理步骤和要求:\n" +
// "1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
// "2. 回答结构:\n" +
// " - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
// " - 然后列出支持该答案的说明,以点的方式将这些说明列出(可直接引用参考内容)\n" +
// "3. 禁止以下行为:\n" +
// " - 添加参考内容中不存在的信息\n" +
// " - 进行任何推测性陈述\n" +
// " - 使用模糊或不确定的表达\n" +
// " - 参考内容为空时应该拒绝回答\n" +
// "参考内容(请严格限制回答范围于此):\n" + content;
String questin = "你是一个严格遵循指令的信息处理助手,请按照以下规范回答用户问题:\n\n" +
"# 处理规范\n" +
"1. 回答范围:\n" +
" - 仅使用提供的参考内容进行回答\n" +
" - 禁止任何超出参考内容的推断、想象或补充\n" +
" - 当参考内容为空或不相关时,必须拒绝回答\n\n" +
"2. 回答结构要求:\n" +
" - 首行必须用「回答:」开头,给出最直接的事实性回答\n" +
" - 后续每行以「•」开头列出支持证据,每条证据必须:\n" +
" * 直接引用参考内容\n" +
" * 标注具体出处位置(如段落编号/行号)\n" +
" * 保持原句完整性,不得改写\n\n" +
"3. 禁止事项:\n" +
" - 任何形式的推测(包括\"可能\"、\"应该\"等不确定表述)\n" +
" - 回答内容不得提出\"参考内容\"、\"证据\"等字样\n" +
" - 参考内容中未明确出现的数字、事实或结论\n" +
" - 总结性陈述或观点性表达\n" +
" - 多个信息点的合并表述\n\n" +
"4. 特殊情形处理:\n" +
" - 专业术语必须保持原文表述\n" +
" - 数据必须包含原始单位和精度\n\n" +
"# 当前任务\n" +
"问题:「" + questionText + "」\n\n" +
"参考内容(严格限制回答范围):\n" +
content;
messages.add(new UserMessage("user", questin));
String chat = aiChatHandler.completions(modelId, messages, null);
... ...
package org.jeecg.modules.airag.zdyrag.controller;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.ai.handler.LLMHandler;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.*;
import java.util.concurrent.*;
@RestController
@RequestMapping("/airag/zdyRag")
@Slf4j
public class ZdyRagMultiStageController {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
private IAIChatHandler aiChatHandler;
@Autowired
private IAiragLogService airagLogService;
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Value("${jeecg.upload.path}")
private String uploadPath;
private final ExecutorService executor = Executors.newCachedThreadPool();
private final ExecutorService asyncLLMExecutor = Executors.newFixedThreadPool(5);
private static final int MAX_CONTEXT_SIZE = 10;
private static final long CONTEXT_TTL_MILLIS = 30 * 60 * 1000; // 30分钟过期
private String redisKey(String sessionId) {
return "chat:context:" + sessionId;
}
@Operation(summary = "multiStageStream with Redis context")
@GetMapping("multiStageStream")
public SseEmitter multiStageStream(@RequestParam String questionText,
@RequestParam(required = false) String sessionId) throws Exception {
SseEmitter emitter = new SseEmitter(300000L);
String modelId = "1926875898187878401";
String knowId = "1926872137990148098";
AiragLog logRecord = new AiragLog()
.setQuestion(questionText)
.setModelId(modelId)
.setCreateTime(new Date());
executor.execute(() -> {
try {
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 5, 0.75);
if (CollectionUtil.isEmpty(maps)) {
sendSimpleMessage(emitter, "该问题未记录在知识库中");
logRecord.setAnswer("该问题未记录在知识库中").setAnswerType(3).setIsStorage(0);
airagLogService.save(logRecord);
emitter.complete();
return;
}
// 多线程摘要
List<Future<String>> summaryFutures = new ArrayList<>();
for (Map<String, Object> map : maps) {
String content = map.get("content").toString();
String summaryPrompt = buildSummaryPrompt(questionText, content);
summaryFutures.add(asyncLLMExecutor.submit(() ->
aiChatHandler.completions(modelId, List.of(new UserMessage("user", summaryPrompt)), null)
));
}
List<String> summaries = new ArrayList<>();
for (Future<String> future : summaryFutures) {
try {
String summary = future.get(15, TimeUnit.SECONDS);
if (StringUtils.isNotBlank(summary)) summaries.add(summary.trim());
} catch (Exception e) {
log.warn("摘要生成失败", e);
}
}
// 多线程候选答案
List<Future<String>> answerFutures = new ArrayList<>();
for (String summary : summaries) {
String answerPrompt = buildAnswerPrompt(questionText, summary);
answerFutures.add(asyncLLMExecutor.submit(() ->
aiChatHandler.completions(modelId, List.of(new UserMessage("user", answerPrompt)), null)
));
}
List<String> candidateAnswers = new ArrayList<>();
for (Future<String> future : answerFutures) {
try {
String answer = future.get(15, TimeUnit.SECONDS);
if (StringUtils.isNotBlank(answer)) candidateAnswers.add(answer);
} catch (Exception e) {
log.warn("候选答案生成失败", e);
}
}
String mergePrompt = buildMergePrompt(questionText, candidateAnswers);
List<ChatMessage> mergeMessages = new ArrayList<>();
// 从 Redis 读取历史上下文
if (StringUtils.isNotBlank(sessionId)) {
Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
if (cached instanceof List) {
//noinspection unchecked
mergeMessages.addAll((List<ChatMessage>) cached);
}
}
mergeMessages.add(new UserMessage("user", mergePrompt));
StringBuilder answerBuilder = new StringBuilder();
Map<String, Object> firstMatch = maps.get(0);
String storedFileName = extractFieldFromMetadata(firstMatch.get("metadata"), "storedFileName");
String docName = extractFieldFromMetadata(firstMatch.get("metadata"), "docName");
String similarityScore = String.valueOf(firstMatch.get("score"));
TokenStream tokenStream = aiChatHandler.chat(modelId, mergeMessages);
tokenStream.onNext(token -> {
try {
answerBuilder.append(token);
Map<String, String> data = new HashMap<>();
data.put("token", token);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
} catch (Exception e) {
log.error("发送 token 失败", e);
}
});
tokenStream.onComplete(response -> {
try {
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", similarityScore);
endData.put("fileName", docName);
if (StringUtils.isNotBlank(storedFileName)) {
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
}
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
airagLogService.save(logRecord);
// 保存更新上下文到 Redis,截断最近10条
if (StringUtils.isNotBlank(sessionId)) {
Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
List<ChatMessage> context;
if (cached instanceof List) {
//noinspection unchecked
context = new ArrayList<>((List<ChatMessage>) cached);
} else {
context = new ArrayList<>();
}
context.add(new UserMessage("user", questionText));
context.add(new UserMessage("assistant", answerBuilder.toString()));
if (context.size() > MAX_CONTEXT_SIZE) {
context = context.subList(context.size() - MAX_CONTEXT_SIZE, context.size());
}
redisTemplate.opsForValue().set(redisKey(sessionId), context, CONTEXT_TTL_MILLIS, TimeUnit.MILLISECONDS);
}
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}
});
tokenStream.onError(error -> {
log.error("生成答案失败", error);
logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(error);
});
tokenStream.start();
} catch (Exception e) {
log.error("多阶段处理异常", e);
logRecord.setAnswer("处理异常: " + e.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(e);
}
});
return emitter;
}
private void sendSimpleMessage(SseEmitter emitter, String message) throws Exception {
Map<String, String> data = new HashMap<>();
data.put("token", message);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
}
private String extractFieldFromMetadata(Object metadataObj, String key) throws Exception {
if (metadataObj == null) return "";
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(metadataObj.toString(), Map.class);
if (metadata.containsKey(key)) {
return metadata.get(key);
}
return "";
}
private String buildSummaryPrompt(String question, String content) {
return "你是一个信息摘要助手,请只针对以下内容进行摘要,严格不添加其他产品信息或无关内容:\n\n" +
"用户问题:" + question + "\n" +
"内容段落:\n" + content + "\n\n" +
"请提取与问题直接相关且仅限于该内容的关键信息,控制在200字以内。";
}
private String buildAnswerPrompt(String question, String summary) {
return "你是一个信息回答助手,请严格根据以下摘要内容回答用户问题。\n\n" +
"用户问题:" + question + "\n" +
"摘要内容:\n" + summary + "\n\n" +
"回答要求:\n- 回答必须以‘回答:’开头\n- 严格禁止添加摘要外的信息\n- 只能使用摘要中提及的内容\n- 禁止合并其他摘要的内容。";
}
private String buildMergePrompt(String question, List<String> answers) {
StringBuilder sb = new StringBuilder("你收到多个候选答案,请从中选择最准确且不交叉混淆产品信息的答案作为最终回答。\n\n");
sb.append("用户问题:").append(question).append("\n");
for (int i = 0; i < answers.size(); i++) {
sb.append("候选答案").append(i + 1).append(":\n").append(answers.get(i)).append("\n\n");
}
sb.append("请直接输出最佳答案,**禁止添加新信息**或跨产品混合。");
return sb.toString();
}
}
... ...