feat: 本项目基本完成
新增 rerankURL、vectorDBURL 配置,以及 rerankModel、retriveTopK、 rerankTopK 和 memory 等参数,以支持更完整的检索和重排序功能。
This commit is contained in:
18
config.json
18
config.json
@@ -1,7 +1,13 @@
|
|||||||
{
|
{
|
||||||
"queryURL":"https://api.siliconflow.cn/v1/chat/completions",
|
"queryURL":"https://api.siliconflow.cn/v1/chat/completions",
|
||||||
"embeddingURL":"https://api.siliconflow.cn/v1/embeddings",
|
"embeddingURL":"https://api.siliconflow.cn/v1/embeddings",
|
||||||
"key":"sk-xsoegkpdvqlbsoodrnaygqycdvhplkyowivkzlszqfytpvti",
|
"rerankURL":"https://api.siliconflow.cn/v1/rerank",
|
||||||
"queryModel":"Qwen/Qwen3-8B",
|
"vectorDBURL":"http://localhost:8000",
|
||||||
"embeddingModel":"Qwen/Qwen3-Embedding-0.6B"
|
"key":"sk-xsoegkpdvqlbsoodrnaygqycdvhplkyowivkzlszqfytpvti",
|
||||||
|
"queryModel":"Qwen/Qwen3-8B",
|
||||||
|
"embeddingModel":"Qwen/Qwen3-Embedding-0.6B",
|
||||||
|
"rerankModel":"Qwen/Qwen3-Reranker-0.6B",
|
||||||
|
"retriveTopK":20,
|
||||||
|
"rerankTopK":10,
|
||||||
|
"memory":true
|
||||||
}
|
}
|
||||||
File diff suppressed because one or more lines are too long
108
src/embedding.cj
108
src/embedding.cj
@@ -1,56 +1,54 @@
|
|||||||
// 本示例演示访问 DeepSeek 大模型
|
// 本示例演示访问 DeepSeek 大模型
|
||||||
package XDUMsgBot_cj
|
package XDUMsgBot_cj
|
||||||
import std.collection.{ArrayList, reduce}
|
// import std.collection.{ArrayList, reduce}
|
||||||
import std.io.StringReader
|
import std.io.StringReader
|
||||||
import stdx.encoding.json.*
|
import stdx.encoding.json.*
|
||||||
import stdx.net.http.*
|
import stdx.net.http.*
|
||||||
import stdx.net.tls.*
|
import stdx.net.tls.*
|
||||||
|
|
||||||
|
|
||||||
class Embedding {
|
class Embedding {
|
||||||
let client: Client
|
let client: Client
|
||||||
public Embedding(let url!: String, let key!: String, let model!: String) {
|
public Embedding(let url!: String, let key!: String, let model!: String) {
|
||||||
var config = TlsClientConfig()
|
var config = TlsClientConfig()
|
||||||
config.verifyMode = TrustAll
|
config.verifyMode = TrustAll
|
||||||
client = ClientBuilder()
|
client = ClientBuilder()
|
||||||
.tlsConfig(config)
|
.tlsConfig(config)
|
||||||
// AI 服务响应有时候比较慢,这里设置为无限等待
|
// AI 服务响应有时候比较慢,这里设置为无限等待
|
||||||
.readTimeout(Duration.Max)
|
.readTimeout(Duration.Max)
|
||||||
.build()
|
.build()
|
||||||
}
|
}
|
||||||
|
|
||||||
func send(input: String) {
|
func send(input: String) {
|
||||||
let content = '''
|
let content = '''
|
||||||
{ "model":"${model}",
|
{ "model":"${model}",
|
||||||
"input":"${input}",
|
"input":"${input}",
|
||||||
"encoding_format": "float",
|
"encoding_format": "float",
|
||||||
"dimensions": 1024
|
"dimensions": 1024
|
||||||
}
|
}
|
||||||
'''
|
'''
|
||||||
println(content)
|
let request = HttpRequestBuilder()
|
||||||
let request = HttpRequestBuilder()
|
.url(url)
|
||||||
.url(url)
|
.header('Authorization', 'Bearer ${key}')
|
||||||
.header('Authorization', 'Bearer ${key}')
|
.header('Content-Type', 'application/json')
|
||||||
.header('Content-Type', 'application/json')
|
.body(content)
|
||||||
.body(content)
|
.post()
|
||||||
.post()
|
.build()
|
||||||
.build()
|
client.send(request)
|
||||||
client.send(request)
|
}
|
||||||
}
|
|
||||||
|
func parse(text: String) {
|
||||||
func parse(text: String) {
|
let json = JsonValue.fromStr(text).asObject()
|
||||||
let json = JsonValue.fromStr(text).asObject()
|
let data = json.getFields()['data'].asArray()
|
||||||
println(json)
|
let embedding = data[0].asObject().getFields()['embedding'].asArray()
|
||||||
let data = json.getFields()['data'].asArray()
|
return embedding
|
||||||
let embedding = data[0].asObject().getFields()['embedding'].asArray()
|
}
|
||||||
return embedding
|
|
||||||
}
|
|
||||||
|
public func embed(input: String) {
|
||||||
|
let response = send(input)
|
||||||
public func embed(input: String) {
|
let output = StringReader(response.body).readToEnd() |> parse
|
||||||
let response = send(input)
|
return output
|
||||||
let output = StringReader(response.body).readToEnd() |> parse
|
}
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
244
src/main.cj
244
src/main.cj
@@ -3,47 +3,12 @@ package XDUMsgBot_cj
|
|||||||
import std.fs.*
|
import std.fs.*
|
||||||
import std.io.*
|
import std.io.*
|
||||||
import stdx.encoding.json.*
|
import stdx.encoding.json.*
|
||||||
|
import stdx.net.http.*
|
||||||
|
// import stdx.net.tls.*
|
||||||
|
|
||||||
// class Config <: Serializable<Config> {
|
func prepare(){
|
||||||
// var url:String = ""
|
|
||||||
// var key:String = ""
|
|
||||||
// var embeddingModel:String = ""
|
|
||||||
// var queryModel:String = ""
|
|
||||||
// public func serialize():DataModel{
|
|
||||||
// return DataModelStruct()
|
|
||||||
// .add(field<String>("url",url))
|
|
||||||
// .add(field<String>("key",key))
|
|
||||||
// .add(field<String>("embeddingModel",embeddingModel))
|
|
||||||
// .add(field<String>("queryModel",queryModel))
|
|
||||||
// }
|
|
||||||
// public static func deserialize(dm: DataModel): Config {
|
|
||||||
// var dms = match (dm) {
|
|
||||||
// case data: DataModelStruct => data
|
|
||||||
// case _ => throw Exception("this data is not DataModelStruct")
|
|
||||||
// }
|
|
||||||
// var result = Config()
|
|
||||||
// result.url = String.deserialize(dms.get("url"))
|
|
||||||
// result.key = String.deserialize(dms.get("key"))
|
|
||||||
// result.embeddingModel = String.deserialize(dms.get("embeddingModel"))
|
|
||||||
// result.queryModel = String.deserialize(dms.get("queryModel"))
|
|
||||||
// return result
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
let config = getConfig()
|
||||||
main() {
|
|
||||||
|
|
||||||
/* prepare and chunk data */
|
|
||||||
// open and read config file
|
|
||||||
let configPath:Path = Path("./config.json")
|
|
||||||
if(!exists(configPath)){
|
|
||||||
println("Error! config.json doesn't exist")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
let configFile:File = File(configPath,Read)
|
|
||||||
let configBytes:Array<Byte> = readToEnd(configFile)
|
|
||||||
configFile.close()
|
|
||||||
let config = JsonValue.fromStr(String.fromUtf8(configBytes)).asObject()
|
|
||||||
|
|
||||||
// open and read data file
|
// open and read data file
|
||||||
let dataPath:Path = Path("./data/data.txt")
|
let dataPath:Path = Path("./data/data.txt")
|
||||||
@@ -59,63 +24,156 @@ main() {
|
|||||||
let dataArray = dataString.split("\r\n")
|
let dataArray = dataString.split("\r\n")
|
||||||
|
|
||||||
|
|
||||||
/* embedding */
|
/* embedding and store vector */
|
||||||
let embeddingModel = Embedding(url:config.getFields()['embeddingURL'].asString().getValue(),
|
let embeddingModel = Embedding(url:config.getFields()['embeddingURL'].asString().getValue(),
|
||||||
key:config.getFields()['key'].asString().getValue(),
|
key:config.getFields()['key'].asString().getValue(),
|
||||||
model:config.getFields()['embeddingModel'].asString().getValue())
|
model:config.getFields()['embeddingModel'].asString().getValue())
|
||||||
// open and write vectors
|
|
||||||
let dataEmbedPath:Path = Path("./data/data_embed.txt")
|
|
||||||
if(!exists(dataEmbedPath)){
|
|
||||||
println("Error! data/data_embed.txt doesn't exist")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
let dataEmbedFile:File = File(dataEmbedPath,Write)
|
|
||||||
dataEmbedFile.setLength(0)
|
|
||||||
var i = 0
|
|
||||||
for(data in dataArray){
|
|
||||||
|
|
||||||
let vector = embeddingModel.embed(data).toString()
|
// open and store vectors
|
||||||
dataEmbedFile.write(vector.toArray())
|
// let dataEmbedPath:Path = Path("./data/data_embed.txt")
|
||||||
i++
|
// if(!exists(dataEmbedPath)){
|
||||||
println(i)
|
// println("Error! data/data_embed.txt doesn't exist")
|
||||||
}
|
// return
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// // 使用 SiliconFlow 提供的服务接口
|
|
||||||
// let robot = LLM(url: 'https://api.siliconflow.cn/v1/chat/completions',
|
|
||||||
// // 如果示例自带的密钥失效,请自行注册,https://cloud.siliconflow.cn/account/ak
|
|
||||||
// key: 'sk-xsoegkpdvqlbsoodrnaygqycdvhplkyowivkzlszqfytpvti',
|
|
||||||
// model: 'Qwen/Qwen3-8B',
|
|
||||||
// memory: true)
|
|
||||||
|
|
||||||
// // robot.preset('我会用林黛玉的风格回复哥哥的所有问题')
|
|
||||||
// // robot.chats('介绍李白')
|
|
||||||
// // println('\n----------\n')
|
|
||||||
|
|
||||||
// // robot.chats('他和安徽的不解情缘')
|
|
||||||
// // println('\n----------\n')
|
|
||||||
|
|
||||||
// // robot.reset()
|
|
||||||
// // robot.chat('你好') |> println
|
|
||||||
// // robot.chat('却是荷池跳雨,散了真珠还聚') |> println
|
|
||||||
|
|
||||||
// while (true) {
|
|
||||||
// let input = readln()
|
|
||||||
// if (input.startsWith('风格#')) {
|
|
||||||
// let style = input.trimStart('风格#')
|
|
||||||
// robot.switchStyle(style)
|
|
||||||
// } else {
|
|
||||||
// let reply = robot.chats(input)
|
|
||||||
// println(reply)
|
|
||||||
// }
|
|
||||||
// }
|
// }
|
||||||
|
// let dataEmbedFile:File = File(dataEmbedPath,Append)
|
||||||
|
// dataEmbedFile.setLength(0)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
for(data in dataArray){
|
||||||
|
let vector = embeddingModel.embed(data).toString()
|
||||||
|
|
||||||
|
let client = ClientBuilder().build()
|
||||||
|
let content = '''
|
||||||
|
{ "embedding":${vector},
|
||||||
|
"document":"${data}"
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
let request = HttpRequestBuilder()
|
||||||
|
.url(config.getFields()['vectorDBURL'].asString().getValue()+"/store")
|
||||||
|
.header('Content-Type', 'application/json')
|
||||||
|
.body(content)
|
||||||
|
.post()
|
||||||
|
.build()
|
||||||
|
let rsp = client.send(request)
|
||||||
|
// read response
|
||||||
|
let buf = Array<UInt8>(1024, repeat: 0)
|
||||||
|
let len = rsp.body.read(buf)
|
||||||
|
println(String.fromUtf8(buf.slice(0, len)))
|
||||||
|
client.close()
|
||||||
|
println("stored: ${data}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getKonwledge(input:String):String{
|
||||||
|
|
||||||
|
|
||||||
|
/* retrive */
|
||||||
|
// println("-------------------retrive:-----------------------")
|
||||||
|
let config = getConfig()
|
||||||
|
|
||||||
|
let embeddingModel = Embedding(url:config.getFields()['embeddingURL'].asString().getValue(),
|
||||||
|
key:config.getFields()['key'].asString().getValue(),
|
||||||
|
model:config.getFields()['embeddingModel'].asString().getValue())
|
||||||
|
let vector = embeddingModel.embed(input).toString()
|
||||||
|
|
||||||
|
let client = ClientBuilder().build()
|
||||||
|
let content = '''
|
||||||
|
{
|
||||||
|
"query_embedding":${vector},
|
||||||
|
"top_k":${config.getFields()['retriveTopK'].asInt().getValue()}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
let request = HttpRequestBuilder()
|
||||||
|
.url(config.getFields()['vectorDBURL'].asString().getValue()+"/query")
|
||||||
|
.header('Content-Type', 'application/json')
|
||||||
|
.body(content)
|
||||||
|
.post()
|
||||||
|
.build()
|
||||||
|
let rsp = client.send(request)
|
||||||
|
let output = StringReader(rsp.body).readToEnd()
|
||||||
|
let json = JsonValue.fromStr(output).asObject().getFields()['documents'].asArray()
|
||||||
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* rerank */
|
||||||
|
// println("-------------------rerank:-----------------------")
|
||||||
|
let rerankModel = Rerank(url:config.getFields()['rerankURL'].asString().getValue(),
|
||||||
|
key:config.getFields()['key'].asString().getValue(),
|
||||||
|
model:config.getFields()['rerankModel'].asString().getValue())
|
||||||
|
let result = rerankModel.rerank(input,json,config.getFields()['rerankTopK'].asInt().getValue())
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/* query*/
|
||||||
|
|
||||||
|
let baseKnowledge = StringBuilder()
|
||||||
|
|
||||||
|
for(item in result){
|
||||||
|
baseKnowledge.append("参考资料:")
|
||||||
|
baseKnowledge.append(item)
|
||||||
|
baseKnowledge.append("\r\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return baseKnowledge.toString()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
main() {
|
||||||
|
|
||||||
|
// just need to run once
|
||||||
|
// prepare()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
let config = getConfig()
|
||||||
|
let robot = Query(url: config.getFields()['queryURL'].asString().getValue(),
|
||||||
|
key: config.getFields()['key'].asString().getValue(),
|
||||||
|
model: config.getFields()['queryModel'].asString().getValue(),
|
||||||
|
memory: config.getFields()['memory'].asBool().getValue())
|
||||||
|
|
||||||
|
let input = readln()
|
||||||
|
let baseKnowledge = getKonwledge(input)
|
||||||
|
robot.preset(input,baseKnowledge)
|
||||||
|
let reply = robot.chats(input)
|
||||||
|
println(reply)
|
||||||
|
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
let input = readln()
|
||||||
|
if (input.startsWith('新对话#')) {
|
||||||
|
let prompt = input.trimStart('新对话#')
|
||||||
|
let baseKnowledge = getKonwledge(prompt)
|
||||||
|
robot.preset(prompt,baseKnowledge)
|
||||||
|
println("--------新对话: 关于问题:${prompt}-------------")
|
||||||
|
let reply = robot.chats(input)
|
||||||
|
println(reply)
|
||||||
|
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||||
|
|
||||||
|
} else {
|
||||||
|
let reply = robot.chats(input)
|
||||||
|
println(reply)
|
||||||
|
println("\n\n------------------回答结束,如果想聊聊新话题,可以输入“新对话#<新话题>”开始新的对话------------------\n\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -1,128 +1,132 @@
|
|||||||
// 本示例演示访问 DeepSeek 大模型
|
package XDUMsgBot_cj
|
||||||
package XDUMsgBot_cj
|
import std.collection.{ArrayList, reduce}
|
||||||
import std.collection.{ArrayList, reduce}
|
import std.io.StringReader
|
||||||
import std.io.StringReader
|
import stdx.encoding.json.*
|
||||||
import stdx.encoding.json.*
|
import stdx.net.http.*
|
||||||
import stdx.net.http.*
|
import stdx.net.tls.*
|
||||||
import stdx.net.tls.*
|
|
||||||
|
// AI 对话中的三类角色
|
||||||
// AI 对话中的三类角色
|
enum Role <: ToString {
|
||||||
enum Role <: ToString {
|
I | AI | System
|
||||||
I | AI | System
|
public func toString() {
|
||||||
public func toString() {
|
match (this) {
|
||||||
match (this) {
|
case I => 'user'
|
||||||
case I => 'user'
|
case AI => 'assistant'
|
||||||
case AI => 'assistant'
|
case System => 'system'
|
||||||
case System => 'system'
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
// 用 ArrayList 记录历史对话,扩展两个工具函数
|
||||||
// 用 ArrayList 记录历史对话,扩展两个工具函数
|
extend ArrayList<String> {
|
||||||
extend ArrayList<String> {
|
func add(role: Role, content: String) {
|
||||||
func add(role: Role, content: String) {
|
'{"role":"${role}","content":${JsonString(content)}}' |> add
|
||||||
'{"role":"${role}","content":${JsonString(content)}}' |> add
|
}
|
||||||
}
|
|
||||||
|
func literal() {
|
||||||
func literal() {
|
(this |> reduce { a, b =>
|
||||||
(this |> reduce { a, b =>
|
a + ',' + b
|
||||||
a + ',' + b
|
}) ?? '' // ?? 相当于简化版的 getOrDefault
|
||||||
}) ?? '' // ?? 相当于简化版的 getOrDefault
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
class Query {
|
||||||
class LLM {
|
let client: Client
|
||||||
let client: Client
|
let history = ArrayList<String>()
|
||||||
let history = ArrayList<String>()
|
public Query(let url!: String, let key!: String, let model!: String,
|
||||||
public LLM(let url!: String, let key!: String, let model!: String,
|
var memory!: Bool = false) {
|
||||||
var memory!: Bool = false) {
|
var config = TlsClientConfig()
|
||||||
var config = TlsClientConfig()
|
config.verifyMode = TrustAll
|
||||||
config.verifyMode = TrustAll
|
client = ClientBuilder()
|
||||||
client = ClientBuilder()
|
.tlsConfig(config)
|
||||||
.tlsConfig(config)
|
// AI 服务响应有时候比较慢,这里设置为无限等待
|
||||||
// AI 服务响应有时候比较慢,这里设置为无限等待
|
.readTimeout(Duration.Max)
|
||||||
.readTimeout(Duration.Max)
|
.build()
|
||||||
.build()
|
}
|
||||||
}
|
|
||||||
|
func send(input: String, stream!: Bool = false) {
|
||||||
func send(input: String, stream!: Bool = false) {
|
if (!memory) {
|
||||||
if (!memory) {
|
history.clear()
|
||||||
history.clear()
|
}
|
||||||
}
|
history.add(I, input)
|
||||||
history.add(I, input)
|
let content = '''
|
||||||
let content = '''
|
{ "model":"${model}",
|
||||||
{ "model":"${model}",
|
"messages":[${history.literal()}],
|
||||||
"messages":[${history.literal()}],
|
"stream":${stream},
|
||||||
"stream":${stream},
|
"enable_thinking": false}'''
|
||||||
"enable_thinking": false}'''
|
|
||||||
|
let request = HttpRequestBuilder()
|
||||||
let request = HttpRequestBuilder()
|
.url(url)
|
||||||
.url(url)
|
.header('Authorization', 'Bearer ${key}')
|
||||||
.header('Authorization', 'Bearer ${key}')
|
.header('Content-Type', 'application/json')
|
||||||
.header('Content-Type', 'application/json')
|
.header('Accept', if (stream) {
|
||||||
.header('Accept', if (stream) {
|
'text/event-stream'
|
||||||
'text/event-stream'
|
} else {
|
||||||
} else {
|
'application/json'
|
||||||
'application/json'
|
})
|
||||||
})
|
.body(content)
|
||||||
.body(content)
|
.post()
|
||||||
.post()
|
.build()
|
||||||
.build()
|
client.send(request)
|
||||||
client.send(request)
|
}
|
||||||
}
|
|
||||||
|
func parse(text: String, stream!: Bool = false) {
|
||||||
func parse(text: String, stream!: Bool = false) {
|
let json = JsonValue.fromStr(text).asObject()
|
||||||
let json = JsonValue.fromStr(text).asObject()
|
let choices = json.getFields()['choices'].asArray()
|
||||||
let choices = json.getFields()['choices'].asArray()
|
// 流式和非流式情况下,这个字段名称不同
|
||||||
// 流式和非流式情况下,这个字段名称不同
|
let key = if (stream) { 'delta' } else { 'message' }
|
||||||
let key = if (stream) { 'delta' } else { 'message' }
|
let message = choices[0].asObject().getFields()[key].asObject()
|
||||||
let message = choices[0].asObject().getFields()[key].asObject()
|
let content = message.getFields()['content'].asString().getValue()
|
||||||
let content = message.getFields()['content'].asString().getValue()
|
return content
|
||||||
return content
|
}
|
||||||
}
|
|
||||||
|
// 流式对话
|
||||||
// 流式对话
|
public func chats(input: String, task!: (String) -> Unit = {o => print(o)}) {
|
||||||
public func chats(input: String, task!: (String) -> Unit = {o => print(o)}) {
|
let response = send(input, stream: true)
|
||||||
let response = send(input, stream: true)
|
let output = StringBuilder()
|
||||||
let output = StringBuilder()
|
let buffer = Array<Byte>(1024 * 8, repeat: 0)
|
||||||
let buffer = Array<Byte>(1024 * 8, repeat: 0)
|
var length = response.body.read(buffer)
|
||||||
var length = response.body.read(buffer)
|
while (length != 0) {
|
||||||
while (length != 0) {
|
let text = String.fromUtf8(buffer[..length])
|
||||||
let text = String.fromUtf8(buffer[..length])
|
const INDEX = 6
|
||||||
const INDEX = 6
|
for (line in text.split('\n', removeEmpty: true)) {
|
||||||
for (line in text.split('\n', removeEmpty: true)) {
|
if (line.size > INDEX && line[INDEX] == b'{') {
|
||||||
if (line.size > INDEX && line[INDEX] == b'{') {
|
let json = line[INDEX..line.size]
|
||||||
let json = line[INDEX..line.size]
|
let slice = parse(json, stream: true)
|
||||||
let slice = parse(json, stream: true)
|
output.append(slice)
|
||||||
output.append(slice)
|
task(slice)
|
||||||
task(slice)
|
}
|
||||||
}
|
}
|
||||||
}
|
length = response.body.read(buffer)
|
||||||
length = response.body.read(buffer)
|
}
|
||||||
}
|
history.add(AI, output.toString())
|
||||||
history.add(AI, output.toString())
|
}
|
||||||
}
|
|
||||||
|
// 非流式
|
||||||
// 非流式
|
public func chat(input: String) {
|
||||||
public func chat(input: String) {
|
let response = send(input)
|
||||||
let response = send(input)
|
let output = StringReader(response.body).readToEnd() |> parse
|
||||||
let output = StringReader(response.body).readToEnd() |> parse
|
history.add(AI, output)
|
||||||
history.add(AI, output)
|
return output
|
||||||
return output
|
}
|
||||||
}
|
|
||||||
|
// 角色预设或加载历史对话
|
||||||
// 角色预设或加载历史对话
|
public func preset(query: String,baseKnowledge:String, role!: Role = System) {
|
||||||
public func preset(content: String, role!: Role = System) {
|
history.clear()
|
||||||
history.add(role, content)
|
history.add(role, """
|
||||||
memory = true
|
你是一位专业的知识助手,名字叫XDUMsgBot,你可以根据相关片段中的信息回答用户关于西安电子科技大学的问题。
|
||||||
}
|
请根据用户的问题和下列片段生成准确的回应。
|
||||||
|
|
||||||
public func reset() {
|
用户问题:${query}
|
||||||
history.clear()
|
|
||||||
}
|
相关片段:
|
||||||
|
${baseKnowledge}
|
||||||
public func switchStyle(styleName:String){
|
|
||||||
reset()
|
请基于上述内容作答,如果没有明确的信息可供参考,请回答不知道,不要编造信息。""")
|
||||||
preset("用${styleName}的风格回复问题")
|
memory = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public func reset() {
|
||||||
|
history.clear()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
56
src/rerank.cj
Normal file
56
src/rerank.cj
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package XDUMsgBot_cj
|
||||||
|
import std.collection.ArrayList
|
||||||
|
import std.io.StringReader
|
||||||
|
import stdx.encoding.json.*
|
||||||
|
import stdx.net.http.*
|
||||||
|
import stdx.net.tls.*
|
||||||
|
|
||||||
|
|
||||||
|
class Rerank {
|
||||||
|
let client: Client
|
||||||
|
public Rerank(let url!: String, let key!: String, let model!: String) {
|
||||||
|
var config = TlsClientConfig()
|
||||||
|
config.verifyMode = TrustAll
|
||||||
|
client = ClientBuilder()
|
||||||
|
.tlsConfig(config)
|
||||||
|
// AI 服务响应有时候比较慢,这里设置为无限等待
|
||||||
|
.readTimeout(Duration.Max)
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func send(input: String,documents:JsonArray,topk:Int) {
|
||||||
|
let content = '''
|
||||||
|
{ "model":"${model}",
|
||||||
|
"query":"${input}",
|
||||||
|
"documents": ${documents},
|
||||||
|
"instruction": "Please rerank the documents based on the query.",
|
||||||
|
"top_n": ${topk},
|
||||||
|
"return_documents": true
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
let request = HttpRequestBuilder()
|
||||||
|
.url(url)
|
||||||
|
.header('Authorization', 'Bearer ${key}')
|
||||||
|
.header('Content-Type', 'application/json')
|
||||||
|
.body(content)
|
||||||
|
.post()
|
||||||
|
.build()
|
||||||
|
client.send(request)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public func rerank(input: String,documents:JsonArray,topk:Int) {
|
||||||
|
|
||||||
|
let response = send(input,documents,topk)
|
||||||
|
let output = StringReader(response.body).readToEnd()
|
||||||
|
let resultArray = JsonValue.fromStr(output).asObject().getFields()["results"].asArray().getItems()
|
||||||
|
let list = ArrayList<String>()
|
||||||
|
for(result in resultArray){
|
||||||
|
let jsonObject = result.asObject()
|
||||||
|
let text = jsonObject.getFields()["document"].asObject().getFields()["text"].asString()
|
||||||
|
list.add(text.toString())
|
||||||
|
}
|
||||||
|
return list
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
18
src/utils.cj
Normal file
18
src/utils.cj
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
package XDUMsgBot_cj
|
||||||
|
|
||||||
|
import std.fs.*
|
||||||
|
import std.io.*
|
||||||
|
import stdx.encoding.json.*
|
||||||
|
|
||||||
|
func getConfig():JsonObject {
|
||||||
|
// open and read config file
|
||||||
|
let configPath:Path = Path("./config.json")
|
||||||
|
if(!exists(configPath)){
|
||||||
|
println("Error! config.json doesn't exist")
|
||||||
|
}
|
||||||
|
let configFile:File = File(configPath,Read)
|
||||||
|
let configBytes:Array<Byte> = readToEnd(configFile)
|
||||||
|
configFile.close()
|
||||||
|
let config = JsonValue.fromStr(String.fromUtf8(configBytes)).asObject()
|
||||||
|
return config
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user