作者 lixiang

删除多余代码

正在显示 17 个修改的文件 包含 677 行增加1312 行删除
... ... @@ -63,9 +63,9 @@ public class EmbeddingsController {
@RequestParam(name = "pageNo", defaultValue = "1") Integer pageNo,
@RequestParam(name = "pageSize", defaultValue = "10") Integer pageSize,
HttpServletRequest req) throws NoSuchFieldException, IllegalAccessException, SQLException {
//Response<Embedding> embedding = aiModelUtils.getEmbedding("1925730210204721154", "33333");
Page<Embeddings> records = embeddingsService.findAll(embeddings,pageNo,pageSize);
Page<Embeddings> page = new Page<>(pageNo, pageSize);
Page<Embeddings> records = embeddingsService.findAll(page,embeddings);
return Result.OK(records);
}
/**
... ... @@ -94,19 +94,13 @@ public class EmbeddingsController {
@RequiresPermissions("embeddings:embeddings:add")
@PostMapping(value = "/add")
public Result<String> add(@RequestBody Embeddings embeddings) {
// 1. 构建完整的metadata
Map<String, Object> metadata = embeddings.getMetadata();
SnowflakeGenerator snowflakeGenerator = new SnowflakeGenerator();
metadata.put("docName", embeddings.getDocName());
String docId = String.valueOf(snowflakeGenerator.next());
metadata.put("docId", docId); // 自动生成唯一文档ID
metadata.put("index", "0"); // 默认索引位置为0
// 2. 设置到embeddings对象
metadata.put("docId", docId);
metadata.put("index", "0");
embeddings.setMetadata(metadata);
System.out.println(new SnowflakeGenerator().next());
embeddingsService.insert(embeddings);
return Result.OK("添加成功!");
}
... ... @@ -122,7 +116,11 @@ public class EmbeddingsController {
@RequiresPermissions("embeddings:embeddings:edit")
@RequestMapping(value = "/edit", method = {RequestMethod.PUT, RequestMethod.POST})
public Result<String> edit(@RequestBody Embeddings embeddings) {
embeddingsService.update(embeddings);
try {
embeddingsService.update(embeddings);
} catch (SQLException e) {
throw new RuntimeException(e);
}
return Result.OK("编辑成功!");
}
... ... @@ -137,6 +135,7 @@ public class EmbeddingsController {
@RequiresPermissions("embeddings:embeddings:delete")
@DeleteMapping(value = "/delete")
public Result<String> delete(@RequestParam(name = "id", required = true) String id) {
//embeddingsService.removeById(id);
embeddingsService.deleteById(id);
return Result.OK("删除成功!");
}
... ... @@ -166,6 +165,10 @@ public class EmbeddingsController {
@Operation(summary = "Embeddings-通过id查询")
@GetMapping(value = "/queryById")
public Result<Embeddings> queryById(@RequestParam(name = "id", required = true) String id) {
// Embeddings Embeddings = embeddingsService.getById(id);
// if(Embeddings==null) {
// return Result.error("未找到对应数据");
// }
embeddingsService.findById(id);
return Result.OK();
}
... ...
... ... @@ -82,7 +82,7 @@ public class AiragLog implements Serializable {
*/
@Excel(name = "回答方式", width = 15)
@TableField("answer_type")
@Schema(description = "回答方式:1:问题库回答 2:模型回答 3:未命中")
@Schema(description = "回答方式:1:问题库回答 2:模型回答 3:未命中 4:发生异常")
private int answerType;
/**
* 提问方式
... ...
package org.jeecg.modules.airag.app.handler;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import org.postgresql.util.PGobject;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.sql.*;
import java.util.Map;
public class JsonbMapTypeHandler extends BaseTypeHandler<Map<String, Object>> {
private static final ObjectMapper objectMapper = new ObjectMapper();
@Override
public void setNonNullParameter(PreparedStatement ps, int i,
Map<String, Object> parameter, JdbcType jdbcType) throws SQLException {
PGobject pgObject = new PGobject();
pgObject.setType("jsonb");
try {
pgObject.setValue(objectMapper.writeValueAsString(parameter));
ps.setObject(i, pgObject);
} catch (JsonProcessingException e) {
throw new SQLException("Failed to convert Map to JSON", e);
}
}
@Override
public Map<String, Object> getNullableResult(ResultSet rs, String columnName) throws SQLException {
return parseJson(rs.getString(columnName));
}
@Override
public Map<String, Object> getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
return parseJson(rs.getString(columnIndex));
}
@Override
public Map<String, Object> getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
return parseJson(cs.getString(columnIndex));
}
private Map<String, Object> parseJson(String json) throws SQLException {
if (json == null) return null;
try {
return objectMapper.readValue(json, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
throw new SQLException("Failed to parse JSON", e);
}
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.handler;
import org.apache.ibatis.type.BaseTypeHandler;
import org.apache.ibatis.type.JdbcType;
import com.pgvector.PGvector;
import java.sql.*;
public class PgVectorTypeHandler extends BaseTypeHandler<float[]> {
@Override
public void setNonNullParameter(PreparedStatement ps, int i,
float[] parameter, JdbcType jdbcType) throws SQLException {
ps.setObject(i, new PGvector(parameter));
}
@Override
public float[] getNullableResult(ResultSet rs, String columnName) throws SQLException {
PGvector pgVector = (PGvector) rs.getObject(columnName);
return pgVector != null ? pgVector.toArray() : null;
}
@Override
public float[] getNullableResult(ResultSet rs, int columnIndex) throws SQLException {
PGvector pgVector = (PGvector) rs.getObject(columnIndex);
return pgVector != null ? pgVector.toArray() : null;
}
@Override
public float[] getNullableResult(CallableStatement cs, int columnIndex) throws SQLException {
PGvector pgVector = (PGvector) cs.getObject(columnIndex);
return pgVector != null ? pgVector.toArray() : null;
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.app.mapper;
import ch.qos.logback.core.net.SyslogOutputStream;
import cn.hutool.core.lang.generator.SnowflakeGenerator;
import com.alibaba.fastjson2.JSONObject;
import com.baomidou.dynamic.datasource.annotation.DS;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.pgvector.PGvector;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.ibatis.annotations.Mapper;
import org.apache.ibatis.annotations.Param;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;
@Component
@Slf4j
public class PgVectorMapper {
@Autowired
private AiModelUtils aiModelUtils;
// PostgreSQL连接参数(实际项目中应从配置读取)
private static final String URL = "jdbc:postgresql://192.168.100.104:5432/postgres";
private static final String USER = "postgres";
private static final String PASSWORD = "postgres";
@Value("${jeecg.ai-chat.embedId}")
private String embedId;
// 获取数据库连接
private Connection getConnection() throws SQLException {
return DriverManager.getConnection(URL, USER, PASSWORD);
}
// 查询所有向量记录
public Page<Embeddings> findAll(Embeddings embeddings,int pageNo,int pageSize) {
List<Embeddings> results = new ArrayList<>();
StringBuilder sql = new StringBuilder("SELECT * FROM embeddings WHERE 1=1");
StringBuilder countSql = new StringBuilder("SELECT COUNT(1) FROM embeddings WHERE 1=1");
List<Object> params = new ArrayList<>(); // 存储参数值
List<Object> countParams = new ArrayList<>(); // 存储参数值
// 动态构建查询条件
if (StringUtils.isNotBlank(embeddings.getDocName())) {
sql.append(" AND metadata ->> 'docName' LIKE ?");
countSql.append(" AND metadata ->> 'docName' LIKE ?");
params.add("%" + embeddings.getDocName() + "%");
countParams.add("%" + embeddings.getDocName() + "%");
}
if (StringUtils.isNotBlank(embeddings.getKnowledgeId())) {
sql.append(" AND metadata ->> 'knowledgeId' = ?");
countSql.append(" AND metadata ->> 'knowledgeId' = ?");
params.add(embeddings.getKnowledgeId());
countParams.add(embeddings.getKnowledgeId());
}
if (StringUtils.isNotBlank(embeddings.getText())) {
sql.append(" AND text ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
countSql.append(" AND text ILIKE ?"); // 使用 ILIKE 进行不区分大小写的模糊匹配
params.add("%" + embeddings.getText() + "%");
countParams.add("%" + embeddings.getText() + "%");
}
sql.append(" ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, (metadata->>'docName') ASC");
// 添加分页
sql.append(" LIMIT ? OFFSET ?");
params.add(pageSize);
params.add((pageNo - 1) * pageSize);
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql.toString())) {
// 设置参数值
for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, params.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
results.add(mapRowToEmbeddings(rs));
}
}
} catch (SQLException e) {
log.error("查询所有向量记录失败", e);
throw new RuntimeException("查询向量数据时发生数据库错误", e);
}
// 执行计数查询
int total = 0;
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(countSql.toString())){
// 设置参数值
for (int i = 0; i < countParams.size(); i++) {
stmt.setObject(i + 1, countParams.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
total = rs.getInt(1); // 直接获取count值
}
}
} catch (SQLException e) {
log.error("查询记录总数失败", e);
throw new RuntimeException("查询记录总数时发生数据库错误", e);
}
Page<Embeddings> page = new Page<>();
page.setRecords(results);
page.setTotal(total);
return page;
}
// 根据ID查询单个向量记录
public Embeddings findById(String id) {
String sql = "SELECT * FROM embeddings WHERE embedding_id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
try (ResultSet rs = stmt.executeQuery()) {
if (rs.next()) {
return mapRowToEmbeddings(rs);
}
}
} catch (SQLException e) {
log.error("根据ID查询向量记录失败, ID: {}", id, e);
throw new RuntimeException("根据ID查询向量时发生数据库错误", e);
}
return null;
}
// 查询所有记录
public Integer findEmbeddingCount(Embeddings embeddings) {
StringBuilder sql = new StringBuilder("select COUNT(1) AS total_count from embeddings where 1 = 1");
List<Object> params = new ArrayList<>();
if(StringUtils.isNotBlank(embeddings.getText())){
sql.append(" AND text = ?");
params.add(embeddings.getText());
}
try(Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql.toString())){
// 设置参数值
for (int i = 0; i < params.size(); i++) {
stmt.setObject(i + 1, params.get(i));
}
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
return rs.getInt("total_count");
}
return 0;
}
} catch (SQLException e) {
log.error("查询所有记录失败", e);
throw new RuntimeException("查询数据时发生数据库错误", e);
}
}
// 插入新向量记录
public int insert(Embeddings record) {
String sql = "INSERT INTO embeddings (embedding_id, embedding, text, metadata) VALUES (?, ?, ?, ?::jsonb)";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, UUID.randomUUID().toString());
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
stmt.setObject(2, embedding.content().vector());
stmt.setObject(3, record.getText());
stmt.setObject(4, toJson(record.getMetadata()));
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("插入向量记录失败: {}", record, e);
throw new RuntimeException("插入向量数据时发生数据库错误", e);
}
}
// 更新向量记录
public int update(Embeddings record) {
String sql = "UPDATE embeddings SET embedding = ?, metadata = ?::jsonb, text = ? WHERE embedding_id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
JSONObject mataData = new JSONObject();
mataData.put("knowledgeId", record.getKnowledgeId()); // 使用前端传入的知识库ID
mataData.put("docName", record.getDocName());
//获取record数据中的docId
Map<String, Object> map = record.getMetadata();
System.out.println("map = " + map);
mataData.put("docId", record.getDocId());
mataData.put("index", "0");
System.out.println("原始数据: " + mataData);
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue(mataData.toJSONString());
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
stmt.setObject(1, embedding.content().vector());
stmt.setObject(2, jsonObject);
stmt.setObject(3, record.getText());
stmt.setString(4, record.getId());
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("更新向量记录失败: {}", record, e);
throw new RuntimeException("更新向量数据时发生数据库错误", e);
}
}
// 根据ID删除向量记录
public int deleteById(String id) {
String sql = "DELETE FROM embeddings WHERE embedding_id = ?";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, id);
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("删除向量记录失败, ID: {}", id, e);
throw new RuntimeException("删除向量数据时发生数据库错误", e);
}
}
// 批量删除方法
public int deleteByIds(List<String> ids) {
if (ids == null || ids.isEmpty()) {
return 0;
}
String sql = "DELETE FROM embeddings WHERE embedding_id IN (";
StringBuilder placeholders = new StringBuilder();
for (int i = 0; i < ids.size(); i++) {
placeholders.append("?");
if (i < ids.size() - 1) {
placeholders.append(",");
}
}
sql += placeholders.toString() + ")";
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
for (int i = 0; i < ids.size(); i++) {
stmt.setString(i + 1, ids.get(i));
}
return stmt.executeUpdate();
} catch (SQLException e) {
log.error("批量删除向量记录失败, IDs: {}", ids, e);
throw new RuntimeException("批量删除向量数据时发生数据库错误", e);
}
}
// 向量相似度搜索
public List<Embeddings> similaritySearch(float[] vector, int limit) {
String sql = "SELECT * FROM embeddings ORDER BY embedding <-> ? LIMIT ?";
List<Embeddings> results = new ArrayList<>();
try (Connection conn = getConnection();
PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setObject(1, new PGvector(vector));
stmt.setInt(2, limit);
try (ResultSet rs = stmt.executeQuery()) {
while (rs.next()) {
results.add(mapRowToEmbeddings(rs));
}
}
} catch (SQLException e) {
log.error("向量相似度搜索失败", e);
throw new RuntimeException("执行向量相似度搜索时发生数据库错误", e);
}
return results;
}
// 将ResultSet行映射为VectorRecord对象
private Embeddings mapRowToEmbeddings(ResultSet rs) throws SQLException {
Embeddings record = new Embeddings();
record.setId(rs.getString("embedding_id"));
record.setText(rs.getString("text"));
String metadataJson = rs.getString("metadata");
if (StringUtils.isNotBlank(metadataJson)) {
record.setMetadata(fromJson(metadataJson));
}
return record;
}
// 将Map转换为JSON字符串
private String toJson(Map<String, Object> map) {
try {
return new ObjectMapper().writeValueAsString(map);
} catch (JsonProcessingException e) {
log.error("元数据转换为JSON失败", e);
return "{}";
}
}
// 将JSON字符串转换为Map
private Map<String, Object> fromJson(String json) {
try {
return new ObjectMapper().readValue(json, new TypeReference<Map<String, Object>>() {});
} catch (JsonProcessingException e) {
log.error("JSON转换为元数据失败", e);
return Collections.emptyMap();
}
}
import java.util.List;
@Mapper
@DS("pgvector")
public interface PgVectorMapper {
Page<Embeddings> findAll(IPage<Embeddings> page, @Param("embeddings") Embeddings embeddings);
Embeddings findById(@Param("id") String id);
Integer findEmbeddingCount(@Param("embeddings") Embeddings embeddings);
int insert(@Param("record") Embeddings record);
int update(@Param("record") Embeddings record);
int deleteById(@Param("id") String id);
int deleteByIds(@Param("ids") List<String> ids);
List<Embeddings> similaritySearch(@Param("vector") float[] vector,
@Param("limit") int limit);
}
\ No newline at end of file
... ...
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE mapper PUBLIC "-//mybatis.org//DTD Mapper 3.0//EN"
"http://mybatis.org/dtd/mybatis-3-mapper.dtd">
<mapper namespace="org.jeecg.modules.airag.app.mapper.PgVectorMapper">
<resultMap id="embeddingsResultMap" type="org.jeecg.modules.airag.app.entity.Embeddings">
<id column="embedding_id" property="id" />
<result column="text" property="text" />
<result column="metadata" property="metadata" typeHandler="org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler" />
</resultMap>
<select id="findAll" resultMap="embeddingsResultMap">
SELECT * FROM embeddings WHERE 1=1
<if test="embeddings.docName != null and embeddings.docName != ''">
AND metadata ->> 'docName' LIKE CONCAT('%', #{embeddings.docName}, '%')
</if>
<if test="embeddings.knowledgeId != null and embeddings.knowledgeId != ''">
AND metadata ->> 'knowledgeId' = #{embeddings.knowledgeId}
</if>
<if test="embeddings.text != null and embeddings.text != ''">
AND text ILIKE CONCAT('%', #{embeddings.text}, '%')
</if>
ORDER BY (metadata->>'knowledgeId') ASC NULLS LAST, (metadata->>'docName') ASC
</select>
<select id="findById" resultMap="embeddingsResultMap">
SELECT * FROM embeddings WHERE embedding_id = #{id}
</select>
<select id="findEmbeddingCount" resultType="int">
SELECT COUNT(1) FROM embeddings WHERE 1=1
<if test="embeddings.text != null and embeddings.text != ''">
AND text = #{embeddings.text}
</if>
</select>
<insert id="insert" parameterType="org.jeecg.modules.airag.app.entity.Embeddings">
INSERT INTO embeddings (embedding_id, embedding, text, metadata)
VALUES (
#{record.id, jdbcType=VARCHAR},
#{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler},
#{record.text, jdbcType=VARCHAR},
#{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb
)
</insert>
<update id="update" parameterType="org.jeecg.modules.airag.app.entity.Embeddings">
UPDATE embeddings
SET
embedding = #{record.embedding, jdbcType=ARRAY, typeHandler=org.jeecg.modules.airag.app.handler.PgVectorTypeHandler},
metadata = #{record.metadata, jdbcType=OTHER, typeHandler=org.jeecg.modules.airag.app.handler.JsonbMapTypeHandler}::jsonb,
text = #{record.text, jdbcType=VARCHAR}
WHERE embedding_id = #{record.id}
</update>
<delete id="deleteById">
DELETE FROM embeddings WHERE embedding_id = #{id}
</delete>
<delete id="deleteByIds">
DELETE FROM embeddings WHERE embedding_id IN
<foreach collection="ids" item="id" open="(" separator="," close=")">
#{id}
</foreach>
</delete>
<select id="similaritySearch" resultMap="embeddingsResultMap">
</select>
</mapper>
\ No newline at end of file
... ...
... ... @@ -4,9 +4,8 @@ package org.jeecg.modules.airag.app.service;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import java.util.ArrayList;
import java.sql.SQLException;
import java.util.List;
/**
... ... @@ -17,10 +16,10 @@ import java.util.List;
*/
public interface IEmbeddingsService {
Page<Embeddings> findAll(Embeddings embeddings,int pageNo,int pageSize);
Page<Embeddings> findAll(Page<Embeddings> page, Embeddings embeddings);
int deleteById(String id);
int insert(Embeddings record);
int update(Embeddings record);
int update(Embeddings record) throws SQLException;
Embeddings findById(String id);
int removeByIds(List<String> ids);
... ...
package org.jeecg.modules.airag.app.service.impl;
import com.alibaba.fastjson2.JSONObject;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import dev.langchain4j.data.embedding.Embedding;
import dev.langchain4j.model.output.Response;
import org.jeecg.modules.airag.app.entity.Embeddings;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.mapper.PgVectorMapper;
import org.jeecg.modules.airag.app.service.IEmbeddingsService;
import org.jeecg.modules.airag.app.utils.AiModelUtils;
import org.postgresql.util.PGobject;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import java.sql.SQLException;
import java.util.List;
import java.util.UUID;
/**
* @Description: test
... ... @@ -20,10 +27,14 @@ import java.util.List;
public class IEmbeddingsServiceImpl implements IEmbeddingsService {
@Autowired
private PgVectorMapper pgVectorMapper;
@Autowired
private AiModelUtils aiModelUtils;
@Value("${jeecg.ai-chat.embedId}")
private String embedId;
@Override
public Page<Embeddings> findAll(Embeddings embeddings, int pageNo, int pageSize) {
return pgVectorMapper.findAll(embeddings,pageNo,pageSize);
public Page<Embeddings> findAll(Page<Embeddings> page, Embeddings embeddings) {
return pgVectorMapper.findAll(page,embeddings);
}
@Override
... ... @@ -36,11 +47,28 @@ public class IEmbeddingsServiceImpl implements IEmbeddingsService {
}
public int insert(Embeddings record) {
record.setId(UUID.randomUUID().toString());
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
record.setEmbedding(embedding.content().vector());
return pgVectorMapper.insert(record);
}
@Override
public int update(Embeddings record) {
public int update(Embeddings record) throws SQLException {
JSONObject mataData = new JSONObject();
mataData.put("knowledgeId", record.getKnowledgeId()); // 使用前端传入的知识库ID
mataData.put("docName", record.getDocName());
mataData.put("docId", record.getDocId()); // 自动生成唯一文档ID
mataData.put("index", "0");
PGobject jsonObject = new PGobject();
jsonObject.setType("json");
jsonObject.setValue(mataData.toJSONString());
record.setMetadata(mataData);
Response<Embedding> embedding = aiModelUtils.getEmbedding(embedId, record.getText());
record.setEmbedding(embedding.content().vector());
return pgVectorMapper.update(record);
}
... ...
package org.jeecg.modules.airag.zdyrag.controller;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.hankcs.hanlp.summary.TextRankKeyword;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.jeecg.ai.handler.AIParams;
import org.jeecg.ai.handler.LLMHandler;
import org.jeecg.common.api.vo.Result;
import org.jeecg.common.system.vo.LoginUser;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.*;
/**
* 直接回答llm
*/
@RestController
@RequestMapping("/airag/zdyRag")
@Slf4j
public class KeyRagController {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
IAIChatHandler aiChatHandler;
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@Value("${jeecg.upload.path}")
private String uploadPath;
@Autowired
private IAiragLogService airagLogService;
// 用于异步处理的线程池
private final ExecutorService executor = Executors.newCachedThreadPool();
@Operation(summary = "sendStream1")
@GetMapping("sendStream1")
public SseEmitter sendStream(String questionText) throws Exception {
SseEmitter emitter = new SseEmitter(300000L);
// 创建日志对象
String modelId = "1926875898187878401";
AiragLog logRecord = new AiragLog()
.setQuestion(questionText)
.setModelId(modelId)
.setCreateTime(new Date());
executor.execute(() -> {
String knowId = "1926872137990148098";
try {
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 2, 0.78);
// 从知识库搜索
if (CollectionUtil.isEmpty(maps)) {
Map<String, String> data = new HashMap<>();
data.put("token", "该问题未记录在知识库中");
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
// 准备END事件数据
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
// 记录日志 - 未命中任何知识库
logRecord.setAnswer("该问题未记录在知识库中")
.setAnswerType(3)
.setIsStorage(0);
airagLogService.save(logRecord);
emitter.complete();
return;
}
// 构建知识库内容
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > 0.78) {
content.append(map.get("content").toString()).append("\n");
}
}
TextRankKeyword textRank = new TextRankKeyword();
List<String> keyWords = textRank.getKeywords(questionText, 5);
System.out.println("关键词...:" + keyWords);
// 获取第一个匹配的元数据用于日志和文件信息
Map<String, Object> firstMatch = maps.get(0);
String fileName = generateFileDocName(firstMatch.get("metadata").toString());
String storedFileName = generateFilePath(firstMatch.get("metadata").toString());
// 构建更优化的prompt
String prompt = String.format("你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +
"处理步骤和要求:\n" +
"1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
"2. 严格基于参考内容回答,禁止使用参考内容中与问题无关的内容\n" +
"3. 回答结构:\n" +
" - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
" - 然后列出支持该答案的具体内容(可直接引用参考内容)\n" +
"4. 禁止以下行为:\n" +
" - 添加参考内容中不存在的信息\n" +
" - 在回答中提及‘参考内容’等字样\n" +
" - 在回答中提及其他产品的功能\n" +
" - 进行任何推测性陈述\n" +
" - 使用模糊或不确定的表达\n" +
" - 参考内容为空时应该拒绝回答\n" +
"参考内容(请严格限制回答范围于此):\n" + content);
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage("user", prompt));
StringBuilder answerBuilder = new StringBuilder();
TokenStream tokenStream = aiChatHandler.chat(modelId, messages);
tokenStream.onNext(token -> {
try {
answerBuilder.append(token);
Map<String, String> data = new HashMap<>();
data.put("token", token);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
} catch (Exception e) {
log.error("发送token失败", e);
}
});
tokenStream.onComplete(response -> {
try {
// 准备END事件数据
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", firstMatch.get("score").toString());
endData.put("fileName", fileName);
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
// 记录日志 - 从知识库生成回答
logRecord.setAnswer(answerBuilder.toString())
.setAnswerType(2);
System.out.println("回答内容 = " + answerBuilder);
airagLogService.save(logRecord);
emitter.complete();
} catch (Exception e) {
log.error("流式响应结束时发生错误", e);
}
});
tokenStream.onError(error -> {
log.error("生成答案失败", error);
// 记录日志 - 错误情况
logRecord.setAnswer("生成答案失败: " + error.getMessage())
.setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(error);
});
tokenStream.start();
} catch (Exception e) {
log.error("处理请求时发生异常", e);
// 记录日志 - 异常情况
logRecord.setAnswer("处理请求时发生异常: " + e.getMessage())
.setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(e);
}
});
return emitter;
}
private String generateFilePath(String metadataJson) throws Exception {
if (StringUtils.isEmpty(metadataJson)) {
return "";
}
ObjectMapper objectMapper = new ObjectMapper();
// 解析JSON字符串
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
// 获取docName和docId
return metadata.get("storedFileName");
}
private String generateFileDocName(String metadataJson) throws Exception {
if (StringUtils.isEmpty(metadataJson)) {
return "";
}
ObjectMapper objectMapper = new ObjectMapper();
// 解析JSON字符串
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
return metadata.get("docName");
}
}
package org.jeecg.modules.airag.zdyrag.controller;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import dev.langchain4j.store.embedding.pgvector.PgVectorEmbeddingStore;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.apache.shiro.SecurityUtils;
import org.apache.shiro.subject.Subject;
import org.apache.shiro.util.ThreadContext;
import org.jeecg.ai.handler.AIParams;
import org.jeecg.ai.handler.LLMHandler;
import org.jeecg.common.api.vo.Result;
import org.jeecg.common.system.vo.LoginUser;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.jeecg.modules.airag.llm.service.IAiragKnowledgeService;
import org.jeecg.modules.airag.zdyrag.service.AiragResponseService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.File;
import java.io.IOException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.*;
/**
* 直接回答llm
*/
@RestController
@RequestMapping("/airag/zdyRag")
@Slf4j
public class ZdyRagController {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
IAIChatHandler aiChatHandler;
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@Value("${jeecg.upload.path}")
private String uploadPath;
@Autowired
private IAiragLogService airagLogService;
// 用于异步处理的线程池
private final ExecutorService executor = Executors.newCachedThreadPool();
@Autowired
private AiragResponseService airagResponseService;
/**
* @author lixiang
* @param questionText 问题文本
* @param code 快捷按钮code
* @param codeType 提问方式,用于记录日志,区分输入框提问还是快捷方式
* @param user 提问人
* @return 以流式返回回答结果
*
* 1、将提问文本与问题库匹配,若匹配则回答预设回答结果
* 2、若问题库中无匹配预设问题,则查询知识库,将查询到的知识提供给llm模型,生成回答结果
* 3、回答时会将当初上传的资料以参考资料的形式进行返回,可进行预览
* 4、将本次的问答结果记录日志
* todo :增加产品推荐功能?
*/
@Operation(summary = "sendStream")
@GetMapping("sendStream")
public SseEmitter sendStream(@RequestParam("questionText") String questionText,
@RequestParam("code") String code,
@RequestParam("codeType") Integer codeType,
@RequestParam("user") String user
) throws Exception {
SseEmitter emitter = new SseEmitter(300000L);
// 创建日志对象
String modelId = "1926875898187878401";
AiragLog logRecord = new AiragLog()
.setQuestion(questionText)
.setCode(code)
.setCreateBy(user)
.setCodeType(codeType)
.setModelId(modelId)
.setCreateTime(new Date());
executor.execute(() -> {
try {
String knowId = "1926872137990148098";
List<QuestionEmbedding> questionEmbeddings = questionEmbeddingService.similaritySearchByQuestion(questionText, 1, 0.8);
// 如果从问题库中找到匹配
if (!questionEmbeddings.isEmpty()) {
QuestionEmbedding questionEmbedding = questionEmbeddings.get(0);
Map<String, String> data = new HashMap<>();
data.put("token", questionEmbedding.getAnswer());
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
// 解析元数据
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(questionEmbedding.getMetadata(), Map.class);
// 准备END事件数据
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", String.valueOf(questionEmbedding.getSimilarity()));
if (metadata != null) {
String docName = metadata.get("docName");
endData.put("fileName", docName);
String fileName = generateFilePath(questionEmbedding.getMetadata());
if (StringUtils.isNotBlank(fileName)) {
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + fileName));
}
}
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData)));
// 记录日志 - 从问题库匹配
logRecord.setAnswer(questionEmbedding.getAnswer())
.setAnswerType(1);
airagLogService.save(logRecord);
emitter.complete();
return;
}
// 从知识库搜索
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 2, 0.78);
if (CollectionUtil.isEmpty(maps)) {
Map<String, String> data = new HashMap<>();
data.put("token", "该问题未记录在知识库中");
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
// 准备END事件数据
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
// 记录日志 - 未命中任何知识库
logRecord.setAnswer("该问题未记录在知识库中")
.setAnswerType(3)
.setIsStorage(0);
airagLogService.save(logRecord);
emitter.complete();
return;
}
// 构建知识库内容
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > 0.78) {
content.append(map.get("content").toString()).append("\n");
}
}
// 获取第一个匹配的元数据用于日志和文件信息
Map<String, Object> firstMatch = maps.get(0);
String fileName = generateFileDocName(firstMatch.get("metadata").toString());
String storedFileName = generateFilePath(firstMatch.get("metadata").toString());
// 构建问题提示
String questin = "你是一个严谨的信息处理助手,请严格按照以下要求处理用户问题:" + questionText + "\n\n" +
"处理步骤和要求:\n" +
"1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
"2. 严格基于参考内容回答,禁止使用参考内容中与问题无关的内容\n" +
"3. 回答结构:\n" +
" - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
" - 然后列出支持该答案的具体内容(可直接引用参考内容)\n" +
"4. 禁止以下行为:\n" +
" - 添加参考内容中不存在的信息\n" +
" - 在回答中提及‘参考内容’等字样\n" +
" - 在回答中提及其他产品的功能\n" +
" - 进行任何推测性陈述\n" +
" - 使用模糊或不确定的表达\n" +
" - 参考内容为空时应该拒绝回答\n" +
"参考内容(请严格限制回答范围于此):\n" + content;
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage("user", questin));
StringBuilder answerBuilder = new StringBuilder();
TokenStream tokenStream = aiChatHandler.chat(modelId, messages);
tokenStream.onNext(token -> {
try {
answerBuilder.append(token);
Map<String, String> data = new HashMap<>();
data.put("token", token);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
} catch (Exception e) {
log.error("发送token失败", e);
}
});
tokenStream.onComplete(response -> {
try {
// 准备END事件数据
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", firstMatch.get("score").toString());
endData.put("fileName", fileName);
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
// 记录日志 - 从知识库生成回答
logRecord.setAnswer(answerBuilder.toString())
.setAnswerType(2);
System.out.println("回答内容 = " + answerBuilder.toString());
airagLogService.save(logRecord);
emitter.complete();
} catch (Exception e) {
log.error("流式响应结束时发生错误", e);
}
});
tokenStream.onError(error -> {
log.error("生成答案失败", error);
// 记录日志 - 错误情况
logRecord.setAnswer("生成答案失败: " + error.getMessage())
.setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(error);
});
tokenStream.start();
} catch (Exception e) {
log.error("处理请求时发生异常", e);
// 记录日志 - 异常情况
logRecord.setAnswer("处理请求时发生异常: " + e.getMessage())
.setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(e);
}
});
return emitter;
@RequestParam("user") String user) {
return airagResponseService.handleStreamRequest(questionText, code, codeType, user);
}
@Operation(summary = "send")
@GetMapping("send")
public Result<Map<String, Object>> send(String questionText) throws Exception {
String knowId = "1926872137990148098";
String modelId = "1926875898187878401";
Integer topNumber = 1;
Double similarity = 0.8;
// 创建日志对象
AiragLog logRecord = new AiragLog()
.setQuestion(questionText)
.setModelId(modelId)
.setCreateTime(new Date());
HashMap<String, Object> resMap = new HashMap<>();
//根据问题相似度进行查询
List<QuestionEmbedding> questionEmbeddings = questionEmbeddingService.similaritySearchByQuestion(questionText, 1,0.8);
for (QuestionEmbedding questionEmbedding : questionEmbeddings) {
resMap.put("question", questionText);
resMap.put("answer", questionEmbedding.getAnswer());
resMap.put("similarity", questionEmbedding.getSimilarity());
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(questionEmbedding.getMetadata(), Map.class);
// 获取docName和docId
if (metadata != null) {
String docName = metadata.get("docName");
resMap.put("fileName", docName);
String fileName = generateFilePath(questionEmbedding.getMetadata());
if (StringUtils.isNotBlank(fileName)) {
resMap.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + fileName));
}
}
// 记录日志 - 从问题库匹配
logRecord.setAnswer(questionEmbedding.getAnswer());
logRecord.setAnswerType(1);
airagLogService.save(logRecord);
log.info("questionEmbedding.getMetadata() = " + questionEmbedding.getMetadata());
log.info("questionEmbedding.getQuestion() = " + questionEmbedding.getQuestion());
log.info("questionEmbedding.getAnswer() = " + questionEmbedding.getAnswer());
log.info("questionEmbedding.getSimilarity() = " + questionEmbedding.getSimilarity());
log.info("-------------------------------------------------------------");
}
//返回问题库命中的问题
if (!questionEmbeddings.isEmpty()) {
return Result.OK(resMap);
}
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 3, 0.75);
if (CollectionUtil.isEmpty(maps)) {
resMap.put("answer", "该问题未记录在知识库中");
// 记录日志 - 未命中任何知识库
logRecord.setAnswer("该问题未记录在知识库中");
logRecord.setAnswerType(3);
logRecord.setIsStorage(0);
airagLogService.save(logRecord);
return Result.OK(resMap);
}
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > 0.78){
log.info("score = " + map.get("score").toString());
log.info("content = " + map.get("content").toString());
content.append(map.get("content").toString()).append("\n");
}
}
List<ChatMessage> messages = new ArrayList<>();
String questin = "你是一个严格遵循指令的信息处理助手,请按照以下规范回答用户问题:\n\n" +
"# 处理规范\n" +
"1. 回答范围:\n" +
" - 仅使用提供的参考内容进行回答\n" +
" - 禁止任何超出参考内容的推断、想象或补充\n" +
" - 当参考内容为空或不相关时,必须拒绝回答\n\n" +
"2. 回答结构要求:\n" +
" - 首行必须用「回答:」开头,给出最直接的事实性回答\n" +
" - 后续每行以「•」开头列出支持证据,每条证据必须:\n" +
" * 直接引用参考内容\n" +
" * 标注具体出处位置(如段落编号/行号)\n" +
" * 保持原句完整性,不得改写\n\n" +
"3. 禁止事项:\n" +
" - 任何形式的推测(包括\"可能\"、\"应该\"等不确定表述)\n" +
" - 回答内容不得提出\"参考内容\"、\"证据\"等字样\n" +
" - 参考内容中未明确出现的数字、事实或结论\n" +
" - 总结性陈述或观点性表达\n" +
" - 多个信息点的合并表述\n\n" +
"4. 特殊情形处理:\n" +
" - 专业术语必须保持原文表述\n" +
" - 数据必须包含原始单位和精度\n\n" +
"# 当前任务\n" +
"问题:「" + questionText + "」\n\n" +
"参考内容(严格限制回答范围):\n" +
content;
messages.add(new UserMessage("user", questin));
String chat = aiChatHandler.completions(modelId, messages, null);
resMap.put("question", questionText);
resMap.put("answer", chat);
resMap.put("similarity", maps.get(0).get("score").toString());
String fileName = generateFileDocName(maps.get(0).get("metadata").toString());
String storedFileName = generateFilePath(maps.get(0).get("metadata").toString());
resMap.put("fileName", fileName);
resMap.put("fileBase64",FileToBase64Util.fileToBase64(uploadPath + storedFileName));
// 记录日志 - 从知识库生成回答
logRecord.setAnswer(chat);
logRecord.setAnswerType(2);
airagLogService.save(logRecord);
return Result.OK(resMap);
}
private String generateFilePath(String metadataJson) throws Exception {
if (StringUtils.isEmpty(metadataJson)) {
return "";
}
ObjectMapper objectMapper = new ObjectMapper();
// 解析JSON字符串
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
// 获取docName和docId
return metadata.get("storedFileName");
}
private String generateFileDocName(String metadataJson) throws Exception {
if (StringUtils.isEmpty(metadataJson)) {
return "";
}
ObjectMapper objectMapper = new ObjectMapper();
// 解析JSON字符串
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
return metadata.get("docName");
}
public static void main(String[] args) {
String s = "学生户口复印_efde055d-1207-4b6f-8d46-79eb557ca711.docx";
String s1 = StringUtils.substringBefore(s, ".");
log.info("s1 = " + s1);
String[] split = s.split("_");
for (String string : split) {
log.info("string = " + string);
}
}
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.zdyrag.controller;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import io.swagger.v3.oas.annotations.Operation;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.ai.handler.LLMHandler;
import org.jeecg.common.api.vo.Result;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
import org.jeecg.modules.airag.zdyrag.helper.MultiTurnContextHelper;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.util.*;
import java.util.concurrent.*;
@Slf4j
@RestController
@RequestMapping("/airag/zdyRag")
public class ZdyRagMultiStageController {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
private IAIChatHandler aiChatHandler;
@Autowired
private IAiragLogService airagLogService;
@Autowired
private RedisTemplate<String, Object> redisTemplate;
@Value("${jeecg.upload.path}")
private String uploadPath;
private final ExecutorService executor = Executors.newCachedThreadPool();
private final ExecutorService asyncLLMExecutor = Executors.newFixedThreadPool(5);
@Operation(summary = "multiStageStream with Redis context")
@GetMapping("multiStageStream")
public SseEmitter multiStageStream(@RequestParam String questionText,
@RequestParam(required = false) String sessionId) throws Exception {
SseEmitter emitter = new SseEmitter(300000L);
String modelId = "1926875898187878401";
String knowId = "1926872137990148098";
AiragLog logRecord = new AiragLog()
.setQuestion(questionText)
.setModelId(modelId)
.setCreateTime(new Date());
executor.execute(() -> {
try {
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 5, 0.75);
// ========================== 知识库为空时,尝试使用历史上下文回答 ==========================
if (CollectionUtil.isEmpty(maps)) {
List<ChatMessage> historyContext = MultiTurnContextHelper.loadHistory(sessionId, redisTemplate);
if (!historyContext.isEmpty()) {
log.info("知识库为空,尝试使用历史上下文回答问题");
String prompt = MultiTurnContextHelper.buildPromptFromHistory(historyContext, questionText);
String answer = aiChatHandler.completions(modelId, List.of(new UserMessage("user", prompt)), null);
if (StringUtils.isBlank(answer) || MultiTurnContextHelper.containsRefusalKeywords(answer)) {
sendSimpleMessage(emitter, "该问题未记录在知识库或历史中,无法回答");
logRecord.setAnswer("该问题未记录在知识库或历史中,无法回答").setAnswerType(3).setIsStorage(0);
} else {
sendSimpleMessage(emitter, answer);
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", "0.0");
endData.put("fileName", "历史上下文");
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
logRecord.setAnswer(answer).setAnswerType(2);
MultiTurnContextHelper.saveHistory(sessionId, redisTemplate, historyContext, questionText, answer);
}
airagLogService.save(logRecord);
emitter.complete();
return;
} else {
sendSimpleMessage(emitter, "该问题未记录在知识库中,且无历史内容可参考");
logRecord.setAnswer("该问题未记录在知识库中,且无历史内容可参考").setAnswerType(3).setIsStorage(0);
airagLogService.save(logRecord);
emitter.complete();
return;
}
}
// ========================== 多线程摘要生成 ==========================
List<Future<String>> summaryFutures = new ArrayList<>();
for (Map<String, Object> map : maps) {
String content = map.get("content").toString();
String summaryPrompt = buildSummaryPrompt(questionText, content);
summaryFutures.add(asyncLLMExecutor.submit(() ->
aiChatHandler.completions(modelId, List.of(new UserMessage("user", summaryPrompt)), null)
));
}
List<String> summaries = new ArrayList<>();
for (Future<String> future : summaryFutures) {
try {
String summary = future.get(15, TimeUnit.SECONDS);
if (StringUtils.isNotBlank(summary)) summaries.add(summary.trim());
} catch (Exception e) {
log.warn("摘要生成失败", e);
}
}
// ========================== 多线程候选答案生成 ==========================
List<Future<String>> answerFutures = new ArrayList<>();
for (String summary : summaries) {
String answerPrompt = buildAnswerPrompt(questionText, summary);
answerFutures.add(asyncLLMExecutor.submit(() ->
aiChatHandler.completions(modelId, List.of(new UserMessage("user", answerPrompt)), null)
));
}
List<String> candidateAnswers = new ArrayList<>();
for (Future<String> future : answerFutures) {
try {
String answer = future.get(15, TimeUnit.SECONDS);
if (StringUtils.isNotBlank(answer)) candidateAnswers.add(answer);
} catch (Exception e) {
log.warn("候选答案生成失败", e);
}
}
// ========================== 合并答案生成最终回答 ==========================
String mergePrompt = buildMergePrompt(questionText, candidateAnswers);
List<ChatMessage> mergeMessages = new ArrayList<>();
if (StringUtils.isNotBlank(sessionId)) {
Object cached = redisTemplate.opsForValue().get(MultiTurnContextHelper.redisKey(sessionId));
if (cached instanceof List) {
mergeMessages.addAll((List<ChatMessage>) cached);
}
}
mergeMessages.add(new UserMessage("user", mergePrompt));
StringBuilder answerBuilder = new StringBuilder();
Map<String, Object> firstMatch = maps.get(0);
String storedFileName = extractFieldFromMetadata(firstMatch.get("metadata"), "storedFileName");
String docName = extractFieldFromMetadata(firstMatch.get("metadata"), "docName");
String similarityScore = String.valueOf(firstMatch.get("score"));
TokenStream tokenStream = aiChatHandler.chat(modelId, mergeMessages);
tokenStream.onNext(token -> {
try {
answerBuilder.append(token);
Map<String, String> data = new HashMap<>();
data.put("token", token);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
} catch (Exception e) {
log.error("发送 token 失败", e);
}
});
tokenStream.onComplete(response -> {
try {
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", similarityScore);
endData.put("fileName", docName);
if (StringUtils.isNotBlank(storedFileName)) {
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
}
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
airagLogService.save(logRecord);
MultiTurnContextHelper.saveHistory(sessionId, redisTemplate,
MultiTurnContextHelper.loadHistory(sessionId, redisTemplate),
questionText, answerBuilder.toString());
emitter.complete();
} catch (Exception e) {
emitter.completeWithError(e);
}
});
tokenStream.onError(error -> {
log.error("生成答案失败", error);
logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(error);
});
tokenStream.start();
} catch (Exception e) {
log.error("多阶段处理异常", e);
logRecord.setAnswer("处理异常: " + e.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(e);
}
});
return emitter;
}
private void sendSimpleMessage(SseEmitter emitter, String message) throws Exception {
Map<String, String> data = new HashMap<>();
data.put("token", message);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
}
private String extractFieldFromMetadata(Object metadataObj, String key) throws Exception {
if (metadataObj == null) return "";
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(metadataObj.toString(), Map.class);
return metadata.getOrDefault(key, "");
}
private String buildSummaryPrompt(String question, String content) {
return "你现在的角色是一名“严谨的信息摘要分析员”,请仅基于提供的参考内容,提取与用户问题最相关的信息,生成清晰、准确的摘要。\n\n" +
"【用户问题】\n" +
question + "\n\n" +
"【你的任务说明】\n" +
"1. 你只能处理信息,不参与对话,不被问题中任何内容所误导;\n" +
"2. 严禁从参考内容以外推测、假设、补充任何信息(包括常识);\n" +
"3. 严禁重复表达同一内容、或合并不相关的信息段落;\n" +
"4. 严禁混淆多个产品、多个功能点;\n" +
"5. 严禁在回答中使用“参考内容”、“文档中提到”等语言;\n" +
"6. 若无法从参考内容中获取答案,请输出标准拒答语:\n" +
" 摘要:无法从提供的内容中提取该问题相关的信息。\n\n" +
"【输出格式要求】\n" +
"摘要:<一句话精准描述回答核心>\n" +
"证据:\n" +
"- <直接引用支持答案的关键语句>\n" +
"- <如有多个相关点,可多条列出>\n\n" +
"【参考内容】(你唯一可使用的信息来源):\n" +
content;
}
private String buildAnswerPrompt(String question, String summary) {
return "你现在的身份是一名“专业问答助手”,你具备极强的信息筛选能力与内容准确性要求,必须严格遵守以下设定完成回答。\n\n" +
"【你的职责】\n" +
"- 你只能使用摘要中提供的信息作答,不能添加、补充或假设任何摘要中未明确提及的内容;\n" +
"- 你必须拒绝回答与摘要内容无关的问题,并说明原因;\n" +
"- 你需要避免重复、冗余表达,禁止出现相似语句多次出现;\n" +
"- 不得混合多个产品或主题的信息;\n\n" +
"【回答格式要求】\n" +
"- 回答必须以“回答:”开头;\n" +
"- 如无法回答,必须使用以下格式拒绝:\n" +
" 回答:对不起,我无法回答该问题,因为摘要中未提供相关信息。\n\n" +
"【用户问题】\n" +
question + "\n\n" +
"【摘要内容】\n" +
summary + "\n\n" +
"请作为“专业问答助手”现在作答:";
}
private String buildMergePrompt(String question, List<String> answers) {
StringBuilder sb = new StringBuilder("你收到多个候选答案,请从中选择最准确且不交叉混淆产品信息的答案作为最终回答。\n\n");
sb.append("用户问题:").append(question).append("\n");
for (int i = 0; i < answers.size(); i++) {
sb.append("候选答案").append(i + 1).append(":\n").append(answers.get(i)).append("\n\n");
}
sb.append("请直接输出最佳答案,**禁止添加新信息**或跨产品混合。");
return sb.toString();
}
}
package org.jeecg.modules.airag.zdyrag.helper;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.data.redis.core.RedisTemplate;
import java.util.*;
import java.util.concurrent.TimeUnit;
@Slf4j
public class MultiTurnContextHelper {
private static final int MAX_CONTEXT_SIZE = 10;
private static final long CONTEXT_TTL_MILLIS = 30 * 60 * 1000; // 30分钟
public static String redisKey(String sessionId) {
return "chat:context:" + sessionId;
}
public static List<ChatMessage> loadHistory(String sessionId, RedisTemplate<String, Object> redisTemplate) {
if (StringUtils.isBlank(sessionId)) return new ArrayList<>();
Object cached = redisTemplate.opsForValue().get(redisKey(sessionId));
if (cached instanceof List) {
return new ArrayList<>((List<ChatMessage>) cached);
}
return new ArrayList<>();
}
public static String buildPromptFromHistory(List<ChatMessage> history, String currentQuestion) {
StringBuilder sb = new StringBuilder("你是一个对话助手,请根据以下历史对话内容回答用户当前问题:\n\n");
sb.append("限制要求:\n");
sb.append("1. 严格只能使用历史对话中明确提到的信息\n");
sb.append("2. 禁止任何基于常识或主观推断的补充\n");
sb.append("3. 若无法从历史内容中明确回答,应直接拒绝回答\n");
sb.append("4. 回答必须以“回答:”开头\n\n");
sb.append("历史对话如下(最多展示最近5轮):\n");
int count = 0;
for (int i = Math.max(0, history.size() - 10); i < history.size(); i++) {
ChatMessage msg = history.get(i);
if (msg instanceof UserMessage) {
sb.append("用户:").append(msg.text()).append("\n");
} else {
sb.append("助手:").append(msg.text()).append("\n");
}
count++;
if (count >= 10) break;
}
sb.append("\n当前用户问题:").append(currentQuestion).append("\n");
return sb.toString();
}
public static void saveHistory(String sessionId, RedisTemplate<String, Object> redisTemplate,
List<ChatMessage> history, String question, String answer) {
if (StringUtils.isBlank(sessionId)) return;
history.add(new UserMessage("user", question));
history.add(new UserMessage("assistant", answer));
if (history.size() > MAX_CONTEXT_SIZE) {
history = history.subList(history.size() - MAX_CONTEXT_SIZE, history.size());
}
redisTemplate.opsForValue().set(redisKey(sessionId), history, CONTEXT_TTL_MILLIS, TimeUnit.MILLISECONDS);
}
public static boolean containsRefusalKeywords(String answer) {
List<String> refusalKeywords = List.of("无法", "不知道", "未提及", "没有相关信息", "参考内容为空", "不能回答");
return refusalKeywords.stream().anyMatch(answer::contains);
}
}
package org.jeecg.modules.airag.zdyrag.service;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.List;
import java.util.Map;
/**
* Service interface for handling AI-RAG responses
*/
public interface AiragResponseService {
SseEmitter handleStreamRequest(String questionText, String code, Integer codeType, String user);
String generateFilePath(String metadataJson) throws IOException, IOException;
String generateFileDocName(String metadataJson) throws IOException;
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.zdyrag.service;
public interface ProductExtractor {
String extractProduct(String questionText);
}
... ...
package org.jeecg.modules.airag.zdyrag.service.impl;
import cn.hutool.core.collection.CollectionUtil;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.SystemMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.service.TokenStream;
import lombok.extern.log4j.Log4j2;
import org.apache.commons.lang3.StringUtils;
import org.jeecg.modules.airag.app.entity.AiragLog;
import org.jeecg.modules.airag.app.entity.QuestionEmbedding;
import org.jeecg.modules.airag.app.service.IAiragLogService;
import org.jeecg.modules.airag.app.service.IQuestionEmbeddingService;
import org.jeecg.modules.airag.app.utils.FileToBase64Util;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.llm.handler.EmbeddingHandler;
import org.jeecg.modules.airag.zdyrag.service.AiragResponseService;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
import java.io.IOException;
import java.util.*;
import java.util.regex.Pattern;
@Service
@Log4j2
public class AiragResponseServiceImpl implements AiragResponseService {
@Autowired
private EmbeddingHandler embeddingHandler;
@Autowired
private IAIChatHandler aiChatHandler;
@Autowired
private IQuestionEmbeddingService questionEmbeddingService;
@Value("${jeecg.upload.path}")
private String uploadPath;
@Autowired
private IAiragLogService airagLogService;
@Override
public SseEmitter handleStreamRequest(String questionText, String code, Integer codeType, String user) {
SseEmitter emitter = new SseEmitter(300000L);
String modelId = "1926875898187878401";
AiragLog logRecord = createLogRecord(questionText, code, codeType, user, modelId);
String cleanedQuestionText = cleanQuestionText(questionText);
try {
// 处理问题库匹配
if (handleQuestionEmbeddingMatch(emitter, cleanedQuestionText, logRecord)) {
return emitter;
}
// 处理知识库搜索
handleKnowledgeBaseSearch(emitter, cleanedQuestionText, logRecord, modelId);
} catch (Exception e) {
handleError(emitter, logRecord, e);
}
return emitter;
}
@Override
public String generateFilePath(String metadataJson) throws IOException {
return extractMetadataValue(metadataJson, "storedFileName");
}
@Override
public String generateFileDocName(String metadataJson) throws IOException {
return extractMetadataValue(metadataJson, "docName");
}
private String extractMetadataValue(String metadataJson, String key) throws IOException {
if (StringUtils.isEmpty(metadataJson)) {
return "";
}
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
return metadata.get(key);
}
/**
* 创建日志对象
* @param questionText 问题原文本
* @param code 快捷按钮code
* @param codeType 提问方式,用于记录日志,区分输入框提问还是快捷方式
* @param user 提问人
* @param modelId 模型id
* @return 返回日志对象
*/
private AiragLog createLogRecord(String questionText, String code, Integer codeType, String user, String modelId) {
return new AiragLog()
.setQuestion(questionText)
.setCode(code)
.setCreateBy(user)
.setCodeType(codeType)
.setModelId(modelId)
.setCreateTime(new Date());
}
/**
* 匹配问题库
* @param emitter 流式返回
* @param questionText 问题原文本
* @param logRecord 日志对象
* @return 返回是否匹配成功
*/
private boolean handleQuestionEmbeddingMatch(SseEmitter emitter, String questionText, AiragLog logRecord) throws Exception {
List<QuestionEmbedding> questionEmbeddings = questionEmbeddingService.similaritySearchByQuestion(questionText, 1, 0.8);
if (questionEmbeddings.isEmpty()) {
return false;
}
QuestionEmbedding questionEmbedding = questionEmbeddings.get(0);
sendQuestionEmbeddingResponse(emitter, questionEmbedding);
logRecord.setAnswer(questionEmbedding.getAnswer()).setAnswerType(1);
airagLogService.save(logRecord);
return true;
}
/**
* 发送token
* @param emitter 流式返回
* @param questionEmbedding 问题向量
*/
private void sendQuestionEmbeddingResponse(SseEmitter emitter, QuestionEmbedding questionEmbedding) throws Exception {
ObjectMapper objectMapper = new ObjectMapper();
// 发送token
Map<String, String> data = new HashMap<>();
data.put("token", questionEmbedding.getAnswer());
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data)));
// 发送END事件
Map<String, String> endData = createEndData(questionEmbedding.getMetadata(), String.valueOf(questionEmbedding.getSimilarity()));
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData)));
emitter.complete();
}
/**
* 知识库匹配
* @param emitter 流式返回
* @param questionText 问题原文本
* @param logRecord 日志对象
* @param modelId 模型id
*/
private void handleKnowledgeBaseSearch(SseEmitter emitter, String questionText, AiragLog logRecord, String modelId) throws Exception {
String knowId = "1926872137990148098";
List<Map<String, Object>> maps = embeddingHandler.searchEmbedding(knowId, questionText, 2, 0.78);
if (CollectionUtil.isEmpty(maps)) {
handleNoKnowledgeFound(emitter, logRecord);
return;
}
String content = buildKnowledgeContent(maps);
Map<String, Object> firstMatch = maps.get(0);
String fileName = generateFileDocName(firstMatch.get("metadata").toString());
String storedFileName = generateFilePath(firstMatch.get("metadata").toString());
String systemPrompt = createSystemPrompt();
String userPrompt = createUserPrompt(questionText, content);
processLLMResponse(emitter, logRecord, modelId, systemPrompt, userPrompt, firstMatch, fileName, storedFileName);
}
/**
* 构建参考内容
* @param maps 匹配到的参考内容
* @return 返回参考内容
*/
private String buildKnowledgeContent(List<Map<String, Object>> maps) {
StringBuilder content = new StringBuilder();
for (Map<String, Object> map : maps) {
if (Double.parseDouble(map.get("score").toString()) > 0.78) {
content.append(map.get("content").toString()).append("\n");
}
}
return content.toString();
}
/**
* 知识库中未记录
* @param emitter 流式返回
* @param logRecord 日志对象
*/
private void handleNoKnowledgeFound(SseEmitter emitter, AiragLog logRecord) throws Exception {
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> data = new HashMap<>();
data.put("token", "该问题未记录在知识库中");
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(data)));
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
emitter.send(SseEmitter.event().data(objectMapper.writeValueAsString(endData)));
logRecord.setAnswer("该问题未记录在知识库中")
.setAnswerType(3)
.setIsStorage(0);
airagLogService.save(logRecord);
emitter.complete();
}
/**
* 构建系统提示词
* @return 系统提示词
*/
private String createSystemPrompt() {
return "你是一个严谨的信息处理助手,必须严格遵守以下规则:\n" +
"1. 严格基于参考内容回答,禁止任何超出参考内容的推断或想象\n" +
"2. 严格基于参考内容回答,禁止使用参考内容中与问题无关的内容\n" +
"3. 回答结构:\n" +
" - 首先用一句话直接回答问题核心(仅限参考内容中明确包含的信息)\n" +
" - 然后列出支持该答案的具体内容(可直接引用参考内容)\n" +
"4. 禁止以下行为:\n" +
" - 添加参考内容中不存在的信息\n" +
" - 在回答中提及'参考内容'等字样\n" +
" - 在回答中提及其他产品的功能\n" +
" - 进行任何推测性陈述\n" +
" - 使用模糊或不确定的表达\n" +
" - 参考内容为空时应该拒绝回答";
}
/**
* 构建用户提示
* @param questionText 问题文本
* @param content 参考内容
* @return 用户提示词
*/
private String createUserPrompt(String questionText, String content) {
return "用户问题:\n" +
questionText + "\n\n" +
"参考内容(请严格限制回答范围于此):\n" +
content;
}
/**
* 对llm模型进行提问
* @param emitter 流式返回
* @param logRecord 日志对象
* @param modelId 模型id
* @param systemPrompt 系统提示词
* @param userPrompt 用户提示词
* @param firstMatch 最相似的数据
* @param fileName 文件名称
* @param storedFileName 本地存储文件名称
*/
private void processLLMResponse(SseEmitter emitter, AiragLog logRecord, String modelId,
String systemPrompt, String userPrompt,
Map<String, Object> firstMatch, String fileName, String storedFileName) {
StringBuilder answerBuilder = new StringBuilder();
List<ChatMessage> messages = new ArrayList<>();
messages.add(new SystemMessage(systemPrompt));
messages.add(new UserMessage(userPrompt));
TokenStream tokenStream = aiChatHandler.chat(modelId, messages);
tokenStream.onNext(token -> {
try {
answerBuilder.append(token);
Map<String, String> data = new HashMap<>();
data.put("token", token);
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(data)));
} catch (Exception e) {
log.error("发送token失败", e);
}
});
tokenStream.onComplete(response -> {
try {
Map<String, String> endData = createEndData(firstMatch.get("metadata").toString(),
firstMatch.get("score").toString());
endData.put("fileName", fileName);
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + storedFileName));
emitter.send(SseEmitter.event().data(new ObjectMapper().writeValueAsString(endData)));
logRecord.setAnswer(answerBuilder.toString()).setAnswerType(2);
airagLogService.save(logRecord);
emitter.complete();
} catch (Exception e) {
log.error("流式响应结束时发生错误", e);
}
});
tokenStream.onError(error -> {
log.error("生成答案失败", error);
logRecord.setAnswer("生成答案失败: " + error.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(error);
});
tokenStream.start();
}
/**
* 创建结束标志,发送结束token
* @param metadataJson 元数据
* @param similarity 相似度
*/
private Map<String, String> createEndData(String metadataJson, String similarity) throws IOException {
Map<String, String> endData = new HashMap<>();
endData.put("event", "END");
endData.put("similarity", similarity);
if (StringUtils.isNotBlank(metadataJson)) {
ObjectMapper objectMapper = new ObjectMapper();
Map<String, String> metadata = objectMapper.readValue(metadataJson, Map.class);
String docName = metadata.get("docName");
endData.put("fileName", docName);
String fileName = generateFilePath(metadataJson);
if (StringUtils.isNotBlank(fileName)) {
endData.put("fileBase64", FileToBase64Util.fileToBase64(uploadPath + fileName));
}
}
return endData;
}
/**
* 异常日志记录
* @param emitter 流式返回
* @param logRecord 日志记录
* @param e 异常处理
*/
private void handleError(SseEmitter emitter, AiragLog logRecord, Exception e) {
log.error("处理请求时发生异常", e);
logRecord.setAnswer("处理请求时发生异常: " + e.getMessage()).setAnswerType(4);
airagLogService.save(logRecord);
emitter.completeWithError(e);
}
/**
* 清理问题文本中的标点符号
* @param questionText 原始问题文本
* @return 清理后的文本(无标点符号)
*/
private String cleanQuestionText(String questionText) {
if (StringUtils.isBlank(questionText)) {
return questionText;
}
// 定义要移除的标点符号正则表达式
Pattern punctuationPattern = Pattern.compile("[,,.。??!!;;、]");
// 替换所有匹配的标点符号为空字符串
return punctuationPattern.matcher(questionText).replaceAll(" ");
}
}
\ No newline at end of file
... ...
package org.jeecg.modules.airag.zdyrag.service.impl;
import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.model.openai.OpenAiChatModel;
import org.jeecg.modules.airag.common.handler.IAIChatHandler;
import org.jeecg.modules.airag.zdyrag.service.ProductExtractor;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;
/**
* 用于实现产品推荐,该功能暂未实现
*/
@Component
public class ProductExtractorImpl implements ProductExtractor {
@Autowired
IAIChatHandler aiChatHandler;
@Override
public String extractProduct(String questionText) {
String modelId = "1926875898187878401";
String prompt =
"请从下列问题中提取涉及的产品名称,返回JSON格式,示例:\n" +
"{ \"products\": [\"产品A\", \"产品B\"] }\n" +
"如果没有产品,请返回:{ \"products\": [] }\n\n" +
"问题:" + questionText;
List<ChatMessage> messages = new ArrayList<>();
messages.add(new UserMessage("user", prompt));
return aiChatHandler.completions(modelId,messages);
}
}
... ...
... ... @@ -162,17 +162,20 @@ spring:
password: 1234
driver-class-name: com.mysql.cj.jdbc.Driver
# 多数据源配置
# pgvector:
# jdbc-url: jdbc:postgresql://192.168.100.103:5432/postgres
# username: postgres
# password: postgres
# driver-class-name: org.postgresql.Driver
pgvector:
url: jdbc:postgresql://192.168.100.104:5432/postgres
username: postgres
password: postgres
driver-class-name: org.postgresql.Driver
#redis 配置
redis:
database: 0
host: 127.0.0.1
port: 6379
password:
mybatis:
type-handlers-package: org.jeecg.modules.airag.app.handler
#mybatis plus 设置
mybatis-plus:
mapper-locations: classpath*:org/jeecg/**/xml/*Mapper.xml
... ...