Why Do We Need to Use Flatmap

Why do we need to use flatMap?

When I started to have a look at Rxjs I also stumbled on that stone. What helped me is the following:

  • documentation from reactivex.io . For instance, for flatMap: http://reactivex.io/documentation/operators/flatmap.html
  • documentation from rxmarbles : http://rxmarbles.com/. You will not find flatMap there, you must look at mergeMap instead (another name).
  • the introduction to Rx that you have been missing: https://gist.github.com/staltz/868e7e9bc2a7b8c1f754. It addresses a very similar example. In particular it addresses the fact that a promise is akin to an observable emitting only one value.
  • finally looking at the type information from RxJava. Javascript not being typed does not help here. Basically if Observable<T> denotes an observable object which pushes values of type T, then flatMap takes a function of type T' -> Observable<T> as its argument, and returns Observable<T>. map takes a function of type T' -> T and returns Observable<T>.

    Going back to your example, you have a function which produces promises from an url string. So T' : string, and T : promise. And from what we said before promise : Observable<T''>, so T : Observable<T''>, with T'' : html. If you put that promise producing function in map, you get Observable<Observable<T''>> when what you want is Observable<T''>: you want the observable to emit the html values. flatMap is called like that because it flattens (removes an observable layer) the result from map. Depending on your background, this might be chinese to you, but everything became crystal clear to me with typing info and the drawing from here: http://reactivex.io/documentation/operators/flatmap.html.

Why do we need flatMap (in general)?

FlatMap, known as "bind" in some other languages, is as you said yourself for function composition.

Imagine for a moment that you have some functions like these:

def foo(x: Int): Option[Int] = Some(x + 2)
def bar(x: Int): Option[Int] = Some(x * 3)

The functions work great, calling foo(3) returns Some(5), and calling bar(3) returns Some(9), and we're all happy.

But now you've run into the situation that requires you to do the operation more than once.

foo(3).map(x => foo(x)) // or just foo(3).map(foo) for short

Job done, right?

Except not really. The output of the expression above is Some(Some(7)), not Some(7), and if you now want to chain another map on the end you can't because foo and bar take an Int, and not an Option[Int].

Enter flatMap

foo(3).flatMap(foo)

Will return Some(7), and

foo(3).flatMap(foo).flatMap(bar)

Returns Some(15).

This is great! Using flatMap lets you chain functions of the shape A => M[B] to oblivion (in the previous example A and B are Int, and M is Option).

More technically speaking; flatMap and bind have the signature M[A] => (A => M[B]) => M[B], meaning they take a "wrapped" value, such as Some(3), Right('foo), or List(1,2,3) and shove it through a function that would normally take an unwrapped value, such as the aforementioned foo and bar. It does this by first "unwrapping" the value, and then passing it through the function.

I've seen the box analogy being used for this, so observe my expertly drawn MSPaint illustration:
Sample Image

This unwrapping and re-wrapping behavior means that if I were to introduce a third function that doesn't return an Option[Int] and tried to flatMap it to the sequence, it wouldn't work because flatMap expects you to return a monad (in this case an Option)

def baz(x: Int): String = x + " is a number"

foo(3).flatMap(foo).flatMap(bar).flatMap(baz) // <<< ERROR

To get around this, if your function doesn't return a monad, you'd just have to use the regular map function

foo(3).flatMap(foo).flatMap(bar).map(baz)

Which would then return Some("15 is a number")

What's the difference between map() and flatMap() methods in Java 8?

Both map and flatMap can be applied to a Stream<T> and they both return a Stream<R>. The difference is that the map operation produces one output value for each input value, whereas the flatMap operation produces an arbitrary number (zero or more) values for each input value.

This is reflected in the arguments to each operation.

The map operation takes a Function, which is called for each value in the input stream and produces one result value, which is sent to the output stream.

The flatMap operation takes a function that conceptually wants to consume one value and produce an arbitrary number of values. However, in Java, it's cumbersome for a method to return an arbitrary number of values, since methods can return only zero or one value. One could imagine an API where the mapper function for flatMap takes a value and returns an array or a List of values, which are then sent to the output. Given that this is the streams library, a particularly apt way to represent an arbitrary number of return values is for the mapper function itself to return a stream! The values from the stream returned by the mapper are drained from the stream and are passed to the output stream. The "clumps" of values returned by each call to the mapper function are not distinguished at all in the output stream, thus the output is said to have been "flattened."

Typical use is for the mapper function of flatMap to return Stream.empty() if it wants to send zero values, or something like Stream.of(a, b, c) if it wants to return several values. But of course any stream can be returned.

What does flatMap do exactly?

Functors define map which have type

trait Functor[F[_]] {
def map[A, B](f: A => B)(v: F[A]): F[B]
}

Monads are functors which support two additional operations:

trait Monad[M[_]] extends Functor[M] {
def pure[A](v: A): M[A]
def join[A](m: M[M[A]]): M[A]
}

Join flattens nested values e.g. if m is List then join has type

def joinList[A](l: List[List[A]]): List[A]

If you have a monad m and you map over it, what happens if b is the same monadic type? For example:

def replicate[A](i: Int, value: A): List[A] = ???
val f = new Functor[List] {
def map[A, B](f: A => B)(v: List[A]) = v.map(f)
}

then

f.map(x => replicate(x, x))(List(1,2,3)) == List(List(1), List(2,2), List(3,3,3))

This has type List[List[Int]] while the input is a List[Int]. It's fairly common with a chain of operations to want each step to return the same input type. Since List can also be made into a monad, you can easily create such a list using join:

listMonad.join(List(List(1), List(2,2), List(3,3,3))) == List(1,2,2,3,3,3)

Now you might want to write a function to combine these two operations into one:

trait Monad[M] {
def flatMap[A, B](f: A => M[B])(m: M[A]): M[B] = join(map(f)(m))
}

then you can simply do:

listMonad.flatMap(List(1,2,3), x => replicate(x, x)) == List(1,2,2,3,3,3)

Exactly what flatMap does depends on the monad type constructor M (List in this example) since it depends on map and join.

Why does mapMulti need type information in comparison to flatMap

Notice that the kind of type inference required to deduce the resulting stream type when you use flatMap, is very different from that when you use mapMulti.

When you use flatMap, the type of the resulting stream is the same type as the return type of the lambda body. That's a special thing that the compiler has been designed to infer type variables from (i.e. the compiler "knows about" it).

However, in the case of mapMulti, the type of the resulting stream that you presumably want can only be inferred from the things you do to the consumer lambda parameter. Hypothetically, the compiler could be designed so that, for example, if you have said consumer.accept(1), then it would look at what you have passed to accept, and see that you want a Stream<Integer>, and in the case of getItems().forEach(consumer), the only place where the type Item could have come from is the return type of getItems, so it would need to go look at that instead.

You are basically asking the compiler to infer the parameter types of a lambda, based on the types of arbitrary expressions inside it. The compiler simply has not been designed to do this.

Other than adding the <Item> prefix, there are other (longer) ways to let it infer a Stream<Item> as the return type of mapMulti:

Make the lambda explicitly typed:

var items = users.stream()
.mapMulti((User u, Consumer<Item> consumer) -> u.getItems().forEach(consumer))
.collect(Collectors.toSet());

Add a temporary stream variable:

// By looking at the type of itemStream, the compiler can figure out that mapMulti should return a Stream<Item>
Stream<Item> itemStream = users.stream()
.mapMulti((u, consumer) -> u.getItems().forEach(consumer));
var items = itemStream.collect(Collectors.toSet());

I don't know if this is more "simplified", but I think it is neater if you use method references:

var items = users.stream()
.map(User::getItems)
.<Item>mapMulti(Iterable::forEach)
.collect(Collectors.toSet());

Why does Finatra use flatMap and not just map?

From a theoretical point of view, if we take away the exceptions part (they cannot be reasoned about using category theory anyway), then those two operations are completely identical as long as your construct of choice (in your case Twitter Future) forms a valid monad.

I don't want to go into length over these concepts, so I'm just going to present the laws directly (using Scala Future):

import scala.concurrent.ExecutionContext.Implicits.global

// Functor identity law
Future(42).map(x => x) == Future(42)

// Monad left-identity law
val f = (x: Int) => Future(x)
Future(42).flatMap(f) == f(42)

// combining those two, since every Monad is also a Functor, we get:
Future(42).map(x => x) == Future(42).flatMap(x => Future(x))

// and if we now generalise identity into any function:
Future(42).map(x => x + 20) == Future(42).flatMap(x => Future(x + 20))

So yes, as you already hinted, those two approaches are identical.

However, there are three comments that I have on this, given that we are including exceptions into the mix:

  1. Be careful - when it comes to throwing exceptions, Scala Future (probably Twitter too) violates the left-identity law on purpose, in order to trade it off for some extra safety.

Example:

import scala.concurrent.ExecutionContext.Implicits.global

def sneakyFuture = {
throw new Exception("boom!")
Future(42)
}

val f1 = Future(42).flatMap(_ => sneakyFuture)
// Future(Failure(java.lang.Exception: boom!))

val f2 = sneakyFuture
// Exception in thread "main" java.lang.Exception: boom!

  1. As @randbw mentioned, throwing exceptions is not idiomatic to FP and it violates principles such as purity of functions and referential transparency of values.

Scala and Twitter Future make it easy for you to just throw an exception - as long as it happens in a Future context, exception will not bubble up, but instead cause that Future to fail. However, that doesn't mean that literally throwing them around in your code should be permitted, because it ruins the structure of your programs (similarly to how GOTO statements do it, or break statements in loops, etc.).

Preferred practice is to always evaluate every code path into a value instead of throwing bombs around, which is why it's better to flatMap into a (failed) Future than to map into some code that throws a bomb.


  1. Keep in mind referential transparency.

If you use map instead of flatMap and someone takes the code from the map and extracts it out into a function, then you're safer if this function returns a Future, otherwise someone might run it outside of Future context.

Example:

import scala.concurrent.ExecutionContext.Implicits.global

Future(42).map(x => {
// this should be done inside a Future
x + 1
})

This is fine. But after completely valid refactoring (which utilizes the rule of referential transparency), your codfe becomes this:

def f(x: Int) =  {
// this should be done inside a Future
x + 1
}
Future(42).map(x => f(x))

And you will run into problems if someone calls f directly. It's much safer to wrap the code into a Future and flatMap on it.

Of course, you could argue that even when using flatMap someone could rip out the f from .flatMap(x => Future(f(x)), but it's not that likely. On the other hand, simply extracting the response processing logic into a separate function fits perfectly with the functional programming's idea of composing small functions into bigger ones, and it's likely to happen.

Which happens first in flatMap, flatten or map?

The purpose of flatmap functions is to take a function that returns a list, and then flatten the result.

So it will map the iterable (which splits in this case), then flatten the resulting 2D iterable (List in this case).



Related Topics



Leave a reply



Submit