Fraud Detection with the DataStream API

Fraud Detection with the DataStream API #

Apache Flink offers a DataStream API for building robust, stateful streaming applications. It provides fine-grained control over state and time, which allows for the implementation of advanced event-driven systems. In this step-by-step guide you’ll learn how to build a stateful streaming application with Flink’s DataStream API.

What Are You Building? #

Credit card fraud is a growing concern in the digital age. Criminals steal credit card numbers by running scams or hacking into insecure systems. Stolen numbers are tested by making one or more small purchases, often for a dollar or less. If that works, they then make more significant purchases to get items they can sell or keep for themselves.

In this tutorial, you will build a fraud detection system for alerting on suspicious credit card transactions. Using a simple set of rules, you will see how Flink allows us to implement advanced business logic and act in real-time.

Prerequisites #

This walkthrough assumes that you have some familiarity with Java or Scala, but you should be able to follow along even if you are coming from a different programming language.

Help, I’m Stuck! #

If you get stuck, check out the community support resources. In particular, Apache Flink’s user mailing list is consistently ranked as one of the most active of any Apache project and a great way to get help quickly.

How to Follow Along #

If you want to follow along, you will require a computer with:

  • Java 8 or 11
  • Maven

A provided Flink Maven Archetype will create a skeleton project with all the necessary dependencies quickly, so you only need to focus on filling out the business logic. These dependencies include flink-streaming-java which is the core dependency for all Flink streaming applications and flink-walkthrough-common that has data generators and other classes specific to this walkthrough.

$ mvn archetype:generate \
    -DarchetypeGroupId=org.apache.flink \
    -DarchetypeArtifactId=flink-walkthrough-datastream-java \
    -DarchetypeVersion=1.13.0 \
    -DgroupId=frauddetection \
    -DartifactId=frauddetection \
    -Dversion=0.1 \
    -Dpackage=spendreport \
    -DinteractiveMode=false
$ mvn archetype:generate \
    -DarchetypeGroupId=org.apache.flink \
    -DarchetypeArtifactId=flink-walkthrough-datastream-scala \
    -DarchetypeVersion=1.13.0 \
    -DgroupId=frauddetection \
    -DartifactId=frauddetection \
    -Dversion=0.1 \
    -Dpackage=spendreport \
    -DinteractiveMode=false

You can edit the groupId, artifactId and package if you like. With the above parameters, Maven will create a folder named frauddetection that contains a project with all the dependencies to complete this tutorial. After importing the project into your editor, you can find a file FraudDetectionJob.java (or FraudDetectionJob.scala) with the following code which you can run directly inside your IDE. Try setting break points through out the data stream and run the code in DEBUG mode to get a feeling for how everything works.

FraudDetectionJob.java

package spendreport;

import org.apache.flink.streaming.api.datastream.DataStream;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.walkthrough.common.sink.AlertSink;
import org.apache.flink.walkthrough.common.entity.Alert;
import org.apache.flink.walkthrough.common.entity.Transaction;
import org.apache.flink.walkthrough.common.source.TransactionSource;

public class FraudDetectionJob {

    public static void main(String[] args) throws Exception {
        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();

        DataStream<Transaction> transactions = env
            .addSource(new TransactionSource())
            .name("transactions");
        
        DataStream<Alert> alerts = transactions
            .keyBy(Transaction::getAccountId)
            .process(new FraudDetector())
            .name("fraud-detector");

        alerts
            .addSink(new AlertSink())
            .name("send-alerts");

        env.execute("Fraud Detection");
    }
}

FraudDetector.java

package spendreport;

import org.apache.flink.streaming.api.functions.KeyedProcessFunction;
import org.apache.flink.util.Collector;
import org.apache.flink.walkthrough.common.entity.Alert;
import org.apache.flink.walkthrough.common.entity.Transaction;

public class FraudDetector extends KeyedProcessFunction<Long, Transaction, Alert> {

    private static final long serialVersionUID = 1L;

    private static final double SMALL_AMOUNT = 1.00;
    private static final double LARGE_AMOUNT = 500.00;
    private static final long ONE_MINUTE = 60 * 1000;

    @Override
    public void processElement(
            Transaction transaction,
            Context context,
            Collector<Alert> collector) throws Exception {

        Alert alert = new Alert();
        alert.setId(transaction.getAccountId());

        collector.collect(alert);
    }
}

FraudDetectionJob.scala

package spendreport

import org.apache.flink.streaming.api.scala._
import org.apache.flink.walkthrough.common.sink.AlertSink
import org.apache.flink.walkthrough.common.entity.Alert
import org.apache.flink.walkthrough.common.entity.Transaction
import org.apache.flink.walkthrough.common.source.TransactionSource

object FraudDetectionJob {

  @throws[Exception]
  def main(args: Array[String]): Unit = {
    val env: StreamExecutionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment

    val transactions: DataStream[Transaction] = env
      .addSource(new TransactionSource)
      .name("transactions")

    val alerts: DataStream[Alert] = transactions
      .keyBy(transaction => transaction.getAccountId)
      .process(new FraudDetector)
      .name("fraud-detector")

    alerts
      .addSink(new AlertSink)
      .name("send-alerts")

    env.execute("Fraud Detection")
  }
}

FraudDetector.scala

package spendreport

import org.apache.flink.streaming.api.functions.KeyedProcessFunction
import org.apache.flink.util.Collector
import org.apache.flink.walkthrough.common.entity.Alert
import org.apache.flink.walkthrough.common.entity.Transaction

object FraudDetector {
  val SMALL_AMOUNT: Double = 1.00
  val LARGE_AMOUNT: Double = 500.00
  val ONE_MINUTE: Long     = 60 * 1000L
}

@SerialVersionUID(1L)
class FraudDetector extends KeyedProcessFunction[Long, Transaction, Alert] {

  @throws[Exception]
  def processElement(
      transaction: Transaction,
      context: KeyedProcessFunction[Long, Transaction, Alert]#Context,
      collector: Collector[Alert]): Unit = {

    val alert = new Alert
    alert.setId(transaction.getAccountId)

    collector.collect(alert)
  }
}

Breaking Down the Code #

Let’s walk step-by-step through the code of these two files. The FraudDetectionJob class defines the data flow of the application and the FraudDetector class defines the business logic of the function that detects fraudulent transactions.

We start describing how the Job is assembled in the main method of the FraudDetectionJob class.

The Execution Environment #

The first line sets up your StreamExecutionEnvironment. The execution environment is how you set properties for your Job, create your sources, and finally trigger the execution of the Job.

StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
val env = StreamExecutionEnvironment.getExecutionEnvironment

Creating a Source #

Sources ingest data from external systems, such as Apache Kafka, Rabbit MQ, or Apache Pulsar, into Flink Jobs. This walkthrough uses a source that generates an infinite stream of credit card transactions for you to process. Each transaction contains an account ID (accountId), timestamp (timestamp) of when the transaction occurred, and US$ amount (amount). The name attached to the source is just for debugging purposes, so if something goes wrong, we will know where the error originated.

DataStream<Transaction> transactions = env
    .addSource(new TransactionSource())
    .name("transactions");
val transactions: DataStream[Transaction] = env
  .addSource(new TransactionSource)
  .name("transactions")

Partitioning Events & Detecting Fraud #

The transactions stream contains a lot of transactions from a large number of users, such that it needs to be processed in parallel by multiple fraud detection tasks. Since fraud occurs on a per-account basis, you must ensure that all transactions for the same account are processed by the same parallel task of the fraud detector operator.

To ensure that the same physical task processes all records for a particular key, you can partition a stream using DataStream#keyBy. The process() call adds an operator that applies a function to each partitioned element in the stream. It is common to say the operator immediately after a keyBy, in this case FraudDetector, is executed within a keyed context.

DataStream<Alert> alerts = transactions
    .keyBy(Transaction::getAccountId)
    .process(new FraudDetector())
    .name("fraud-detector");
val alerts: DataStream[Alert] = transactions
  .keyBy(transaction => transaction.getAccountId)
  .process(new FraudDetector)
  .name("fraud-detector")

Outputting Results #

A sink writes a DataStream to an external system; such as Apache Kafka, Cassandra, and AWS Kinesis. The AlertSink logs each Alert record with log level INFO, instead of writing it to persistent storage, so you can easily see your results.

alerts.addSink(new AlertSink());
alerts.addSink(new AlertSink)

The Fraud Detector #

The fraud detector is implemented as a KeyedProcessFunction. Its method KeyedProcessFunction#processElement is called for every transaction event. This first version produces an alert on every transaction, which some may say is overly conservative.

The next steps of this tutorial will guide you to expand the fraud detector with more meaningful business logic.

public class FraudDetector extends KeyedProcessFunction<Long, Transaction, Alert> {

    private static final double SMALL_AMOUNT = 1.00;
    private static final double LARGE_AMOUNT = 500.00;
    private static final long ONE_MINUTE = 60 * 1000;

    @Override
    public void processElement(
            Transaction transaction,
            Context context,
            Collector<Alert> collector) throws Exception {
  
        Alert alert = new Alert();
        alert.setId(transaction.getAccountId());

        collector.collect(alert);
    }
}
object FraudDetector {
  val SMALL_AMOUNT: Double = 1.00
  val LARGE_AMOUNT: Double = 500.00
  val ONE_MINUTE: Long     = 60 * 1000L
}

@SerialVersionUID(1L)
class FraudDetector extends KeyedProcessFunction[Long, Transaction, Alert] {

  @throws[Exception]
  def processElement(
      transaction: Transaction,
      context: KeyedProcessFunction[Long, Transaction, Alert]#Context,
      collector: Collector[Alert]): Unit = {

    val alert = new Alert
    alert.setId(transaction.getAccountId)

    collector.collect(alert)
  }
}

Writing a Real Application (v1) #

For the first version, the fraud detector should output an alert for any account that makes a small transaction immediately followed by a large one. Where small is anything less than $1.00 and large is more than $500. Imagine your fraud detector processes the following stream of transactions for a particular account.

Expected Output #

Running this code with the provided TransactionSource will emit fraud alerts for account 3. You should see the following output in your task manager logs:

2019-08-19 14:22:06,220 INFO  org.apache.flink.walkthrough.common.sink.AlertSink - Alert{id=3}
2019-08-19 14:22:11,383 INFO  org.apache.flink.walkthrough.common.sink.AlertSink - Alert{id=3}
2019-08-19 14:22:16,551 INFO  org.apache.flink.walkthrough.common.sink.AlertSink - Alert{id=3}
2019-08-19 14:22:21,723 INFO  org.apache.flink.walkthrough.common.sink.AlertSink - Alert{id=3}
2019-08-19 14:22:26,896 INFO  org.apache.flink.walkthrough.common.sink.AlertSink - Alert{id=3}