Skip to content

Commit

Permalink
Merge pull request #3 from almenscorner/msal
Browse files Browse the repository at this point in the history
v1.1.0
  • Loading branch information
almenscorner committed Feb 13, 2023
2 parents eb8cdca + 0ad8a08 commit 7cb6500
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 55 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ else:
```

In addition to importing this package to your automation account when running from Azure Automation, you must also import the following packages,
- [adal](https://pypi.org/project/adal/)
- [msal](https://pypi.org/project/msal)
- [azure-core](https://pypi.org/project/azure-core/)
- [azure-storage-blob](https://pypi.org/project/azure-storage-blob/)
- [msrest](https://pypi.org/project/msrest/)
Expand Down Expand Up @@ -146,6 +146,12 @@ To use the tool, you must set a couple of environment variables that will be use
- CONTAINER_NAME - Name of your Azure Storage Container
- AZURE_STORAGE_CONNECTION_STRING - Connection string to your Azure Storage account

If using interactive authentication, the CLIENT_SECRET is not required.

If using certificate authentication, additional environment variables are required,
- THUMBPRINT - Thumbprint of the certificate on your app registration
- KEY_FILE - Path to the private key of the certificate on your app registation

## Azure AD app registration permissions
- DeviceManagementManagedDevices.Read.All
- Directory.Read.All
Expand Down
28 changes: 0 additions & 28 deletions munki_manifest_generator/env_vars.py

This file was deleted.

17 changes: 16 additions & 1 deletion munki_manifest_generator/get_device_catalogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
This module is used to get the catalogs that should be added or removed from a device.
"""

from collections import defaultdict


def get_device_catalogs(
groups, device_manifest, add_catalogs=False, remove_catalogs=False
):
Expand All @@ -29,14 +32,26 @@ def get_device_catalogs(
if remove_catalogs:
catalogs_to_remove = []

grouped_by_catalog = defaultdict(list)
for item in groups:
grouped_by_catalog[item['catalog']].append(item)

multiple_result = dict(grouped_by_catalog)

for catalog in device_manifest.catalogs:
if catalog not in GROUP_CATALOGS and catalog != "Production":
device_manifest.catalogs.remove(catalog)
catalogs_to_remove.append(catalog)

for group in groups:
if group["catalog"] is not None:
if (
if (group["catalog"] in multiple_result):
# get group names
group_names = [item["name"] for item in multiple_result[group["catalog"]]]
if group["name"] not in group_names:
device_manifest.catalogs.remove(group["catalog"])
catalogs_to_remove.append(group["catalog"])
elif (
group["name"] not in device_manifest.included_manifests
and group["catalog"] in device_manifest.catalogs
):
Expand Down
48 changes: 48 additions & 0 deletions munki_manifest_generator/graph/get_authentication_token.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#!/usr/bin/env python3

"""
This module is used to get the access token for the tenant.
"""

import os
import json

from munki_manifest_generator.graph.obtain_access_token import obtain_accesstoken_app, obtain_accesstoken_cert, obtain_accesstoken_interactive

def getAuth(app, certauth, interactiveauth):
"""
This function authenticates to MS Graph and returns the access token.
:param mode: The mode used when using this tool
:param localauth: Path to dict with keys to authenticate
:param tenant: Which tenant to authenticate to, PROD or DEV
:return: The access token
"""

if certauth:
KEY_FILE = os.environ.get("KEY_FILE")
THUMBPRINT = os.environ.get("THUMBPRINT")
TENANT_NAME = os.environ.get("TENANT_NAME")
CLIENT_ID = os.environ.get("CLIENT_ID")

if not all([KEY_FILE, THUMBPRINT, TENANT_NAME, CLIENT_ID]):
raise Exception("One or more os.environ variables not set")
return obtain_accesstoken_cert(TENANT_NAME, CLIENT_ID, THUMBPRINT, KEY_FILE)

if interactiveauth:
TENANT_NAME = os.environ.get("TENANT_NAME")
CLIENT_ID = os.environ.get("CLIENT_ID")

if not all([TENANT_NAME, CLIENT_ID]):
raise Exception("One or more os.environ variables not set")

return obtain_accesstoken_interactive(TENANT_NAME, CLIENT_ID)

if app:
TENANT_NAME = os.environ.get("TENANT_NAME")
CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
if not all([TENANT_NAME, CLIENT_ID, CLIENT_SECRET]):
raise Exception("One or more os.environ variables not set")

return obtain_accesstoken_app(TENANT_NAME, CLIENT_ID, CLIENT_SECRET)
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
This module is used to get the device group membership and update included manifests.
"""

from munki_manifest_generator.graph.obtain_access_token import obtain_access_token
from munki_manifest_generator.graph.make_api_request import make_api_request

ENDPOINT = "https://graph.microsoft.com/v1.0/devices"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
This module is used to get the user group membership and update included manifests.
"""

from munki_manifest_generator.graph.obtain_access_token import obtain_access_token
from munki_manifest_generator.graph.make_api_request import make_api_request

ENDPOINT = "https://graph.microsoft.com/v1.0/users"
Expand Down
2 changes: 1 addition & 1 deletion munki_manifest_generator/graph/make_api_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def make_api_request(endpoint, token, q_param=None):

headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {0}".format(token["accessToken"]),
"Authorization": "Bearer {0}".format(token["access_token"]),
}

# This section handles a bug with the Python requests module which
Expand Down
123 changes: 113 additions & 10 deletions munki_manifest_generator/graph/obtain_access_token.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,126 @@
#!/usr/bin/env python3

"""
This module is used to obtain an access token for use with Graph API.
This module contains the functions used to get the access token for MS Graph.
"""

from adal import AuthenticationContext

from msal import ConfidentialClientApplication, PublicClientApplication

def obtain_access_token(client_id, client_secret, tenant_name):
"""Return an access token for use with Graph API."""
AUTHORITY = "https://login.microsoftonline.com/"
SCOPE = ["https://graph.microsoft.com/.default"]

auth_context = AuthenticationContext(
"https://login.microsoftonline.com/" + tenant_name

def obtain_accesstoken_app(TENANT_NAME, CLIENT_ID, CLIENT_SECRET):
"""
This function is used to get an access token to MS Graph using client credentials.
:param TENANT_NAME: The name of the Azure tenant
:param CLIENT_ID: The ID of the registered Azure AD application
:param CLIENT_SECRET: Secret of the registered Azure AD application
:return: The access token
"""

# Create app instance
app = ConfidentialClientApplication(
client_id=CLIENT_ID,
client_credential=CLIENT_SECRET,
authority=AUTHORITY + TENANT_NAME,
)

token = None

try:
# Check if token is already cached
token = app.acquire_token_silent(SCOPE, account=None)

# If not, get a new token
if not token:
token = app.acquire_token_for_client(scopes=SCOPE)
if not token:
raise Exception("No token returned")

except Exception as e:
raise Exception("Error obtaining access token: " + str(e))

return token


def obtain_accesstoken_cert(TENANT_NAME, CLIENT_ID, THUMBPRINT, KEY_FILE):
"""
This function is used to get an access token to MS Graph using a certificate.
:param TENANT_NAME: The name of the Azure tenant
:param CLIENT_ID: The ID of the registered Azure AD application
:param THUMBPRINT Thumbprint of the certificate uploaded to Azure AD
:param KEY_FILE: Path to the private key of the certificate
:return: The access token
"""

# Create app instance
app = ConfidentialClientApplication(
client_id=CLIENT_ID,
client_credential={
"thumbprint": THUMBPRINT,
"private_key": open(KEY_FILE).read(),
},
authority=AUTHORITY + TENANT_NAME,
)

token = auth_context.acquire_token_with_client_credentials(
resource="https://graph.microsoft.com",
client_id=client_id,
client_secret=client_secret,
token = None

try:
# Check if token is already cached
token = app.acquire_token_silent(SCOPE, account=None)

# If not, get a new token
if not token:
token = app.acquire_token_for_client(scopes=SCOPE)
if not token:
raise Exception("No token returned")

except Exception as e:
raise Exception("Error obtaining access token: " + str(e))

return token


def obtain_accesstoken_interactive(TENANT_NAME, CLIENT_ID):
"""
This function is used to get an access token to MS Graph interactivly.
:param TENANT_NAME: The name of the Azure tenant
:param CLIENT_ID: The ID of the registered Azure AD application
:return: The access token
"""

# Create app instance
app = PublicClientApplication(
client_id=CLIENT_ID,
client_credential=None,
authority=AUTHORITY + TENANT_NAME,
)

token = None

# Set the required scopes
scopes = [
"DeviceManagementManagedDevices.Read.All",
"Directory.Read.All",
"GroupMember.Read.All",
"Group.Read.All"
]

try:
# Get the token interactively
token = app.acquire_token_interactive(
scopes=scopes, max_age=1200, prompt="select_account"
)

if not token:
raise Exception("No token returned")

except Exception as e:
raise Exception("Error obtaining access token: " + str(e))

return token
38 changes: 28 additions & 10 deletions munki_manifest_generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from operator import itemgetter
from munki_manifest_generator.env_vars import check_env_vars
from munki_manifest_generator.manifest import Manifest
from munki_manifest_generator.graph.obtain_access_token import obtain_access_token
from munki_manifest_generator.graph.get_authentication_token import getAuth
from munki_manifest_generator.graph.make_api_request import make_api_request
from munki_manifest_generator.graph.get_device_group_membership import (
get_device_group_membership,
Expand All @@ -35,6 +35,8 @@ def main(**kwargs):
l = None
s = None
t = None
c = None
i = None

if not kwargs:
argparser = argparse.ArgumentParser()
Expand All @@ -60,10 +62,22 @@ def main(**kwargs):
)
argparser.add_argument(
"-t",
"--test",
"--test",
help="Enable testing, no changes will be made to manifests on Azure Storage.",
action="store_true",
)
argparser.add_argument(
"-c",
"--certauth",
help="When using certificate auth, the following ENV variables is required: TENANT_NAME, CLIENT_ID, THUMBPRINT, KEY_FILE",
action="store_true",
)
argparser.add_argument(
"-i",
"--interactiveauth",
help="When using interactive auth, the following ENV variables is required: TENANT_NAME, CLIENT_ID",
action="store_true",
)

args = argparser.parse_args()

Expand All @@ -78,23 +92,27 @@ def main(**kwargs):
l = kwargs.get("group_list")
sm = kwargs.get("safe_manifest")
t = kwargs.get("test")
c = kwargs.get("certauth")
i = kwargs.get("interactiveauth")

if t:
print(
"*****Testing mode enabled, no changes will be made to manifests on Azure Storage*****"
)


def run(json_file, group_list, serial_number, SAFE_MANIFEST, TEST):
def run(json_file, group_list, serial_number, SAFE_MANIFEST, TEST, CERTAUTH, INTERACTIVEAUTH):

check_env_vars()
CLIENT_ID = os.environ.get("CLIENT_ID")
CLIENT_SECRET = os.environ.get("CLIENT_SECRET")
TENANT_NAME = os.environ.get("TENANT_NAME")
if not all([os.environ.get("CONTAINER_NAME"), os.environ.get("AZURE_STORAGE_CONNECTION_STRING")]):
raise Exception("Missing required environment variables, stopping...")
CONTAINER_NAME = os.environ.get("CONTAINER_NAME")
CONNECTION_STRING = os.environ.get("AZURE_STORAGE_CONNECTION_STRING")
ENDPOINT = "https://graph.microsoft.com/v1.0/deviceManagement/managedDevices"
TOKEN = obtain_access_token(CLIENT_ID, CLIENT_SECRET, TENANT_NAME)
if not CERTAUTH and not INTERACTIVEAUTH:
APP = True
else:
APP = False
TOKEN = getAuth(APP, CERTAUTH, INTERACTIVEAUTH)
CURRENT_MANIFESTS = get_current_manifest_blobs(
CONNECTION_STRING, CONTAINER_NAME
)
Expand Down Expand Up @@ -264,9 +282,9 @@ def run(json_file, group_list, serial_number, SAFE_MANIFEST, TEST):
)

if not kwargs:
run(args.json, args.group_list, args.serial_number, args.safe_manifest, args.test)
run(args.json, args.group_list, args.serial_number, args.safe_manifest, args.test, args.certauth, args.interactiveauth)
else:
run(j, l, s, sm, t)
run(j, l, s, sm, t, c, i)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 7cb6500

Please sign in to comment.