Last active
October 23, 2022 04:50
-
-
Save AlexDel/fd533a481241e802c8a9b36840435748 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ClassifierRepository: | |
def __init__(): | |
self.queue_length = queue_length | |
self.s3_client = minio.Minio( | |
endpoint=endpoint, | |
access_key=access_key, | |
secret_key=secret_key | |
) | |
self.dbsession = sqlalchemy.create_engine(conn_string).connect() | |
def save_classifier( | |
self, | |
uuid: str, | |
bucket_name: str, | |
object_name: str, | |
classifier: LinearClassifier): | |
try: | |
with io.BytesIO() as temp_buffer: | |
torch.save(classifier, temp_buffer) | |
with io.BytesIO(temp_buffer.getvalue()) as buffer: | |
self.s3_client.put_object( | |
bucket_name=bucket_name, | |
object_name=object_name, | |
data=buffer | |
) | |
self.dbsession.add(Document, **{'uuid': uuid,**classifier.metadata.serialize()) | |
except ResponseError as err: | |
raise err | |
self.add_classifier_to_cache_queue( | |
uuid=uuid, | |
classifier=self.set_model_to_cpu(classifier) | |
) | |
if CUDA_DEVICE_ID: | |
torch.cuda.empty_cache() | |
def get_classifier( | |
self, | |
uuid: str, | |
bucket_name: str, | |
object_name: str) -> LinearClassifier: | |
if uuid in self.queue: | |
classifier = self.set_model_to_gpu( | |
model=self.queue[uuid] | |
) | |
return classifier | |
try: | |
minio_response = self.s3_client.get_object( | |
bucket_name=bucket_name, | |
object_name=object_name | |
) | |
metadata = self.dbsession.where(Document, {'uuid': uuid}) | |
with io.BytesIO(minio_response.read()) as buffer: | |
classifier = torch.load(buffer) | |
classifier.load_metadata(metadata) | |
except ResponseError as err: | |
raise err | |
finally: | |
minio_response.close() | |
minio_response.release_conn() | |
return self.set_model_to_gpu(classifier) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment