Last active
May 2, 2019 09:29
-
-
Save weaming/0ff215bdddce842b1cde8b39aae1c7b4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
pip install docker mysqlclient peewee | |
""" | |
import os | |
import random | |
import time | |
import datetime | |
from playhouse.db_url import parseresult_to_dict, urlparse | |
import MySQLdb | |
try: | |
import docker | |
except ImportError: | |
docker = None | |
class Context: | |
def setup(self): | |
pass | |
def teardown(self): | |
pass | |
def __enter__(self): | |
self.setup() | |
return self | |
def __exit__(self, *args): | |
self.teardown() | |
class DBMock(Context): | |
"""Use existing db connection (from DB_URL) to mock database. Provide the property db_url. | |
example db url: mysql://root:password@127.0.0.1:3306/test | |
""" | |
def __init__(self, db_url, sqlfile, db_name_prefix): | |
self.origin_db_url = db_url | |
self.sqlfile = sqlfile | |
self.db_name = "{}_{}".format( | |
db_name_prefix, datetime.datetime.now().strftime("%Y%m%d%H%M%S") | |
) | |
self.parse_db_url() | |
def parse_db_url(self): | |
print("db_url", self.origin_db_url) | |
parsed = urlparse(self.origin_db_url) | |
self.connect_kwargs = parseresult_to_dict(parsed) | |
def create_db(self): | |
# self.drop_db() | |
if "database" in self.connect_kwargs: | |
del self.connect_kwargs["database"] | |
self.connect() # connect without specify the database | |
self.cursor.execute(f"CREATE DATABASE IF NOT EXISTS {self.db_name};") | |
self.cursor.close() | |
self.db.close() | |
def drop_db(self): | |
if "database" in self.connect_kwargs: | |
del self.connect_kwargs["database"] | |
self.connect() # connect without specify the database | |
self.cursor.execute(f"DROP DATABASE IF EXISTS {self.db_name};") | |
self.db.commit() | |
self.cursor.close() | |
self.db.close() | |
def connect(self): | |
# print('self.connect_kwargs', self.connect_kwargs) | |
self.db = MySQLdb.connect(**self.connect_kwargs) | |
self.cursor = self.db.cursor() | |
def import_sql(self): | |
self.connect_kwargs["database"] = self.db_name | |
self.connect() | |
with open(self.sqlfile) as f: | |
for sql in f.read().split(";\n"): | |
# if sql.lstrip().startswith("--"): | |
# continue | |
self.cursor.execute(sql) | |
# print(sql) | |
self.db.commit() | |
self.cursor.close() | |
self.db.close() | |
def setup(self): | |
self.create_db() | |
self.import_sql() | |
def teardown(self): | |
self.drop_db() | |
@property | |
def db_url(self): | |
return "mysql://{user}:{passwd}@{host}:{port}/{database}".format( | |
**self.connect_kwargs | |
) | |
class DockerDBMock(DBMock): | |
"""Use docker instance to mock database""" | |
sqlfile = None | |
database = "test" | |
docker_wrapper = None # type: DockerMariadb | |
def __init__(self, sqlfile): | |
self.sqlfile = sqlfile | |
def setup(self): | |
if not self.sqlfile: | |
raise Exception("missing sqlfile") | |
# may block in importing | |
self.docker_wrapper = DockerMariadb(self.sqlfile, self.database) | |
def teardown(self): | |
self.docker_wrapper.stop(remove=True, timeout=1) | |
@property | |
def db_url(self): | |
return "mysql://{user}:{password}@{host}:{port}/{database}".format( | |
user=self.docker_wrapper.user, | |
password=self.docker_wrapper.password, | |
host=self.docker_wrapper.host, | |
port=self.docker_wrapper.port, | |
database=self.docker_wrapper.database, | |
) | |
class DockerMariadb: | |
def __new__(cls, *args, **kwargs): | |
if docker is None: | |
raise NotImplementedError( | |
"cannot use DockerMariadb without package docker installed" | |
) | |
return super().__new__() | |
def __init__(self, sqlfile, database="test"): | |
""" | |
environments: | |
IMPORT_TIMEOUT | |
""" | |
self.host = "127.0.0.1" | |
self.port = random.randint(1000, 25535) | |
self.user = "root" | |
self.password = "password" | |
self.database = database | |
self.sqlfile = sqlfile | |
self.name = f"test-mariadb-{self.port}" | |
self.client = docker.from_env() | |
self.container: docker.Container = None | |
# start docker container | |
self.db_run() | |
time.sleep(int(os.getenv("IMPORT_TIMEOUT") or 10)) | |
status = self.status | |
if status and status != "running": | |
raise Exception("container status is {status}") | |
@property | |
def status(self): | |
if self.container: | |
return self.container.status | |
def db_run(self): | |
""" | |
https://docker-py.readthedocs.io/en/stable/containers.html | |
https://hub.docker.com/_/mariadb?tab=description#initializing-a-fresh-instance | |
docker run --name mariadb -p 3306:3306 -e MYSQL_ROOT_PASSWORD=password -d mariadb --character-set-server=utf8mb4 --collation-server=utf8mb4_general_ci | |
""" | |
volume_target = "/docker-entrypoint-initdb.d/sql_dump_file.sql" | |
self.client.containers.run( | |
"mariadb", | |
command="--character-set-server=utf8mb4 --collation-server=utf8mb4_general_ci", | |
name=self.name, | |
detach=True, | |
ports={"3306/tcp": ("127.0.0.1", self.port)}, | |
environment={ | |
"MYSQL_ROOT_PASSWORD": self.password, | |
"MYSQL_DATABASE": self.database, | |
}, | |
volumes={self.sqlfile: {"bind": volume_target, "mode": "ro"}}, | |
) | |
self.container = self.client.containers.get(self.name) | |
def db_import(self): | |
pass | |
def stop(self, remove=False, timeout=1): | |
if self.container: | |
self.container.stop(timeout=timeout) | |
if remove: | |
self.container.remove(force=True) | |
def __repr__(self): | |
return f"<{self.__class__.__name__}= {self.container and self.container.name} (status={self.status} port={self.port})>" | |
# test | |
if __name__ == "__main__": | |
sqlfile = os.getenv("SQLFILE") | |
if not sqlfile: | |
raise Exception("missing env SQLFILE") | |
with DockerDBMock(sqlfile=os.path.abspath(sqlfile)) as env: | |
print(env.db_url) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment