|
...
|
...
|
@@ -12,6 +12,7 @@ import dev.langchain4j.model.output.Response; |
|
|
|
import io.minio.messages.Metadata;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import org.jeecg.modules.airag.app.entity.Embeddings;
|
|
|
|
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
|
|
|
|
import org.jeecg.modules.airag.app.utils.AiModelUtils;
|
|
|
|
import org.postgresql.util.PGobject;
|
|
...
|
...
|
@@ -20,6 +21,7 @@ import org.springframework.stereotype.Component; |
|
|
|
|
|
|
|
import java.sql.*;
|
|
|
|
import java.util.*;
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
|
|
|
|
@Component
|
|
|
|
@Slf4j
|
|
...
|
...
|
@@ -89,6 +91,23 @@ public class QuestionEmbeddingMapper { |
|
|
|
throw new RuntimeException("查询数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
// 2. 获取知识库名称映射
|
|
|
|
Map<String, String> knowledgeNameMap = getKnowledgeNameMap(results);
|
|
|
|
|
|
|
|
// 3. 设置知识库名称并处理空值
|
|
|
|
for (QuestionEmbedding record : results) {
|
|
|
|
String knowledgeId = record.getKnowledgeId();
|
|
|
|
String name = knowledgeNameMap.get(knowledgeId);
|
|
|
|
record.setKnowledgeName(name != null ? name : "");
|
|
|
|
}
|
|
|
|
|
|
|
|
// 4. 安全排序(处理空值)
|
|
|
|
results.sort(Comparator
|
|
|
|
.comparing(QuestionEmbedding::getKnowledgeName,
|
|
|
|
Comparator.nullsLast(Comparator.naturalOrder()))
|
|
|
|
.thenComparing(QuestionEmbedding::getQuestion,
|
|
|
|
Comparator.nullsLast(Comparator.naturalOrder())));
|
|
|
|
|
|
|
|
// 执行计数查询
|
|
|
|
long total = 0;
|
|
|
|
try(Connection conn = getConnection();
|
|
...
|
...
|
@@ -236,6 +255,37 @@ public class QuestionEmbeddingMapper { |
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 批量删除方法
|
|
|
|
public int deleteByIds(List<String> ids) {
|
|
|
|
if (ids == null || ids.isEmpty()) {
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
String sql = "DELETE FROM question_embedding WHERE id IN (";
|
|
|
|
StringBuilder placeholders = new StringBuilder();
|
|
|
|
for (int i = 0; i < ids.size(); i++) {
|
|
|
|
placeholders.append("?");
|
|
|
|
if (i < ids.size() - 1) {
|
|
|
|
placeholders.append(",");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sql += placeholders.toString() + ")";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
for (int i = 0; i < ids.size(); i++) {
|
|
|
|
stmt.setString(i + 1, ids.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
return stmt.executeUpdate();
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("批量删除向量记录失败, IDs: {}", ids, e);
|
|
|
|
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 向量相似度查询 (基于问题文本的向量)
|
|
|
|
* @param question 问题文本
|
|
...
|
...
|
@@ -376,4 +426,49 @@ public class QuestionEmbeddingMapper { |
|
|
|
return Collections.emptyMap();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取知识库名称映射
|
|
|
|
private Map<String, String> getKnowledgeNameMap(List<QuestionEmbedding> records) {
|
|
|
|
// 提取所有知识库ID
|
|
|
|
Set<String> knowledgeIds = records.stream()
|
|
|
|
.map(QuestionEmbedding::getKnowledgeId)
|
|
|
|
.filter(Objects::nonNull)
|
|
|
|
.collect(Collectors.toSet());
|
|
|
|
|
|
|
|
if (knowledgeIds.isEmpty()) {
|
|
|
|
return Collections.emptyMap();
|
|
|
|
}
|
|
|
|
|
|
|
|
// 从 MySQL 查询知识库名称
|
|
|
|
Map<String, String> knowledgeNameMap = new HashMap<>();
|
|
|
|
try (Connection mysqlConn = getMysqlConnection()) {
|
|
|
|
String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
|
|
|
|
String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
|
|
|
|
|
|
|
|
try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
|
|
|
|
int index = 1;
|
|
|
|
for (String id : knowledgeIds) {
|
|
|
|
stmt.setString(index++, id);
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询知识库名称失败", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
return knowledgeNameMap;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取 MySQL 连接
|
|
|
|
private Connection getMysqlConnection() throws SQLException {
|
|
|
|
String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
|
|
|
|
String user = "root";
|
|
|
|
String password = "123456";
|
|
|
|
return DriverManager.getConnection(url, user, password);
|
|
|
|
}
|
|
|
|
} |
|
|
\ No newline at end of file |
...
|
...
|
|