Last active
August 5, 2020 15:33
-
-
Save deusaquilus/dfb42880656df12779a0afd4f20ef1bb to your computer and use it in GitHub Desktop.
Better Typed ExpandJoin which uses FlatMap/FlatJoin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package io.getquill.sql.norm | |
import io.getquill.ast._ | |
import io.getquill.ast.Implicits._ | |
import io.getquill.norm.BetaReduction | |
import io.getquill.norm.TypeBehavior.{ ReplaceWithReduction => RWR } | |
import io.getquill.norm.Normalize | |
import io.getquill.quat.Quat | |
// def nestedWithRenames(): Unit = { | |
// case class Ent(name: String) | |
// case class Foo(fame: String) | |
// case class Bar(bame: String) | |
// | |
// implicit val entSchema = schemaMeta[Ent]("TheEnt", _.name -> "theName") | |
// | |
// val q = quote { | |
// query[Foo] | |
// .join(query[Ent]).on((f, e) => f.fame == e.name) // (Foo, Ent) | |
// .distinct | |
// .join(query[Bar]).on((fe, b) => (fe._1.fame == b.bame)) // ((Foo, Ent), Bar) | |
// .distinct | |
// .map(feb => (feb._1._2, feb._2)) // feb: ((Foo, Ent), Bar) | |
// .distinct | |
// .map(eb => (eb._1.name, eb._2.bame)) // eb: (Ent, Bar) | |
// } | |
// println(run(q)) //helloo | |
// } | |
// nestedWithRenames() | |
object ExpandJoin { | |
def apply(q: Ast) = expand(q, None) | |
def expand(q: Ast, id: Option[Ident]) = | |
Transform(q) { | |
case q @ Join(_, _, _, Ident(a, _), Ident(b, _), _) => // Ident a and Ident b should have the same Quat, could add an assertion for that | |
val (qr, tuple) = expandedTuple(q) | |
val innermostOpt = | |
CollectAst(qr) { | |
case fm @ FlatMap(_, _, MoreTables) => fm | |
}.headOption | |
innermostOpt match { | |
case Some(innermost) => | |
val newInnermost = | |
innermost match { | |
case FlatMap(fj: FlatJoin, alias, MoreTables) => | |
val fjr = BetaReduction(fj, RWR, MoreCond -> (Constant(1) +==+ Constant(1))) // TODO reduce this out i.e. something && 1==1 should be just something | |
Map(fjr, alias, tuple) | |
case other => | |
throw new IllegalArgumentException(s"Flat Join Expansion created Illegal FlatJoin COnstruct:\n${io.getquill.util.Messages.qprint(other).plainText}") | |
} | |
val output = BetaReduction(qr, RWR, innermost -> newInnermost) | |
// Check that there are no placeholders remaining in the AST. Otherwise something has gone wrong. | |
val verifyNoPlaceholders = new StatelessTransformer { | |
override def applyIdent(id: Ident): Ident = | |
id match { | |
case `MoreCond` | `MoreTables` => | |
throw new IllegalArgumentException(s"Flat Join Expansion could not succeed due to placeholders remaining in the AST:\n${io.getquill.util.Messages.qprint(output).plainText}") | |
case _ => | |
id | |
} | |
} | |
verifyNoPlaceholders(output) | |
output | |
case None => | |
qr | |
} | |
} | |
val MoreTables = Ident("<PH>", Quat.Generic) | |
val MoreCond = Ident("<PH>", Quat.Generic) | |
private def expandedTuple(q: Join): (FlatMap, Tuple) = | |
q match { | |
case Join(t, a: Join, b: Join, tA, tB, o) => | |
val (ar, at) = expandedTuple(a) | |
val (br, bt) = expandedTuple(b) | |
val or = BetaReduction(o, RWR, tA -> at, tB -> bt) | |
val arbrFlat = BetaReduction(ar, RWR, MoreTables -> br) | |
val arbr = BetaReduction(arbrFlat, RWR, MoreCond -> or).asInstanceOf[FlatMap] // Reduction of a flatMap must be a flatMap | |
(arbr, Tuple(List(at, bt))) | |
case Join(t, a: Join, b, tA, tB, o) => | |
val (ar, at) = expandedTuple(a) | |
val or = BetaReduction(o, RWR, tA -> at) | |
val br = | |
FlatMap( | |
FlatJoin(t, b, tB, or +&&+ MoreCond), | |
tB, MoreTables | |
) | |
val arbr = BetaReduction(ar, RWR, MoreTables -> br).asInstanceOf[FlatMap] | |
(arbr, Tuple(List(at, tB))) | |
case Join(t, a, b: Join, tA, tB, o) => | |
val (br, bt) = expandedTuple(b) | |
val or = BetaReduction(o, RWR, tB -> bt) | |
val arbr = | |
FlatMap( | |
FlatJoin(t, a, tA, or), | |
tA, | |
BetaReduction(br, RWR, MoreCond -> or) | |
) | |
(arbr, Tuple(List(tA, bt))) | |
case q @ Join(t, a, b, tA, tB, on) => | |
// (Join(t, nestedExpand(a, tA), nestedExpand(b, tB), tA, tB, on), Tuple(List(tA, tB))) | |
val ar = nestedExpand(a, tA) | |
val br = nestedExpand(b, tB) | |
val ab = | |
FlatMap( | |
ar, tA, | |
FlatMap(FlatJoin(t, br, tB, on), tB, MoreTables) | |
) | |
(ab, Tuple(List(tA, tB))) | |
} | |
private def nestedExpand(q: Ast, id: Ident) = | |
Normalize(expand(q, Some(id))) match { | |
case Map(q, _, _) => q | |
case q => q | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment