Created
November 6, 2023 15:07
-
-
Save jackylamhk/908607b5213440f26aa2bf3087f9f80f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import requests | |
import json | |
import msal | |
from urllib.parse import urlencode | |
from email.mime.multipart import MIMEMultipart | |
from helpers import cypto | |
from helpers.config import AppConfig | |
logger = logging.getLogger(__name__) | |
class MSGraphClient: | |
def __init__(self, | |
tenant_id: str, | |
client_id: str, | |
client_secret: str = None, | |
api_url: str = 'https://graph.microsoft.com/v1.0/', | |
public_auth = False): | |
""" | |
MS Graph Custom Client, uses the msal and requests libraries. | |
:params bool public_auth: Uses interactive sign in instead of client credentials. | |
""" | |
# Set MS Graph Base URL | |
self.tenant_id = tenant_id | |
self.client_id = client_id | |
self._client_secret = client_secret | |
self.api_url = api_url | |
self.public_auth = public_auth | |
self._init_msal() | |
self._start_session() | |
def _init_msal(self): | |
if self.public_auth == True: | |
# !FIXME Scopes - to be refined | |
# Interactive authentication (user-based) flow | |
self.scopes = ["Directory.ReadWrite.All","Mail.Send","RoleManagement.ReadWrite.Directory","User.Export.All","User.Invite.All","User.ManageIdentities.All","User.ReadWrite.All"] | |
self.auth_client = msal.PublicClientApplication(self.client_id) | |
else: | |
authority = f"https://login.microsoftonline.com/{self.tenant_id}" | |
self.scopes = ['https://graph.microsoft.com/.default'] | |
self.auth_client = msal.ConfidentialClientApplication(self.client_id, authority=authority, client_credential={self._client_secret}) | |
def _get_token(self): | |
token_result = self.auth_client.acquire_token_silent(self.scopes, account=None) | |
# If the token is not available in cache, acquire a new one from Azure AD and save it to a variable | |
if not token_result: | |
if self.public_auth == True: | |
token_result = self.auth_client.acquire_token_interactive(self.scopes, prompt=None, login_hint=None, domain_hint=None, claims_challenge=None, timeout=None, port=None, extra_scopes_to_consent=None, max_age=None, parent_window_handle=None, on_before_launching_ui=None) | |
else: | |
token_result = self.auth_client.acquire_token_for_client(self.scopes) | |
try: | |
access_token = {'Authorization': f"Bearer {token_result['access_token']}"} | |
except KeyError: | |
raise ValueError(f"GraphCustomClient Authentication Error: {token_result}") | |
return access_token | |
def _start_session(self): | |
auth_header = self._get_token() | |
self.session = requests.Session() | |
self.session.headers = auth_header | |
def _construct_uri(self, path: str, params: dict = None): | |
uri = f"{self.api_url}{path}" | |
if params == None: | |
return uri | |
else: | |
return f"{uri}?{urlencode(params)}" | |
def _validate_response(_, response: requests.Response): | |
try: | |
response.raise_for_status() | |
except requests.exceptions.HTTPError: | |
raise ConnectionError(f"GraphCustomClient Error: {response.status_code} {response._content}") | |
# !FIXME Use an OData library instead? | |
try: | |
payload = response.json() | |
payload.pop('@odata.context', None) | |
logger.debug("GraphCustomClient: Response: "+json.dumps(payload)) | |
return payload | |
except requests.JSONDecodeError: | |
logger.debug("GraphCustomClient: No reponse content") | |
return None | |
except json.JSONDecodeError: | |
return None | |
except KeyError: | |
return response.json() | |
def _post(self, path: str, payload: dict | str | MIMEMultipart): | |
auth_header = self._get_token() | |
if type(payload) is dict: | |
response = self.session.post(self._construct_uri(path), json=payload, headers=auth_header) | |
else: | |
auth_header['Content-Type'] = 'text/plain' | |
response = self.session.post(self._construct_uri(path), data=payload, headers=auth_header) | |
return self._validate_response(response) | |
def _put(self, path: str, payload: dict): | |
auth_header = self._get_token() | |
response = self.session.put(self._construct_uri(path), json=payload, headers=auth_header) | |
return self._validate_response(response) | |
def _get(self, path: str, params=None): | |
auth_header = self._get_token() | |
if params and '$filter' in params: | |
auth_header['ConsistencyLevel'] = 'eventual' | |
params['$count'] = 'true' | |
response = self.session.get(self._construct_uri(path,params), headers=auth_header) | |
return self._validate_response(response) | |
def _get_gen(self, path: str, params=None): | |
auth_header = self._get_token() | |
if params and '$filter' in params: | |
auth_header['ConsistencyLevel'] = 'eventual' | |
params['$count'] = 'true' | |
# Get the first page | |
response = requests.get(self._construct_uri(path,params), headers=auth_header) | |
while True: | |
self._validate_response(response) | |
payload = response.json() | |
yield payload['value'] | |
if '@odata.nextLink' not in payload: | |
break | |
# Get the next page | |
response = requests.get(payload['@odata.nextLink'], headers=auth_header) | |
def _delete(self, path: str, params = None): | |
auth_header = self._get_token() | |
response = self.session.delete(self._construct_uri(path,params), headers=auth_header) | |
return self._validate_response(response) | |
def _patch(self, path: str, payload: dict, params = None): | |
auth_header = self._get_token() | |
response = self.session.patch(self._construct_uri(path,params), json=payload, headers=auth_header) | |
return self._validate_response(response) | |
def get_user_gen(self, params: dict = None): | |
users_pages = self._get_gen('users', params) | |
for users in users_pages: | |
for user in users: | |
yield user | |
def update_user(self, user_id: str, payload: dict): | |
return self._patch(f"users/{user_id}", payload) | |
def update_user_manager(self, user_id: str, manager_id: str): | |
self._put(f"users/{user_id}/manager/$ref", { "@odata.id": f"https://graph.microsoft.com/v1.0/users/{manager_id}" }) | |
def disable_user(self, user_id: str): | |
return self.update_user(user_id, {'accountEnabled': False, 'employeeType': 'Terminated'}) | |
def get_user_licenses(self, user_id: str): | |
response = self._get(f"users/{user_id}", {'$select': 'assignedLicenses'}) | |
return response['assignedLicenses'] | |
def remove_user_licenses(self, user_id: str, licenses: list): | |
""" | |
Remove User Licenses | |
:param list licenses: skuId of the license(s) | |
""" | |
payload = {"addLicenses": [], "removeLicenses": []} | |
payload["removeLicenses"] = licenses | |
return self._post(f"users/{user_id}/assignLicense", payload=payload) | |
def get_user_groups(self, user_id: str): | |
return self._get(f"users/{user_id}/memberOf/microsoft.graph.group") | |
def remove_group_member(self, group_id: str, user_id: str): | |
return self._delete(f"groups/{group_id}/members/{user_id}/$ref") | |
def get_user_directory_role(self, user_id: str): | |
return self._get(f"users/{user_id}/memberOf/microsoft.graph.directoryRole", {'$select': 'id'}) | |
def remove_user_directory_role(self, user_id: str, role_id: str): | |
return self._delete(f"directoryRoles/{role_id}/members/{user_id}/$ref") | |
def remove_user_manager(self, user_id: str): | |
return self._delete(f"users/{user_id}/manager/$ref") | |
def assign_user_manager(self, user_id: str, manager_id: str): | |
payload = { | |
"@odata.id": f"https://graph.microsoft.com/v1.0/users/{manager_id}" | |
} | |
return self._put(f"users/{user_id}/manager/$ref", payload) | |
def assign_user_license(self, user_id: str, license: str): | |
""" | |
Assigns a license to a user. | |
Microsoft 365 Business Premium SKU ID: cbdc14ab-d96c-4c30-b9f4-6ada7cdc1d46 | |
""" | |
payload = { | |
"addLicenses": [{ | |
"disabledPlans": [], | |
"skuId": license | |
}], | |
"removeLicenses": [] | |
} | |
return self._post(f"users/{user_id}/assignLicense", payload=payload) | |
def get_user(self, user_id: str, params: dict = None): | |
""" | |
Returns a user's information. | |
Use the following filter to get their email from their Employee ID: | |
{ "$select": "userPrincipalName,employeeId", "$filter": f"employeeId eq '{employee_id}'" } | |
""" | |
return self._get(f"users/{user_id}", params) | |
def send_email(self, message: dict | MIMEMultipart, sender_email: str = None): | |
""" | |
Send an Email via MS Graph. | |
:param str message: MIME or MS Graph API Payload | |
:param str sender_email: Sender's email | |
""" | |
if not sender_email: | |
sender_email = AppConfig.default_sender_email | |
if type(message) is MIMEMultipart: | |
message = cypto.base64_encode(message) | |
return self._post(f"users/{sender_email}/sendMail", message) | |
def create_user(self, payload: dict): | |
""" | |
Create a user on Azure AD. | |
Either pass a payload or pass the required information. | |
Paylod required: {"accountEnabled": True, "displayName": display_name, "mailNickname": email_alias, | |
"userPrincipalName": email, "passwordProfile" : {"forceChangePasswordNextSignIn": True, "password": | |
password}} | |
""" | |
return self._post('users', payload) | |
def delete_user(self, user_id: str): | |
""" | |
Delete a user on Azure AD. | |
""" | |
return self._delete(f"users/{user_id}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment