|
|
|
//package org.jeecg.modules.airag.app.config;
|
|
|
|
//import dev.langchain4j.data.document.Metadata;
|
|
|
|
//import dev.langchain4j.data.embedding.Embedding;
|
|
|
|
//import dev.langchain4j.data.segment.TextSegment;
|
|
|
|
//import dev.langchain4j.store.embedding.EmbeddingMatch;
|
|
|
|
//import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
|
|
|
|
//import dev.langchain4j.store.embedding.EmbeddingSearchResult;
|
|
|
|
//import dev.langchain4j.store.embedding.EmbeddingStore;
|
|
|
|
//import dev.langchain4j.store.embedding.filter.Filter;
|
|
|
|
//import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
|
|
|
|
//import lombok.extern.log4j.Log4j2;
|
|
|
|
//import org.springframework.beans.factory.annotation.Autowired;
|
|
|
|
//import org.springframework.beans.factory.annotation.Qualifier;
|
|
|
|
//import org.springframework.jdbc.core.JdbcTemplate;
|
|
|
|
//import org.springframework.stereotype.Component;
|
|
|
|
//
|
|
|
|
//import javax.sql.DataSource;
|
|
|
|
//import java.util.ArrayList;
|
|
|
|
//import java.util.Collection;
|
|
|
|
//import java.util.List;
|
|
|
|
//import java.util.Map;
|
|
|
|
//import java.util.Collection;
|
|
|
|
//import java.util.List;
|
|
|
|
//import java.util.stream.Collectors;
|
|
|
|
//
|
|
|
|
//@Component
|
|
|
|
//@Log4j2
|
|
|
|
//public class PostgreEmbeddingStore implements EmbeddingStore<TextSegment> {
|
|
|
|
//
|
|
|
|
// @Autowired
|
|
|
|
// private PgVectorEmbeddingStore pgVectorEmbeddingStore;
|
|
|
|
//
|
|
|
|
// @Autowired
|
|
|
|
// private JdbcTemplate pgJdbcTemplate;
|
|
|
|
//
|
|
|
|
// @Autowired
|
|
|
|
// public PostgreEmbeddingStore(
|
|
|
|
// PgVectorEmbeddingStore pgVectorEmbeddingStore) {
|
|
|
|
// this.pgJdbcTemplate = pgJdbcTemplate;
|
|
|
|
// this.pgVectorEmbeddingStore = pgVectorEmbeddingStore;
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public String add(Embedding embedding) {
|
|
|
|
// return "";
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public void add(String id, Embedding embedding) {
|
|
|
|
//
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public String add(Embedding embedding, TextSegment textSegment) {
|
|
|
|
// return "";
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<String> addAll(List<Embedding> embeddings) {
|
|
|
|
// return List.of();
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
|
|
|
|
// return List.of();
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public void remove(String id) {
|
|
|
|
// EmbeddingStore.super.remove(id);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public void removeAll(Collection<String> ids) {
|
|
|
|
// EmbeddingStore.super.removeAll(ids);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public void removeAll(Filter filter) {
|
|
|
|
// EmbeddingStore.super.removeAll(filter);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public void removeAll() {
|
|
|
|
// EmbeddingStore.super.removeAll();
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
|
|
|
|
// return EmbeddingStore.super.search(request);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
|
|
|
|
// return findRelevant(referenceEmbedding, maxResults, 0.0);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
|
|
|
|
// try {
|
|
|
|
//// // 使用 PgVectorEmbeddingStore 进行查询
|
|
|
|
//// EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
|
|
|
|
//// .queryEmbedding(referenceEmbedding)
|
|
|
|
//// .maxResults(maxResults)
|
|
|
|
//// .minScore(minScore)
|
|
|
|
//// .build();
|
|
|
|
//
|
|
|
|
// // 构建带内存ID过滤的查询
|
|
|
|
// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " +
|
|
|
|
// "FROM embeddings " +
|
|
|
|
// "WHERE (1 - (embedding <=> ?)) >= ? " +
|
|
|
|
// "ORDER BY distance " +
|
|
|
|
// "LIMIT ?";
|
|
|
|
//
|
|
|
|
// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList(
|
|
|
|
// sql,
|
|
|
|
// referenceEmbedding.vectorAsList(),
|
|
|
|
// referenceEmbedding.vectorAsList(),
|
|
|
|
// minScore,
|
|
|
|
// maxResults
|
|
|
|
// );
|
|
|
|
//
|
|
|
|
//
|
|
|
|
//
|
|
|
|
//// EmbeddingSearchResult<TextSegment> result = pgVectorEmbeddingStore.search(request);
|
|
|
|
////
|
|
|
|
////
|
|
|
|
// return convertToMatches(rows);
|
|
|
|
// } catch (Exception e) {
|
|
|
|
// log.error("向量查询失败", e);
|
|
|
|
// throw new RuntimeException("向量搜索失败: " + e.getMessage(), e);
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding, int maxResults) {
|
|
|
|
// return findRelevant(memoryId, referenceEmbedding, maxResults, 0.0);
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// @Override
|
|
|
|
// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding,
|
|
|
|
// int maxResults, double minScore) {
|
|
|
|
// try {
|
|
|
|
// // 构建带内存ID过滤的查询
|
|
|
|
// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " +
|
|
|
|
// "FROM embeddings " +
|
|
|
|
// "WHERE metadata->>'memory_id' = ? " +
|
|
|
|
// "AND (1 - (embedding <=> ?)) >= ? " +
|
|
|
|
// "ORDER BY distance " +
|
|
|
|
// "LIMIT ?";
|
|
|
|
//
|
|
|
|
// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList(
|
|
|
|
// sql,
|
|
|
|
// referenceEmbedding.vectorAsList(),
|
|
|
|
// memoryId.toString(),
|
|
|
|
// referenceEmbedding.vectorAsList(),
|
|
|
|
// minScore,
|
|
|
|
// maxResults
|
|
|
|
// );
|
|
|
|
//
|
|
|
|
// return convertToMatches(rows);
|
|
|
|
// } catch (Exception e) {
|
|
|
|
// log.error("带内存ID的向量查询失败", e);
|
|
|
|
// throw new RuntimeException("带内存ID的向量搜索失败: " + e.getMessage(), e);
|
|
|
|
// }
|
|
|
|
// }
|
|
|
|
//
|
|
|
|
// private List<EmbeddingMatch<TextSegment>> convertToMatches(List<Map<String, Object>> rows) {
|
|
|
|
// List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>();
|
|
|
|
// for (Map<String, Object> row : rows) {
|
|
|
|
// String id = row.get("id").toString();
|
|
|
|
// String content = (String) row.get("content");
|
|
|
|
//
|
|
|
|
// // 处理 Metadata
|
|
|
|
// Map<String, String> metadataMap = (Map<String, String>) row.get("metadata");
|
|
|
|
// Metadata metadata = Metadata.from(metadataMap);
|
|
|
|
//
|
|
|
|
// // 处理 Embedding 转换
|
|
|
|
// List<Float> embeddingList = (List<Float>) row.get("embedding");
|
|
|
|
// float[] embeddingArray = new float[embeddingList.size()];
|
|
|
|
// for (int i = 0; i < embeddingList.size(); i++) {
|
|
|
|
// embeddingArray[i] = embeddingList.get(i);
|
|
|
|
// }
|
|
|
|
// Embedding embedding = new Embedding(embeddingArray);
|
|
|
|
//
|
|
|
|
// double score = 1 - (double) row.get("distance");
|
|
|
|
// TextSegment textSegment = TextSegment.from(content, metadata);
|
|
|
|
//
|
|
|
|
// matches.add(new EmbeddingMatch<>(score, id, embedding, textSegment));
|
|
|
|
// }
|
|
|
|
// return matches;
|
|
|
|
// }
|
|
|
|
//} |