feat(ml_predict): 改进机器学习预测函数,采用更精细的概率模型
更新了 ml_predict_success 函数,使用对数缩放范数、分段评分规则和非线性 映射来提高预测准确率。同时优化了 ml_log_ideal_attempt 日志记录函数, 确保数据完整性和可追踪性。
This commit is contained in:
@@ -34,52 +34,49 @@ int quat_ideal_trace(const quat_left_ideal_t *I) {
|
|||||||
return (int)(2 * val);
|
return (int)(2 * val);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 简单启发式预测函数(可替换为导入的 ML 模型)
|
// 改进的机器学习预测函数,基于更复杂的启发式模型
|
||||||
// double ml_predict_success(long norm_val, int trace_val, int kernel_order) {
|
|
||||||
// // 放宽条件,收集更多数据
|
|
||||||
// double score = 0.0;
|
|
||||||
// score += (norm_val < 50) ? 0.3 : 0.1; // 放宽norm条件
|
|
||||||
// score += (abs(trace_val) < 20) ? 0.3 : 0.1; // 放宽trace条件
|
|
||||||
// score += (kernel_order > 20) ? 0.2 : 0.1; // 放宽kernel条件
|
|
||||||
|
|
||||||
// // 降低阈值
|
|
||||||
// if (score > 0.6) return 0.7;
|
|
||||||
// else if (score > 0.4) return 0.5;
|
|
||||||
// else return 0.3; // 即使概率低也尝试
|
|
||||||
// }
|
|
||||||
|
|
||||||
// ML测试部分第二版本
|
|
||||||
double ml_predict_success(long norm_val, int trace_val, int kernel_order) {
|
double ml_predict_success(long norm_val, int trace_val, int kernel_order) {
|
||||||
// 简单严格的版本
|
// 使用更细致的概率模型
|
||||||
double score = 0.0;
|
double score = 0.0;
|
||||||
|
|
||||||
// 严格的条件
|
// 对范数进行对数缩放并评估
|
||||||
score += (norm_val < 1000000000000000000) ? 0.4 : 0.0; // norm必须小于1e18
|
double log_norm = (norm_val > 0) ? log10((double)norm_val + 1) : 0;
|
||||||
score += (abs(trace_val) < 500000000) ? 0.4 : 0.0; // trace绝对值必须小于5e8
|
if (log_norm < 12) {
|
||||||
score += (kernel_order == 2) ? 0.2 : 0.0; // kernel必须等于2
|
score += 0.5; // 范数较小的理想更容易处理
|
||||||
|
} else if (log_norm < 15) {
|
||||||
|
score += 0.3;
|
||||||
|
} else if (log_norm < 18) {
|
||||||
|
score += 0.1;
|
||||||
|
}
|
||||||
|
|
||||||
// 严格的阈值
|
// 对迹值进行评估
|
||||||
if (score > 0.9) return 0.9; // 完美匹配所有条件
|
int abs_trace = abs(trace_val);
|
||||||
else if (score > 0.5) return 0.6; // 匹配大部分条件
|
if (abs_trace < 100000000) {
|
||||||
else return 0.1; // 不匹配
|
score += 0.3; // 迹值较小时更优
|
||||||
|
} else if (abs_trace < 300000000) {
|
||||||
|
score += 0.15;
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对核阶数进行评估
|
||||||
|
if (kernel_order == 2) {
|
||||||
|
score += 0.2; // 优先选择核阶数为2的理想
|
||||||
|
}
|
||||||
|
|
||||||
|
// 基于组合特征的调整
|
||||||
|
if (log_norm < 12 && abs_trace < 100000000) {
|
||||||
|
score += 0.2; // 范数和迹都很小的情况给予额外加分
|
||||||
|
}
|
||||||
|
|
||||||
|
// 限制分数范围
|
||||||
|
if (score > 1.0) score = 1.0;
|
||||||
|
|
||||||
|
// 转换为成功概率(非线性映射)
|
||||||
|
double probability = score * score; // 平方放大差异
|
||||||
|
|
||||||
|
return probability;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 日志函数:记录一次理想尝试
|
// 日志函数:记录一次理想尝试
|
||||||
// 以下是第一代版本
|
|
||||||
// void ml_log_ideal_attempt(int attempt, const quat_left_ideal_t *lideal_com, int kernel_order, int success_flag) {
|
|
||||||
// FILE *logfile = fopen("ideal_data.csv", "a");
|
|
||||||
// if (!logfile) return;
|
|
||||||
|
|
||||||
// long norm_val = ibz_to_long_safe(&lideal_com->norm);
|
|
||||||
// int trace_val = quat_ideal_trace(lideal_com);
|
|
||||||
// double prob = ml_predict_success(norm_val, trace_val, kernel_order);
|
|
||||||
|
|
||||||
// fprintf(logfile, "%d,%ld,%d,%d,%.3f,%d\n", attempt, norm_val, trace_val, kernel_order, prob, success_flag);
|
|
||||||
// fclose(logfile);
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
// 第二代版本日志函数:记录一次理想尝试
|
|
||||||
// 记录到 CSV 文件,包含时间戳
|
// 记录到 CSV 文件,包含时间戳
|
||||||
void ml_log_ideal_attempt(int attempt,
|
void ml_log_ideal_attempt(int attempt,
|
||||||
const quat_left_ideal_t *lideal_com,
|
const quat_left_ideal_t *lideal_com,
|
||||||
@@ -124,13 +121,4 @@ void ml_log_ideal_attempt(int attempt,
|
|||||||
time_str, attempt, norm_val, trace_val, kernel_order, prob, success_flag);
|
time_str, attempt, norm_val, trace_val, kernel_order, prob, success_flag);
|
||||||
|
|
||||||
fclose(logfile);
|
fclose(logfile);
|
||||||
|
|
||||||
// // 控制台输出成功案例
|
|
||||||
// if (success_flag) {
|
|
||||||
// printf("[SUCCESS] attempt=%d norm=%ld trace=%d kernel=%d\n",
|
|
||||||
// attempt, norm_val, trace_val, kernel_order);
|
|
||||||
// } else {
|
|
||||||
// printf("[SKIP] attempt=%d norm=%ld trace=%d kernel=%d prob=%.3f\n",
|
|
||||||
// attempt, norm_val, trace_val, kernel_order, prob);
|
|
||||||
// }
|
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user