Skip to content

Instantly share code, notes, and snippets.

@AlexDel
Last active October 23, 2022 04:50
Show Gist options
  • Save AlexDel/fd533a481241e802c8a9b36840435748 to your computer and use it in GitHub Desktop.
Save AlexDel/fd533a481241e802c8a9b36840435748 to your computer and use it in GitHub Desktop.
class ClassifierRepository:
def __init__():
self.queue_length = queue_length
self.s3_client = minio.Minio(
endpoint=endpoint,
access_key=access_key,
secret_key=secret_key
)
self.dbsession = sqlalchemy.create_engine(conn_string).connect()
def save_classifier(
self,
uuid: str,
bucket_name: str,
object_name: str,
classifier: LinearClassifier):
try:
with io.BytesIO() as temp_buffer:
torch.save(classifier, temp_buffer)
with io.BytesIO(temp_buffer.getvalue()) as buffer:
self.s3_client.put_object(
bucket_name=bucket_name,
object_name=object_name,
data=buffer
)
self.dbsession.add(Document, **{'uuid': uuid,**classifier.metadata.serialize())
except ResponseError as err:
raise err
self.add_classifier_to_cache_queue(
uuid=uuid,
classifier=self.set_model_to_cpu(classifier)
)
if CUDA_DEVICE_ID:
torch.cuda.empty_cache()
def get_classifier(
self,
uuid: str,
bucket_name: str,
object_name: str) -> LinearClassifier:
if uuid in self.queue:
classifier = self.set_model_to_gpu(
model=self.queue[uuid]
)
return classifier
try:
minio_response = self.s3_client.get_object(
bucket_name=bucket_name,
object_name=object_name
)
metadata = self.dbsession.where(Document, {'uuid': uuid})
with io.BytesIO(minio_response.read()) as buffer:
classifier = torch.load(buffer)
classifier.load_metadata(metadata)
except ResponseError as err:
raise err
finally:
minio_response.close()
minio_response.release_conn()
return self.set_model_to_gpu(classifier)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment