Write a Spark application
Spark$SPARK_HOME/examples
and at GitHub
Scala
To avoid Scala compatibility issues, we suggest you use Spark dependencies for the
correct Scala version when you compile a Spark application for an Amazon EMR cluster. The
Scala version you should use depends on the version of Spark installed on your
cluster. For example, Amazon EMR release 5.30.1 uses Spark 2.4.5, which is built with
Scala 2.11. If your cluster uses Amazon EMR release 5.30.1, use Spark dependencies for
Scala 2.11. For more information about the Scala versions used by Spark, see the
Apache Spark
documentation
package org.apache.spark.examples
import scala.math.random
import org.apache.spark._
/** Computes an approximation to pi */
object SparkPi {
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("Spark Pi")
val spark = new SparkContext(conf)
val slices = if (args.length > 0) args(0).toInt else 2
val n = math.min(100000L * slices, Int.MaxValue).toInt // avoid overflow
val count = spark.parallelize(1 until n, slices).map { i =>
val x = random * 2 - 1
val y = random * 2 - 1
if (x*x + y*y < 1) 1 else 0
}.reduce(_ + _)
println("Pi is roughly " + 4.0 * count / n)
spark.stop()
}
}
Java
package org.apache.spark.examples;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import java.util.ArrayList;
import java.util.List;
/**
* Computes an approximation to pi
* Usage: JavaSparkPi [slices]
*/
public final class JavaSparkPi {
public static void main(String[] args) throws Exception {
SparkConf sparkConf = new SparkConf().setAppName("JavaSparkPi");
JavaSparkContext jsc = new JavaSparkContext(sparkConf);
int slices = (args.length == 1) ? Integer.parseInt(args[0]) : 2;
int n = 100000 * slices;
List<Integer> l = new ArrayList<Integer>(n);
for (int i = 0; i < n; i++) {
l.add(i);
}
JavaRDD<Integer> dataSet = jsc.parallelize(l, slices);
int count = dataSet.map(new Function<Integer, Integer>() {
@Override
public Integer call(Integer integer) {
double x = Math.random() * 2 - 1;
double y = Math.random() * 2 - 1;
return (x * x + y * y < 1) ? 1 : 0;
}
}).reduce(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer integer, Integer integer2) {
return integer + integer2;
}
});
System.out.println("Pi is roughly " + 4.0 * count / n);
jsc.stop();
}
}
Python
import argparse
import logging
from operator import add
from random import random
from pyspark.sql import SparkSession
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
def calculate_pi(partitions, output_uri):
"""
Calculates pi by testing a large number of random numbers against a unit circle
inscribed inside a square. The trials are partitioned so they can be run in
parallel on cluster instances.
:param partitions: The number of partitions to use for the calculation.
:param output_uri: The URI where the output is written, typically an Amazon S3
bucket, such as 's3://example-bucket/pi-calc'.
"""
def calculate_hit(_):
x = random() * 2 - 1
y = random() * 2 - 1
return 1 if x**2 + y**2 < 1 else 0
tries = 100000 * partitions
logger.info(
"Calculating pi with a total of %s tries in %s partitions.", tries, partitions
)
with SparkSession.builder.appName("My PyPi").getOrCreate() as spark:
hits = (
spark.sparkContext.parallelize(range(tries), partitions)
.map(calculate_hit)
.reduce(add)
)
pi = 4.0 * hits / tries
logger.info("%s tries and %s hits gives pi estimate of %s.", tries, hits, pi)
if output_uri is not None:
df = spark.createDataFrame([(tries, hits, pi)], ["tries", "hits", "pi"])
df.write.mode("overwrite").json(output_uri)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--partitions",
default=2,
type=int,
help="The number of parallel partitions to use when calculating pi.",
)
parser.add_argument(
"--output_uri", help="The URI where output is saved, typically an S3 bucket."
)
args = parser.parse_args()
calculate_pi(args.partitions, args.output_uri)