作者 lixiang

修改为双数据源

... ... @@ -17,14 +17,12 @@ import org.jeecg.modules.airag.app.utils.JsonUtils;
import org.jeecg.modules.airag.llm.entity.AiragKnowledge;
import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.web.bind.annotation.*;
import org.springframework.web.multipart.MultipartFile;
import java.util.Arrays;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.*;
import java.util.stream.Collectors;
@RestController
... ... @@ -44,13 +42,10 @@ public class QuestionEmbeddingController {
.collect(Collectors.toMap(AiragKnowledge::getId, AiragKnowledge::getName));
page.getRecords().forEach(item -> {
String metadata = item.getMetadata();
if (StringUtils.isNotBlank(metadata)) {
Map<String, String> jsonMap = JsonUtils.jsonUtils(metadata);
Map<String, Object> jsonMap = item.getMetadata();
if (jsonMap.containsKey("knowledgeId")) {
item.setKnowledgeName(airagKnowledgeMap.get(jsonMap.get("knowledgeId")));
item.setKnowledgeId(jsonMap.get("knowledgeId"));
}
item.setKnowledgeId(jsonMap.get("knowledgeId").toString());
}
});
... ... @@ -86,12 +81,9 @@ public class QuestionEmbeddingController {
String docId = String.valueOf(snowflakeGenerator.next());
metadata.put("docId", docId); // 自动生成唯一文档ID
metadata.put("knowledgeId", record.getKnowledgeId());
// 使用 Jackson 序列化 Map 到 JSON
ObjectMapper mapper = new ObjectMapper();
String metadataJson = mapper.writeValueAsString(metadata);
// 2. 设置到embeddings对象
record.setMetadata(metadataJson);
record.setMetadata(metadata);
record.setId(UUID.randomUUID().toString());
int result = questionEmbeddingService.insert(record);
return result > 0 ? Result.OK("添加成功!") : Result.error("添加失败");
}
... ... @@ -112,14 +104,10 @@ public class QuestionEmbeddingController {
String knowledgeName = airagKnowledgeMap.get(record.getKnowledgeId());
record.setKnowledgeName(knowledgeName);
String existMetadata = existRecord.getMetadata();
Map<String, String> jsonMap = new HashMap<>();
if (StringUtils.isNotBlank(existMetadata)) {
jsonMap = JsonUtils.jsonUtils(existMetadata);
}
Map<String, Object> metadata = existRecord.getMetadata();
jsonMap.put("knowledgeId", record.getKnowledgeId());
record.setMetadata(Json.toJson(jsonMap));
metadata.put("knowledgeId", record.getKnowledgeId());
record.setMetadata(metadata);
}
int result = questionEmbeddingService.update(record);
return result > 0 ? Result.OK("编辑成功!") : Result.error("编辑失败");
... ... @@ -144,7 +132,6 @@ public class QuestionEmbeddingController {
}
@PostMapping("/uploadZip")
@Transactional(rollbackFor = {Exception.class})
public Result<?> uploadZip(
@RequestParam("file") MultipartFile file,
@RequestParam("knowledgeId") String knowledgeId) {
... ...
... ... @@ -30,7 +30,7 @@ public class QuestionEmbedding {
/**
* 元数据
*/
private String metadata;
private Map<String, Object> metadata;
/**
* 向量
*/
... ...
package org.jeecg.modules.airag.app.mapper;
import cn.hutool.core.lang.generator.SnowflakeGenerator;
import com.alibaba.fastjson2.JSONObject;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import io.minio.messages.Metadata;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;
import java.util.List;
@Component
@Slf4j
public class QuestionEmbeddingMapper {
@Mapper
@DS("pgvector")
public interface QuestionEmbeddingMapper {
Page<QuestionEmbedding> findAll(IPage<QuestionEmbedding> page, @Param("questionEmbedding") QuestionEmbedding questionEmbedding);
@Autowired
private AiModelUtils aiModelUtils;
Integer findQuestionCount(@Param("questionEmbedding") QuestionEmbedding questionEmbedding);
@Value("${jeecg.ai-chat.embedId}")
private String embedId;
// PostgreSQL连接参数(应与项目配置一致)
private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
private static final String USER = "postgres";
private static final String PASSWORD = "postgres";
QuestionEmbedding findById(@Param("id") String id);
@DS("pgvector")
int insert(@Param("record") QuestionEmbedding record);
// 获取数据库连接
private Connection getConnection() throws SQLException {
return DriverManager.getConnection(URL, USER, PASSWORD);
}
int update(@Param("record") QuestionEmbedding record);
// 查询所有记录
public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, int pageNo, int pageSize) {
List<QuestionEmbedding> results = new ArrayList<>();
StringBuilder sql = new StringBuilder("select * from question_embedding where 1 = 1");
StringBuilder countSql = new StringBuilder("select count(1) from question_embedding where 1 = 1");
List<Object> params = new ArrayList<>();
List<Object> countParams = new ArrayList<>();
int deleteById(@Param("id") String id);
if (StringUtils.isNotBlank(questionEmbedding.getKnowledgeId())) {
sql.append(" AND metadata ->> 'knowledgeId' = ?");
countSql.append(" AND metadata ->> 'knowledgeId' = ?");
params.add(questionEmbedding.getKnowledgeId());
countParams.add(questionEmbedding.getKnowledgeId());
}
if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){
sql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
countSql.append(" AND question ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
params.add("%" + questionEmbedding.getQuestion() + "%");
countParams.add("%" + questionEmbedding.getQuestion() + "%");
}
int deleteByIds(@Param("ids") List<String> ids);
if(StringUtils.isNotBlank(questionEmbedding.getAnswer())){
sql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
countSql.append(" AND answer ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
params.add("%" + questionEmbedding.getAnswer() + "%");
countParams.add("%" + questionEmbedding.getAnswer() + "%");
}
sql.append(" ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC");
// 添加分页
sql.append(" LIMIT ? OFFSET ?");
params.add(pageSize);
params.add((pageNo - 1) * pageSize);
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
// 设置参数值
for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, params.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
results.add(mapRowToQuestionEmbedding(rs));
}
}
} catch (SQLException e) {
log.error("查询所有记录失败", e);
throw new RuntimeException("查询数据时发生数据库错误", e);
}
// 执行计数查询
long total = 0;
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(countSql.toString())){
// 设置参数值
for (int i = 0; i < countParams.size(); i++) {
stmt.setObject(i + 1, countParams.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
total = rs.getLong(1); // 直接获取count值
}
}
} catch (SQLException e) {
log.error("查询记录总数失败", e);
throw new RuntimeException("查询记录总数时发生数据库错误", e);
}
Page<QuestionEmbedding> page = new Page<>();
page.setRecords(results);
page.setTotal(total);
return page;
}
// 查询所有记录
public Integer findQuestionCount(QuestionEmbedding questionEmbedding) {
StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from question_embedding where 1 = 1");
List<Object> params = new ArrayList<>();
if(StringUtils.isNotBlank(questionEmbedding.getQuestion())){
sql.append(" AND question = ?");
params.add(questionEmbedding.getQuestion());
}
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
// 设置参数值
for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, params.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
return rs.getInt("total_count");
}
return 0;
}
} catch (SQLException e) {
log.error("查询所有记录失败", e);
throw new RuntimeException("查询数据时发生数据库错误", e);
}
}
// 根据ID查询单个记录
public QuestionEmbedding findById(String id) {
String sql = "SELECT * FROM question_embedding WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
try (ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
return mapRowToQuestionEmbedding(rs);
}
}
} catch (SQLException e) {
log.error("根据ID查询记录失败, ID: {}", id, e);
throw new RuntimeException("根据ID查询时发生数据库错误", e);
}
return null;
}
// 插入新记录
public int insert(QuestionEmbedding record) {
String sql = "INSERT INTO question_embedding (id, text, question, answer, metadata,embedding) VALUES (?, ?, ?, ?, ?::jsonb,?)";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, UUID.randomUUID().toString());
stmt.setString(2, record.getText());
stmt.setString(3, record.getQuestion());
stmt.setString(4, record.getAnswer());
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue(record.getMetadata());
stmt.setObject(5, jsonObject);
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
stmt.setObject(6, embedding.content().vector());
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("插入记录失败: {}", record, e);
throw new RuntimeException("插入数据时发生数据库错误", e);
}
}
// 更新记录
public int update(QuestionEmbedding record) {
String sql = "UPDATE question_embedding SET text = ?, question = ?, answer = ?, metadata = ?::jsonb ,embedding = ? WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, record.getText());
stmt.setString(2, record.getQuestion());
stmt.setString(3, record.getAnswer());
stmt.setObject(4, record.getMetadata());
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
stmt.setObject(5, embedding.content().vector());
stmt.setString(6, record.getId());
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("更新记录失败: {}", record, e);
throw new RuntimeException("更新数据时发生数据库错误", e);
}
}
// 批量删除方法
public int deleteByIds(List<String> ids) {
if (ids == null || ids.isEmpty()) {
return 0;
}
String sql = "DELETE FROM question_embedding WHERE id IN (";
StringBuilder placeholders = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
placeholders.append("?");
if (i < ids.size() - 1) {
placeholders.append(",");
}
}
sql += placeholders.toString() + ")";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
for (int i = 0; i < ids.size(); i++) {
stmt.setString(i + 1, ids.get(i));
}
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("批量删除向量记录失败, IDs: {}", ids, e);
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
}
}
/**
* 向量相似度查询 (基于问题文本的向量)
* @param question 问题文本
* @param limit 返回结果数量
* @param minSimilarity 最小相似度阈值(0-1)
* @return 相似问答列表(按相似度降序)
*/
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
List<QuestionEmbedding> results = new ArrayList<>();
// 1. 参数校验
if (minSimilarity < 0 || minSimilarity > 1) {
throw new IllegalArgumentException("相似度阈值必须在0到1之间");
}
// 2. 获取问题的嵌入向量
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, question);
float[] queryVector = embeddingResponse.content().vector();
// 3. 计算最大允许距离(1 - 相似度阈值)
double maxDistance = 1 - minSimilarity;
// 4. 执行向量相似度查询
String sql = "SELECT *, embedding <-> ? AS distance " +
"FROM question_embedding " +
"WHERE embedding <-> ? < ? " + // 距离小于阈值
"ORDER BY distance ASC " + // 按距离升序
"LIMIT ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
// 设置参数
PGvector vector = new PGvector(queryVector);
stmt.setObject(1, vector);
stmt.setObject(2, vector);
stmt.setDouble(3, maxDistance);
stmt.setInt(4, limit);
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
// 计算相似度(1 - 距离)
double distance = rs.getDouble("distance");
double similarity = 1 - distance;
record.setSimilarity(similarity);
results.add(record);
}
}
} catch (SQLException e) {
log.error("向量相似度查询失败", e);
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
}
return results;
}
/**
* 向量相似度查询 (直接使用向量)
* @param vector 查询向量
* @param limit 返回结果数量
* @return 相似问答列表(按相似度降序)
*/
public List<QuestionEmbedding> similaritySearch(float[] vector, int limit) {
List<QuestionEmbedding> results = new ArrayList<>();
String sql = "SELECT *, embedding <-> ? AS similarity " +
"FROM question_embedding " +
"ORDER BY similarity ASC " +
"LIMIT ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setObject(1, new PGvector(vector));
stmt.setInt(2, limit);
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
QuestionEmbedding record = mapRowToQuestionEmbedding(rs);
double similarity = 1 - rs.getDouble("similarity");
record.setSimilarity(similarity);
results.add(record);
}
}
} catch (SQLException e) {
log.error("向量相似度查询失败", e);
throw new RuntimeException("执行向量相似度查询时发生数据库错误", e);
}
return results;
}
// 根据ID删除记录
public int deleteById(String id) {
String sql = "DELETE FROM question_embedding WHERE id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("删除记录失败, ID: {}", id, e);
throw new RuntimeException("删除数据时发生数据库错误", e);
}
}
// 将ResultSet行映射为QuestionEmbedding对象
private QuestionEmbedding mapRowToQuestionEmbedding(ResultSet rs) throws SQLException {
QuestionEmbedding record = new QuestionEmbedding();
record.setId(rs.getString("id"));
record.setText(rs.getString("text"));
record.setQuestion(rs.getString("question"));
record.setAnswer(rs.getString("answer"));
String metadataJson = rs.getString("metadata");
if (StringUtils.isNotBlank(metadataJson)) {
record.setMetadata(metadataJson);
}
return record;
}
// 将Map转换为JSON字符串
private String toJson(Map<String, Object> map) {
try {
return new ObjectMapper().writeValueAsString(map);
} catch (JsonProcessingException e) {
log.error("元数据转换为JSON失败", e);
return "{}";
}
}
// 将JSON字符串转换为Map
private Map<String, Object> fromJson(String json) {
try {
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
log.error("JSON转换为元数据失败", e);
return Collections.emptyMap();
}
}
List<QuestionEmbedding> similaritySearchByQuestion(@Param("vector") float[] vector,
@Param("limit") int limit,
@Param("minSimilarity") Double minSimilarity);
List<QuestionEmbedding> similaritySearch(@Param("vector") float[] vector,
@Param("limit") int limit);
}
\ No newline at end of file
... ...
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper">
<resultMap id="questionEmbeddingResultMap" type="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
<id column="id" property="id" />
<result column="text" property="text" />
<result column="question" property="question" />
<result column="answer" property="answer" />
<result column="metadata" property="metadata" typeHandler="org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler"/>
<result column="similarity" property="similarity" />
</resultMap>
<select id="findAll" resultMap="questionEmbeddingResultMap">
SELECT * FROM question_embedding WHERE 1 = 1
<if test="questionEmbedding.knowledgeId != null and questionEmbedding.knowledgeId != ''">
AND metadata ->> 'knowledgeId' = #{questionEmbedding.knowledgeId}
</if>
<if test="questionEmbedding.question != null and questionEmbedding.question != ''">
AND question ILIKE CONCAT('%', #{questionEmbedding.question}, '%')
</if>
<if test="questionEmbedding.answer != null and questionEmbedding.answer != ''">
AND answer ILIKE CONCAT('%', #{questionEmbedding.answer}, '%')
</if>
ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, question ASC
</select>
<select id="findQuestionCount" resultType="int">
SELECT COUNT(1) AS total_count FROM question_embedding WHERE 1 = 1
<if test="questionEmbedding.question != null and questionEmbedding.question != ''">
AND question = #{questionEmbedding.question}
</if>
</select>
<select id="findById" resultMap="questionEmbeddingResultMap">
SELECT * FROM question_embedding WHERE id = #{id}
</select>
<insert id="insert" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
INSERT INTO question_embedding (id, text, question, answer, metadata, embedding)
VALUES (
#{record.id, jdbcType=VARCHAR},
#{record.text, jdbcType=VARCHAR},
#{record.question, jdbcType=VARCHAR},
#{record.answer, jdbcType=VARCHAR},
#{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb,
#{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}
)
</insert>
<update id="update" parameterType="org.jeecg.modules.airag.app.entity.QuestionEmbedding">
UPDATE question_embedding
SET
text = #{record.text, jdbcType=VARCHAR},
question = #{record.question, jdbcType=VARCHAR},
answer = #{record.answer, jdbcType=VARCHAR},
metadata = #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb,
embedding = #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}
WHERE id = #{record.id}
</update>
<delete id="deleteById">
DELETE FROM question_embedding WHERE id = #{id}
</delete>
<delete id="deleteByIds">
DELETE FROM question_embedding WHERE id IN
<foreach collection="ids" item="id" open="(" separator="," close=")">
#{id}
</foreach>
</delete>
<select id="similaritySearchByQuestion" resultMap="questionEmbeddingResultMap">
<![CDATA[
SELECT *,
(embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler})::float AS similarity
FROM question_embedding
WHERE (embedding <-> #{vector, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler}) < #{minSimilarity}
ORDER BY similarity ASC
LIMIT #{limit}
]]>
</select>
<select id="similaritySearch" resultMap="questionEmbeddingResultMap">
<!-- SELECT *, embedding <-> #{vector} AS similarity-->
<!-- FROM question_embedding-->
<!-- ORDER BY similarity ASC-->
<!-- LIMIT #{limit}-->
</select>
</mapper>
\ No newline at end of file
... ...
... ... @@ -50,8 +50,6 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
@Override
public void saveToQuestionLibrary(AiragLog log) throws JsonProcessingException {
// 这里实现将问题和回答存入问题库数据表的逻辑
// 假设问题库数据表的实体类为 QuestionLibrary,Mapper 接口为 QuestionLibraryMapper
QuestionEmbedding questionEmbedding = new QuestionEmbedding();
questionEmbedding.setQuestion(log.getQuestion());
questionEmbedding.setAnswer(log.getAnswer());
... ... @@ -62,11 +60,7 @@ public class AiragLogServiceImpl extends ServiceImpl<AiragLogMapper, AiragLog> i
String docId = String.valueOf(snowflakeGenerator.next());
metadata.put("docId", docId);
metadata.put("knowledgeId", questionEmbedding.getKnowledgeId());
// 使用 Jackson 序列化 Map 到 JSON
ObjectMapper mapper = new ObjectMapper();
String metadataJson = mapper.writeValueAsString(metadata);
// 2. 设置到embeddings对象
questionEmbedding.setMetadata(metadataJson);
questionEmbedding.setMetadata(metadata);
questionEmbeddingMapper.insert(questionEmbedding);
airagLogMapper.updataIsStorage(log.getIsStorage(),log.getId());
System.out.println("1");
... ...
package org.jeecg.modules.airag.app.service.impl;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.fasterxml.jackson.core.JsonProcessingException;
import org.apache.commons.lang3.StringUtils;
import org.apache.poi.hwpf.usermodel.CharacterRun;
import org.apache.poi.hwpf.HWPFDocument;
import org.apache.poi.hwpf.usermodel.Paragraph;
... ... @@ -15,7 +17,10 @@ import dev.langchain4j.model.output.Response;
import org.apache.commons.io.FilenameUtils;
import org.apache.poi.xwpf.usermodel.*;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.mapper.EmbeddingsMapper;
import org.jeecg.modules.airag.app.mapper.PgVectorMapper;
import org.jeecg.modules.airag.app.mapper.QuestionEmbeddingMapper;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
... ... @@ -60,7 +65,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
private AiModelUtils aiModelUtils;
@Autowired
private IAIChatHandler aiChatHandler;
private PgVectorMapper pgVectorMapper;
@Value("${jeecg.upload.path}")
private String uploadPath;
... ... @@ -68,17 +73,13 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
private String embedId;
private static final Set<String> ALLOWED_EXTENSIONS = Set.of("txt", "doc", "docx");
private static final Pattern SPECIAL_CHARS_PATTERN = Pattern.compile("[^a-zA-Z0-9\\u4e00-\\u9fa5\\s]");
private static final Pattern UUID_PATTERN = Pattern.compile("_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}");
// 数据库连接配置
private static final String DB_URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
private static final String DB_USER = "postgres";
private static final String DB_PASSWORD = "postgres";
@Override
public Page<QuestionEmbedding> findAll(QuestionEmbedding questionEmbedding, Integer pageNo, Integer pageSize) {
return questionEmbeddingMapper.findAll(questionEmbedding,pageNo,pageSize);
Page<QuestionEmbedding> page = new Page<>(pageNo, pageSize);
return questionEmbeddingMapper.findAll(page,questionEmbedding);
}
@Override
... ... @@ -93,11 +94,21 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
@Override
public int insert(QuestionEmbedding record) {
if (StringUtils.isNotBlank(record.getQuestion())){
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
record.setEmbedding(embedding.content().vector());
}
return questionEmbeddingMapper.insert(record);
}
@Override
public int update(QuestionEmbedding record) {
if (StringUtils.isNotBlank(record.getQuestion())){
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getQuestion());
record.setEmbedding(embedding.content().vector());
}
return questionEmbeddingMapper.update(record);
}
... ... @@ -113,7 +124,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
@Override
public List<QuestionEmbedding> similaritySearchByQuestion(String question, int limit, Double minSimilarity) {
return questionEmbeddingMapper.similaritySearchByQuestion(question, limit, minSimilarity);
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, question);
return questionEmbeddingMapper.similaritySearchByQuestion(embedding.content().vector(), limit, minSimilarity);
}
@Override
... ... @@ -183,10 +195,8 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
segments = splitWordDocument(targetPath.toString());
}
// 原有逻辑:保存到question_embedding表
saveSegmentsToDatabase(segments, originalFileName, storedFileName, knowledgeId);
// 新增逻辑:同时保存到embeddings表
saveToEmbeddingsTable(segments, originalFileName, storedFileName, knowledgeId);
}
... ... @@ -196,7 +206,6 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
String displayFileName = removeUuidSuffix(originalFileName);
displayFileName = FilenameUtils.removeExtension(displayFileName);
try (Connection conn = getConnection()) {
for (String segment : segments) {
if (segment.trim().isEmpty()) continue;
... ... @@ -205,44 +214,29 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
if (parts.length < 2) continue;
String titlePath = parts[0].trim();
String answer = segment.trim(); // 整个回答段(含标题 + 内容)
// 整个回答段(标题 + 内容)
String answer = segment.trim();
// 获取 embedding
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, answer);
float[] embeddingVector = embeddingResponse.content().vector();
// 准备 metadata
Map<String, Object> metadata = new HashMap<>();
metadata.put("docName", originalFileName);
metadata.put("storedFileName", storedFileName);
metadata.put("knowledgeId", knowledgeId);
metadata.put("title", displayFileName + ": " + titlePath);
// 插入
String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";
try (PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, UUID.randomUUID().toString());
stmt.setObject(2, new PGvector(embeddingVector));
stmt.setString(3, answer);
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue(new ObjectMapper().writeValueAsString(metadata));
stmt.setObject(4, jsonObject);
stmt.executeUpdate();
}
}
} catch (Exception e) {
log.error("保存分段到embeddings表失败", e);
Embeddings embeddings = new Embeddings();
embeddings.setMetadata(metadata);
embeddings.setId(UUID.randomUUID().toString());
embeddings.setEmbedding(embeddingVector);
embeddings.setText(answer);
pgVectorMapper.insert(embeddings);
}
}
// 获取数据库连接
private Connection getConnection() throws SQLException {
return DriverManager.getConnection(DB_URL, DB_USER, DB_PASSWORD);
}
private String generateStoredFileName(String originalFileName) {
String baseName = FilenameUtils.removeExtension(originalFileName);
... ... @@ -359,6 +353,7 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
return 0;
}
private void saveSegmentsToDatabase(List<String> segments, String originalFileName, String storedFileName, String knowledgeId) {
if (segments.isEmpty()) return;
... ... @@ -384,24 +379,22 @@ public class QuestionEmbeddingServiceImpl implements IQuestionEmbeddingService {
record.setAnswer(titleLine + "\n" + content);
record.setText("");
Map<String, String> metadata = new LinkedHashMap<>();
Map<String, Object> metadata = new LinkedHashMap<>();
metadata.put("docId", docId);
metadata.put("docName", originalFileName);
metadata.put("storedFileName", storedFileName);
metadata.put("knowledgeId", knowledgeId);
try {
record.setMetadata(new ObjectMapper().writeValueAsString(metadata));
} catch (JsonProcessingException e) {
log.error("生成metadata JSON失败", e);
}
record.setMetadata(metadata);
log.info("保存分段: title={}, content_length={}", question, segment.length());
Response<Embedding> embeddingResponse = aiModelUtils.getEmbedding(embedId, record.getQuestion());
record.setEmbedding(embeddingResponse.content().vector());
record.setKnowledgeId(knowledgeId);
questionEmbeddingMapper.insert(record);
insert(record);
}
}
... ...
... ... @@ -133,7 +133,7 @@ public class AiragResponseServiceImpl implements AiragResponseService {
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data)));
// 发送END事件
Map<String, String> endData = createEndData(questionEmbedding.getMetadata(), String.valueOf(questionEmbedding.getSimilarity()));
Map<String, String> endData = createEndData(objectMapper.writeValueAsString(questionEmbedding.getMetadata()), String.valueOf(questionEmbedding.getSimilarity()));
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData)));
emitter.complete();
}
... ...