C# でゼロから Deep Learning を実装する挑戦の続き。 この挑戦では、『ゼロから作る Deep Learning』同様に、 手書き数字認識のニューラルネットワークを実装するので、 MNIST データセットを利用する。
MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges
MNIST データセットを読み込んで、 Math.NET Numerics の Vector に変換するコードを書いてみた。
using MathNet.Numerics.LinearAlgebra; using System; using System.IO; using System.Linq; namespace MnistSample { class Program { static void Main(string[] args) { if (args.Length != 2) { Console.WriteLine("MnistSample.exe <pixelsFilePath> <labelsFilePath>"); Console.WriteLine("Press enter to quit"); Console.ReadLine(); return; } var cancel = false; Console.CancelKeyPress += (sender, e) => { cancel = true; }; // MNIST データセットをロードする var images = MnistImage.Load(pixelFilePath: args[0], labelFilePath: args[1]); Console.WriteLine($"Image count : {images.Length}"); // 入力したインデックスにある画像の情報を表示する Console.WriteLine($"Press <Ctrl> + <C> to quit"); while (cancel == false) { Console.WriteLine($"Input index (0 ~ {images.Length - 1})"); var input = Console.ReadLine(); int index; if (int.TryParse(input, out index) && (0 <= index) && (index <= images.Length)) { var image = images[index]; Console.WriteLine($"label : {image.Label}"); Console.WriteLine(image.ToVector()); } } } } /// <summary> /// MNIST データセットの画像を表します。 /// </summary> public class MnistImage { /// <summary> /// 画像の高さを取得します。 /// </summary> public int Height { get; } /// <summary> /// 画像の幅を取得します。 /// </summary> public int Width { get; } /// <summary> /// MNIST 画像のピクセルを取得します。 /// </summary> public byte[][] Pixels { get; } /// <summary> /// 0 ~ 9 までのラベルを取得します。 /// </summary> public byte Label { get; } /// <summary> /// <see cref="MnistImage"/> クラスの新しいインスタンスを初期化します。 /// </summary> /// <param name="height">画像の高さ</param> /// <param name="width">画像の幅</param> /// <param name="pixels">画像を構成するピクセル</param> /// <param name="label">ラベル</param> public MnistImage(int height, int width, byte[][] pixels, byte label) { Height = height; Width = width; Label = label; Pixels = new byte[height][]; for (var i = 0; i < height; i++) { Pixels[i] = new byte[width]; } for (var i = 0; i < height; i++) { for (var j = 0; j < width; j++) { Pixels[i][j] = pixels[i][j]; } } } // Vector<T> と Matrix<T> は byte をサポートしていない /// <summary> /// ベクトルに変換します。 /// </summary> /// <returns>ベクトル</returns> public Vector<double> ToVector() { var flatten = Pixels.SelectMany(row => row) .Select(b => Convert.ToDouble(b)); var vector = Vector<double>.Build.DenseOfEnumerable(flatten); return vector; } /// <summary> /// 行列に変換します。 /// </summary> /// <returns>行列</returns> public Matrix<double> ToMatrix() { return Matrix<double>.Build.Dense(Height, Width, (row, col) => Pixels[row][col]); } /// <summary> /// MNIST データセットとラベルをロードします。 /// </summary> /// <param name="pixelFilePath">MNIST データセットのパス</param> /// <param name="labelFilePath">ラベルのパス</param> /// <returns><see cref="MnistImage"/> の配列</returns> public static MnistImage[] Load(string pixelFilePath, string labelFilePath) { using (var imageStream = File.OpenRead(pixelFilePath)) using (var labelStream = File.OpenRead(labelFilePath)) using (var imageReader = new BinaryReader(imageStream)) using (var labelReader = new BinaryReader(labelStream)) { int magic1 = imageReader.ReadInt32(); magic1 = ReverseBytes(magic1); // 画像の枚数を取得 int imageCount = imageReader.ReadInt32(); imageCount = ReverseBytes(imageCount); // 画像の高さを取得 int imageHeight = imageReader.ReadInt32(); imageHeight = ReverseBytes(imageHeight); // 画像の幅を取得 int imageWidth = imageReader.ReadInt32(); imageWidth = ReverseBytes(imageWidth); int magic2 = labelReader.ReadInt32(); magic2 = ReverseBytes(magic2); // ラベルの個数を取得 int labelCount = labelReader.ReadInt32(); labelCount = ReverseBytes(labelCount); // 読み込んだ1枚分の画像データを格納するバッファを作成 var pixels = new byte[imageHeight][]; for (var i = 0; i < pixels.Length; i++) { pixels[i] = new byte[imageWidth]; } // 読み込んだすべての MNIST 画像を格納する配列を作成 var result = new MnistImage[imageCount]; for (int di = 0; di < imageCount; di++) { for (int i = 0; i < imageHeight; i++) // get 28x28 pixel values { for (int j = 0; j < imageWidth; j++) { byte b = imageReader.ReadByte(); pixels[i][j] = b; } } // ラベルを取得 byte label = labelReader.ReadByte(); var image = new MnistImage(imageHeight, imageWidth, pixels, label); result[di] = image; } return result; } } /// <summary> /// 整数のビットを逆順にします。 /// </summary> /// <param name="value">整数</param> /// <returns>ビットを逆順にした整数</returns> public static int ReverseBytes(int value) { byte[] intAsBytes = BitConverter.GetBytes(value); Array.Reverse(intAsBytes); return BitConverter.ToInt32(intAsBytes, 0); } } }
読み込んだデータセットのうち、指定した位置の画像情報を出力できるようにしてある。 試しに 3 を入力してみた結果がこちら。