代码拉取完成,页面将自动刷新
/*
* Copyright 2024 KylinSoft Co., Ltd.
*
* This program is free software: you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free Software
* Foundation, either version 3 of the License, or (at your option) any later
* version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License along with
* this program. If not, see <https://www.gnu.org/licenses/>.
*/
#include "ondevicenlpengine.h"
#include "nlpserver/nlpserver.h"
#include <jsoncpp/json/json.h>
namespace ai_engine::lm::nlp {
OndeviceNlpEngine::OndeviceNlpEngine()
: nlpServer_(&NlpServer::getInstance()) {}
OndeviceNlpEngine::~OndeviceNlpEngine() = default;
void OndeviceNlpEngine::setChatResultCallback(
nlp::ChatResultCallback callback) {
chatResultCallback_ = std::move(callback);
if (nlpServer_ == nullptr) {
return;
}
auto chatCallback = [this](const std::string &result) {
this->onChatResult(result);
};
nlpServer_->setCompletionResultCallback(chatCallback, sessionId_);
}
void OndeviceNlpEngine::setContextSize(int size) { (void)size; }
void OndeviceNlpEngine::clearContext() {}
bool OndeviceNlpEngine::chat(const std::string &message,
ai_engine::lm::EngineError &error) {
chatStopped_ = false;
if (nlpServer_ == nullptr) {
error = {AiCapability::Nlp, EngineErrorCategory::Initialization,
(int)NlpEngineErrorCode::FailedToConnectServer,
"[OndeviceNlpEngine]: NlpServer is not initialized"};
return false;
}
if (!inited_) {
if (!initChatModule(error)) {
return false;
}
setChatResultCallback(chatResultCallback_);
}
releaseTimerCondition_.notify_all();
std::thread t([this]() {
std::unique_lock<std::mutex> lock(releaseTimerMutex_);
// 5 分钟无新对话自动卸载模型
if (releaseTimerCondition_.wait_for(lock,
std::chrono::seconds(5 * 60)) ==
std::cv_status::timeout) {
releaseTimerCallback();
}
});
t.detach();
return nlpServer_->completion(message, sessionId_, taskId_, slotId_, error);
}
void OndeviceNlpEngine::stopChat() {
chatStopped_ = true;
if (nlpServer_ == nullptr) {
return;
}
nlpServer_->cancelCompletion(taskId_);
}
void OndeviceNlpEngine::onChatResult(const std::string &result) {
if (chatResultCallback_ == nullptr) {
return;
}
if (chatStopped_) {
return;
}
Json::Reader reader;
Json::Value body;
if (!reader.parse(result, body)) {
// std::cout << "Parse json failed !" << std::endl;
return;
}
slotId_ = body["id_slot"].asUInt();
body["is_end"] = body["stop"];
body["result"] = body["content"];
// body["sentence_id"] = 0; // 本地模型不需要 sentence_id
ChatResult res{body.toStyledString(), EngineError()};
chatResultCallback_(res);
}
bool OndeviceNlpEngine::initChatModule(ai_engine::lm::EngineError &error) {
if (inited_) {
return true;
}
inited_ = true;
return nlpServer_->initSession(sessionId_, error);
}
bool OndeviceNlpEngine::destroyChatModule(ai_engine::lm::EngineError &error) {
(void)error;
if (!inited_) {
return true;
}
inited_ = false;
releaseTimerCondition_.notify_all();
return nlpServer_->destroySession(sessionId_);
}
void OndeviceNlpEngine::releaseTimerCallback() {
ai_engine::lm::EngineError error;
destroyChatModule(error);
}
} // namespace ai_engine::lm::nlp
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。