Created
October 19, 2021 09:47
-
-
Save TheDIM47/2ba6899d9fdc790547a669a8e4b70421 to your computer and use it in GitHub Desktop.
MySqlCatalog - Flink MySQL catalog implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import org.apache.flink.connector.jdbc.catalog.AbstractJdbcCatalog; | |
import org.apache.flink.connector.jdbc.table.JdbcConnectorOptions; | |
import org.apache.flink.connector.jdbc.table.JdbcDynamicTableFactory; | |
import org.apache.flink.table.api.DataTypes; | |
import org.apache.flink.table.api.Schema; | |
import org.apache.flink.table.api.ValidationException; | |
import org.apache.flink.table.api.constraints.UniqueConstraint; | |
import org.apache.flink.table.catalog.CatalogBaseTable; | |
import org.apache.flink.table.catalog.CatalogDatabase; | |
import org.apache.flink.table.catalog.CatalogDatabaseImpl; | |
import org.apache.flink.table.catalog.CatalogTable; | |
import org.apache.flink.table.catalog.ObjectPath; | |
import org.apache.flink.table.catalog.exceptions.CatalogException; | |
import org.apache.flink.table.catalog.exceptions.DatabaseNotExistException; | |
import org.apache.flink.table.catalog.exceptions.TableNotExistException; | |
import org.apache.flink.table.factories.FactoryUtil; | |
import org.apache.flink.table.types.DataType; | |
import org.apache.flink.table.types.logical.DecimalType; | |
import org.apache.flink.util.StringUtils; | |
import org.slf4j.Logger; | |
import org.slf4j.LoggerFactory; | |
import java.sql.*; | |
import java.util.*; | |
public class MySqlCatalog extends AbstractJdbcCatalog { | |
private static final Logger LOG = LoggerFactory.getLogger(MySqlCatalog.class); | |
private static final Set<String> builtinSchemas = new HashSet<>( | |
Arrays.asList("sys", "mysql", "performance_schema", "information_schema") | |
); | |
public MySqlCatalog(String catalogName, String defaultDatabase, String username, String pwd, String baseUrl) { | |
super(catalogName, defaultDatabase, username, pwd, extractBaseUrl(baseUrl)); | |
// Split db url params and make connection properties | |
String[] ps = splitParams(baseUrl); | |
properties = new Properties(); | |
for (String param : ps) { | |
String[] kv = param.split("="); | |
if (kv.length == 2 && !"".equals(kv[0]) && !"".equals(kv[1])) | |
properties.setProperty(kv[0], kv[1]); | |
} | |
properties.setProperty("user", properties.getProperty("user") == null ? super.getUsername() : properties.getProperty("user")); | |
properties.setProperty("password", properties.getProperty("password") == null ? super.getPassword() : properties.getProperty("password")); | |
} | |
@Override | |
public void open() throws CatalogException { | |
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) { | |
// test connection, fail early if we cannot connect to database | |
} catch (SQLException e) { | |
throw new ValidationException(String.format("Failed connecting to %s via JDBC.", defaultUrl), e); | |
} | |
LOG.info("Catalog {} established connection to {}", getName(), defaultUrl); | |
} | |
@Override | |
public List<String> listDatabases() throws CatalogException { | |
return listSchemas(null); | |
} | |
@Override | |
public CatalogDatabase getDatabase(String databaseName) throws DatabaseNotExistException, CatalogException { | |
if (listDatabases().contains(databaseName)) { | |
return new CatalogDatabaseImpl(Collections.emptyMap(), null); | |
} else { | |
throw new DatabaseNotExistException(getName(), databaseName); | |
} | |
} | |
@Override | |
public List<String> listTables(String databaseName) throws DatabaseNotExistException, CatalogException { | |
List<String> tables = new ArrayList<>(); | |
String sql = "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema = ?;"; | |
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) { | |
try (PreparedStatement stmt = conn.prepareStatement(sql)) { | |
stmt.setString(1, databaseName); | |
try (ResultSet rsTables = stmt.executeQuery()) { | |
while (rsTables.next()) { | |
tables.add(databaseName + "." + rsTables.getString(1)); | |
} | |
return tables; | |
} | |
} | |
} catch (SQLException e) { | |
throw new CatalogException(String.format("Failed listing tables in database %s", getName()), e); | |
} | |
} | |
@Override | |
public CatalogBaseTable getTable(ObjectPath tablePath) throws TableNotExistException, CatalogException { | |
if (!tableExists(tablePath)) { | |
throw new TableNotExistException(getName(), tablePath); | |
} | |
MySqlTablePath myPath = MySqlTablePath.fromFlinkTableName(tablePath.getFullName()); | |
String dbUrl = baseUrl + tablePath.getDatabaseName(); | |
try (Connection conn = DriverManager.getConnection(dbUrl, properties)) { | |
DatabaseMetaData metaData = conn.getMetaData(); | |
Optional<UniqueConstraint> primaryKey = getPrimaryKey(metaData, myPath.getSchemaName(), myPath.getTableName()); | |
final ResultSetMetaData rsmd; | |
try (PreparedStatement ps = conn.prepareStatement(String.format("SELECT * FROM %s;", myPath.getFullPath()))) { | |
rsmd = ps.getMetaData(); | |
} | |
String[] names = new String[rsmd.getColumnCount()]; | |
DataType[] types = new DataType[rsmd.getColumnCount()]; | |
for (int i = 1; i <= rsmd.getColumnCount(); i++) { | |
names[i - 1] = rsmd.getColumnName(i); | |
types[i - 1] = fromJDBCType(rsmd, i); | |
if (rsmd.isNullable(i) == ResultSetMetaData.columnNoNulls) { | |
types[i - 1] = types[i - 1].notNull(); | |
} | |
} | |
Schema.Builder builder = Schema.newBuilder().fromFields(names, types); | |
primaryKey.ifPresent(pk -> builder.primaryKeyNamed(pk.getName(), pk.getColumns())); | |
Schema schema = builder.build(); | |
Map<String, String> props = new HashMap<>(); | |
props.put(FactoryUtil.CONNECTOR.key(), JdbcDynamicTableFactory.IDENTIFIER); | |
props.put(JdbcConnectorOptions.URL.key(), baseUrl + tablePath.getDatabaseName()); | |
props.put(JdbcConnectorOptions.TABLE_NAME.key(), tablePath.getObjectName()); | |
props.put(JdbcConnectorOptions.USERNAME.key(), username); | |
props.put(JdbcConnectorOptions.PASSWORD.key(), pwd); | |
return CatalogTable.of(schema, "", Collections.emptyList(), props); | |
} catch (Exception e) { | |
throw new CatalogException(String.format("Failed getting table %s", tablePath.getFullName()), e); | |
} | |
} | |
@Override | |
public boolean tableExists(ObjectPath tablePath) throws CatalogException { | |
try { | |
List<String> tables = listTables(tablePath.getDatabaseName()); | |
boolean result = tables.contains(tablePath.getFullName()); | |
if (!result) { | |
String flinkPath = MySqlTablePath.fromFlinkTableName(tablePath.getObjectName()).getFullPath(); | |
result = tables.contains(flinkPath); | |
} | |
return result; | |
} catch (DatabaseNotExistException e) { | |
return false; | |
} | |
} | |
private List<String> listSchemas(String schemaName) throws CatalogException { | |
List<String> databases = new ArrayList<>(); | |
try (Connection conn = DriverManager.getConnection(defaultUrl, properties)) { | |
String filter = StringUtils.isNullOrWhitespaceOnly(schemaName) ? "" : | |
String.format(" WHERE schema_name like '%s'", schemaName); | |
String sql = String.format("SELECT schema_name FROM information_schema.schemata %s;", filter); | |
try (PreparedStatement ps = conn.prepareStatement(sql)) { | |
try (ResultSet rs = ps.executeQuery()) { | |
while (rs.next()) { | |
String dbName = rs.getString(1); | |
if (!builtinSchemas.contains(dbName)) { | |
databases.add(rs.getString(1)); | |
} | |
} | |
} | |
} | |
return databases; | |
} catch (Exception e) { | |
throw new CatalogException(String.format("Failed listing database in catalog %s", getName()), e); | |
} | |
} | |
public static final String MYSQL_BIGINT = "bigint"; | |
public static final String MYSQL_BIGINT_UNSIGNED = "bigint unsigned"; | |
public static final String MYSQL_BINARY = "binary"; | |
public static final String MYSQL_BIT = "bit"; | |
public static final String MYSQL_BLOB = "blob"; | |
public static final String MYSQL_BYTE = "tinyint"; | |
public static final String MYSQL_BYTE_UNSIGNED = "tinyint unsigned"; | |
public static final String MYSQL_CHAR = "char"; | |
public static final String MYSQL_CHARACTER_VARYING = "varchar"; | |
public static final String MYSQL_DATE = "date"; | |
public static final String MYSQL_DATETIME = "datetime"; | |
public static final String MYSQL_DECIMAL = "decimal"; | |
public static final String MYSQL_DECIMAL_UNSIGNED = "decimal unsigned"; | |
public static final String MYSQL_DOUBLE = "double"; | |
public static final String MYSQL_FLOAT = "float"; | |
public static final String MYSQL_GEOMETRY = "geometry"; | |
public static final String MYSQL_INT = "int"; | |
public static final String MYSQL_INT_UNSIGNED = "int unsigned"; | |
public static final String MYSQL_JSON = "json"; | |
public static final String MYSQL_MEDIUMINT = "mediumint"; | |
public static final String MYSQL_NUMERIC = "numeric"; | |
public static final String MYSQL_SMALLINT = "smallint"; | |
public static final String MYSQL_TEXT = "text"; | |
public static final String MYSQL_TIME = "time"; | |
public static final String MYSQL_TIMESTAMP = "timestamp"; | |
public static final String MYSQL_TIMESTAMPTZ = "timestamptz"; | |
public static final String MYSQL_VARBINARY = "varbinary"; | |
public static final String MYSQL_YEAR = "year"; | |
private DataType fromJDBCType(ResultSetMetaData metadata, int colIndex) throws SQLException { | |
String myType = metadata.getColumnTypeName(colIndex).toLowerCase(); | |
int precision = metadata.getPrecision(colIndex); | |
int scale = metadata.getScale(colIndex); | |
switch (myType) { | |
case MYSQL_BIT: | |
return DataTypes.BOOLEAN(); | |
case MYSQL_BYTE: | |
return DataTypes.TINYINT(); | |
case MYSQL_BYTE_UNSIGNED: | |
case MYSQL_SMALLINT: | |
case MYSQL_YEAR: | |
return DataTypes.SMALLINT(); | |
case MYSQL_INT: | |
case MYSQL_MEDIUMINT: | |
return DataTypes.INT(); | |
case MYSQL_FLOAT: | |
return DataTypes.FLOAT(); | |
case MYSQL_DOUBLE: | |
return DataTypes.DOUBLE(); | |
case MYSQL_BIGINT_UNSIGNED: | |
case MYSQL_BIGINT: | |
case MYSQL_INT_UNSIGNED: | |
return DataTypes.BIGINT(); | |
case MYSQL_DECIMAL: | |
case MYSQL_DECIMAL_UNSIGNED: | |
case MYSQL_NUMERIC: | |
if (precision > 0) { | |
return DataTypes.DECIMAL(precision, metadata.getScale(colIndex)); | |
} | |
return DataTypes.DECIMAL(DecimalType.MAX_PRECISION, 18); | |
case MYSQL_CHAR: | |
return DataTypes.CHAR(precision); | |
case MYSQL_CHARACTER_VARYING: | |
return DataTypes.VARCHAR(precision); | |
case MYSQL_JSON: | |
case MYSQL_TEXT: | |
return DataTypes.STRING(); | |
case MYSQL_TIMESTAMP: | |
return DataTypes.TIMESTAMP(scale); | |
case MYSQL_TIMESTAMPTZ: | |
return DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(scale); | |
case MYSQL_TIME: | |
return DataTypes.TIME(scale); | |
case MYSQL_DATE: | |
case MYSQL_DATETIME: | |
return DataTypes.DATE(); | |
case MYSQL_BINARY: | |
case MYSQL_BLOB: | |
case MYSQL_GEOMETRY: | |
case MYSQL_VARBINARY: | |
return DataTypes.BYTES(); | |
default: | |
throw new UnsupportedOperationException(String.format("Doesn't support MySQL type '%s' yet", myType)); | |
} | |
} | |
private static String extractBaseUrl(String url) { | |
try { | |
java.net.URI uri = new java.net.URI(url); | |
if (uri.getHost() == null) { | |
uri = new java.net.URI(uri.getRawSchemeSpecificPart()); | |
} | |
return "jdbc:" + uri.getScheme() + "://" + uri.getHost() + ":" + uri.getPort(); | |
} catch (Exception e) { | |
throw new RuntimeException(e); | |
} | |
} | |
private static String[] splitParams(String url) { | |
String[] ps = (url.split("\\?").length > 1 ? url.split("\\?")[1] : "").split("&"); | |
return (ps.length == 1 && "".equals(ps[0])) ? new String[0] : ps; | |
} | |
private final Properties properties; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment