Skip to content

Instantly share code, notes, and snippets.

@TheDIM47
Created October 19, 2021 09:47
Show Gist options
  • Save TheDIM47/2ba6899d9fdc790547a669a8e4b70421 to your computer and use it in GitHub Desktop.
Save TheDIM47/2ba6899d9fdc790547a669a8e4b70421 to your computer and use it in GitHub Desktop.
MySqlCatalog - Flink MySQL catalog implementation
import org.apache.flink.connector.jdbc.catalog.AbstractJdbcCatalog;
import org.apache.flink.connector.jdbc.table.JdbcConnectorOptions;
import org.apache.flink.connector.jdbc.table.JdbcDynamicTableFactory;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.api.constraints.UniqueConstraint;
import org.apache.flink.table.catalog.CatalogBaseTable;
import org.apache.flink.table.catalog.CatalogDatabase;
import org.apache.flink.table.catalog.CatalogDatabaseImpl;
import org.apache.flink.table.catalog.CatalogTable;
import org.apache.flink.table.catalog.ObjectPath;
import org.apache.flink.table.catalog.exceptions.CatalogException;
import org.apache.flink.table.catalog.exceptions.DatabaseNotExistException;
import org.apache.flink.table.catalog.exceptions.TableNotExistException;
import org.apache.flink.table.factories.FactoryUtil;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.logical.DecimalType;
import org.apache.flink.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.sql.*;
import java.util.*;
public class MySqlCatalog extends AbstractJdbcCatalog {
private static final Logger LOG = LoggerFactory.getLogger(MySqlCatalog.class);
private static final Set<String> builtinSchemas = new HashSet<>(
Arrays.asList("sys", "mysql", "performance_schema", "information_schema")
);
public MySqlCatalog(String catalogName, String defaultDatabase, String username, String pwd, String baseUrl) {
super(catalogName, defaultDatabase, username, pwd, extractBaseUrl(baseUrl));
// Split db url params and make connection properties
String[] ps = splitParams(baseUrl);
properties = new Properties();
for (String param : ps) {
String[] kv = param.split("=");
if (kv.length == 2 && !"".equals(kv[0]) && !"".equals(kv[1]))
properties.setProperty(kv[0], kv[1]);
}
properties.setProperty("user", properties.getProperty("user") == null ? super.getUsername() : properties.getProperty("user"));
properties.setProperty("password", properties.getProperty("password") == null ? super.getPassword() : properties.getProperty("password"));
}
@Override
public void open() throws CatalogException {
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) {
// test connection, fail early if we cannot connect to database
} catch (SQLException e) {
throw new ValidationException(String.format("Failed connecting to %s via JDBC.", defaultUrl), e);
}
LOG.info("Catalog {} established connection to {}", getName(), defaultUrl);
}
@Override
public List<String> listDatabases() throws CatalogException {
return listSchemas(null);
}
@Override
public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistException, CatalogException {
if (listDatabases().contains(databaseName)) {
return new CatalogDatabaseImpl(Collections.emptyMap(), null);
} else {
throw new DatabaseNotExistException(getName(), databaseName);
}
}
@Override
public List<String> listTables(String databaseName) throws DatabaseNotExistException, CatalogException {
List<String> tables = new ArrayList<>();
String sql = "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = ?;";
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) {
try (PreparedStatement stmt = conn.prepareStatement(sql)) {
stmt.setString(1, databaseName);
try (ResultSet rsTables = stmt.executeQuery()) {
while (rsTables.next()) {
tables.add(databaseName + "." + rsTables.getString(1));
}
return tables;
}
}
} catch (SQLException e) {
throw new CatalogException(String.format("Failed listing tables in database %s", getName()), e);
}
}
@Override
public CatalogBaseTable getTable(ObjectPath tablePath) throws TableNotExistException, CatalogException {
if (!tableExists(tablePath)) {
throw new TableNotExistException(getName(), tablePath);
}
MySqlTablePath myPath = MySqlTablePath.fromFlinkTableName(tablePath.getFullName());
String dbUrl = baseUrl + tablePath.getDatabaseName();
try (Connection conn = DriverManager.getConnection(dbUrl, properties)) {
DatabaseMetaData metaData = conn.getMetaData();
Optional<UniqueConstraint> primaryKey = getPrimaryKey(metaData, myPath.getSchemaName(), myPath.getTableName());
final ResultSetMetaData rsmd;
try (PreparedStatement ps = conn.prepareStatement(String.format("SELECT * FROM %s;", myPath.getFullPath()))) {
rsmd = ps.getMetaData();
}
String[] names = new String[rsmd.getColumnCount()];
DataType[] types = new DataType[rsmd.getColumnCount()];
for (int i = 1; i <= rsmd.getColumnCount(); i++) {
names[i - 1] = rsmd.getColumnName(i);
types[i - 1] = fromJDBCType(rsmd, i);
if (rsmd.isNullable(i) == ResultSetMetaData.columnNoNulls) {
types[i - 1] = types[i - 1].notNull();
}
}
Schema.Builder builder = Schema.newBuilder().fromFields(names, types);
primaryKey.ifPresent(pk -> builder.primaryKeyNamed(pk.getName(), pk.getColumns()));
Schema schema = builder.build();
Map<String, String> props = new HashMap<>();
props.put(FactoryUtil.CONNECTOR.key(), JdbcDynamicTableFactory.IDENTIFIER);
props.put(JdbcConnectorOptions.URL.key(), baseUrl + tablePath.getDatabaseName());
props.put(JdbcConnectorOptions.TABLE_NAME.key(), tablePath.getObjectName());
props.put(JdbcConnectorOptions.USERNAME.key(), username);
props.put(JdbcConnectorOptions.PASSWORD.key(), pwd);
return CatalogTable.of(schema, "", Collections.emptyList(), props);
} catch (Exception e) {
throw new CatalogException(String.format("Failed getting table %s", tablePath.getFullName()), e);
}
}
@Override
public boolean tableExists(ObjectPath tablePath) throws CatalogException {
try {
List<String> tables = listTables(tablePath.getDatabaseName());
boolean result = tables.contains(tablePath.getFullName());
if (!result) {
String flinkPath = MySqlTablePath.fromFlinkTableName(tablePath.getObjectName()).getFullPath();
result = tables.contains(flinkPath);
}
return result;
} catch (DatabaseNotExistException e) {
return false;
}
}
private List<String> listSchemas(String schemaName) throws CatalogException {
List<String> databases = new ArrayList<>();
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) {
String filter = StringUtils.isNullOrWhitespaceOnly(schemaName) ? "" :
String.format(" WHERE schema_name like '%s'", schemaName);
String sql = String.format("SELECT schema_name FROM information_schema.schemata %s;", filter);
try (PreparedStatement ps = conn.prepareStatement(sql)) {
try (ResultSet rs = ps.executeQuery()) {
while (rs.next()) {
String dbName = rs.getString(1);
if (!builtinSchemas.contains(dbName)) {
databases.add(rs.getString(1));
}
}
}
}
return databases;
} catch (Exception e) {
throw new CatalogException(String.format("Failed listing database in catalog %s", getName()), e);
}
}
public static final String MYSQL_BIGINT = "bigint";
public static final String MYSQL_BIGINT_UNSIGNED = "bigint unsigned";
public static final String MYSQL_BINARY = "binary";
public static final String MYSQL_BIT = "bit";
public static final String MYSQL_BLOB = "blob";
public static final String MYSQL_BYTE = "tinyint";
public static final String MYSQL_BYTE_UNSIGNED = "tinyint unsigned";
public static final String MYSQL_CHAR = "char";
public static final String MYSQL_CHARACTER_VARYING = "varchar";
public static final String MYSQL_DATE = "date";
public static final String MYSQL_DATETIME = "datetime";
public static final String MYSQL_DECIMAL = "decimal";
public static final String MYSQL_DECIMAL_UNSIGNED = "decimal unsigned";
public static final String MYSQL_DOUBLE = "double";
public static final String MYSQL_FLOAT = "float";
public static final String MYSQL_GEOMETRY = "geometry";
public static final String MYSQL_INT = "int";
public static final String MYSQL_INT_UNSIGNED = "int unsigned";
public static final String MYSQL_JSON = "json";
public static final String MYSQL_MEDIUMINT = "mediumint";
public static final String MYSQL_NUMERIC = "numeric";
public static final String MYSQL_SMALLINT = "smallint";
public static final String MYSQL_TEXT = "text";
public static final String MYSQL_TIME = "time";
public static final String MYSQL_TIMESTAMP = "timestamp";
public static final String MYSQL_TIMESTAMPTZ = "timestamptz";
public static final String MYSQL_VARBINARY = "varbinary";
public static final String MYSQL_YEAR = "year";
private DataType fromJDBCType(ResultSetMetaData metadata, int colIndex) throws SQLException {
String myType = metadata.getColumnTypeName(colIndex).toLowerCase();
int precision = metadata.getPrecision(colIndex);
int scale = metadata.getScale(colIndex);
switch (myType) {
case MYSQL_BIT:
return DataTypes.BOOLEAN();
case MYSQL_BYTE:
return DataTypes.TINYINT();
case MYSQL_BYTE_UNSIGNED:
case MYSQL_SMALLINT:
case MYSQL_YEAR:
return DataTypes.SMALLINT();
case MYSQL_INT:
case MYSQL_MEDIUMINT:
return DataTypes.INT();
case MYSQL_FLOAT:
return DataTypes.FLOAT();
case MYSQL_DOUBLE:
return DataTypes.DOUBLE();
case MYSQL_BIGINT_UNSIGNED:
case MYSQL_BIGINT:
case MYSQL_INT_UNSIGNED:
return DataTypes.BIGINT();
case MYSQL_DECIMAL:
case MYSQL_DECIMAL_UNSIGNED:
case MYSQL_NUMERIC:
if (precision > 0) {
return DataTypes.DECIMAL(precision, metadata.getScale(colIndex));
}
return DataTypes.DECIMAL(DecimalType.MAX_PRECISION, 18);
case MYSQL_CHAR:
return DataTypes.CHAR(precision);
case MYSQL_CHARACTER_VARYING:
return DataTypes.VARCHAR(precision);
case MYSQL_JSON:
case MYSQL_TEXT:
return DataTypes.STRING();
case MYSQL_TIMESTAMP:
return DataTypes.TIMESTAMP(scale);
case MYSQL_TIMESTAMPTZ:
return DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(scale);
case MYSQL_TIME:
return DataTypes.TIME(scale);
case MYSQL_DATE:
case MYSQL_DATETIME:
return DataTypes.DATE();
case MYSQL_BINARY:
case MYSQL_BLOB:
case MYSQL_GEOMETRY:
case MYSQL_VARBINARY:
return DataTypes.BYTES();
default:
throw new UnsupportedOperationException(String.format("Doesn't support MySQL type '%s' yet", myType));
}
}
private static String extractBaseUrl(String url) {
try {
java.net.URI uri = new java.net.URI(url);
if (uri.getHost() == null) {
uri = new java.net.URI(uri.getRawSchemeSpecificPart());
}
return "jdbc:" + uri.getScheme() + "://" + uri.getHost() + ":" + uri.getPort();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
private static String[] splitParams(String url) {
String[] ps = (url.split("\\?").length > 1 ? url.split("\\?")[1] : "").split("&");
return (ps.length == 1 && "".equals(ps[0])) ? new String[0] : ps;
}
private final Properties properties;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment