正在显示
11 个修改的文件
包含
1104 行增加
和
0 行删除
| 1 | +//package org.jeecg.modules.airag.app.config; | ||
| 2 | +// | ||
| 3 | +//import org.springframework.beans.factory.annotation.Qualifier; | ||
| 4 | +//import org.springframework.boot.context.properties.ConfigurationProperties; | ||
| 5 | +//import org.springframework.boot.jdbc.DataSourceBuilder; | ||
| 6 | +//import org.springframework.context.annotation.Bean; | ||
| 7 | +//import org.springframework.context.annotation.Configuration; | ||
| 8 | +//import org.springframework.jdbc.core.JdbcTemplate; | ||
| 9 | +// | ||
| 10 | +//import javax.sql.DataSource; | ||
| 11 | +// | ||
| 12 | +//@Configuration | ||
| 13 | +//public class PgVectorDataSourceConfig { | ||
| 14 | +// | ||
| 15 | +// @Bean(name = "pgVectorDataSource") | ||
| 16 | +// @ConfigurationProperties(prefix = "spring.datasource.dynamic.datasource.pgvector") | ||
| 17 | +// public DataSource pgVectorDataSource() { | ||
| 18 | +// return DataSourceBuilder.create().build(); | ||
| 19 | +// } | ||
| 20 | +// | ||
| 21 | +// @Bean(name = "pgVectorJdbcTemplate") | ||
| 22 | +// public JdbcTemplate pgVectorJdbcTemplate( | ||
| 23 | +// @Qualifier("pgVectorDataSource") DataSource dataSource) { | ||
| 24 | +// return new JdbcTemplate(dataSource); | ||
| 25 | +// } | ||
| 26 | +//} |
| 1 | +//package org.jeecg.modules.airag.app.config; | ||
| 2 | +//import dev.langchain4j.data.document.Metadata; | ||
| 3 | +//import dev.langchain4j.data.embedding.Embedding; | ||
| 4 | +//import dev.langchain4j.data.segment.TextSegment; | ||
| 5 | +//import dev.langchain4j.store.embedding.EmbeddingMatch; | ||
| 6 | +//import dev.langchain4j.store.embedding.EmbeddingSearchRequest; | ||
| 7 | +//import dev.langchain4j.store.embedding.EmbeddingSearchResult; | ||
| 8 | +//import dev.langchain4j.store.embedding.EmbeddingStore; | ||
| 9 | +//import dev.langchain4j.store.embedding.filter.Filter; | ||
| 10 | +//import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; | ||
| 11 | +//import lombok.extern.log4j.Log4j2; | ||
| 12 | +//import org.springframework.beans.factory.annotation.Autowired; | ||
| 13 | +//import org.springframework.beans.factory.annotation.Qualifier; | ||
| 14 | +//import org.springframework.jdbc.core.JdbcTemplate; | ||
| 15 | +//import org.springframework.stereotype.Component; | ||
| 16 | +// | ||
| 17 | +//import javax.sql.DataSource; | ||
| 18 | +//import java.util.ArrayList; | ||
| 19 | +//import java.util.Collection; | ||
| 20 | +//import java.util.List; | ||
| 21 | +//import java.util.Map; | ||
| 22 | +//import java.util.Collection; | ||
| 23 | +//import java.util.List; | ||
| 24 | +//import java.util.stream.Collectors; | ||
| 25 | +// | ||
| 26 | +//@Component | ||
| 27 | +//@Log4j2 | ||
| 28 | +//public class PostgreEmbeddingStore implements EmbeddingStore<TextSegment> { | ||
| 29 | +// | ||
| 30 | +// @Autowired | ||
| 31 | +// private PgVectorEmbeddingStore pgVectorEmbeddingStore; | ||
| 32 | +// | ||
| 33 | +// @Autowired | ||
| 34 | +// private JdbcTemplate pgJdbcTemplate; | ||
| 35 | +// | ||
| 36 | +// @Autowired | ||
| 37 | +// public PostgreEmbeddingStore( | ||
| 38 | +// PgVectorEmbeddingStore pgVectorEmbeddingStore) { | ||
| 39 | +// this.pgJdbcTemplate = pgJdbcTemplate; | ||
| 40 | +// this.pgVectorEmbeddingStore = pgVectorEmbeddingStore; | ||
| 41 | +// } | ||
| 42 | +// | ||
| 43 | +// | ||
| 44 | +// @Override | ||
| 45 | +// public String add(Embedding embedding) { | ||
| 46 | +// return ""; | ||
| 47 | +// } | ||
| 48 | +// | ||
| 49 | +// @Override | ||
| 50 | +// public void add(String id, Embedding embedding) { | ||
| 51 | +// | ||
| 52 | +// } | ||
| 53 | +// | ||
| 54 | +// @Override | ||
| 55 | +// public String add(Embedding embedding, TextSegment textSegment) { | ||
| 56 | +// return ""; | ||
| 57 | +// } | ||
| 58 | +// | ||
| 59 | +// @Override | ||
| 60 | +// public List<String> addAll(List<Embedding> embeddings) { | ||
| 61 | +// return List.of(); | ||
| 62 | +// } | ||
| 63 | +// | ||
| 64 | +// @Override | ||
| 65 | +// public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) { | ||
| 66 | +// return List.of(); | ||
| 67 | +// } | ||
| 68 | +// | ||
| 69 | +// @Override | ||
| 70 | +// public void remove(String id) { | ||
| 71 | +// EmbeddingStore.super.remove(id); | ||
| 72 | +// } | ||
| 73 | +// | ||
| 74 | +// @Override | ||
| 75 | +// public void removeAll(Collection<String> ids) { | ||
| 76 | +// EmbeddingStore.super.removeAll(ids); | ||
| 77 | +// } | ||
| 78 | +// | ||
| 79 | +// @Override | ||
| 80 | +// public void removeAll(Filter filter) { | ||
| 81 | +// EmbeddingStore.super.removeAll(filter); | ||
| 82 | +// } | ||
| 83 | +// | ||
| 84 | +// @Override | ||
| 85 | +// public void removeAll() { | ||
| 86 | +// EmbeddingStore.super.removeAll(); | ||
| 87 | +// } | ||
| 88 | +// | ||
| 89 | +// @Override | ||
| 90 | +// public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) { | ||
| 91 | +// return EmbeddingStore.super.search(request); | ||
| 92 | +// } | ||
| 93 | +// | ||
| 94 | +// @Override | ||
| 95 | +// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) { | ||
| 96 | +// return findRelevant(referenceEmbedding, maxResults, 0.0); | ||
| 97 | +// } | ||
| 98 | +// | ||
| 99 | +// @Override | ||
| 100 | +// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) { | ||
| 101 | +// try { | ||
| 102 | +//// // 使用 PgVectorEmbeddingStore 进行查询 | ||
| 103 | +//// EmbeddingSearchRequest request = EmbeddingSearchRequest.builder() | ||
| 104 | +//// .queryEmbedding(referenceEmbedding) | ||
| 105 | +//// .maxResults(maxResults) | ||
| 106 | +//// .minScore(minScore) | ||
| 107 | +//// .build(); | ||
| 108 | +// | ||
| 109 | +// // 构建带内存ID过滤的查询 | ||
| 110 | +// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " + | ||
| 111 | +// "FROM embeddings " + | ||
| 112 | +// "WHERE (1 - (embedding <=> ?)) >= ? " + | ||
| 113 | +// "ORDER BY distance " + | ||
| 114 | +// "LIMIT ?"; | ||
| 115 | +// | ||
| 116 | +// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList( | ||
| 117 | +// sql, | ||
| 118 | +// referenceEmbedding.vectorAsList(), | ||
| 119 | +// referenceEmbedding.vectorAsList(), | ||
| 120 | +// minScore, | ||
| 121 | +// maxResults | ||
| 122 | +// ); | ||
| 123 | +// | ||
| 124 | +// | ||
| 125 | +// | ||
| 126 | +//// EmbeddingSearchResult<TextSegment> result = pgVectorEmbeddingStore.search(request); | ||
| 127 | +//// | ||
| 128 | +//// | ||
| 129 | +// return convertToMatches(rows); | ||
| 130 | +// } catch (Exception e) { | ||
| 131 | +// log.error("向量查询失败", e); | ||
| 132 | +// throw new RuntimeException("向量搜索失败: " + e.getMessage(), e); | ||
| 133 | +// } | ||
| 134 | +// } | ||
| 135 | +// | ||
| 136 | +// @Override | ||
| 137 | +// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding, int maxResults) { | ||
| 138 | +// return findRelevant(memoryId, referenceEmbedding, maxResults, 0.0); | ||
| 139 | +// } | ||
| 140 | +// | ||
| 141 | +// @Override | ||
| 142 | +// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding, | ||
| 143 | +// int maxResults, double minScore) { | ||
| 144 | +// try { | ||
| 145 | +// // 构建带内存ID过滤的查询 | ||
| 146 | +// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " + | ||
| 147 | +// "FROM embeddings " + | ||
| 148 | +// "WHERE metadata->>'memory_id' = ? " + | ||
| 149 | +// "AND (1 - (embedding <=> ?)) >= ? " + | ||
| 150 | +// "ORDER BY distance " + | ||
| 151 | +// "LIMIT ?"; | ||
| 152 | +// | ||
| 153 | +// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList( | ||
| 154 | +// sql, | ||
| 155 | +// referenceEmbedding.vectorAsList(), | ||
| 156 | +// memoryId.toString(), | ||
| 157 | +// referenceEmbedding.vectorAsList(), | ||
| 158 | +// minScore, | ||
| 159 | +// maxResults | ||
| 160 | +// ); | ||
| 161 | +// | ||
| 162 | +// return convertToMatches(rows); | ||
| 163 | +// } catch (Exception e) { | ||
| 164 | +// log.error("带内存ID的向量查询失败", e); | ||
| 165 | +// throw new RuntimeException("带内存ID的向量搜索失败: " + e.getMessage(), e); | ||
| 166 | +// } | ||
| 167 | +// } | ||
| 168 | +// | ||
| 169 | +// private List<EmbeddingMatch<TextSegment>> convertToMatches(List<Map<String, Object>> rows) { | ||
| 170 | +// List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>(); | ||
| 171 | +// for (Map<String, Object> row : rows) { | ||
| 172 | +// String id = row.get("id").toString(); | ||
| 173 | +// String content = (String) row.get("content"); | ||
| 174 | +// | ||
| 175 | +// // 处理 Metadata | ||
| 176 | +// Map<String, String> metadataMap = (Map<String, String>) row.get("metadata"); | ||
| 177 | +// Metadata metadata = Metadata.from(metadataMap); | ||
| 178 | +// | ||
| 179 | +// // 处理 Embedding 转换 | ||
| 180 | +// List<Float> embeddingList = (List<Float>) row.get("embedding"); | ||
| 181 | +// float[] embeddingArray = new float[embeddingList.size()]; | ||
| 182 | +// for (int i = 0; i < embeddingList.size(); i++) { | ||
| 183 | +// embeddingArray[i] = embeddingList.get(i); | ||
| 184 | +// } | ||
| 185 | +// Embedding embedding = new Embedding(embeddingArray); | ||
| 186 | +// | ||
| 187 | +// double score = 1 - (double) row.get("distance"); | ||
| 188 | +// TextSegment textSegment = TextSegment.from(content, metadata); | ||
| 189 | +// | ||
| 190 | +// matches.add(new EmbeddingMatch<>(score, id, embedding, textSegment)); | ||
| 191 | +// } | ||
| 192 | +// return matches; | ||
| 193 | +// } | ||
| 194 | +//} |
| 1 | +//package org.jeecg.modules.airag.app.config; | ||
| 2 | +// | ||
| 3 | +//import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; | ||
| 4 | +//import org.springframework.beans.factory.annotation.Value; | ||
| 5 | +//import org.springframework.context.annotation.Bean; | ||
| 6 | +//import org.springframework.context.annotation.Configuration; | ||
| 7 | +// | ||
| 8 | +//@Configuration | ||
| 9 | +//public class VectorStoreConfig { | ||
| 10 | +// @Value("${jeecg.ai-rag.embed-store.host}") | ||
| 11 | +// private String host; | ||
| 12 | +// @Value("${jeecg.ai-rag.embed-store.port}") | ||
| 13 | +// private Integer port; | ||
| 14 | +// @Value("${jeecg.ai-rag.embed-store.database}") | ||
| 15 | +// private String database; | ||
| 16 | +// @Value("${jeecg.ai-rag.embed-store.user}") | ||
| 17 | +// private String user; | ||
| 18 | +// @Value("${jeecg.ai-rag.embed-store.password}") | ||
| 19 | +// private String password; | ||
| 20 | +//// @Value("${spring.datasource.vector.url}") | ||
| 21 | +//// private String url; | ||
| 22 | +//// | ||
| 23 | +//// @Value("${spring.datasource.vector.username}") | ||
| 24 | +//// private String username; | ||
| 25 | +//// | ||
| 26 | +//// @Value("${spring.datasource.vector.password}") | ||
| 27 | +//// private String password; | ||
| 28 | +// | ||
| 29 | +// @Bean | ||
| 30 | +// public PgVectorEmbeddingStore pgVectorEmbeddingStore() { | ||
| 31 | +// return PgVectorEmbeddingStore.builder() | ||
| 32 | +// .host(host) // 如果使用独立参数 | ||
| 33 | +// .port(port) | ||
| 34 | +// .user(user) | ||
| 35 | +// .password(password) | ||
| 36 | +// .database(database) | ||
| 37 | +// .table("embeddings") | ||
| 38 | +// .dimension(1536) // 根据你的模型维度设置 | ||
| 39 | +// .useIndex(true) | ||
| 40 | +// .createTable(true) | ||
| 41 | +// .build(); | ||
| 42 | +// | ||
| 43 | +//// return PgVectorEmbeddingStore.builder() | ||
| 44 | +//// .dataSource(dataSource) // 注入配置好的DataSource | ||
| 45 | +//// .table("embeddings") | ||
| 46 | +//// .dimension(1536) | ||
| 47 | +//// .build(); | ||
| 48 | +// } | ||
| 49 | +//} |
| @@ -42,6 +42,7 @@ import org.jeecg.modules.airag.app.service.IEmbeddingsService; | @@ -42,6 +42,7 @@ import org.jeecg.modules.airag.app.service.IEmbeddingsService; | ||
| 42 | import org.jeecg.common.system.base.controller.JeecgController; | 42 | import org.jeecg.common.system.base.controller.JeecgController; |
| 43 | import org.jeecg.modules.airag.app.utils.AiModelUtils; | 43 | import org.jeecg.modules.airag.app.utils.AiModelUtils; |
| 44 | import org.jeecg.modules.airag.llm.entity.AiragModel; | 44 | import org.jeecg.modules.airag.llm.entity.AiragModel; |
| 45 | +import org.jeecg.modules.airag.llm.handler.EmbeddingHandler; | ||
| 45 | import org.springframework.beans.factory.annotation.Autowired; | 46 | import org.springframework.beans.factory.annotation.Autowired; |
| 46 | import org.springframework.beans.factory.annotation.Qualifier; | 47 | import org.springframework.beans.factory.annotation.Qualifier; |
| 47 | import org.springframework.boot.jdbc.DataSourceBuilder; | 48 | import org.springframework.boot.jdbc.DataSourceBuilder; |
| @@ -69,6 +70,9 @@ public class EmbeddingsController { | @@ -69,6 +70,9 @@ public class EmbeddingsController { | ||
| 69 | @Autowired | 70 | @Autowired |
| 70 | private AiModelUtils aiModelUtils; | 71 | private AiModelUtils aiModelUtils; |
| 71 | 72 | ||
| 73 | + @Autowired | ||
| 74 | + private EmbeddingHandler embeddingHandler; | ||
| 75 | + | ||
| 72 | 76 | ||
| 73 | /** | 77 | /** |
| 74 | * 分页列表查询 | 78 | * 分页列表查询 |
| @@ -86,6 +90,13 @@ public class EmbeddingsController { | @@ -86,6 +90,13 @@ public class EmbeddingsController { | ||
| 86 | @RequestParam(name="pageNo", defaultValue="1") Integer pageNo, | 90 | @RequestParam(name="pageNo", defaultValue="1") Integer pageNo, |
| 87 | @RequestParam(name="pageSize", defaultValue="10") Integer pageSize, | 91 | @RequestParam(name="pageSize", defaultValue="10") Integer pageSize, |
| 88 | HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { | 92 | HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException { |
| 93 | +// AiragModel airagModel = new AiragModel(); | ||
| 94 | +// airagModel.setId("1925730210204721154"); | ||
| 95 | +// airagModel.setProvider("OLLAMA"); | ||
| 96 | +// airagModel.setModelName("nomic-embed-text"); | ||
| 97 | +// airagModel.setBaseUrl("http://localhost:11434"); | ||
| 98 | +// EmbeddingStore<TextSegment> embedStore = embeddingHandler.getEmbedStore(airagModel); | ||
| 99 | +// embeddingHandler.searchEmbedding() | ||
| 89 | Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", "33333"); | 100 | Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", "33333"); |
| 90 | 101 | ||
| 91 | List<Embeddings> records = embeddingsService.findAll(); | 102 | List<Embeddings> records = embeddingsService.findAll(); |
| 1 | +package org.jeecg.modules.airag.app.controller; | ||
| 2 | + | ||
| 3 | +import org.jeecg.common.api.vo.Result; | ||
| 4 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 5 | +import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | ||
| 6 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 7 | +import org.springframework.transaction.annotation.Transactional; | ||
| 8 | +import org.springframework.web.bind.annotation.*; | ||
| 9 | +import org.springframework.web.multipart.MultipartFile; | ||
| 10 | + | ||
| 11 | +import java.util.List; | ||
| 12 | + | ||
| 13 | +@RestController | ||
| 14 | +@RequestMapping("/question/embedding") | ||
| 15 | +public class QuestionEmbeddingController { | ||
| 16 | + @Autowired | ||
| 17 | + private IQuestionEmbeddingService questionEmbeddingService; | ||
| 18 | + | ||
| 19 | + @GetMapping("/list") | ||
| 20 | + public Result<List<QuestionEmbedding>> findAll() { | ||
| 21 | + List<QuestionEmbedding> list = questionEmbeddingService.findAll(); | ||
| 22 | + return Result.OK(list); | ||
| 23 | + } | ||
| 24 | + | ||
| 25 | + @GetMapping("/queryById") | ||
| 26 | + public Result<QuestionEmbedding> findById(@RequestParam String id) { | ||
| 27 | + QuestionEmbedding record = questionEmbeddingService.findById(id); | ||
| 28 | + if (record == null) { | ||
| 29 | + return Result.error("未找到对应数据"); | ||
| 30 | + } | ||
| 31 | + return Result.OK(record); | ||
| 32 | + } | ||
| 33 | + | ||
| 34 | + @PostMapping("/add") | ||
| 35 | + public Result<String> insert(@RequestBody QuestionEmbedding record) { | ||
| 36 | + int result = questionEmbeddingService.insert(record); | ||
| 37 | + return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败"); | ||
| 38 | + } | ||
| 39 | + | ||
| 40 | + @RequestMapping(value = "/edit", method = {RequestMethod.PUT,RequestMethod.POST}) | ||
| 41 | + public Result<String> update(@RequestBody QuestionEmbedding record) { | ||
| 42 | + int result = questionEmbeddingService.update(record); | ||
| 43 | + return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败"); | ||
| 44 | + } | ||
| 45 | + | ||
| 46 | + @DeleteMapping("/delete") | ||
| 47 | + public Result<String> deleteById(@RequestParam String id) { | ||
| 48 | + int result = questionEmbeddingService.deleteById(id); | ||
| 49 | + return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败"); | ||
| 50 | + } | ||
| 51 | + | ||
| 52 | + @PostMapping("/uploadZip") | ||
| 53 | + @Transactional(rollbackFor = {Exception.class}) | ||
| 54 | + public Result<?> uploadZip(@RequestParam("file") MultipartFile file) { | ||
| 55 | + return questionEmbeddingService.processZipUpload(file); | ||
| 56 | + } | ||
| 57 | + | ||
| 58 | +} |
| 1 | +package org.jeecg.modules.airag.app.entity; | ||
| 2 | + | ||
| 3 | +import lombok.AllArgsConstructor; | ||
| 4 | +import lombok.Data; | ||
| 5 | +import lombok.NoArgsConstructor; | ||
| 6 | + | ||
| 7 | +import java.util.Map; | ||
| 8 | + | ||
| 9 | +@Data | ||
| 10 | +@AllArgsConstructor | ||
| 11 | +@NoArgsConstructor | ||
| 12 | +public class QuestionEmbedding { | ||
| 13 | + private String id; | ||
| 14 | + private String text; | ||
| 15 | + private String question; | ||
| 16 | + private String answer; | ||
| 17 | + private String metadata; | ||
| 18 | + private float[] embedding; | ||
| 19 | + private Double similarity; | ||
| 20 | +} |
| 1 | +package org.jeecg.modules.airag.app.mapper; | ||
| 2 | + | ||
| 3 | +import com.fasterxml.jackson.core.JsonProcessingException; | ||
| 4 | +import com.fasterxml.jackson.core.type.TypeReference; | ||
| 5 | +import com.fasterxml.jackson.databind.ObjectMapper; | ||
| 6 | +import com.pgvector.PGvector; | ||
| 7 | +import dev.langchain4j.data.embedding.Embedding; | ||
| 8 | +import dev.langchain4j.model.output.Response; | ||
| 9 | +import lombok.extern.slf4j.Slf4j; | ||
| 10 | +import org.apache.commons.lang3.StringUtils; | ||
| 11 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 12 | +import org.jeecg.modules.airag.app.utils.AiModelUtils; | ||
| 13 | +import org.postgresql.util.PGobject; | ||
| 14 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 15 | +import org.springframework.stereotype.Component; | ||
| 16 | + | ||
| 17 | +import java.sql.*; | ||
| 18 | +import java.util.*; | ||
| 19 | + | ||
| 20 | +@Component | ||
| 21 | +@Slf4j | ||
| 22 | +public class QuestionEmbeddingMapper { | ||
| 23 | + | ||
| 24 | + @Autowired | ||
| 25 | + private AiModelUtils aiModelUtils; | ||
| 26 | + | ||
| 27 | + // PostgreSQL连接参数(应与项目配置一致) | ||
| 28 | + private static final String URL = "jdbc:postgresql://192.168.100.103:5432/postgres"; | ||
| 29 | + private static final String USER = "postgres"; | ||
| 30 | + private static final String PASSWORD = "postgres"; | ||
| 31 | + | ||
| 32 | + // 获取数据库连接 | ||
| 33 | + private Connection getConnection() throws SQLException { | ||
| 34 | + return DriverManager.getConnection(URL, USER, PASSWORD); | ||
| 35 | + } | ||
| 36 | + | ||
| 37 | + // 查询所有记录 | ||
| 38 | + public List<QuestionEmbedding> findAll() { | ||
| 39 | + List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 40 | + String sql = "SELECT * FROM question_embedding"; | ||
| 41 | + | ||
| 42 | + try (Connection conn = getConnection(); | ||
| 43 | + PreparedStatement stmt = conn.prepareStatement(sql); | ||
| 44 | + ResultSet rs = stmt.executeQuery()) { | ||
| 45 | + | ||
| 46 | + while (rs.next()) { | ||
| 47 | + results.add(mapRowToQuestionEmbedding(rs)); | ||
| 48 | + } | ||
| 49 | + } catch (SQLException e) { | ||
| 50 | + log.error("查询所有记录失败", e); | ||
| 51 | + throw new RuntimeException("查询数据时发生数据库错误", e); | ||
| 52 | + } | ||
| 53 | + return results; | ||
| 54 | + } | ||
| 55 | + | ||
| 56 | + // 根据ID查询单个记录 | ||
| 57 | + public QuestionEmbedding findById(String id) { | ||
| 58 | + String sql = "SELECT * FROM question_embedding WHERE id = ?"; | ||
| 59 | + | ||
| 60 | + try (Connection conn = getConnection(); | ||
| 61 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 62 | + | ||
| 63 | + stmt.setString(1, id); | ||
| 64 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 65 | + if (rs.next()) { | ||
| 66 | + return mapRowToQuestionEmbedding(rs); | ||
| 67 | + } | ||
| 68 | + } | ||
| 69 | + } catch (SQLException e) { | ||
| 70 | + log.error("根据ID查询记录失败, ID: {}", id, e); | ||
| 71 | + throw new RuntimeException("根据ID查询时发生数据库错误", e); | ||
| 72 | + } | ||
| 73 | + return null; | ||
| 74 | + } | ||
| 75 | + | ||
| 76 | + // 插入新记录 | ||
| 77 | + public int insert(QuestionEmbedding record) { | ||
| 78 | + String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)"; | ||
| 79 | + | ||
| 80 | + try (Connection conn = getConnection(); | ||
| 81 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 82 | + | ||
| 83 | + stmt.setString(1, UUID.randomUUID().toString()); | ||
| 84 | + stmt.setString(2, record.getText()); | ||
| 85 | + stmt.setString(3, record.getQuestion()); | ||
| 86 | + stmt.setString(4, record.getAnswer()); | ||
| 87 | + PGobject jsonObject = new PGobject(); | ||
| 88 | + jsonObject.setType("json"); | ||
| 89 | + jsonObject.setValue("{\"name\":\"John\", \"age\":30}"); | ||
| 90 | + stmt.setObject(5, jsonObject); | ||
| 91 | + Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion()); | ||
| 92 | + stmt.setObject(6, embedding.content().vector()); | ||
| 93 | + return stmt.executeUpdate(); | ||
| 94 | + } catch (SQLException e) { | ||
| 95 | + log.error("插入记录失败: {}", record, e); | ||
| 96 | + throw new RuntimeException("插入数据时发生数据库错误", e); | ||
| 97 | + } | ||
| 98 | + } | ||
| 99 | + | ||
| 100 | + // 更新记录 | ||
| 101 | + public int update(QuestionEmbedding record) { | ||
| 102 | + String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?"; | ||
| 103 | + | ||
| 104 | + try (Connection conn = getConnection(); | ||
| 105 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 106 | + | ||
| 107 | + stmt.setString(1, record.getText()); | ||
| 108 | + stmt.setString(2, record.getQuestion()); | ||
| 109 | + stmt.setString(3, record.getAnswer()); | ||
| 110 | + PGobject jsonObject = new PGobject(); | ||
| 111 | + jsonObject.setType("json"); | ||
| 112 | + jsonObject.setValue("{\"name\":\"John\", \"age\":30}"); | ||
| 113 | + stmt.setObject(4, jsonObject); | ||
| 114 | + | ||
| 115 | + Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion()); | ||
| 116 | + stmt.setObject(5, embedding.content().vector()); | ||
| 117 | + | ||
| 118 | + stmt.setString(6, record.getId()); | ||
| 119 | + | ||
| 120 | + return stmt.executeUpdate(); | ||
| 121 | + } catch (SQLException e) { | ||
| 122 | + log.error("更新记录失败: {}", record, e); | ||
| 123 | + throw new RuntimeException("更新数据时发生数据库错误", e); | ||
| 124 | + } | ||
| 125 | + } | ||
| 126 | + | ||
| 127 | + | ||
| 128 | + /** | ||
| 129 | + * 向量相似度查询 (基于问题文本的向量) | ||
| 130 | + * @param question 问题文本 | ||
| 131 | + * @param limit 返回结果数量 | ||
| 132 | + * @param minSimilarity 最小相似度阈值(0-1) | ||
| 133 | + * @return 相似问答列表(按相似度降序) | ||
| 134 | + */ | ||
| 135 | + public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { | ||
| 136 | + List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 137 | + | ||
| 138 | + // 1. 参数校验 | ||
| 139 | + if (minSimilarity < 0 || minSimilarity > 1) { | ||
| 140 | + throw new IllegalArgumentException("相似度阈值必须在0到1之间"); | ||
| 141 | + } | ||
| 142 | + | ||
| 143 | + // 2. 获取问题的嵌入向量 | ||
| 144 | + Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding("1925730210204721154", question); | ||
| 145 | + float[] queryVector = embeddingResponse.content().vector(); | ||
| 146 | + | ||
| 147 | + // 3. 计算最大允许距离(1 - 相似度阈值) | ||
| 148 | + double maxDistance = 1 - minSimilarity; | ||
| 149 | + | ||
| 150 | + // 4. 执行向量相似度查询 | ||
| 151 | + String sql = "SELECT *, embedding <-> ? AS distance " + | ||
| 152 | + "FROM question_embedding " + | ||
| 153 | + "WHERE embedding <-> ? < ? " + // 距离小于阈值 | ||
| 154 | + "ORDER BY distance ASC " + // 按距离升序 | ||
| 155 | + "LIMIT ?"; | ||
| 156 | + | ||
| 157 | + try (Connection conn = getConnection(); | ||
| 158 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 159 | + | ||
| 160 | + // 设置参数 | ||
| 161 | + PGvector vector = new PGvector(queryVector); | ||
| 162 | + stmt.setObject(1, vector); | ||
| 163 | + stmt.setObject(2, vector); | ||
| 164 | + stmt.setDouble(3, maxDistance); | ||
| 165 | + stmt.setInt(4, limit); | ||
| 166 | + | ||
| 167 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 168 | + while (rs.next()) { | ||
| 169 | + QuestionEmbedding record = mapRowToQuestionEmbedding(rs); | ||
| 170 | + // 计算相似度(1 - 距离) | ||
| 171 | + double distance = rs.getDouble("distance"); | ||
| 172 | + double similarity = 1 - distance; | ||
| 173 | + record.setSimilarity(similarity); | ||
| 174 | + results.add(record); | ||
| 175 | + } | ||
| 176 | + } | ||
| 177 | + } catch (SQLException e) { | ||
| 178 | + log.error("向量相似度查询失败", e); | ||
| 179 | + throw new RuntimeException("执行向量相似度查询时发生数据库错误", e); | ||
| 180 | + } | ||
| 181 | + return results; | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + /** | ||
| 185 | + * 向量相似度查询 (直接使用向量) | ||
| 186 | + * @param vector 查询向量 | ||
| 187 | + * @param limit 返回结果数量 | ||
| 188 | + * @return 相似问答列表(按相似度降序) | ||
| 189 | + */ | ||
| 190 | + public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) { | ||
| 191 | + List<QuestionEmbedding> results = new ArrayList<>(); | ||
| 192 | + String sql = "SELECT *, embedding <-> ? AS similarity " + | ||
| 193 | + "FROM question_embedding " + | ||
| 194 | + "ORDER BY similarity ASC " + | ||
| 195 | + "LIMIT ?"; | ||
| 196 | + | ||
| 197 | + try (Connection conn = getConnection(); | ||
| 198 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 199 | + | ||
| 200 | + stmt.setObject(1, new PGvector(vector)); | ||
| 201 | + stmt.setInt(2, limit); | ||
| 202 | + | ||
| 203 | + try (ResultSet rs = stmt.executeQuery()) { | ||
| 204 | + while (rs.next()) { | ||
| 205 | + QuestionEmbedding record = mapRowToQuestionEmbedding(rs); | ||
| 206 | + double similarity = 1 - rs.getDouble("similarity"); | ||
| 207 | + record.setSimilarity(similarity); | ||
| 208 | + results.add(record); | ||
| 209 | + } | ||
| 210 | + } | ||
| 211 | + } catch (SQLException e) { | ||
| 212 | + log.error("向量相似度查询失败", e); | ||
| 213 | + throw new RuntimeException("执行向量相似度查询时发生数据库错误", e); | ||
| 214 | + } | ||
| 215 | + return results; | ||
| 216 | + } | ||
| 217 | + | ||
| 218 | + // 根据ID删除记录 | ||
| 219 | + public int deleteById(String id) { | ||
| 220 | + String sql = "DELETE FROM question_embedding WHERE id = ?"; | ||
| 221 | + | ||
| 222 | + try (Connection conn = getConnection(); | ||
| 223 | + PreparedStatement stmt = conn.prepareStatement(sql)) { | ||
| 224 | + | ||
| 225 | + stmt.setString(1, id); | ||
| 226 | + return stmt.executeUpdate(); | ||
| 227 | + } catch (SQLException e) { | ||
| 228 | + log.error("删除记录失败, ID: {}", id, e); | ||
| 229 | + throw new RuntimeException("删除数据时发生数据库错误", e); | ||
| 230 | + } | ||
| 231 | + } | ||
| 232 | + | ||
| 233 | + // 将ResultSet行映射为QuestionEmbedding对象 | ||
| 234 | + private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException { | ||
| 235 | + QuestionEmbedding record = new QuestionEmbedding(); | ||
| 236 | + record.setId(rs.getString("id")); | ||
| 237 | + record.setText(rs.getString("text")); | ||
| 238 | + record.setQuestion(rs.getString("question")); | ||
| 239 | + record.setAnswer(rs.getString("answer")); | ||
| 240 | + | ||
| 241 | + String metadataJson = rs.getString("metadata"); | ||
| 242 | + if (StringUtils.isNotBlank(metadataJson)) { | ||
| 243 | + record.setMetadata(""); | ||
| 244 | + } | ||
| 245 | + | ||
| 246 | + return record; | ||
| 247 | + } | ||
| 248 | + | ||
| 249 | + // 将Map转换为JSON字符串 | ||
| 250 | + private String toJson(Map<String, Object> map) { | ||
| 251 | + try { | ||
| 252 | + return new ObjectMapper().writeValueAsString(map); | ||
| 253 | + } catch (JsonProcessingException e) { | ||
| 254 | + log.error("元数据转换为JSON失败", e); | ||
| 255 | + return "{}"; | ||
| 256 | + } | ||
| 257 | + } | ||
| 258 | + | ||
| 259 | + // 将JSON字符串转换为Map | ||
| 260 | + private Map<String, Object> fromJson(String json) { | ||
| 261 | + try { | ||
| 262 | + return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {}); | ||
| 263 | + } catch (JsonProcessingException e) { | ||
| 264 | + log.error("JSON转换为元数据失败", e); | ||
| 265 | + return Collections.emptyMap(); | ||
| 266 | + } | ||
| 267 | + } | ||
| 268 | +} |
| 1 | +package org.jeecg.modules.airag.app.service; | ||
| 2 | + | ||
| 3 | +import org.jeecg.common.api.vo.Result; | ||
| 4 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 5 | +import org.springframework.web.multipart.MultipartFile; | ||
| 6 | + | ||
| 7 | +import java.util.List; | ||
| 8 | + | ||
| 9 | +public interface IQuestionEmbeddingService { | ||
| 10 | + List<QuestionEmbedding> findAll(); | ||
| 11 | + QuestionEmbedding findById(String id); | ||
| 12 | + int insert(QuestionEmbedding record); | ||
| 13 | + int update(QuestionEmbedding record); | ||
| 14 | + int deleteById(String id); | ||
| 15 | + List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity); | ||
| 16 | + List<QuestionEmbedding> similaritySearch(float[] vector, int limit); | ||
| 17 | + | ||
| 18 | + Result<?> processZipUpload(MultipartFile file); | ||
| 19 | +} |
| 1 | +package org.jeecg.modules.airag.app.service.impl; | ||
| 2 | + | ||
| 3 | +import org.apache.poi.hwpf.usermodel.CharacterRun; | ||
| 4 | +import org.apache.poi.hwpf.HWPFDocument; | ||
| 5 | +import org.apache.poi.hwpf.usermodel.Paragraph; | ||
| 6 | +import org.apache.poi.hwpf.usermodel.Range; | ||
| 7 | +import dev.langchain4j.data.document.Document; | ||
| 8 | +import dev.langchain4j.data.document.DocumentSplitter; | ||
| 9 | +import dev.langchain4j.data.document.loader.FileSystemDocumentLoader; | ||
| 10 | +import dev.langchain4j.data.document.parser.TextDocumentParser; | ||
| 11 | +import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter; | ||
| 12 | +import dev.langchain4j.data.document.splitter.DocumentSplitters; | ||
| 13 | +import dev.langchain4j.data.embedding.Embedding; | ||
| 14 | +import dev.langchain4j.data.segment.TextSegment; | ||
| 15 | +import dev.langchain4j.model.embedding.EmbeddingModel; | ||
| 16 | +import dev.langchain4j.model.output.Response; | ||
| 17 | +import org.apache.commons.io.FilenameUtils; | ||
| 18 | +import org.apache.poi.xwpf.usermodel.IBodyElement; | ||
| 19 | +import org.apache.poi.xwpf.usermodel.XWPFDocument; | ||
| 20 | +import org.apache.poi.xwpf.usermodel.XWPFParagraph; | ||
| 21 | +import org.apache.poi.xwpf.usermodel.XWPFTable; | ||
| 22 | +import org.jeecg.common.api.vo.Result; | ||
| 23 | +import org.jeecg.common.util.CommonUtils; | ||
| 24 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 25 | +import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper; | ||
| 26 | +import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | ||
| 27 | +import org.jeecg.modules.airag.app.utils.AiModelUtils; | ||
| 28 | +import org.jeecg.modules.airag.common.handler.IAIChatHandler; | ||
| 29 | +import org.slf4j.Logger; | ||
| 30 | +import org.slf4j.LoggerFactory; | ||
| 31 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 32 | +import org.springframework.beans.factory.annotation.Value; | ||
| 33 | +import org.springframework.stereotype.Service; | ||
| 34 | +import org.springframework.web.multipart.MultipartFile; | ||
| 35 | + | ||
| 36 | +import java.io.File; | ||
| 37 | +import java.io.FileInputStream; | ||
| 38 | +import java.io.IOException; | ||
| 39 | +import java.nio.file.Files; | ||
| 40 | +import java.nio.file.Path; | ||
| 41 | +import java.nio.file.Paths; | ||
| 42 | +import java.nio.file.StandardCopyOption; | ||
| 43 | +import java.util.*; | ||
| 44 | +import java.util.regex.Pattern; | ||
| 45 | +import java.util.stream.Collectors; | ||
| 46 | +import java.util.zip.ZipEntry; | ||
| 47 | +import java.util.zip.ZipInputStream; | ||
| 48 | + | ||
| 49 | +@Service | ||
| 50 | +public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService { | ||
| 51 | + | ||
| 52 | + private static final Logger log = LoggerFactory.getLogger(QuestionEmbeddingServiceImpl.class); | ||
| 53 | + | ||
| 54 | + @Autowired | ||
| 55 | + private QuestionEmbeddingMapper questionEmbeddingMapper; | ||
| 56 | + | ||
| 57 | + @Autowired | ||
| 58 | + private AiModelUtils aiModelUtils; | ||
| 59 | + | ||
| 60 | + @Autowired | ||
| 61 | + private IAIChatHandler aiChatHandler; | ||
| 62 | + | ||
| 63 | + @Value("${jeecg.upload.path}") | ||
| 64 | + private String uploadPath; | ||
| 65 | + | ||
| 66 | + private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx"); | ||
| 67 | + private static final Pattern SPECIAL_CHARS_PATTERN = Pattern.compile("[^a-zA-Z0-9\\u4e00-\\u9fa5\\s]"); | ||
| 68 | + | ||
| 69 | + @Override | ||
| 70 | + public List<QuestionEmbedding> findAll() { | ||
| 71 | + return questionEmbeddingMapper.findAll(); | ||
| 72 | + } | ||
| 73 | + | ||
| 74 | + @Override | ||
| 75 | + public QuestionEmbedding findById(String id) { | ||
| 76 | + return questionEmbeddingMapper.findById(id); | ||
| 77 | + } | ||
| 78 | + | ||
| 79 | + @Override | ||
| 80 | + public int insert(QuestionEmbedding record) { | ||
| 81 | + return questionEmbeddingMapper.insert(record); | ||
| 82 | + } | ||
| 83 | + | ||
| 84 | + @Override | ||
| 85 | + public int update(QuestionEmbedding record) { | ||
| 86 | + return questionEmbeddingMapper.update(record); | ||
| 87 | + } | ||
| 88 | + | ||
| 89 | + @Override | ||
| 90 | + public int deleteById(String id) { | ||
| 91 | + return questionEmbeddingMapper.deleteById(id); | ||
| 92 | + } | ||
| 93 | + | ||
| 94 | + @Override | ||
| 95 | + public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) { | ||
| 96 | + return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity); | ||
| 97 | + } | ||
| 98 | + | ||
| 99 | + @Override | ||
| 100 | + public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) { | ||
| 101 | + return questionEmbeddingMapper.similaritySearch(vector, limit); | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | + public Result<?> processZipUpload(MultipartFile zipFile) { | ||
| 105 | + try { | ||
| 106 | + Path tempDir = Files.createTempDirectory("zip_upload_"); | ||
| 107 | + List<Path> validFiles = extractAndFilterZip(zipFile, tempDir); | ||
| 108 | + | ||
| 109 | + if (validFiles.isEmpty()) { | ||
| 110 | + return Result.error("ZIP文件中没有有效的TXT或Word文档"); | ||
| 111 | + } | ||
| 112 | + | ||
| 113 | + for (Path filePath : validFiles) { | ||
| 114 | + processSingleFile(filePath); | ||
| 115 | + } | ||
| 116 | + | ||
| 117 | + return Result.OK("文件上传和处理成功"); | ||
| 118 | + } catch (Exception e) { | ||
| 119 | + log.error("处理ZIP文件上传失败", e); | ||
| 120 | + return Result.error("处理ZIP文件失败: " + e.getMessage()); | ||
| 121 | + } | ||
| 122 | + } | ||
| 123 | + | ||
| 124 | + private List<Path> extractAndFilterZip(MultipartFile zipFile, Path tempDir) throws IOException { | ||
| 125 | + List<Path> validFiles = new ArrayList<>(); | ||
| 126 | + | ||
| 127 | + try (ZipInputStream zipIn = new ZipInputStream(zipFile.getInputStream())) { | ||
| 128 | + ZipEntry entry; | ||
| 129 | + while ((entry = zipIn.getNextEntry()) != null) { | ||
| 130 | + if (!entry.isDirectory()) { | ||
| 131 | + String fileName = entry.getName(); | ||
| 132 | + String ext = FilenameUtils.getExtension(fileName).toLowerCase(); | ||
| 133 | + | ||
| 134 | + if (ALLOWED_EXTENSIONS.contains(ext)) { | ||
| 135 | + String safeFileName = new File(fileName).getName(); | ||
| 136 | + Path outputPath = tempDir.resolve(safeFileName); | ||
| 137 | + Files.copy(zipIn, outputPath, StandardCopyOption.REPLACE_EXISTING); | ||
| 138 | + validFiles.add(outputPath); | ||
| 139 | + } | ||
| 140 | + } | ||
| 141 | + zipIn.closeEntry(); | ||
| 142 | + } | ||
| 143 | + } | ||
| 144 | + return validFiles; | ||
| 145 | + } | ||
| 146 | + | ||
| 147 | + private void processSingleFile(Path filePath) throws Exception { | ||
| 148 | + String originalFileName = filePath.getFileName().toString(); | ||
| 149 | + String fileExt = FilenameUtils.getExtension(originalFileName); | ||
| 150 | + String newFileName = FilenameUtils.removeExtension(originalFileName) + "_" + UUID.randomUUID() + "." + fileExt; | ||
| 151 | + Path targetPath = Paths.get(uploadPath, newFileName); | ||
| 152 | + Files.move(filePath, targetPath, StandardCopyOption.REPLACE_EXISTING); | ||
| 153 | + | ||
| 154 | + List<String> segments; | ||
| 155 | + if (fileExt.equalsIgnoreCase("txt")) { | ||
| 156 | + String fileContent = readFileContent(targetPath); | ||
| 157 | + String cleanedContent = cleanText(fileContent); | ||
| 158 | + segments = splitTxtDocument(cleanedContent); | ||
| 159 | + } else { | ||
| 160 | + segments = splitWordDocument(targetPath.toString()); | ||
| 161 | + } | ||
| 162 | + | ||
| 163 | + saveSegmentsToDatabase(segments, originalFileName, newFileName); | ||
| 164 | + } | ||
| 165 | + | ||
| 166 | + private String readFileContent(Path filePath) throws IOException { | ||
| 167 | + return new String(Files.readAllBytes(filePath)); | ||
| 168 | + } | ||
| 169 | + | ||
| 170 | + private String cleanText(String text) { | ||
| 171 | + text = SPECIAL_CHARS_PATTERN.matcher(text).replaceAll(""); | ||
| 172 | + return text.replaceAll("\\s+", " ").trim(); | ||
| 173 | + } | ||
| 174 | + | ||
| 175 | + private List<String> splitTxtDocument(String content) { | ||
| 176 | + DocumentSplitter splitter = new DocumentByParagraphSplitter(1000, 200); | ||
| 177 | + Document document = Document.from(content); | ||
| 178 | + return splitter.split(document).stream() | ||
| 179 | + .map(TextSegment::text) | ||
| 180 | + .map(this::cleanText) | ||
| 181 | + .collect(Collectors.toList()); | ||
| 182 | + } | ||
| 183 | + | ||
| 184 | + public List<String> splitWordDocument(String filePath) throws Exception { | ||
| 185 | + List<String> result = new ArrayList<>(); | ||
| 186 | + String ext = FilenameUtils.getExtension(filePath).toLowerCase(); | ||
| 187 | + StringBuilder fullContent = new StringBuilder(); | ||
| 188 | + String fileName = new File(filePath).getName(); | ||
| 189 | + fileName = fileName.substring(0, fileName.lastIndexOf('.')); // 去掉后缀 | ||
| 190 | + | ||
| 191 | + if (ext.equals("docx")) { | ||
| 192 | + try (XWPFDocument doc = new XWPFDocument(new FileInputStream(filePath))) { | ||
| 193 | + StringBuilder currentSection = new StringBuilder(); | ||
| 194 | + boolean isTableSection = false; | ||
| 195 | + | ||
| 196 | + for (IBodyElement element : doc.getBodyElements()) { | ||
| 197 | + if (element instanceof XWPFParagraph) { | ||
| 198 | + XWPFParagraph para = (XWPFParagraph) element; | ||
| 199 | + String text = cleanText(para.getText()); | ||
| 200 | + fullContent.append(text).append("\n"); | ||
| 201 | + | ||
| 202 | + if (isTableSection) { | ||
| 203 | + result.add(currentSection.toString().trim()); | ||
| 204 | + currentSection = new StringBuilder(); | ||
| 205 | + isTableSection = false; | ||
| 206 | + } | ||
| 207 | + | ||
| 208 | + String style = para.getStyle(); | ||
| 209 | + if (style != null && style.matches("Heading\\d")) { | ||
| 210 | + if (currentSection.length() > 0) { | ||
| 211 | + result.add(currentSection.toString().trim()); | ||
| 212 | + } | ||
| 213 | + currentSection = new StringBuilder(text).append("\n"); | ||
| 214 | + } else { | ||
| 215 | + currentSection.append(text).append("\n"); | ||
| 216 | + } | ||
| 217 | + } else if (element instanceof XWPFTable) { | ||
| 218 | + String tableContent = extractTableContent((XWPFTable) element); | ||
| 219 | + fullContent.append(tableContent).append("\n"); | ||
| 220 | + | ||
| 221 | + if (!isTableSection) { | ||
| 222 | + if (currentSection.length() > 0) { | ||
| 223 | + result.add(currentSection.toString().trim()); | ||
| 224 | + } | ||
| 225 | + currentSection = new StringBuilder(); | ||
| 226 | + isTableSection = true; | ||
| 227 | + } | ||
| 228 | + currentSection.append(tableContent).append("\n"); | ||
| 229 | + } | ||
| 230 | + } | ||
| 231 | + | ||
| 232 | + if (currentSection.length() > 0) { | ||
| 233 | + result.add(currentSection.toString().trim()); | ||
| 234 | + } | ||
| 235 | + } | ||
| 236 | + } else if (ext.equals("doc")) { | ||
| 237 | + try (HWPFDocument doc = new HWPFDocument(new FileInputStream(filePath))) { | ||
| 238 | + Range range = doc.getRange(); | ||
| 239 | + StringBuilder currentSection = new StringBuilder(); | ||
| 240 | + boolean isTableSection = false; | ||
| 241 | + | ||
| 242 | + for (int i = 0; i < range.numParagraphs(); i++) { | ||
| 243 | + Paragraph para = range.getParagraph(i); | ||
| 244 | + String text = cleanText(para.text()); | ||
| 245 | + fullContent.append(text).append("\n"); | ||
| 246 | + | ||
| 247 | + if (para.isInTable()) { | ||
| 248 | + if (!isTableSection) { | ||
| 249 | + if (currentSection.length() > 0) { | ||
| 250 | + result.add(currentSection.toString().trim()); | ||
| 251 | + } | ||
| 252 | + currentSection = new StringBuilder(); | ||
| 253 | + isTableSection = true; | ||
| 254 | + } | ||
| 255 | + currentSection.append(text).append("\n"); | ||
| 256 | + } else { | ||
| 257 | + if (isTableSection) { | ||
| 258 | + result.add(currentSection.toString().trim()); | ||
| 259 | + currentSection = new StringBuilder(); | ||
| 260 | + isTableSection = false; | ||
| 261 | + } | ||
| 262 | + | ||
| 263 | + if (isHeading(para, range)) { | ||
| 264 | + if (currentSection.length() > 0) { | ||
| 265 | + result.add(currentSection.toString().trim()); | ||
| 266 | + } | ||
| 267 | + currentSection = new StringBuilder(text).append("\n"); | ||
| 268 | + } else { | ||
| 269 | + currentSection.append(text).append("\n"); | ||
| 270 | + } | ||
| 271 | + } | ||
| 272 | + } | ||
| 273 | + | ||
| 274 | + if (currentSection.length() > 0) { | ||
| 275 | + result.add(currentSection.toString().trim()); | ||
| 276 | + } | ||
| 277 | + } | ||
| 278 | + } | ||
| 279 | + | ||
| 280 | + if (fullContent.length() < 1000) { | ||
| 281 | + return Collections.singletonList(fileName + "\n" + fullContent.toString().trim()); | ||
| 282 | + } | ||
| 283 | + | ||
| 284 | + return result; | ||
| 285 | + } | ||
| 286 | + | ||
| 287 | + private String extractTableContent(XWPFTable table) { | ||
| 288 | + StringBuilder tableContent = new StringBuilder(); | ||
| 289 | + table.getRows().forEach(row -> { | ||
| 290 | + row.getTableCells().forEach(cell -> { | ||
| 291 | + tableContent.append("| ").append(cleanText(cell.getText())).append(" "); | ||
| 292 | + }); | ||
| 293 | + tableContent.append("|\n"); | ||
| 294 | + }); | ||
| 295 | + return tableContent.toString(); | ||
| 296 | + } | ||
| 297 | + | ||
| 298 | + private static boolean isHeading(Paragraph para, Range range) { | ||
| 299 | + int styleIndex = para.getStyleIndex(); | ||
| 300 | + if (styleIndex >= 1 && styleIndex <= 9) { | ||
| 301 | + return true; | ||
| 302 | + } | ||
| 303 | + | ||
| 304 | + try { | ||
| 305 | + CharacterRun run = para.getCharacterRun(0); | ||
| 306 | + if (run.isBold() || run.getFontSize() > 12) { | ||
| 307 | + return true; | ||
| 308 | + } | ||
| 309 | + } catch (Exception e) { | ||
| 310 | + log.warn("获取字符格式失败", e); | ||
| 311 | + } | ||
| 312 | + | ||
| 313 | + String text = para.text().trim(); | ||
| 314 | + return text.toUpperCase().equals(text) && | ||
| 315 | + text.length() < 100 && | ||
| 316 | + !text.contains(".") && | ||
| 317 | + !text.contains("\t"); | ||
| 318 | + } | ||
| 319 | + | ||
| 320 | + private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName) { | ||
| 321 | + if (segments.isEmpty()) { | ||
| 322 | + return; | ||
| 323 | + } | ||
| 324 | + | ||
| 325 | + String fileNameWithoutExt = originalFileName.substring(0, originalFileName.lastIndexOf('.')); | ||
| 326 | + String question = segments.size() == 1 ? fileNameWithoutExt : null; | ||
| 327 | + | ||
| 328 | + for (String segment : segments) { | ||
| 329 | + if (segment.trim().isEmpty()) { | ||
| 330 | + continue; | ||
| 331 | + } | ||
| 332 | + | ||
| 333 | + QuestionEmbedding record = new QuestionEmbedding(); | ||
| 334 | + record.setId(UUID.randomUUID().toString()); | ||
| 335 | + | ||
| 336 | + if (question != null) { | ||
| 337 | + record.setQuestion(question); | ||
| 338 | + } else { | ||
| 339 | + String firstLine = segment.lines().findFirst().orElse("未命名问题"); | ||
| 340 | + record.setQuestion(cleanText(firstLine)); | ||
| 341 | + } | ||
| 342 | + | ||
| 343 | + record.setAnswer(segment.trim()); | ||
| 344 | + record.setText(""); | ||
| 345 | + record.setMetadata("{\"fileName\":\"" + storedFileName + "\"}"); | ||
| 346 | + | ||
| 347 | + Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion()); | ||
| 348 | + record.setEmbedding(embeddingResponse.content().vector()); | ||
| 349 | + | ||
| 350 | + questionEmbeddingMapper.insert(record); | ||
| 351 | + } | ||
| 352 | + } | ||
| 353 | +} |
| 1 | +package org.jeecg.modules.airag.zdyrag.controller; | ||
| 2 | + | ||
| 3 | +import dev.langchain4j.data.message.ChatMessage; | ||
| 4 | +import dev.langchain4j.data.message.UserMessage; | ||
| 5 | +import dev.langchain4j.service.TokenStream; | ||
| 6 | +import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore; | ||
| 7 | +import io.swagger.v3.oas.annotations.Operation; | ||
| 8 | +import org.jeecg.ai.handler.AIParams; | ||
| 9 | +import org.jeecg.ai.handler.LLMHandler; | ||
| 10 | +import org.jeecg.common.api.vo.Result; | ||
| 11 | +import org.jeecg.modules.airag.app.entity.QuestionEmbedding; | ||
| 12 | +import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService; | ||
| 13 | +import org.jeecg.modules.airag.common.handler.IAIChatHandler; | ||
| 14 | +import org.jeecg.modules.airag.llm.handler.EmbeddingHandler; | ||
| 15 | +import org.springframework.beans.factory.annotation.Autowired; | ||
| 16 | +import org.springframework.stereotype.Component; | ||
| 17 | +import org.springframework.web.bind.annotation.GetMapping; | ||
| 18 | +import org.springframework.web.bind.annotation.RequestMapping; | ||
| 19 | +import org.springframework.web.bind.annotation.RestController; | ||
| 20 | + | ||
| 21 | +import java.util.ArrayList; | ||
| 22 | +import java.util.HashMap; | ||
| 23 | +import java.util.List; | ||
| 24 | +import java.util.Map; | ||
| 25 | + | ||
| 26 | +@RestController | ||
| 27 | +@RequestMapping("/airag/zdyRag") | ||
| 28 | +public class ZdyRagController { | ||
| 29 | + @Autowired | ||
| 30 | + private EmbeddingHandler embeddingHandler; | ||
| 31 | + @Autowired | ||
| 32 | + IAIChatHandler aiChatHandler; | ||
| 33 | + @Autowired | ||
| 34 | + LLMHandler llmHandler; | ||
| 35 | + @Autowired | ||
| 36 | + private IQuestionEmbeddingService questionEmbeddingService; | ||
| 37 | + | ||
| 38 | + | ||
| 39 | + @Operation(summary = "send") | ||
| 40 | + @GetMapping("send") | ||
| 41 | + public Result<Map<String, Object>> send(String questionText) { | ||
| 42 | + String knowId = "1926872137990148098"; | ||
| 43 | +// String text = "身份证丢失办理流程"; | ||
| 44 | + Integer topNumber = 3; | ||
| 45 | + Double similarity = 0.8; | ||
| 46 | + HashMap<String, Object> resMap = new HashMap<>(); | ||
| 47 | + //根据问题相似度进行查询 | ||
| 48 | + List<QuestionEmbedding> questionEmbeddings = questionEmbeddingService.similaritySearchByQuestion(questionText, 1,0.8); | ||
| 49 | + for (QuestionEmbedding questionEmbedding : questionEmbeddings) { | ||
| 50 | + resMap.put("question", questionEmbedding.getQuestion()); | ||
| 51 | + resMap.put("answer", questionEmbedding.getAnswer()); | ||
| 52 | + resMap.put("similarity", similarity); | ||
| 53 | + System.out.println("questionEmbedding.getQuestion() = " + questionEmbedding.getQuestion()); | ||
| 54 | + System.out.println("questionEmbedding.getAnswer() = " + questionEmbedding.getAnswer()); | ||
| 55 | + System.out.println("questionEmbedding.getSimilarity() = " + questionEmbedding.getSimilarity()); | ||
| 56 | + System.out.println("-------------------------------------------------------------"); | ||
| 57 | + } | ||
| 58 | + //返回问题库命中的问题 | ||
| 59 | + if (!questionEmbeddings.isEmpty()) { | ||
| 60 | + return Result.OK(resMap); | ||
| 61 | + } | ||
| 62 | + | ||
| 63 | + List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, topNumber, similarity); | ||
| 64 | + StringBuilder content = new StringBuilder(); | ||
| 65 | + for (Map<String, Object> map : maps) { | ||
| 66 | + if (Double.parseDouble(map.get("score").toString()) > similarity){ | ||
| 67 | + System.out.println("score = " + map.get("score").toString()); | ||
| 68 | + System.out.println("content = " + map.get("content").toString()); | ||
| 69 | + content.append(map.get("content").toString()).append("\n"); | ||
| 70 | + } | ||
| 71 | + } | ||
| 72 | + | ||
| 73 | + List<ChatMessage> messages = new ArrayList<>(); | ||
| 74 | + String questin = "请整理出与用户所提出的问题相关的信息,舍弃掉与问题无关的内容,进行整理,回答用户的问题" + | ||
| 75 | + "问题如下:" + questionText + | ||
| 76 | + "文本信息如下:" + content | ||
| 77 | + ; | ||
| 78 | + | ||
| 79 | + | ||
| 80 | + messages.add(new UserMessage("user", questin)); | ||
| 81 | +// AIParams aiParams = new AIParams(); | ||
| 82 | +// aiParams.setBaseUrl("http://localhost:11434"); | ||
| 83 | +// aiParams.setModelName("EntropyYue/chatglm3"); | ||
| 84 | +// aiParams.setProvider("OLLAMA"); | ||
| 85 | + String chat = aiChatHandler.completions("1926875898187878401", messages, null); | ||
| 86 | + resMap.put("question", questionText); | ||
| 87 | + resMap.put("answer", chat); | ||
| 88 | + return Result.OK(resMap); | ||
| 89 | + } | ||
| 90 | + | ||
| 91 | + public static void main(String[] args) { | ||
| 92 | + List<ChatMessage> messages = new ArrayList<>(); | ||
| 93 | + messages.add(new UserMessage("user", "你好,你是谁?")); | ||
| 94 | + LLMHandler llmHandler1 = new LLMHandler(); | ||
| 95 | + AIParams aiParams = new AIParams(); | ||
| 96 | + aiParams.setBaseUrl("http://localhost:11434"); | ||
| 97 | + aiParams.setModelName("EntropyYue/chatglm3"); | ||
| 98 | + aiParams.setProvider("OLLAMA"); | ||
| 99 | + TokenStream chat = llmHandler1.chat(messages, aiParams); | ||
| 100 | + System.out.println("chat = " + chat); | ||
| 101 | + | ||
| 102 | + } | ||
| 103 | + | ||
| 104 | +} |
| @@ -188,6 +188,8 @@ mybatis-plus: | @@ -188,6 +188,8 @@ mybatis-plus: | ||
| 188 | minidao: | 188 | minidao: |
| 189 | base-package: org.jeecg.modules.jmreport.*,org.jeecg.modules.drag.* | 189 | base-package: org.jeecg.modules.jmreport.*,org.jeecg.modules.drag.* |
| 190 | jeecg: | 190 | jeecg: |
| 191 | + upload: | ||
| 192 | + path: D:\\upload\\ | ||
| 191 | # AI集成 | 193 | # AI集成 |
| 192 | ai-chat: | 194 | ai-chat: |
| 193 | enabled: true | 195 | enabled: true |
-
请 注册 或 登录 后发表评论