|
...
|
...
|
@@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException; |
|
|
|
import com.fasterxml.jackson.core.type.TypeReference;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
|
|
import com.pgvector.PGvector;
|
|
|
|
import dev.langchain4j.data.embedding.Embedding;
|
|
|
|
import dev.langchain4j.model.output.Response;
|
|
|
|
import lombok.extern.slf4j.Slf4j;
|
|
|
|
import org.apache.commons.lang3.StringUtils;
|
|
|
|
import org.jeecg.modules.airag.app.entity.Embeddings;
|
|
|
|
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
|
|
|
|
import org.jeecg.modules.airag.app.utils.AiModelUtils;
|
|
|
|
import org.postgresql.util.PGobject;
|
|
|
|
import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
import org.springframework.beans.factory.annotation.Value;
|
|
|
|
import org.springframework.stereotype.Component;
|
|
|
|
|
|
|
|
import java.sql.*;
|
|
...
|
...
|
@@ -22,12 +27,18 @@ import java.util.stream.Collectors; |
|
|
|
@Component
|
|
|
|
@Slf4j
|
|
|
|
public class PgVectorMapper {
|
|
|
|
@Autowired
|
|
|
|
private AiModelUtils aiModelUtils;
|
|
|
|
|
|
|
|
// PostgreSQL连接参数(实际项目中应从配置读取)
|
|
|
|
private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
|
|
|
|
private static final String USER = "postgres";
|
|
|
|
private static final String PASSWORD = "postgres";
|
|
|
|
|
|
|
|
@Value("${jeecg.ai-chat.embedId}")
|
|
|
|
private String embedId;
|
|
|
|
|
|
|
|
|
|
|
|
// 获取数据库连接
|
|
|
|
private Connection getConnection() throws SQLException {
|
|
|
|
return DriverManager.getConnection(URL, USER, PASSWORD);
|
|
...
|
...
|
@@ -133,22 +144,40 @@ public class PgVectorMapper { |
|
|
|
}
|
|
|
|
return null;
|
|
|
|
}
|
|
|
|
// 查询所有记录
|
|
|
|
public Integer findEmbeddingCount(Embeddings embeddings) {
|
|
|
|
|
|
|
|
// 插入新向量记录
|
|
|
|
public int insert(Embeddings record) {
|
|
|
|
/*Map<String, Object> metadata = new LinkedHashMap<>();
|
|
|
|
StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from embeddings where 1 = 1");
|
|
|
|
List<Object> params = new ArrayList<>();
|
|
|
|
|
|
|
|
if(StringUtils.isNotBlank(embeddings.getText())){
|
|
|
|
sql.append(" AND text = ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
|
|
|
|
params.add(embeddings.getText());
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
try(Connection conn = getConnection();
|
|
|
|
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
|
|
|
|
// 设置参数值
|
|
|
|
for (int i = 0; i < params.size(); i++) {
|
|
|
|
stmt.setObject(i + 1, params.get(i));
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
return rs.getInt("total_count");
|
|
|
|
}
|
|
|
|
return 0;
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询所有记录失败", e);
|
|
|
|
throw new RuntimeException("查询数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
// 按固定顺序添加字段
|
|
|
|
metadata.put("docId", UUID.randomUUID().toString());
|
|
|
|
metadata.put("knowledgeId", getKnowledgeId(record)); // 使用统一方法获取
|
|
|
|
metadata.put("docName", record.getDocName());
|
|
|
|
metadata.put("index", 0); // 确保是整数
|
|
|
|
record.setMetadata(metadata);
|
|
|
|
*/
|
|
|
|
// 自动生成向量(这里需要调用嵌入模型)
|
|
|
|
float[] embedding = generateEmbedding(record.getText());
|
|
|
|
record.setEmbedding(embedding);
|
|
|
|
}
|
|
|
|
|
|
|
|
// 插入新向量记录
|
|
|
|
public int insert(Embeddings record) {
|
|
|
|
|
|
|
|
String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";
|
|
|
|
|
|
...
|
...
|
@@ -157,7 +186,8 @@ public class PgVectorMapper { |
|
|
|
|
|
|
|
stmt.setString(1, UUID.randomUUID().toString());
|
|
|
|
// stmt.setObject(2, new PGvector(record.getEmbedding()));
|
|
|
|
stmt.setObject(2, new PGvector(record.getEmbedding()));
|
|
|
|
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
|
|
|
|
stmt.setObject(2, embedding.content().vector());
|
|
|
|
stmt.setObject(3, record.getText());
|
|
|
|
stmt.setObject(4, toJson(record.getMetadata()));
|
|
|
|
|
|
...
|
...
|
@@ -167,7 +197,6 @@ public class PgVectorMapper { |
|
|
|
throw new RuntimeException("插入向量数据时发生数据库错误", e);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 更新向量记录
|
|
|
|
public int update(Embeddings record) {
|
|
|
|
String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?";
|
|
...
|
...
|
@@ -187,12 +216,7 @@ public class PgVectorMapper { |
|
|
|
System.out.println("原始数据: " + mataData);
|
|
|
|
|
|
|
|
|
|
|
|
PGobject jsonObject = new PGobject();/*
|
|
|
|
System.out.println("原始数据: " + mataData); // 检查原始对象
|
|
|
|
String jsonStr = mataData.toJSONString();
|
|
|
|
System.out.println("JSON字符串: " + jsonStr); // 检查序列化后的JSON
|
|
|
|
jsonObject.setValue(jsonStr);
|
|
|
|
System.out.println("存入后的值: " + jsonObject.getValue());*/ // 检查存入后的值
|
|
|
|
PGobject jsonObject = new PGobject();
|
|
|
|
jsonObject.setType("json");
|
|
|
|
jsonObject.setValue(mataData.toJSONString());
|
|
|
|
stmt.setObject(1, new PGvector(record.getEmbedding()));
|
|
...
|
...
|
@@ -308,68 +332,4 @@ public class PgVectorMapper { |
|
|
|
return Collections.emptyMap();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
// 自动生成嵌入向量的方法(需根据您的嵌入模型实现)
|
|
|
|
private float[] generateEmbedding(String text) {
|
|
|
|
// 改为生成 768 维向量
|
|
|
|
float[] embedding = new float[1024]; // OpenAI 标准维度是 1536,这里改为 768
|
|
|
|
|
|
|
|
// 实际项目中应调用嵌入模型 API
|
|
|
|
// 例如:return embeddingClient.generate(text, 768);
|
|
|
|
|
|
|
|
// 临时实现:生成随机向量(仅用于演示)
|
|
|
|
log.warn("使用随机向量生成 - 实际项目中应替换为真实模型调用");
|
|
|
|
Random random = new Random();
|
|
|
|
for (int i = 0; i < embedding.length; i++) {
|
|
|
|
embedding[i] = random.nextFloat() * 2 - 1;
|
|
|
|
}
|
|
|
|
return embedding;
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// 获取知识库名称映射
|
|
|
|
private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) {
|
|
|
|
// 提取所有知识库ID
|
|
|
|
Set<String> knowledgeIds = records.stream()
|
|
|
|
.map(Embeddings::getKnowledgeId)
|
|
|
|
.filter(Objects::nonNull)
|
|
|
|
.collect(Collectors.toSet());
|
|
|
|
|
|
|
|
if (knowledgeIds.isEmpty()) {
|
|
|
|
return Collections.emptyMap();
|
|
|
|
}
|
|
|
|
|
|
|
|
// 从 MySQL 查询知识库名称
|
|
|
|
Map<String, String> knowledgeNameMap = new HashMap<>();
|
|
|
|
try (Connection mysqlConn = getMysqlConnection()) {
|
|
|
|
String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));
|
|
|
|
String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);
|
|
|
|
|
|
|
|
try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {
|
|
|
|
int index = 1;
|
|
|
|
for (String id : knowledgeIds) {
|
|
|
|
stmt.setString(index++, id);
|
|
|
|
}
|
|
|
|
|
|
|
|
try (ResultSet rs = stmt.executeQuery()) {
|
|
|
|
while (rs.next()) {
|
|
|
|
knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
} catch (SQLException e) {
|
|
|
|
log.error("查询知识库名称失败", e);
|
|
|
|
}
|
|
|
|
|
|
|
|
return knowledgeNameMap;
|
|
|
|
}
|
|
|
|
|
|
|
|
// 获取 MySQL 连接
|
|
|
|
private Connection getMysqlConnection() throws SQLException {
|
|
|
|
String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";
|
|
|
|
String user = "root";
|
|
|
|
String password = "123456";
|
|
|
|
return DriverManager.getConnection(url, user, password);
|
|
|
|
}
|
|
|
|
|
|
|
|
} |
|
|
\ No newline at end of file |
...
|
...
|
|