-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathclassifier_gbt.cpp
More file actions
56 lines (51 loc) · 1.94 KB
/
classifier_gbt.cpp
File metadata and controls
56 lines (51 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
#include "classifier_gbt.h"
ClassifierGBT::ClassifierGBT()
{
par_Type = gbt_type_idx[0];
par_WeakCount = 100;
par_Shrinkage = 0.1;
par_SubsamplePortion = 0.2;
par_MaxDepth = 2;
par_UseSurrogates = false;
}
void ClassifierGBT::trainData(const std::vector<cv::Point> &data, const std::vector<int> &labels)
{
cls.clear();
loadData(data, labels);
cv::Mat var_types( 1, pData.cols + 1, CV_8UC1, cv::Scalar(CV_VAR_ORDERED) );
var_types.at<uchar>( pData.cols ) = CV_VAR_CATEGORICAL;
cv::GradientBoostingTreeParams params;
params.loss_function_type = par_Type;
params.weak_count = par_WeakCount;
params.shrinkage = par_Shrinkage;
params.subsample_portion = par_SubsamplePortion;
params.max_depth = par_MaxDepth;
params.use_surrogates = par_UseSurrogates;
cls.train(pData, CV_ROW_SAMPLE, lData, cv::Mat(), cv::Mat(), var_types, cv::Mat(), params);
isTrainedFlag = true;
}
int ClassifierGBT::classify(int x, int y)
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
return cvRound(cls.predict(testSample));
}
QString ClassifierGBT::toQString() const
{
return QString("GBT{loss_fun_type=%1, weak_count=%2, shrinkage=%3, subsample_portion=%4, max_depth=%5, use_surrogates=%6}")
.arg(gbt_type_name[par_Type])
.arg(par_WeakCount)
.arg(par_Shrinkage)
.arg(par_SubsamplePortion)
.arg(par_MaxDepth)
.arg(par_UseSurrogates);
}
void ClassifierGBT::setParameters(int type, int weakCount, float shrinkage, float subsamplePortion, int maxDepth, bool useSurrogates)
{
par_Type = type;
par_WeakCount = weakCount;
par_Shrinkage = shrinkage;
par_SubsamplePortion = subsamplePortion;
par_MaxDepth = maxDepth;
par_UseSurrogates = useSurrogates;
}