正在显示
1 个修改的文件
包含
45 行增加
和
85 行删除
| @@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException; | @@ -8,11 +8,16 @@ import com.fasterxml.jackson.core.JsonProcessingException; | ||
| 8 | import com.fasterxml.jackson.core.type.TypeReference; | 8 | import com.fasterxml.jackson.core.type.TypeReference; |
| 9 | import com.fasterxml.jackson.databind.ObjectMapper; | 9 | import com.fasterxml.jackson.databind.ObjectMapper; |
| 10 | import com.pgvector.PGvector; | 10 | import com.pgvector.PGvector; |
| 11 | +import dev.langchain4j.data.embedding.Embedding; | ||
| 12 | +import dev.langchain4j.model.output.Response; | ||
| 11 | import lombok.extern.slf4j.Slf4j; | 13 | import lombok.extern.slf4j.Slf4j; |
| 12 | import org.apache.commons.lang3.StringUtils; | 14 | import org.apache.commons.lang3.StringUtils; |
| 13 | import org.jeecg.modules.airag.app.entity.Embeddings; | 15 | import org.jeecg.modules.airag.app.entity.Embeddings; |
| 14 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | 16 | import org.jeecg.modules.airag.app.entity.QuestionEmbedding; |
| 17 | +import org.jeecg.modules.airag.app.utils.AiModelUtils; | ||
| 15 | import org.postgresql.util.PGobject; | 18 | import org.postgresql.util.PGobject; |
| 19 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 20 | +import org.springframework.beans.factory.annotation.Value; | ||
| 16 | import org.springframework.stereotype.Component; | 21 | import org.springframework.stereotype.Component; |
| 17 | 22 | ||
| 18 | import java.sql.*; | 23 | import java.sql.*; |
| @@ -22,12 +27,18 @@ import java.util.stream.Collectors; | @@ -22,12 +27,18 @@ import java.util.stream.Collectors; | ||
| 22 | @Component | 27 | @Component |
| 23 | @Slf4j | 28 | @Slf4j |
| 24 | public class PgVectorMapper { | 29 | public class PgVectorMapper { |
| 30 | + @Autowired | ||
| 31 | + private AiModelUtils aiModelUtils; | ||
| 25 | 32 | ||
| 26 | // PostgreSQL连接参数(实际项目中应从配置读取) | 33 | // PostgreSQL连接参数(实际项目中应从配置读取) |
| 27 | private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres"; | 34 | private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres"; |
| 28 | private static final String USER = "postgres"; | 35 | private static final String USER = "postgres"; |
| 29 | private static final String PASSWORD = "postgres"; | 36 | private static final String PASSWORD = "postgres"; |
| 30 | 37 | ||
| 38 | + @Value("${jeecg.ai-chat.embedId}") | ||
| 39 | + private String embedId; | ||
| 40 | + | ||
| 41 | + | ||
| 31 | // 获取数据库连接 | 42 | // 获取数据库连接 |
| 32 | private Connection getConnection() throws SQLException { | 43 | private Connection getConnection() throws SQLException { |
| 33 | return DriverManager.getConnection(URL, USER, PASSWORD); | 44 | return DriverManager.getConnection(URL, USER, PASSWORD); |
| @@ -133,22 +144,40 @@ public class PgVectorMapper { | @@ -133,22 +144,40 @@ public class PgVectorMapper { | ||
| 133 | } | 144 | } |
| 134 | return null; | 145 | return null; |
| 135 | } | 146 | } |
| 147 | + // 查询所有记录 | ||
| 148 | + public Integer findEmbeddingCount(Embeddings embeddings) { | ||
| 136 | 149 | ||
| 137 | - // 插入新向量记录 | ||
| 138 | - public int insert(Embeddings record) { | ||
| 139 | - /*Map<String, Object> metadata = new LinkedHashMap<>(); | 150 | + StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from embeddings where 1 = 1"); |
| 151 | + List<Object> params = new ArrayList<>(); | ||
| 152 | + | ||
| 153 | + if(StringUtils.isNotBlank(embeddings.getText())){ | ||
| 154 | + sql.append(" AND text = ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配 | ||
| 155 | + params.add(embeddings.getText()); | ||
| 156 | + } | ||
| 157 | + | ||
| 158 | + | ||
| 159 | + try(Connection conn = getConnection(); | ||
| 160 | + PreparedStatement stmt = conn.prepareStatement(sql.toString())){ | ||
| 161 | + // 设置参数值 | ||
| 162 | + for (int i = 0; i < params.size(); i++) { | ||
| 163 | + stmt.setObject(i + 1, params.get(i)); | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 167 | + while (rs.next()) { | ||
| 168 | + return rs.getInt("total_count"); | ||
| 169 | + } | ||
| 170 | + return 0; | ||
| 171 | + } | ||
| 172 | + } catch (SQLException e) { | ||
| 173 | + log.error("查询所有记录失败", e); | ||
| 174 | + throw new RuntimeException("查询数据时发生数据库错误", e); | ||
| 175 | + } | ||
| 140 | 176 | ||
| 141 | - // 按固定顺序添加字段 | ||
| 142 | - metadata.put("docId", UUID.randomUUID().toString()); | ||
| 143 | - metadata.put("knowledgeId", getKnowledgeId(record)); // 使用统一方法获取 | ||
| 144 | - metadata.put("docName", record.getDocName()); | ||
| 145 | - metadata.put("index", 0); // 确保是整数 | ||
| 146 | - record.setMetadata(metadata); | ||
| 147 | -*/ | ||
| 148 | - // 自动生成向量(这里需要调用嵌入模型) | ||
| 149 | - float[] embedding = generateEmbedding(record.getText()); | ||
| 150 | - record.setEmbedding(embedding); | 177 | + } |
| 151 | 178 | ||
| 179 | + // 插入新向量记录 | ||
| 180 | + public int insert(Embeddings record) { | ||
| 152 | 181 | ||
| 153 | String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)"; | 182 | String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)"; |
| 154 | 183 | ||
| @@ -157,7 +186,8 @@ public class PgVectorMapper { | @@ -157,7 +186,8 @@ public class PgVectorMapper { | ||
| 157 | 186 | ||
| 158 | stmt.setString(1, UUID.randomUUID().toString()); | 187 | stmt.setString(1, UUID.randomUUID().toString()); |
| 159 | // stmt.setObject(2, new PGvector(record.getEmbedding())); | 188 | // stmt.setObject(2, new PGvector(record.getEmbedding())); |
| 160 | - stmt.setObject(2, new PGvector(record.getEmbedding())); | 189 | + Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText()); |
| 190 | + stmt.setObject(2, embedding.content().vector()); | ||
| 161 | stmt.setObject(3, record.getText()); | 191 | stmt.setObject(3, record.getText()); |
| 162 | stmt.setObject(4, toJson(record.getMetadata())); | 192 | stmt.setObject(4, toJson(record.getMetadata())); |
| 163 | 193 | ||
| @@ -167,7 +197,6 @@ public class PgVectorMapper { | @@ -167,7 +197,6 @@ public class PgVectorMapper { | ||
| 167 | throw new RuntimeException("插入向量数据时发生数据库错误", e); | 197 | throw new RuntimeException("插入向量数据时发生数据库错误", e); |
| 168 | } | 198 | } |
| 169 | } | 199 | } |
| 170 | - | ||
| 171 | // 更新向量记录 | 200 | // 更新向量记录 |
| 172 | public int update(Embeddings record) { | 201 | public int update(Embeddings record) { |
| 173 | String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?"; | 202 | String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?"; |
| @@ -187,12 +216,7 @@ public class PgVectorMapper { | @@ -187,12 +216,7 @@ public class PgVectorMapper { | ||
| 187 | System.out.println("原始数据: " + mataData); | 216 | System.out.println("原始数据: " + mataData); |
| 188 | 217 | ||
| 189 | 218 | ||
| 190 | - PGobject jsonObject = new PGobject();/* | ||
| 191 | - System.out.println("原始数据: " + mataData); // 检查原始对象 | ||
| 192 | - String jsonStr = mataData.toJSONString(); | ||
| 193 | - System.out.println("JSON字符串: " + jsonStr); // 检查序列化后的JSON | ||
| 194 | - jsonObject.setValue(jsonStr); | ||
| 195 | - System.out.println("存入后的值: " + jsonObject.getValue());*/ // 检查存入后的值 | 219 | + PGobject jsonObject = new PGobject(); |
| 196 | jsonObject.setType("json"); | 220 | jsonObject.setType("json"); |
| 197 | jsonObject.setValue(mataData.toJSONString()); | 221 | jsonObject.setValue(mataData.toJSONString()); |
| 198 | stmt.setObject(1, new PGvector(record.getEmbedding())); | 222 | stmt.setObject(1, new PGvector(record.getEmbedding())); |
| @@ -308,68 +332,4 @@ public class PgVectorMapper { | @@ -308,68 +332,4 @@ public class PgVectorMapper { | ||
| 308 | return Collections.emptyMap(); | 332 | return Collections.emptyMap(); |
| 309 | } | 333 | } |
| 310 | } | 334 | } |
| 311 | - | ||
| 312 | - // 自动生成嵌入向量的方法(需根据您的嵌入模型实现) | ||
| 313 | - private float[] generateEmbedding(String text) { | ||
| 314 | - // 改为生成 768 维向量 | ||
| 315 | - float[] embedding = new float[1024]; // OpenAI 标准维度是 1536,这里改为 768 | ||
| 316 | - | ||
| 317 | - // 实际项目中应调用嵌入模型 API | ||
| 318 | - // 例如:return embeddingClient.generate(text, 768); | ||
| 319 | - | ||
| 320 | - // 临时实现:生成随机向量(仅用于演示) | ||
| 321 | - log.warn("使用随机向量生成 - 实际项目中应替换为真实模型调用"); | ||
| 322 | - Random random = new Random(); | ||
| 323 | - for (int i = 0; i < embedding.length; i++) { | ||
| 324 | - embedding[i] = random.nextFloat() * 2 - 1; | ||
| 325 | - } | ||
| 326 | - return embedding; | ||
| 327 | - } | ||
| 328 | - | ||
| 329 | - | ||
| 330 | - // 获取知识库名称映射 | ||
| 331 | - private Map<String, String> getKnowledgeNameMap(List<Embeddings> records) { | ||
| 332 | - // 提取所有知识库ID | ||
| 333 | - Set<String> knowledgeIds = records.stream() | ||
| 334 | - .map(Embeddings::getKnowledgeId) | ||
| 335 | - .filter(Objects::nonNull) | ||
| 336 | - .collect(Collectors.toSet()); | ||
| 337 | - | ||
| 338 | - if (knowledgeIds.isEmpty()) { | ||
| 339 | - return Collections.emptyMap(); | ||
| 340 | - } | ||
| 341 | - | ||
| 342 | - // 从 MySQL 查询知识库名称 | ||
| 343 | - Map<String, String> knowledgeNameMap = new HashMap<>(); | ||
| 344 | - try (Connection mysqlConn = getMysqlConnection()) { | ||
| 345 | - String placeholders = String.join(",", Collections.nCopies(knowledgeIds.size(), "?")); | ||
| 346 | - String sql = String.format("SELECT id, name FROM airag_knowledge WHERE id IN (%s)", placeholders); | ||
| 347 | - | ||
| 348 | - try (PreparedStatement stmt = mysqlConn.prepareStatement(sql)) { | ||
| 349 | - int index = 1; | ||
| 350 | - for (String id : knowledgeIds) { | ||
| 351 | - stmt.setString(index++, id); | ||
| 352 | - } | ||
| 353 | - | ||
| 354 | - try (ResultSet rs = stmt.executeQuery()) { | ||
| 355 | - while (rs.next()) { | ||
| 356 | - knowledgeNameMap.put(rs.getString("id"), rs.getString("name")); | ||
| 357 | - } | ||
| 358 | - } | ||
| 359 | - } | ||
| 360 | - } catch (SQLException e) { | ||
| 361 | - log.error("查询知识库名称失败", e); | ||
| 362 | - } | ||
| 363 | - | ||
| 364 | - return knowledgeNameMap; | ||
| 365 | - } | ||
| 366 | - | ||
| 367 | - // 获取 MySQL 连接 | ||
| 368 | - private Connection getMysqlConnection() throws SQLException { | ||
| 369 | - String url = "jdbc:mysql://localhost:3306/jeecg-boot-dev?characterEncoding=UTF-8&useUnicode=true&useSSL=false&tinyInt1isBit=false&allowPublicKeyRetrieval=true&serverTimezone=Asia/Shanghai"; | ||
| 370 | - String user = "root"; | ||
| 371 | - String password = "123456"; | ||
| 372 | - return DriverManager.getConnection(url, user, password); | ||
| 373 | - } | ||
| 374 | - | ||
| 375 | } | 335 | } |
-
请 注册 或 登录 后发表评论