// -*- Mode: Go; indent-tabs-mode: t -*-
// +build !nosecboot

/*
 * Copyright (C) 2021 Canonical Ltd
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License version 3 as
 * published by the Free Software Foundation.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 *
 */

package secboot

import (
	"io"

	sb "github.com/snapcore/secboot"
	sb_efi "github.com/snapcore/secboot/efi"
	sb_tpm2 "github.com/snapcore/secboot/tpm2"
)

var (
	EFIImageFromBootFile = efiImageFromBootFile
	LockTPMSealedKeys    = lockTPMSealedKeys
)

func MockSbConnectToDefaultTPM(f func() (*sb_tpm2.Connection, error)) (restore func()) {
	old := sbConnectToDefaultTPM
	sbConnectToDefaultTPM = f
	return func() {
		sbConnectToDefaultTPM = old
	}
}

func MockProvisionTPM(f func(tpm *sb_tpm2.Connection, mode sb_tpm2.ProvisionMode, newLockoutAuth []byte) error) (restore func()) {
	old := provisionTPM
	provisionTPM = f
	return func() {
		provisionTPM = old
	}
}

func MockSbEfiAddSecureBootPolicyProfile(f func(profile *sb_tpm2.PCRProtectionProfile, params *sb_efi.SecureBootPolicyProfileParams) error) (restore func()) {
	old := sbefiAddSecureBootPolicyProfile
	sbefiAddSecureBootPolicyProfile = f
	return func() {
		sbefiAddSecureBootPolicyProfile = old
	}
}

func MockSbEfiAddBootManagerProfile(f func(profile *sb_tpm2.PCRProtectionProfile, params *sb_efi.BootManagerProfileParams) error) (restore func()) {
	old := sbefiAddBootManagerProfile
	sbefiAddBootManagerProfile = f
	return func() {
		sbefiAddBootManagerProfile = old
	}
}

func MockSbEfiAddSystemdStubProfile(f func(profile *sb_tpm2.PCRProtectionProfile, params *sb_efi.SystemdStubProfileParams) error) (restore func()) {
	old := sbefiAddSystemdStubProfile
	sbefiAddSystemdStubProfile = f
	return func() {
		sbefiAddSystemdStubProfile = old
	}
}

func MockSbAddSnapModelProfile(f func(profile *sb_tpm2.PCRProtectionProfile, params *sb_tpm2.SnapModelProfileParams) error) (restore func()) {
	old := sbAddSnapModelProfile
	sbAddSnapModelProfile = f
	return func() {
		sbAddSnapModelProfile = old
	}
}

func MockSbSealKeyToTPMMultiple(f func(tpm *sb_tpm2.Connection, keys []*sb_tpm2.SealKeyRequest, params *sb_tpm2.KeyCreationParams) (sb_tpm2.PolicyAuthKey, error)) (restore func()) {
	old := sbSealKeyToTPMMultiple
	sbSealKeyToTPMMultiple = f
	return func() {
		sbSealKeyToTPMMultiple = old
	}
}

func MockSbUpdateKeyPCRProtectionPolicyMultiple(f func(tpm *sb_tpm2.Connection, keys []*sb_tpm2.SealedKeyObject, authKey sb_tpm2.PolicyAuthKey, pcrProfile *sb_tpm2.PCRProtectionProfile) error) (restore func()) {
	old := sbUpdateKeyPCRProtectionPolicyMultiple
	sbUpdateKeyPCRProtectionPolicyMultiple = f
	return func() {
		sbUpdateKeyPCRProtectionPolicyMultiple = old
	}
}

func MockSbBlockPCRProtectionPolicies(f func(tpm *sb_tpm2.Connection, pcrs []int) error) (restore func()) {
	old := sbBlockPCRProtectionPolicies
	sbBlockPCRProtectionPolicies = f
	return func() {
		sbBlockPCRProtectionPolicies = old
	}
}

func MockSbActivateVolumeWithRecoveryKey(f func(volumeName, sourceDevicePath string,
	keyReader io.Reader, options *sb.ActivateVolumeOptions) error) (restore func()) {
	old := sbActivateVolumeWithRecoveryKey
	sbActivateVolumeWithRecoveryKey = f
	return func() {
		sbActivateVolumeWithRecoveryKey = old
	}
}

func MockSbActivateVolumeWithTPMSealedKey(f func(tpm *sb_tpm2.Connection, volumeName, sourceDevicePath, keyPath string,
	pinReader io.Reader, options *sb.ActivateVolumeOptions) (bool, error)) (restore func()) {
	old := sbActivateVolumeWithTPMSealedKey
	sbActivateVolumeWithTPMSealedKey = f
	return func() {
		sbActivateVolumeWithTPMSealedKey = old
	}
}

func MockSbActivateVolumeWithKey(f func(volumeName, sourceDevicePath string, key []byte,
	options *sb.ActivateVolumeOptions) error) (restore func()) {
	old := sbActivateVolumeWithKey
	sbActivateVolumeWithKey = f
	return func() {
		sbActivateVolumeWithKey = old
	}
}

func MockSbActivateVolumeWithKeyData(f func(volumeName, sourceDevicePath string, key *sb.KeyData, options *sb.ActivateVolumeOptions) (sb.SnapModelChecker, error)) (restore func()) {
	oldSbActivateVolumeWithKeyData := sbActivateVolumeWithKeyData
	sbActivateVolumeWithKeyData = f
	return func() {
		sbActivateVolumeWithKeyData = oldSbActivateVolumeWithKeyData
	}
}

func MockSbMeasureSnapSystemEpochToTPM(f func(tpm *sb_tpm2.Connection, pcrIndex int) error) (restore func()) {
	old := sbMeasureSnapSystemEpochToTPM
	sbMeasureSnapSystemEpochToTPM = f
	return func() {
		sbMeasureSnapSystemEpochToTPM = old
	}
}

func MockSbMeasureSnapModelToTPM(f func(tpm *sb_tpm2.Connection, pcrIndex int, model sb.SnapModel) error) (restore func()) {
	old := sbMeasureSnapModelToTPM
	sbMeasureSnapModelToTPM = f
	return func() {
		sbMeasureSnapModelToTPM = old
	}
}

func MockRandomKernelUUID(f func() string) (restore func()) {
	old := randutilRandomKernelUUID
	randutilRandomKernelUUID = f
	return func() {
		randutilRandomKernelUUID = old
	}
}

func MockSbInitializeLUKS2Container(f func(devicePath, label string, key []byte,
	opts *sb.InitializeLUKS2ContainerOptions) error) (restore func()) {
	old := sbInitializeLUKS2Container
	sbInitializeLUKS2Container = f
	return func() {
		sbInitializeLUKS2Container = old
	}
}

func MockSbAddRecoveryKeyToLUKS2Container(f func(devicePath string, key []byte, recoveryKey sb.RecoveryKey, opts *sb.KDFOptions) error) (restore func()) {
	old := sbAddRecoveryKeyToLUKS2Container
	sbAddRecoveryKeyToLUKS2Container = f
	return func() {
		sbAddRecoveryKeyToLUKS2Container = old
	}
}

func MockIsTPMEnabled(f func(tpm *sb_tpm2.Connection) bool) (restore func()) {
	old := isTPMEnabled
	isTPMEnabled = f
	return func() {
		isTPMEnabled = old
	}
}

func MockFDEHasRevealKey(f func() bool) (restore func()) {
	old := fdeHasRevealKey
	fdeHasRevealKey = f
	return func() {
		fdeHasRevealKey = old
	}
}

func MockSbDeactivateVolume(f func(volumeName string) error) (restore func()) {
	old := sbDeactivateVolume
	sbDeactivateVolume = f
	return func() {
		sbDeactivateVolume = old
	}
}

func MockSbReadSealedKeyObject(f func(string) (*sb_tpm2.SealedKeyObject, error)) (restore func()) {
	old := sbReadSealedKeyObject
	sbReadSealedKeyObject = f
	return func() {
		sbReadSealedKeyObject = old
	}
}
