GraphQL.NET の DataLoader

GraphQL でネストしたデータを取得するクエリを発行するとき、 N+1 問題を回避するために、 ネストしたデータの先読み込みか遅延読み込みを行う必要がある。

先読み込みの場合は、クエリのリゾルバ内で Entity Framework Core の Include を使って実装できるけど、 せっかく取得しても使わず無駄になってしまうことがありそう。 使うかどうかわからないデータを先読みするのは非効率か。

まとめて遅延読み込みできれば効率的。 遅延読み込みの場合は、DataLoader という仕組みを使う。 GraphQL.NET は DataLoader の機能を提供しているので、追加のライブラリは不要。 早速サンプルを書いてみた。

using System;
using System.Collections.Generic;
using System.Linq;
using GraphQL;
using GraphQL.DataLoader;
using GraphQL.Server;
using GraphQL.Server.Ui.Playground;
using GraphQL.Types;
using GraphQL.Types.Relay.DataObjects;
using Microsoft.AspNetCore;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;

namespace DataLoaderSample
{
    public class Team
    {
        public int Id { get; set; }
        public string Name { get; set; }

        public IList<Player> Players { get; set; }
    }

    public class Player
    {
        public int Id { get; set; }
        public string Name { get; set; }
        public string Position { get; set; }
        public int Number { get; set; }
        public int TeamId { get; set; }

        public Team Team { get; set; }
    }

    public class ApplicationDbContext : DbContext
    {
        public ApplicationDbContext(DbContextOptions<ApplicationDbContext> options)
            : base(options)
        {
        }

        public DbSet<Team> Teams => Set<Team>();

        public DbSet<Player> Players => Set<Player>();

        protected override void OnModelCreating(ModelBuilder modelBuilder)
        {
            modelBuilder.Entity<Team>(e =>
            {
                e.Property(x => x.Id).IsRequired();
                e.Property(x => x.Name).IsRequired();
                e.HasKey(x => x.Id);
                e.HasMany(x => x.Players);
            });
            modelBuilder.Entity<Player>(e =>
            {
                e.Property(x => x.Id).IsRequired();
                e.Property(x => x.Name).IsRequired();
                e.Property(x => x.Number).IsRequired();
                e.Property(x => x.Position).IsRequired();
                e.Property(x => x.TeamId).IsRequired();
                e.HasKey(x => x.Id);
                e.HasOne(x => x.Team);
            });
        }

        // テスト用のデータを登録する
        public void EnsureSeedData()
        {
            if (!Teams.Any())
            {
                Teams.Add(new Team
                {
                    Name = "バルセロナ",
                    Players = new List<Player>
                    {
                        new Player
                        {
                            Name = "メッシ",
                            Number = 10,
                            Position = "FW",
                        },
                        new Player
                        {
                            Name = "スアレス",
                            Number = 9,
                            Position = "FW",
                        },
                    },
                });
                Teams.Add(new Team
                {
                    Name = "レアルマドリード",
                    Players = new List<Player>
                    {
                        new Player
                        {
                            Name = "ベイル",
                            Number = 11,
                            Position = "FW",
                        },
                        new Player
                        {
                            Name = "モドリッチ",
                            Number = 10,
                            Position = "MF",
                        },
                    },
                });

                SaveChanges();
            }
        }
    }

    public class TeamType : ObjectGraphType<Team>
    {
        readonly IDataLoaderContextAccessor accessor;

        public TeamType(IDataLoaderContextAccessor accessor)
            : base()
        {
            this.accessor = accessor;

            Field(x => x.Id);
            Field(x => x.Name);

            // チームに所属する選手を取得できる
            Connection<PlayerType>()
                .Name("players")
                .ResolveAsync(async context =>
                {
                    // チームの選手を取得する部分は DataLoader を使って
                    // まとめて取得できるようにする。
                    var dbContext = (ApplicationDbContext)context.UserContext;
                    var loader = accessor.Context.GetOrAddBatchLoader<int, Connection<Player>>(
                        loaderKey: "team_players",
                        fetchFunc: async (teamIdList) =>
                        {
                            // チームでグループ化し、各グループを Connection に変換
                            return await dbContext.Players
                                .Where(p => teamIdList.Contains(p.TeamId))
                                .OrderBy(p => p.TeamId)
                                .ThenBy(p => p.Id)
                                .GroupBy(p => p.TeamId)
                                .ToDictionaryAsync(
                                    keySelector: players => players.Key,
                                    elementSelector: players =>
                                    {
                                        var connection = new Connection<Player>
                                        {
                                            PageInfo = new PageInfo
                                            {
                                                StartCursor = players.FirstOrDefault()?.Id.ToString(),
                                                EndCursor = players.LastOrDefault()?.Id.ToString(),
                                            },
                                            Edges = players.Select(xx => new Edge<Player>
                                            {
                                                Cursor = xx.Id.ToString(),
                                                Node = xx,
                                            }).ToList(),
                                        };
                                        return connection;
                                    });
                        });
                    return await loader.LoadAsync(context.Source.Id);
                });
        }
    }

    public class PlayerType : ObjectGraphType<Player>
    {
        readonly IDataLoaderContextAccessor accessor;

        public PlayerType(IDataLoaderContextAccessor accessor)
            : base()
        {
            this.accessor = accessor;

            Field(x => x.Id);
            Field(x => x.Name);
            Field(x => x.Position);
            Field(x => x.Number);
            Field(x => x.TeamId);

            // 選手が所属するチームを取得できる
            FieldAsync<TeamType>(
                name: "team",
                resolve: async context =>
                {
                    // 選手が所属するチームを取得する部分も DataLoader を使って
                    // まとめて取得できるようにする。
                    var dbContext = (ApplicationDbContext)context.UserContext;
                    var loader = accessor.Context.GetOrAddBatchLoader<int, Team>(
                        loaderKey: "player_team",
                        fetchFunc: async (teamIdList) =>
                        {
                            return await dbContext.Teams
                                .Where(x => teamIdList.Contains(x.Id))
                                .ToDictionaryAsync(x => x.Id);
                        });
                    return await loader.LoadAsync(context.Source.TeamId);
                });
        }
    }

    // クエリを表す型
    public class SampleQuery : ObjectGraphType
    {
        public SampleQuery()
            : base()
        {
            Connection<TeamType>()
                .Name("teams")
                .ResolveAsync(async context =>
                {
                    // クエリのフィールドは DataLoader を使わなくてもまとめて取得できる
                    var dbContext = (ApplicationDbContext)context.UserContext;
                    var teams = await dbContext.Teams.ToListAsync();
                    return new Connection<Team>
                    {
                        PageInfo = new PageInfo
                        {
                            StartCursor = teams.FirstOrDefault()?.Id.ToString(),
                            EndCursor = teams.LastOrDefault()?.Id.ToString(),
                        },
                        Edges = teams.Select(x => new Edge<Team>
                        {
                            Cursor = x.Id.ToString(),
                            Node = x,
                        }).ToList(),
                    };
                });
            Connection<PlayerType>()
                .Name("players")
                .ResolveAsync(async context =>
                {
                    // クエリのフィールドは DataLoader を使わなくてもまとめて取得できる
                    var dbContext = (ApplicationDbContext)context.UserContext;
                    var players = await dbContext.Players.ToListAsync();
                    return new Connection<Player>
                    {
                        PageInfo = new PageInfo
                        {
                            StartCursor = players.FirstOrDefault()?.Id.ToString(),
                            EndCursor = players.LastOrDefault()?.Id.ToString(),
                        },
                        Edges = players.Select(x => new Edge<Player>
                        {
                            Cursor = x.Id.ToString(),
                            Node = x,
                        }).ToList(),
                    };
                });
        }
    }

    // サンプルの GraphQL スキーマ。
    public class SampleSchema : Schema
    {
        public SampleSchema(IDependencyResolver dependencyResolver)
            : base(dependencyResolver)
        {
            // 今回はクエリだけ。
            Query = dependencyResolver.Resolve<SampleQuery>();
        }
    }

    public class Startup
    {
        public Startup(IConfiguration configuration)
        {
            Configuration = configuration;
        }

        public IConfiguration Configuration { get; }

        public void ConfigureServices(IServiceCollection services)
        {
            services.AddDbContext<ApplicationDbContext>(options =>
                options.UseSqlServer(Configuration.GetConnectionString("DefaultConnection")));

            // GraphQL Server を使用する
            services.AddGraphQL()
                .AddUserContextBuilder(httpContext =>
                {
                    return httpContext.RequestServices
                        .GetService<ApplicationDbContext>();
                })
                .AddRelayGraphTypes() // ページネーションで使う型を DI コンテナに登録
                .AddDataLoader();  // DataLoader 関連の型を DI コンテナに登録

            // ↑でDependencyResolver を登録してくれないので、自前で登録する必要がある。
            // DefaultDependencyResolver はコンストラクタインジェクションに対応していないため、
            // FuncDependencyResolver を使わないといけない。
            services.AddSingleton<IDependencyResolver>(x => new FuncDependencyResolver(x.GetService));

            // 自前で定義した GraphQL 用の型を登録
            services.AddSingleton<TeamType>();
            services.AddSingleton<PlayerType>();
            services.AddSingleton<SampleQuery>();
            services.AddSingleton<SampleSchema>();
        }

        public void Configure(IApplicationBuilder app, IHostingEnvironment env, ApplicationDbContext dbContext)
        {
            // GraphQL Server を使う
            app.UseGraphQL<SampleSchema>("/graphql");

            // GraphQL Playground を使う
            app.UseGraphQLPlayground(new GraphQLPlaygroundOptions
            {
                Path = "/ui/playground",
            });

            // テスト用データベースが無ければ作る
            dbContext.Database.EnsureCreated();
            dbContext.EnsureSeedData();
        }
    }

    public class Program
    {
        public static void Main(string[] args)
        {
            CreateWebHostBuilder(args).Build().Run();
        }

        public static IWebHostBuilder CreateWebHostBuilder(string[] args) =>
            WebHost.CreateDefaultBuilder(args)
                .UseStartup<Startup>();
    }
}

Visual Studioブレークポイントを設定した状態で、 GraphQL Playground からネストしたデータを取得するクエリを実行すると、 ネストしたデータをまとめて取得できていることが確認できた。

f:id:griefworker:20190122100019p:plain