|
|
|
package org.jeecg.modules.airag.app.mapper;
|
|
|
|
|
|
|
|
import cn.hutool.core.lang.generator.SnowflakeGenerator;
|
|
|
|
import com.alibaba.fastjson2.JSONObject;
|
|
|
|
import com.baomidou.dynamic.datasource.annotation.DS;
|
|
|
|
import com.baomidou.mybatisplus.core.metadata.IPage;
|
|
|
|
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
|
|
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
|
|
import com.fasterxml.jackson.core.type.TypeReference;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
import com.pgvector.PGvector;
|
|
|
|
import dev.langchain4j.data.embedding.Embedding;
|
|
|
|
import dev.langchain4j.model.output.Response;
|
|
|
|
import io.minio.messages.Metadata;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import org.jeecg.modules.airag.app.entity.Embeddings;
|
|
|
|
import org.apache.ibatis.annotations.Mapper;
|
|
|
|
import org.apache.ibatis.annotations.Param;
|
|
|
|
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
|
|
|
|
import org.jeecg.modules.airag.app.utils.AiModelUtils;
|
|
|
|
import org.postgresql.util.PGobject;
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
|
import org.springframework.stereotype.Component;
|
|
|
|
|
|
|
|
import java.sql.*;
|
|
|
|
import java.util.*;
|
|
|
|
import java.util.stream.Collectors;
|
|
|
|
import java.util.List;
|
|
|
|
|
|
|
|
@Component
|
|
|
|
@Slf4j
|
|
|
|
public class QuestionEmbeddingMapper {
|
|
|
|
@Mapper
|
|
|
|
@DS("pgvector")
|
|
|
|
public interface QuestionEmbeddingMapper {
|
|
|
|
Page<QuestionEmbedding> findAll(IPage<QuestionEmbedding> page, @Param("questionEmbedding") QuestionEmbedding questionEmbedding);
|
|
|
|
|
|
|
|
@Autowired
|
|
|
|
private AiModelUtils aiModelUtils;
|
|
|
|
Integer findQuestionCount(@Param("questionEmbedding") QuestionEmbedding questionEmbedding);
|
|
|
|
|
|
|
|
@Value("${jeecg.ai-chat.embedId}")
|
|
|
|
private String embedId;
|
|
|
|
// PostgreSQL连接参数(应与项目配置一致)
|
|
|
|
private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
|
|
|
|
private static final String USER = "postgres";
|
|
|
|
private static final String PASSWORD = "postgres";
|
|
|
|
QuestionEmbedding findById(@Param("id") String id);
|
|
|
|
@DS("pgvector")
|
|
|
|
int insert(@Param("record") QuestionEmbedding record);
|
|
|
|
|
|
|
|
// 获取数据库连接
|
|
|
|
private Connection getConnection() throws SQLException {
|
|
|
|
return DriverManager.getConnection(URL, USER, PASSWORD);
|
|
|
|
}
|
|
|
|
int update(@Param("record") QuestionEmbedding record);
|
|
|
|
|
|
|
|
// 查询所有记录
|
|
|
|
public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, int pageNo, int pageSize) {
|
|
|
|
List<QuestionEmbedding> results = new ArrayList<>();
|
|
|
|
StringBuilder sql = new StringBuilder("select * from question_embedding where 1 = 1");
|
|
|
|
StringBuilder countSql = new StringBuilder("select count(1) from question_embedding where 1 = 1");
|
|
|
|
List<Object> params = new ArrayList<>();
|
|
|
|
List<Object> countParams = new ArrayList<>();
|
|
|
|
int deleteById(@Param("id") String id);
|
|
|
|
|
|
|
|
if (StringUtils.isNotBlank(questionEmbedding.getKnowledgeId())) {
|
|
|
|
sql.append(" AND metadata ->> 'knowledgeId' = ?");
|
|
|
|
countSql.append(" AND metadata ->> 'knowledgeId' = ?");
|
|
|
|
params.add(questionEmbedding.getKnowledgeId());
|
|
|
|
countParams.add(questionEmbedding.getKnowledgeId());
|
|
|
|
}
|
|
|
|
if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){
|
|
|
|
sql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
|
|
|
|
countSql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
|
|
|
|
params.add("%" + questionEmbedding.getQuestion() + "%");
|
|
|
|
countParams.add("%" + questionEmbedding.getQuestion() + "%");
|
|
|
|
}
|
|
|
|
int deleteByIds(@Param("ids") List<String> ids);
|
|
|
|
|
|
|
|
if(StringUtils.isNotBlank(questionEmbedding.getAnswer())){
|
|
|
|
sql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
|
|
|
|
countSql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
|
|
|
|
params.add("%" + questionEmbedding.getAnswer() + "%");
|
|
|
|
countParams.add("%" + questionEmbedding.getAnswer() + "%");
|
|
|
|
}
|
|
|
|
|
|
|
|
sql.append(" ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC");
|
|
|
|
|
|
|
|
// 添加分页
|
|
|
|
sql.append(" LIMIT ? OFFSET ?");
|
|
|
|
params.add(pageSize);
|
|
|
|
params.add((pageNo - 1) * pageSize);
|
|
|
|
|
|
|
|
|
|
|
|
try(Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
|
|
|
|
// 设置参数值
|
|
|
|
for (int i = 0; i < params.size(); i++) {
|
|
|
|
stmt.setObject(i + 1, params.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
results.add(mapRowToQuestionEmbedding(rs));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询所有记录失败", e);
|
|
|
|
throw new RuntimeException("查询数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
// 执行计数查询
|
|
|
|
long total = 0;
|
|
|
|
try(Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(countSql.toString())){
|
|
|
|
// 设置参数值
|
|
|
|
for (int i = 0; i < countParams.size(); i++) {
|
|
|
|
stmt.setObject(i + 1, countParams.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
if (rs.next()) {
|
|
|
|
total = rs.getLong(1); // 直接获取count值
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询记录总数失败", e);
|
|
|
|
throw new RuntimeException("查询记录总数时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
Page<QuestionEmbedding> page = new Page<>();
|
|
|
|
page.setRecords(results);
|
|
|
|
page.setTotal(total);
|
|
|
|
return page;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 查询所有记录
|
|
|
|
public Integer findQuestionCount(QuestionEmbedding questionEmbedding) {
|
|
|
|
|
|
|
|
StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from question_embedding where 1 = 1");
|
|
|
|
List<Object> params = new ArrayList<>();
|
|
|
|
|
|
|
|
if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){
|
|
|
|
sql.append(" AND question = ?");
|
|
|
|
params.add(questionEmbedding.getQuestion());
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try(Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
|
|
|
|
// 设置参数值
|
|
|
|
for (int i = 0; i < params.size(); i++) {
|
|
|
|
stmt.setObject(i + 1, params.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
return rs.getInt("total_count");
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询所有记录失败", e);
|
|
|
|
throw new RuntimeException("查询数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// 根据ID查询单个记录
|
|
|
|
public QuestionEmbedding findById(String id) {
|
|
|
|
String sql = "SELECT * FROM question_embedding WHERE id = ?";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
stmt.setString(1, id);
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
if (rs.next()) {
|
|
|
|
return mapRowToQuestionEmbedding(rs);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("根据ID查询记录失败, ID: {}", id, e);
|
|
|
|
throw new RuntimeException("根据ID查询时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 插入新记录
|
|
|
|
public int insert(QuestionEmbedding record) {
|
|
|
|
String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)";
|
|
|
|
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
stmt.setString(1, UUID.randomUUID().toString());
|
|
|
|
stmt.setString(2, record.getText());
|
|
|
|
stmt.setString(3, record.getQuestion());
|
|
|
|
stmt.setString(4, record.getAnswer());
|
|
|
|
PGobject jsonObject = new PGobject();
|
|
|
|
jsonObject.setType("json");
|
|
|
|
jsonObject.setValue(record.getMetadata());
|
|
|
|
stmt.setObject(5, jsonObject);
|
|
|
|
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
|
|
|
|
stmt.setObject(6, embedding.content().vector());
|
|
|
|
return stmt.executeUpdate();
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("插入记录失败: {}", record, e);
|
|
|
|
throw new RuntimeException("插入数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 更新记录
|
|
|
|
public int update(QuestionEmbedding record) {
|
|
|
|
String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
|
|
|
|
stmt.setString(1, record.getText());
|
|
|
|
stmt.setString(2, record.getQuestion());
|
|
|
|
stmt.setString(3, record.getAnswer());
|
|
|
|
stmt.setObject(4, record.getMetadata());
|
|
|
|
|
|
|
|
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
|
|
|
|
stmt.setObject(5, embedding.content().vector());
|
|
|
|
|
|
|
|
stmt.setString(6, record.getId());
|
|
|
|
|
|
|
|
return stmt.executeUpdate();
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("更新记录失败: {}", record, e);
|
|
|
|
throw new RuntimeException("更新数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 批量删除方法
|
|
|
|
public int deleteByIds(List<String> ids) {
|
|
|
|
if (ids == null || ids.isEmpty()) {
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
|
|
|
|
String sql = "DELETE FROM question_embedding WHERE id IN (";
|
|
|
|
StringBuilder placeholders = new StringBuilder();
|
|
|
|
for (int i = 0; i < ids.size(); i++) {
|
|
|
|
placeholders.append("?");
|
|
|
|
if (i < ids.size() - 1) {
|
|
|
|
placeholders.append(",");
|
|
|
|
}
|
|
|
|
}
|
|
|
|
sql += placeholders.toString() + ")";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
for (int i = 0; i < ids.size(); i++) {
|
|
|
|
stmt.setString(i + 1, ids.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
return stmt.executeUpdate();
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("批量删除向量记录失败, IDs: {}", ids, e);
|
|
|
|
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 向量相似度查询 (基于问题文本的向量)
|
|
|
|
* @param question 问题文本
|
|
|
|
* @param limit 返回结果数量
|
|
|
|
* @param minSimilarity 最小相似度阈值(0-1)
|
|
|
|
* @return 相似问答列表(按相似度降序)
|
|
|
|
*/
|
|
|
|
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
|
|
|
|
List<QuestionEmbedding> results = new ArrayList<>();
|
|
|
|
|
|
|
|
// 1. 参数校验
|
|
|
|
if (minSimilarity < 0 || minSimilarity > 1) {
|
|
|
|
throw new IllegalArgumentException("相似度阈值必须在0到1之间");
|
|
|
|
}
|
|
|
|
|
|
|
|
// 2. 获取问题的嵌入向量
|
|
|
|
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, question);
|
|
|
|
float[] queryVector = embeddingResponse.content().vector();
|
|
|
|
// 3. 计算最大允许距离(1 - 相似度阈值)
|
|
|
|
double maxDistance = 1 - minSimilarity;
|
|
|
|
|
|
|
|
// 4. 执行向量相似度查询
|
|
|
|
String sql = "SELECT *, embedding <-> ? AS distance " +
|
|
|
|
"FROM question_embedding " +
|
|
|
|
"WHERE embedding <-> ? < ? " + // 距离小于阈值
|
|
|
|
"ORDER BY distance ASC " + // 按距离升序
|
|
|
|
"LIMIT ?";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
// 设置参数
|
|
|
|
PGvector vector = new PGvector(queryVector);
|
|
|
|
stmt.setObject(1, vector);
|
|
|
|
stmt.setObject(2, vector);
|
|
|
|
stmt.setDouble(3, maxDistance);
|
|
|
|
stmt.setInt(4, limit);
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
|
|
|
|
// 计算相似度(1 - 距离)
|
|
|
|
double distance = rs.getDouble("distance");
|
|
|
|
double similarity = 1 - distance;
|
|
|
|
record.setSimilarity(similarity);
|
|
|
|
results.add(record);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("向量相似度查询失败", e);
|
|
|
|
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
|
|
|
/**
|
|
|
|
* 向量相似度查询 (直接使用向量)
|
|
|
|
* @param vector 查询向量
|
|
|
|
* @param limit 返回结果数量
|
|
|
|
* @return 相似问答列表(按相似度降序)
|
|
|
|
*/
|
|
|
|
public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) {
|
|
|
|
List<QuestionEmbedding> results = new ArrayList<>();
|
|
|
|
String sql = "SELECT *, embedding <-> ? AS similarity " +
|
|
|
|
"FROM question_embedding " +
|
|
|
|
"ORDER BY similarity ASC " +
|
|
|
|
"LIMIT ?";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
stmt.setObject(1, new PGvector(vector));
|
|
|
|
stmt.setInt(2, limit);
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
|
|
|
|
double similarity = 1 - rs.getDouble("similarity");
|
|
|
|
record.setSimilarity(similarity);
|
|
|
|
results.add(record);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("向量相似度查询失败", e);
|
|
|
|
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
return results;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 根据ID删除记录
|
|
|
|
public int deleteById(String id) {
|
|
|
|
String sql = "DELETE FROM question_embedding WHERE id = ?";
|
|
|
|
|
|
|
|
try (Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql)) {
|
|
|
|
|
|
|
|
stmt.setString(1, id);
|
|
|
|
return stmt.executeUpdate();
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("删除记录失败, ID: {}", id, e);
|
|
|
|
throw new RuntimeException("删除数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 将ResultSet行映射为QuestionEmbedding对象
|
|
|
|
private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException {
|
|
|
|
QuestionEmbedding record = new QuestionEmbedding();
|
|
|
|
record.setId(rs.getString("id"));
|
|
|
|
record.setText(rs.getString("text"));
|
|
|
|
record.setQuestion(rs.getString("question"));
|
|
|
|
record.setAnswer(rs.getString("answer"));
|
|
|
|
|
|
|
|
String metadataJson = rs.getString("metadata");
|
|
|
|
if (StringUtils.isNotBlank(metadataJson)) {
|
|
|
|
record.setMetadata(metadataJson);
|
|
|
|
}
|
|
|
|
|
|
|
|
return record;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 将Map转换为JSON字符串
|
|
|
|
private String toJson(Map<String, Object> map) {
|
|
|
|
try {
|
|
|
|
return new ObjectMapper().writeValueAsString(map);
|
|
|
|
} catch (JsonProcessingException e) {
|
|
|
|
log.error("元数据转换为JSON失败", e);
|
|
|
|
return "{}";
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 将JSON字符串转换为Map
|
|
|
|
private Map<String, Object> fromJson(String json) {
|
|
|
|
try {
|
|
|
|
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {});
|
|
|
|
} catch (JsonProcessingException e) {
|
|
|
|
log.error("JSON转换为元数据失败", e);
|
|
|
|
return Collections.emptyMap();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
List<QuestionEmbedding> similaritySearchByQuestion(@Param("vector") float[] vector,
|
|
|
|
@Param("limit") int limit,
|
|
|
|
@Param("minSimilarity") Double minSimilarity);
|
|
|
|
|
|
|
|
List<QuestionEmbedding> similaritySearch(@Param("vector") float[] vector,
|
|
|
|
@Param("limit") int limit);
|
|
|
|
} |
|
|
\ No newline at end of file |
...
|
...
|
|