-
-
Notifications
You must be signed in to change notification settings - Fork 26
Expand file tree
/
Copy pathProgram.cs
More file actions
153 lines (132 loc) · 5.69 KB
/
Copy pathProgram.cs
File metadata and controls
153 lines (132 loc) · 5.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
using System;
using System.IO;
using System.Linq;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Transforms;
using BetterConsoleTables;
namespace Mnist
{
/// <summary>
/// The Digit class represents one mnist digit.
/// </summary>
class Digit
{
[ColumnName("PixelValues")]
[VectorType(784)]
public float[] PixelValues;
[LoadColumn(0)]
public float Number;
}
/// <summary>
/// The DigitPrediction class represents one digit prediction.
/// </summary>
class DigitPrediction
{
[ColumnName("Score")]
public float[] Score;
}
/// <summary>
/// The main program class.
/// </summary>
class Program
{
// filenames for data set
private static string trainDataPath = Path.Combine(Environment.CurrentDirectory, "mnist_train.csv");
private static string testDataPath = Path.Combine(Environment.CurrentDirectory, "mnist_test.csv");
/// <summary>
/// The main program entry point.
/// </summary>
/// <param name="args">The command line arguments.</param>
static void Main(string[] args)
{
// create a machine learning context
var context = new MLContext();
// load data
Console.WriteLine("Loading data....");
var columnDef = new TextLoader.Column[]
{
new TextLoader.Column(nameof(Digit.PixelValues), DataKind.Single, 1, 784),
new TextLoader.Column("Number", DataKind.Single, 0)
};
var trainDataView = context.Data.LoadFromTextFile(
path: trainDataPath,
columns : columnDef,
hasHeader : true,
separatorChar : ',');
var testDataView = context.Data.LoadFromTextFile(
path: testDataPath,
columns : columnDef,
hasHeader : true,
separatorChar : ',');
// build a training pipeline
// step 1: map the number column to a key value and store in the label column
var pipeline = context.Transforms.Conversion.MapValueToKey(
outputColumnName: "Label",
inputColumnName: "Number",
keyOrdinality: ValueToKeyMappingEstimator.KeyOrdinality.ByValue)
// step 2: concatenate all feature columns
.Append(context.Transforms.Concatenate(
"Features",
nameof(Digit.PixelValues)))
// step 3: cache data to speed up training
.AppendCacheCheckpoint(context)
// step 4: train the model with SDCA
.Append(context.MulticlassClassification.Trainers.SdcaMaximumEntropy(
labelColumnName: "Label",
featureColumnName: "Features"))
// step 5: map the label key value back to a number
.Append(context.Transforms.Conversion.MapKeyToValue(
outputColumnName: "Number",
inputColumnName: "Label"));
// train the model
Console.WriteLine("Training model....");
var model = pipeline.Fit(trainDataView);
// use the model to make predictions on the test data
Console.WriteLine("Evaluating model....");
var predictions = model.Transform(testDataView);
// evaluate the predictions
var metrics = context.MulticlassClassification.Evaluate(
data: predictions,
labelColumnName: "Number",
scoreColumnName: "Score");
// show evaluation metrics
Console.WriteLine($"Evaluation metrics");
Console.WriteLine($" MicroAccuracy: {metrics.MicroAccuracy:0.###}");
Console.WriteLine($" MacroAccuracy: {metrics.MacroAccuracy:0.###}");
Console.WriteLine($" LogLoss: {metrics.LogLoss:#.###}");
Console.WriteLine($" LogLossReduction: {metrics.LogLossReduction:#.###}");
Console.WriteLine();
// grab three digits from the test data
var digits = context.Data.CreateEnumerable<Digit>(testDataView, reuseRowObject: false).ToArray();
var testDigits = new Digit[] { digits[5], digits[16], digits[28], digits[63], digits[129] };
// create a prediction engine
var engine = context.Model.CreatePredictionEngine<Digit, DigitPrediction>(model);
// set up a table to show the predictions
var table = new Table(TableConfiguration.Unicode());
table.AddColumn("Digit");
for (var i = 0; i < 10; i++)
table.AddColumn($"P{i}");
// predict each test digit
for (var i=0; i < testDigits.Length; i++)
{
var prediction = engine.Predict(testDigits[i]);
table.AddRow(
testDigits[i].Number,
prediction.Score[0].ToString("P2"),
prediction.Score[1].ToString("P2"),
prediction.Score[2].ToString("P2"),
prediction.Score[3].ToString("P2"),
prediction.Score[4].ToString("P2"),
prediction.Score[5].ToString("P2"),
prediction.Score[6].ToString("P2"),
prediction.Score[7].ToString("P2"),
prediction.Score[8].ToString("P2"),
prediction.Score[9].ToString("P2"));
}
// show results
Console.WriteLine(table.ToString());
Console.ReadKey();
}
}
}