Skip to content

Commit afe30f0

Browse files
authored
Use argparse to setup spark (#2082)
1 parent bf33945 commit afe30f0

File tree

2 files changed

+22
-17
lines changed

2 files changed

+22
-17
lines changed

images/pyspark-notebook/Dockerfile

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,11 @@ ENV SPARK_OPTS="--driver-java-options=-Xms1024M --driver-java-options=-Xmx4096M
4141
COPY setup_spark.py /opt/setup-scripts/
4242

4343
# Setup Spark
44-
RUN SPARK_VERSION="${spark_version}" \
45-
HADOOP_VERSION="${hadoop_version}" \
46-
SCALA_VERSION="${scala_version}" \
47-
SPARK_DOWNLOAD_URL="${spark_download_url}" \
48-
/opt/setup-scripts/setup_spark.py
44+
RUN /opt/setup-scripts/setup_spark.py \
45+
--spark-version="${spark_version}" \
46+
--hadoop-version="${hadoop_version}" \
47+
--scala-version="${scala_version}" \
48+
--spark-download-url="${spark_download_url}"
4949

5050
# Configure IPython system-wide
5151
COPY ipython_kernel_config.py "/etc/ipython/"

images/pyspark-notebook/setup_spark.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
# Requirements:
66
# - Run as the root user
7-
# - Required env variables: SPARK_HOME, HADOOP_VERSION, SPARK_DOWNLOAD_URL
8-
# - Optional env variables: SPARK_VERSION, SCALA_VERSION
7+
# - Required env variable: SPARK_HOME
98

9+
import argparse
1010
import logging
1111
import os
1212
import subprocess
@@ -27,13 +27,10 @@ def get_all_refs(url: str) -> list[str]:
2727
return [a["href"] for a in soup.find_all("a", href=True)]
2828

2929

30-
def get_spark_version() -> str:
30+
def get_latest_spark_version() -> str:
3131
"""
32-
If ${SPARK_VERSION} env variable is non-empty, simply returns it
33-
Otherwise, returns the last stable version of Spark using spark archive
32+
Returns the last stable version of Spark using spark archive
3433
"""
35-
if (version := os.environ["SPARK_VERSION"]) != "":
36-
return version
3734
LOGGER.info("Downloading Spark versions information")
3835
all_refs = get_all_refs("https://archive.apache.org/dist/spark/")
3936
stable_versions = [
@@ -106,12 +103,20 @@ def configure_spark(spark_dir_name: str, spark_home: Path) -> None:
106103
if __name__ == "__main__":
107104
logging.basicConfig(level=logging.INFO)
108105

109-
spark_version = get_spark_version()
106+
arg_parser = argparse.ArgumentParser()
107+
arg_parser.add_argument("--spark-version", required=True)
108+
arg_parser.add_argument("--hadoop-version", required=True)
109+
arg_parser.add_argument("--scala-version", required=True)
110+
arg_parser.add_argument("--spark-download-url", type=Path, required=True)
111+
args = arg_parser.parse_args()
112+
113+
args.spark_version = args.spark_version or get_latest_spark_version()
114+
110115
spark_dir_name = download_spark(
111-
spark_version=spark_version,
112-
hadoop_version=os.environ["HADOOP_VERSION"],
113-
scala_version=os.environ["SCALA_VERSION"],
114-
spark_download_url=Path(os.environ["SPARK_DOWNLOAD_URL"]),
116+
spark_version=args.spark_version,
117+
hadoop_version=args.hadoop_version,
118+
scala_version=args.scala_version,
119+
spark_download_url=args.spark_download_url,
115120
)
116121
configure_spark(
117122
spark_dir_name=spark_dir_name, spark_home=Path(os.environ["SPARK_HOME"])

0 commit comments

Comments
 (0)