正在显示
15 个修改的文件
包含
262 行增加
和
3 行删除
| @@ -19,6 +19,7 @@ import org.jeecg.modules.airag.app.service.IAiragLogService; | @@ -19,6 +19,7 @@ import org.jeecg.modules.airag.app.service.IAiragLogService; | ||
| 19 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | 19 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; |
| 20 | import org.jeecg.modules.airag.llm.entity.AiragKnowledge; | 20 | import org.jeecg.modules.airag.llm.entity.AiragKnowledge; |
| 21 | import org.jeecg.modules.airag.llm.entity.AiragModel; | 21 | import org.jeecg.modules.airag.llm.entity.AiragModel; |
| 22 | +import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService; | ||
| 22 | import org.jeecg.modules.airag.llm.service.IAiragModelService; | 23 | import org.jeecg.modules.airag.llm.service.IAiragModelService; |
| 23 | import org.springframework.beans.factory.annotation.Autowired; | 24 | import org.springframework.beans.factory.annotation.Autowired; |
| 24 | import org.springframework.web.bind.annotation.*; | 25 | import org.springframework.web.bind.annotation.*; |
| @@ -29,6 +30,8 @@ import javax.servlet.http.HttpServletResponse; | @@ -29,6 +30,8 @@ import javax.servlet.http.HttpServletResponse; | ||
| 29 | import java.sql.SQLException; | 30 | import java.sql.SQLException; |
| 30 | import java.util.*; | 31 | import java.util.*; |
| 31 | 32 | ||
| 33 | +import java.util.stream.Collectors; | ||
| 34 | + | ||
| 32 | /** | 35 | /** |
| 33 | * @Description: 日志管理 | 36 | * @Description: 日志管理 |
| 34 | * @Author: jeecg-boot | 37 | * @Author: jeecg-boot |
| @@ -51,6 +54,9 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | @@ -51,6 +54,9 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | ||
| 51 | @Autowired | 54 | @Autowired |
| 52 | private IQuestionEmbeddingService questionEmbeddingService; | 55 | private IQuestionEmbeddingService questionEmbeddingService; |
| 53 | 56 | ||
| 57 | + @Autowired | ||
| 58 | + private IAiragKnowledgeService airagKnowledgeService; | ||
| 59 | + | ||
| 54 | /** | 60 | /** |
| 55 | * 分页列表查询 | 61 | * 分页列表查询 |
| 56 | * | 62 | * |
| @@ -105,6 +111,31 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | @@ -105,6 +111,31 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | ||
| 105 | public Result<List<AiragModel>> queryAiragKnowledgeList(AiragModel airagModel, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { | 111 | public Result<List<AiragModel>> queryAiragKnowledgeList(AiragModel airagModel, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { |
| 106 | QueryWrapper<AiragModel> queryWrapper = QueryGenerator.initQueryWrapper(airagModel, req.getParameterMap()); | 112 | QueryWrapper<AiragModel> queryWrapper = QueryGenerator.initQueryWrapper(airagModel, req.getParameterMap()); |
| 107 | List<AiragModel> list = airagModelService.list(queryWrapper); | 113 | List<AiragModel> list = airagModelService.list(queryWrapper); |
| 114 | + | ||
| 115 | + // 过滤出 model_type 为 "llm" 的记录 | ||
| 116 | + List<AiragModel> filteredList = list.stream() | ||
| 117 | + .filter(model -> "LLM".equals(model.getModelType())) | ||
| 118 | + .collect(Collectors.toList()); | ||
| 119 | + | ||
| 120 | + | ||
| 121 | + return Result.OK(filteredList); | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + | ||
| 125 | + /** | ||
| 126 | + * 查询知识库名称 | ||
| 127 | + * | ||
| 128 | + * @param airagKnowledge | ||
| 129 | + * @param req | ||
| 130 | + * @return | ||
| 131 | + */ | ||
| 132 | + @AutoLog(value = "日志管理-查询知识库名称") | ||
| 133 | + @Operation(summary="日志管理-查询知识库名称") | ||
| 134 | + @GetMapping(value = "/listKnowledgeName") | ||
| 135 | + public Result<List<AiragKnowledge>> queryAiragKnowledgeNameList(AiragKnowledge airagKnowledge, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { | ||
| 136 | + QueryWrapper<AiragKnowledge> queryWrapper = QueryGenerator.initQueryWrapper(airagKnowledge, req.getParameterMap()); | ||
| 137 | + List<AiragKnowledge> list = airagKnowledgeService.list(queryWrapper); | ||
| 138 | + | ||
| 108 | return Result.OK(list); | 139 | return Result.OK(list); |
| 109 | } | 140 | } |
| 110 | 141 | ||
| @@ -141,6 +172,7 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | @@ -141,6 +172,7 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi | ||
| 141 | if(questionCount > 0){ | 172 | if(questionCount > 0){ |
| 142 | return Result.error("重复问题不能存入"); | 173 | return Result.error("重复问题不能存入"); |
| 143 | } | 174 | } |
| 175 | + | ||
| 144 | airagLog.setIsStorage(1); | 176 | airagLog.setIsStorage(1); |
| 145 | airagLogService.saveToQuestionLibrary(airagLog); | 177 | airagLogService.saveToQuestionLibrary(airagLog); |
| 146 | return Result.OK("存入问题库成功!"); | 178 | return Result.OK("存入问题库成功!"); |
| @@ -165,7 +165,7 @@ public class EmbeddingsController { | @@ -165,7 +165,7 @@ public class EmbeddingsController { | ||
| 165 | @RequiresPermissions("embeddings:embeddings:deleteBatch") | 165 | @RequiresPermissions("embeddings:embeddings:deleteBatch") |
| 166 | @DeleteMapping(value = "/deleteBatch") | 166 | @DeleteMapping(value = "/deleteBatch") |
| 167 | public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) { | 167 | public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) { |
| 168 | -// this.embeddingsService.removeByIds(Arrays.asList(ids.split(","))); | 168 | + embeddingsService.removeByIds(Arrays.asList(ids.split(","))); |
| 169 | return Result.OK("批量删除成功!"); | 169 | return Result.OK("批量删除成功!"); |
| 170 | } | 170 | } |
| 171 | 171 |
| @@ -2,9 +2,12 @@ package org.jeecg.modules.airag.app.controller; | @@ -2,9 +2,12 @@ package org.jeecg.modules.airag.app.controller; | ||
| 2 | 2 | ||
| 3 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | 3 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; |
| 4 | import dev.langchain4j.internal.Json; | 4 | import dev.langchain4j.internal.Json; |
| 5 | +import io.swagger.v3.oas.annotations.Operation; | ||
| 5 | import lombok.extern.slf4j.Slf4j; | 6 | import lombok.extern.slf4j.Slf4j; |
| 6 | import org.apache.commons.lang3.StringUtils; | 7 | import org.apache.commons.lang3.StringUtils; |
| 8 | +import org.apache.shiro.authz.annotation.RequiresPermissions; | ||
| 7 | import org.jeecg.common.api.vo.Result; | 9 | import org.jeecg.common.api.vo.Result; |
| 10 | +import org.jeecg.common.aspect.annotation.AutoLog; | ||
| 8 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | 11 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; |
| 9 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | 12 | import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; |
| 10 | import org.jeecg.modules.airag.app.utils.JsonUtils; | 13 | import org.jeecg.modules.airag.app.utils.JsonUtils; |
| @@ -15,6 +18,7 @@ import org.springframework.transaction.annotation.Transactional; | @@ -15,6 +18,7 @@ import org.springframework.transaction.annotation.Transactional; | ||
| 15 | import org.springframework.web.bind.annotation.*; | 18 | import org.springframework.web.bind.annotation.*; |
| 16 | import org.springframework.web.multipart.MultipartFile; | 19 | import org.springframework.web.multipart.MultipartFile; |
| 17 | 20 | ||
| 21 | +import java.util.Arrays; | ||
| 18 | import java.util.HashMap; | 22 | import java.util.HashMap; |
| 19 | import java.util.Map; | 23 | import java.util.Map; |
| 20 | import java.util.stream.Collectors; | 24 | import java.util.stream.Collectors; |
| @@ -111,6 +115,18 @@ public class QuestionEmbeddingController { | @@ -111,6 +115,18 @@ public class QuestionEmbeddingController { | ||
| 111 | return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败"); | 115 | return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败"); |
| 112 | } | 116 | } |
| 113 | 117 | ||
| 118 | + /** | ||
| 119 | + * 批量删除 | ||
| 120 | + * | ||
| 121 | + * @param ids | ||
| 122 | + * @return | ||
| 123 | + */ | ||
| 124 | + @DeleteMapping(value = "/deleteBatch") | ||
| 125 | + public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) { | ||
| 126 | + questionEmbeddingService.removeByIds(Arrays.asList(ids.split(","))); | ||
| 127 | + return Result.OK("批量删除成功!"); | ||
| 128 | + } | ||
| 129 | + | ||
| 114 | @PostMapping("/uploadZip") | 130 | @PostMapping("/uploadZip") |
| 115 | @Transactional(rollbackFor = {Exception.class}) | 131 | @Transactional(rollbackFor = {Exception.class}) |
| 116 | public Result<?> uploadZip( | 132 | public Result<?> uploadZip( |
| @@ -102,6 +102,9 @@ public class AiragLog implements Serializable { | @@ -102,6 +102,9 @@ public class AiragLog implements Serializable { | ||
| 102 | // 新增:临时字段(非数据库字段) | 102 | // 新增:临时字段(非数据库字段) |
| 103 | @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 | 103 | @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 |
| 104 | private String createTime_end; | 104 | private String createTime_end; |
| 105 | + // 新增:临时字段(非数据库字段) | ||
| 106 | + @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 | ||
| 107 | + private String knowledgeId; | ||
| 105 | 108 | ||
| 106 | @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 | 109 | @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 |
| 107 | private String createTimeStr; | 110 | private String createTimeStr; |
| @@ -24,5 +24,6 @@ public class Embeddings { | @@ -24,5 +24,6 @@ public class Embeddings { | ||
| 24 | private String knowledgeId; // 新增知识库ID字段 | 24 | private String knowledgeId; // 新增知识库ID字段 |
| 25 | private String docId; // 新增文档ID字段 | 25 | private String docId; // 新增文档ID字段 |
| 26 | private String index; // 新增索引位置字段 | 26 | private String index; // 新增索引位置字段 |
| 27 | + private String knowledgeName; | ||
| 27 | 28 | ||
| 28 | } | 29 | } |
| @@ -19,6 +19,6 @@ public interface AiragLogMapper extends BaseMapper<AiragLog> { | @@ -19,6 +19,6 @@ public interface AiragLogMapper extends BaseMapper<AiragLog> { | ||
| 19 | 19 | ||
| 20 | IPage<AiragLog> pageList(@Param("param1") AiragLog airagLog, Page<AiragLog> page); | 20 | IPage<AiragLog> pageList(@Param("param1") AiragLog airagLog, Page<AiragLog> page); |
| 21 | 21 | ||
| 22 | - int updataIsStorage(@Param("param1") int isStorage); | 22 | + int updataIsStorage(@Param("param1") int isStorage, @Param("param2") String id); |
| 23 | 23 | ||
| 24 | } | 24 | } |
| @@ -11,11 +11,13 @@ import com.pgvector.PGvector; | @@ -11,11 +11,13 @@ import com.pgvector.PGvector; | ||
| 11 | import lombok.extern.slf4j.Slf4j; | 11 | import lombok.extern.slf4j.Slf4j; |
| 12 | import org.apache.commons.lang3.StringUtils; | 12 | import org.apache.commons.lang3.StringUtils; |
| 13 | import org.jeecg.modules.airag.app.entity.Embeddings; | 13 | import org.jeecg.modules.airag.app.entity.Embeddings; |
| 14 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 14 | import org.postgresql.util.PGobject; | 15 | import org.postgresql.util.PGobject; |
| 15 | import org.springframework.stereotype.Component; | 16 | import org.springframework.stereotype.Component; |
| 16 | 17 | ||
| 17 | import java.sql.*; | 18 | import java.sql.*; |
| 18 | import java.util.*; | 19 | import java.util.*; |
| 20 | +import java.util.stream.Collectors; | ||
| 19 | 21 | ||
| 20 | @Component | 22 | @Component |
| 21 | @Slf4j | 23 | @Slf4j |
| @@ -79,6 +81,23 @@ public class PgVectorMapper { | @@ -79,6 +81,23 @@ public class PgVectorMapper { | ||
| 79 | } | 81 | } |
| 80 | 82 | ||
| 81 | 83 | ||
| 84 | + // 2. 获取知识库名称映射 | ||
| 85 | + Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results); | ||
| 86 | + | ||
| 87 | + // 3. 设置知识库名称并处理空值 | ||
| 88 | + for (Embeddings record : results) { | ||
| 89 | + String knowledgeId = record.getKnowledgeId(); | ||
| 90 | + String name = knowledgeNameMap.get(knowledgeId); | ||
| 91 | + record.setKnowledgeName(name != null ? name : ""); | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + // 4. 安全排序(处理空值) | ||
| 95 | + results.sort(Comparator | ||
| 96 | + .comparing(Embeddings::getKnowledgeName, | ||
| 97 | + Comparator.nullsLast(Comparator.naturalOrder())) | ||
| 98 | + .thenComparing(Embeddings::getDocName, | ||
| 99 | + Comparator.nullsLast(Comparator.naturalOrder()))); | ||
| 100 | + | ||
| 82 | // 执行计数查询 | 101 | // 执行计数查询 |
| 83 | int total = 0; | 102 | int total = 0; |
| 84 | try(Connection conn = getConnection(); | 103 | try(Connection conn = getConnection(); |
| @@ -212,6 +231,36 @@ public class PgVectorMapper { | @@ -212,6 +231,36 @@ public class PgVectorMapper { | ||
| 212 | } | 231 | } |
| 213 | } | 232 | } |
| 214 | 233 | ||
| 234 | + // 批量删除方法 | ||
| 235 | + public int deleteByIds(List<String> ids) { | ||
| 236 | + if (ids == null || ids.isEmpty()) { | ||
| 237 | + return 0; | ||
| 238 | + } | ||
| 239 | + | ||
| 240 | + String sql = "DELETE FROM embeddings WHERE embedding_id IN ("; | ||
| 241 | + StringBuilder placeholders = new StringBuilder(); | ||
| 242 | + for (int i = 0; i < ids.size(); i++) { | ||
| 243 | + placeholders.append("?"); | ||
| 244 | + if (i < ids.size() - 1) { | ||
| 245 | + placeholders.append(","); | ||
| 246 | + } | ||
| 247 | + } | ||
| 248 | + sql += placeholders.toString() + ")"; | ||
| 249 | + | ||
| 250 | + try (Connection conn = getConnection(); | ||
| 251 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 252 | + | ||
| 253 | + for (int i = 0; i < ids.size(); i++) { | ||
| 254 | + stmt.setString(i + 1, ids.get(i)); | ||
| 255 | + } | ||
| 256 | + | ||
| 257 | + return stmt.executeUpdate(); | ||
| 258 | + } catch (SQLException e) { | ||
| 259 | + log.error("批量删除向量记录失败, IDs: {}", ids, e); | ||
| 260 | + throw new RuntimeException("批量删除向量数据时发生数据库错误", e); | ||
| 261 | + } | ||
| 262 | + } | ||
| 263 | + | ||
| 215 | // 向量相似度搜索 | 264 | // 向量相似度搜索 |
| 216 | public List<Embeddings> similaritySearch(float[] vector, int limit) { | 265 | public List<Embeddings> similaritySearch(float[] vector, int limit) { |
| 217 | String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?"; | 266 | String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?"; |
| @@ -286,4 +335,50 @@ public class PgVectorMapper { | @@ -286,4 +335,50 @@ public class PgVectorMapper { | ||
| 286 | return embedding; | 335 | return embedding; |
| 287 | } | 336 | } |
| 288 | 337 | ||
| 338 | + | ||
| 339 | + // 获取知识库名称映射 | ||
| 340 | + private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) { | ||
| 341 | + // 提取所有知识库ID | ||
| 342 | + Set<String> knowledgeIds = records.stream() | ||
| 343 | + .map(Embeddings::getKnowledgeId) | ||
| 344 | + .filter(Objects::nonNull) | ||
| 345 | + .collect(Collectors.toSet()); | ||
| 346 | + | ||
| 347 | + if (knowledgeIds.isEmpty()) { | ||
| 348 | + return Collections.emptyMap(); | ||
| 349 | + } | ||
| 350 | + | ||
| 351 | + // 从 MySQL 查询知识库名称 | ||
| 352 | + Map<String, String> knowledgeNameMap = new HashMap<>(); | ||
| 353 | + try (Connection mysqlConn = getMysqlConnection()) { | ||
| 354 | + String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?")); | ||
| 355 | + String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders); | ||
| 356 | + | ||
| 357 | + try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) { | ||
| 358 | + int index = 1; | ||
| 359 | + for (String id : knowledgeIds) { | ||
| 360 | + stmt.setString(index++, id); | ||
| 361 | + } | ||
| 362 | + | ||
| 363 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 364 | + while (rs.next()) { | ||
| 365 | + knowledgeNameMap.put(rs.getString("id"), rs.getString("name")); | ||
| 366 | + } | ||
| 367 | + } | ||
| 368 | + } | ||
| 369 | + } catch (SQLException e) { | ||
| 370 | + log.error("查询知识库名称失败", e); | ||
| 371 | + } | ||
| 372 | + | ||
| 373 | + return knowledgeNameMap; | ||
| 374 | + } | ||
| 375 | + | ||
| 376 | + // 获取 MySQL 连接 | ||
| 377 | + private Connection getMysqlConnection() throws SQLException { | ||
| 378 | + String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai"; | ||
| 379 | + String user = "root"; | ||
| 380 | + String password = "123456"; | ||
| 381 | + return DriverManager.getConnection(url, user, password); | ||
| 382 | + } | ||
| 383 | + | ||
| 289 | } | 384 | } |
| @@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response; | @@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response; | ||
| 12 | import io.minio.messages.Metadata; | 12 | import io.minio.messages.Metadata; |
| 13 | import lombok.extern.slf4j.Slf4j; | 13 | import lombok.extern.slf4j.Slf4j; |
| 14 | import org.apache.commons.lang3.StringUtils; | 14 | import org.apache.commons.lang3.StringUtils; |
| 15 | +import org.jeecg.modules.airag.app.entity.Embeddings; | ||
| 15 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | 16 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; |
| 16 | import org.jeecg.modules.airag.app.utils.AiModelUtils; | 17 | import org.jeecg.modules.airag.app.utils.AiModelUtils; |
| 17 | import org.postgresql.util.PGobject; | 18 | import org.postgresql.util.PGobject; |
| @@ -20,6 +21,7 @@ import org.springframework.stereotype.Component; | @@ -20,6 +21,7 @@ import org.springframework.stereotype.Component; | ||
| 20 | 21 | ||
| 21 | import java.sql.*; | 22 | import java.sql.*; |
| 22 | import java.util.*; | 23 | import java.util.*; |
| 24 | +import java.util.stream.Collectors; | ||
| 23 | 25 | ||
| 24 | @Component | 26 | @Component |
| 25 | @Slf4j | 27 | @Slf4j |
| @@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper { | @@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper { | ||
| 89 | throw new RuntimeException("查询数据时发生数据库错误", e); | 91 | throw new RuntimeException("查询数据时发生数据库错误", e); |
| 90 | } | 92 | } |
| 91 | 93 | ||
| 94 | + // 2. 获取知识库名称映射 | ||
| 95 | + Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results); | ||
| 96 | + | ||
| 97 | + // 3. 设置知识库名称并处理空值 | ||
| 98 | + for (QuestionEmbedding record : results) { | ||
| 99 | + String knowledgeId = record.getKnowledgeId(); | ||
| 100 | + String name = knowledgeNameMap.get(knowledgeId); | ||
| 101 | + record.setKnowledgeName(name != null ? name : ""); | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + // 4. 安全排序(处理空值) | ||
| 105 | + results.sort(Comparator | ||
| 106 | + .comparing(QuestionEmbedding::getKnowledgeName, | ||
| 107 | + Comparator.nullsLast(Comparator.naturalOrder())) | ||
| 108 | + .thenComparing(QuestionEmbedding::getQuestion, | ||
| 109 | + Comparator.nullsLast(Comparator.naturalOrder()))); | ||
| 110 | + | ||
| 92 | // 执行计数查询 | 111 | // 执行计数查询 |
| 93 | long total = 0; | 112 | long total = 0; |
| 94 | try(Connection conn = getConnection(); | 113 | try(Connection conn = getConnection(); |
| @@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper { | @@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper { | ||
| 236 | } | 255 | } |
| 237 | 256 | ||
| 238 | 257 | ||
| 258 | + // 批量删除方法 | ||
| 259 | + public int deleteByIds(List<String> ids) { | ||
| 260 | + if (ids == null || ids.isEmpty()) { | ||
| 261 | + return 0; | ||
| 262 | + } | ||
| 263 | + | ||
| 264 | + String sql = "DELETE FROM question_embedding WHERE id IN ("; | ||
| 265 | + StringBuilder placeholders = new StringBuilder(); | ||
| 266 | + for (int i = 0; i < ids.size(); i++) { | ||
| 267 | + placeholders.append("?"); | ||
| 268 | + if (i < ids.size() - 1) { | ||
| 269 | + placeholders.append(","); | ||
| 270 | + } | ||
| 271 | + } | ||
| 272 | + sql += placeholders.toString() + ")"; | ||
| 273 | + | ||
| 274 | + try (Connection conn = getConnection(); | ||
| 275 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 276 | + | ||
| 277 | + for (int i = 0; i < ids.size(); i++) { | ||
| 278 | + stmt.setString(i + 1, ids.get(i)); | ||
| 279 | + } | ||
| 280 | + | ||
| 281 | + return stmt.executeUpdate(); | ||
| 282 | + } catch (SQLException e) { | ||
| 283 | + log.error("批量删除向量记录失败, IDs: {}", ids, e); | ||
| 284 | + throw new RuntimeException("批量删除向量数据时发生数据库错误", e); | ||
| 285 | + } | ||
| 286 | + } | ||
| 287 | + | ||
| 288 | + | ||
| 239 | /** | 289 | /** |
| 240 | * 向量相似度查询 (基于问题文本的向量) | 290 | * 向量相似度查询 (基于问题文本的向量) |
| 241 | * @param question 问题文本 | 291 | * @param question 问题文本 |
| @@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper { | @@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper { | ||
| 376 | return Collections.emptyMap(); | 426 | return Collections.emptyMap(); |
| 377 | } | 427 | } |
| 378 | } | 428 | } |
| 429 | + | ||
| 430 | + // 获取知识库名称映射 | ||
| 431 | + private Map<String, String> getKnowledgeNameMap(List<QuestionEmbedding> records) { | ||
| 432 | + // 提取所有知识库ID | ||
| 433 | + Set<String> knowledgeIds = records.stream() | ||
| 434 | + .map(QuestionEmbedding::getKnowledgeId) | ||
| 435 | + .filter(Objects::nonNull) | ||
| 436 | + .collect(Collectors.toSet()); | ||
| 437 | + | ||
| 438 | + if (knowledgeIds.isEmpty()) { | ||
| 439 | + return Collections.emptyMap(); | ||
| 440 | + } | ||
| 441 | + | ||
| 442 | + // 从 MySQL 查询知识库名称 | ||
| 443 | + Map<String, String> knowledgeNameMap = new HashMap<>(); | ||
| 444 | + try (Connection mysqlConn = getMysqlConnection()) { | ||
| 445 | + String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?")); | ||
| 446 | + String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders); | ||
| 447 | + | ||
| 448 | + try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) { | ||
| 449 | + int index = 1; | ||
| 450 | + for (String id : knowledgeIds) { | ||
| 451 | + stmt.setString(index++, id); | ||
| 452 | + } | ||
| 453 | + | ||
| 454 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 455 | + while (rs.next()) { | ||
| 456 | + knowledgeNameMap.put(rs.getString("id"), rs.getString("name")); | ||
| 457 | + } | ||
| 458 | + } | ||
| 459 | + } | ||
| 460 | + } catch (SQLException e) { | ||
| 461 | + log.error("查询知识库名称失败", e); | ||
| 462 | + } | ||
| 463 | + | ||
| 464 | + return knowledgeNameMap; | ||
| 465 | + } | ||
| 466 | + | ||
| 467 | + // 获取 MySQL 连接 | ||
| 468 | + private Connection getMysqlConnection() throws SQLException { | ||
| 469 | + String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai"; | ||
| 470 | + String user = "root"; | ||
| 471 | + String password = "123456"; | ||
| 472 | + return DriverManager.getConnection(url, user, password); | ||
| 473 | + } | ||
| 379 | } | 474 | } |
| @@ -5,6 +5,7 @@ package org.jeecg.modules.airag.app.service; | @@ -5,6 +5,7 @@ package org.jeecg.modules.airag.app.service; | ||
| 5 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | 5 | import com.baomidou.mybatisplus.extension.plugins.pagination.Page; |
| 6 | import org.jeecg.modules.airag.app.entity.Embeddings; | 6 | import org.jeecg.modules.airag.app.entity.Embeddings; |
| 7 | 7 | ||
| 8 | +import java.util.ArrayList; | ||
| 8 | import java.util.List; | 9 | import java.util.List; |
| 9 | 10 | ||
| 10 | /** | 11 | /** |
| @@ -20,4 +21,5 @@ public interface IEmbeddingsService { | @@ -20,4 +21,5 @@ public interface IEmbeddingsService { | ||
| 20 | int insert(Embeddings record); | 21 | int insert(Embeddings record); |
| 21 | int update(Embeddings record); | 22 | int update(Embeddings record); |
| 22 | Embeddings findById(String id); | 23 | Embeddings findById(String id); |
| 24 | + int removeByIds(List<String> ids); | ||
| 23 | } | 25 | } |
| @@ -13,6 +13,7 @@ public interface IQuestionEmbeddingService { | @@ -13,6 +13,7 @@ public interface IQuestionEmbeddingService { | ||
| 13 | QuestionEmbedding findById(String id); | 13 | QuestionEmbedding findById(String id); |
| 14 | int insert(QuestionEmbedding record); | 14 | int insert(QuestionEmbedding record); |
| 15 | int update(QuestionEmbedding record); | 15 | int update(QuestionEmbedding record); |
| 16 | + int removeByIds(List<String> ids); | ||
| 16 | int deleteById(String id); | 17 | int deleteById(String id); |
| 17 | List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity); | 18 | List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity); |
| 18 | List<QuestionEmbedding> similaritySearch(float[] vector, int limit); | 19 | List<QuestionEmbedding> similaritySearch(float[] vector, int limit); |
| @@ -56,8 +56,9 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | @@ -56,8 +56,9 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i | ||
| 56 | QuestionEmbedding questionEmbedding = new QuestionEmbedding(); | 56 | QuestionEmbedding questionEmbedding = new QuestionEmbedding(); |
| 57 | questionEmbedding.setQuestion(log.getQuestion()); | 57 | questionEmbedding.setQuestion(log.getQuestion()); |
| 58 | questionEmbedding.setAnswer(log.getAnswer()); | 58 | questionEmbedding.setAnswer(log.getAnswer()); |
| 59 | + questionEmbedding.setKnowledgeId(log.getKnowledgeId()); | ||
| 59 | questionEmbeddingMapper.insert(questionEmbedding); | 60 | questionEmbeddingMapper.insert(questionEmbedding); |
| 60 | - airagLogMapper.updataIsStorage(log.getIsStorage()); | 61 | + airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId()); |
| 61 | 62 | ||
| 62 | } | 63 | } |
| 63 | 64 |
| @@ -42,4 +42,9 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService { | @@ -42,4 +42,9 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService { | ||
| 42 | public Embeddings findById(String id) { | 42 | public Embeddings findById(String id) { |
| 43 | return pgVectorMapper.findById(id); | 43 | return pgVectorMapper.findById(id); |
| 44 | } | 44 | } |
| 45 | + | ||
| 46 | + @Override | ||
| 47 | + public int removeByIds(List<String> ids) { | ||
| 48 | + return pgVectorMapper.deleteByIds(ids); | ||
| 49 | + } | ||
| 45 | } | 50 | } |
| @@ -95,6 +95,11 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | @@ -95,6 +95,11 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 95 | } | 95 | } |
| 96 | 96 | ||
| 97 | @Override | 97 | @Override |
| 98 | + public int removeByIds(List<String> ids) { | ||
| 99 | + return questionEmbeddingMapper.deleteByIds(ids); | ||
| 100 | + } | ||
| 101 | + | ||
| 102 | + @Override | ||
| 98 | public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { | 103 | public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { |
| 99 | return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); | 104 | return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); |
| 100 | } | 105 | } |
-
请 注册 或 登录 后发表评论