Last active
April 4, 2020 02:50
-
-
Save bfouts-osaro/136ceeaab7f3e9024f045b2de2d061a5 to your computer and use it in GitHub Desktop.
ModelBlueprint.get_all tests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import pytest | |
import pprint | |
from geri.client.types import ModelBlueprint | |
from geri.client import transport | |
@pytest.fixture(autouse=True) | |
def disable_server_calls(monkeypatch): | |
#this will ensure that we never actually try to hit the server | |
monkeypatch.delattr("requests.sessions.Session.request") | |
@pytest.fixture(autouse=True) | |
def geri_prod(monkeypatch): | |
#we need to do something like this to ensure that we are using the prod environment, but this doesn't quite work yet | |
monkeypatch.setenv("GERI_ENV", "prod") | |
class MockResponse: | |
def __init__(self, data, status_code=200): | |
self.data =data | |
self.status_code = status_code | |
def json(self): | |
return self.data | |
class MockRequest: | |
def __init__(self, **kwargs): | |
self.url = None | |
self.request_args = {} | |
self.kwargs = kwargs | |
def get(self, url, **request_args): | |
self.url = url | |
self.request_args = request_args | |
result = { | |
'pagination_key': None, | |
'resources': [ | |
{ | |
'collection_version': self.kwargs.get('collection_version', 0), | |
'committed': False, | |
'created_at': '2019-10-31T00:09:25.110925', | |
'updated_at': '2019-10-31T00:09:25.110925', | |
'full_kind': 'model_blueprint.1', | |
'kind': 'model_blueprint', | |
'name': 'wills_test_blueprint', | |
'oid': '973d9ad3-5a35-4d5b-929f-3ffee8cf5077', | |
'project': { | |
'full_kind': 'project.default.1', | |
'kind': 'project.default', | |
'name': 'osaro/misc', | |
'oid': '10ed59bd-0ea2-4fd2-a4cc-1efe6c189549', | |
'version': 1 | |
}, | |
'spec': { | |
'updated_at': '2019-11-07T21:55:52.481618', | |
'aggregators': { | |
'best_model': { | |
'()': 'mask_rcnn.model_blueprints.aggregators.gqfcn.cross_entropy.GQFCNV1CrossEntropyAggregator' | |
} | |
} | |
} | |
}] | |
} | |
return MockResponse(result, self.request_args.get('status_code', 200)) | |
def test_show_me_available_properties(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
pprint.pprint(results[0].__dir__()) | |
#uncomment the following line to see what properties the object has. stdout is only produced on error with pytest | |
#assert 1 == 2 | |
def test_model_blueprint_get_all_request_url(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert request.url == "https://geri.osaro.io:443/api/v1/model_blueprints" | |
def test_model_blueprint_get_all_request_headers(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert request.request_args['headers']['Accept'] == 'application/json' | |
assert 'Bearer ' in request.request_args['headers']['Authorization'] | |
def test_model_blueprint_get_all_request_stream(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert not request.request_args['stream'] | |
def test_model_blueprint_get_all_request_data(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert request.request_args['json']['data'] == {} | |
#use sets to compare, as we don't care about order | |
assert set(request.request_args['json']['fields']) == set([ | |
'full_kind', | |
'created_at', | |
'project', | |
'version', | |
'spec', | |
'kind', | |
'updated_at', | |
'committed', | |
'oid', | |
'collection_version', | |
'name']) | |
def test_model_blueprint_result_class(monkeypatch): | |
request = MockRequest() | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert results[0].__class__.__name__ == "prod_ModelBlueprint" | |
def test_model_blueprint_collection_version(monkeypatch): | |
request = MockRequest(collection_version=111) | |
monkeypatch.setattr(requests.sessions.Session, "get", request.get) | |
results = list(ModelBlueprint.get_all()) | |
assert results[0].collection_version == 111 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment