Created
December 27, 2017 10:59
-
-
Save ficapy/196552486ec5fe594771594d9d5eb67a to your computer and use it in GitHub Desktop.
Postgresql 批量更新对比单条更新
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# -*- coding: utf-8 -*- | |
# Author: ficapy | |
import random | |
import csv | |
import time | |
from functools import wraps | |
from io import StringIO | |
from contextlib import closing, contextmanager | |
from psycopg2.pool import ThreadedConnectionPool | |
from psycopg2.extras import execute_values | |
pool = ThreadedConnectionPool( | |
5, ## Min | |
20, ## Max | |
database='dbname', | |
user='username', | |
password='pwd') | |
@contextmanager | |
def get_curs(): | |
conn = pool.getconn() | |
try: | |
yield conn.cursor() | |
conn.commit() | |
except Exception: | |
conn.rollback() | |
finally: | |
pool.putconn(conn) | |
def timeit(func): | |
@wraps(func) | |
def inner(*args, **kwargs): | |
with closing(StringIO()) as f, get_curs() as curs: | |
curs.execute("DROP TABLE IF EXISTS demo;") | |
curs.execute("""CREATE TABLE IF NOT EXISTS demo ( | |
id INT PRIMARY KEY, | |
mch_id INT | |
)""") | |
writer = csv.writer(f) | |
for i in range(1, 1000000): | |
writer.writerow(map(str, [i, random.randint(1, 1000)])) | |
f.seek(0) | |
curs.copy_from(f, 'demo', sep=',') | |
print("start {}".format(func.__name__)) | |
start = time.time() | |
func(*args, **kwargs) | |
print("{} elapse times: {}".format(func.__name__, time.time() - start)) | |
return inner | |
@timeit | |
def single(data: dict): | |
for key, value in data.items(): | |
with get_curs() as curs: | |
curs.execute("UPDATE demo SET mch_id = %s WHERE id = %s", (value,key)) | |
@timeit | |
def batch(data: dict): | |
with get_curs() as curs: | |
execute_values(curs, | |
"UPDATE demo SET mch_id=tmp.mch_id FROM (VALUES %s) AS tmp (id,mch_id) WHERE demo.id=tmp.id", | |
list(data.items())) | |
if __name__ == '__main__': | |
data = {i: random.randint(1, 1000) for i in range(1, 1000000)} | |
batch(data) | |
single(data) | |
# Result | |
# start batch | |
# batch elapse times: 13.745781898498535 | |
# start single | |
# single elapse times: 343.0373680591583 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment