aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYuanjian Li <xyliyuanjian@gmail.com>2018-06-05 08:23:08 +0700
committerhyukjinkwon <gurwls223@apache.org>2018-06-05 08:23:08 +0700
commitdbb4d83829ec4b51d6e6d3a96f7a4e611d8827bc (patch)
tree357e2fc0b8f0e36193791dd51183984caca4be03
parentff0501b0c27dc8149bd5fb38a19d9b0056698766 (diff)
[SPARK-24215][PYSPARK] Implement _repr_html_ for dataframes in PySpark
## What changes were proposed in this pull request? Implement `_repr_html_` for PySpark while in notebook and add config named "spark.sql.repl.eagerEval.enabled" to control this. The dev list thread for context: http://apache-spark-developers-list.1001551.n3.nabble.com/eager-execution-and-debuggability-td23928.html ## How was this patch tested? New ut in DataFrameSuite and manual test in jupyter. Some screenshot below. **After:** ![image](https://user-images.githubusercontent.com/4833765/40268422-8db5bef0-5b9f-11e8-80f1-04bc654a4f2c.png) **Before:** ![image](https://user-images.githubusercontent.com/4833765/40268431-9f92c1b8-5b9f-11e8-9db9-0611f0940b26.png) Author: Yuanjian Li <xyliyuanjian@gmail.com> Closes #21370 from xuanyuanking/SPARK-24215.
-rw-r--r--docs/configuration.md27
-rw-r--r--python/pyspark/sql/dataframe.py65
-rw-r--r--python/pyspark/sql/tests.py30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala84
4 files changed, 176 insertions, 30 deletions
diff --git a/docs/configuration.md b/docs/configuration.md
index 64af0e98a8..5588c372d3 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -457,6 +457,33 @@ Apart from these, the following properties are also available, and may be useful
</td>
</tr>
<tr>
+ <td><code>spark.sql.repl.eagerEval.enabled</code></td>
+ <td>false</td>
+ <td>
+ Enable eager evaluation or not. If true and the REPL you are using supports eager evaluation,
+ Dataset will be ran automatically. The HTML table which generated by <code>_repl_html_</code>
+ called by notebooks like Jupyter will feedback the queries user have defined. For plain Python
+ REPL, the output will be shown like <code>dataframe.show()</code>
+ (see <a href="https://issues.apache.org/jira/browse/SPARK-24215">SPARK-24215</a> for more details).
+ </td>
+</tr>
+<tr>
+ <td><code>spark.sql.repl.eagerEval.maxNumRows</code></td>
+ <td>20</td>
+ <td>
+ Default number of rows in eager evaluation output HTML table generated by <code>_repr_html_</code> or plain text,
+ this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> is set to true.
+ </td>
+</tr>
+<tr>
+ <td><code>spark.sql.repl.eagerEval.truncate</code></td>
+ <td>20</td>
+ <td>
+ Default number of truncate in eager evaluation output HTML table generated by <code>_repr_html_</code> or
+ plain text, this only take effect when <code>spark.sql.repl.eagerEval.enabled</code> set to true.
+ </td>
+</tr>
+<tr>
<td><code>spark.files</code></td>
<td></td>
<td>
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 808235ab25..1e6a1acebb 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -78,6 +78,9 @@ class DataFrame(object):
self.is_cached = False
self._schema = None # initialized lazily
self._lazy_rdd = None
+ # Check whether _repr_html is supported or not, we use it to avoid calling _jdf twice
+ # by __repr__ and _repr_html_ while eager evaluation opened.
+ self._support_repr_html = False
@property
@since(1.3)
@@ -351,8 +354,68 @@ class DataFrame(object):
else:
print(self._jdf.showString(n, int(truncate), vertical))
+ @property
+ def _eager_eval(self):
+ """Returns true if the eager evaluation enabled.
+ """
+ return self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.enabled", "false").lower() == "true"
+
+ @property
+ def _max_num_rows(self):
+ """Returns the max row number for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.maxNumRows", "20"))
+
+ @property
+ def _truncate(self):
+ """Returns the truncate length for eager evaluation.
+ """
+ return int(self.sql_ctx.getConf(
+ "spark.sql.repl.eagerEval.truncate", "20"))
+
def __repr__(self):
- return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+ if not self._support_repr_html and self._eager_eval:
+ vertical = False
+ return self._jdf.showString(
+ self._max_num_rows, self._truncate, vertical)
+ else:
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+
+ def _repr_html_(self):
+ """Returns a dataframe with html code when you enabled eager evaluation
+ by 'spark.sql.repl.eagerEval.enabled', this only called by REPL you are
+ using support eager evaluation with HTML.
+ """
+ import cgi
+ if not self._support_repr_html:
+ self._support_repr_html = True
+ if self._eager_eval:
+ max_num_rows = max(self._max_num_rows, 0)
+ vertical = False
+ sock_info = self._jdf.getRowsToPython(
+ max_num_rows, self._truncate, vertical)
+ rows = list(_load_from_socket(sock_info, BatchedSerializer(PickleSerializer())))
+ head = rows[0]
+ row_data = rows[1:]
+ has_more_data = len(row_data) > max_num_rows
+ row_data = row_data[:max_num_rows]
+
+ html = "<table border='1'>\n"
+ # generate table head
+ html += "<tr><th>%s</th></tr>\n" % "</th><th>".join(map(lambda x: cgi.escape(x), head))
+ # generate table rows
+ for row in row_data:
+ html += "<tr><td>%s</td></tr>\n" % "</td><td>".join(
+ map(lambda x: cgi.escape(x), row))
+ html += "</table>\n"
+ if has_more_data:
+ html += "only showing top %d %s\n" % (
+ max_num_rows, "row" if max_num_rows == 1 else "rows")
+ return html
+ else:
+ return None
@since(2.1)
def checkpoint(self, eager=True):
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ea2dd7605d..487eb19c3b 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3074,6 +3074,36 @@ class SQLTests(ReusedSQLTestCase):
finally:
shutil.rmtree(path)
+ def test_repr_html(self):
+ import re
+ pattern = re.compile(r'^ *\|', re.MULTILINE)
+ df = self.spark.createDataFrame([(1, "1"), (22222, "22222")], ("key", "value"))
+ self.assertEquals(None, df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.enabled": True}):
+ expected1 = """<table border='1'>
+ |<tr><th>key</th><th>value</th></tr>
+ |<tr><td>1</td><td>1</td></tr>
+ |<tr><td>22222</td><td>22222</td></tr>
+ |</table>
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected1), df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.truncate": 3}):
+ expected2 = """<table border='1'>
+ |<tr><th>key</th><th>value</th></tr>
+ |<tr><td>1</td><td>1</td></tr>
+ |<tr><td>222</td><td>222</td></tr>
+ |</table>
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected2), df._repr_html_())
+ with self.sql_conf({"spark.sql.repl.eagerEval.maxNumRows": 1}):
+ expected3 = """<table border='1'>
+ |<tr><th>key</th><th>value</th></tr>
+ |<tr><td>1</td><td>1</td></tr>
+ |</table>
+ |only showing top 1 row
+ |"""
+ self.assertEquals(re.sub(pattern, '', expected3), df._repr_html_())
+
class HiveSparkSubmitTests(SparkSubmitTests):
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index abb5ae53f4..f552610469 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -231,16 +231,17 @@ class Dataset[T] private[sql](
}
/**
- * Compose the string representing rows for output
+ * Get rows represented in Sequence by specific truncate and vertical requirement.
*
- * @param _numRows Number of rows to show
+ * @param numRows Number of rows to return
* @param truncate If set to more than 0, truncates strings to `truncate` characters and
* all cells will be aligned right.
- * @param vertical If set to true, prints output rows vertically (one line per column value).
+ * @param vertical If set to true, the rows to return do not need truncate.
*/
- private[sql] def showString(
- _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = {
- val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ private[sql] def getRows(
+ numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Seq[Seq[String]] = {
val newDf = toDF()
val castCols = newDf.logicalPlan.output.map { col =>
// Since binary types in top-level schema fields have a specific format to print,
@@ -251,14 +252,12 @@ class Dataset[T] private[sql](
Column(col).cast(StringType)
}
}
- val takeResult = newDf.select(castCols: _*).take(numRows + 1)
- val hasMoreData = takeResult.length > numRows
- val data = takeResult.take(numRows)
+ val data = newDf.select(castCols: _*).take(numRows + 1)
// For array values, replace Seq and Array with square brackets
// For cells that are beyond `truncate` characters, replace it with the
// first `truncate-3` and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ schema.fieldNames.toSeq +: data.map { row =>
row.toSeq.map { cell =>
val str = cell match {
case null => "null"
@@ -274,6 +273,26 @@ class Dataset[T] private[sql](
}
}: Seq[String]
}
+ }
+
+ /**
+ * Compose the string representing rows for output
+ *
+ * @param _numRows Number of rows to show
+ * @param truncate If set to more than 0, truncates strings to `truncate` characters and
+ * all cells will be aligned right.
+ * @param vertical If set to true, prints output rows vertically (one line per column value).
+ */
+ private[sql] def showString(
+ _numRows: Int,
+ truncate: Int = 20,
+ vertical: Boolean = false): String = {
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ // Get rows represented by Seq[Seq[String]], we may get one more line if it has more data.
+ val tmpRows = getRows(numRows, truncate, vertical)
+
+ val hasMoreData = tmpRows.length - 1 > numRows
+ val rows = tmpRows.take(numRows + 1)
val sb = new StringBuilder
val numCols = schema.fieldNames.length
@@ -291,31 +310,25 @@ class Dataset[T] private[sql](
}
}
+ val paddedRows = rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ if (truncate > 0) {
+ StringUtils.leftPad(cell, colWidths(i))
+ } else {
+ StringUtils.rightPad(cell, colWidths(i))
+ }
+ }
+ }
+
// Create SeparateLine
val sep: String = colWidths.map("-" * _).addString(sb, "+", "+", "+\n").toString()
// column names
- rows.head.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell, colWidths(i))
- } else {
- StringUtils.rightPad(cell, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
-
+ paddedRows.head.addString(sb, "|", "|", "|\n")
sb.append(sep)
// data
- rows.tail.foreach {
- _.zipWithIndex.map { case (cell, i) =>
- if (truncate > 0) {
- StringUtils.leftPad(cell.toString, colWidths(i))
- } else {
- StringUtils.rightPad(cell.toString, colWidths(i))
- }
- }.addString(sb, "|", "|", "|\n")
- }
-
+ paddedRows.tail.foreach(_.addString(sb, "|", "|", "|\n"))
sb.append(sep)
} else {
// Extended display mode enabled
@@ -346,7 +359,7 @@ class Dataset[T] private[sql](
}
// Print a footer
- if (vertical && data.isEmpty) {
+ if (vertical && rows.tail.isEmpty) {
// In a vertical mode, print an empty row set explicitly
sb.append("(0 rows)\n")
} else if (hasMoreData) {
@@ -3209,6 +3222,19 @@ class Dataset[T] private[sql](
}
}
+ private[sql] def getRowsToPython(
+ _numRows: Int,
+ truncate: Int,
+ vertical: Boolean): Array[Any] = {
+ EvaluatePython.registerPicklers()
+ val numRows = _numRows.max(0).min(Int.MaxValue - 1)
+ val rows = getRows(numRows, truncate, vertical).map(_.toArray).toArray
+ val toJava: (Any) => Any = EvaluatePython.toJava(_, ArrayType(ArrayType(StringType)))
+ val iter: Iterator[Array[Byte]] = new SerDeUtil.AutoBatchedPickler(
+ rows.iterator.map(toJava))
+ PythonRDD.serveIterator(iter, "serve-GetRows")
+ }
+
/**
* Collect a Dataset as ArrowPayload byte arrays and serve to PySpark.
*/