|
|
|
1
|
+package org.jeecg.modules.airag.zdyrag.controller;
|
|
|
|
2
|
+
|
|
|
|
3
|
+import cn.hutool.core.collection.CollectionUtil;
|
|
|
|
4
|
+import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
5
|
+import dev.langchain4j.data.message.ChatMessage;
|
|
|
|
6
|
+import dev.langchain4j.data.message.UserMessage;
|
|
|
|
7
|
+import dev.langchain4j.service.TokenStream;
|
|
|
|
8
|
+import io.swagger.v3.oas.annotations.Operation;
|
|
|
|
9
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
|
10
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
|
11
|
+import org.jeecg.ai.handler.LLMHandler;
|
|
|
|
12
|
+import org.jeecg.common.api.vo.Result;
|
|
|
|
13
|
+import org.jeecg.modules.airag.app.entity.AiragLog;
|
|
|
|
14
|
+import org.jeecg.modules.airag.app.service.IAiragLogService;
|
|
|
|
15
|
+import org.jeecg.modules.airag.common.handler.IAIChatHandler;
|
|
|
|
16
|
+import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
|
|
|
|
17
|
+import org.jeecg.modules.airag.app.utils.FileToBase64Util;
|
|
|
|
18
|
+import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
19
|
+import org.springframework.beans.factory.annotation.Value;
|
|
|
|
20
|
+import org.springframework.data.redis.core.RedisTemplate;
|
|
|
|
21
|
+import org.springframework.web.bind.annotation.GetMapping;
|
|
|
|
22
|
+import org.springframework.web.bind.annotation.RequestMapping;
|
|
|
|
23
|
+import org.springframework.web.bind.annotation.RequestParam;
|
|
|
|
24
|
+import org.springframework.web.bind.annotation.RestController;
|
|
|
|
25
|
+import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
|
|
|
26
|
+
|
|
|
|
27
|
+import java.util.*;
|
|
|
|
28
|
+import java.util.concurrent.*;
|
|
|
|
29
|
+
|
|
|
|
30
|
+@RestController
|
|
|
|
31
|
+@RequestMapping("/airag/zdyRag")
|
|
|
|
32
|
+@Slf4j
|
|
|
|
33
|
+public class ZdyRagMultiStageController {
|
|
|
|
34
|
+
|
|
|
|
35
|
+ @Autowired
|
|
|
|
36
|
+ private EmbeddingHandler embeddingHandler;
|
|
|
|
37
|
+
|
|
|
|
38
|
+ @Autowired
|
|
|
|
39
|
+ private IAIChatHandler aiChatHandler;
|
|
|
|
40
|
+
|
|
|
|
41
|
+ @Autowired
|
|
|
|
42
|
+ private IAiragLogService airagLogService;
|
|
|
|
43
|
+
|
|
|
|
44
|
+ @Autowired
|
|
|
|
45
|
+ private RedisTemplate<String, Object> redisTemplate;
|
|
|
|
46
|
+
|
|
|
|
47
|
+ @Value("${jeecg.upload.path}")
|
|
|
|
48
|
+ private String uploadPath;
|
|
|
|
49
|
+
|
|
|
|
50
|
+ private final ExecutorService executor = Executors.newCachedThreadPool();
|
|
|
|
51
|
+ private final ExecutorService asyncLLMExecutor = Executors.newFixedThreadPool(5);
|
|
|
|
52
|
+
|
|
|
|
53
|
+ private static final int MAX_CONTEXT_SIZE = 10;
|
|
|
|
54
|
+ private static final long CONTEXT_TTL_MILLIS = 30 * 60 * 1000; // 30分钟过期
|
|
|
|
55
|
+
|
|
|
|
56
|
+ private String redisKey(String sessionId) {
|
|
|
|
57
|
+ return "chat:context:" + sessionId;
|
|
|
|
58
|
+ }
|
|
|
|
59
|
+
|
|
|
|
60
|
+ @Operation(summary = "multiStageStream with Redis context")
|
|
|
|
61
|
+ @GetMapping("multiStageStream")
|
|
|
|
62
|
+ public SseEmitter multiStageStream(@RequestParam String questionText,
|
|
|
|
63
|
+ @RequestParam(required = false) String sessionId) throws Exception {
|
|
|
|
64
|
+ SseEmitter emitter = new SseEmitter(300000L);
|
|
|
|
65
|
+ String modelId = "1926875898187878401";
|
|
|
|
66
|
+ String knowId = "1926872137990148098";
|
|
|
|
67
|
+
|
|
|
|
68
|
+ AiragLog logRecord = new AiragLog()
|
|
|
|
69
|
+ .setQuestion(questionText)
|
|
|
|
70
|
+ .setModelId(modelId)
|
|
|
|
71
|
+ .setCreateTime(new Date());
|
|
|
|
72
|
+
|
|
|
|
73
|
+ executor.execute(() -> {
|
|
|
|
74
|
+ try {
|
|
|
|
75
|
+ List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 5, 0.75);
|
|
|
|
76
|
+
|
|
|
|
77
|
+ if (CollectionUtil.isEmpty(maps)) {
|
|
|
|
78
|
+ sendSimpleMessage(emitter, "该问题未记录在知识库中");
|
|
|
|
79
|
+ logRecord.setAnswer("该问题未记录在知识库中").setAnswerType(3).setIsStorage(0);
|
|
|
|
80
|
+ airagLogService.save(logRecord);
|
|
|
|
81
|
+ emitter.complete();
|
|
|
|
82
|
+ return;
|
|
|
|
83
|
+ }
|
|
|
|
84
|
+
|
|
|
|
85
|
+ // 多线程摘要
|
|
|
|
86
|
+ List<Future<String>> summaryFutures = new ArrayList<>();
|
|
|
|
87
|
+ for (Map<String, Object> map : maps) {
|
|
|
|
88
|
+ String content = map.get("content").toString();
|
|
|
|
89
|
+ String summaryPrompt = buildSummaryPrompt(questionText, content);
|
|
|
|
90
|
+ summaryFutures.add(asyncLLMExecutor.submit(() ->
|
|
|
|
91
|
+ aiChatHandler.completions(modelId, List.of(new UserMessage("user", summaryPrompt)), null)
|
|
|
|
92
|
+ ));
|
|
|
|
93
|
+ }
|
|
|
|
94
|
+
|
|
|
|
95
|
+ List<String> summaries = new ArrayList<>();
|
|
|
|
96
|
+ for (Future<String> future : summaryFutures) {
|
|
|
|
97
|
+ try {
|
|
|
|
98
|
+ String summary = future.get(15, TimeUnit.SECONDS);
|
|
|
|
99
|
+ if (StringUtils.isNotBlank(summary)) summaries.add(summary.trim());
|
|
|
|
100
|
+ } catch (Exception e) {
|
|
|
|
101
|
+ log.warn("摘要生成失败", e);
|
|
|
|
102
|
+ }
|
|
|
|
103
|
+ }
|
|
|
|
104
|
+
|
|
|
|
105
|
+ // 多线程候选答案
|
|
|
|
106
|
+ List<Future<String>> answerFutures = new ArrayList<>();
|
|
|
|
107
|
+ for (String summary : summaries) {
|
|
|
|
108
|
+ String answerPrompt = buildAnswerPrompt(questionText, summary);
|
|
|
|
109
|
+ answerFutures.add(asyncLLMExecutor.submit(() ->
|
|
|
|
110
|
+ aiChatHandler.completions(modelId, List.of(new UserMessage("user", answerPrompt)), null)
|
|
|
|
111
|
+ ));
|
|
|
|
112
|
+ }
|
|
|
|
113
|
+
|
|
|
|
114
|
+ List<String> candidateAnswers = new ArrayList<>();
|
|
|
|
115
|
+ for (Future<String> future : answerFutures) {
|
|
|
|
116
|
+ try {
|
|
|
|
117
|
+ String answer = future.get(15, TimeUnit.SECONDS);
|
|
|
|
118
|
+ if (StringUtils.isNotBlank(answer)) candidateAnswers.add(answer);
|
|
|
|
119
|
+ } catch (Exception e) {
|
|
|
|
120
|
+ log.warn("候选答案生成失败", e);
|
|
|
|
121
|
+ }
|
|
|
|
122
|
+ }
|
|
|
|
123
|
+
|
|
|
|
124
|
+ String mergePrompt = buildMergePrompt(questionText, candidateAnswers);
|
|
|
|
125
|
+ List<ChatMessage> mergeMessages = new ArrayList<>();
|
|
|
|
126
|
+
|
|
|
|
127
|
+ // 从 Redis 读取历史上下文
|
|
|
|
128
|
+ if (StringUtils.isNotBlank(sessionId)) {
|
|
|
|
129
|
+ Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
|
|
|
|
130
|
+ if (cached instanceof List) {
|
|
|
|
131
|
+ //noinspection unchecked
|
|
|
|
132
|
+ mergeMessages.addAll((List<ChatMessage>) cached);
|
|
|
|
133
|
+ }
|
|
|
|
134
|
+ }
|
|
|
|
135
|
+ mergeMessages.add(new UserMessage("user", mergePrompt));
|
|
|
|
136
|
+
|
|
|
|
137
|
+ StringBuilder answerBuilder = new StringBuilder();
|
|
|
|
138
|
+
|
|
|
|
139
|
+ Map<String, Object> firstMatch = maps.get(0);
|
|
|
|
140
|
+ String storedFileName = extractFieldFromMetadata(firstMatch.get("metadata"), "storedFileName");
|
|
|
|
141
|
+ String docName = extractFieldFromMetadata(firstMatch.get("metadata"), "docName");
|
|
|
|
142
|
+ String similarityScore = String.valueOf(firstMatch.get("score"));
|
|
|
|
143
|
+
|
|
|
|
144
|
+ TokenStream tokenStream = aiChatHandler.chat(modelId, mergeMessages);
|
|
|
|
145
|
+
|
|
|
|
146
|
+ tokenStream.onNext(token -> {
|
|
|
|
147
|
+ try {
|
|
|
|
148
|
+ answerBuilder.append(token);
|
|
|
|
149
|
+ Map<String, String> data = new HashMap<>();
|
|
|
|
150
|
+ data.put("token", token);
|
|
|
|
151
|
+ emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
|
|
|
|
152
|
+ } catch (Exception e) {
|
|
|
|
153
|
+ log.error("发送 token 失败", e);
|
|
|
|
154
|
+ }
|
|
|
|
155
|
+ });
|
|
|
|
156
|
+
|
|
|
|
157
|
+ tokenStream.onComplete(response -> {
|
|
|
|
158
|
+ try {
|
|
|
|
159
|
+ Map<String, String> endData = new HashMap<>();
|
|
|
|
160
|
+ endData.put("event", "END");
|
|
|
|
161
|
+ endData.put("similarity", similarityScore);
|
|
|
|
162
|
+ endData.put("fileName", docName);
|
|
|
|
163
|
+ if (StringUtils.isNotBlank(storedFileName)) {
|
|
|
|
164
|
+ endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
|
|
|
|
165
|
+ }
|
|
|
|
166
|
+ emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
|
|
|
|
167
|
+
|
|
|
|
168
|
+ logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
|
|
|
|
169
|
+ airagLogService.save(logRecord);
|
|
|
|
170
|
+
|
|
|
|
171
|
+ // 保存更新上下文到 Redis,截断最近10条
|
|
|
|
172
|
+ if (StringUtils.isNotBlank(sessionId)) {
|
|
|
|
173
|
+ Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
|
|
|
|
174
|
+ List<ChatMessage> context;
|
|
|
|
175
|
+ if (cached instanceof List) {
|
|
|
|
176
|
+ //noinspection unchecked
|
|
|
|
177
|
+ context = new ArrayList<>((List<ChatMessage>) cached);
|
|
|
|
178
|
+ } else {
|
|
|
|
179
|
+ context = new ArrayList<>();
|
|
|
|
180
|
+ }
|
|
|
|
181
|
+ context.add(new UserMessage("user", questionText));
|
|
|
|
182
|
+ context.add(new UserMessage("assistant", answerBuilder.toString()));
|
|
|
|
183
|
+ if (context.size() > MAX_CONTEXT_SIZE) {
|
|
|
|
184
|
+ context = context.subList(context.size() - MAX_CONTEXT_SIZE, context.size());
|
|
|
|
185
|
+ }
|
|
|
|
186
|
+ redisTemplate.opsForValue().set(redisKey(sessionId), context, CONTEXT_TTL_MILLIS, TimeUnit.MILLISECONDS);
|
|
|
|
187
|
+ }
|
|
|
|
188
|
+
|
|
|
|
189
|
+ emitter.complete();
|
|
|
|
190
|
+ } catch (Exception e) {
|
|
|
|
191
|
+ emitter.completeWithError(e);
|
|
|
|
192
|
+ }
|
|
|
|
193
|
+ });
|
|
|
|
194
|
+
|
|
|
|
195
|
+ tokenStream.onError(error -> {
|
|
|
|
196
|
+ log.error("生成答案失败", error);
|
|
|
|
197
|
+ logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
|
|
|
|
198
|
+ airagLogService.save(logRecord);
|
|
|
|
199
|
+ emitter.completeWithError(error);
|
|
|
|
200
|
+ });
|
|
|
|
201
|
+
|
|
|
|
202
|
+ tokenStream.start();
|
|
|
|
203
|
+
|
|
|
|
204
|
+ } catch (Exception e) {
|
|
|
|
205
|
+ log.error("多阶段处理异常", e);
|
|
|
|
206
|
+ logRecord.setAnswer("处理异常: " + e.getMessage()).setAnswerType(4);
|
|
|
|
207
|
+ airagLogService.save(logRecord);
|
|
|
|
208
|
+ emitter.completeWithError(e);
|
|
|
|
209
|
+ }
|
|
|
|
210
|
+ });
|
|
|
|
211
|
+
|
|
|
|
212
|
+ return emitter;
|
|
|
|
213
|
+ }
|
|
|
|
214
|
+
|
|
|
|
215
|
+ private void sendSimpleMessage(SseEmitter emitter, String message) throws Exception {
|
|
|
|
216
|
+ Map<String, String> data = new HashMap<>();
|
|
|
|
217
|
+ data.put("token", message);
|
|
|
|
218
|
+ emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
|
|
|
|
219
|
+ }
|
|
|
|
220
|
+
|
|
|
|
221
|
+ private String extractFieldFromMetadata(Object metadataObj, String key) throws Exception {
|
|
|
|
222
|
+ if (metadataObj == null) return "";
|
|
|
|
223
|
+ ObjectMapper objectMapper = new ObjectMapper();
|
|
|
|
224
|
+ Map<String, String> metadata = objectMapper.readValue(metadataObj.toString(), Map.class);
|
|
|
|
225
|
+ if (metadata.containsKey(key)) {
|
|
|
|
226
|
+ return metadata.get(key);
|
|
|
|
227
|
+ }
|
|
|
|
228
|
+ return "";
|
|
|
|
229
|
+ }
|
|
|
|
230
|
+
|
|
|
|
231
|
+ private String buildSummaryPrompt(String question, String content) {
|
|
|
|
232
|
+ return "你是一个信息摘要助手,请只针对以下内容进行摘要,严格不添加其他产品信息或无关内容:\n\n" +
|
|
|
|
233
|
+ "用户问题:" + question + "\n" +
|
|
|
|
234
|
+ "内容段落:\n" + content + "\n\n" +
|
|
|
|
235
|
+ "请提取与问题直接相关且仅限于该内容的关键信息,控制在200字以内。";
|
|
|
|
236
|
+ }
|
|
|
|
237
|
+
|
|
|
|
238
|
+ private String buildAnswerPrompt(String question, String summary) {
|
|
|
|
239
|
+ return "你是一个信息回答助手,请严格根据以下摘要内容回答用户问题。\n\n" +
|
|
|
|
240
|
+ "用户问题:" + question + "\n" +
|
|
|
|
241
|
+ "摘要内容:\n" + summary + "\n\n" +
|
|
|
|
242
|
+ "回答要求:\n- 回答必须以‘回答:’开头\n- 严格禁止添加摘要外的信息\n- 只能使用摘要中提及的内容\n- 禁止合并其他摘要的内容。";
|
|
|
|
243
|
+ }
|
|
|
|
244
|
+ private String buildMergePrompt(String question, List<String> answers) {
|
|
|
|
245
|
+ StringBuilder sb = new StringBuilder("你收到多个候选答案,请从中选择最准确且不交叉混淆产品信息的答案作为最终回答。\n\n");
|
|
|
|
246
|
+ sb.append("用户问题:").append(question).append("\n");
|
|
|
|
247
|
+ for (int i = 0; i < answers.size(); i++) {
|
|
|
|
248
|
+ sb.append("候选答案").append(i + 1).append(":\n").append(answers.get(i)).append("\n\n");
|
|
|
|
249
|
+ }
|
|
|
|
250
|
+ sb.append("请直接输出最佳答案,**禁止添加新信息**或跨产品混合。");
|
|
|
|
251
|
+ return sb.toString();
|
|
|
|
252
|
+ }
|
|
|
|
253
|
+
|
|
|
|
254
|
+} |