Skip to content

Instantly share code, notes, and snippets.

@harupy
Created November 14, 2022 05:44
Show Gist options
  • Save harupy/42ee881cc14b1a96ed4bc8a60b061f19 to your computer and use it in GitHub Desktop.
Save harupy/42ee881cc14b1a96ed4bc8a60b061f19 to your computer and use it in GitHub Desktop.
import mlflow
from sklearn.linear_model import LogisticRegression
from pathlib import Path
import tempfile
with mlflow.start_run():
mi = mlflow.sklearn.log_model(LogisticRegression(), "model")
client = mlflow.MlflowClient()
with tempfile.TemporaryDirectory() as tmpdir:
dst = mlflow.artifacts.download_artifacts(mi.model_uri, dst_path=tmpdir)
for p in Path(dst).rglob("*"):
rel_path = p.relative_to(dst)
if rel_path.name == "requirements.txt":
contents = p.read_text()
new_contents = "\n".join(
"mlflow==1.30.0" if l == "mlflow" else l for l in contents.splitlines()
)
p.write_text(new_contents)
client.log_artifact(mi.run_id, str(p), "model")
elif rel_path.name == "conda.yaml":
contents = p.read_text()
new_contents = "\n".join(
l.replace("mlflow", "mlflow==1.30.0") if l.endswith(" - mlflow") else l
for l in contents.splitlines()
)
p.write_text(new_contents)
client.log_artifact(mi.run_id, str(p), "model")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment