ぷるぷるした直方体

元職業エンジニアの雑記です

slick-codegenの利用例と中身の説明

Scala Advent Calendar 2017の17日目の記事です。

qiita.com

前日は@yoshiyoshifujiiさんによる「インフラストラクチャ層からのエラーについて考える」でした。

qiita.com

背景

slick-codegenは、DBのスキーマから対応するコードを自動生成するプログラムです。 同じようなコードをひたすら書く作業を無くしてくれる、頼れるやつです。 Slickを使っている人であれば、一度は使ったことがあるかと思います。

そんな便利なslick-codegenは公式のドキュメントでも紹介されています(Schema Code Generation) が、あまり説明が無いため、実際どのようなカスタマイズができるのかが分かりにくかったりします。 また、サンプルのリポジトリが紹介されていますが、ここに無い使用方法も多く存在します。 APIドキュメントを読み解けば良いですが、もっと雑に何が出来るか知りたいところ。

ということで、いくつか使えそうなサンプルを掲載して、その中でカスタマイズ方法を説明しようと思います。 なお、本記事の末尾に全てのサンプルを組み合わせたコードを貼り付けますので、実際に使ってみたい方はそちらをご利用ください。

slick-codegeのコード

とは言え、結局コードを読んだ方がよく理解できることがしばしば。 slick-codegenは5つのScalaコードからなる小さなプログラムですので、分かりにくいところは直に読むと早いです。

github.com

全体像

イメージを湧きやすくするため、少しだけ全体像の説明を。

自分で作るカスタムジェネレーターは、SourceCodeGeneratorを継承して作ります。

class CustomGenerator(model: m.Model) extends SourceCodeGenerator(model) {
  ...
}

SourceCodeGeneratorは、AbstractSourceCodeGeneratorOutputHelpersを実装しています。 https://github.com/slick/slick/blob/v3.2.1/slick-codegen/src/main/scala/slick/codegen/SourceCodeGenerator.scala

class SourceCodeGenerator(model: m.Model) extends AbstractSourceCodeGenerator(model) with OutputHelpers{
  ...
}

さらに、AbstractSourceCodeGeneratorAbstractGeneratorを使っています。 StringGeneratorHelpersは同じファイルで定義されており、本筋とは外れるので無視します。 https://github.com/slick/slick/blob/v3.2.1/slick-codegen/src/main/scala/slick/codegen/AbstractSourceCodeGenerator.scala

abstract class AbstractSourceCodeGenerator(model: m.Model) extends AbstractGenerator[String,String,String](model) with StringGeneratorHelpers{
  ...
}

AbstractGeneratorはというと、GeneratorHelpersというヘルパーを使っているだけです。 https://github.com/slick/slick/blob/v3.2.1/slick-codegen/src/main/scala/slick/codegen/AbstractGenerator.scala

abstract class AbstractGenerator[Code,TermName,TypeName](model: m.Model) extends GeneratorHelpers[Code,TermName,TypeName]{ codegen =>
  ...
}

ということで、 自分のクラス ← SourceCodeGeneratorAbstractSourceCodeGeneratorAbstractGenerator という関係になっています。 このそれぞれのクラスは同名の各ファイルで定義されていますので、順番に各コードを追ってゆくと分かりやすいかと思います。

下記サンプルでも、対応するコードを示しながら説明してゆきます。

サンプル集

サンプルとして使うスキーマですが、適当に以下のようなSQLを用意してMySQLに流しました。

create table user(
   id INT NOT NULL AUTO_INCREMENT,
   name VARCHAR(255) NOT NULL,
   created_at datetime NOT NULL DEFAULT CURRENT_TIMESTAMP,
   PRIMARY KEY ( id )
);

create table item(
   id INT NOT NULL AUTO_INCREMENT,
   name VARCHAR(255) NOT NULL,
   user_id INT,
   FOREIGN KEY (user_id) REFERENCES user(id) ON DELETE CASCADE,
   PRIMARY KEY ( id )
);

java.sql.timeをorg.joda.time.DateTimeに変換する

いきなり自前のコードではなく恐縮ですが、最初のサンプルは参考文献から。 https://qiita.com/uedashuhei/items/25d5a6e786075729d3b3 こちらの例では、以下の処理が行われています。

  • importを追加する (code)
  • Auto Incrementの値をOption型にする (Table.autoIncLastAsOption。なお、後述のサンプルコードではTable.autoIncLastColumn.asOptionを使用。こちらの記事でdeprecatedなことを知りました
  • カラムの型を変換する (Column.rawType)
  • 特定のカラムを名前で判別し除外する (modelに対する処理)

java.sql.Timestamporg.joda.time.DateTimeにしています。便利ですね。

デフォルト値を付ける

createdAtupdatedAtを作成時にいちいち入れるのはなかなか手間です。 先の例のようにまるまる消してしまう方法もありますが、それらの値を取得したい時に困ってしまいます。 そこで、下記のようなデフォルト値を付けることで便利に扱えるようにします。

case class UserRow(id: Int, name: String, createdAt: org.joda.time.DateTime = DateTime.now)

これにより、必要なパラメーターだけを指定して作成ができます。

UserRow(1, "name")

slick-codegenでは、AbstractGenerator.ColumnDef.defaultをオーバーライドすることで実現できます。 デフォルト値がない場合はNone、ある場合はSomeで返します。

override def default = rawName match {
  case "createdAt" | "updatedAt" => Some("DateTime.now")
  case _ => super.default
}

JSONへの変換を自動で行う

テスト時や単純なAPIでは、下記のようにxxxRowをそのままJSONに変換して返したい時があります。

def someAction = Action { ...
   // DBからユーザーを取り出し
   val user: UserRow = ...
   Ok(Json.toJson(user)) // これ
}

これには、UserRowのコンパニオンオブジェクトで下記のimplicitな変換を定義する必要があります。

object UserRow {
  implicit val jsonWrites = Json.writes[Entity]
}

一つずつ作るのは大変なので、これも自動生成させます。 注意点として、Rowのコンパニオンオブジェクトを作ると UserRow.tupledがコンパイルエラーになり、(UserRow.apply _).tupledにしないと動かなくなるので、その対策も必要です。

xxxRowはどうやって作られているのか

さて、slick-codegenでxxzRowの下に同名のコンパニオンオブジェクトを作るにはどうしたらよいでしょうか。 そのためにも、xxxRowが何かを知る必要があります。

まず、1つのテーブルに対応するコードはTableDef型で表されます。 TableDef.codeを見ると、完成したコードが入っているわけですね。

xxxRowはEntityと呼ばれ、EntityTypeDef型の情報により生成されます。 このEntityTypeDefTableDef.definitionsで並べられ、TableDef.codeで実際にコードになります。

def definitions = Seq[Def]( EntityType, PlainSqlMapper, TableClass, TableValue )
def code: Seq[Code] = definitions.flatMap(_.getEnabled).map(_.docWithCode)

ちなみに、EntityTypeEntityType型のインスタンスを返す関数です。

Defを作る

上記から、Def型で作ってdefinitionsに並べてあげれば、コンパニオンオブジェクトを追記することができそうです。 クラス名はEntityと同じなので、EntityTypeDefを流用してしまいましょう。 今回は中身が決まりきっているため、codeだけoverrideすれば目的を達成できます。

class EntityCompanionDef extends EntityTypeDef {
  override def doc = ""
  override def code =
    s"""object ${rawName} {
       |  import play.api.libs.json._
       |  import play.api.libs.json.JodaWrites
       |  import play.api.libs.json.JodaReads
       |
       |  val dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZ"
       |
       |  implicit val dateTimeWriter = JodaWrites.jodaDateWrites(dateFormat)
       |  implicit val dateTimeJsReader = JodaReads.jodaDateReads(dateFormat)
       |  implicit val jsonWrites = Json.writes[${rawName}]
       |  implicit val jsonReads = Json.reads[${rawName}]
       |}
     """.stripMargin
}

override def definitions = {
  val companion = new EntityCompanionDef
  Seq[Def](EntityType, companion, PlainSqlMapper, TableClass, TableValue)
}

ついでにJodatimeの文字列変換も入れています。 なお、docは不要なので空文字にしています。

これを実行すると、コンパニオンオブジェクトが追記された定義が出力されます。

ただし、前述したとおり、このままではコンパイルが通らないので、TableDef.codeに置換処理を追記します。

override def code = {
  super.code.toList.map(_.replaceAll(s"${EntityType.name}.tupled", s"(${EntityType.name}.apply _).tupled"))
}

独自のID型を利用する

さらに発展系として、各リソースのID型を区別し、取り違えないようにしたいと思います。 Entityはidの型がOption[Int]からOption[UserId]となります。

case class UserRow(name: String, createdAt: DateTime = DateTime.now, id: Option[UserId] = None)

こちらの記事で述べられている、各IDの型とPathBindableを自動生成してみます。 https://qiita.com/srd7/items/ee2098d7cebc50ae0e01

まずはID型の生成から始めます。

ID型のGenerator

わかりやすさのため、ID型はテーブル定義とは別ファイルにします。 クラス定義から始めたいので、自動で生成されているこれらが邪魔になります。

// AUTO-GENERATED Slick data model
/** Stand-alone Slick data model for immediate use */
object IDs extends {
  val profile = slick.driver.MySQLDriver
} with IDs

/** Slick data model trait for extension, choice of backend or usage in the cake pattern. (Make sure to initialize this late.) */
trait IDs {
  val profile: slick.jdbc.JdbcProfile
  import profile.api._
  import slick.model.ForeignKeyAction
  // NOTE: GetResult mappers for plain SQL are only generated for tables where Slick knows how to map the types of all columns.
  import slick.jdbc.{GetResult => GR}

  /** DDL for all tables. Call .create to execute. */
  lazy val schema: profile.SchemaDescription = Item.schema ++ User.schema
  @deprecated("Use .schema instead of .ddl", "3.0")
  def ddl = schema

  ...
}

まずはこれを消す所から。 コードを漁ると、trait IDsの定型文はAbstractSourceCodeGenerator.codeで処理されていることが分かります。

def code = {
  "import slick.model.ForeignKeyAction\n" +
  ...
  } else "") +
  tables.map(_.code.mkString("\n")).mkString("\n\n")
}

さらに外側、// AUTO-GENERATED Slick data modelなどはOutputHelper.packageCodeで作られています。 これをオーバーライドし、テーブル定義だけを吐き出すようにすれば良いですね。

override def code = tables.map(_.code.mkString("\n")).mkString("\n\n")
override def packageCode(profile: String, pkg: String, container: String, parentType: Option[String]) : String =
  s"""package models
     |
     |${code}
   """.stripMargin

これで下記のような出力を得ることができます。

package models

class ItemId private (private[models] val value: Int) extends AnyVal {
  override def toString = value.toString
}
object ItemId {
  ...
}


class UserId private (private[models] val value: Int) extends AnyVal {
  override def toString = value.toString
}
object UserId {
  ...
}

TableのIDカラムを書き換える

(ここからはスキーマに大きく依存するので、実際に利用される際は調整をしてください。今回は、主キーのカラム名はid、テーブルは単数形、外部キーのカラム名は(table)Idという名前になっているとします)

各テーブルのidカラム型を書き換えるのは簡単ですが、外部キー用のIDは少々厄介です。

override def Column = new Column(_) {
  override def rawType = model.tpe match {
    case _ if model.name == "id" => s"""${TableValue.name}Id""" // かんたん
    case _ => super.rawType
  }
}

今回は存在するテーブル名を列挙して、”(テーブル名)_id”に一致したら型を変更する、というゴリ押しで実装します。 まず、テーブル名と型名のマップを作ります。各種Generatorを呼び出す前に、DB情報から抜き出します。

val idColumnNameSet = (for {
  t <- model.tables
} yield s"${t.name.table}_id").toSet

これで、以下のようなセットが得られます。

Set(item_id, user_id)

これをGeneratorに渡し

class CustomTableGenerator(model: m.Model, idColumnNameSet: Set[String]) extends SourceCodeGenerator(model) {

このように判定をすれば、user_idUserId型へ、item_idItemId型へと書き変わります。

override def rawType = model.tpe match {
  case _ if idColumnNameSet.contains(model.name) => model.name.toCamelCase
  case _ if model.name == "id" => s"""${TableValue.name}Id"""
  case _ => super.rawType
}

ID用Mapperの追加

Slickが自動でIntやLongをID型にマッピング出来るよう、以下のようなimplicitを配置する必要があります。

implicit val userIdMapper = MappedColumnType.base[UserId, Long](_.value, UserId.apply)

https://qiita.com/srd7/items/ee2098d7cebc50ae0e01#slick-%E3%81%BE%E3%82%8F%E3%82%8A

これは、先程のSetを使うと簡単に実現できます。

def implicitIdMapper(name: String): String = {
  val idName = s"${name.toCamelCase}"
  val uncapitalizedIdName = idName.head.toLower + idName.tail
  s"implicit val ${uncapitalizedIdName}Mapper = MappedColumnType.base[${idName}, Int](_.value, ${idName}.apply)"
}

このような関数を用意し、SourceCodeGenerator.codeの中で呼び出すことで、全ID型に対するimplicitな変換クラスの先頭に記述できます。

PathBindableのGenerator

最後に、conf/routesで各ID型を使えるよう、PathBindableも定義します。 これはまた独立したファイルに書くこととします。

ID型の生成とほぼ同じなので、末尾のまとめコードを参照ください。

まとめ

slick-codegenのサンプルを通して、コードの中身と少し入り込んだカスタマイズ方法を記載してゆきました。

出力を簡単にカスタマイズできるので、Slick以外でも使えそうですね。 slick-codegenを活用し、面倒な記述はなるべく自動生成に任せてゆきましょう。

明日のScala Advent Calendar 2017は、@grimrose@githubさんによるsangriaの紹介です。 私はGraphQLに手を出そうとしつつ、未だできていません。記事が楽しみです!

今回作ったコード

最後に、今回説明に使ったサンプルを盛り込んだコードを掲載します。 自作ジェネレーターのテンプレートにご利用ください。

package main

import slick.driver.JdbcProfile

import scala.concurrent.ExecutionContext.Implicits.global
import slick.driver.MySQLDriver.api._
import slick.driver.MySQLDriver

import scala.collection.mutable
import slick.{model => m}
import slick.codegen.SourceCodeGenerator
import slick.model.Model
import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext}

class CustomTableGenerator(model: m.Model, idColumnNameSet: Set[String]) extends SourceCodeGenerator(model) {

  def implicitIdMapper(name: String): String = {
    val idName = s"${name.toCamelCase}"
    val uncapitalizedIdName = idName.head.toLower + idName.tail
    s"implicit val ${uncapitalizedIdName}Mapper = MappedColumnType.base[${idName}, Int](_.value, ${idName}.apply)"
  }

  // add some custom imports
  override def code =
    s"""|import com.github.tototoshi.slick.MySQLJodaSupport._
       |
       |${(idColumnNameSet.map(implicitIdMapper)).mkString("\n")}
       |
       |""".stripMargin +
      super.code

  override def Table = new Table(_) {
    override def autoIncLast = true

    override def Column = new Column(_) {
      override def asOption = autoInc
      override def rawType = model.tpe match {
        case "java.sql.Timestamp" => "DateTime"
        case "java.sql.Date" => "DateTime"
        case _ if idColumnNameSet.contains(model.name) => model.name.toCamelCase
        case _ if model.name == "id" => s"""${TableValue.name}Id"""
        case _ => super.rawType
      }

      override def default = rawName match {
        case "createdAt" | "updatedAt" => Some("DateTime.now")
        case _ => super.default
      }
    }

    class EntityCompanionDef extends EntityTypeDef {
      override def doc = ""

      override def code =
        s"""object ${rawName} {
           |  import play.api.libs.json._
           |  import play.api.libs.json.JodaWrites
           |  import play.api.libs.json.JodaReads
           |
           |  val dateFormat = "yyyy-MM-dd'T'HH:mm:ss.SSSZ"
           |
           |  implicit val dateTimeWriter = JodaWrites.jodaDateWrites(dateFormat)
           |  implicit val dateTimeJsReader = JodaReads.jodaDateReads(dateFormat)
           |  implicit val jsonWrites = Json.writes[${rawName}]
           |  implicit val jsonReads = Json.reads[${rawName}]
           |}
           |""".stripMargin
    }

    override def definitions = {
      val companion = new EntityCompanionDef
      Seq[Def](EntityType, companion, PlainSqlMapper, TableClass, TableValue)
    }

    override def code = {
      super.code.toList.map(_.replaceAll(s"${EntityType.name}.tupled", s"(${EntityType.name}.apply _).tupled"))
    }
  }
}


class CustomIDGenerator(model: m.Model) extends SourceCodeGenerator(model) {

  override def code = tables.map(_.code.mkString("\n")).mkString("\n\n")

  override def packageCode(profile: String, pkg: String, container: String, parentType: Option[String]): String =
    s"""package models
       |
       |import play.api.libs.json._
       |
       |${code}
       |""".stripMargin

  override def Table = new Table(_) {

    class IDDef extends EntityTypeDef {
      override def doc = ""

      override def code = {
        val name = TableValue.name
        val idName = s"""${name}Id"""
        val packageName = "models"
        val idType = "Int"

        s"""class ${idName} private (private[${packageName}] val value: ${idType}) extends AnyVal {
           |  override def toString = value.toString
           |}
           |object ${idName} {
           |  private[models] def apply(value: ${idType}) = new ${idName}(value)
           |  private[models] def unapply(id: ${idName}) = Some(id.value)
           |  implicit val jsonWrites = Json.writes[${idName}]
           |  implicit val jsonReads = Json.reads[${idName}]
           |  def fromString(str: String): Either[Throwable, ${idName}] = {
           |    try {
           |      Right(${idName}(str.to${idType}))
           |    } catch {
           |      case e: Throwable => Left(e)
           |    }
           |  }
           |}
           |""".stripMargin
      }
    }

    override def definitions = {
      Seq[Def](new IDDef)
    }
  }
}

class PathBindableGenerator(model: m.Model) extends SourceCodeGenerator(model) {

  override def code = tables.map(_.code.mkString("\n")).mkString("\n\n")

  override def packageCode(profile: String, pkg: String, container: String, parentType: Option[String]): String =
    s"""package models
       |
       |import play.api.mvc.PathBindable
       |
       |object PathBindableImplicits {
       |  ${indent(code)}
       |}
       |""".stripMargin

  override def Table = new Table(_) {

    class PathBindamleDef extends EntityTypeDef {
      override def doc = ""

      override def code = {
        val name = TableValue.name
        val idName = s"""${name}Id"""
        val uncapitalizedIdName = idName.head.toLower + idName.tail
        val implicitName = s"""${uncapitalizedIdName}PathBindable"""

        val packageName = "models"

        s"""implicit def ${implicitName} = new PathBindable[${idName}] {
           |  override def bind(key: String, value: String): Either[String, ${idName}] = {
           |    ${idName}.fromString(value).left.map(_.getMessage)
           |  }
           |  override def unbind(key: String, ${uncapitalizedIdName}: ${idName}): String = {
           |    ${uncapitalizedIdName}.toString
           |  }
           |}
           |""".stripMargin
      }
    }

    override def definitions = {
      Seq[Def](new PathBindamleDef)
    }
  }
}

object SlickCodegen extends App {
  val dbs = Setting.dev

  val slickDriver = dbs.slickDriver
  val profile = dbs.profile
  val jdbcDriver = dbs.jdbcDriver
  val url = dbs.url
  val outputFolder = dbs.outputFolder
  val schemas = dbs.schemas
  val pkg = dbs.pkg
  val user = dbs.user
  val password = dbs.password
  val driver: slick.jdbc.JdbcProfile = dbs.profile
  val db = {
    Database.forURL(url, driver = jdbcDriver, user = user, password = password)
  }

  import scala.concurrent.ExecutionContext.Implicits.global

  val modelFuture = db.run(driver.createModel(None, false))
  val f = modelFuture.map(model => {
    val idColumnNameSet = (for {
      t <- model.tables
    } yield s"${t.name.table}_id").toSet

    new CustomTableGenerator(model, idColumnNameSet).writeToFile(slickDriver, outputFolder, pkg, "Tables", "Tables.scala")
    new CustomIDGenerator(model).writeToFile(slickDriver, outputFolder, pkg, "IDs", "IDs.scala")
    new PathBindableGenerator(model).writeToFile(slickDriver, outputFolder, pkg, "PathBindableImplicits", "PathBindableImplicits.scala")
  })

  Await.result(f, Duration.Inf)
}