Scala inference rules-of-thumb

Scala has, let’s say, temperamental type inference. But if you keep a mental model of how scalac does inference, more or less, it’s a lot easier to write code that infers, and perhaps more importantly, to design APIs that will infer conveniently and easily.

Inference happens left-to-right

Inference proceeds from left to right, except where there is an expected type. This is the theme of most of the remaining rules; examples of those are also examples of this.

So if you only remember one rule, remember this one.

The method receiver cut

The main source of left-to-rightness is the method-receiver cut: when calling a method receiver.m(arg), Scala must fully determine the type of receiver before it can search for m. So the expected type of this whole expression is irrelevant to determining receiver's type.

For example, Scalaz provides a method into such that x.into(f) = f.apply(x). However,

f.apply(x)  // f's type must be fully determined without context
x.into(f)   // x's type must be fully determined without context

Local val and def cut

The second main cut is that for local vals and defs. The usage of a local val or def is never used to infer the type of that val.

// This might compile,
f(some.long(expr))
// But this might not
val x = some.long(expr)
f(x)

The method selection cut

The final main cut is “method selection”: once Scala selects a method, it will not back up and try a different one. If there is a direct method of a given name (not one added by implicit class), it will always win, even if the call doesn’t compile. That’s why we can’t replace the Option#fold method in Scala with a better implicit one in Scalaz: the implicitly added fold cannot win, even though it takes a different number of arguments.

Type parameters can be inferred from expected type

When it seems like you need to pass type parameters explicitly, it is often more convenient (and yields clearer code) to add an ascription to impose an expected type, giving inference a little more information to pick types. It’s clearer because it expresses what instead of how, and isn’t sensitive to type parameter order or the introduction of new method type parameters.

// two ways to produce a Gen[Party]
 
val gp1 = Tag.subst[String, Gen, PartyTag](Gen.identifier)
//                  A     ,   F,        T
 
// all 3 type params, yes, even the higher-kinded one,
// can be inferred by matching the expected type to the
// return type of Tag.subst, F[A @@ T]
val gp2: Gen[Party] = Tag.subst(Gen.identifier)

Expected type flows into ordinary arguments in a very convenient way. If Gen.identifier (or another expression of your choosing) required an expected type to choose type parameters of its own, it could get those from the fact that the expected type to subst's argument in both cases above is Gen[String].

Type parameters are inferred arglist by arglist

This is usually why method signatures are curried. For example,

// in List
def foldLeft[B](z: B)(op: (B, A) ⇒ B): B
 
// consider the call
xs.foldLeft(z)((acc, elt) => expr)

Here, B will be solely determined as the inferred type of z. Thus, the lambda can be inferred, with acc having whatever type was chosen for B. Since B is already determined, expr has that as an expected type, and cannot change that choice.

Implicit parameters are a special case; when not explicitly specified, they are considered one by one, left to right. You can use this to chain MPTCs with effective fundeps in a single implicit parameter list. This is very important to take advantage of when working with Shapeless’s typeclasses, most of which are such MPTCs.

To infer a lambda, you need to expect an argument type

To fully infer a lambda, the expected type must be function, with known argument type; result type can be determined from the body, in true left-to-right fashion. Scala 2.12’s map illustrates.

// supposing
val xs: List[Int]
// it has the method with signature
def map[B, That](f: Int ⇒ B)(implicit bf: CanBuildFrom[List[Int], B, That]): That
 
// and this works
xs.map(n => (n, n + 1))

Here’s the complete chain of events, using several rules we’ve talked about already.

  1. xs is fully determined, independent of context.
  2. map is looked up; the above signature is found.
  3. The expected type of the argument is Int => B where B is yet to be determined.
  4. That means the lambda can be fully inferred; by inferring its body, we determine that B = (Int, Int).
  5. Implicit resolution happens; we’ve determined B, but not That, but this is no problem for implicit resolution, which happily searches for a CanBuildFrom[List[Int], (Int, Int), That] where That is yet to be determined.
  6. That resolution will yield a value of type CanBuildFrom[List[Int], (Int, Int), List[(Int, Int)]]; thus we finally determine That = List[(Int, Int)].
  7. And that’s the inferred result type.

To review, the first argument list determines what B is, and the second argument list takes account of that when determining what That is, even during implicit resolution.

Method overloading breaks all kinds of inference

If another overload of map existed above, step 3 could not happen, because the presence of overloads means that ordinary arguments must have fully determined types independent of context. That would also make the implications of subst described above no longer apply. Many more things break in the presence of overloading; it’s very easy to break a great deal of Scala code by introducing a new overload for a method, and disable many convenient forms of expression for new code, so it is typically safer to just use a different method name.

Scala 2.13 fixes a few of the method overloading problems that were necessary to support the new collections library design, but this was not a comprehensive “overload fix”, so overloads should still be avoided for similar reasons as with 2.12.

Implicit resolution searches the type-parts’ companions

The non-orphan typeclass instance space consists of the companions of a type’s “parts”. For example, to find the implicit Monoid[List[Blah]], objects Monoid, List, and Blah will be searched, as well as the companion objects of all of Monoid's, List's, and Blah's superclasses. “Parts” is defined by specification, though you must specify -Xsource:2.13 (or use Scala 2.13+) to get a fully correct implementation.

For coherent typeclass definitions, placing instances in “parts” instead of requiring import statements to use particular instances is the best default choice. That’s because, unlike imports, it gives you a set of globally well-defined places to look for instances when browsing code, and therefore a more predictable implicit resolution. It is also not subject to name-shadowing, a frequent source of implicit lookup failures.

Some examples of good and bad design

The introduction of Option#fold was very frustrating to we Scalazzi, because the signature is not well-chosen.

def fold[Z](none: => Z)(some: A => Z): Z  // bad
def cata[Z](some: A => Z, none: => Z): Z  // good
 
// attempts to implement toList:
oa.fold(Nil)(_ :: Nil)  // will not compile
oa.cata(_ :: Nil, Nil)  // will compile just fine

That’s because in the fold case, only none is used to determine Z, whereas cata's signature uses both some and none to determine Z. By contrast, currying is appropriate for List#foldLeft, but not merely because there’s a function argument.

The standard library supplies both sides of the isomorphism between A => Option[B] and PartialFunction[A, B], both of which can be expressed by lambdas. These are good functions to supply, because the most convenient lambda expression can fit one or the other of those types depending on circumstance. However, only the conversion from the former to the latter, Function.unlift, is convenient to use. That’s because its counterpart, lift, is defined as a method on PartialFunction. To illustrate, these should do the same thing:

oa.collect(Function.unlift(a => Some(a))) // will compile
oa.flatMap({case a => a}.lift)            // will not compile

Here’s what happens with unlift.

  1. Supposing oa: Option[Qux], collect expects an argument of type PartialFunction[Qux, B], where B is not yet determined.
  2. unlift’s first type parameter is inferred to be Qux from this fact.
  3. Therefore, unlift expects an argument of type Qux => Option[B], where B is not yet determined.
  4. From the lambda body we infer B = Qux, and the determination flows all the way out to collect, and the result type of the whole expression.

By contrast, lift meets its demise quickly. That’s because the occurrence of the literal PartialFunction is in method receiver position, which cuts it off from expected type. lift is still useful when implementing higher-order functions, but it could be so much more so.

Designing a type-safe equals function

It’s common to make a mistake when writing expressions like a == b where a and b have incompatible types; scalac has some warnings about this, but they’re very unreliable and cover only a few cases. If we keep in mind that inference happens left-to-right, we can design a safer replacement for this function, though.

Here are three attempts at designing such a replacement.

def eq1[A](l: A, r: A): Boolean
 
def eq2[A](l: A)(r: A): Boolean
 
import scalaz.Leibniz.===
def eq3[A, B](l: A, r: B)(implicit ev: A === B): Boolean

eq1 fails completely. That’s because, for any two types, no matter whether related or not, there is defined a least upper-bound, and that will be chosen for A. So if you pass an Int and a String, it will just choose A = Any; if you pass a List and an Option, it will just choose A = Product with Serializable; and so on. Normally, you could catch this later in the flow, but the return type doesn’t contain A.

eq2 is better, and might be good enough. A is determined only from l's type, and then the r expression has an expected type of whatever that is. It could still be a subtype, though; that might be undesirable if you’re worried about l's type being accidentally too wide, as can happen when you append lists and such. (If you define an implicit class to do this, it will likely have similar inference properties to eq2. Scalaz’s Equal extension method has this behavior.) For example, this mistake is possible for eq2, though it is caught by the wartremover Any (if you haven’t disabled it!):

def foo(xs: List[Int], ys: List[String], zs: List[Int]) =
  eq2(xs ++ ys)(zs) // A inferred to be List[Any], a supertype of zs's type

eq3 is the strictest form, even though it looks like a simple expanding of eq1. (=== is like Haskell’s ~; it represents a type-equality constraint.) That’s because no “least upper-bounding” happens; the types A and B are determined from the two arguments, independently of each other, and then we require a proof that the types are equal. The necessary implicit value is defined in a “part” of ===, so resolution will succeed if the types are equal, no matter where we are in the code. Since we can’t walk backwards (that wouldn’t be “left to right”!) and change either A or B to complete that implicit resolution, they must have been independently inferred to have precisely the same type. (This also works in the curried or implicit class form; the type parameter inference behaves exactly the same.)

The key to this trickery is that it takes advantage of what is usually a weakness of Scala inference: that there are “cuts” determined by the left-to-right trend of inference. I think this is what great Scala APIs have in common: they work around these weaknesses in their design, turning them into strengths for type-safety wherever possible.

2 Likes