Skip to content

Instantly share code, notes, and snippets.

@jackylamhk
Created November 6, 2023 15:07
Show Gist options
  • Save jackylamhk/908607b5213440f26aa2bf3087f9f80f to your computer and use it in GitHub Desktop.
Save jackylamhk/908607b5213440f26aa2bf3087f9f80f to your computer and use it in GitHub Desktop.
import logging
import requests
import json
import msal
from urllib.parse import urlencode
from email.mime.multipart import MIMEMultipart
from helpers import cypto
from helpers.config import AppConfig
logger = logging.getLogger(__name__)
class MSGraphClient:
def __init__(self,
tenant_id: str,
client_id: str,
client_secret: str = None,
api_url: str = 'https://graph.microsoft.com/v1.0/',
public_auth = False):
"""
MS Graph Custom Client, uses the msal and requests libraries.
:params bool public_auth: Uses interactive sign in instead of client credentials.
"""
# Set MS Graph Base URL
self.tenant_id = tenant_id
self.client_id = client_id
self._client_secret = client_secret
self.api_url = api_url
self.public_auth = public_auth
self._init_msal()
self._start_session()
def _init_msal(self):
if self.public_auth == True:
# !FIXME Scopes - to be refined
# Interactive authentication (user-based) flow
self.scopes = ["Directory.ReadWrite.All","Mail.Send","RoleManagement.ReadWrite.Directory","User.Export.All","User.Invite.All","User.ManageIdentities.All","User.ReadWrite.All"]
self.auth_client = msal.PublicClientApplication(self.client_id)
else:
authority = f"https://login.microsoftonline.com/{self.tenant_id}"
self.scopes = ['https://graph.microsoft.com/.default']
self.auth_client = msal.ConfidentialClientApplication(self.client_id, authority=authority, client_credential={self._client_secret})
def _get_token(self):
token_result = self.auth_client.acquire_token_silent(self.scopes, account=None)
# If the token is not available in cache, acquire a new one from Azure AD and save it to a variable
if not token_result:
if self.public_auth == True:
token_result = self.auth_client.acquire_token_interactive(self.scopes, prompt=None, login_hint=None, domain_hint=None, claims_challenge=None, timeout=None, port=None, extra_scopes_to_consent=None, max_age=None, parent_window_handle=None, on_before_launching_ui=None)
else:
token_result = self.auth_client.acquire_token_for_client(self.scopes)
try:
access_token = {'Authorization': f"Bearer {token_result['access_token']}"}
except KeyError:
raise ValueError(f"GraphCustomClient Authentication Error: {token_result}")
return access_token
def _start_session(self):
auth_header = self._get_token()
self.session = requests.Session()
self.session.headers = auth_header
def _construct_uri(self, path: str, params: dict = None):
uri = f"{self.api_url}{path}"
if params == None:
return uri
else:
return f"{uri}?{urlencode(params)}"
def _validate_response(_, response: requests.Response):
try:
response.raise_for_status()
except requests.exceptions.HTTPError:
raise ConnectionError(f"GraphCustomClient Error: {response.status_code} {response._content}")
# !FIXME Use an OData library instead?
try:
payload = response.json()
payload.pop('@odata.context', None)
logger.debug("GraphCustomClient: Response: "+json.dumps(payload))
return payload
except requests.JSONDecodeError:
logger.debug("GraphCustomClient: No reponse content")
return None
except json.JSONDecodeError:
return None
except KeyError:
return response.json()
def _post(self, path: str, payload: dict | str | MIMEMultipart):
auth_header = self._get_token()
if type(payload) is dict:
response = self.session.post(self._construct_uri(path), json=payload, headers=auth_header)
else:
auth_header['Content-Type'] = 'text/plain'
response = self.session.post(self._construct_uri(path), data=payload, headers=auth_header)
return self._validate_response(response)
def _put(self, path: str, payload: dict):
auth_header = self._get_token()
response = self.session.put(self._construct_uri(path), json=payload, headers=auth_header)
return self._validate_response(response)
def _get(self, path: str, params=None):
auth_header = self._get_token()
if params and '$filter' in params:
auth_header['ConsistencyLevel'] = 'eventual'
params['$count'] = 'true'
response = self.session.get(self._construct_uri(path,params), headers=auth_header)
return self._validate_response(response)
def _get_gen(self, path: str, params=None):
auth_header = self._get_token()
if params and '$filter' in params:
auth_header['ConsistencyLevel'] = 'eventual'
params['$count'] = 'true'
# Get the first page
response = requests.get(self._construct_uri(path,params), headers=auth_header)
while True:
self._validate_response(response)
payload = response.json()
yield payload['value']
if '@odata.nextLink' not in payload:
break
# Get the next page
response = requests.get(payload['@odata.nextLink'], headers=auth_header)
def _delete(self, path: str, params = None):
auth_header = self._get_token()
response = self.session.delete(self._construct_uri(path,params), headers=auth_header)
return self._validate_response(response)
def _patch(self, path: str, payload: dict, params = None):
auth_header = self._get_token()
response = self.session.patch(self._construct_uri(path,params), json=payload, headers=auth_header)
return self._validate_response(response)
def get_user_gen(self, params: dict = None):
users_pages = self._get_gen('users', params)
for users in users_pages:
for user in users:
yield user
def update_user(self, user_id: str, payload: dict):
return self._patch(f"users/{user_id}", payload)
def update_user_manager(self, user_id: str, manager_id: str):
self._put(f"users/{user_id}/manager/$ref", { "@odata.id": f"https://graph.microsoft.com/v1.0/users/{manager_id}" })
def disable_user(self, user_id: str):
return self.update_user(user_id, {'accountEnabled': False, 'employeeType': 'Terminated'})
def get_user_licenses(self, user_id: str):
response = self._get(f"users/{user_id}", {'$select': 'assignedLicenses'})
return response['assignedLicenses']
def remove_user_licenses(self, user_id: str, licenses: list):
"""
Remove User Licenses
:param list licenses: skuId of the license(s)
"""
payload = {"addLicenses": [], "removeLicenses": []}
payload["removeLicenses"] = licenses
return self._post(f"users/{user_id}/assignLicense", payload=payload)
def get_user_groups(self, user_id: str):
return self._get(f"users/{user_id}/memberOf/microsoft.graph.group")
def remove_group_member(self, group_id: str, user_id: str):
return self._delete(f"groups/{group_id}/members/{user_id}/$ref")
def get_user_directory_role(self, user_id: str):
return self._get(f"users/{user_id}/memberOf/microsoft.graph.directoryRole", {'$select': 'id'})
def remove_user_directory_role(self, user_id: str, role_id: str):
return self._delete(f"directoryRoles/{role_id}/members/{user_id}/$ref")
def remove_user_manager(self, user_id: str):
return self._delete(f"users/{user_id}/manager/$ref")
def assign_user_manager(self, user_id: str, manager_id: str):
payload = {
"@odata.id": f"https://graph.microsoft.com/v1.0/users/{manager_id}"
}
return self._put(f"users/{user_id}/manager/$ref", payload)
def assign_user_license(self, user_id: str, license: str):
"""
Assigns a license to a user.
Microsoft 365 Business Premium SKU ID: cbdc14ab-d96c-4c30-b9f4-6ada7cdc1d46
"""
payload = {
"addLicenses": [{
"disabledPlans": [],
"skuId": license
}],
"removeLicenses": []
}
return self._post(f"users/{user_id}/assignLicense", payload=payload)
def get_user(self, user_id: str, params: dict = None):
"""
Returns a user's information.
Use the following filter to get their email from their Employee ID:
{ "$select": "userPrincipalName,employeeId", "$filter": f"employeeId eq '{employee_id}'" }
"""
return self._get(f"users/{user_id}", params)
def send_email(self, message: dict | MIMEMultipart, sender_email: str = None):
"""
Send an Email via MS Graph.
:param str message: MIME or MS Graph API Payload
:param str sender_email: Sender's email
"""
if not sender_email:
sender_email = AppConfig.default_sender_email
if type(message) is MIMEMultipart:
message = cypto.base64_encode(message)
return self._post(f"users/{sender_email}/sendMail", message)
def create_user(self, payload: dict):
"""
Create a user on Azure AD.
Either pass a payload or pass the required information.
Paylod required: {"accountEnabled": True, "displayName": display_name, "mailNickname": email_alias,
"userPrincipalName": email, "passwordProfile" : {"forceChangePasswordNextSignIn": True, "password":
password}}
"""
return self._post('users', payload)
def delete_user(self, user_id: str):
"""
Delete a user on Azure AD.
"""
return self._delete(f"users/{user_id}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment