Skip to content

Instantly share code, notes, and snippets.

@deanm0000
Created June 20, 2024 18:59
Show Gist options
  • Save deanm0000/95723b17fc8bec3e64a173deae964f0b to your computer and use it in GitHub Desktop.
Save deanm0000/95723b17fc8bec3e64a173deae964f0b to your computer and use it in GitHub Desktop.
def parse_dtypes(df, exclude=[]):
str_cols = [x for x, y in df.schema.items() if y == pl.String and x not in exclude]
try_casts = df.select(
pl.struct(pl.all()).alias("original"),
pl.struct(
pl.coalesce(
pl.col(col).str.strptime(pl.Datetime, x, strict=False)
for x in ["%Y-%m-%dT%H:%M:%S", "%Y-%m-%d %H:%M:%S"]
)
for col in str_cols
).alias("Datetime"),
pl.struct(
pl.coalesce(
pl.col(col).str.strptime(pl.Date, x, strict=False) for x in ["%Y-%m-%d"]
)
for col in str_cols
).alias("Date"),
*[
pl.struct(pl.col(col).cast(x, strict=False) for col in str_cols).alias(
str(x)
)
for x in [pl.Int64, pl.Float64]
],
)
good_cast = (
pl.concat(
[
try_casts.select(
pl.lit(str(x)).alias("dtype"),
*(
pl.col(str(x)).struct.field(col).null_count()
- pl.col("original").struct.field(col).null_count()
for col in str_cols
),
)
for x in [pl.Datetime, pl.Date, pl.Float64, pl.Int64]
]
)
.melt("dtype")
.filter(pl.col("value") == 0)
.sort("dtype")
.unique("variable", keep="last", maintain_order=True)
.drop("value")
)
map_cols = pl.concat(
[
good_cast,
pl.Series("variable", df.columns)
.to_frame()
.with_columns(pl.lit("original").alias("dtype")),
],
how="diagonal_relaxed",
).unique("variable", keep="first")
return try_casts.select(
pl.col(x).struct.field(y) for x, y in map_cols.iter_rows()
).select(df.columns)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment