Created
July 19, 2022 03:48
-
-
Save geosmart/e031702fdb3f2334cf0d2412ac7ba9c4 to your computer and use it in GitHub Desktop.
Snowflake fro python3
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import logging | |
# 位数 | |
WORKER_ID_BITS = 5 | |
DATA_CENTER_ID_BITS = 5 | |
SEQUENCE_BITS = 12 | |
TIMESTAMP_EPOCH = 1288834974657 | |
# 0-31 | |
MAX_WORKER_ID = -1 ^ (-1 << WORKER_ID_BITS) | |
# 0-31 | |
MAX_DATA_CENTER_ID = -1 ^ (-1 << DATA_CENTER_ID_BITS) | |
# 0-4095 | |
SEQUENCE_MASK = -1 ^ (-1 << SEQUENCE_BITS) | |
# 偏移量 | |
WORKER_ID_SHIFT = SEQUENCE_BITS | |
DATA_CENTER_ID_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS | |
TIMESTAMP_LEFT_SHIFT = SEQUENCE_BITS + WORKER_ID_BITS + DATA_CENTER_ID_BITS | |
class Snowflake(object): | |
""" | |
snowflake 64位的组成 | |
1: 未使用(符号位表示正数) | |
41:timestamp,(41位的长度可以使用69年) | |
5:dcId,0-31 | |
5:workerId,0-31,(dcId+workerId,10位的长度最多支持部署1024个节点) | |
12:sequence,0-4095,(12位的计数顺序号支持每个节点每毫秒产生4096个ID序号) | |
""" | |
def __init__(self, data_center_id, worker_id): | |
self.dc_id = data_center_id | |
self.worker_id = worker_id | |
self.last_timestamp = TIMESTAMP_EPOCH | |
self.sequence = 0 | |
self.sequence_overload = 0 | |
self.errors = 0 | |
self.generated_ids = 0 | |
def get_time(self): | |
# 获取毫秒时间戳 | |
return int(time.time() * 1000) | |
def get_next_id(self): | |
timestamp = self.get_time() | |
if timestamp < self.last_timestamp: | |
# stop handling requests til we've caught back up | |
if self.last_timestamp - timestamp < 2000: | |
# 容忍2秒内的回拨,避免NTP校时造成的异常 | |
timestamp = self.last_timestamp | |
else: | |
# 如果服务器时间有问题(时钟后退) 报错 | |
self.errors += 1 | |
raise Exception( | |
"Clock went backwards! %d < %d" % (timestamp, self.last_timestamp) | |
) | |
if timestamp == self.last_timestamp: | |
self.sequence = (self.sequence + 1) & SEQUENCE_MASK | |
if self.sequence == 0: | |
logging.warning("The sequence has been overload") | |
self.sequence_overload += 1 | |
timestamp = self.til_next_millis(self.last_timestamp) | |
else: | |
self.sequence = 0 | |
self.last_timestamp = timestamp | |
return ( | |
((timestamp - TIMESTAMP_EPOCH) << TIMESTAMP_LEFT_SHIFT) | |
| (self.dc_id << DATA_CENTER_ID_SHIFT) | |
| (self.worker_id << WORKER_ID_SHIFT) | |
| self.sequence | |
) | |
def til_next_millis(self, last_timestamp): | |
""" | |
循环等待下一个时间 | |
""" | |
timestamp = self.get_time() | |
# 循环直到操作系统时间戳变化 | |
while timestamp == last_timestamp: | |
timestamp = self.get_time() | |
if timestamp < last_timestamp: | |
# 如果发现新的时间戳比上次记录的时间戳数值小,说明操作系统时间发生了倒退,报错 | |
raise Exception( | |
"Clock moved backwards. Refusing to generate id for {}ms".format( | |
last_timestamp - timestamp | |
) | |
) | |
return timestamp | |
def get_worker_id(self, id: int): | |
# 根据Snowflake的ID,获取机器id | |
return id >> WORKER_ID_SHIFT & ~(-1 << WORKER_ID_BITS) | |
def get_data_center_id(self, id: int): | |
# 根据Snowflake的ID,获取数据中心id | |
return id >> DATA_CENTER_ID_SHIFT & ~(-1 << DATA_CENTER_ID_BITS) | |
def get_generate_timestamp(self, id: int): | |
# 根据Snowflake的ID,获取生成的时间 | |
return (id >> TIMESTAMP_LEFT_SHIFT & ~(-1 << 41)) + TIMESTAMP_EPOCH | |
def get_id_info(self, id: int): | |
return { | |
"dc": self.get_data_center_id(id), | |
"worker": self.get_worker_id(id), | |
"timestamp": self.get_generate_timestamp(id), | |
} | |
@property | |
def stats(self): | |
return { | |
"dc": self.dc_id, | |
"worker": self.worker_id, | |
"timestamp": self.get_time(), # current timestamp for this worker | |
"last_timestamp": self.last_timestamp, # the last timestamp that generated ID on | |
"sequence": self.sequence, # the sequence number for last timestamp | |
"sequence_overload": self.sequence_overload, # the number of times that the sequence is overflow | |
"errors": self.errors, # the number of times that clock went backward | |
} | |
if __name__ == "__main__": | |
import socket | |
import hashlib | |
hostname = socket.gethostname() | |
print(hostname) | |
dc = int(hashlib.sha1(hostname.encode("utf-8")).hexdigest(), 16) % 32 | |
workerId = int(hashlib.blake2s(hostname.encode("utf-8")).hexdigest(), 16) % 32 | |
id_util = Snowflake(dc, workerId) | |
snowflake_id = id_util.get_next_id() | |
print(snowflake_id) | |
print(id_util.stats) | |
print(id_util.get_id_info(snowflake_id)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment