Skip to content

Instantly share code, notes, and snippets.

@bfouts-osaro
Last active April 4, 2020 02:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bfouts-osaro/136ceeaab7f3e9024f045b2de2d061a5 to your computer and use it in GitHub Desktop.
Save bfouts-osaro/136ceeaab7f3e9024f045b2de2d061a5 to your computer and use it in GitHub Desktop.
ModelBlueprint.get_all tests
import requests
import pytest
import pprint
from geri.client.types import ModelBlueprint
from geri.client import transport
@pytest.fixture(autouse=True)
def disable_server_calls(monkeypatch):
#this will ensure that we never actually try to hit the server
monkeypatch.delattr("requests.sessions.Session.request")
@pytest.fixture(autouse=True)
def geri_prod(monkeypatch):
#we need to do something like this to ensure that we are using the prod environment, but this doesn't quite work yet
monkeypatch.setenv("GERI_ENV", "prod")
class MockResponse:
def __init__(self, data, status_code=200):
self.data =data
self.status_code = status_code
def json(self):
return self.data
class MockRequest:
def __init__(self, **kwargs):
self.url = None
self.request_args = {}
self.kwargs = kwargs
def get(self, url, **request_args):
self.url = url
self.request_args = request_args
result = {
'pagination_key': None,
'resources': [
{
'collection_version': self.kwargs.get('collection_version', 0),
'committed': False,
'created_at': '2019-10-31T00:09:25.110925',
'updated_at': '2019-10-31T00:09:25.110925',
'full_kind': 'model_blueprint.1',
'kind': 'model_blueprint',
'name': 'wills_test_blueprint',
'oid': '973d9ad3-5a35-4d5b-929f-3ffee8cf5077',
'project': {
'full_kind': 'project.default.1',
'kind': 'project.default',
'name': 'osaro/misc',
'oid': '10ed59bd-0ea2-4fd2-a4cc-1efe6c189549',
'version': 1
},
'spec': {
'updated_at': '2019-11-07T21:55:52.481618',
'aggregators': {
'best_model': {
'()': 'mask_rcnn.model_blueprints.aggregators.gqfcn.cross_entropy.GQFCNV1CrossEntropyAggregator'
}
}
}
}]
}
return MockResponse(result, self.request_args.get('status_code', 200))
def test_show_me_available_properties(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
pprint.pprint(results[0].__dir__())
#uncomment the following line to see what properties the object has. stdout is only produced on error with pytest
#assert 1 == 2
def test_model_blueprint_get_all_request_url(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert request.url == "https://geri.osaro.io:443/api/v1/model_blueprints"
def test_model_blueprint_get_all_request_headers(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert request.request_args['headers']['Accept'] == 'application/json'
assert 'Bearer ' in request.request_args['headers']['Authorization']
def test_model_blueprint_get_all_request_stream(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert not request.request_args['stream']
def test_model_blueprint_get_all_request_data(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert request.request_args['json']['data'] == {}
#use sets to compare, as we don't care about order
assert set(request.request_args['json']['fields']) == set([
'full_kind',
'created_at',
'project',
'version',
'spec',
'kind',
'updated_at',
'committed',
'oid',
'collection_version',
'name'])
def test_model_blueprint_result_class(monkeypatch):
request = MockRequest()
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert results[0].__class__.__name__ == "prod_ModelBlueprint"
def test_model_blueprint_collection_version(monkeypatch):
request = MockRequest(collection_version=111)
monkeypatch.setattr(requests.sessions.Session, "get", request.get)
results = list(ModelBlueprint.get_all())
assert results[0].collection_version == 111
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment