Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix multiple issues in current PR by crating V2 APIs #267

Merged
merged 6 commits into from
Sep 13, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
270 changes: 270 additions & 0 deletions wrapper/cpp/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,20 @@

package main

/*
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

typedef unsigned long long uint64_t;

*/
import "C"

import (
"context"
"time"
"unsafe"

"github.com/microsoft/moc-sdk-for-go/services/security/authentication"
"github.com/microsoft/moc-sdk-for-go/services/security/keyvault"
Expand All @@ -25,6 +34,28 @@ import (
"github.com/microsoft/moc/pkg/config"
)

// possible Win32 return values
const (
Win32Succeed int = 0 // ERROR_SUCCESS
Win32ErrorInsufficientBuffer int = 122 // ERROR_INSUFFICIENT_BUFFER
Win32ErrorBadArg int = 160 // ERROR_BAD_ARGUMENTS
Win32ErrorFunctionFail int = 1627 // ERROR_FUNCTION_FAILED
)

// helper function to best effort copy over the error message
func copyErrorMessage(errMessage *C.char, errMessageBuffer *C.char, errMessageSize C.ulonglong) {
if (errMessage != nil) {
if (errMessageBuffer != nil && errMessageSize > 0) {
msglength := C.strlen(errMessage)
if (errMessageSize < msglength) {
msglength = errMessageSize
}
C.strncpy(errMessageBuffer, errMessage, msglength)
}
C.free(unsafe.Pointer(errMessage))
}
}

// This function exists to maintain backwards compatability. Please use SecurityLoginCV.
//
//export SecurityLogin
Expand Down Expand Up @@ -63,6 +94,46 @@ func SecurityLoginCV(serverName *C.char, groupName *C.char, loginFilePath *C.cha
return nil
}

//export SecurityLoginV2
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
func SecurityLoginV2(serverName *C.char, groupName *C.char, loginFilePath *C.char, cv *C.char, timeoutInSeconds C.int, errMessageBuffer *C.char, errMessageSize C.ulonglong) C.int {
if (serverName == nil || groupName == nil || loginFilePath == nil || cv == nil) {
telemetry.EmitWrapperTelemetry("SecurityLoginV2", C.GoString(cv), "", "InvalidArgument", C.GoString(serverName))
copyErrorMessage(C.CString("invalid argument"), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorBadArg)
}

loginconfig := auth.LoginConfig{}
err := config.LoadYAMLFile(C.GoString(loginFilePath), &loginconfig)
if err != nil {
telemetry.EmitWrapperTelemetry("SecurityLoginV2", C.GoString(cv), err.Error(), "config.LoadYAMLFile", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

authenticationClient, err := authentication.NewAuthenticationClientAuthMode(C.GoString(serverName), loginconfig)
if err != nil {
telemetry.EmitWrapperTelemetry("SecurityLoginV2", C.GoString(cv), err.Error(), "authentication.NewAuthenticationClientAuthMode", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutInSeconds)*time.Second)
defer cancel()

// Login with config stores the access file in the WSSD_CONFIG environment variable
// set true to auto renew
_, err = authenticationClient.LoginWithConfig(ctx, C.GoString(groupName), loginconfig, true)
if err != nil {
telemetry.EmitWrapperTelemetry("SecurityLoginV2", C.GoString(cv), err.Error(), "authenticationClient.LoginWithConfig", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

//Provide moc version information after login
telemetry.EmitWrapperTelemetry("SecurityLoginV2", C.GoString(cv), "", "", C.GoString(serverName))
return C.int(Win32Succeed)
}

// This function exists to maintain backwards compatability. Please use KeyvaultKeyEncryptDataCV.
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
//
//export KeyvaultKeyEncryptData
Expand Down Expand Up @@ -101,6 +172,60 @@ func KeyvaultKeyEncryptDataCV(serverName *C.char, groupName *C.char, keyvaultNam
return C.CString(*response.Result)
}

//export KeyvaultKeyEncryptDataV2
func KeyvaultKeyEncryptDataV2(serverName *C.char, groupName *C.char, keyvaultName *C.char, keyName *C.char, input *C.char, algorithm *C.char, cv *C.char, timeoutInSeconds C.int, outputBuffer *C.char, outputBufferSize *C.ulonglong, errMessageBuffer *C.char, errMessageSize C.ulonglong) C.int {
if (serverName == nil || groupName == nil || keyvaultName == nil || keyName == nil || input == nil || algorithm == nil || cv == nil || outputBufferSize == nil) {
copyErrorMessage(C.CString("Invalid Argument"), errMessageBuffer, errMessageSize)
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
telemetry.EmitWrapperTelemetry("KeyvaultKeyEncryptDataV2", C.GoString(cv), "", "InvalidArgument", C.GoString(serverName))
return C.int(Win32ErrorBadArg)
}

keyClient, err := getKeyvaultKeyClient(C.GoString(serverName), C.GoString(cv))
// if errror occurs, return an empty string so that caller can tell between error and encrypted blob
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultKeyEncryptDataV2", C.GoString(cv), err.Error(), "getKeyvaultKeyClient", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutInSeconds)*time.Second)
defer cancel()

// input is base64 encoded
var value string
value = C.GoString(input)

alg :=keyvault.JSONWebKeyEncryptionAlgorithm(C.GoString(algorithm))
parameters := &keyvault.KeyOperationsParameters{
Value: &value,
Algorithm: alg,
}

response, err := keyClient.Encrypt(ctx, C.GoString(groupName), C.GoString(keyvaultName), C.GoString(keyName), parameters)
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultKeyEncryptDataV2", C.GoString(cv), err.Error(), "keyClient.Encrypt", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

// retrun base64 encoded string
danmamsft marked this conversation as resolved.
Show resolved Hide resolved
encryptedCString := C.CString(*response.Result)
var encryptedCStringLength C.ulonglong = C.strlen(encryptedCString)

if (outputBuffer == nil || *outputBufferSize < encryptedCStringLength) {
telemetry.EmitWrapperTelemetry("KeyvaultKeyEncryptDataV2", C.GoString(cv), "", "InsufficientBuffer", C.GoString(serverName))
*outputBufferSize = encryptedCStringLength;
C.free(unsafe.Pointer(encryptedCString))
return C.int(Win32ErrorInsufficientBuffer)
}

// copy over the result
C.strncpy(outputBuffer, encryptedCString, encryptedCStringLength)
*outputBufferSize = encryptedCStringLength
C.free(unsafe.Pointer(encryptedCString))
return C.int(Win32Succeed)
}

// This function exists to maintain backwards compatability. Please use KeyvaultKeyDecryptDataCV.
//
//export KeyvaultKeyDecryptData
Expand Down Expand Up @@ -137,6 +262,56 @@ func KeyvaultKeyDecryptDataCV(serverName *C.char, groupName *C.char, keyvaultNam
return C.CString(*response.Result)
}

//export KeyvaultKeyDecryptDataV2
func KeyvaultKeyDecryptDataV2(serverName *C.char, groupName *C.char, keyvaultName *C.char, keyName *C.char, input *C.char, algorithm *C.char, cv *C.char, timeoutInSeconds C.int, outputBuffer *C.char, outputBufferSize *C.ulonglong, errMessageBuffer *C.char, errMessageSize C.ulonglong) C.int {
if (serverName == nil || groupName == nil || keyvaultName == nil || keyName == nil || input == nil || algorithm == nil || cv == nil || outputBufferSize == nil) {
telemetry.EmitWrapperTelemetry("KeyvaultKeyDecryptDataV2", C.GoString(cv), "", "InvalidArgument", C.GoString(serverName))
copyErrorMessage(C.CString("Invalid Argument"), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorBadArg)
}

keyClient, err := getKeyvaultKeyClient(C.GoString(serverName), C.GoString(cv))
// if errror occurs, return an empty string so that caller can tell between error and decrypted blob
if err != nil {
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutInSeconds)*time.Second)
defer cancel()

var value string
value = C.GoString(input)

alg :=keyvault.JSONWebKeyEncryptionAlgorithm(C.GoString(algorithm))
parameters := &keyvault.KeyOperationsParameters{
Value: &value,
Algorithm: alg,
}

response, err := keyClient.Decrypt(ctx, C.GoString(groupName), C.GoString(keyvaultName), C.GoString(keyName), parameters)
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultKeyDecryptDataV2", C.GoString(cv), err.Error(), "keyClient.Decrypt", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

decryptedCString := C.CString(*response.Result)
var decryptedCStringLength C.ulonglong = C.strlen(decryptedCString)

if (outputBuffer == nil || *outputBufferSize < decryptedCStringLength) {
telemetry.EmitWrapperTelemetry("KeyvaultKeyDecryptDataV2", C.GoString(cv), "", "InsufficientBuffer", C.GoString(serverName))
*outputBufferSize = decryptedCStringLength;
C.free(unsafe.Pointer(decryptedCString))
return C.int(Win32ErrorInsufficientBuffer)
}
// copy over the result
C.strncpy(outputBuffer, decryptedCString, decryptedCStringLength)
*outputBufferSize = decryptedCStringLength
C.free(unsafe.Pointer(decryptedCString))
return C.int(Win32Succeed)
}

// This function exists to maintain backwards compatability. Please use KeyvaultKeyExistCV.
//
//export KeyvaultKeyExist
Expand Down Expand Up @@ -214,6 +389,52 @@ func KeyvaultKeyCreateOrUpdateCV(serverName *C.char, groupName *C.char, keyvault
return nil
}

//export KeyvaultKeyCreateOrUpdateV2
func KeyvaultKeyCreateOrUpdateV2(serverName *C.char, groupName *C.char, keyvaultName *C.char, keyName *C.char, keyTypeName *C.char, keySize C.int, cv *C.char, timeoutInSeconds C.int, errMessageBuffer *C.char, errMessageSize C.ulonglong) C.int {
if (serverName == nil || groupName == nil || keyvaultName == nil || keyName == nil || keyTypeName == nil || cv == nil) {
telemetry.EmitWrapperTelemetry("KeyvaultKeyCreateOrUpdateV2", C.GoString(cv), "", "InvalidArgument", C.GoString(serverName))
copyErrorMessage(C.CString("Invalid Argument"), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorBadArg)
}

keyClient, err := getKeyvaultKeyClient(C.GoString(serverName), C.GoString(cv))
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultKeyCreateOrUpdateV2", C.GoString(cv), err.Error(), "getKeyvaultKeyClient", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

var kvConfig *keyvault.Key
kvConfig = &keyvault.Key{}

var keyNameString string
keyNameString = C.GoString(keyName)
kvConfig.Name = &keyNameString
kvConfig.KeyProperties = &keyvault.KeyProperties{}

kvConfig.KeyType = keyvault.JSONWebKeyType(C.GoString(keyTypeName))
var tKeySize int32
tKeySize = int32(keySize)
kvConfig.KeySize = &tKeySize

var keyRotation int64
keyRotation = -1
kvConfig.KeyRotationFrequencyInSeconds = &keyRotation // -1 means disable key rotation

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutInSeconds)*time.Second)
defer cancel()

_, err = keyClient.CreateOrUpdate(ctx, C.GoString(groupName), C.GoString(keyvaultName), C.GoString(keyName), kvConfig)
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultKeyCreateOrUpdateV2", C.GoString(cv), err.Error(), "keyClient.CreateOrUpdate", C.GoString(serverName))
//This return cannot be empty!
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

return C.int(Win32Succeed)
}

// This function exists to maintain backwards compatability. Please use KeyvaultKeySignDataCV.
//
//export KeyvaultKeySignData
Expand Down Expand Up @@ -332,6 +553,55 @@ func KeyvaultGetPublicKeyCV(serverName *C.char, groupName *C.char, keyvaultName
return C.CString(*pemPkcs1KeyPub)
}

//export KeyvaultGetPublicKeyV2
func KeyvaultGetPublicKeyV2(serverName *C.char, groupName *C.char, keyvaultName *C.char, keyName *C.char, cv *C.char, timeoutInSeconds C.int, outputBuffer *C.char, outputBufferSize *C.ulonglong, errMessageBuffer *C.char, errMessageSize C.ulonglong) C.int {
if (serverName == nil || groupName == nil || keyvaultName == nil || keyName == nil || outputBufferSize == nil || cv == nil) {
telemetry.EmitWrapperTelemetry("KeyvaultGetPublicKeyV2", C.GoString(cv), "", "InvalidArgument", C.GoString(serverName))
copyErrorMessage(C.CString("Invalid Argument"), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorBadArg)
}

keyClient, err := getKeyvaultKeyClient(C.GoString(serverName), C.GoString(cv))
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultGetPublicKeyV2", C.GoString(cv), err.Error(), "getKeyvaultKeyClient", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeoutInSeconds)*time.Second)
defer cancel()

keys, err := keyClient.Get(ctx, C.GoString(groupName), C.GoString(keyvaultName), C.GoString(keyName))
if err != nil {
telemetry.EmitWrapperTelemetry("KeyvaultGetPublicKeyV2", C.GoString(cv), err.Error(), "keyClient.Get", C.GoString(serverName))
copyErrorMessage(C.CString(telemetry.FilterSensitiveData(err.Error())), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

if keys == nil || len(*keys) <= 0 {
telemetry.EmitWrapperTelemetry("KeyvaultGetPublicKeyV2", C.GoString(cv), "", "EmptyKey", C.GoString(serverName))
copyErrorMessage(C.CString("Returned key is empty"), errMessageBuffer, errMessageSize)
return C.int(Win32ErrorFunctionFail)
}

pemPkcs1KeyPub := (*keys)[0].Value
pemPkcs1KeyPubCString := C.CString(*pemPkcs1KeyPub)
var pemPkcs1KeyPubCStringLength C.ulonglong = C.strlen(pemPkcs1KeyPubCString)

if (outputBuffer == nil || *outputBufferSize < pemPkcs1KeyPubCStringLength) {
telemetry.EmitWrapperTelemetry("KeyvaultGetPublicKeyV2", C.GoString(cv), "", "InsufficientBuffer", C.GoString(serverName))
*outputBufferSize = pemPkcs1KeyPubCStringLength;
C.free(unsafe.Pointer(pemPkcs1KeyPubCString))
return C.int(Win32ErrorInsufficientBuffer)
}

// copy over the result and size
C.strncpy(outputBuffer, pemPkcs1KeyPubCString, pemPkcs1KeyPubCStringLength)
*outputBufferSize = pemPkcs1KeyPubCStringLength
C.free(unsafe.Pointer(pemPkcs1KeyPubCString))
return C.int(Win32Succeed)
}

func getKeyvaultKeyClient(serverName string, cv string) (*key.KeyClient, error) {
authorizer, err := auth.NewAuthorizerFromEnvironment(serverName)
if err != nil {
Expand Down
Loading