作者 dong

修复bug,更正需求

正在显示 15 个修改的文件 包含 262 行增加3 行删除
... ... @@ -19,6 +19,7 @@ import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.llm.entity.AiragKnowledge;
import org.jeecg.modules.airag.llm.entity.AiragModel;
import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
import org.jeecg.modules.airag.llm.service.IAiragModelService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.*;
... ... @@ -29,6 +30,8 @@ import javax.servlet.http.HttpServletResponse;
import java.sql.SQLException;
import java.util.*;
import java.util.stream.Collectors;
/**
* @Description: 日志管理
* @Author: jeecg-boot
... ... @@ -51,6 +54,9 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@Autowired
private IAiragKnowledgeService airagKnowledgeService;
/**
* 分页列表查询
*
... ... @@ -105,6 +111,31 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
public Result<List<AiragModel>> queryAiragKnowledgeList(AiragModel airagModel, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
QueryWrapper<AiragModel> queryWrapper = QueryGenerator.initQueryWrapper(airagModel, req.getParameterMap());
List<AiragModel> list = airagModelService.list(queryWrapper);
// 过滤出 model_type 为 "llm" 的记录
List<AiragModel> filteredList = list.stream()
.filter(model -> "LLM".equals(model.getModelType()))
.collect(Collectors.toList());
return Result.OK(filteredList);
}
/**
* 查询知识库名称
*
* @param airagKnowledge
* @param req
* @return
*/
@AutoLog(value = "日志管理-查询知识库名称")
@Operation(summary="日志管理-查询知识库名称")
@GetMapping(value = "/listKnowledgeName")
public Result<List<AiragKnowledge>> queryAiragKnowledgeNameList(AiragKnowledge airagKnowledge, HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
QueryWrapper<AiragKnowledge> queryWrapper = QueryGenerator.initQueryWrapper(airagKnowledge, req.getParameterMap());
List<AiragKnowledge> list = airagKnowledgeService.list(queryWrapper);
return Result.OK(list);
}
... ... @@ -141,6 +172,7 @@ public class AiragLogController extends JeecgController<AiragLog, IAiragLogServi
if(questionCount > 0){
return Result.error("重复问题不能存入");
}
airagLog.setIsStorage(1);
airagLogService.saveToQuestionLibrary(airagLog);
return Result.OK("存入问题库成功!");
... ...
... ... @@ -165,7 +165,7 @@ public class EmbeddingsController {
@RequiresPermissions("embeddings:embeddings:deleteBatch")
@DeleteMapping(value = "/deleteBatch")
public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) {
// this.embeddingsService.removeByIds(Arrays.asList(ids.split(",")));
embeddingsService.removeByIds(Arrays.asList(ids.split(",")));
return Result.OK("批量删除成功!");
}
... ...
... ... @@ -2,9 +2,12 @@ package org.jeecg.modules.airag.app.controller;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import dev.langchain4j.internal.Json;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.authz.annotation.RequiresPermissions;
import org.jeecg.common.api.vo.Result;
import org.jeecg.common.aspect.annotation.AutoLog;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.JsonUtils;
... ... @@ -15,6 +18,7 @@ import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
... ... @@ -111,6 +115,18 @@ public class QuestionEmbeddingController {
return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败");
}
/**
* 批量删除
*
* @param ids
* @return
*/
@DeleteMapping(value = "/deleteBatch")
public Result<String> deleteBatch(@RequestParam(name = "ids", required = true) String ids) {
questionEmbeddingService.removeByIds(Arrays.asList(ids.split(",")));
return Result.OK("批量删除成功!");
}
@PostMapping("/uploadZip")
@Transactional(rollbackFor = {Exception.class})
public Result<?> uploadZip(
... ...
... ... @@ -102,6 +102,9 @@ public class AiragLog implements Serializable {
// 新增:临时字段(非数据库字段)
@TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
private String createTime_end;
// 新增:临时字段(非数据库字段)
@TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
private String knowledgeId;
@TableField(exist = false) // MyBatis-Plus 标记该字段不存在于数据库表中
private String createTimeStr;
... ...
... ... @@ -24,5 +24,6 @@ public class Embeddings {
private String knowledgeId; // 新增知识库ID字段
private String docId; // 新增文档ID字段
private String index; // 新增索引位置字段
private String knowledgeName;
}
\ No newline at end of file
... ...
... ... @@ -52,4 +52,6 @@ public class QuestionEmbedding {
private String knowledgeId;
}
\ No newline at end of file
... ...
... ... @@ -19,6 +19,6 @@ public interface AiragLogMapper extends BaseMapper<AiragLog> {
IPage<AiragLog> pageList(@Param("param1") AiragLog airagLog, Page<AiragLog> page);
int updataIsStorage(@Param("param1") int isStorage);
int updataIsStorage(@Param("param1") int isStorage, @Param("param2") String id);
}
... ...
... ... @@ -11,11 +11,13 @@ import com.pgvector.PGvector;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.postgresql.util.PGobject;
import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;
@Component
@Slf4j
... ... @@ -79,6 +81,23 @@ public class PgVectorMapper {
}
// 2. 获取知识库名称映射
Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results);
// 3. 设置知识库名称并处理空值
for (Embeddings record : results) {
String knowledgeId = record.getKnowledgeId();
String name = knowledgeNameMap.get(knowledgeId);
record.setKnowledgeName(name != null ? name : "");
}
// 4. 安全排序(处理空值)
results.sort(Comparator
.comparing(Embeddings::getKnowledgeName,
Comparator.nullsLast(Comparator.naturalOrder()))
.thenComparing(Embeddings::getDocName,
Comparator.nullsLast(Comparator.naturalOrder())));
// 执行计数查询
int total = 0;
try(Connection conn = getConnection();
... ... @@ -212,6 +231,36 @@ public class PgVectorMapper {
}
}
// 批量删除方法
public int deleteByIds(List<String> ids) {
if (ids == null || ids.isEmpty()) {
return 0;
}
String sql = "DELETE FROM embeddings WHERE embedding_id IN (";
StringBuilder placeholders = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
placeholders.append("?");
if (i < ids.size() - 1) {
placeholders.append(",");
}
}
sql += placeholders.toString() + ")";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
for (int i = 0; i < ids.size(); i++) {
stmt.setString(i + 1, ids.get(i));
}
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("批量删除向量记录失败, IDs: {}", ids, e);
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
}
}
// 向量相似度搜索
public List<Embeddings> similaritySearch(float[] vector, int limit) {
String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?";
... ... @@ -286,4 +335,50 @@ public class PgVectorMapper {
return embedding;
}
// 获取知识库名称映射
private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) {
// 提取所有知识库ID
Set<String> knowledgeIds = records.stream()
.map(Embeddings::getKnowledgeId)
.filter(Objects::nonNull)
.collect(Collectors.toSet());
if (knowledgeIds.isEmpty()) {
return Collections.emptyMap();
}
// 从 MySQL 查询知识库名称
Map<String, String> knowledgeNameMap = new HashMap<>();
try (Connection mysqlConn = getMysqlConnection()) {
String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
int index = 1;
for (String id : knowledgeIds) {
stmt.setString(index++, id);
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
}
}
}
} catch (SQLException e) {
log.error("查询知识库名称失败", e);
}
return knowledgeNameMap;
}
// 获取 MySQL 连接
private Connection getMysqlConnection() throws SQLException {
String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
String user = "root";
String password = "123456";
return DriverManager.getConnection(url, user, password);
}
}
\ No newline at end of file
... ...
... ... @@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response;
import io.minio.messages.Metadata;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
... ... @@ -20,6 +21,7 @@ import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;
@Component
@Slf4j
... ... @@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper {
throw new RuntimeException("查询数据时发生数据库错误", e);
}
// 2. 获取知识库名称映射
Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results);
// 3. 设置知识库名称并处理空值
for (QuestionEmbedding record : results) {
String knowledgeId = record.getKnowledgeId();
String name = knowledgeNameMap.get(knowledgeId);
record.setKnowledgeName(name != null ? name : "");
}
// 4. 安全排序(处理空值)
results.sort(Comparator
.comparing(QuestionEmbedding::getKnowledgeName,
Comparator.nullsLast(Comparator.naturalOrder()))
.thenComparing(QuestionEmbedding::getQuestion,
Comparator.nullsLast(Comparator.naturalOrder())));
// 执行计数查询
long total = 0;
try(Connection conn = getConnection();
... ... @@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper {
}
// 批量删除方法
public int deleteByIds(List<String> ids) {
if (ids == null || ids.isEmpty()) {
return 0;
}
String sql = "DELETE FROM question_embedding WHERE id IN (";
StringBuilder placeholders = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
placeholders.append("?");
if (i < ids.size() - 1) {
placeholders.append(",");
}
}
sql += placeholders.toString() + ")";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
for (int i = 0; i < ids.size(); i++) {
stmt.setString(i + 1, ids.get(i));
}
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("批量删除向量记录失败, IDs: {}", ids, e);
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
}
}
/**
* 向量相似度查询 (基于问题文本的向量)
* @param question 问题文本
... ... @@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper {
return Collections.emptyMap();
}
}
// 获取知识库名称映射
private Map<String, String> getKnowledgeNameMap(List<QuestionEmbedding> records) {
// 提取所有知识库ID
Set<String> knowledgeIds = records.stream()
.map(QuestionEmbedding::getKnowledgeId)
.filter(Objects::nonNull)
.collect(Collectors.toSet());
if (knowledgeIds.isEmpty()) {
return Collections.emptyMap();
}
// 从 MySQL 查询知识库名称
Map<String, String> knowledgeNameMap = new HashMap<>();
try (Connection mysqlConn = getMysqlConnection()) {
String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
int index = 1;
for (String id : knowledgeIds) {
stmt.setString(index++, id);
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
}
}
}
} catch (SQLException e) {
log.error("查询知识库名称失败", e);
}
return knowledgeNameMap;
}
// 获取 MySQL 连接
private Connection getMysqlConnection() throws SQLException {
String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
String user = "root";
String password = "123456";
return DriverManager.getConnection(url, user, password);
}
}
\ No newline at end of file
... ...
... ... @@ -36,5 +36,6 @@
<update id="updataIsStorage">
update airag_log
set is_storage = #{isStorage}
where id = #{id}
</update>
</mapper>
\ No newline at end of file
... ...
... ... @@ -5,6 +5,7 @@ package org.jeecg.modules.airag.app.service;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.jeecg.modules.airag.app.entity.Embeddings;
import java.util.ArrayList;
import java.util.List;
/**
... ... @@ -20,4 +21,5 @@ public interface IEmbeddingsService {
int insert(Embeddings record);
int update(Embeddings record);
Embeddings findById(String id);
int removeByIds(List<String> ids);
}
... ...
... ... @@ -13,6 +13,7 @@ public interface IQuestionEmbeddingService {
QuestionEmbedding findById(String id);
int insert(QuestionEmbedding record);
int update(QuestionEmbedding record);
int removeByIds(List<String> ids);
int deleteById(String id);
List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity);
List<QuestionEmbedding> similaritySearch(float[] vector, int limit);
... ...
... ... @@ -56,8 +56,9 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
QuestionEmbedding questionEmbedding = new QuestionEmbedding();
questionEmbedding.setQuestion(log.getQuestion());
questionEmbedding.setAnswer(log.getAnswer());
questionEmbedding.setKnowledgeId(log.getKnowledgeId());
questionEmbeddingMapper.insert(questionEmbedding);
airagLogMapper.updataIsStorage(log.getIsStorage());
airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId());
}
... ...
... ... @@ -42,4 +42,9 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService {
public Embeddings findById(String id) {
return pgVectorMapper.findById(id);
}
@Override
public int removeByIds(List<String> ids) {
return pgVectorMapper.deleteByIds(ids);
}
}
... ...
... ... @@ -95,6 +95,11 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
}
@Override
public int removeByIds(List<String> ids) {
return questionEmbeddingMapper.deleteByIds(ids);
}
@Override
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity);
}
... ...