Skip to content

Instantly share code, notes, and snippets.

@geosmart
Created July 19, 2022 03:48
Show Gist options
  • Save geosmart/e031702fdb3f2334cf0d2412ac7ba9c4 to your computer and use it in GitHub Desktop.
Save geosmart/e031702fdb3f2334cf0d2412ac7ba9c4 to your computer and use it in GitHub Desktop.
Snowflake fro python3
import time
import logging
# 位数
WORKER_ID_BITS = 5
DATA_CENTER_ID_BITS = 5
SEQUENCE_BITS = 12
TIMESTAMP_EPOCH = 1288834974657
# 0-31
MAX_WORKER_ID = -1 ^ (-1 << WORKER_ID_BITS)
# 0-31
MAX_DATA_CENTER_ID = -1 ^ (-1 << DATA_CENTER_ID_BITS)
# 0-4095
SEQUENCE_MASK = -1 ^ (-1 << SEQUENCE_BITS)
# 偏移量
WORKER_ID_SHIFT = SEQUENCE_BITS
DATA_CENTER_ID_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS
TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATA_CENTER_ID_BITS
class Snowflake(object):
"""
snowflake 64位的组成
1: 未使用(符号位表示正数)
41:timestamp,(41位的长度可以使用69年)
5:dcId,0-31
5:workerId,0-31,(dcId+workerId,10位的长度最多支持部署1024个节点)
12:sequence,0-4095,(12位的计数顺序号支持每个节点每毫秒产生4096个ID序号)
"""
def __init__(self, data_center_id, worker_id):
self.dc_id = data_center_id
self.worker_id = worker_id
self.last_timestamp = TIMESTAMP_EPOCH
self.sequence = 0
self.sequence_overload = 0
self.errors = 0
self.generated_ids = 0
def get_time(self):
# 获取毫秒时间戳
return int(time.time() * 1000)
def get_next_id(self):
timestamp = self.get_time()
if timestamp < self.last_timestamp:
# stop handling requests til we've caught back up
if self.last_timestamp - timestamp < 2000:
# 容忍2秒内的回拨,避免NTP校时造成的异常
timestamp = self.last_timestamp
else:
# 如果服务器时间有问题(时钟后退) 报错
self.errors += 1
raise Exception(
"Clock went backwards! %d < %d" % (timestamp, self.last_timestamp)
)
if timestamp == self.last_timestamp:
self.sequence = (self.sequence + 1) & SEQUENCE_MASK
if self.sequence == 0:
logging.warning("The sequence has been overload")
self.sequence_overload += 1
timestamp = self.til_next_millis(self.last_timestamp)
else:
self.sequence = 0
self.last_timestamp = timestamp
return (
((timestamp - TIMESTAMP_EPOCH) << TIMESTAMP_LEFT_SHIFT)
| (self.dc_id << DATA_CENTER_ID_SHIFT)
| (self.worker_id << WORKER_ID_SHIFT)
| self.sequence
)
def til_next_millis(self, last_timestamp):
"""
循环等待下一个时间
"""
timestamp = self.get_time()
# 循环直到操作系统时间戳变化
while timestamp == last_timestamp:
timestamp = self.get_time()
if timestamp < last_timestamp:
# 如果发现新的时间戳比上次记录的时间戳数值小,说明操作系统时间发生了倒退,报错
raise Exception(
"Clock moved backwards. Refusing to generate id for {}ms".format(
last_timestamp - timestamp
)
)
return timestamp
def get_worker_id(self, id: int):
# 根据Snowflake的ID,获取机器id
return id >> WORKER_ID_SHIFT & ~(-1 << WORKER_ID_BITS)
def get_data_center_id(self, id: int):
# 根据Snowflake的ID,获取数据中心id
return id >> DATA_CENTER_ID_SHIFT & ~(-1 << DATA_CENTER_ID_BITS)
def get_generate_timestamp(self, id: int):
# 根据Snowflake的ID,获取生成的时间
return (id >> TIMESTAMP_LEFT_SHIFT & ~(-1 << 41)) + TIMESTAMP_EPOCH
def get_id_info(self, id: int):
return {
"dc": self.get_data_center_id(id),
"worker": self.get_worker_id(id),
"timestamp": self.get_generate_timestamp(id),
}
@property
def stats(self):
return {
"dc": self.dc_id,
"worker": self.worker_id,
"timestamp": self.get_time(), # current timestamp for this worker
"last_timestamp": self.last_timestamp, # the last timestamp that generated ID on
"sequence": self.sequence, # the sequence number for last timestamp
"sequence_overload": self.sequence_overload, # the number of times that the sequence is overflow
"errors": self.errors, # the number of times that clock went backward
}
if __name__ == "__main__":
import socket
import hashlib
hostname = socket.gethostname()
print(hostname)
dc = int(hashlib.sha1(hostname.encode("utf-8")).hexdigest(), 16) % 32
workerId = int(hashlib.blake2s(hostname.encode("utf-8")).hexdigest(), 16) % 32
id_util = Snowflake(dc, workerId)
snowflake_id = id_util.get_next_id()
print(snowflake_id)
print(id_util.stats)
print(id_util.get_id_info(snowflake_id))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment