|
|
|
package org.jeecg.modules.airag.zdyrag.controller;
|
|
|
|
|
|
|
|
import cn.hutool.core.collection.CollectionUtil;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
import dev.langchain4j.data.message.ChatMessage;
|
|
|
|
import dev.langchain4j.data.message.UserMessage;
|
|
|
|
import dev.langchain4j.service.TokenStream;
|
|
|
|
import io.swagger.v3.oas.annotations.Operation;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import org.jeecg.ai.handler.LLMHandler;
|
|
|
|
import org.jeecg.common.api.vo.Result;
|
|
|
|
import org.jeecg.modules.airag.app.entity.AiragLog;
|
|
|
|
import org.jeecg.modules.airag.app.service.IAiragLogService;
|
|
|
|
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
|
|
|
|
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
|
|
|
|
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
|
import org.springframework.data.redis.core.RedisTemplate;
|
|
|
|
import org.springframework.web.bind.annotation.GetMapping;
|
|
|
|
import org.springframework.web.bind.annotation.RequestMapping;
|
|
|
|
import org.springframework.web.bind.annotation.RequestParam;
|
|
|
|
import org.springframework.web.bind.annotation.RestController;
|
|
|
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
|
|
|
|
import java.util.*;
|
|
|
|
import java.util.concurrent.*;
|
|
|
|
|
|
|
|
@RestController
|
|
|
|
@RequestMapping("/airag/zdyRag")
|
|
|
|
@Slf4j
|
|
|
|
public class ZdyRagMultiStageController {
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
private EmbeddingHandler embeddingHandler;
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
private IAIChatHandler aiChatHandler;
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
private IAiragLogService airagLogService;
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
private RedisTemplate<String, Object> redisTemplate;
|
|
|
|
|
|
|
|
@Value("${jeecg.upload.path}")
|
|
|
|
private String uploadPath;
|
|
|
|
|
|
|
|
private final ExecutorService executor = Executors.newCachedThreadPool();
|
|
|
|
private final ExecutorService asyncLLMExecutor = Executors.newFixedThreadPool(5);
|
|
|
|
|
|
|
|
private static final int MAX_CONTEXT_SIZE = 10;
|
|
|
|
private static final long CONTEXT_TTL_MILLIS = 30 * 60 * 1000; // 30分钟过期
|
|
|
|
|
|
|
|
private String redisKey(String sessionId) {
|
|
|
|
return "chat:context:" + sessionId;
|
|
|
|
}
|
|
|
|
|
|
|
|
@Operation(summary = "multiStageStream with Redis context")
|
|
|
|
@GetMapping("multiStageStream")
|
|
|
|
public SseEmitter multiStageStream(@RequestParam String questionText,
|
|
|
|
@RequestParam(required = false) String sessionId) throws Exception {
|
|
|
|
SseEmitter emitter = new SseEmitter(300000L);
|
|
|
|
String modelId = "1926875898187878401";
|
|
|
|
String knowId = "1926872137990148098";
|
|
|
|
|
|
|
|
AiragLog logRecord = new AiragLog()
|
|
|
|
.setQuestion(questionText)
|
|
|
|
.setModelId(modelId)
|
|
|
|
.setCreateTime(new Date());
|
|
|
|
|
|
|
|
executor.execute(() -> {
|
|
|
|
try {
|
|
|
|
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 5, 0.75);
|
|
|
|
|
|
|
|
if (CollectionUtil.isEmpty(maps)) {
|
|
|
|
sendSimpleMessage(emitter, "该问题未记录在知识库中");
|
|
|
|
logRecord.setAnswer("该问题未记录在知识库中").setAnswerType(3).setIsStorage(0);
|
|
|
|
airagLogService.save(logRecord);
|
|
|
|
emitter.complete();
|
|
|
|
return;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 多线程摘要
|
|
|
|
List<Future<String>> summaryFutures = new ArrayList<>();
|
|
|
|
for (Map<String, Object> map : maps) {
|
|
|
|
String content = map.get("content").toString();
|
|
|
|
String summaryPrompt = buildSummaryPrompt(questionText, content);
|
|
|
|
summaryFutures.add(asyncLLMExecutor.submit(() ->
|
|
|
|
aiChatHandler.completions(modelId, List.of(new UserMessage("user", summaryPrompt)), null)
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
List<String> summaries = new ArrayList<>();
|
|
|
|
for (Future<String> future : summaryFutures) {
|
|
|
|
try {
|
|
|
|
String summary = future.get(15, TimeUnit.SECONDS);
|
|
|
|
if (StringUtils.isNotBlank(summary)) summaries.add(summary.trim());
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.warn("摘要生成失败", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 多线程候选答案
|
|
|
|
List<Future<String>> answerFutures = new ArrayList<>();
|
|
|
|
for (String summary : summaries) {
|
|
|
|
String answerPrompt = buildAnswerPrompt(questionText, summary);
|
|
|
|
answerFutures.add(asyncLLMExecutor.submit(() ->
|
|
|
|
aiChatHandler.completions(modelId, List.of(new UserMessage("user", answerPrompt)), null)
|
|
|
|
));
|
|
|
|
}
|
|
|
|
|
|
|
|
List<String> candidateAnswers = new ArrayList<>();
|
|
|
|
for (Future<String> future : answerFutures) {
|
|
|
|
try {
|
|
|
|
String answer = future.get(15, TimeUnit.SECONDS);
|
|
|
|
if (StringUtils.isNotBlank(answer)) candidateAnswers.add(answer);
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.warn("候选答案生成失败", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
String mergePrompt = buildMergePrompt(questionText, candidateAnswers);
|
|
|
|
List<ChatMessage> mergeMessages = new ArrayList<>();
|
|
|
|
|
|
|
|
// 从 Redis 读取历史上下文
|
|
|
|
if (StringUtils.isNotBlank(sessionId)) {
|
|
|
|
Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
|
|
|
|
if (cached instanceof List) {
|
|
|
|
//noinspection unchecked
|
|
|
|
mergeMessages.addAll((List<ChatMessage>) cached);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
mergeMessages.add(new UserMessage("user", mergePrompt));
|
|
|
|
|
|
|
|
StringBuilder answerBuilder = new StringBuilder();
|
|
|
|
|
|
|
|
Map<String, Object> firstMatch = maps.get(0);
|
|
|
|
String storedFileName = extractFieldFromMetadata(firstMatch.get("metadata"), "storedFileName");
|
|
|
|
String docName = extractFieldFromMetadata(firstMatch.get("metadata"), "docName");
|
|
|
|
String similarityScore = String.valueOf(firstMatch.get("score"));
|
|
|
|
|
|
|
|
TokenStream tokenStream = aiChatHandler.chat(modelId, mergeMessages);
|
|
|
|
|
|
|
|
tokenStream.onNext(token -> {
|
|
|
|
try {
|
|
|
|
answerBuilder.append(token);
|
|
|
|
Map<String, String> data = new HashMap<>();
|
|
|
|
data.put("token", token);
|
|
|
|
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.error("发送 token 失败", e);
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
tokenStream.onComplete(response -> {
|
|
|
|
try {
|
|
|
|
Map<String, String> endData = new HashMap<>();
|
|
|
|
endData.put("event", "END");
|
|
|
|
endData.put("similarity", similarityScore);
|
|
|
|
endData.put("fileName", docName);
|
|
|
|
if (StringUtils.isNotBlank(storedFileName)) {
|
|
|
|
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
|
|
|
|
}
|
|
|
|
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
|
|
|
|
|
|
|
|
logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
|
|
|
|
airagLogService.save(logRecord);
|
|
|
|
|
|
|
|
// 保存更新上下文到 Redis,截断最近10条
|
|
|
|
if (StringUtils.isNotBlank(sessionId)) {
|
|
|
|
Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
|
|
|
|
List<ChatMessage> context;
|
|
|
|
if (cached instanceof List) {
|
|
|
|
//noinspection unchecked
|
|
|
|
context = new ArrayList<>((List<ChatMessage>) cached);
|
|
|
|
} else {
|
|
|
|
context = new ArrayList<>();
|
|
|
|
}
|
|
|
|
context.add(new UserMessage("user", questionText));
|
|
|
|
context.add(new UserMessage("assistant", answerBuilder.toString()));
|
|
|
|
if (context.size() > MAX_CONTEXT_SIZE) {
|
|
|
|
context = context.subList(context.size() - MAX_CONTEXT_SIZE, context.size());
|
|
|
|
}
|
|
|
|
redisTemplate.opsForValue().set(redisKey(sessionId), context, CONTEXT_TTL_MILLIS, TimeUnit.MILLISECONDS);
|
|
|
|
}
|
|
|
|
|
|
|
|
emitter.complete();
|
|
|
|
} catch (Exception e) {
|
|
|
|
emitter.completeWithError(e);
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
tokenStream.onError(error -> {
|
|
|
|
log.error("生成答案失败", error);
|
|
|
|
logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
|
|
|
|
airagLogService.save(logRecord);
|
|
|
|
emitter.completeWithError(error);
|
|
|
|
});
|
|
|
|
|
|
|
|
tokenStream.start();
|
|
|
|
|
|
|
|
} catch (Exception e) {
|
|
|
|
log.error("多阶段处理异常", e);
|
|
|
|
logRecord.setAnswer("处理异常: " + e.getMessage()).setAnswerType(4);
|
|
|
|
airagLogService.save(logRecord);
|
|
|
|
emitter.completeWithError(e);
|
|
|
|
}
|
|
|
|
});
|
|
|
|
|
|
|
|
return emitter;
|
|
|
|
}
|
|
|
|
|
|
|
|
private void sendSimpleMessage(SseEmitter emitter, String message) throws Exception {
|
|
|
|
Map<String, String> data = new HashMap<>();
|
|
|
|
data.put("token", message);
|
|
|
|
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
|
|
|
|
}
|
|
|
|
|
|
|
|
private String extractFieldFromMetadata(Object metadataObj, String key) throws Exception {
|
|
|
|
if (metadataObj == null) return "";
|
|
|
|
ObjectMapper objectMapper = new ObjectMapper();
|
|
|
|
Map<String, String> metadata = objectMapper.readValue(metadataObj.toString(), Map.class);
|
|
|
|
if (metadata.containsKey(key)) {
|
|
|
|
return metadata.get(key);
|
|
|
|
}
|
|
|
|
return "";
|
|
|
|
}
|
|
|
|
|
|
|
|
private String buildSummaryPrompt(String question, String content) {
|
|
|
|
return "你是一个信息摘要助手,请只针对以下内容进行摘要,严格不添加其他产品信息或无关内容:\n\n" +
|
|
|
|
"用户问题:" + question + "\n" +
|
|
|
|
"内容段落:\n" + content + "\n\n" +
|
|
|
|
"请提取与问题直接相关且仅限于该内容的关键信息,控制在200字以内。";
|
|
|
|
}
|
|
|
|
|
|
|
|
private String buildAnswerPrompt(String question, String summary) {
|
|
|
|
return "你是一个信息回答助手,请严格根据以下摘要内容回答用户问题。\n\n" +
|
|
|
|
"用户问题:" + question + "\n" +
|
|
|
|
"摘要内容:\n" + summary + "\n\n" +
|
|
|
|
"回答要求:\n- 回答必须以‘回答:’开头\n- 严格禁止添加摘要外的信息\n- 只能使用摘要中提及的内容\n- 禁止合并其他摘要的内容。";
|
|
|
|
}
|
|
|
|
private String buildMergePrompt(String question, List<String> answers) {
|
|
|
|
StringBuilder sb = new StringBuilder("你收到多个候选答案,请从中选择最准确且不交叉混淆产品信息的答案作为最终回答。\n\n");
|
|
|
|
sb.append("用户问题:").append(question).append("\n");
|
|
|
|
for (int i = 0; i < answers.size(); i++) {
|
|
|
|
sb.append("候选答案").append(i + 1).append(":\n").append(answers.get(i)).append("\n\n");
|
|
|
|
}
|
|
|
|
sb.append("请直接输出最佳答案,**禁止添加新信息**或跨产品混合。");
|
|
|
|
return sb.toString();
|
|
|
|
}
|
|
|
|
|
|
|
|
} |
...
|
...
|
|