作者 dong

pgvector修正

@@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException; @@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException;
8 import com.fasterxml.jackson.core.type.TypeReference; 8 import com.fasterxml.jackson.core.type.TypeReference;
9 import com.fasterxml.jackson.databind.ObjectMapper; 9 import com.fasterxml.jackson.databind.ObjectMapper;
10 import com.pgvector.PGvector; 10 import com.pgvector.PGvector;
  11 +import dev.langchain4j.data.embedding.Embedding;
  12 +import dev.langchain4j.model.output.Response;
11 import lombok.extern.slf4j.Slf4j; 13 import lombok.extern.slf4j.Slf4j;
12 import org.apache.commons.lang3.StringUtils; 14 import org.apache.commons.lang3.StringUtils;
13 import org.jeecg.modules.airag.app.entity.Embeddings; 15 import org.jeecg.modules.airag.app.entity.Embeddings;
14 import org.jeecg.modules.airag.app.entity.QuestionEmbedding; 16 import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
  17 +import org.jeecg.modules.airag.app.utils.AiModelUtils;
15 import org.postgresql.util.PGobject; 18 import org.postgresql.util.PGobject;
  19 +import org.springframework.beans.factory.annotation.Autowired;
  20 +import org.springframework.beans.factory.annotation.Value;
16 import org.springframework.stereotype.Component; 21 import org.springframework.stereotype.Component;
17 22
18 import java.sql.*; 23 import java.sql.*;
@@ -22,12 +27,18 @@ import java.util.stream.Collectors; @@ -22,12 +27,18 @@ import java.util.stream.Collectors;
22 @Component 27 @Component
23 @Slf4j 28 @Slf4j
24 public class PgVectorMapper { 29 public class PgVectorMapper {
  30 + @Autowired
  31 + private AiModelUtils aiModelUtils;
25 32
26 // PostgreSQL连接参数(实际项目中应从配置读取) 33 // PostgreSQL连接参数(实际项目中应从配置读取)
27 private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres"; 34 private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
28 private static final String USER = "postgres"; 35 private static final String USER = "postgres";
29 private static final String PASSWORD = "postgres"; 36 private static final String PASSWORD = "postgres";
30 37
  38 + @Value("${jeecg.ai-chat.embedId}")
  39 + private String embedId;
  40 +
  41 +
31 // 获取数据库连接 42 // 获取数据库连接
32 private Connection getConnection() throws SQLException { 43 private Connection getConnection() throws SQLException {
33 return DriverManager.getConnection(URL, USER, PASSWORD); 44 return DriverManager.getConnection(URL, USER, PASSWORD);
@@ -133,22 +144,40 @@ public class PgVectorMapper { @@ -133,22 +144,40 @@ public class PgVectorMapper {
133 } 144 }
134 return null; 145 return null;
135 } 146 }
  147 + // 查询所有记录
  148 + public Integer findEmbeddingCount(Embeddings embeddings) {
136 149
137 - // 插入新向量记录  
138 - public int insert(Embeddings record) {  
139 - /*Map<String, Object> metadata = new LinkedHashMap<>(); 150 + StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from embeddings where 1 = 1");
  151 + List<Object> params = new ArrayList<>();
  152 +
  153 + if(StringUtils.isNotBlank(embeddings.getText())){
  154 + sql.append(" AND text = ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
  155 + params.add(embeddings.getText());
  156 + }
  157 +
  158 +
  159 + try(Connection conn = getConnection();
  160 + PreparedStatement stmt = conn.prepareStatement(sql.toString())){
  161 + // 设置参数值
  162 + for (int i = 0; i < params.size(); i++) {
  163 + stmt.setObject(i + 1, params.get(i));
  164 + }
  165 +
  166 + try (ResultSet rs = stmt.executeQuery()) {
  167 + while (rs.next()) {
  168 + return rs.getInt("total_count");
  169 + }
  170 + return 0;
  171 + }
  172 + } catch (SQLException e) {
  173 + log.error("查询所有记录失败", e);
  174 + throw new RuntimeException("查询数据时发生数据库错误", e);
  175 + }
140 176
141 - // 按固定顺序添加字段  
142 - metadata.put("docId", UUID.randomUUID().toString());  
143 - metadata.put("knowledgeId", getKnowledgeId(record)); // 使用统一方法获取  
144 - metadata.put("docName", record.getDocName());  
145 - metadata.put("index", 0); // 确保是整数  
146 - record.setMetadata(metadata);  
147 -*/  
148 - // 自动生成向量(这里需要调用嵌入模型)  
149 - float[] embedding = generateEmbedding(record.getText());  
150 - record.setEmbedding(embedding); 177 + }
151 178
  179 + // 插入新向量记录
  180 + public int insert(Embeddings record) {
152 181
153 String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)"; 182 String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";
154 183
@@ -157,7 +186,8 @@ public class PgVectorMapper { @@ -157,7 +186,8 @@ public class PgVectorMapper {
157 186
158 stmt.setString(1, UUID.randomUUID().toString()); 187 stmt.setString(1, UUID.randomUUID().toString());
159 // stmt.setObject(2, new PGvector(record.getEmbedding())); 188 // stmt.setObject(2, new PGvector(record.getEmbedding()));
160 - stmt.setObject(2, new PGvector(record.getEmbedding())); 189 + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
  190 + stmt.setObject(2, embedding.content().vector());
161 stmt.setObject(3, record.getText()); 191 stmt.setObject(3, record.getText());
162 stmt.setObject(4, toJson(record.getMetadata())); 192 stmt.setObject(4, toJson(record.getMetadata()));
163 193
@@ -167,7 +197,6 @@ public class PgVectorMapper { @@ -167,7 +197,6 @@ public class PgVectorMapper {
167 throw new RuntimeException("插入向量数据时发生数据库错误", e); 197 throw new RuntimeException("插入向量数据时发生数据库错误", e);
168 } 198 }
169 } 199 }
170 -  
171 // 更新向量记录 200 // 更新向量记录
172 public int update(Embeddings record) { 201 public int update(Embeddings record) {
173 String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?"; 202 String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?";
@@ -187,12 +216,7 @@ public class PgVectorMapper { @@ -187,12 +216,7 @@ public class PgVectorMapper {
187 System.out.println("原始数据: " + mataData); 216 System.out.println("原始数据: " + mataData);
188 217
189 218
190 - PGobject jsonObject = new PGobject();/*  
191 - System.out.println("原始数据: " + mataData); // 检查原始对象  
192 - String jsonStr = mataData.toJSONString();  
193 - System.out.println("JSON字符串: " + jsonStr); // 检查序列化后的JSON  
194 - jsonObject.setValue(jsonStr);  
195 - System.out.println("存入后的值: " + jsonObject.getValue());*/ // 检查存入后的值 219 + PGobject jsonObject = new PGobject();
196 jsonObject.setType("json"); 220 jsonObject.setType("json");
197 jsonObject.setValue(mataData.toJSONString()); 221 jsonObject.setValue(mataData.toJSONString());
198 stmt.setObject(1, new PGvector(record.getEmbedding())); 222 stmt.setObject(1, new PGvector(record.getEmbedding()));
@@ -308,68 +332,4 @@ public class PgVectorMapper { @@ -308,68 +332,4 @@ public class PgVectorMapper {
308 return Collections.emptyMap(); 332 return Collections.emptyMap();
309 } 333 }
310 } 334 }
311 -  
312 - // 自动生成嵌入向量的方法(需根据您的嵌入模型实现)  
313 - private float[] generateEmbedding(String text) {  
314 - // 改为生成 768 维向量  
315 - float[] embedding = new float[1024]; // OpenAI 标准维度是 1536,这里改为 768  
316 -  
317 - // 实际项目中应调用嵌入模型 API  
318 - // 例如:return embeddingClient.generate(text, 768);  
319 -  
320 - // 临时实现:生成随机向量(仅用于演示)  
321 - log.warn("使用随机向量生成 - 实际项目中应替换为真实模型调用");  
322 - Random random = new Random();  
323 - for (int i = 0; i < embedding.length; i++) {  
324 - embedding[i] = random.nextFloat() * 2 - 1;  
325 - }  
326 - return embedding;  
327 - }  
328 -  
329 -  
330 - // 获取知识库名称映射  
331 - private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) {  
332 - // 提取所有知识库ID  
333 - Set<String> knowledgeIds = records.stream()  
334 - .map(Embeddings::getKnowledgeId)  
335 - .filter(Objects::nonNull)  
336 - .collect(Collectors.toSet());  
337 -  
338 - if (knowledgeIds.isEmpty()) {  
339 - return Collections.emptyMap();  
340 - }  
341 -  
342 - // 从 MySQL 查询知识库名称  
343 - Map<String, String> knowledgeNameMap = new HashMap<>();  
344 - try (Connection mysqlConn = getMysqlConnection()) {  
345 - String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?"));  
346 - String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders);  
347 -  
348 - try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) {  
349 - int index = 1;  
350 - for (String id : knowledgeIds) {  
351 - stmt.setString(index++, id);  
352 - }  
353 -  
354 - try (ResultSet rs = stmt.executeQuery()) {  
355 - while (rs.next()) {  
356 - knowledgeNameMap.put(rs.getString("id"), rs.getString("name"));  
357 - }  
358 - }  
359 - }  
360 - } catch (SQLException e) {  
361 - log.error("查询知识库名称失败", e);  
362 - }  
363 -  
364 - return knowledgeNameMap;  
365 - }  
366 -  
367 - // 获取 MySQL 连接  
368 - private Connection getMysqlConnection() throws SQLException {  
369 - String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai";  
370 - String user = "root";  
371 - String password = "123456";  
372 - return DriverManager.getConnection(url, user, password);  
373 - }  
374 -  
375 } 335 }