作者 lixiang

问题库以及自定义问答

//package org.jeecg.modules.airag.app.config;
//
//import org.springframework.beans.factory.annotation.Qualifier;
//import org.springframework.boot.context.properties.ConfigurationProperties;
//import org.springframework.boot.jdbc.DataSourceBuilder;
//import org.springframework.context.annotation.Bean;
//import org.springframework.context.annotation.Configuration;
//import org.springframework.jdbc.core.JdbcTemplate;
//
//import javax.sql.DataSource;
//
//@Configuration
//public class PgVectorDataSourceConfig {
//
// @Bean(name = "pgVectorDataSource")
// @ConfigurationProperties(prefix = "spring.datasource.dynamic.datasource.pgvector")
// public DataSource pgVectorDataSource() {
// return DataSourceBuilder.create().build();
// }
//
// @Bean(name = "pgVectorJdbcTemplate")
// public JdbcTemplate pgVectorJdbcTemplate(
// @Qualifier("pgVectorDataSource") DataSource dataSource) {
// return new JdbcTemplate(dataSource);
// }
//}
\ No newline at end of file
... ...
//package org.jeecg.modules.airag.app.config;
//import dev.langchain4j.data.document.Metadata;
//import dev.langchain4j.data.embedding.Embedding;
//import dev.langchain4j.data.segment.TextSegment;
//import dev.langchain4j.store.embedding.EmbeddingMatch;
//import dev.langchain4j.store.embedding.EmbeddingSearchRequest;
//import dev.langchain4j.store.embedding.EmbeddingSearchResult;
//import dev.langchain4j.store.embedding.EmbeddingStore;
//import dev.langchain4j.store.embedding.filter.Filter;
//import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
//import lombok.extern.log4j.Log4j2;
//import org.springframework.beans.factory.annotation.Autowired;
//import org.springframework.beans.factory.annotation.Qualifier;
//import org.springframework.jdbc.core.JdbcTemplate;
//import org.springframework.stereotype.Component;
//
//import javax.sql.DataSource;
//import java.util.ArrayList;
//import java.util.Collection;
//import java.util.List;
//import java.util.Map;
//import java.util.Collection;
//import java.util.List;
//import java.util.stream.Collectors;
//
//@Component
//@Log4j2
//public class PostgreEmbeddingStore implements EmbeddingStore<TextSegment> {
//
// @Autowired
// private PgVectorEmbeddingStore pgVectorEmbeddingStore;
//
// @Autowired
// private JdbcTemplate pgJdbcTemplate;
//
// @Autowired
// public PostgreEmbeddingStore(
// PgVectorEmbeddingStore pgVectorEmbeddingStore) {
// this.pgJdbcTemplate = pgJdbcTemplate;
// this.pgVectorEmbeddingStore = pgVectorEmbeddingStore;
// }
//
//
// @Override
// public String add(Embedding embedding) {
// return "";
// }
//
// @Override
// public void add(String id, Embedding embedding) {
//
// }
//
// @Override
// public String add(Embedding embedding, TextSegment textSegment) {
// return "";
// }
//
// @Override
// public List<String> addAll(List<Embedding> embeddings) {
// return List.of();
// }
//
// @Override
// public List<String> addAll(List<Embedding> embeddings, List<TextSegment> embedded) {
// return List.of();
// }
//
// @Override
// public void remove(String id) {
// EmbeddingStore.super.remove(id);
// }
//
// @Override
// public void removeAll(Collection<String> ids) {
// EmbeddingStore.super.removeAll(ids);
// }
//
// @Override
// public void removeAll(Filter filter) {
// EmbeddingStore.super.removeAll(filter);
// }
//
// @Override
// public void removeAll() {
// EmbeddingStore.super.removeAll();
// }
//
// @Override
// public EmbeddingSearchResult<TextSegment> search(EmbeddingSearchRequest request) {
// return EmbeddingStore.super.search(request);
// }
//
// @Override
// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults) {
// return findRelevant(referenceEmbedding, maxResults, 0.0);
// }
//
// @Override
// public List<EmbeddingMatch<TextSegment>> findRelevant(Embedding referenceEmbedding, int maxResults, double minScore) {
// try {
//// // 使用 PgVectorEmbeddingStore 进行查询
//// EmbeddingSearchRequest request = EmbeddingSearchRequest.builder()
//// .queryEmbedding(referenceEmbedding)
//// .maxResults(maxResults)
//// .minScore(minScore)
//// .build();
//
// // 构建带内存ID过滤的查询
// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " +
// "FROM embeddings " +
// "WHERE (1 - (embedding <=> ?)) >= ? " +
// "ORDER BY distance " +
// "LIMIT ?";
//
// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList(
// sql,
// referenceEmbedding.vectorAsList(),
// referenceEmbedding.vectorAsList(),
// minScore,
// maxResults
// );
//
//
//
//// EmbeddingSearchResult<TextSegment> result = pgVectorEmbeddingStore.search(request);
////
////
// return convertToMatches(rows);
// } catch (Exception e) {
// log.error("向量查询失败", e);
// throw new RuntimeException("向量搜索失败: " + e.getMessage(), e);
// }
// }
//
// @Override
// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding, int maxResults) {
// return findRelevant(memoryId, referenceEmbedding, maxResults, 0.0);
// }
//
// @Override
// public List<EmbeddingMatch<TextSegment>> findRelevant(Object memoryId, Embedding referenceEmbedding,
// int maxResults, double minScore) {
// try {
// // 构建带内存ID过滤的查询
// String sql = "SELECT id, content, metadata, embedding <=> ? AS distance " +
// "FROM embeddings " +
// "WHERE metadata->>'memory_id' = ? " +
// "AND (1 - (embedding <=> ?)) >= ? " +
// "ORDER BY distance " +
// "LIMIT ?";
//
// List<Map<String, Object>> rows = pgJdbcTemplate.queryForList(
// sql,
// referenceEmbedding.vectorAsList(),
// memoryId.toString(),
// referenceEmbedding.vectorAsList(),
// minScore,
// maxResults
// );
//
// return convertToMatches(rows);
// } catch (Exception e) {
// log.error("带内存ID的向量查询失败", e);
// throw new RuntimeException("带内存ID的向量搜索失败: " + e.getMessage(), e);
// }
// }
//
// private List<EmbeddingMatch<TextSegment>> convertToMatches(List<Map<String, Object>> rows) {
// List<EmbeddingMatch<TextSegment>> matches = new ArrayList<>();
// for (Map<String, Object> row : rows) {
// String id = row.get("id").toString();
// String content = (String) row.get("content");
//
// // 处理 Metadata
// Map<String, String> metadataMap = (Map<String, String>) row.get("metadata");
// Metadata metadata = Metadata.from(metadataMap);
//
// // 处理 Embedding 转换
// List<Float> embeddingList = (List<Float>) row.get("embedding");
// float[] embeddingArray = new float[embeddingList.size()];
// for (int i = 0; i < embeddingList.size(); i++) {
// embeddingArray[i] = embeddingList.get(i);
// }
// Embedding embedding = new Embedding(embeddingArray);
//
// double score = 1 - (double) row.get("distance");
// TextSegment textSegment = TextSegment.from(content, metadata);
//
// matches.add(new EmbeddingMatch<>(score, id, embedding, textSegment));
// }
// return matches;
// }
//}
... ...
//package org.jeecg.modules.airag.app.config;
//
//import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
//import org.springframework.beans.factory.annotation.Value;
//import org.springframework.context.annotation.Bean;
//import org.springframework.context.annotation.Configuration;
//
//@Configuration
//public class VectorStoreConfig {
// @Value("${jeecg.ai-rag.embed-store.host}")
// private String host;
// @Value("${jeecg.ai-rag.embed-store.port}")
// private Integer port;
// @Value("${jeecg.ai-rag.embed-store.database}")
// private String database;
// @Value("${jeecg.ai-rag.embed-store.user}")
// private String user;
// @Value("${jeecg.ai-rag.embed-store.password}")
// private String password;
//// @Value("${spring.datasource.vector.url}")
//// private String url;
////
//// @Value("${spring.datasource.vector.username}")
//// private String username;
////
//// @Value("${spring.datasource.vector.password}")
//// private String password;
//
// @Bean
// public PgVectorEmbeddingStore pgVectorEmbeddingStore() {
// return PgVectorEmbeddingStore.builder()
// .host(host) // 如果使用独立参数
// .port(port)
// .user(user)
// .password(password)
// .database(database)
// .table("embeddings")
// .dimension(1536) // 根据你的模型维度设置
// .useIndex(true)
// .createTable(true)
// .build();
//
//// return PgVectorEmbeddingStore.builder()
//// .dataSource(dataSource) // 注入配置好的DataSource
//// .table("embeddings")
//// .dimension(1536)
//// .build();
// }
//}
\ No newline at end of file
... ...
... ... @@ -42,6 +42,7 @@ import org.jeecg.modules.airag.app.service.IEmbeddingsService;
import org.jeecg.common.system.base.controller.JeecgController;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.jeecg.modules.airag.llm.entity.AiragModel;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.jdbc.DataSourceBuilder;
... ... @@ -69,6 +70,9 @@ public class EmbeddingsController {
@Autowired
private AiModelUtils aiModelUtils;
@Autowired
private EmbeddingHandler embeddingHandler;
/**
* 分页列表查询
... ... @@ -86,6 +90,13 @@ public class EmbeddingsController {
@RequestParam(name="pageNo", defaultValue="1") Integer pageNo,
@RequestParam(name="pageSize", defaultValue="10") Integer pageSize,
HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
// AiragModel airagModel = new AiragModel();
// airagModel.setId("1925730210204721154");
// airagModel.setProvider("OLLAMA");
// airagModel.setModelName("nomic-embed-text");
// airagModel.setBaseUrl("http://localhost:11434");
// EmbeddingStore<TextSegment> embedStore = embeddingHandler.getEmbedStore(airagModel);
// embeddingHandler.searchEmbedding()
Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", "33333");
List<Embeddings> records = embeddingsService.findAll();
... ...
package org.jeecg.modules.airag.app.controller;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
@RestController
@RequestMapping("/question/embedding")
public class QuestionEmbeddingController {
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@GetMapping("/list")
public Result<List<QuestionEmbedding>> findAll() {
List<QuestionEmbedding> list = questionEmbeddingService.findAll();
return Result.OK(list);
}
@GetMapping("/queryById")
public Result<QuestionEmbedding> findById(@RequestParam String id) {
QuestionEmbedding record = questionEmbeddingService.findById(id);
if (record == null) {
return Result.error("未找到对应数据");
}
return Result.OK(record);
}
@PostMapping("/add")
public Result<String> insert(@RequestBody QuestionEmbedding record) {
int result = questionEmbeddingService.insert(record);
return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败");
}
@RequestMapping(value = "/edit", method = {RequestMethod.PUT,RequestMethod.POST})
public Result<String> update(@RequestBody QuestionEmbedding record) {
int result = questionEmbeddingService.update(record);
return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败");
}
@DeleteMapping("/delete")
public Result<String> deleteById(@RequestParam String id) {
int result = questionEmbeddingService.deleteById(id);
return result > 0 ? Result.OK("删除成功!") : Result.error("删除失败");
}
@PostMapping("/uploadZip")
@Transactional(rollbackFor = {Exception.class})
public Result<?> uploadZip(@RequestParam("file") MultipartFile file) {
return questionEmbeddingService.processZipUpload(file);
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.entity;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.util.Map;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class QuestionEmbedding {
private String id;
private String text;
private String question;
private String answer;
private String metadata;
private float[] embedding;
private Double similarity;
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.mapper;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.*;
@Component
@Slf4j
public class QuestionEmbeddingMapper {
@Autowired
private AiModelUtils aiModelUtils;
// PostgreSQL连接参数(应与项目配置一致)
private static final String URL = "jdbc:postgresql://192.168.100.103:5432/postgres";
private static final String USER = "postgres";
private static final String PASSWORD = "postgres";
// 获取数据库连接
private Connection getConnection() throws SQLException {
return DriverManager.getConnection(URL, USER, PASSWORD);
}
// 查询所有记录
public List<QuestionEmbedding> findAll() {
List<QuestionEmbedding> results = new ArrayList<>();
String sql = "SELECT * FROM question_embedding";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql);
ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
results.add(mapRowToQuestionEmbedding(rs));
}
} catch (SQLException e) {
log.error("查询所有记录失败", e);
throw new RuntimeException("查询数据时发生数据库错误", e);
}
return results;
}
// 根据ID查询单个记录
public QuestionEmbedding findById(String id) {
String sql = "SELECT * FROM question_embedding WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
try (ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
return mapRowToQuestionEmbedding(rs);
}
}
} catch (SQLException e) {
log.error("根据ID查询记录失败, ID: {}", id, e);
throw new RuntimeException("根据ID查询时发生数据库错误", e);
}
return null;
}
// 插入新记录
public int insert(QuestionEmbedding record) {
String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, UUID.randomUUID().toString());
stmt.setString(2, record.getText());
stmt.setString(3, record.getQuestion());
stmt.setString(4, record.getAnswer());
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue("{\"name\":\"John\", \"age\":30}");
stmt.setObject(5, jsonObject);
Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion());
stmt.setObject(6, embedding.content().vector());
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("插入记录失败: {}", record, e);
throw new RuntimeException("插入数据时发生数据库错误", e);
}
}
// 更新记录
public int update(QuestionEmbedding record) {
String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, record.getText());
stmt.setString(2, record.getQuestion());
stmt.setString(3, record.getAnswer());
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue("{\"name\":\"John\", \"age\":30}");
stmt.setObject(4, jsonObject);
Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion());
stmt.setObject(5, embedding.content().vector());
stmt.setString(6, record.getId());
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("更新记录失败: {}", record, e);
throw new RuntimeException("更新数据时发生数据库错误", e);
}
}
/**
* 向量相似度查询 (基于问题文本的向量)
* @param question 问题文本
* @param limit 返回结果数量
* @param minSimilarity 最小相似度阈值(0-1)
* @return 相似问答列表(按相似度降序)
*/
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
List<QuestionEmbedding> results = new ArrayList<>();
// 1. 参数校验
if (minSimilarity < 0 || minSimilarity > 1) {
throw new IllegalArgumentException("相似度阈值必须在0到1之间");
}
// 2. 获取问题的嵌入向量
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding("1925730210204721154", question);
float[] queryVector = embeddingResponse.content().vector();
// 3. 计算最大允许距离(1 - 相似度阈值)
double maxDistance = 1 - minSimilarity;
// 4. 执行向量相似度查询
String sql = "SELECT *, embedding <-> ? AS distance " +
"FROM question_embedding " +
"WHERE embedding <-> ? < ? " + // 距离小于阈值
"ORDER BY distance ASC " + // 按距离升序
"LIMIT ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
// 设置参数
PGvector vector = new PGvector(queryVector);
stmt.setObject(1, vector);
stmt.setObject(2, vector);
stmt.setDouble(3, maxDistance);
stmt.setInt(4, limit);
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
// 计算相似度(1 - 距离)
double distance = rs.getDouble("distance");
double similarity = 1 - distance;
record.setSimilarity(similarity);
results.add(record);
}
}
} catch (SQLException e) {
log.error("向量相似度查询失败", e);
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
}
return results;
}
/**
* 向量相似度查询 (直接使用向量)
* @param vector 查询向量
* @param limit 返回结果数量
* @return 相似问答列表(按相似度降序)
*/
public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) {
List<QuestionEmbedding> results = new ArrayList<>();
String sql = "SELECT *, embedding <-> ? AS similarity " +
"FROM question_embedding " +
"ORDER BY similarity ASC " +
"LIMIT ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setObject(1, new PGvector(vector));
stmt.setInt(2, limit);
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
double similarity = 1 - rs.getDouble("similarity");
record.setSimilarity(similarity);
results.add(record);
}
}
} catch (SQLException e) {
log.error("向量相似度查询失败", e);
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
}
return results;
}
// 根据ID删除记录
public int deleteById(String id) {
String sql = "DELETE FROM question_embedding WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("删除记录失败, ID: {}", id, e);
throw new RuntimeException("删除数据时发生数据库错误", e);
}
}
// 将ResultSet行映射为QuestionEmbedding对象
private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException {
QuestionEmbedding record = new QuestionEmbedding();
record.setId(rs.getString("id"));
record.setText(rs.getString("text"));
record.setQuestion(rs.getString("question"));
record.setAnswer(rs.getString("answer"));
String metadataJson = rs.getString("metadata");
if (StringUtils.isNotBlank(metadataJson)) {
record.setMetadata("");
}
return record;
}
// 将Map转换为JSON字符串
private String toJson(Map<String, Object> map) {
try {
return new ObjectMapper().writeValueAsString(map);
} catch (JsonProcessingException e) {
log.error("元数据转换为JSON失败", e);
return "{}";
}
}
// 将JSON字符串转换为Map
private Map<String, Object> fromJson(String json) {
try {
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
log.error("JSON转换为元数据失败", e);
return Collections.emptyMap();
}
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.service;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.springframework.web.multipart.MultipartFile;
import java.util.List;
public interface IQuestionEmbeddingService {
List<QuestionEmbedding> findAll();
QuestionEmbedding findById(String id);
int insert(QuestionEmbedding record);
int update(QuestionEmbedding record);
int deleteById(String id);
List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity);
List<QuestionEmbedding> similaritySearch(float[] vector, int limit);
Result<?> processZipUpload(MultipartFile file);
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.service.impl;
import org.apache.poi.hwpf.usermodel.CharacterRun;
import org.apache.poi.hwpf.HWPFDocument;
import org.apache.poi.hwpf.usermodel.Paragraph;
import org.apache.poi.hwpf.usermodel.Range;
import dev.langchain4j.data.document.Document;
import dev.langchain4j.data.document.DocumentSplitter;
import dev.langchain4j.data.document.loader.FileSystemDocumentLoader;
import dev.langchain4j.data.document.parser.TextDocumentParser;
import dev.langchain4j.data.document.splitter.DocumentByParagraphSplitter;
import dev.langchain4j.data.document.splitter.DocumentSplitters;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.data.segment.TextSegment;
import dev.langchain4j.model.embedding.EmbeddingModel;
import dev.langchain4j.model.output.Response;
import org.apache.commons.io.FilenameUtils;
import org.apache.poi.xwpf.usermodel.IBodyElement;
import org.apache.poi.xwpf.usermodel.XWPFDocument;
import org.apache.poi.xwpf.usermodel.XWPFParagraph;
import org.apache.poi.xwpf.usermodel.XWPFTable;
import org.jeecg.common.api.vo.Result;
import org.jeecg.common.util.CommonUtils;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.multipart.MultipartFile;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.StandardCopyOption;
import java.util.*;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
@Service
public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
private static final Logger log = LoggerFactory.getLogger(QuestionEmbeddingServiceImpl.class);
@Autowired
private QuestionEmbeddingMapper questionEmbeddingMapper;
@Autowired
private AiModelUtils aiModelUtils;
@Autowired
private IAIChatHandler aiChatHandler;
@Value("${jeecg.upload.path}")
private String uploadPath;
private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx");
private static final Pattern SPECIAL_CHARS_PATTERN = Pattern.compile("[^a-zA-Z0-9\\u4e00-\\u9fa5\\s]");
@Override
public List<QuestionEmbedding> findAll() {
return questionEmbeddingMapper.findAll();
}
@Override
public QuestionEmbedding findById(String id) {
return questionEmbeddingMapper.findById(id);
}
@Override
public int insert(QuestionEmbedding record) {
return questionEmbeddingMapper.insert(record);
}
@Override
public int update(QuestionEmbedding record) {
return questionEmbeddingMapper.update(record);
}
@Override
public int deleteById(String id) {
return questionEmbeddingMapper.deleteById(id);
}
@Override
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity);
}
@Override
public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) {
return questionEmbeddingMapper.similaritySearch(vector, limit);
}
public Result<?> processZipUpload(MultipartFile zipFile) {
try {
Path tempDir = Files.createTempDirectory("zip_upload_");
List<Path> validFiles = extractAndFilterZip(zipFile, tempDir);
if (validFiles.isEmpty()) {
return Result.error("ZIP文件中没有有效的TXT或Word文档");
}
for (Path filePath : validFiles) {
processSingleFile(filePath);
}
return Result.OK("文件上传和处理成功");
} catch (Exception e) {
log.error("处理ZIP文件上传失败", e);
return Result.error("处理ZIP文件失败: " + e.getMessage());
}
}
private List<Path> extractAndFilterZip(MultipartFile zipFile, Path tempDir) throws IOException {
List<Path> validFiles = new ArrayList<>();
try (ZipInputStream zipIn = new ZipInputStream(zipFile.getInputStream())) {
ZipEntry entry;
while ((entry = zipIn.getNextEntry()) != null) {
if (!entry.isDirectory()) {
String fileName = entry.getName();
String ext = FilenameUtils.getExtension(fileName).toLowerCase();
if (ALLOWED_EXTENSIONS.contains(ext)) {
String safeFileName = new File(fileName).getName();
Path outputPath = tempDir.resolve(safeFileName);
Files.copy(zipIn, outputPath, StandardCopyOption.REPLACE_EXISTING);
validFiles.add(outputPath);
}
}
zipIn.closeEntry();
}
}
return validFiles;
}
private void processSingleFile(Path filePath) throws Exception {
String originalFileName = filePath.getFileName().toString();
String fileExt = FilenameUtils.getExtension(originalFileName);
String newFileName = FilenameUtils.removeExtension(originalFileName) + "_" + UUID.randomUUID() + "." + fileExt;
Path targetPath = Paths.get(uploadPath, newFileName);
Files.move(filePath, targetPath, StandardCopyOption.REPLACE_EXISTING);
List<String> segments;
if (fileExt.equalsIgnoreCase("txt")) {
String fileContent = readFileContent(targetPath);
String cleanedContent = cleanText(fileContent);
segments = splitTxtDocument(cleanedContent);
} else {
segments = splitWordDocument(targetPath.toString());
}
saveSegmentsToDatabase(segments, originalFileName, newFileName);
}
private String readFileContent(Path filePath) throws IOException {
return new String(Files.readAllBytes(filePath));
}
private String cleanText(String text) {
text = SPECIAL_CHARS_PATTERN.matcher(text).replaceAll("");
return text.replaceAll("\\s+", " ").trim();
}
private List<String> splitTxtDocument(String content) {
DocumentSplitter splitter = new DocumentByParagraphSplitter(1000, 200);
Document document = Document.from(content);
return splitter.split(document).stream()
.map(TextSegment::text)
.map(this::cleanText)
.collect(Collectors.toList());
}
public List<String> splitWordDocument(String filePath) throws Exception {
List<String> result = new ArrayList<>();
String ext = FilenameUtils.getExtension(filePath).toLowerCase();
StringBuilder fullContent = new StringBuilder();
String fileName = new File(filePath).getName();
fileName = fileName.substring(0, fileName.lastIndexOf('.')); // 去掉后缀
if (ext.equals("docx")) {
try (XWPFDocument doc = new XWPFDocument(new FileInputStream(filePath))) {
StringBuilder currentSection = new StringBuilder();
boolean isTableSection = false;
for (IBodyElement element : doc.getBodyElements()) {
if (element instanceof XWPFParagraph) {
XWPFParagraph para = (XWPFParagraph) element;
String text = cleanText(para.getText());
fullContent.append(text).append("\n");
if (isTableSection) {
result.add(currentSection.toString().trim());
currentSection = new StringBuilder();
isTableSection = false;
}
String style = para.getStyle();
if (style != null && style.matches("Heading\\d")) {
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
currentSection = new StringBuilder(text).append("\n");
} else {
currentSection.append(text).append("\n");
}
} else if (element instanceof XWPFTable) {
String tableContent = extractTableContent((XWPFTable) element);
fullContent.append(tableContent).append("\n");
if (!isTableSection) {
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
currentSection = new StringBuilder();
isTableSection = true;
}
currentSection.append(tableContent).append("\n");
}
}
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
}
} else if (ext.equals("doc")) {
try (HWPFDocument doc = new HWPFDocument(new FileInputStream(filePath))) {
Range range = doc.getRange();
StringBuilder currentSection = new StringBuilder();
boolean isTableSection = false;
for (int i = 0; i < range.numParagraphs(); i++) {
Paragraph para = range.getParagraph(i);
String text = cleanText(para.text());
fullContent.append(text).append("\n");
if (para.isInTable()) {
if (!isTableSection) {
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
currentSection = new StringBuilder();
isTableSection = true;
}
currentSection.append(text).append("\n");
} else {
if (isTableSection) {
result.add(currentSection.toString().trim());
currentSection = new StringBuilder();
isTableSection = false;
}
if (isHeading(para, range)) {
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
currentSection = new StringBuilder(text).append("\n");
} else {
currentSection.append(text).append("\n");
}
}
}
if (currentSection.length() > 0) {
result.add(currentSection.toString().trim());
}
}
}
if (fullContent.length() < 1000) {
return Collections.singletonList(fileName + "\n" + fullContent.toString().trim());
}
return result;
}
private String extractTableContent(XWPFTable table) {
StringBuilder tableContent = new StringBuilder();
table.getRows().forEach(row -> {
row.getTableCells().forEach(cell -> {
tableContent.append("| ").append(cleanText(cell.getText())).append(" ");
});
tableContent.append("|\n");
});
return tableContent.toString();
}
private static boolean isHeading(Paragraph para, Range range) {
int styleIndex = para.getStyleIndex();
if (styleIndex >= 1 && styleIndex <= 9) {
return true;
}
try {
CharacterRun run = para.getCharacterRun(0);
if (run.isBold() || run.getFontSize() > 12) {
return true;
}
} catch (Exception e) {
log.warn("获取字符格式失败", e);
}
String text = para.text().trim();
return text.toUpperCase().equals(text) &&
text.length() < 100 &&
!text.contains(".") &&
!text.contains("\t");
}
private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName) {
if (segments.isEmpty()) {
return;
}
String fileNameWithoutExt = originalFileName.substring(0, originalFileName.lastIndexOf('.'));
String question = segments.size() == 1 ? fileNameWithoutExt : null;
for (String segment : segments) {
if (segment.trim().isEmpty()) {
continue;
}
QuestionEmbedding record = new QuestionEmbedding();
record.setId(UUID.randomUUID().toString());
if (question != null) {
record.setQuestion(question);
} else {
String firstLine = segment.lines().findFirst().orElse("未命名问题");
record.setQuestion(cleanText(firstLine));
}
record.setAnswer(segment.trim());
record.setText("");
record.setMetadata("{\"fileName\":\"" + storedFileName + "\"}");
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding("1925730210204721154", record.getQuestion());
record.setEmbedding(embeddingResponse.content().vector());
questionEmbeddingMapper.insert(record);
}
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.zdyrag.controller;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import io.swagger.v3.oas.annotations.Operation;
import org.jeecg.ai.handler.AIParams;
import org.jeecg.ai.handler.LLMHandler;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@RestController
@RequestMapping("/airag/zdyRag")
public class ZdyRagController {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
IAIChatHandler aiChatHandler;
@Autowired
LLMHandler llmHandler;
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@Operation(summary = "send")
@GetMapping("send")
public Result<Map<String, Object>> send(String questionText) {
String knowId = "1926872137990148098";
// String text = "身份证丢失办理流程";
Integer topNumber = 3;
Double similarity = 0.8;
HashMap<String, Object> resMap = new HashMap<>();
//根据问题相似度进行查询
List<QuestionEmbedding> questionEmbeddings = questionEmbeddingService.similaritySearchByQuestion(questionText, 1,0.8);
for (QuestionEmbedding questionEmbedding : questionEmbeddings) {
resMap.put("question", questionEmbedding.getQuestion());
resMap.put("answer", questionEmbedding.getAnswer());
resMap.put("similarity", similarity);
System.out.println("questionEmbedding.getQuestion() = " + questionEmbedding.getQuestion());
System.out.println("questionEmbedding.getAnswer() = " + questionEmbedding.getAnswer());
System.out.println("questionEmbedding.getSimilarity() = " + questionEmbedding.getSimilarity());
System.out.println("-------------------------------------------------------------");
}
//返回问题库命中的问题
if (!questionEmbeddings.isEmpty()) {
return Result.OK(resMap);
}
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, topNumber, similarity);
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > similarity){
System.out.println("score = " + map.get("score").toString());
System.out.println("content = " + map.get("content").toString());
content.append(map.get("content").toString()).append("\n");
}
}
List<ChatMessage> messages = new ArrayList<>();
String questin = "请整理出与用户所提出的问题相关的信息,舍弃掉与问题无关的内容,进行整理,回答用户的问题" +
"问题如下:" + questionText +
"文本信息如下:" + content
;
messages.add(new UserMessage("user", questin));
// AIParams aiParams = new AIParams();
// aiParams.setBaseUrl("http://localhost:11434");
// aiParams.setModelName("EntropyYue/chatglm3");
// aiParams.setProvider("OLLAMA");
String chat = aiChatHandler.completions("1926875898187878401", messages, null);
resMap.put("question", questionText);
resMap.put("answer", chat);
return Result.OK(resMap);
}
public static void main(String[] args) {
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage("user", "你好,你是谁?"));
LLMHandler llmHandler1 = new LLMHandler();
AIParams aiParams = new AIParams();
aiParams.setBaseUrl("http://localhost:11434");
aiParams.setModelName("EntropyYue/chatglm3");
aiParams.setProvider("OLLAMA");
TokenStream chat = llmHandler1.chat(messages, aiParams);
System.out.println("chat = " + chat);
}
}
... ...
... ... @@ -188,6 +188,8 @@ mybatis-plus:
minidao:
base-package: org.jeecg.modules.jmreport.*,org.jeecg.modules.drag.*
jeecg:
upload:
path: D:\\upload\\
# AI集成
ai-chat:
enabled: true
... ...