MNIST データセットを読み込んでベクトルに変換

C# でゼロから Deep Learning を実装する挑戦の続き。 この挑戦では、『ゼロから作る Deep Learning』同様に、 手書き数字認識のニューラルネットワークを実装するので、 MNIST データセットを利用する。

MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges

MNIST データセットを読み込んで、 Math.NET Numerics の Vector に変換するコードを書いてみた。

using MathNet.Numerics.LinearAlgebra;
using System;
using System.IO;
using System.Linq;

namespace MnistSample
{
    class Program
    {
        static void Main(string[] args)
        {
            if (args.Length != 2)
            {
                Console.WriteLine("MnistSample.exe <pixelsFilePath> <labelsFilePath>");
                Console.WriteLine("Press enter to quit");
                Console.ReadLine();
                return;
            }


            var cancel = false;
            Console.CancelKeyPress += (sender, e) =>
            {
                cancel = true;
            };

            // MNIST データセットをロードする
            var images = MnistImage.Load(pixelFilePath: args[0], labelFilePath: args[1]);
            Console.WriteLine($"Image count : {images.Length}");

            // 入力したインデックスにある画像の情報を表示する
            Console.WriteLine($"Press <Ctrl> + <C> to quit");
            while (cancel == false)
            {
                Console.WriteLine($"Input index (0 ~ {images.Length - 1})");

                var input = Console.ReadLine();
                int index;
                if (int.TryParse(input, out index) &&
                    (0 <= index) &&
                    (index <= images.Length))
                {
                    var image = images[index];
                    Console.WriteLine($"label : {image.Label}");
                    Console.WriteLine(image.ToVector());
                }
            }
        }
    }

    /// <summary>
    /// MNIST データセットの画像を表します。
    /// </summary>
    public class MnistImage
    {
        /// <summary>
        /// 画像の高さを取得します。
        /// </summary>
        public int Height { get; }

        /// <summary>
        /// 画像の幅を取得します。
        /// </summary>
        public int Width { get; }

        /// <summary>
        /// MNIST 画像のピクセルを取得します。
        /// </summary>
        public byte[][] Pixels { get; }

        /// <summary>
        /// 0 ~ 9 までのラベルを取得します。
        /// </summary>
        public byte Label { get; }

        /// <summary>
        /// <see cref="MnistImage"/> クラスの新しいインスタンスを初期化します。
        /// </summary>
        /// <param name="height">画像の高さ</param>
        /// <param name="width">画像の幅</param>
        /// <param name="pixels">画像を構成するピクセル</param>
        /// <param name="label">ラベル</param>
        public MnistImage(int height, int width, byte[][] pixels, byte label)
        {
            Height = height;
            Width = width;
            Label = label;
            Pixels = new byte[height][];
            for (var i = 0; i < height; i++)
            {
                Pixels[i] = new byte[width];
            }
            for (var i = 0; i < height; i++)
            {
                for (var j = 0; j < width; j++)
                {
                    Pixels[i][j] = pixels[i][j];
                }
            }
        }

        // Vector<T> と Matrix<T> は byte をサポートしていない

        /// <summary>
        /// ベクトルに変換します。
        /// </summary>
        /// <returns>ベクトル</returns>
        public Vector<double> ToVector()
        {
            var flatten = Pixels.SelectMany(row => row)
                .Select(b => Convert.ToDouble(b));
            var vector = Vector<double>.Build.DenseOfEnumerable(flatten);
            return vector;
        }

        /// <summary>
        /// 行列に変換します。
        /// </summary>
        /// <returns>行列</returns>
        public Matrix<double> ToMatrix()
        {
            return Matrix<double>.Build.Dense(Height, Width, (row, col) => Pixels[row][col]);
        }

        /// <summary>
        /// MNIST データセットとラベルをロードします。
        /// </summary>
        /// <param name="pixelFilePath">MNIST データセットのパス</param>
        /// <param name="labelFilePath">ラベルのパス</param>
        /// <returns><see cref="MnistImage"/> の配列</returns>
        public static MnistImage[] Load(string pixelFilePath, string labelFilePath)
        {
            using (var imageStream = File.OpenRead(pixelFilePath))
            using (var labelStream = File.OpenRead(labelFilePath))
            using (var imageReader = new BinaryReader(imageStream))
            using (var labelReader = new BinaryReader(labelStream))
            {
                int magic1 = imageReader.ReadInt32();
                magic1 = ReverseBytes(magic1);

                // 画像の枚数を取得
                int imageCount = imageReader.ReadInt32();
                imageCount = ReverseBytes(imageCount);

                // 画像の高さを取得
                int imageHeight = imageReader.ReadInt32();
                imageHeight = ReverseBytes(imageHeight);

                // 画像の幅を取得
                int imageWidth = imageReader.ReadInt32();
                imageWidth = ReverseBytes(imageWidth);

                int magic2 = labelReader.ReadInt32();
                magic2 = ReverseBytes(magic2);

                // ラベルの個数を取得
                int labelCount = labelReader.ReadInt32();
                labelCount = ReverseBytes(labelCount);

                // 読み込んだ1枚分の画像データを格納するバッファを作成
                var pixels = new byte[imageHeight][];
                for (var i = 0; i < pixels.Length; i++)
                {
                    pixels[i] = new byte[imageWidth];
                }

                // 読み込んだすべての MNIST 画像を格納する配列を作成
                var result = new MnistImage[imageCount];

                for (int di = 0; di < imageCount; di++)
                {
                    for (int i = 0; i < imageHeight; i++) // get 28x28 pixel values
                    {
                        for (int j = 0; j < imageWidth; j++)
                        {
                            byte b = imageReader.ReadByte();
                            pixels[i][j] = b;
                        }
                    }

                    // ラベルを取得
                    byte label = labelReader.ReadByte();

                    var image = new MnistImage(imageHeight, imageWidth, pixels, label);
                    result[di] = image;
                }

                return result;
            }
        }

        /// <summary>
        /// 整数のビットを逆順にします。
        /// </summary>
        /// <param name="value">整数</param>
        /// <returns>ビットを逆順にした整数</returns>
        public static int ReverseBytes(int value)
        {
            byte[] intAsBytes = BitConverter.GetBytes(value);
            Array.Reverse(intAsBytes);
            return BitConverter.ToInt32(intAsBytes, 0);
        }
    }
}

読み込んだデータセットのうち、指定した位置の画像情報を出力できるようにしてある。 試しに 3 を入力してみた結果がこちら。