Partial Function Application and Currying in Java

Writing functional code in Java has not been easy historically, and some aspects of functional programming were not even possible. Java 8 introduced lots of functional programming features (Functional Interfaces, Lambda Expressions, Stream API, and several other features). In this example, I will show how to do partial application and currying in Java.

Partial function application is the process of reducing the number of arguments, producing another function of smaller arity (number of arguments). Currying is the technique of translating the evaluation of a function that takes multiple arguments into evaluating a sequence of functions, each with a single argument. Currying is related to, but not the same as, partial application.

In functional languages, such as Haskell, functions are defined in the curried form by default. I.e., subtract 5 7 actually means (subtract 5) 7 :

Prelude> subtract 5 7
2
Prelude> (subtract 5) 7
2

which is the same as the following:

Prelude> let sub5 = subtract 5
Prelude> sub5 7
2

In Java, we need to write some additional code to achieve the same effect. For demonstration purposes, we will use the following class, which includes two methods for multiplying two and three numbers:

public class Example {

    public static int multiply(int x, int y) {
        return x * y;
    }

    public static int multiply3(int x, int y, int z) {
        return x * y * z;
    }
}

Here is a simple implementation of a partial application function in Java:

public static <A, B, R> Function<B, R> partial(BiFunction<A, B, R> f, A x) {
    return (y) -> f.apply(x, y);
}

Above, we use the Function (a function that accepts one argument) and BiFunction (a function that accepts two arguments) functional interfaces. Unfortunately, Java is more limited than other functional languages. Java has BiFunction, but no TriFunction and more. To handle more arguments, we need to define new functional interface:

@FunctionalInterface
interface TriFunction<A, B, C, R> {
    R apply(A a, B b, C c);
}

Now we can create two more partial application methods:

public static <A, B, C, R> Function<C, R> partial(TriFunction<A, B, C, R> f, A x, B y) {
    return (z) -> f.apply(x, y, z);
}

public static <A, B, C, R> BiFunction<B, C, R> partial(TriFunction<A, B, C, R> f, A x) {
    return (y, z) -> f.apply(x, y, z);
}

With our partial application methods, we can now do the following:

Function<Integer, Integer> mul5 = partial(Example::multiply, 5);
System.out.println(mul5.apply(2)); // 10
Function<Integer, Integer> mul53 = partial(Example::multiply3, 5, 3);
System.out.println(mul53.apply(2)); // 30
BiFunction<Integer, Integer, Integer> mul4 = partial(Example::multiply3, 4);
System.out.println(mul4.apply(2, 5)); // 40

Another cleaner and more “functional” way is to use lambda expressions in the partial application methods, like this:

BiFunction<Integer, Integer, Integer> subtract = (x, y) -> x - y;
Function<Integer, Integer> sub5 = partial(subtract, 5);
System.out.println(sub5.apply(2)); // 3

We can do it even more cleanly:

Function<Integer, Function<Integer, Integer>> subtractCur = x -> y -> x - y;
System.out.println(subtractCur.apply(10).apply(4)); // 6

Java allows us to implement partial application functions, but we have to create all variations separately.

The essential benefits of functional programming are the following:

  • Pure functions are easier to reason about (no unintended side effects, no external state, etc.).
  • Testing pure functions is easier.
  • Debugging is easier.
  • Functional code is more stable and reliable.
  • Functional code is easier to comprehend.

So, functional programming is a highly valued approach to writing code.

Here is the full source code:

package dev.isdn.demo.func.curry;

import java.util.function.Function;
import java.util.function.BiFunction;

public class Example {

    @FunctionalInterface
    interface TriFunction<A, B, C, R> {
        R apply(A a, B b, C c);
    }

    public static int multiply(int x, int y) {
        return x * y;
    }

    public static int multiply3(int x, int y, int z) {
        return x * y * z;
    }

    public static <A, B, R> Function<B, R> partial(BiFunction<A, B, R> f, A x) {
        return (y) -> f.apply(x, y);
    }

    public static <A, B, C, R> Function<C, R> partial(TriFunction<A, B, C, R> f, A x, B y) {
        return (z) -> f.apply(x, y, z);
    }

    public static <A, B, C, R> BiFunction<B, C, R> partial(TriFunction<A, B, C, R> f, A x) {
        return (y, z) -> f.apply(x, y, z);
    }

    public static void main(String[] args) {
        Function<Integer, Integer> mul5 = partial(Example::multiply, 5);
        System.out.println(mul5.apply(2)); // 10

        Function<Integer, Integer> mul53 = partial(Example::multiply3, 5, 3);
        System.out.println(mul53.apply(2)); // 30

        BiFunction<Integer, Integer, Integer> mul4 = partial(Example::multiply3, 4);
        System.out.println(mul4.apply(2, 5)); // 40

        BiFunction<Integer, Integer, Integer> subtract = (x, y) -> x - y;
        Function<Integer, Integer> sub5 = partial(subtract, 5);
        System.out.println(sub5.apply(2)); // 3

        Function<Integer, Function<Integer, Integer>> subtractCur = x -> y -> x - y;
        System.out.println(subtractCur.apply(10).apply(4)); // 6

    }
}

Leave a Reply