作者 dong

pgvector修正

... ... @@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.sql.*;
... ... @@ -22,12 +27,18 @@ import java.util.stream.Collectors;
@Component
@Slf4j
public class PgVectorMapper {
@Autowired
private AiModelUtils aiModelUtils;
// PostgreSQL连接参数(实际项目中应从配置读取)
private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
private static final String USER = "postgres";
private static final String PASSWORD = "postgres";
@Value("${jeecg.ai-chat.embedId}")
private String embedId;
// 获取数据库连接
private Connection getConnection() throws SQLException {
return DriverManager.getConnection(URL, USER, PASSWORD);
... ... @@ -133,22 +144,40 @@ public class PgVectorMapper {
}
return null;
}
// 查询所有记录
public Integer findEmbeddingCount(Embeddings embeddings) {
// 插入新向量记录
public int insert(Embeddings record) {
/*Map<String, Object> metadata = new LinkedHashMap<>();
StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from embeddings where 1 = 1");
List<Object> params = new ArrayList<>();
if(StringUtils.isNotBlank(embeddings.getText())){
sql.append(" AND text = ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
params.add(embeddings.getText());
}
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
// 设置参数值
for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, params.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
return rs.getInt("total_count");
}
return 0;
}
} catch (SQLException e) {
log.error("查询所有记录失败", e);
throw new RuntimeException("查询数据时发生数据库错误", e);
}
// 按固定顺序添加字段
metadata.put("docId", UUID.randomUUID().toString());
metadata.put("knowledgeId", getKnowledgeId(record)); // 使用统一方法获取
metadata.put("docName", record.getDocName());
metadata.put("index", 0); // 确保是整数
record.setMetadata(metadata);
*/
// 自动生成向量(这里需要调用嵌入模型)
float[] embedding = generateEmbedding(record.getText());
record.setEmbedding(embedding);
}
// 插入新向量记录
public int insert(Embeddings record) {
String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";
... ... @@ -157,7 +186,8 @@ public class PgVectorMapper {
stmt.setString(1, UUID.randomUUID().toString());
// stmt.setObject(2, new PGvector(record.getEmbedding()));
stmt.setObject(2, new PGvector(record.getEmbedding()));
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
stmt.setObject(2, embedding.content().vector());
stmt.setObject(3, record.getText());
stmt.setObject(4, toJson(record.getMetadata()));
... ... @@ -167,7 +197,6 @@ public class PgVectorMapper {
throw new RuntimeException("插入向量数据时发生数据库错误", e);
}
}
// 更新向量记录
public int update(Embeddings record) {
String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?";
... ... @@ -187,12 +216,7 @@ public class PgVectorMapper {
System.out.println("原始数据: " + mataData);
PGobject jsonObject = new PGobject();/*
System.out.println("原始数据: " + mataData); // 检查原始对象
String jsonStr = mataData.toJSONString();
System.out.println("JSON字符串: " + jsonStr); // 检查序列化后的JSON
jsonObject.setValue(jsonStr);
System.out.println("存入后的值: " + jsonObject.getValue());*/ // 检查存入后的值
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue(mataData.toJSONString());
stmt.setObject(1, new PGvector(record.getEmbedding()));
... ... @@ -308,68 +332,4 @@ public class PgVectorMapper {
return Collections.emptyMap();
}
}
// 自动生成嵌入向量的方法(需根据您的嵌入模型实现)
private float[] generateEmbedding(String text) {
// 改为生成 768 维向量
float[] embedding = new float[1024]; // OpenAI 标准维度是 1536,这里改为 768
// 实际项目中应调用嵌入模型 API
// 例如:return embeddingClient.generate(text, 768);
// 临时实现:生成随机向量(仅用于演示)
log.warn("使用随机向量生成 - 实际项目中应替换为真实模型调用");
Random random = new Random();
for (int i = 0; i < embedding.length; i++) {
embedding[i] = random.nextFloat() * 2 - 1;
}
return embedding;
}
// 获取知识库名称映射
private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) {
// 提取所有知识库ID
Set<String> knowledgeIds = records.stream()
.map(Embeddings::getKnowledgeId)
.filter(Objects::nonNull)
.collect(Collectors.toSet());
if (knowledgeIds.isEmpty()) {
return Collections.emptyMap();
}
// 从 MySQL 查询知识库名称
Map<String, String> knowledgeNameMap = new HashMap<>();
try (Connection mysqlConn = getMysqlConnection()) {
String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
int index = 1;
for (String id : knowledgeIds) {
stmt.setString(index++, id);
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
}
}
}
} catch (SQLException e) {
log.error("查询知识库名称失败", e);
}
return knowledgeNameMap;
}
// 获取 MySQL 连接
private Connection getMysqlConnection() throws SQLException {
String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
String user = "root";
String password = "123456";
return DriverManager.getConnection(url, user, password);
}
}
\ No newline at end of file
... ...