feat: 本项目基本完成
新增 rerankURL、vectorDBURL 配置,以及 rerankModel、retriveTopK、 rerankTopK 和 memory 等参数,以支持更完整的检索和重排序功能。
This commit is contained in:
@@ -1,7 +1,13 @@
|
||||
{
|
||||
"queryURL":"https://api.siliconflow.cn/v1/chat/completions",
|
||||
"embeddingURL":"https://api.siliconflow.cn/v1/embeddings",
|
||||
"rerankURL":"https://api.siliconflow.cn/v1/rerank",
|
||||
"vectorDBURL":"http://localhost:8000",
|
||||
"key":"sk-xsoegkpdvqlbsoodrnaygqycdvhplkyowivkzlszqfytpvti",
|
||||
"queryModel":"Qwen/Qwen3-8B",
|
||||
"embeddingModel":"Qwen/Qwen3-Embedding-0.6B"
|
||||
"embeddingModel":"Qwen/Qwen3-Embedding-0.6B",
|
||||
"rerankModel":"Qwen/Qwen3-Reranker-0.6B",
|
||||
"retriveTopK":20,
|
||||
"rerankTopK":10,
|
||||
"memory":true
|
||||
}
|
||||
File diff suppressed because one or more lines are too long
@@ -1,6 +1,6 @@
|
||||
// 本示例演示访问 DeepSeek 大模型
|
||||
package XDUMsgBot_cj
|
||||
import std.collection.{ArrayList, reduce}
|
||||
// import std.collection.{ArrayList, reduce}
|
||||
import std.io.StringReader
|
||||
import stdx.encoding.json.*
|
||||
import stdx.net.http.*
|
||||
@@ -27,7 +27,6 @@ class Embedding {
|
||||
"dimensions": 1024
|
||||
}
|
||||
'''
|
||||
println(content)
|
||||
let request = HttpRequestBuilder()
|
||||
.url(url)
|
||||
.header('Authorization', 'Bearer ${key}')
|
||||
@@ -40,7 +39,6 @@ class Embedding {
|
||||
|
||||
func parse(text: String) {
|
||||
let json = JsonValue.fromStr(text).asObject()
|
||||
println(json)
|
||||
let data = json.getFields()['data'].asArray()
|
||||
let embedding = data[0].asObject().getFields()['embedding'].asArray()
|
||||
return embedding
|
||||
|
||||
244
src/main.cj
244
src/main.cj
@@ -3,47 +3,12 @@ package XDUMsgBot_cj
|
||||
import std.fs.*
|
||||
import std.io.*
|
||||
import stdx.encoding.json.*
|
||||
import stdx.net.http.*
|
||||
// import stdx.net.tls.*
|
||||
|
||||
// class Config <: Serializable<Config> {
|
||||
// var url:String = ""
|
||||
// var key:String = ""
|
||||
// var embeddingModel:String = ""
|
||||
// var queryModel:String = ""
|
||||
// public func serialize():DataModel{
|
||||
// return DataModelStruct()
|
||||
// .add(field<String>("url",url))
|
||||
// .add(field<String>("key",key))
|
||||
// .add(field<String>("embeddingModel",embeddingModel))
|
||||
// .add(field<String>("queryModel",queryModel))
|
||||
// }
|
||||
// public static func deserialize(dm: DataModel): Config {
|
||||
// var dms = match (dm) {
|
||||
// case data: DataModelStruct => data
|
||||
// case _ => throw Exception("this data is not DataModelStruct")
|
||||
// }
|
||||
// var result = Config()
|
||||
// result.url = String.deserialize(dms.get("url"))
|
||||
// result.key = String.deserialize(dms.get("key"))
|
||||
// result.embeddingModel = String.deserialize(dms.get("embeddingModel"))
|
||||
// result.queryModel = String.deserialize(dms.get("queryModel"))
|
||||
// return result
|
||||
// }
|
||||
// }
|
||||
func prepare(){
|
||||
|
||||
|
||||
main() {
|
||||
|
||||
/* prepare and chunk data */
|
||||
// open and read config file
|
||||
let configPath:Path = Path("./config.json")
|
||||
if(!exists(configPath)){
|
||||
println("Error! config.json doesn't exist")
|
||||
return
|
||||
}
|
||||
let configFile:File = File(configPath,Read)
|
||||
let configBytes:Array<Byte> = readToEnd(configFile)
|
||||
configFile.close()
|
||||
let config = JsonValue.fromStr(String.fromUtf8(configBytes)).asObject()
|
||||
let config = getConfig()
|
||||
|
||||
// open and read data file
|
||||
let dataPath:Path = Path("./data/data.txt")
|
||||
@@ -59,63 +24,156 @@ main() {
|
||||
let dataArray = dataString.split("\r\n")
|
||||
|
||||
|
||||
/* embedding */
|
||||
/* embedding and store vector */
|
||||
let embeddingModel = Embedding(url:config.getFields()['embeddingURL'].asString().getValue(),
|
||||
key:config.getFields()['key'].asString().getValue(),
|
||||
model:config.getFields()['embeddingModel'].asString().getValue())
|
||||
// open and write vectors
|
||||
let dataEmbedPath:Path = Path("./data/data_embed.txt")
|
||||
if(!exists(dataEmbedPath)){
|
||||
println("Error! data/data_embed.txt doesn't exist")
|
||||
return
|
||||
}
|
||||
let dataEmbedFile:File = File(dataEmbedPath,Write)
|
||||
dataEmbedFile.setLength(0)
|
||||
var i = 0
|
||||
for(data in dataArray){
|
||||
|
||||
let vector = embeddingModel.embed(data).toString()
|
||||
dataEmbedFile.write(vector.toArray())
|
||||
i++
|
||||
println(i)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
// // 使用 SiliconFlow 提供的服务接口
|
||||
// let robot = LLM(url: 'https://api.siliconflow.cn/v1/chat/completions',
|
||||
// // 如果示例自带的密钥失效,请自行注册,https://cloud.siliconflow.cn/account/ak
|
||||
// key: 'sk-xsoegkpdvqlbsoodrnaygqycdvhplkyowivkzlszqfytpvti',
|
||||
// model: 'Qwen/Qwen3-8B',
|
||||
// memory: true)
|
||||
|
||||
// // robot.preset('我会用林黛玉的风格回复哥哥的所有问题')
|
||||
// // robot.chats('介绍李白')
|
||||
// // println('\n----------\n')
|
||||
|
||||
// // robot.chats('他和安徽的不解情缘')
|
||||
// // println('\n----------\n')
|
||||
|
||||
// // robot.reset()
|
||||
// // robot.chat('你好') |> println
|
||||
// // robot.chat('却是荷池跳雨,散了真珠还聚') |> println
|
||||
|
||||
// while (true) {
|
||||
// let input = readln()
|
||||
// if (input.startsWith('风格#')) {
|
||||
// let style = input.trimStart('风格#')
|
||||
// robot.switchStyle(style)
|
||||
// } else {
|
||||
// let reply = robot.chats(input)
|
||||
// println(reply)
|
||||
// }
|
||||
// open and store vectors
|
||||
// let dataEmbedPath:Path = Path("./data/data_embed.txt")
|
||||
// if(!exists(dataEmbedPath)){
|
||||
// println("Error! data/data_embed.txt doesn't exist")
|
||||
// return
|
||||
// }
|
||||
// let dataEmbedFile:File = File(dataEmbedPath,Append)
|
||||
// dataEmbedFile.setLength(0)
|
||||
|
||||
|
||||
|
||||
for(data in dataArray){
|
||||
let vector = embeddingModel.embed(data).toString()
|
||||
|
||||
let client = ClientBuilder().build()
|
||||
let content = '''
|
||||
{ "embedding":${vector},
|
||||
"document":"${data}"
|
||||
}
|
||||
'''
|
||||
let request = HttpRequestBuilder()
|
||||
.url(config.getFields()['vectorDBURL'].asString().getValue()+"/store")
|
||||
.header('Content-Type', 'application/json')
|
||||
.body(content)
|
||||
.post()
|
||||
.build()
|
||||
let rsp = client.send(request)
|
||||
// read response
|
||||
let buf = Array<UInt8>(1024, repeat: 0)
|
||||
let len = rsp.body.read(buf)
|
||||
println(String.fromUtf8(buf.slice(0, len)))
|
||||
client.close()
|
||||
println("stored: ${data}")
|
||||
}
|
||||
}
|
||||
|
||||
func getKonwledge(input:String):String{
|
||||
|
||||
|
||||
/* retrive */
|
||||
// println("-------------------retrive:-----------------------")
|
||||
let config = getConfig()
|
||||
|
||||
let embeddingModel = Embedding(url:config.getFields()['embeddingURL'].asString().getValue(),
|
||||
key:config.getFields()['key'].asString().getValue(),
|
||||
model:config.getFields()['embeddingModel'].asString().getValue())
|
||||
let vector = embeddingModel.embed(input).toString()
|
||||
|
||||
let client = ClientBuilder().build()
|
||||
let content = '''
|
||||
{
|
||||
"query_embedding":${vector},
|
||||
"top_k":${config.getFields()['retriveTopK'].asInt().getValue()}
|
||||
}
|
||||
'''
|
||||
let request = HttpRequestBuilder()
|
||||
.url(config.getFields()['vectorDBURL'].asString().getValue()+"/query")
|
||||
.header('Content-Type', 'application/json')
|
||||
.body(content)
|
||||
.post()
|
||||
.build()
|
||||
let rsp = client.send(request)
|
||||
let output = StringReader(rsp.body).readToEnd()
|
||||
let json = JsonValue.fromStr(output).asObject().getFields()['documents'].asArray()
|
||||
client.close()
|
||||
|
||||
|
||||
|
||||
/* rerank */
|
||||
// println("-------------------rerank:-----------------------")
|
||||
let rerankModel = Rerank(url:config.getFields()['rerankURL'].asString().getValue(),
|
||||
key:config.getFields()['key'].asString().getValue(),
|
||||
model:config.getFields()['rerankModel'].asString().getValue())
|
||||
let result = rerankModel.rerank(input,json,config.getFields()['rerankTopK'].asInt().getValue())
|
||||
|
||||
|
||||
|
||||
/* query*/
|
||||
|
||||
let baseKnowledge = StringBuilder()
|
||||
|
||||
for(item in result){
|
||||
baseKnowledge.append("参考资料:")
|
||||
baseKnowledge.append(item)
|
||||
baseKnowledge.append("\r\n")
|
||||
}
|
||||
|
||||
return baseKnowledge.toString()
|
||||
|
||||
}
|
||||
|
||||
|
||||
main() {
|
||||
|
||||
// just need to run once
|
||||
// prepare()
|
||||
|
||||
|
||||
|
||||
let config = getConfig()
|
||||
let robot = Query(url: config.getFields()['queryURL'].asString().getValue(),
|
||||
key: config.getFields()['key'].asString().getValue(),
|
||||
model: config.getFields()['queryModel'].asString().getValue(),
|
||||
memory: config.getFields()['memory'].asBool().getValue())
|
||||
|
||||
let input = readln()
|
||||
let baseKnowledge = getKonwledge(input)
|
||||
robot.preset(input,baseKnowledge)
|
||||
let reply = robot.chats(input)
|
||||
println(reply)
|
||||
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||
|
||||
while (true) {
|
||||
let input = readln()
|
||||
if (input.startsWith('新对话#')) {
|
||||
let prompt = input.trimStart('新对话#')
|
||||
let baseKnowledge = getKonwledge(prompt)
|
||||
robot.preset(prompt,baseKnowledge)
|
||||
println("--------新对话: 关于问题:${prompt}-------------")
|
||||
let reply = robot.chats(input)
|
||||
println(reply)
|
||||
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||
|
||||
} else {
|
||||
let reply = robot.chats(input)
|
||||
println(reply)
|
||||
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
// 本示例演示访问 DeepSeek 大模型
|
||||
package XDUMsgBot_cj
|
||||
import std.collection.{ArrayList, reduce}
|
||||
import std.io.StringReader
|
||||
@@ -31,10 +30,10 @@ extend ArrayList<String> {
|
||||
}
|
||||
}
|
||||
|
||||
class LLM {
|
||||
class Query {
|
||||
let client: Client
|
||||
let history = ArrayList<String>()
|
||||
public LLM(let url!: String, let key!: String, let model!: String,
|
||||
public Query(let url!: String, let key!: String, let model!: String,
|
||||
var memory!: Bool = false) {
|
||||
var config = TlsClientConfig()
|
||||
config.verifyMode = TrustAll
|
||||
@@ -112,17 +111,22 @@ class LLM {
|
||||
}
|
||||
|
||||
// 角色预设或加载历史对话
|
||||
public func preset(content: String, role!: Role = System) {
|
||||
history.add(role, content)
|
||||
public func preset(query: String,baseKnowledge:String, role!: Role = System) {
|
||||
history.clear()
|
||||
history.add(role, """
|
||||
你是一位专业的知识助手,名字叫XDUMsgBot,你可以根据相关片段中的信息回答用户关于西安电子科技大学的问题。
|
||||
请根据用户的问题和下列片段生成准确的回应。
|
||||
|
||||
用户问题:${query}
|
||||
|
||||
相关片段:
|
||||
${baseKnowledge}
|
||||
|
||||
请基于上述内容作答,如果没有明确的信息可供参考,请回答不知道,不要编造信息。""")
|
||||
memory = true
|
||||
}
|
||||
|
||||
public func reset() {
|
||||
history.clear()
|
||||
}
|
||||
|
||||
public func switchStyle(styleName:String){
|
||||
reset()
|
||||
preset("用${styleName}的风格回复问题")
|
||||
}
|
||||
}
|
||||
56
src/rerank.cj
Normal file
56
src/rerank.cj
Normal file
@@ -0,0 +1,56 @@
|
||||
package XDUMsgBot_cj
|
||||
import std.collection.ArrayList
|
||||
import std.io.StringReader
|
||||
import stdx.encoding.json.*
|
||||
import stdx.net.http.*
|
||||
import stdx.net.tls.*
|
||||
|
||||
|
||||
class Rerank {
|
||||
let client: Client
|
||||
public Rerank(let url!: String, let key!: String, let model!: String) {
|
||||
var config = TlsClientConfig()
|
||||
config.verifyMode = TrustAll
|
||||
client = ClientBuilder()
|
||||
.tlsConfig(config)
|
||||
// AI 服务响应有时候比较慢,这里设置为无限等待
|
||||
.readTimeout(Duration.Max)
|
||||
.build()
|
||||
}
|
||||
|
||||
func send(input: String,documents:JsonArray,topk:Int) {
|
||||
let content = '''
|
||||
{ "model":"${model}",
|
||||
"query":"${input}",
|
||||
"documents": ${documents},
|
||||
"instruction": "Please rerank the documents based on the query.",
|
||||
"top_n": ${topk},
|
||||
"return_documents": true
|
||||
}
|
||||
'''
|
||||
let request = HttpRequestBuilder()
|
||||
.url(url)
|
||||
.header('Authorization', 'Bearer ${key}')
|
||||
.header('Content-Type', 'application/json')
|
||||
.body(content)
|
||||
.post()
|
||||
.build()
|
||||
client.send(request)
|
||||
}
|
||||
|
||||
|
||||
public func rerank(input: String,documents:JsonArray,topk:Int) {
|
||||
|
||||
let response = send(input,documents,topk)
|
||||
let output = StringReader(response.body).readToEnd()
|
||||
let resultArray = JsonValue.fromStr(output).asObject().getFields()["results"].asArray().getItems()
|
||||
let list = ArrayList<String>()
|
||||
for(result in resultArray){
|
||||
let jsonObject = result.asObject()
|
||||
let text = jsonObject.getFields()["document"].asObject().getFields()["text"].asString()
|
||||
list.add(text.toString())
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
}
|
||||
18
src/utils.cj
Normal file
18
src/utils.cj
Normal file
@@ -0,0 +1,18 @@
|
||||
package XDUMsgBot_cj
|
||||
|
||||
import std.fs.*
|
||||
import std.io.*
|
||||
import stdx.encoding.json.*
|
||||
|
||||
func getConfig():JsonObject {
|
||||
// open and read config file
|
||||
let configPath:Path = Path("./config.json")
|
||||
if(!exists(configPath)){
|
||||
println("Error! config.json doesn't exist")
|
||||
}
|
||||
let configFile:File = File(configPath,Read)
|
||||
let configBytes:Array<Byte> = readToEnd(configFile)
|
||||
configFile.close()
|
||||
let config = JsonValue.fromStr(String.fromUtf8(configBytes)).asObject()
|
||||
return config
|
||||
}
|
||||
Reference in New Issue
Block a user