作者 lixiang

修改为双数据源

@@ -17,14 +17,12 @@ import org.jeecg.modules.airag.app.utils.JsonUtils; @@ -17,14 +17,12 @@ import org.jeecg.modules.airag.app.utils.JsonUtils;
17 import org.jeecg.modules.airag.llm.entity.AiragKnowledge; 17 import org.jeecg.modules.airag.llm.entity.AiragKnowledge;
18 import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService; 18 import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
19 import org.springframework.beans.factory.annotation.Autowired; 19 import org.springframework.beans.factory.annotation.Autowired;
  20 +import org.springframework.transaction.annotation.Propagation;
20 import org.springframework.transaction.annotation.Transactional; 21 import org.springframework.transaction.annotation.Transactional;
21 import org.springframework.web.bind.annotation.*; 22 import org.springframework.web.bind.annotation.*;
22 import org.springframework.web.multipart.MultipartFile; 23 import org.springframework.web.multipart.MultipartFile;
23 24
24 -import java.util.Arrays;  
25 -import java.util.HashMap;  
26 -import java.util.LinkedHashMap;  
27 -import java.util.Map; 25 +import java.util.*;
28 import java.util.stream.Collectors; 26 import java.util.stream.Collectors;
29 27
30 @RestController 28 @RestController
@@ -44,13 +42,10 @@ public class QuestionEmbeddingController { @@ -44,13 +42,10 @@ public class QuestionEmbeddingController {
44 .collect(Collectors.toMap(AiragKnowledge::getId, AiragKnowledge::getName)); 42 .collect(Collectors.toMap(AiragKnowledge::getId, AiragKnowledge::getName));
45 43
46 page.getRecords().forEach(item -> { 44 page.getRecords().forEach(item -> {
47 - String metadata = item.getMetadata();  
48 - if (StringUtils.isNotBlank(metadata)) {  
49 - Map<String, String> jsonMap = JsonUtils.jsonUtils(metadata); 45 + Map<String, Object> jsonMap = item.getMetadata();
50 if (jsonMap.containsKey("knowledgeId")) { 46 if (jsonMap.containsKey("knowledgeId")) {
51 item.setKnowledgeName(airagKnowledgeMap.get(jsonMap.get("knowledgeId"))); 47 item.setKnowledgeName(airagKnowledgeMap.get(jsonMap.get("knowledgeId")));
52 - item.setKnowledgeId(jsonMap.get("knowledgeId"));  
53 - } 48 + item.setKnowledgeId(jsonMap.get("knowledgeId").toString());
54 } 49 }
55 50
56 }); 51 });
@@ -86,12 +81,9 @@ public class QuestionEmbeddingController { @@ -86,12 +81,9 @@ public class QuestionEmbeddingController {
86 String docId = String.valueOf(snowflakeGenerator.next()); 81 String docId = String.valueOf(snowflakeGenerator.next());
87 metadata.put("docId", docId); // 自动生成唯一文档ID 82 metadata.put("docId", docId); // 自动生成唯一文档ID
88 metadata.put("knowledgeId", record.getKnowledgeId()); 83 metadata.put("knowledgeId", record.getKnowledgeId());
89 - // 使用 Jackson 序列化 Map 到 JSON  
90 - ObjectMapper mapper = new ObjectMapper();  
91 - String metadataJson = mapper.writeValueAsString(metadata);  
92 - // 2. 设置到embeddings对象  
93 - record.setMetadata(metadataJson);  
94 84
  85 + record.setMetadata(metadata);
  86 + record.setId(UUID.randomUUID().toString());
95 int result = questionEmbeddingService.insert(record); 87 int result = questionEmbeddingService.insert(record);
96 return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败"); 88 return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败");
97 } 89 }
@@ -112,14 +104,10 @@ public class QuestionEmbeddingController { @@ -112,14 +104,10 @@ public class QuestionEmbeddingController {
112 String knowledgeName = airagKnowledgeMap.get(record.getKnowledgeId()); 104 String knowledgeName = airagKnowledgeMap.get(record.getKnowledgeId());
113 record.setKnowledgeName(knowledgeName); 105 record.setKnowledgeName(knowledgeName);
114 106
115 - String existMetadata = existRecord.getMetadata();  
116 - Map<String, String> jsonMap = new HashMap<>();  
117 - if (StringUtils.isNotBlank(existMetadata)) {  
118 - jsonMap = JsonUtils.jsonUtils(existMetadata);  
119 - } 107 + Map<String, Object> metadata = existRecord.getMetadata();
120 108
121 - jsonMap.put("knowledgeId", record.getKnowledgeId());  
122 - record.setMetadata(Json.toJson(jsonMap)); 109 + metadata.put("knowledgeId", record.getKnowledgeId());
  110 + record.setMetadata(metadata);
123 } 111 }
124 int result = questionEmbeddingService.update(record); 112 int result = questionEmbeddingService.update(record);
125 return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败"); 113 return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败");
@@ -144,7 +132,6 @@ public class QuestionEmbeddingController { @@ -144,7 +132,6 @@ public class QuestionEmbeddingController {
144 } 132 }
145 133
146 @PostMapping("/uploadZip") 134 @PostMapping("/uploadZip")
147 - @Transactional(rollbackFor = {Exception.class})  
148 public Result<?> uploadZip( 135 public Result<?> uploadZip(
149 @RequestParam("file") MultipartFile file, 136 @RequestParam("file") MultipartFile file,
150 @RequestParam("knowledgeId") String knowledgeId) { 137 @RequestParam("knowledgeId") String knowledgeId) {
@@ -30,7 +30,7 @@ public class QuestionEmbedding { @@ -30,7 +30,7 @@ public class QuestionEmbedding {
30 /** 30 /**
31 * 元数据 31 * 元数据
32 */ 32 */
33 - private String metadata; 33 + private Map<String, Object> metadata;
34 /** 34 /**
35 * 向量 35 * 向量
36 */ 36 */
1 package org.jeecg.modules.airag.app.mapper; 1 package org.jeecg.modules.airag.app.mapper;
2 2
3 -import cn.hutool.core.lang.generator.SnowflakeGenerator;  
4 -import com.alibaba.fastjson2.JSONObject; 3 +import com.baomidou.dynamic.datasource.annotation.DS;
  4 +import com.baomidou.mybatisplus.core.metadata.IPage;
5 import com.baomidou.mybatisplus.extension.plugins.pagination.Page; 5 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
6 -import com.fasterxml.jackson.core.JsonProcessingException;  
7 -import com.fasterxml.jackson.core.type.TypeReference;  
8 -import com.fasterxml.jackson.databind.ObjectMapper;  
9 -import com.pgvector.PGvector;  
10 -import dev.langchain4j.data.embedding.Embedding;  
11 -import dev.langchain4j.model.output.Response;  
12 -import io.minio.messages.Metadata;  
13 -import lombok.extern.slf4j.Slf4j;  
14 -import org.apache.commons.lang3.StringUtils;  
15 -import org.jeecg.modules.airag.app.entity.Embeddings; 6 +import org.apache.ibatis.annotations.Mapper;
  7 +import org.apache.ibatis.annotations.Param;
16 import org.jeecg.modules.airag.app.entity.QuestionEmbedding; 8 import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
17 -import org.jeecg.modules.airag.app.utils.AiModelUtils;  
18 -import org.postgresql.util.PGobject;  
19 -import org.springframework.beans.factory.annotation.Autowired;  
20 -import org.springframework.beans.factory.annotation.Value;  
21 -import org.springframework.stereotype.Component;  
22 9
23 -import java.sql.*;  
24 -import java.util.*;  
25 -import java.util.stream.Collectors; 10 +import java.util.List;
26 11
27 -@Component  
28 -@Slf4j  
29 -public class QuestionEmbeddingMapper { 12 +@Mapper
  13 +@DS("pgvector")
  14 +public interface QuestionEmbeddingMapper {
  15 + Page<QuestionEmbedding> findAll(IPage<QuestionEmbedding> page, @Param("questionEmbedding") QuestionEmbedding questionEmbedding);
30 16
31 - @Autowired  
32 - private AiModelUtils aiModelUtils; 17 + Integer findQuestionCount(@Param("questionEmbedding") QuestionEmbedding questionEmbedding);
33 18
34 - @Value("${jeecg.ai-chat.embedId}")  
35 - private String embedId;  
36 - // PostgreSQL连接参数(应与项目配置一致)  
37 - private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";  
38 - private static final String USER = "postgres";  
39 - private static final String PASSWORD = "postgres"; 19 + QuestionEmbedding findById(@Param("id") String id);
  20 + @DS("pgvector")
  21 + int insert(@Param("record") QuestionEmbedding record);
40 22
41 - // 获取数据库连接  
42 - private Connection getConnection() throws SQLException {  
43 - return DriverManager.getConnection(URL, USER, PASSWORD);  
44 - } 23 + int update(@Param("record") QuestionEmbedding record);
45 24
46 - // 查询所有记录  
47 - public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, int pageNo, int pageSize) {  
48 - List<QuestionEmbedding> results = new ArrayList<>();  
49 - StringBuilder sql = new StringBuilder("select * from question_embedding where 1 = 1");  
50 - StringBuilder countSql = new StringBuilder("select count(1) from question_embedding where 1 = 1");  
51 - List<Object> params = new ArrayList<>();  
52 - List<Object> countParams = new ArrayList<>(); 25 + int deleteById(@Param("id") String id);
53 26
54 - if (StringUtils.isNotBlank(questionEmbedding.getKnowledgeId())) {  
55 - sql.append(" AND metadata ->> 'knowledgeId' = ?");  
56 - countSql.append(" AND metadata ->> 'knowledgeId' = ?");  
57 - params.add(questionEmbedding.getKnowledgeId());  
58 - countParams.add(questionEmbedding.getKnowledgeId());  
59 - }  
60 - if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){  
61 - sql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配  
62 - countSql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配  
63 - params.add("%" + questionEmbedding.getQuestion() + "%");  
64 - countParams.add("%" + questionEmbedding.getQuestion() + "%");  
65 - } 27 + int deleteByIds(@Param("ids") List<String> ids);
66 28
67 - if(StringUtils.isNotBlank(questionEmbedding.getAnswer())){  
68 - sql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配  
69 - countSql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配  
70 - params.add("%" + questionEmbedding.getAnswer() + "%");  
71 - countParams.add("%" + questionEmbedding.getAnswer() + "%");  
72 - }  
73 -  
74 - sql.append(" ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC");  
75 -  
76 - // 添加分页  
77 - sql.append(" LIMIT ? OFFSET ?");  
78 - params.add(pageSize);  
79 - params.add((pageNo - 1) * pageSize);  
80 -  
81 -  
82 - try(Connection conn = getConnection();  
83 - PreparedStatement stmt = conn.prepareStatement(sql.toString())){  
84 - // 设置参数值  
85 - for (int i = 0; i < params.size(); i++) {  
86 - stmt.setObject(i + 1, params.get(i));  
87 - }  
88 -  
89 - try (ResultSet rs = stmt.executeQuery()) {  
90 - while (rs.next()) {  
91 - results.add(mapRowToQuestionEmbedding(rs));  
92 - }  
93 - }  
94 - } catch (SQLException e) {  
95 - log.error("查询所有记录失败", e);  
96 - throw new RuntimeException("查询数据时发生数据库错误", e);  
97 - }  
98 -  
99 - // 执行计数查询  
100 - long total = 0;  
101 - try(Connection conn = getConnection();  
102 - PreparedStatement stmt = conn.prepareStatement(countSql.toString())){  
103 - // 设置参数值  
104 - for (int i = 0; i < countParams.size(); i++) {  
105 - stmt.setObject(i + 1, countParams.get(i));  
106 - }  
107 -  
108 - try (ResultSet rs = stmt.executeQuery()) {  
109 - if (rs.next()) {  
110 - total = rs.getLong(1); // 直接获取count值  
111 - }  
112 - }  
113 - } catch (SQLException e) {  
114 - log.error("查询记录总数失败", e);  
115 - throw new RuntimeException("查询记录总数时发生数据库错误", e);  
116 - }  
117 -  
118 - Page<QuestionEmbedding> page = new Page<>();  
119 - page.setRecords(results);  
120 - page.setTotal(total);  
121 - return page;  
122 - }  
123 -  
124 - // 查询所有记录  
125 - public Integer findQuestionCount(QuestionEmbedding questionEmbedding) {  
126 -  
127 - StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from question_embedding where 1 = 1");  
128 - List<Object> params = new ArrayList<>();  
129 -  
130 - if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){  
131 - sql.append(" AND question = ?");  
132 - params.add(questionEmbedding.getQuestion());  
133 - }  
134 -  
135 -  
136 - try(Connection conn = getConnection();  
137 - PreparedStatement stmt = conn.prepareStatement(sql.toString())){  
138 - // 设置参数值  
139 - for (int i = 0; i < params.size(); i++) {  
140 - stmt.setObject(i + 1, params.get(i));  
141 - }  
142 -  
143 - try (ResultSet rs = stmt.executeQuery()) {  
144 - while (rs.next()) {  
145 - return rs.getInt("total_count");  
146 - }  
147 - return 0;  
148 - }  
149 - } catch (SQLException e) {  
150 - log.error("查询所有记录失败", e);  
151 - throw new RuntimeException("查询数据时发生数据库错误", e);  
152 - }  
153 -  
154 - }  
155 -  
156 - // 根据ID查询单个记录  
157 - public QuestionEmbedding findById(String id) {  
158 - String sql = "SELECT * FROM question_embedding WHERE id = ?";  
159 -  
160 - try (Connection conn = getConnection();  
161 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
162 -  
163 - stmt.setString(1, id);  
164 - try (ResultSet rs = stmt.executeQuery()) {  
165 - if (rs.next()) {  
166 - return mapRowToQuestionEmbedding(rs);  
167 - }  
168 - }  
169 - } catch (SQLException e) {  
170 - log.error("根据ID查询记录失败, ID: {}", id, e);  
171 - throw new RuntimeException("根据ID查询时发生数据库错误", e);  
172 - }  
173 - return null;  
174 - }  
175 -  
176 - // 插入新记录  
177 - public int insert(QuestionEmbedding record) {  
178 - String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)";  
179 -  
180 -  
181 - try (Connection conn = getConnection();  
182 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
183 - stmt.setString(1, UUID.randomUUID().toString());  
184 - stmt.setString(2, record.getText());  
185 - stmt.setString(3, record.getQuestion());  
186 - stmt.setString(4, record.getAnswer());  
187 - PGobject jsonObject = new PGobject();  
188 - jsonObject.setType("json");  
189 - jsonObject.setValue(record.getMetadata());  
190 - stmt.setObject(5, jsonObject);  
191 - Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());  
192 - stmt.setObject(6, embedding.content().vector());  
193 - return stmt.executeUpdate();  
194 - } catch (SQLException e) {  
195 - log.error("插入记录失败: {}", record, e);  
196 - throw new RuntimeException("插入数据时发生数据库错误", e);  
197 - }  
198 - }  
199 -  
200 - // 更新记录  
201 - public int update(QuestionEmbedding record) {  
202 - String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?";  
203 -  
204 - try (Connection conn = getConnection();  
205 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
206 -  
207 -  
208 - stmt.setString(1, record.getText());  
209 - stmt.setString(2, record.getQuestion());  
210 - stmt.setString(3, record.getAnswer());  
211 - stmt.setObject(4, record.getMetadata());  
212 -  
213 - Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());  
214 - stmt.setObject(5, embedding.content().vector());  
215 -  
216 - stmt.setString(6, record.getId());  
217 -  
218 - return stmt.executeUpdate();  
219 - } catch (SQLException e) {  
220 - log.error("更新记录失败: {}", record, e);  
221 - throw new RuntimeException("更新数据时发生数据库错误", e);  
222 - }  
223 - }  
224 -  
225 -  
226 - // 批量删除方法  
227 - public int deleteByIds(List<String> ids) {  
228 - if (ids == null || ids.isEmpty()) {  
229 - return 0;  
230 - }  
231 -  
232 - String sql = "DELETE FROM question_embedding WHERE id IN (";  
233 - StringBuilder placeholders = new StringBuilder();  
234 - for (int i = 0; i < ids.size(); i++) {  
235 - placeholders.append("?");  
236 - if (i < ids.size() - 1) {  
237 - placeholders.append(",");  
238 - }  
239 - }  
240 - sql += placeholders.toString() + ")";  
241 -  
242 - try (Connection conn = getConnection();  
243 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
244 -  
245 - for (int i = 0; i < ids.size(); i++) {  
246 - stmt.setString(i + 1, ids.get(i));  
247 - }  
248 -  
249 - return stmt.executeUpdate();  
250 - } catch (SQLException e) {  
251 - log.error("批量删除向量记录失败, IDs: {}", ids, e);  
252 - throw new RuntimeException("批量删除向量数据时发生数据库错误", e);  
253 - }  
254 - }  
255 -  
256 -  
257 - /**  
258 - * 向量相似度查询 (基于问题文本的向量)  
259 - * @param question 问题文本  
260 - * @param limit 返回结果数量  
261 - * @param minSimilarity 最小相似度阈值(0-1)  
262 - * @return 相似问答列表(按相似度降序)  
263 - */  
264 - public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {  
265 - List<QuestionEmbedding> results = new ArrayList<>();  
266 -  
267 - // 1. 参数校验  
268 - if (minSimilarity < 0 || minSimilarity > 1) {  
269 - throw new IllegalArgumentException("相似度阈值必须在0到1之间");  
270 - }  
271 -  
272 - // 2. 获取问题的嵌入向量  
273 - Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, question);  
274 - float[] queryVector = embeddingResponse.content().vector();  
275 - // 3. 计算最大允许距离(1 - 相似度阈值)  
276 - double maxDistance = 1 - minSimilarity;  
277 -  
278 - // 4. 执行向量相似度查询  
279 - String sql = "SELECT *, embedding <-> ? AS distance " +  
280 - "FROM question_embedding " +  
281 - "WHERE embedding <-> ? < ? " + // 距离小于阈值  
282 - "ORDER BY distance ASC " + // 按距离升序  
283 - "LIMIT ?";  
284 -  
285 - try (Connection conn = getConnection();  
286 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
287 -  
288 - // 设置参数  
289 - PGvector vector = new PGvector(queryVector);  
290 - stmt.setObject(1, vector);  
291 - stmt.setObject(2, vector);  
292 - stmt.setDouble(3, maxDistance);  
293 - stmt.setInt(4, limit);  
294 -  
295 - try (ResultSet rs = stmt.executeQuery()) {  
296 - while (rs.next()) {  
297 - QuestionEmbedding record = mapRowToQuestionEmbedding(rs);  
298 - // 计算相似度(1 - 距离)  
299 - double distance = rs.getDouble("distance");  
300 - double similarity = 1 - distance;  
301 - record.setSimilarity(similarity);  
302 - results.add(record);  
303 - }  
304 - }  
305 - } catch (SQLException e) {  
306 - log.error("向量相似度查询失败", e);  
307 - throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);  
308 - }  
309 - return results;  
310 - }  
311 -  
312 - /**  
313 - * 向量相似度查询 (直接使用向量)  
314 - * @param vector 查询向量  
315 - * @param limit 返回结果数量  
316 - * @return 相似问答列表(按相似度降序)  
317 - */  
318 - public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) {  
319 - List<QuestionEmbedding> results = new ArrayList<>();  
320 - String sql = "SELECT *, embedding <-> ? AS similarity " +  
321 - "FROM question_embedding " +  
322 - "ORDER BY similarity ASC " +  
323 - "LIMIT ?";  
324 -  
325 - try (Connection conn = getConnection();  
326 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
327 -  
328 - stmt.setObject(1, new PGvector(vector));  
329 - stmt.setInt(2, limit);  
330 -  
331 - try (ResultSet rs = stmt.executeQuery()) {  
332 - while (rs.next()) {  
333 - QuestionEmbedding record = mapRowToQuestionEmbedding(rs);  
334 - double similarity = 1 - rs.getDouble("similarity");  
335 - record.setSimilarity(similarity);  
336 - results.add(record);  
337 - }  
338 - }  
339 - } catch (SQLException e) {  
340 - log.error("向量相似度查询失败", e);  
341 - throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);  
342 - }  
343 - return results;  
344 - }  
345 -  
346 - // 根据ID删除记录  
347 - public int deleteById(String id) {  
348 - String sql = "DELETE FROM question_embedding WHERE id = ?";  
349 -  
350 - try (Connection conn = getConnection();  
351 - PreparedStatement stmt = conn.prepareStatement(sql)) {  
352 -  
353 - stmt.setString(1, id);  
354 - return stmt.executeUpdate();  
355 - } catch (SQLException e) {  
356 - log.error("删除记录失败, ID: {}", id, e);  
357 - throw new RuntimeException("删除数据时发生数据库错误", e);  
358 - }  
359 - }  
360 -  
361 - // 将ResultSet行映射为QuestionEmbedding对象  
362 - private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException {  
363 - QuestionEmbedding record = new QuestionEmbedding();  
364 - record.setId(rs.getString("id"));  
365 - record.setText(rs.getString("text"));  
366 - record.setQuestion(rs.getString("question"));  
367 - record.setAnswer(rs.getString("answer"));  
368 -  
369 - String metadataJson = rs.getString("metadata");  
370 - if (StringUtils.isNotBlank(metadataJson)) {  
371 - record.setMetadata(metadataJson);  
372 - }  
373 -  
374 - return record;  
375 - }  
376 -  
377 - // 将Map转换为JSON字符串  
378 - private String toJson(Map<String, Object> map) {  
379 - try {  
380 - return new ObjectMapper().writeValueAsString(map);  
381 - } catch (JsonProcessingException e) {  
382 - log.error("元数据转换为JSON失败", e);  
383 - return "{}";  
384 - }  
385 - }  
386 -  
387 - // 将JSON字符串转换为Map  
388 - private Map<String, Object> fromJson(String json) {  
389 - try {  
390 - return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {});  
391 - } catch (JsonProcessingException e) {  
392 - log.error("JSON转换为元数据失败", e);  
393 - return Collections.emptyMap();  
394 - }  
395 - } 29 + List<QuestionEmbedding> similaritySearchByQuestion(@Param("vector") float[] vector,
  30 + @Param("limit") int limit,
  31 + @Param("minSimilarity") Double minSimilarity);
396 32
  33 + List<QuestionEmbedding> similaritySearch(@Param("vector") float[] vector,
  34 + @Param("limit") int limit);
397 } 35 }
  1 +<?xml version="1.0" encoding="UTF-8"?>
  2 +<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
  3 + "http://mybatis.org/dtd/mybatis-3-mapper.dtd">
  4 +<mapper namespace="org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper">
  5 +
  6 + <resultMap id="questionEmbeddingResultMap" type="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
  7 + <id column="id" property="id" />
  8 + <result column="text" property="text" />
  9 + <result column="question" property="question" />
  10 + <result column="answer" property="answer" />
  11 + <result column="metadata" property="metadata" typeHandler="org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler"/>
  12 + <result column="similarity" property="similarity" />
  13 + </resultMap>
  14 +
  15 + <select id="findAll" resultMap="questionEmbeddingResultMap">
  16 + SELECT * FROM question_embedding WHERE 1 = 1
  17 + <if test="questionEmbedding.knowledgeId != null and questionEmbedding.knowledgeId != ''">
  18 + AND metadata ->> 'knowledgeId' = #{questionEmbedding.knowledgeId}
  19 + </if>
  20 + <if test="questionEmbedding.question != null and questionEmbedding.question != ''">
  21 + AND question ILIKE CONCAT('%', #{questionEmbedding.question}, '%')
  22 + </if>
  23 + <if test="questionEmbedding.answer != null and questionEmbedding.answer != ''">
  24 + AND answer ILIKE CONCAT('%', #{questionEmbedding.answer}, '%')
  25 + </if>
  26 + ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC
  27 + </select>
  28 +
  29 + <select id="findQuestionCount" resultType="int">
  30 + SELECT COUNT(1) AS total_count FROM question_embedding WHERE 1 = 1
  31 + <if test="questionEmbedding.question != null and questionEmbedding.question != ''">
  32 + AND question = #{questionEmbedding.question}
  33 + </if>
  34 + </select>
  35 +
  36 + <select id="findById" resultMap="questionEmbeddingResultMap">
  37 + SELECT * FROM question_embedding WHERE id = #{id}
  38 + </select>
  39 +
  40 + <insert id="insert" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
  41 + INSERT INTO question_embedding (id, text, question, answer, metadata, embedding)
  42 + VALUES (
  43 + #{record.id, jdbcType=VARCHAR},
  44 + #{record.text, jdbcType=VARCHAR},
  45 + #{record.question, jdbcType=VARCHAR},
  46 + #{record.answer, jdbcType=VARCHAR},
  47 + #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb,
  48 + #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}
  49 + )
  50 + </insert>
  51 +
  52 + <update id="update" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
  53 + UPDATE question_embedding
  54 + SET
  55 + text = #{record.text, jdbcType=VARCHAR},
  56 + question = #{record.question, jdbcType=VARCHAR},
  57 + answer = #{record.answer, jdbcType=VARCHAR},
  58 + metadata = #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb,
  59 + embedding = #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}
  60 + WHERE id = #{record.id}
  61 + </update>
  62 +
  63 + <delete id="deleteById">
  64 + DELETE FROM question_embedding WHERE id = #{id}
  65 + </delete>
  66 +
  67 + <delete id="deleteByIds">
  68 + DELETE FROM question_embedding WHERE id IN
  69 + <foreach collection="ids" item="id" open="(" separator="," close=")">
  70 + #{id}
  71 + </foreach>
  72 + </delete>
  73 +
  74 + <select id="similaritySearchByQuestion" resultMap="questionEmbeddingResultMap">
  75 + <![CDATA[
  76 + SELECT *,
  77 + (embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler})::float AS similarity
  78 + FROM question_embedding
  79 + WHERE (embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}) < #{minSimilarity}
  80 + ORDER BY similarity ASC
  81 + LIMIT #{limit}
  82 + ]]>
  83 + </select>
  84 +
  85 + <select id="similaritySearch" resultMap="questionEmbeddingResultMap">
  86 +<!-- SELECT *, embedding <-> #{vector} AS similarity-->
  87 +<!-- FROM question_embedding-->
  88 +<!-- ORDER BY similarity ASC-->
  89 +<!-- LIMIT #{limit}-->
  90 + </select>
  91 +</mapper>
@@ -50,8 +50,6 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i @@ -50,8 +50,6 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
50 50
51 @Override 51 @Override
52 public void saveToQuestionLibrary(AiragLog log) throws JsonProcessingException { 52 public void saveToQuestionLibrary(AiragLog log) throws JsonProcessingException {
53 - // 这里实现将问题和回答存入问题库数据表的逻辑  
54 - // 假设问题库数据表的实体类为 QuestionLibrary,Mapper 接口为 QuestionLibraryMapper  
55 QuestionEmbedding questionEmbedding = new QuestionEmbedding(); 53 QuestionEmbedding questionEmbedding = new QuestionEmbedding();
56 questionEmbedding.setQuestion(log.getQuestion()); 54 questionEmbedding.setQuestion(log.getQuestion());
57 questionEmbedding.setAnswer(log.getAnswer()); 55 questionEmbedding.setAnswer(log.getAnswer());
@@ -62,11 +60,7 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i @@ -62,11 +60,7 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
62 String docId = String.valueOf(snowflakeGenerator.next()); 60 String docId = String.valueOf(snowflakeGenerator.next());
63 metadata.put("docId", docId); 61 metadata.put("docId", docId);
64 metadata.put("knowledgeId", questionEmbedding.getKnowledgeId()); 62 metadata.put("knowledgeId", questionEmbedding.getKnowledgeId());
65 - // 使用 Jackson 序列化 Map 到 JSON  
66 - ObjectMapper mapper = new ObjectMapper();  
67 - String metadataJson = mapper.writeValueAsString(metadata);  
68 - // 2. 设置到embeddings对象  
69 - questionEmbedding.setMetadata(metadataJson); 63 + questionEmbedding.setMetadata(metadata);
70 questionEmbeddingMapper.insert(questionEmbedding); 64 questionEmbeddingMapper.insert(questionEmbedding);
71 airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId()); 65 airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId());
72 System.out.println("1"); 66 System.out.println("1");
1 package org.jeecg.modules.airag.app.service.impl; 1 package org.jeecg.modules.airag.app.service.impl;
2 2
  3 +import com.baomidou.dynamic.datasource.annotation.DS;
3 import com.baomidou.mybatisplus.extension.plugins.pagination.Page; 4 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
4 import com.fasterxml.jackson.core.JsonProcessingException; 5 import com.fasterxml.jackson.core.JsonProcessingException;
  6 +import org.apache.commons.lang3.StringUtils;
5 import org.apache.poi.hwpf.usermodel.CharacterRun; 7 import org.apache.poi.hwpf.usermodel.CharacterRun;
6 import org.apache.poi.hwpf.HWPFDocument; 8 import org.apache.poi.hwpf.HWPFDocument;
7 import org.apache.poi.hwpf.usermodel.Paragraph; 9 import org.apache.poi.hwpf.usermodel.Paragraph;
@@ -15,7 +17,10 @@ import dev.langchain4j.model.output.Response; @@ -15,7 +17,10 @@ import dev.langchain4j.model.output.Response;
15 import org.apache.commons.io.FilenameUtils; 17 import org.apache.commons.io.FilenameUtils;
16 import org.apache.poi.xwpf.usermodel.*; 18 import org.apache.poi.xwpf.usermodel.*;
17 import org.jeecg.common.api.vo.Result; 19 import org.jeecg.common.api.vo.Result;
  20 +import org.jeecg.modules.airag.app.entity.Embeddings;
18 import org.jeecg.modules.airag.app.entity.QuestionEmbedding; 21 import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
  22 +import org.jeecg.modules.airag.app.mapper.EmbeddingsMapper;
  23 +import org.jeecg.modules.airag.app.mapper.PgVectorMapper;
19 import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper; 24 import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper;
20 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; 25 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
21 import org.jeecg.modules.airag.app.utils.AiModelUtils; 26 import org.jeecg.modules.airag.app.utils.AiModelUtils;
@@ -60,7 +65,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -60,7 +65,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
60 private AiModelUtils aiModelUtils; 65 private AiModelUtils aiModelUtils;
61 66
62 @Autowired 67 @Autowired
63 - private IAIChatHandler aiChatHandler; 68 + private PgVectorMapper pgVectorMapper;
64 69
65 @Value("${jeecg.upload.path}") 70 @Value("${jeecg.upload.path}")
66 private String uploadPath; 71 private String uploadPath;
@@ -68,17 +73,13 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -68,17 +73,13 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
68 private String embedId; 73 private String embedId;
69 74
70 private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx"); 75 private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx");
71 - private static final Pattern SPECIAL_CHARS_PATTERN = Pattern.compile("[^a-zA-Z0-9\\u4e00-\\u9fa5\\s]");  
72 private static final Pattern UUID_PATTERN = Pattern.compile("_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"); 76 private static final Pattern UUID_PATTERN = Pattern.compile("_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}");
73 77
74 - // 数据库连接配置  
75 - private static final String DB_URL = "jdbc:postgresql://192.168.100.104:5432/postgres";  
76 - private static final String DB_USER = "postgres";  
77 - private static final String DB_PASSWORD = "postgres";  
78 78
79 @Override 79 @Override
80 public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, Integer pageNo, Integer pageSize) { 80 public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, Integer pageNo, Integer pageSize) {
81 - return questionEmbeddingMapper.findAll(questionEmbedding,pageNo,pageSize); 81 + Page<QuestionEmbedding> page = new Page<>(pageNo, pageSize);
  82 + return questionEmbeddingMapper.findAll(page,questionEmbedding);
82 } 83 }
83 84
84 @Override 85 @Override
@@ -93,11 +94,21 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -93,11 +94,21 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
93 94
94 @Override 95 @Override
95 public int insert(QuestionEmbedding record) { 96 public int insert(QuestionEmbedding record) {
  97 + if (StringUtils.isNotBlank(record.getQuestion())){
  98 + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
  99 + record.setEmbedding(embedding.content().vector());
  100 + }
  101 +
  102 +
96 return questionEmbeddingMapper.insert(record); 103 return questionEmbeddingMapper.insert(record);
97 } 104 }
98 105
99 @Override 106 @Override
100 public int update(QuestionEmbedding record) { 107 public int update(QuestionEmbedding record) {
  108 + if (StringUtils.isNotBlank(record.getQuestion())){
  109 + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
  110 + record.setEmbedding(embedding.content().vector());
  111 + }
101 return questionEmbeddingMapper.update(record); 112 return questionEmbeddingMapper.update(record);
102 } 113 }
103 114
@@ -113,7 +124,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -113,7 +124,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
113 124
114 @Override 125 @Override
115 public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { 126 public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
116 - return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); 127 + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, question);
  128 + return questionEmbeddingMapper.similaritySearchByQuestion(embedding.content().vector(), limit, minSimilarity);
117 } 129 }
118 130
119 @Override 131 @Override
@@ -183,10 +195,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -183,10 +195,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
183 segments = splitWordDocument(targetPath.toString()); 195 segments = splitWordDocument(targetPath.toString());
184 } 196 }
185 197
186 - // 原有逻辑:保存到question_embedding表  
187 saveSegmentsToDatabase(segments, originalFileName, storedFileName, knowledgeId); 198 saveSegmentsToDatabase(segments, originalFileName, storedFileName, knowledgeId);
188 199
189 - // 新增逻辑:同时保存到embeddings表  
190 saveToEmbeddingsTable(segments, originalFileName, storedFileName, knowledgeId); 200 saveToEmbeddingsTable(segments, originalFileName, storedFileName, knowledgeId);
191 201
192 } 202 }
@@ -196,7 +206,6 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -196,7 +206,6 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
196 String displayFileName = removeUuidSuffix(originalFileName); 206 String displayFileName = removeUuidSuffix(originalFileName);
197 displayFileName = FilenameUtils.removeExtension(displayFileName); 207 displayFileName = FilenameUtils.removeExtension(displayFileName);
198 208
199 - try (Connection conn = getConnection()) {  
200 for (String segment : segments) { 209 for (String segment : segments) {
201 if (segment.trim().isEmpty()) continue; 210 if (segment.trim().isEmpty()) continue;
202 211
@@ -205,44 +214,29 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -205,44 +214,29 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
205 if (parts.length < 2) continue; 214 if (parts.length < 2) continue;
206 215
207 String titlePath = parts[0].trim(); 216 String titlePath = parts[0].trim();
208 - String answer = segment.trim(); // 整个回答段(含标题 + 内容) 217 + // 整个回答段(标题 + 内容)
  218 + String answer = segment.trim();
209 219
210 // 获取 embedding 220 // 获取 embedding
211 Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, answer); 221 Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, answer);
212 float[] embeddingVector = embeddingResponse.content().vector(); 222 float[] embeddingVector = embeddingResponse.content().vector();
213 223
214 - // 准备 metadata  
215 Map<String, Object> metadata = new HashMap<>(); 224 Map<String, Object> metadata = new HashMap<>();
216 metadata.put("docName", originalFileName); 225 metadata.put("docName", originalFileName);
217 metadata.put("storedFileName", storedFileName); 226 metadata.put("storedFileName", storedFileName);
218 metadata.put("knowledgeId", knowledgeId); 227 metadata.put("knowledgeId", knowledgeId);
219 metadata.put("title", displayFileName + ": " + titlePath); 228 metadata.put("title", displayFileName + ": " + titlePath);
220 -  
221 - // 插入  
222 - String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";  
223 - try (PreparedStatement stmt = conn.prepareStatement(sql)) {  
224 - stmt.setString(1, UUID.randomUUID().toString());  
225 - stmt.setObject(2, new PGvector(embeddingVector));  
226 - stmt.setString(3, answer);  
227 -  
228 - PGobject jsonObject = new PGobject();  
229 - jsonObject.setType("json");  
230 - jsonObject.setValue(new ObjectMapper().writeValueAsString(metadata));  
231 - stmt.setObject(4, jsonObject);  
232 -  
233 - stmt.executeUpdate();  
234 - }  
235 - }  
236 - } catch (Exception e) {  
237 - log.error("保存分段到embeddings表失败", e); 229 + Embeddings embeddings = new Embeddings();
  230 + embeddings.setMetadata(metadata);
  231 + embeddings.setId(UUID.randomUUID().toString());
  232 + embeddings.setEmbedding(embeddingVector);
  233 + embeddings.setText(answer);
  234 + pgVectorMapper.insert(embeddings);
238 } 235 }
239 } 236 }
240 237
241 238
242 - // 获取数据库连接  
243 - private Connection getConnection() throws SQLException {  
244 - return DriverManager.getConnection(DB_URL, DB_USER, DB_PASSWORD);  
245 - } 239 +
246 240
247 private String generateStoredFileName(String originalFileName) { 241 private String generateStoredFileName(String originalFileName) {
248 String baseName = FilenameUtils.removeExtension(originalFileName); 242 String baseName = FilenameUtils.removeExtension(originalFileName);
@@ -359,6 +353,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -359,6 +353,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
359 353
360 return 0; 354 return 0;
361 } 355 }
  356 +
362 private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName, String knowledgeId) { 357 private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName, String knowledgeId) {
363 if (segments.isEmpty()) return; 358 if (segments.isEmpty()) return;
364 359
@@ -384,24 +379,22 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -384,24 +379,22 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
384 record.setAnswer(titleLine + "\n" + content); 379 record.setAnswer(titleLine + "\n" + content);
385 record.setText(""); 380 record.setText("");
386 381
387 - Map<String, String> metadata = new LinkedHashMap<>(); 382 + Map<String, Object> metadata = new LinkedHashMap<>();
388 metadata.put("docId", docId); 383 metadata.put("docId", docId);
389 metadata.put("docName", originalFileName); 384 metadata.put("docName", originalFileName);
390 metadata.put("storedFileName", storedFileName); 385 metadata.put("storedFileName", storedFileName);
391 metadata.put("knowledgeId", knowledgeId); 386 metadata.put("knowledgeId", knowledgeId);
392 387
393 - try {  
394 - record.setMetadata(new ObjectMapper().writeValueAsString(metadata));  
395 - } catch (JsonProcessingException e) {  
396 - log.error("生成metadata JSON失败", e);  
397 - } 388 +
  389 + record.setMetadata(metadata);
  390 +
398 391
399 log.info("保存分段: title={}, content_length={}", question, segment.length()); 392 log.info("保存分段: title={}, content_length={}", question, segment.length());
400 393
401 Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, record.getQuestion()); 394 Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, record.getQuestion());
402 record.setEmbedding(embeddingResponse.content().vector()); 395 record.setEmbedding(embeddingResponse.content().vector());
403 record.setKnowledgeId(knowledgeId); 396 record.setKnowledgeId(knowledgeId);
404 - questionEmbeddingMapper.insert(record); 397 + insert(record);
405 } 398 }
406 } 399 }
407 400
@@ -133,7 +133,7 @@ public class AiragResponseServiceImpl implements AiragResponseService { @@ -133,7 +133,7 @@ public class AiragResponseServiceImpl implements AiragResponseService {
133 emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data))); 133 emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data)));
134 134
135 // 发送END事件 135 // 发送END事件
136 - Map<String, String> endData = createEndData(questionEmbedding.getMetadata(), String.valueOf(questionEmbedding.getSimilarity())); 136 + Map<String, String> endData = createEndData(objectMapper.writeValueAsString(questionEmbedding.getMetadata()), String.valueOf(questionEmbedding.getSimilarity()));
137 emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData))); 137 emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData)));
138 emitter.complete(); 138 emitter.complete();
139 } 139 }