Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
COPY Spark DataFrame rows to PostgreSQL (via JDBC)
import java.io.InputStream
import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.{ DataFrame, Row }
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection
val jdbcUrl = s"jdbc:postgresql://..." // db credentials elided
val connectionProperties = {
val props = new java.util.Properties()
props.setProperty("driver", "org.postgresql.Driver")
props
}
// Spark reads the "driver" property to allow users to override the default driver selected, otherwise
// it picks the Redshift driver, which doesn't support JDBC CopyManager.
// https://github.com/apache/spark/blob/v1.6.1/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala#L44-51
val cf: () => Connection = JdbcUtils.createConnectionFactory(jdbcUrl, connectionProperties)
// Convert every partition (an `Iterator[Row]`) to bytes (InputStream)
def rowsToInputStream(rows: Iterator[Row], delimiter: String): InputStream = {
val bytes: Iterator[Byte] = rows.map { row =>
(row.mkString(delimiter) + "\n").getBytes
}.flatten
new InputStream {
override def read(): Int = if (bytes.hasNext) {
bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
} else {
-1
}
}
}
// Beware: this will open a db connection for every partition of your DataFrame.
frame.foreachPartition { rows =>
val conn = cf()
val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
cm.copyIn(
"""COPY my_schema._mytable FROM STDIN WITH (NULL 'null', FORMAT CSV, DELIMITER E'\t')""", // adjust COPY settings as you desire, options from https://www.postgresql.org/docs/9.5/static/sql-copy.html
rowsToInputStream(rows, "\t"))
conn.close()
}
@silas

This comment has been minimized.

Copy link

silas commented Nov 21, 2017

Slightly modified to deal with escaping and Spark 2.2

import java.io.InputStream
import java.sql.DriverManager
import java.util.Properties

import org.apache.spark.sql.{DataFrame, Row}
import org.postgresql.copy.CopyManager
import org.postgresql.core.BaseConnection

object CopyHelper {

  def rowsToInputStream(rows: Iterator[Row]): InputStream = {
    val bytes: Iterator[Byte] = rows.map { row =>
      (row.toSeq
        .map { v =>
          if (v == null) {
            """\N"""
          } else {
            "\"" + v.toString.replaceAll("\"", "\"\"") + "\""
          }
        }
        .mkString("\t") + "\n").getBytes
    }.flatten

    new InputStream {
      override def read(): Int =
        if (bytes.hasNext) {
          bytes.next & 0xff // bitwise AND - make the signed byte an unsigned int from 0-255
        } else {
          -1
        }
    }
  }

  def copyIn(driver: String,
             url: String,
             user: String,
             password: String,
             properties: Properties)(df: DataFrame, table: String): Unit = {
    df.foreachPartition { rows =>
      Class.forName(driver)

      val conn = DriverManager.getConnection(url, user, password)

      try {
        val cm = new CopyManager(conn.asInstanceOf[BaseConnection])
        cm.copyIn(
          s"COPY $table " + """FROM STDIN WITH (NULL '\N', FORMAT CSV, DELIMITER E'\t')""",
          rowsToInputStream(rows))
        ()
      } finally {
        conn.close()
      }
    }
  }
}
@eyedia

This comment has been minimized.

Copy link

eyedia commented May 13, 2019

5 star from me. Thank you so much!

2 Questions:

  1. I had to use .toList.flatten.toIterator in line 27, as flatten has been deprecated from iteration. I hope I wont lose any performance.
  2. I cannot have tab char ('\t') in my cell?

Thanks again!

@rghv404

This comment has been minimized.

Copy link

rghv404 commented Aug 1, 2019

This saved incredible amount of hours. Cheers!

@anshul-cached

This comment has been minimized.

Copy link

anshul-cached commented Aug 9, 2019

I got org.apache.spark.SparkException: Task not serializable

@sdressler

This comment has been minimized.

Copy link

sdressler commented Dec 15, 2019

I got org.apache.spark.SparkException: Task not serializable

Had the same, changing the object definition to

object CopyHelper extends Serializable

Fixes it.

@visheshSeshu

This comment has been minimized.

Copy link

visheshSeshu commented Jan 22, 2020

Hi,
I'm new to Scala. Could you please tell me the solution for this issue below. which im implementing in scala spark.

:37: error: identifier expected but string literal found.
def copyIn(driver: "org.postgresql.Driver",
^
:58: error: ')' expected but '}' found.
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.