正在显示
7 个修改的文件
包含
159 行增加
和
456 行删除
| @@ -17,14 +17,12 @@ import org.jeecg.modules.airag.app.utils.JsonUtils; | @@ -17,14 +17,12 @@ import org.jeecg.modules.airag.app.utils.JsonUtils; | ||
| 17 | import org.jeecg.modules.airag.llm.entity.AiragKnowledge; | 17 | import org.jeecg.modules.airag.llm.entity.AiragKnowledge; |
| 18 | import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService; | 18 | import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService; |
| 19 | import org.springframework.beans.factory.annotation.Autowired; | 19 | import org.springframework.beans.factory.annotation.Autowired; |
| 20 | +import org.springframework.transaction.annotation.Propagation; | ||
| 20 | import org.springframework.transaction.annotation.Transactional; | 21 | import org.springframework.transaction.annotation.Transactional; |
| 21 | import org.springframework.web.bind.annotation.*; | 22 | import org.springframework.web.bind.annotation.*; |
| 22 | import org.springframework.web.multipart.MultipartFile; | 23 | import org.springframework.web.multipart.MultipartFile; |
| 23 | 24 | ||
| 24 | -import java.util.Arrays; | ||
| 25 | -import java.util.HashMap; | ||
| 26 | -import java.util.LinkedHashMap; | ||
| 27 | -import java.util.Map; | 25 | +import java.util.*; |
| 28 | import java.util.stream.Collectors; | 26 | import java.util.stream.Collectors; |
| 29 | 27 | ||
| 30 | @RestController | 28 | @RestController |
| @@ -44,13 +42,10 @@ public class QuestionEmbeddingController { | @@ -44,13 +42,10 @@ public class QuestionEmbeddingController { | ||
| 44 | .collect(Collectors.toMap(AiragKnowledge::getId, AiragKnowledge::getName)); | 42 | .collect(Collectors.toMap(AiragKnowledge::getId, AiragKnowledge::getName)); |
| 45 | 43 | ||
| 46 | page.getRecords().forEach(item -> { | 44 | page.getRecords().forEach(item -> { |
| 47 | - String metadata = item.getMetadata(); | ||
| 48 | - if (StringUtils.isNotBlank(metadata)) { | ||
| 49 | - Map<String, String> jsonMap = JsonUtils.jsonUtils(metadata); | 45 | + Map<String, Object> jsonMap = item.getMetadata(); |
| 50 | if (jsonMap.containsKey("knowledgeId")) { | 46 | if (jsonMap.containsKey("knowledgeId")) { |
| 51 | item.setKnowledgeName(airagKnowledgeMap.get(jsonMap.get("knowledgeId"))); | 47 | item.setKnowledgeName(airagKnowledgeMap.get(jsonMap.get("knowledgeId"))); |
| 52 | - item.setKnowledgeId(jsonMap.get("knowledgeId")); | ||
| 53 | - } | 48 | + item.setKnowledgeId(jsonMap.get("knowledgeId").toString()); |
| 54 | } | 49 | } |
| 55 | 50 | ||
| 56 | }); | 51 | }); |
| @@ -86,12 +81,9 @@ public class QuestionEmbeddingController { | @@ -86,12 +81,9 @@ public class QuestionEmbeddingController { | ||
| 86 | String docId = String.valueOf(snowflakeGenerator.next()); | 81 | String docId = String.valueOf(snowflakeGenerator.next()); |
| 87 | metadata.put("docId", docId); // 自动生成唯一文档ID | 82 | metadata.put("docId", docId); // 自动生成唯一文档ID |
| 88 | metadata.put("knowledgeId", record.getKnowledgeId()); | 83 | metadata.put("knowledgeId", record.getKnowledgeId()); |
| 89 | - // 使用 Jackson 序列化 Map 到 JSON | ||
| 90 | - ObjectMapper mapper = new ObjectMapper(); | ||
| 91 | - String metadataJson = mapper.writeValueAsString(metadata); | ||
| 92 | - // 2. 设置到embeddings对象 | ||
| 93 | - record.setMetadata(metadataJson); | ||
| 94 | 84 | ||
| 85 | + record.setMetadata(metadata); | ||
| 86 | + record.setId(UUID.randomUUID().toString()); | ||
| 95 | int result = questionEmbeddingService.insert(record); | 87 | int result = questionEmbeddingService.insert(record); |
| 96 | return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败"); | 88 | return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败"); |
| 97 | } | 89 | } |
| @@ -112,14 +104,10 @@ public class QuestionEmbeddingController { | @@ -112,14 +104,10 @@ public class QuestionEmbeddingController { | ||
| 112 | String knowledgeName = airagKnowledgeMap.get(record.getKnowledgeId()); | 104 | String knowledgeName = airagKnowledgeMap.get(record.getKnowledgeId()); |
| 113 | record.setKnowledgeName(knowledgeName); | 105 | record.setKnowledgeName(knowledgeName); |
| 114 | 106 | ||
| 115 | - String existMetadata = existRecord.getMetadata(); | ||
| 116 | - Map<String, String> jsonMap = new HashMap<>(); | ||
| 117 | - if (StringUtils.isNotBlank(existMetadata)) { | ||
| 118 | - jsonMap = JsonUtils.jsonUtils(existMetadata); | ||
| 119 | - } | 107 | + Map<String, Object> metadata = existRecord.getMetadata(); |
| 120 | 108 | ||
| 121 | - jsonMap.put("knowledgeId", record.getKnowledgeId()); | ||
| 122 | - record.setMetadata(Json.toJson(jsonMap)); | 109 | + metadata.put("knowledgeId", record.getKnowledgeId()); |
| 110 | + record.setMetadata(metadata); | ||
| 123 | } | 111 | } |
| 124 | int result = questionEmbeddingService.update(record); | 112 | int result = questionEmbeddingService.update(record); |
| 125 | return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败"); | 113 | return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败"); |
| @@ -144,7 +132,6 @@ public class QuestionEmbeddingController { | @@ -144,7 +132,6 @@ public class QuestionEmbeddingController { | ||
| 144 | } | 132 | } |
| 145 | 133 | ||
| 146 | @PostMapping("/uploadZip") | 134 | @PostMapping("/uploadZip") |
| 147 | - @Transactional(rollbackFor = {Exception.class}) | ||
| 148 | public Result<?> uploadZip( | 135 | public Result<?> uploadZip( |
| 149 | @RequestParam("file") MultipartFile file, | 136 | @RequestParam("file") MultipartFile file, |
| 150 | @RequestParam("knowledgeId") String knowledgeId) { | 137 | @RequestParam("knowledgeId") String knowledgeId) { |
| 1 | package org.jeecg.modules.airag.app.mapper; | 1 | package org.jeecg.modules.airag.app.mapper; |
| 2 | 2 | ||
| 3 | -import cn.hutool.core.lang.generator.SnowflakeGenerator; | ||
| 4 | -import com.alibaba.fastjson2.JSONObject; | 3 | +import com.baomidou.dynamic.datasource.annotation.DS; |
| 4 | +import com.baomidou.mybatisplus.core.metadata.IPage; | ||
| 5 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | 5 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; |
| 6 | -import com.fasterxml.jackson.core.JsonProcessingException; | ||
| 7 | -import com.fasterxml.jackson.core.type.TypeReference; | ||
| 8 | -import com.fasterxml.jackson.databind.ObjectMapper; | ||
| 9 | -import com.pgvector.PGvector; | ||
| 10 | -import dev.langchain4j.data.embedding.Embedding; | ||
| 11 | -import dev.langchain4j.model.output.Response; | ||
| 12 | -import io.minio.messages.Metadata; | ||
| 13 | -import lombok.extern.slf4j.Slf4j; | ||
| 14 | -import org.apache.commons.lang3.StringUtils; | ||
| 15 | -import org.jeecg.modules.airag.app.entity.Embeddings; | 6 | +import org.apache.ibatis.annotations.Mapper; |
| 7 | +import org.apache.ibatis.annotations.Param; | ||
| 16 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | 8 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; |
| 17 | -import org.jeecg.modules.airag.app.utils.AiModelUtils; | ||
| 18 | -import org.postgresql.util.PGobject; | ||
| 19 | -import org.springframework.beans.factory.annotation.Autowired; | ||
| 20 | -import org.springframework.beans.factory.annotation.Value; | ||
| 21 | -import org.springframework.stereotype.Component; | ||
| 22 | 9 | ||
| 23 | -import java.sql.*; | ||
| 24 | -import java.util.*; | ||
| 25 | -import java.util.stream.Collectors; | 10 | +import java.util.List; |
| 26 | 11 | ||
| 27 | -@Component | ||
| 28 | -@Slf4j | ||
| 29 | -public class QuestionEmbeddingMapper { | 12 | +@Mapper |
| 13 | +@DS("pgvector") | ||
| 14 | +public interface QuestionEmbeddingMapper { | ||
| 15 | + Page<QuestionEmbedding> findAll(IPage<QuestionEmbedding> page, @Param("questionEmbedding") QuestionEmbedding questionEmbedding); | ||
| 30 | 16 | ||
| 31 | - @Autowired | ||
| 32 | - private AiModelUtils aiModelUtils; | 17 | + Integer findQuestionCount(@Param("questionEmbedding") QuestionEmbedding questionEmbedding); |
| 33 | 18 | ||
| 34 | - @Value("${jeecg.ai-chat.embedId}") | ||
| 35 | - private String embedId; | ||
| 36 | - // PostgreSQL连接参数(应与项目配置一致) | ||
| 37 | - private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres"; | ||
| 38 | - private static final String USER = "postgres"; | ||
| 39 | - private static final String PASSWORD = "postgres"; | 19 | + QuestionEmbedding findById(@Param("id") String id); |
| 20 | + @DS("pgvector") | ||
| 21 | + int insert(@Param("record") QuestionEmbedding record); | ||
| 40 | 22 | ||
| 41 | - // 获取数据库连接 | ||
| 42 | - private Connection getConnection() throws SQLException { | ||
| 43 | - return DriverManager.getConnection(URL, USER, PASSWORD); | ||
| 44 | - } | 23 | + int update(@Param("record") QuestionEmbedding record); |
| 45 | 24 | ||
| 46 | - // 查询所有记录 | ||
| 47 | - public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, int pageNo, int pageSize) { | ||
| 48 | - List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 49 | - StringBuilder sql = new StringBuilder("select * from question_embedding where 1 = 1"); | ||
| 50 | - StringBuilder countSql = new StringBuilder("select count(1) from question_embedding where 1 = 1"); | ||
| 51 | - List<Object> params = new ArrayList<>(); | ||
| 52 | - List<Object> countParams = new ArrayList<>(); | 25 | + int deleteById(@Param("id") String id); |
| 53 | 26 | ||
| 54 | - if (StringUtils.isNotBlank(questionEmbedding.getKnowledgeId())) { | ||
| 55 | - sql.append(" AND metadata ->> 'knowledgeId' = ?"); | ||
| 56 | - countSql.append(" AND metadata ->> 'knowledgeId' = ?"); | ||
| 57 | - params.add(questionEmbedding.getKnowledgeId()); | ||
| 58 | - countParams.add(questionEmbedding.getKnowledgeId()); | ||
| 59 | - } | ||
| 60 | - if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){ | ||
| 61 | - sql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配 | ||
| 62 | - countSql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配 | ||
| 63 | - params.add("%" + questionEmbedding.getQuestion() + "%"); | ||
| 64 | - countParams.add("%" + questionEmbedding.getQuestion() + "%"); | ||
| 65 | - } | 27 | + int deleteByIds(@Param("ids") List<String> ids); |
| 66 | 28 | ||
| 67 | - if(StringUtils.isNotBlank(questionEmbedding.getAnswer())){ | ||
| 68 | - sql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配 | ||
| 69 | - countSql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配 | ||
| 70 | - params.add("%" + questionEmbedding.getAnswer() + "%"); | ||
| 71 | - countParams.add("%" + questionEmbedding.getAnswer() + "%"); | ||
| 72 | - } | ||
| 73 | - | ||
| 74 | - sql.append(" ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC"); | ||
| 75 | - | ||
| 76 | - // 添加分页 | ||
| 77 | - sql.append(" LIMIT ? OFFSET ?"); | ||
| 78 | - params.add(pageSize); | ||
| 79 | - params.add((pageNo - 1) * pageSize); | ||
| 80 | - | ||
| 81 | - | ||
| 82 | - try(Connection conn = getConnection(); | ||
| 83 | - PreparedStatement stmt = conn.prepareStatement(sql.toString())){ | ||
| 84 | - // 设置参数值 | ||
| 85 | - for (int i = 0; i < params.size(); i++) { | ||
| 86 | - stmt.setObject(i + 1, params.get(i)); | ||
| 87 | - } | ||
| 88 | - | ||
| 89 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 90 | - while (rs.next()) { | ||
| 91 | - results.add(mapRowToQuestionEmbedding(rs)); | ||
| 92 | - } | ||
| 93 | - } | ||
| 94 | - } catch (SQLException e) { | ||
| 95 | - log.error("查询所有记录失败", e); | ||
| 96 | - throw new RuntimeException("查询数据时发生数据库错误", e); | ||
| 97 | - } | ||
| 98 | - | ||
| 99 | - // 执行计数查询 | ||
| 100 | - long total = 0; | ||
| 101 | - try(Connection conn = getConnection(); | ||
| 102 | - PreparedStatement stmt = conn.prepareStatement(countSql.toString())){ | ||
| 103 | - // 设置参数值 | ||
| 104 | - for (int i = 0; i < countParams.size(); i++) { | ||
| 105 | - stmt.setObject(i + 1, countParams.get(i)); | ||
| 106 | - } | ||
| 107 | - | ||
| 108 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 109 | - if (rs.next()) { | ||
| 110 | - total = rs.getLong(1); // 直接获取count值 | ||
| 111 | - } | ||
| 112 | - } | ||
| 113 | - } catch (SQLException e) { | ||
| 114 | - log.error("查询记录总数失败", e); | ||
| 115 | - throw new RuntimeException("查询记录总数时发生数据库错误", e); | ||
| 116 | - } | ||
| 117 | - | ||
| 118 | - Page<QuestionEmbedding> page = new Page<>(); | ||
| 119 | - page.setRecords(results); | ||
| 120 | - page.setTotal(total); | ||
| 121 | - return page; | ||
| 122 | - } | ||
| 123 | - | ||
| 124 | - // 查询所有记录 | ||
| 125 | - public Integer findQuestionCount(QuestionEmbedding questionEmbedding) { | ||
| 126 | - | ||
| 127 | - StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from question_embedding where 1 = 1"); | ||
| 128 | - List<Object> params = new ArrayList<>(); | ||
| 129 | - | ||
| 130 | - if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){ | ||
| 131 | - sql.append(" AND question = ?"); | ||
| 132 | - params.add(questionEmbedding.getQuestion()); | ||
| 133 | - } | ||
| 134 | - | ||
| 135 | - | ||
| 136 | - try(Connection conn = getConnection(); | ||
| 137 | - PreparedStatement stmt = conn.prepareStatement(sql.toString())){ | ||
| 138 | - // 设置参数值 | ||
| 139 | - for (int i = 0; i < params.size(); i++) { | ||
| 140 | - stmt.setObject(i + 1, params.get(i)); | ||
| 141 | - } | ||
| 142 | - | ||
| 143 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 144 | - while (rs.next()) { | ||
| 145 | - return rs.getInt("total_count"); | ||
| 146 | - } | ||
| 147 | - return 0; | ||
| 148 | - } | ||
| 149 | - } catch (SQLException e) { | ||
| 150 | - log.error("查询所有记录失败", e); | ||
| 151 | - throw new RuntimeException("查询数据时发生数据库错误", e); | ||
| 152 | - } | ||
| 153 | - | ||
| 154 | - } | ||
| 155 | - | ||
| 156 | - // 根据ID查询单个记录 | ||
| 157 | - public QuestionEmbedding findById(String id) { | ||
| 158 | - String sql = "SELECT * FROM question_embedding WHERE id = ?"; | ||
| 159 | - | ||
| 160 | - try (Connection conn = getConnection(); | ||
| 161 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 162 | - | ||
| 163 | - stmt.setString(1, id); | ||
| 164 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 165 | - if (rs.next()) { | ||
| 166 | - return mapRowToQuestionEmbedding(rs); | ||
| 167 | - } | ||
| 168 | - } | ||
| 169 | - } catch (SQLException e) { | ||
| 170 | - log.error("根据ID查询记录失败, ID: {}", id, e); | ||
| 171 | - throw new RuntimeException("根据ID查询时发生数据库错误", e); | ||
| 172 | - } | ||
| 173 | - return null; | ||
| 174 | - } | ||
| 175 | - | ||
| 176 | - // 插入新记录 | ||
| 177 | - public int insert(QuestionEmbedding record) { | ||
| 178 | - String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)"; | ||
| 179 | - | ||
| 180 | - | ||
| 181 | - try (Connection conn = getConnection(); | ||
| 182 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 183 | - stmt.setString(1, UUID.randomUUID().toString()); | ||
| 184 | - stmt.setString(2, record.getText()); | ||
| 185 | - stmt.setString(3, record.getQuestion()); | ||
| 186 | - stmt.setString(4, record.getAnswer()); | ||
| 187 | - PGobject jsonObject = new PGobject(); | ||
| 188 | - jsonObject.setType("json"); | ||
| 189 | - jsonObject.setValue(record.getMetadata()); | ||
| 190 | - stmt.setObject(5, jsonObject); | ||
| 191 | - Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion()); | ||
| 192 | - stmt.setObject(6, embedding.content().vector()); | ||
| 193 | - return stmt.executeUpdate(); | ||
| 194 | - } catch (SQLException e) { | ||
| 195 | - log.error("插入记录失败: {}", record, e); | ||
| 196 | - throw new RuntimeException("插入数据时发生数据库错误", e); | ||
| 197 | - } | ||
| 198 | - } | ||
| 199 | - | ||
| 200 | - // 更新记录 | ||
| 201 | - public int update(QuestionEmbedding record) { | ||
| 202 | - String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?"; | ||
| 203 | - | ||
| 204 | - try (Connection conn = getConnection(); | ||
| 205 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 206 | - | ||
| 207 | - | ||
| 208 | - stmt.setString(1, record.getText()); | ||
| 209 | - stmt.setString(2, record.getQuestion()); | ||
| 210 | - stmt.setString(3, record.getAnswer()); | ||
| 211 | - stmt.setObject(4, record.getMetadata()); | ||
| 212 | - | ||
| 213 | - Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion()); | ||
| 214 | - stmt.setObject(5, embedding.content().vector()); | ||
| 215 | - | ||
| 216 | - stmt.setString(6, record.getId()); | ||
| 217 | - | ||
| 218 | - return stmt.executeUpdate(); | ||
| 219 | - } catch (SQLException e) { | ||
| 220 | - log.error("更新记录失败: {}", record, e); | ||
| 221 | - throw new RuntimeException("更新数据时发生数据库错误", e); | ||
| 222 | - } | ||
| 223 | - } | ||
| 224 | - | ||
| 225 | - | ||
| 226 | - // 批量删除方法 | ||
| 227 | - public int deleteByIds(List<String> ids) { | ||
| 228 | - if (ids == null || ids.isEmpty()) { | ||
| 229 | - return 0; | ||
| 230 | - } | ||
| 231 | - | ||
| 232 | - String sql = "DELETE FROM question_embedding WHERE id IN ("; | ||
| 233 | - StringBuilder placeholders = new StringBuilder(); | ||
| 234 | - for (int i = 0; i < ids.size(); i++) { | ||
| 235 | - placeholders.append("?"); | ||
| 236 | - if (i < ids.size() - 1) { | ||
| 237 | - placeholders.append(","); | ||
| 238 | - } | ||
| 239 | - } | ||
| 240 | - sql += placeholders.toString() + ")"; | ||
| 241 | - | ||
| 242 | - try (Connection conn = getConnection(); | ||
| 243 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 244 | - | ||
| 245 | - for (int i = 0; i < ids.size(); i++) { | ||
| 246 | - stmt.setString(i + 1, ids.get(i)); | ||
| 247 | - } | ||
| 248 | - | ||
| 249 | - return stmt.executeUpdate(); | ||
| 250 | - } catch (SQLException e) { | ||
| 251 | - log.error("批量删除向量记录失败, IDs: {}", ids, e); | ||
| 252 | - throw new RuntimeException("批量删除向量数据时发生数据库错误", e); | ||
| 253 | - } | ||
| 254 | - } | ||
| 255 | - | ||
| 256 | - | ||
| 257 | - /** | ||
| 258 | - * 向量相似度查询 (基于问题文本的向量) | ||
| 259 | - * @param question 问题文本 | ||
| 260 | - * @param limit 返回结果数量 | ||
| 261 | - * @param minSimilarity 最小相似度阈值(0-1) | ||
| 262 | - * @return 相似问答列表(按相似度降序) | ||
| 263 | - */ | ||
| 264 | - public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { | ||
| 265 | - List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 266 | - | ||
| 267 | - // 1. 参数校验 | ||
| 268 | - if (minSimilarity < 0 || minSimilarity > 1) { | ||
| 269 | - throw new IllegalArgumentException("相似度阈值必须在0到1之间"); | ||
| 270 | - } | ||
| 271 | - | ||
| 272 | - // 2. 获取问题的嵌入向量 | ||
| 273 | - Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, question); | ||
| 274 | - float[] queryVector = embeddingResponse.content().vector(); | ||
| 275 | - // 3. 计算最大允许距离(1 - 相似度阈值) | ||
| 276 | - double maxDistance = 1 - minSimilarity; | ||
| 277 | - | ||
| 278 | - // 4. 执行向量相似度查询 | ||
| 279 | - String sql = "SELECT *, embedding <-> ? AS distance " + | ||
| 280 | - "FROM question_embedding " + | ||
| 281 | - "WHERE embedding <-> ? < ? " + // 距离小于阈值 | ||
| 282 | - "ORDER BY distance ASC " + // 按距离升序 | ||
| 283 | - "LIMIT ?"; | ||
| 284 | - | ||
| 285 | - try (Connection conn = getConnection(); | ||
| 286 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 287 | - | ||
| 288 | - // 设置参数 | ||
| 289 | - PGvector vector = new PGvector(queryVector); | ||
| 290 | - stmt.setObject(1, vector); | ||
| 291 | - stmt.setObject(2, vector); | ||
| 292 | - stmt.setDouble(3, maxDistance); | ||
| 293 | - stmt.setInt(4, limit); | ||
| 294 | - | ||
| 295 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 296 | - while (rs.next()) { | ||
| 297 | - QuestionEmbedding record = mapRowToQuestionEmbedding(rs); | ||
| 298 | - // 计算相似度(1 - 距离) | ||
| 299 | - double distance = rs.getDouble("distance"); | ||
| 300 | - double similarity = 1 - distance; | ||
| 301 | - record.setSimilarity(similarity); | ||
| 302 | - results.add(record); | ||
| 303 | - } | ||
| 304 | - } | ||
| 305 | - } catch (SQLException e) { | ||
| 306 | - log.error("向量相似度查询失败", e); | ||
| 307 | - throw new RuntimeException("执行向量相似度查询时发生数据库错误", e); | ||
| 308 | - } | ||
| 309 | - return results; | ||
| 310 | - } | ||
| 311 | - | ||
| 312 | - /** | ||
| 313 | - * 向量相似度查询 (直接使用向量) | ||
| 314 | - * @param vector 查询向量 | ||
| 315 | - * @param limit 返回结果数量 | ||
| 316 | - * @return 相似问答列表(按相似度降序) | ||
| 317 | - */ | ||
| 318 | - public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) { | ||
| 319 | - List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 320 | - String sql = "SELECT *, embedding <-> ? AS similarity " + | ||
| 321 | - "FROM question_embedding " + | ||
| 322 | - "ORDER BY similarity ASC " + | ||
| 323 | - "LIMIT ?"; | ||
| 324 | - | ||
| 325 | - try (Connection conn = getConnection(); | ||
| 326 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 327 | - | ||
| 328 | - stmt.setObject(1, new PGvector(vector)); | ||
| 329 | - stmt.setInt(2, limit); | ||
| 330 | - | ||
| 331 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 332 | - while (rs.next()) { | ||
| 333 | - QuestionEmbedding record = mapRowToQuestionEmbedding(rs); | ||
| 334 | - double similarity = 1 - rs.getDouble("similarity"); | ||
| 335 | - record.setSimilarity(similarity); | ||
| 336 | - results.add(record); | ||
| 337 | - } | ||
| 338 | - } | ||
| 339 | - } catch (SQLException e) { | ||
| 340 | - log.error("向量相似度查询失败", e); | ||
| 341 | - throw new RuntimeException("执行向量相似度查询时发生数据库错误", e); | ||
| 342 | - } | ||
| 343 | - return results; | ||
| 344 | - } | ||
| 345 | - | ||
| 346 | - // 根据ID删除记录 | ||
| 347 | - public int deleteById(String id) { | ||
| 348 | - String sql = "DELETE FROM question_embedding WHERE id = ?"; | ||
| 349 | - | ||
| 350 | - try (Connection conn = getConnection(); | ||
| 351 | - PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 352 | - | ||
| 353 | - stmt.setString(1, id); | ||
| 354 | - return stmt.executeUpdate(); | ||
| 355 | - } catch (SQLException e) { | ||
| 356 | - log.error("删除记录失败, ID: {}", id, e); | ||
| 357 | - throw new RuntimeException("删除数据时发生数据库错误", e); | ||
| 358 | - } | ||
| 359 | - } | ||
| 360 | - | ||
| 361 | - // 将ResultSet行映射为QuestionEmbedding对象 | ||
| 362 | - private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException { | ||
| 363 | - QuestionEmbedding record = new QuestionEmbedding(); | ||
| 364 | - record.setId(rs.getString("id")); | ||
| 365 | - record.setText(rs.getString("text")); | ||
| 366 | - record.setQuestion(rs.getString("question")); | ||
| 367 | - record.setAnswer(rs.getString("answer")); | ||
| 368 | - | ||
| 369 | - String metadataJson = rs.getString("metadata"); | ||
| 370 | - if (StringUtils.isNotBlank(metadataJson)) { | ||
| 371 | - record.setMetadata(metadataJson); | ||
| 372 | - } | ||
| 373 | - | ||
| 374 | - return record; | ||
| 375 | - } | ||
| 376 | - | ||
| 377 | - // 将Map转换为JSON字符串 | ||
| 378 | - private String toJson(Map<String, Object> map) { | ||
| 379 | - try { | ||
| 380 | - return new ObjectMapper().writeValueAsString(map); | ||
| 381 | - } catch (JsonProcessingException e) { | ||
| 382 | - log.error("元数据转换为JSON失败", e); | ||
| 383 | - return "{}"; | ||
| 384 | - } | ||
| 385 | - } | ||
| 386 | - | ||
| 387 | - // 将JSON字符串转换为Map | ||
| 388 | - private Map<String, Object> fromJson(String json) { | ||
| 389 | - try { | ||
| 390 | - return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {}); | ||
| 391 | - } catch (JsonProcessingException e) { | ||
| 392 | - log.error("JSON转换为元数据失败", e); | ||
| 393 | - return Collections.emptyMap(); | ||
| 394 | - } | ||
| 395 | - } | 29 | + List<QuestionEmbedding> similaritySearchByQuestion(@Param("vector") float[] vector, |
| 30 | + @Param("limit") int limit, | ||
| 31 | + @Param("minSimilarity") Double minSimilarity); | ||
| 396 | 32 | ||
| 33 | + List<QuestionEmbedding> similaritySearch(@Param("vector") float[] vector, | ||
| 34 | + @Param("limit") int limit); | ||
| 397 | } | 35 | } |
| 1 | +<?xml version="1.0" encoding="UTF-8"?> | ||
| 2 | +<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN" | ||
| 3 | + "http://mybatis.org/dtd/mybatis-3-mapper.dtd"> | ||
| 4 | +<mapper namespace="org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper"> | ||
| 5 | + | ||
| 6 | + <resultMap id="questionEmbeddingResultMap" type="org.jeecg.modules.airag.app.entity.QuestionEmbedding"> | ||
| 7 | + <id column="id" property="id" /> | ||
| 8 | + <result column="text" property="text" /> | ||
| 9 | + <result column="question" property="question" /> | ||
| 10 | + <result column="answer" property="answer" /> | ||
| 11 | + <result column="metadata" property="metadata" typeHandler="org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler"/> | ||
| 12 | + <result column="similarity" property="similarity" /> | ||
| 13 | + </resultMap> | ||
| 14 | + | ||
| 15 | + <select id="findAll" resultMap="questionEmbeddingResultMap"> | ||
| 16 | + SELECT * FROM question_embedding WHERE 1 = 1 | ||
| 17 | + <if test="questionEmbedding.knowledgeId != null and questionEmbedding.knowledgeId != ''"> | ||
| 18 | + AND metadata ->> 'knowledgeId' = #{questionEmbedding.knowledgeId} | ||
| 19 | + </if> | ||
| 20 | + <if test="questionEmbedding.question != null and questionEmbedding.question != ''"> | ||
| 21 | + AND question ILIKE CONCAT('%', #{questionEmbedding.question}, '%') | ||
| 22 | + </if> | ||
| 23 | + <if test="questionEmbedding.answer != null and questionEmbedding.answer != ''"> | ||
| 24 | + AND answer ILIKE CONCAT('%', #{questionEmbedding.answer}, '%') | ||
| 25 | + </if> | ||
| 26 | + ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC | ||
| 27 | + </select> | ||
| 28 | + | ||
| 29 | + <select id="findQuestionCount" resultType="int"> | ||
| 30 | + SELECT COUNT(1) AS total_count FROM question_embedding WHERE 1 = 1 | ||
| 31 | + <if test="questionEmbedding.question != null and questionEmbedding.question != ''"> | ||
| 32 | + AND question = #{questionEmbedding.question} | ||
| 33 | + </if> | ||
| 34 | + </select> | ||
| 35 | + | ||
| 36 | + <select id="findById" resultMap="questionEmbeddingResultMap"> | ||
| 37 | + SELECT * FROM question_embedding WHERE id = #{id} | ||
| 38 | + </select> | ||
| 39 | + | ||
| 40 | + <insert id="insert" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding"> | ||
| 41 | + INSERT INTO question_embedding (id, text, question, answer, metadata, embedding) | ||
| 42 | + VALUES ( | ||
| 43 | + #{record.id, jdbcType=VARCHAR}, | ||
| 44 | + #{record.text, jdbcType=VARCHAR}, | ||
| 45 | + #{record.question, jdbcType=VARCHAR}, | ||
| 46 | + #{record.answer, jdbcType=VARCHAR}, | ||
| 47 | + #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb, | ||
| 48 | + #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler} | ||
| 49 | + ) | ||
| 50 | + </insert> | ||
| 51 | + | ||
| 52 | + <update id="update" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding"> | ||
| 53 | + UPDATE question_embedding | ||
| 54 | + SET | ||
| 55 | + text = #{record.text, jdbcType=VARCHAR}, | ||
| 56 | + question = #{record.question, jdbcType=VARCHAR}, | ||
| 57 | + answer = #{record.answer, jdbcType=VARCHAR}, | ||
| 58 | + metadata = #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb, | ||
| 59 | + embedding = #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler} | ||
| 60 | + WHERE id = #{record.id} | ||
| 61 | + </update> | ||
| 62 | + | ||
| 63 | + <delete id="deleteById"> | ||
| 64 | + DELETE FROM question_embedding WHERE id = #{id} | ||
| 65 | + </delete> | ||
| 66 | + | ||
| 67 | + <delete id="deleteByIds"> | ||
| 68 | + DELETE FROM question_embedding WHERE id IN | ||
| 69 | + <foreach collection="ids" item="id" open="(" separator="," close=")"> | ||
| 70 | + #{id} | ||
| 71 | + </foreach> | ||
| 72 | + </delete> | ||
| 73 | + | ||
| 74 | + <select id="similaritySearchByQuestion" resultMap="questionEmbeddingResultMap"> | ||
| 75 | + <![CDATA[ | ||
| 76 | + SELECT *, | ||
| 77 | + (embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler})::float AS similarity | ||
| 78 | + FROM question_embedding | ||
| 79 | + WHERE (embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}) < #{minSimilarity} | ||
| 80 | + ORDER BY similarity ASC | ||
| 81 | + LIMIT #{limit} | ||
| 82 | + ]]> | ||
| 83 | + </select> | ||
| 84 | + | ||
| 85 | + <select id="similaritySearch" resultMap="questionEmbeddingResultMap"> | ||
| 86 | +<!-- SELECT *, embedding <-> #{vector} AS similarity--> | ||
| 87 | +<!-- FROM question_embedding--> | ||
| 88 | +<!-- ORDER BY similarity ASC--> | ||
| 89 | +<!-- LIMIT #{limit}--> | ||
| 90 | + </select> | ||
| 91 | +</mapper> |
| @@ -50,8 +50,6 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | @@ -50,8 +50,6 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | ||
| 50 | 50 | ||
| 51 | @Override | 51 | @Override |
| 52 | public void saveToQuestionLibrary(AiragLog log) throws JsonProcessingException { | 52 | public void saveToQuestionLibrary(AiragLog log) throws JsonProcessingException { |
| 53 | - // 这里实现将问题和回答存入问题库数据表的逻辑 | ||
| 54 | - // 假设问题库数据表的实体类为 QuestionLibrary,Mapper 接口为 QuestionLibraryMapper | ||
| 55 | QuestionEmbedding questionEmbedding = new QuestionEmbedding(); | 53 | QuestionEmbedding questionEmbedding = new QuestionEmbedding(); |
| 56 | questionEmbedding.setQuestion(log.getQuestion()); | 54 | questionEmbedding.setQuestion(log.getQuestion()); |
| 57 | questionEmbedding.setAnswer(log.getAnswer()); | 55 | questionEmbedding.setAnswer(log.getAnswer()); |
| @@ -62,11 +60,7 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | @@ -62,11 +60,7 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | ||
| 62 | String docId = String.valueOf(snowflakeGenerator.next()); | 60 | String docId = String.valueOf(snowflakeGenerator.next()); |
| 63 | metadata.put("docId", docId); | 61 | metadata.put("docId", docId); |
| 64 | metadata.put("knowledgeId", questionEmbedding.getKnowledgeId()); | 62 | metadata.put("knowledgeId", questionEmbedding.getKnowledgeId()); |
| 65 | - // 使用 Jackson 序列化 Map 到 JSON | ||
| 66 | - ObjectMapper mapper = new ObjectMapper(); | ||
| 67 | - String metadataJson = mapper.writeValueAsString(metadata); | ||
| 68 | - // 2. 设置到embeddings对象 | ||
| 69 | - questionEmbedding.setMetadata(metadataJson); | 63 | + questionEmbedding.setMetadata(metadata); |
| 70 | questionEmbeddingMapper.insert(questionEmbedding); | 64 | questionEmbeddingMapper.insert(questionEmbedding); |
| 71 | airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId()); | 65 | airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId()); |
| 72 | System.out.println("1"); | 66 | System.out.println("1"); |
| 1 | package org.jeecg.modules.airag.app.service.impl; | 1 | package org.jeecg.modules.airag.app.service.impl; |
| 2 | 2 | ||
| 3 | +import com.baomidou.dynamic.datasource.annotation.DS; | ||
| 3 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | 4 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; |
| 4 | import com.fasterxml.jackson.core.JsonProcessingException; | 5 | import com.fasterxml.jackson.core.JsonProcessingException; |
| 6 | +import org.apache.commons.lang3.StringUtils; | ||
| 5 | import org.apache.poi.hwpf.usermodel.CharacterRun; | 7 | import org.apache.poi.hwpf.usermodel.CharacterRun; |
| 6 | import org.apache.poi.hwpf.HWPFDocument; | 8 | import org.apache.poi.hwpf.HWPFDocument; |
| 7 | import org.apache.poi.hwpf.usermodel.Paragraph; | 9 | import org.apache.poi.hwpf.usermodel.Paragraph; |
| @@ -15,7 +17,10 @@ import dev.langchain4j.model.output.Response; | @@ -15,7 +17,10 @@ import dev.langchain4j.model.output.Response; | ||
| 15 | import org.apache.commons.io.FilenameUtils; | 17 | import org.apache.commons.io.FilenameUtils; |
| 16 | import org.apache.poi.xwpf.usermodel.*; | 18 | import org.apache.poi.xwpf.usermodel.*; |
| 17 | import org.jeecg.common.api.vo.Result; | 19 | import org.jeecg.common.api.vo.Result; |
| 20 | +import org.jeecg.modules.airag.app.entity.Embeddings; | ||
| 18 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | 21 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; |
| 22 | +import org.jeecg.modules.airag.app.mapper.EmbeddingsMapper; | ||
| 23 | +import org.jeecg.modules.airag.app.mapper.PgVectorMapper; | ||
| 19 | import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper; | 24 | import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper; |
| 20 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | 25 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; |
| 21 | import org.jeecg.modules.airag.app.utils.AiModelUtils; | 26 | import org.jeecg.modules.airag.app.utils.AiModelUtils; |
| @@ -60,7 +65,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -60,7 +65,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 60 | private AiModelUtils aiModelUtils; | 65 | private AiModelUtils aiModelUtils; |
| 61 | 66 | ||
| 62 | @Autowired | 67 | @Autowired |
| 63 | - private IAIChatHandler aiChatHandler; | 68 | + private PgVectorMapper pgVectorMapper; |
| 64 | 69 | ||
| 65 | @Value("${jeecg.upload.path}") | 70 | @Value("${jeecg.upload.path}") |
| 66 | private String uploadPath; | 71 | private String uploadPath; |
| @@ -68,17 +73,13 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -68,17 +73,13 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 68 | private String embedId; | 73 | private String embedId; |
| 69 | 74 | ||
| 70 | private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx"); | 75 | private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx"); |
| 71 | - private static final Pattern SPECIAL_CHARS_PATTERN = Pattern.compile("[^a-zA-Z0-9\\u4e00-\\u9fa5\\s]"); | ||
| 72 | private static final Pattern UUID_PATTERN = Pattern.compile("_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"); | 76 | private static final Pattern UUID_PATTERN = Pattern.compile("_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"); |
| 73 | 77 | ||
| 74 | - // 数据库连接配置 | ||
| 75 | - private static final String DB_URL = "jdbc:postgresql://192.168.100.104:5432/postgres"; | ||
| 76 | - private static final String DB_USER = "postgres"; | ||
| 77 | - private static final String DB_PASSWORD = "postgres"; | ||
| 78 | 78 | ||
| 79 | @Override | 79 | @Override |
| 80 | public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, Integer pageNo, Integer pageSize) { | 80 | public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, Integer pageNo, Integer pageSize) { |
| 81 | - return questionEmbeddingMapper.findAll(questionEmbedding,pageNo,pageSize); | 81 | + Page<QuestionEmbedding> page = new Page<>(pageNo, pageSize); |
| 82 | + return questionEmbeddingMapper.findAll(page,questionEmbedding); | ||
| 82 | } | 83 | } |
| 83 | 84 | ||
| 84 | @Override | 85 | @Override |
| @@ -93,11 +94,21 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -93,11 +94,21 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 93 | 94 | ||
| 94 | @Override | 95 | @Override |
| 95 | public int insert(QuestionEmbedding record) { | 96 | public int insert(QuestionEmbedding record) { |
| 97 | + if (StringUtils.isNotBlank(record.getQuestion())){ | ||
| 98 | + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion()); | ||
| 99 | + record.setEmbedding(embedding.content().vector()); | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + | ||
| 96 | return questionEmbeddingMapper.insert(record); | 103 | return questionEmbeddingMapper.insert(record); |
| 97 | } | 104 | } |
| 98 | 105 | ||
| 99 | @Override | 106 | @Override |
| 100 | public int update(QuestionEmbedding record) { | 107 | public int update(QuestionEmbedding record) { |
| 108 | + if (StringUtils.isNotBlank(record.getQuestion())){ | ||
| 109 | + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion()); | ||
| 110 | + record.setEmbedding(embedding.content().vector()); | ||
| 111 | + } | ||
| 101 | return questionEmbeddingMapper.update(record); | 112 | return questionEmbeddingMapper.update(record); |
| 102 | } | 113 | } |
| 103 | 114 | ||
| @@ -113,7 +124,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -113,7 +124,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 113 | 124 | ||
| 114 | @Override | 125 | @Override |
| 115 | public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { | 126 | public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { |
| 116 | - return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); | 127 | + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, question); |
| 128 | + return questionEmbeddingMapper.similaritySearchByQuestion(embedding.content().vector(), limit, minSimilarity); | ||
| 117 | } | 129 | } |
| 118 | 130 | ||
| 119 | @Override | 131 | @Override |
| @@ -183,10 +195,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -183,10 +195,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 183 | segments = splitWordDocument(targetPath.toString()); | 195 | segments = splitWordDocument(targetPath.toString()); |
| 184 | } | 196 | } |
| 185 | 197 | ||
| 186 | - // 原有逻辑:保存到question_embedding表 | ||
| 187 | saveSegmentsToDatabase(segments, originalFileName, storedFileName, knowledgeId); | 198 | saveSegmentsToDatabase(segments, originalFileName, storedFileName, knowledgeId); |
| 188 | 199 | ||
| 189 | - // 新增逻辑:同时保存到embeddings表 | ||
| 190 | saveToEmbeddingsTable(segments, originalFileName, storedFileName, knowledgeId); | 200 | saveToEmbeddingsTable(segments, originalFileName, storedFileName, knowledgeId); |
| 191 | 201 | ||
| 192 | } | 202 | } |
| @@ -196,7 +206,6 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -196,7 +206,6 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 196 | String displayFileName = removeUuidSuffix(originalFileName); | 206 | String displayFileName = removeUuidSuffix(originalFileName); |
| 197 | displayFileName = FilenameUtils.removeExtension(displayFileName); | 207 | displayFileName = FilenameUtils.removeExtension(displayFileName); |
| 198 | 208 | ||
| 199 | - try (Connection conn = getConnection()) { | ||
| 200 | for (String segment : segments) { | 209 | for (String segment : segments) { |
| 201 | if (segment.trim().isEmpty()) continue; | 210 | if (segment.trim().isEmpty()) continue; |
| 202 | 211 | ||
| @@ -205,44 +214,29 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -205,44 +214,29 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 205 | if (parts.length < 2) continue; | 214 | if (parts.length < 2) continue; |
| 206 | 215 | ||
| 207 | String titlePath = parts[0].trim(); | 216 | String titlePath = parts[0].trim(); |
| 208 | - String answer = segment.trim(); // 整个回答段(含标题 + 内容) | 217 | + // 整个回答段(标题 + 内容) |
| 218 | + String answer = segment.trim(); | ||
| 209 | 219 | ||
| 210 | // 获取 embedding | 220 | // 获取 embedding |
| 211 | Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, answer); | 221 | Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, answer); |
| 212 | float[] embeddingVector = embeddingResponse.content().vector(); | 222 | float[] embeddingVector = embeddingResponse.content().vector(); |
| 213 | 223 | ||
| 214 | - // 准备 metadata | ||
| 215 | Map<String, Object> metadata = new HashMap<>(); | 224 | Map<String, Object> metadata = new HashMap<>(); |
| 216 | metadata.put("docName", originalFileName); | 225 | metadata.put("docName", originalFileName); |
| 217 | metadata.put("storedFileName", storedFileName); | 226 | metadata.put("storedFileName", storedFileName); |
| 218 | metadata.put("knowledgeId", knowledgeId); | 227 | metadata.put("knowledgeId", knowledgeId); |
| 219 | metadata.put("title", displayFileName + ": " + titlePath); | 228 | metadata.put("title", displayFileName + ": " + titlePath); |
| 220 | - | ||
| 221 | - // 插入 | ||
| 222 | - String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)"; | ||
| 223 | - try (PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 224 | - stmt.setString(1, UUID.randomUUID().toString()); | ||
| 225 | - stmt.setObject(2, new PGvector(embeddingVector)); | ||
| 226 | - stmt.setString(3, answer); | ||
| 227 | - | ||
| 228 | - PGobject jsonObject = new PGobject(); | ||
| 229 | - jsonObject.setType("json"); | ||
| 230 | - jsonObject.setValue(new ObjectMapper().writeValueAsString(metadata)); | ||
| 231 | - stmt.setObject(4, jsonObject); | ||
| 232 | - | ||
| 233 | - stmt.executeUpdate(); | ||
| 234 | - } | ||
| 235 | - } | ||
| 236 | - } catch (Exception e) { | ||
| 237 | - log.error("保存分段到embeddings表失败", e); | 229 | + Embeddings embeddings = new Embeddings(); |
| 230 | + embeddings.setMetadata(metadata); | ||
| 231 | + embeddings.setId(UUID.randomUUID().toString()); | ||
| 232 | + embeddings.setEmbedding(embeddingVector); | ||
| 233 | + embeddings.setText(answer); | ||
| 234 | + pgVectorMapper.insert(embeddings); | ||
| 238 | } | 235 | } |
| 239 | } | 236 | } |
| 240 | 237 | ||
| 241 | 238 | ||
| 242 | - // 获取数据库连接 | ||
| 243 | - private Connection getConnection() throws SQLException { | ||
| 244 | - return DriverManager.getConnection(DB_URL, DB_USER, DB_PASSWORD); | ||
| 245 | - } | 239 | + |
| 246 | 240 | ||
| 247 | private String generateStoredFileName(String originalFileName) { | 241 | private String generateStoredFileName(String originalFileName) { |
| 248 | String baseName = FilenameUtils.removeExtension(originalFileName); | 242 | String baseName = FilenameUtils.removeExtension(originalFileName); |
| @@ -359,6 +353,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -359,6 +353,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 359 | 353 | ||
| 360 | return 0; | 354 | return 0; |
| 361 | } | 355 | } |
| 356 | + | ||
| 362 | private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName, String knowledgeId) { | 357 | private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName, String knowledgeId) { |
| 363 | if (segments.isEmpty()) return; | 358 | if (segments.isEmpty()) return; |
| 364 | 359 | ||
| @@ -384,24 +379,22 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -384,24 +379,22 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 384 | record.setAnswer(titleLine + "\n" + content); | 379 | record.setAnswer(titleLine + "\n" + content); |
| 385 | record.setText(""); | 380 | record.setText(""); |
| 386 | 381 | ||
| 387 | - Map<String, String> metadata = new LinkedHashMap<>(); | 382 | + Map<String, Object> metadata = new LinkedHashMap<>(); |
| 388 | metadata.put("docId", docId); | 383 | metadata.put("docId", docId); |
| 389 | metadata.put("docName", originalFileName); | 384 | metadata.put("docName", originalFileName); |
| 390 | metadata.put("storedFileName", storedFileName); | 385 | metadata.put("storedFileName", storedFileName); |
| 391 | metadata.put("knowledgeId", knowledgeId); | 386 | metadata.put("knowledgeId", knowledgeId); |
| 392 | 387 | ||
| 393 | - try { | ||
| 394 | - record.setMetadata(new ObjectMapper().writeValueAsString(metadata)); | ||
| 395 | - } catch (JsonProcessingException e) { | ||
| 396 | - log.error("生成metadata JSON失败", e); | ||
| 397 | - } | 388 | + |
| 389 | + record.setMetadata(metadata); | ||
| 390 | + | ||
| 398 | 391 | ||
| 399 | log.info("保存分段: title={}, content_length={}", question, segment.length()); | 392 | log.info("保存分段: title={}, content_length={}", question, segment.length()); |
| 400 | 393 | ||
| 401 | Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, record.getQuestion()); | 394 | Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, record.getQuestion()); |
| 402 | record.setEmbedding(embeddingResponse.content().vector()); | 395 | record.setEmbedding(embeddingResponse.content().vector()); |
| 403 | record.setKnowledgeId(knowledgeId); | 396 | record.setKnowledgeId(knowledgeId); |
| 404 | - questionEmbeddingMapper.insert(record); | 397 | + insert(record); |
| 405 | } | 398 | } |
| 406 | } | 399 | } |
| 407 | 400 |
| @@ -133,7 +133,7 @@ public class AiragResponseServiceImpl implements AiragResponseService { | @@ -133,7 +133,7 @@ public class AiragResponseServiceImpl implements AiragResponseService { | ||
| 133 | emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data))); | 133 | emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data))); |
| 134 | 134 | ||
| 135 | // 发送END事件 | 135 | // 发送END事件 |
| 136 | - Map<String, String> endData = createEndData(questionEmbedding.getMetadata(), String.valueOf(questionEmbedding.getSimilarity())); | 136 | + Map<String, String> endData = createEndData(objectMapper.writeValueAsString(questionEmbedding.getMetadata()), String.valueOf(questionEmbedding.getSimilarity())); |
| 137 | emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData))); | 137 | emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData))); |
| 138 | emitter.complete(); | 138 | emitter.complete(); |
| 139 | } | 139 | } |
-
请 注册 或 登录 后发表评论