作者 dong

修复bug,更正需求

正在显示 15 个修改的文件 包含 262 行增加3 行删除
@@ -19,6 +19,7 @@ import org.jeecg.modules.airag.app.service.IAiragLogService; @@ -19,6 +19,7 @@ import org.jeecg.modules.airag.app.service.IAiragLogService;
19 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; 19 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
20 import org.jeecg.modules.airag.llm.entity.AiragKnowledge; 20 import org.jeecg.modules.airag.llm.entity.AiragKnowledge;
21 import org.jeecg.modules.airag.llm.entity.AiragModel; 21 import org.jeecg.modules.airag.llm.entity.AiragModel;
  22 +import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
22 import org.jeecg.modules.airag.llm.service.IAiragModelService; 23 import org.jeecg.modules.airag.llm.service.IAiragModelService;
23 import org.springframework.beans.factory.annotation.Autowired; 24 import org.springframework.beans.factory.annotation.Autowired;
24 import org.springframework.web.bind.annotation.*; 25 import org.springframework.web.bind.annotation.*;
@@ -29,6 +30,8 @@ import javax.servlet.http.HttpServletResponse; @@ -29,6 +30,8 @@ import javax.servlet.http.HttpServletResponse;
29 import java.sql.SQLException; 30 import java.sql.SQLException;
30 import java.util.*; 31 import java.util.*;
31 32
  33 +import java.util.stream.Collectors;
  34 +
32 /** 35 /**
33 * @Description: 日志管理 36 * @Description: 日志管理
34 * @Author: jeecg-boot 37 * @Author: jeecg-boot
@@ -51,6 +54,9 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi @@ -51,6 +54,9 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
51 @Autowired 54 @Autowired
52 private IQuestionEmbeddingService questionEmbeddingService; 55 private IQuestionEmbeddingService questionEmbeddingService;
53 56
  57 + @Autowired
  58 + private IAiragKnowledgeService airagKnowledgeService;
  59 +
54 /** 60 /**
55 * 分页列表查询 61 * 分页列表查询
56 * 62 *
@@ -105,6 +111,31 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi @@ -105,6 +111,31 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
105 public Result<List<AiragModel>> queryAiragKnowledgeList(AiragModel airagModel, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { 111 public Result<List<AiragModel>> queryAiragKnowledgeList(AiragModel airagModel, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
106 QueryWrapper<AiragModel> queryWrapper = QueryGenerator.initQueryWrapper(airagModel, req.getParameterMap()); 112 QueryWrapper<AiragModel> queryWrapper = QueryGenerator.initQueryWrapper(airagModel, req.getParameterMap());
107 List<AiragModel> list = airagModelService.list(queryWrapper); 113 List<AiragModel> list = airagModelService.list(queryWrapper);
  114 +
  115 + // 过滤出 model_type 为 "llm" 的记录
  116 + List<AiragModel> filteredList = list.stream()
  117 + .filter(model -> "LLM".equals(model.getModelType()))
  118 + .collect(Collectors.toList());
  119 +
  120 +
  121 + return Result.OK(filteredList);
  122 + }
  123 +
  124 +
  125 + /**
  126 + * 查询知识库名称
  127 + *
  128 + * @param airagKnowledge
  129 + * @param req
  130 + * @return
  131 + */
  132 + @AutoLog(value = "日志管理-查询知识库名称")
  133 + @Operation(summary="日志管理-查询知识库名称")
  134 + @GetMapping(value = "/listKnowledgeName")
  135 + public Result<List<AiragKnowledge>> queryAiragKnowledgeNameList(AiragKnowledge airagKnowledge, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
  136 + QueryWrapper<AiragKnowledge> queryWrapper = QueryGenerator.initQueryWrapper(airagKnowledge, req.getParameterMap());
  137 + List<AiragKnowledge> list = airagKnowledgeService.list(queryWrapper);
  138 +
108 return Result.OK(list); 139 return Result.OK(list);
109 } 140 }
110 141
@@ -141,6 +172,7 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi @@ -141,6 +172,7 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
141 if(questionCount > 0){ 172 if(questionCount > 0){
142 return Result.error("重复问题不能存入"); 173 return Result.error("重复问题不能存入");
143 } 174 }
  175 +
144 airagLog.setIsStorage(1); 176 airagLog.setIsStorage(1);
145 airagLogService.saveToQuestionLibrary(airagLog); 177 airagLogService.saveToQuestionLibrary(airagLog);
146 return Result.OK("存入问题库成功!"); 178 return Result.OK("存入问题库成功!");
@@ -165,7 +165,7 @@ public class EmbeddingsController { @@ -165,7 +165,7 @@ public class EmbeddingsController {
165 @RequiresPermissions("embeddings:embeddings:deleteBatch") 165 @RequiresPermissions("embeddings:embeddings:deleteBatch")
166 @DeleteMapping(value = "/deleteBatch") 166 @DeleteMapping(value = "/deleteBatch")
167 public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) { 167 public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) {
168 -// this.embeddingsService.removeByIds(Arrays.asList(ids.split(","))); 168 + embeddingsService.removeByIds(Arrays.asList(ids.split(",")));
169 return Result.OK("批量删除成功!"); 169 return Result.OK("批量删除成功!");
170 } 170 }
171 171
@@ -2,9 +2,12 @@ package org.jeecg.modules.airag.app.controller; @@ -2,9 +2,12 @@ package org.jeecg.modules.airag.app.controller;
2 2
3 import com.baomidou.mybatisplus.extension.plugins.pagination.Page; 3 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
4 import dev.langchain4j.internal.Json; 4 import dev.langchain4j.internal.Json;
  5 +import io.swagger.v3.oas.annotations.Operation;
5 import lombok.extern.slf4j.Slf4j; 6 import lombok.extern.slf4j.Slf4j;
6 import org.apache.commons.lang3.StringUtils; 7 import org.apache.commons.lang3.StringUtils;
  8 +import org.apache.shiro.authz.annotation.RequiresPermissions;
7 import org.jeecg.common.api.vo.Result; 9 import org.jeecg.common.api.vo.Result;
  10 +import org.jeecg.common.aspect.annotation.AutoLog;
8 import org.jeecg.modules.airag.app.entity.QuestionEmbedding; 11 import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
9 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; 12 import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
10 import org.jeecg.modules.airag.app.utils.JsonUtils; 13 import org.jeecg.modules.airag.app.utils.JsonUtils;
@@ -15,6 +18,7 @@ import org.springframework.transaction.annotation.Transactional; @@ -15,6 +18,7 @@ import org.springframework.transaction.annotation.Transactional;
15 import org.springframework.web.bind.annotation.*; 18 import org.springframework.web.bind.annotation.*;
16 import org.springframework.web.multipart.MultipartFile; 19 import org.springframework.web.multipart.MultipartFile;
17 20
  21 +import java.util.Arrays;
18 import java.util.HashMap; 22 import java.util.HashMap;
19 import java.util.Map; 23 import java.util.Map;
20 import java.util.stream.Collectors; 24 import java.util.stream.Collectors;
@@ -111,6 +115,18 @@ public class QuestionEmbeddingController { @@ -111,6 +115,18 @@ public class QuestionEmbeddingController {
111 return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败"); 115 return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败");
112 } 116 }
113 117
  118 + /**
  119 + * 批量删除
  120 + *
  121 + * @param ids
  122 + * @return
  123 + */
  124 + @DeleteMapping(value = "/deleteBatch")
  125 + public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) {
  126 + questionEmbeddingService.removeByIds(Arrays.asList(ids.split(",")));
  127 + return Result.OK("批量删除成功!");
  128 + }
  129 +
114 @PostMapping("/uploadZip") 130 @PostMapping("/uploadZip")
115 @Transactional(rollbackFor = {Exception.class}) 131 @Transactional(rollbackFor = {Exception.class})
116 public Result<?> uploadZip( 132 public Result<?> uploadZip(
@@ -102,6 +102,9 @@ public class AiragLog implements Serializable { @@ -102,6 +102,9 @@ public class AiragLog implements Serializable {
102 // 新增:临时字段(非数据库字段) 102 // 新增:临时字段(非数据库字段)
103 @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 103 @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
104 private String createTime_end; 104 private String createTime_end;
  105 + // 新增:临时字段(非数据库字段)
  106 + @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
  107 + private String knowledgeId;
105 108
106 @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中 109 @TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
107 private String createTimeStr; 110 private String createTimeStr;
@@ -24,5 +24,6 @@ public class Embeddings { @@ -24,5 +24,6 @@ public class Embeddings {
24 private String knowledgeId; // 新增知识库ID字段 24 private String knowledgeId; // 新增知识库ID字段
25 private String docId; // 新增文档ID字段 25 private String docId; // 新增文档ID字段
26 private String index; // 新增索引位置字段 26 private String index; // 新增索引位置字段
  27 + private String knowledgeName;
27 28
28 } 29 }
@@ -52,4 +52,6 @@ public class QuestionEmbedding { @@ -52,4 +52,6 @@ public class QuestionEmbedding {
52 private String knowledgeId; 52 private String knowledgeId;
53 53
54 54
  55 +
  56 +
55 } 57 }
@@ -19,6 +19,6 @@ public interface AiragLogMapper extends BaseMapper<AiragLog> { @@ -19,6 +19,6 @@ public interface AiragLogMapper extends BaseMapper<AiragLog> {
19 19
20 IPage<AiragLog> pageList(@Param("param1") AiragLog airagLog, Page<AiragLog> page); 20 IPage<AiragLog> pageList(@Param("param1") AiragLog airagLog, Page<AiragLog> page);
21 21
22 - int updataIsStorage(@Param("param1") int isStorage); 22 + int updataIsStorage(@Param("param1") int isStorage, @Param("param2") String id);
23 23
24 } 24 }
@@ -11,11 +11,13 @@ import com.pgvector.PGvector; @@ -11,11 +11,13 @@ import com.pgvector.PGvector;
11 import lombok.extern.slf4j.Slf4j; 11 import lombok.extern.slf4j.Slf4j;
12 import org.apache.commons.lang3.StringUtils; 12 import org.apache.commons.lang3.StringUtils;
13 import org.jeecg.modules.airag.app.entity.Embeddings; 13 import org.jeecg.modules.airag.app.entity.Embeddings;
  14 +import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
14 import org.postgresql.util.PGobject; 15 import org.postgresql.util.PGobject;
15 import org.springframework.stereotype.Component; 16 import org.springframework.stereotype.Component;
16 17
17 import java.sql.*; 18 import java.sql.*;
18 import java.util.*; 19 import java.util.*;
  20 +import java.util.stream.Collectors;
19 21
20 @Component 22 @Component
21 @Slf4j 23 @Slf4j
@@ -79,6 +81,23 @@ public class PgVectorMapper { @@ -79,6 +81,23 @@ public class PgVectorMapper {
79 } 81 }
80 82
81 83
  84 + // 2. 获取知识库名称映射
  85 + Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results);
  86 +
  87 + // 3. 设置知识库名称并处理空值
  88 + for (Embeddings record : results) {
  89 + String knowledgeId = record.getKnowledgeId();
  90 + String name = knowledgeNameMap.get(knowledgeId);
  91 + record.setKnowledgeName(name != null ? name : "");
  92 + }
  93 +
  94 + // 4. 安全排序(处理空值)
  95 + results.sort(Comparator
  96 + .comparing(Embeddings::getKnowledgeName,
  97 + Comparator.nullsLast(Comparator.naturalOrder()))
  98 + .thenComparing(Embeddings::getDocName,
  99 + Comparator.nullsLast(Comparator.naturalOrder())));
  100 +
82 // 执行计数查询 101 // 执行计数查询
83 int total = 0; 102 int total = 0;
84 try(Connection conn = getConnection(); 103 try(Connection conn = getConnection();
@@ -212,6 +231,36 @@ public class PgVectorMapper { @@ -212,6 +231,36 @@ public class PgVectorMapper {
212 } 231 }
213 } 232 }
214 233
  234 + // 批量删除方法
  235 + public int deleteByIds(List<String> ids) {
  236 + if (ids == null || ids.isEmpty()) {
  237 + return 0;
  238 + }
  239 +
  240 + String sql = "DELETE FROM embeddings WHERE embedding_id IN (";
  241 + StringBuilder placeholders = new StringBuilder();
  242 + for (int i = 0; i < ids.size(); i++) {
  243 + placeholders.append("?");
  244 + if (i < ids.size() - 1) {
  245 + placeholders.append(",");
  246 + }
  247 + }
  248 + sql += placeholders.toString() + ")";
  249 +
  250 + try (Connection conn = getConnection();
  251 + PreparedStatement stmt = conn.prepareStatement(sql)) {
  252 +
  253 + for (int i = 0; i < ids.size(); i++) {
  254 + stmt.setString(i + 1, ids.get(i));
  255 + }
  256 +
  257 + return stmt.executeUpdate();
  258 + } catch (SQLException e) {
  259 + log.error("批量删除向量记录失败, IDs: {}", ids, e);
  260 + throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
  261 + }
  262 + }
  263 +
215 // 向量相似度搜索 264 // 向量相似度搜索
216 public List<Embeddings> similaritySearch(float[] vector, int limit) { 265 public List<Embeddings> similaritySearch(float[] vector, int limit) {
217 String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?"; 266 String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?";
@@ -286,4 +335,50 @@ public class PgVectorMapper { @@ -286,4 +335,50 @@ public class PgVectorMapper {
286 return embedding; 335 return embedding;
287 } 336 }
288 337
  338 +
  339 + // 获取知识库名称映射
  340 + private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) {
  341 + // 提取所有知识库ID
  342 + Set<String> knowledgeIds = records.stream()
  343 + .map(Embeddings::getKnowledgeId)
  344 + .filter(Objects::nonNull)
  345 + .collect(Collectors.toSet());
  346 +
  347 + if (knowledgeIds.isEmpty()) {
  348 + return Collections.emptyMap();
  349 + }
  350 +
  351 + // 从 MySQL 查询知识库名称
  352 + Map<String, String> knowledgeNameMap = new HashMap<>();
  353 + try (Connection mysqlConn = getMysqlConnection()) {
  354 + String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
  355 + String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
  356 +
  357 + try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
  358 + int index = 1;
  359 + for (String id : knowledgeIds) {
  360 + stmt.setString(index++, id);
  361 + }
  362 +
  363 + try (ResultSet rs = stmt.executeQuery()) {
  364 + while (rs.next()) {
  365 + knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
  366 + }
  367 + }
  368 + }
  369 + } catch (SQLException e) {
  370 + log.error("查询知识库名称失败", e);
  371 + }
  372 +
  373 + return knowledgeNameMap;
  374 + }
  375 +
  376 + // 获取 MySQL 连接
  377 + private Connection getMysqlConnection() throws SQLException {
  378 + String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
  379 + String user = "root";
  380 + String password = "123456";
  381 + return DriverManager.getConnection(url, user, password);
  382 + }
  383 +
289 } 384 }
@@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response; @@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response;
12 import io.minio.messages.Metadata; 12 import io.minio.messages.Metadata;
13 import lombok.extern.slf4j.Slf4j; 13 import lombok.extern.slf4j.Slf4j;
14 import org.apache.commons.lang3.StringUtils; 14 import org.apache.commons.lang3.StringUtils;
  15 +import org.jeecg.modules.airag.app.entity.Embeddings;
15 import org.jeecg.modules.airag.app.entity.QuestionEmbedding; 16 import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
16 import org.jeecg.modules.airag.app.utils.AiModelUtils; 17 import org.jeecg.modules.airag.app.utils.AiModelUtils;
17 import org.postgresql.util.PGobject; 18 import org.postgresql.util.PGobject;
@@ -20,6 +21,7 @@ import org.springframework.stereotype.Component; @@ -20,6 +21,7 @@ import org.springframework.stereotype.Component;
20 21
21 import java.sql.*; 22 import java.sql.*;
22 import java.util.*; 23 import java.util.*;
  24 +import java.util.stream.Collectors;
23 25
24 @Component 26 @Component
25 @Slf4j 27 @Slf4j
@@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper { @@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper {
89 throw new RuntimeException("查询数据时发生数据库错误", e); 91 throw new RuntimeException("查询数据时发生数据库错误", e);
90 } 92 }
91 93
  94 + // 2. 获取知识库名称映射
  95 + Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results);
  96 +
  97 + // 3. 设置知识库名称并处理空值
  98 + for (QuestionEmbedding record : results) {
  99 + String knowledgeId = record.getKnowledgeId();
  100 + String name = knowledgeNameMap.get(knowledgeId);
  101 + record.setKnowledgeName(name != null ? name : "");
  102 + }
  103 +
  104 + // 4. 安全排序(处理空值)
  105 + results.sort(Comparator
  106 + .comparing(QuestionEmbedding::getKnowledgeName,
  107 + Comparator.nullsLast(Comparator.naturalOrder()))
  108 + .thenComparing(QuestionEmbedding::getQuestion,
  109 + Comparator.nullsLast(Comparator.naturalOrder())));
  110 +
92 // 执行计数查询 111 // 执行计数查询
93 long total = 0; 112 long total = 0;
94 try(Connection conn = getConnection(); 113 try(Connection conn = getConnection();
@@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper { @@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper {
236 } 255 }
237 256
238 257
  258 + // 批量删除方法
  259 + public int deleteByIds(List<String> ids) {
  260 + if (ids == null || ids.isEmpty()) {
  261 + return 0;
  262 + }
  263 +
  264 + String sql = "DELETE FROM question_embedding WHERE id IN (";
  265 + StringBuilder placeholders = new StringBuilder();
  266 + for (int i = 0; i < ids.size(); i++) {
  267 + placeholders.append("?");
  268 + if (i < ids.size() - 1) {
  269 + placeholders.append(",");
  270 + }
  271 + }
  272 + sql += placeholders.toString() + ")";
  273 +
  274 + try (Connection conn = getConnection();
  275 + PreparedStatement stmt = conn.prepareStatement(sql)) {
  276 +
  277 + for (int i = 0; i < ids.size(); i++) {
  278 + stmt.setString(i + 1, ids.get(i));
  279 + }
  280 +
  281 + return stmt.executeUpdate();
  282 + } catch (SQLException e) {
  283 + log.error("批量删除向量记录失败, IDs: {}", ids, e);
  284 + throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
  285 + }
  286 + }
  287 +
  288 +
239 /** 289 /**
240 * 向量相似度查询 (基于问题文本的向量) 290 * 向量相似度查询 (基于问题文本的向量)
241 * @param question 问题文本 291 * @param question 问题文本
@@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper { @@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper {
376 return Collections.emptyMap(); 426 return Collections.emptyMap();
377 } 427 }
378 } 428 }
  429 +
  430 + // 获取知识库名称映射
  431 + private Map<String, String> getKnowledgeNameMap(List<QuestionEmbedding> records) {
  432 + // 提取所有知识库ID
  433 + Set<String> knowledgeIds = records.stream()
  434 + .map(QuestionEmbedding::getKnowledgeId)
  435 + .filter(Objects::nonNull)
  436 + .collect(Collectors.toSet());
  437 +
  438 + if (knowledgeIds.isEmpty()) {
  439 + return Collections.emptyMap();
  440 + }
  441 +
  442 + // 从 MySQL 查询知识库名称
  443 + Map<String, String> knowledgeNameMap = new HashMap<>();
  444 + try (Connection mysqlConn = getMysqlConnection()) {
  445 + String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
  446 + String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
  447 +
  448 + try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
  449 + int index = 1;
  450 + for (String id : knowledgeIds) {
  451 + stmt.setString(index++, id);
  452 + }
  453 +
  454 + try (ResultSet rs = stmt.executeQuery()) {
  455 + while (rs.next()) {
  456 + knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
  457 + }
  458 + }
  459 + }
  460 + } catch (SQLException e) {
  461 + log.error("查询知识库名称失败", e);
  462 + }
  463 +
  464 + return knowledgeNameMap;
  465 + }
  466 +
  467 + // 获取 MySQL 连接
  468 + private Connection getMysqlConnection() throws SQLException {
  469 + String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
  470 + String user = "root";
  471 + String password = "123456";
  472 + return DriverManager.getConnection(url, user, password);
  473 + }
379 } 474 }
@@ -36,5 +36,6 @@ @@ -36,5 +36,6 @@
36 <update id="updataIsStorage"> 36 <update id="updataIsStorage">
37 update airag_log 37 update airag_log
38 set is_storage = #{isStorage} 38 set is_storage = #{isStorage}
  39 + where id = #{id}
39 </update> 40 </update>
40 </mapper> 41 </mapper>
@@ -5,6 +5,7 @@ package org.jeecg.modules.airag.app.service; @@ -5,6 +5,7 @@ package org.jeecg.modules.airag.app.service;
5 import com.baomidou.mybatisplus.extension.plugins.pagination.Page; 5 import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
6 import org.jeecg.modules.airag.app.entity.Embeddings; 6 import org.jeecg.modules.airag.app.entity.Embeddings;
7 7
  8 +import java.util.ArrayList;
8 import java.util.List; 9 import java.util.List;
9 10
10 /** 11 /**
@@ -20,4 +21,5 @@ public interface IEmbeddingsService { @@ -20,4 +21,5 @@ public interface IEmbeddingsService {
20 int insert(Embeddings record); 21 int insert(Embeddings record);
21 int update(Embeddings record); 22 int update(Embeddings record);
22 Embeddings findById(String id); 23 Embeddings findById(String id);
  24 + int removeByIds(List<String> ids);
23 } 25 }
@@ -13,6 +13,7 @@ public interface IQuestionEmbeddingService { @@ -13,6 +13,7 @@ public interface IQuestionEmbeddingService {
13 QuestionEmbedding findById(String id); 13 QuestionEmbedding findById(String id);
14 int insert(QuestionEmbedding record); 14 int insert(QuestionEmbedding record);
15 int update(QuestionEmbedding record); 15 int update(QuestionEmbedding record);
  16 + int removeByIds(List<String> ids);
16 int deleteById(String id); 17 int deleteById(String id);
17 List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity); 18 List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity);
18 List<QuestionEmbedding> similaritySearch(float[] vector, int limit); 19 List<QuestionEmbedding> similaritySearch(float[] vector, int limit);
@@ -56,8 +56,9 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i @@ -56,8 +56,9 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
56 QuestionEmbedding questionEmbedding = new QuestionEmbedding(); 56 QuestionEmbedding questionEmbedding = new QuestionEmbedding();
57 questionEmbedding.setQuestion(log.getQuestion()); 57 questionEmbedding.setQuestion(log.getQuestion());
58 questionEmbedding.setAnswer(log.getAnswer()); 58 questionEmbedding.setAnswer(log.getAnswer());
  59 + questionEmbedding.setKnowledgeId(log.getKnowledgeId());
59 questionEmbeddingMapper.insert(questionEmbedding); 60 questionEmbeddingMapper.insert(questionEmbedding);
60 - airagLogMapper.updataIsStorage(log.getIsStorage()); 61 + airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId());
61 62
62 } 63 }
63 64
@@ -42,4 +42,9 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService { @@ -42,4 +42,9 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService {
42 public Embeddings findById(String id) { 42 public Embeddings findById(String id) {
43 return pgVectorMapper.findById(id); 43 return pgVectorMapper.findById(id);
44 } 44 }
  45 +
  46 + @Override
  47 + public int removeByIds(List<String> ids) {
  48 + return pgVectorMapper.deleteByIds(ids);
  49 + }
45 } 50 }
@@ -95,6 +95,11 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { @@ -95,6 +95,11 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
95 } 95 }
96 96
97 @Override 97 @Override
  98 + public int removeByIds(List<String> ids) {
  99 + return questionEmbeddingMapper.deleteByIds(ids);
  100 + }
  101 +
  102 + @Override
98 public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { 103 public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
99 return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); 104 return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity);
100 } 105 }