/*
 * Decompiled with CFR 0.152.
 */
package org.gcube.io.jsonwebtoken.impl.security;

import java.nio.charset.StandardCharsets;
import java.security.Key;
import java.security.KeyPair;
import java.security.PrivateKey;
import java.security.Provider;
import java.security.PublicKey;
import java.security.SecureRandom;
import java.security.interfaces.ECKey;
import javax.crypto.KeyAgreement;
import javax.crypto.SecretKey;
import org.gcube.io.jsonwebtoken.JweHeader;
import org.gcube.io.jsonwebtoken.impl.DefaultJweHeader;
import org.gcube.io.jsonwebtoken.impl.lang.Bytes;
import org.gcube.io.jsonwebtoken.impl.lang.CheckedFunction;
import org.gcube.io.jsonwebtoken.impl.lang.RequiredParameterReader;
import org.gcube.io.jsonwebtoken.impl.security.AbstractCurve;
import org.gcube.io.jsonwebtoken.impl.security.ConcatKDF;
import org.gcube.io.jsonwebtoken.impl.security.CryptoAlgorithm;
import org.gcube.io.jsonwebtoken.impl.security.DefaultDecryptionKeyRequest;
import org.gcube.io.jsonwebtoken.impl.security.DefaultKeyRequest;
import org.gcube.io.jsonwebtoken.impl.security.DirectKeyAlgorithm;
import org.gcube.io.jsonwebtoken.impl.security.ECCurve;
import org.gcube.io.jsonwebtoken.impl.security.EdwardsCurve;
import org.gcube.io.jsonwebtoken.impl.security.KeysBridge;
import org.gcube.io.jsonwebtoken.impl.security.StandardCurves;
import org.gcube.io.jsonwebtoken.lang.Arrays;
import org.gcube.io.jsonwebtoken.lang.Assert;
import org.gcube.io.jsonwebtoken.security.AeadAlgorithm;
import org.gcube.io.jsonwebtoken.security.Curve;
import org.gcube.io.jsonwebtoken.security.DecryptionKeyRequest;
import org.gcube.io.jsonwebtoken.security.DynamicJwkBuilder;
import org.gcube.io.jsonwebtoken.security.EcPublicJwk;
import org.gcube.io.jsonwebtoken.security.InvalidKeyException;
import org.gcube.io.jsonwebtoken.security.Jwks;
import org.gcube.io.jsonwebtoken.security.KeyAlgorithm;
import org.gcube.io.jsonwebtoken.security.KeyLengthSupplier;
import org.gcube.io.jsonwebtoken.security.KeyPairBuilder;
import org.gcube.io.jsonwebtoken.security.KeyRequest;
import org.gcube.io.jsonwebtoken.security.KeyResult;
import org.gcube.io.jsonwebtoken.security.OctetPublicJwk;
import org.gcube.io.jsonwebtoken.security.PublicJwk;
import org.gcube.io.jsonwebtoken.security.Request;
import org.gcube.io.jsonwebtoken.security.SecureRequest;
import org.gcube.io.jsonwebtoken.security.SecurityException;

class EcdhKeyAlgorithm
extends CryptoAlgorithm
implements KeyAlgorithm<PublicKey, PrivateKey> {
    protected static final String JCA_NAME = "ECDH";
    protected static final String XDH_JCA_NAME = "XDH";
    protected static final String DEFAULT_ID = "ECDH-ES";
    private static final String CONCAT_KDF_HASH_ALG_NAME = "SHA-256";
    private static final ConcatKDF CONCAT_KDF = new ConcatKDF("SHA-256");
    private final KeyAlgorithm<SecretKey, SecretKey> WRAP_ALG;

    private static String idFor(KeyAlgorithm<SecretKey, SecretKey> wrapAlg) {
        return wrapAlg instanceof DirectKeyAlgorithm ? DEFAULT_ID : "ECDH-ES+" + wrapAlg.getId();
    }

    EcdhKeyAlgorithm() {
        this(new DirectKeyAlgorithm());
    }

    EcdhKeyAlgorithm(KeyAlgorithm<SecretKey, SecretKey> wrapAlg) {
        super(EcdhKeyAlgorithm.idFor(wrapAlg), JCA_NAME);
        this.WRAP_ALG = Assert.notNull(wrapAlg, "Wrap algorithm cannot be null.");
    }

    protected KeyPair generateKeyPair(Curve curve, Provider provider, SecureRandom random) {
        return (KeyPair)((KeyPairBuilder)((KeyPairBuilder)curve.keyPair().provider(provider)).random(random)).build();
    }

    protected byte[] generateZ(final KeyRequest<?> request, final PublicKey pub, final PrivateKey priv) {
        return this.jca(request).withKeyAgreement(new CheckedFunction<KeyAgreement, byte[]>(){

            @Override
            public byte[] apply(KeyAgreement keyAgreement) throws Exception {
                keyAgreement.init((Key)KeysBridge.root(priv), CryptoAlgorithm.ensureSecureRandom(request));
                keyAgreement.doPhase(pub, true);
                return keyAgreement.generateSecret();
            }
        });
    }

    protected String getConcatKDFAlgorithmId(AeadAlgorithm enc) {
        return this.WRAP_ALG instanceof DirectKeyAlgorithm ? Assert.hasText(enc.getId(), "AeadAlgorithm id cannot be null or empty.") : this.getId();
    }

    private byte[] createOtherInfo(int keydatalen, String AlgorithmID, byte[] PartyUInfo, byte[] PartyVInfo) {
        Assert.hasText(AlgorithmID, "AlgorithmId cannot be null or empty.");
        byte[] algIdBytes = AlgorithmID.getBytes(StandardCharsets.US_ASCII);
        PartyUInfo = Arrays.length(PartyUInfo) == 0 ? Bytes.EMPTY : PartyUInfo;
        PartyVInfo = Arrays.length(PartyVInfo) == 0 ? Bytes.EMPTY : PartyVInfo;
        return Bytes.concat(Bytes.toBytes(algIdBytes.length), algIdBytes, Bytes.toBytes(PartyUInfo.length), PartyUInfo, Bytes.toBytes(PartyVInfo.length), PartyVInfo, Bytes.toBytes(keydatalen), Bytes.EMPTY);
    }

    private int getKeyBitLength(AeadAlgorithm enc) {
        int bitLength = this.WRAP_ALG instanceof KeyLengthSupplier ? ((KeyLengthSupplier)((Object)this.WRAP_ALG)).getKeyBitLength() : enc.getKeyBitLength();
        return Assert.gt(bitLength, 0, "Algorithm keyBitLength must be > 0");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private SecretKey deriveKey(KeyRequest<?> request, PublicKey publicKey, PrivateKey privateKey) {
        AeadAlgorithm enc = Assert.notNull(request.getEncryptionAlgorithm(), "Request encryptionAlgorithm cannot be null.");
        int requiredCekBitLen = this.getKeyBitLength(enc);
        String AlgorithmID = this.getConcatKDFAlgorithmId(enc);
        byte[] apu = request.getHeader().getAgreementPartyUInfo();
        byte[] apv = request.getHeader().getAgreementPartyVInfo();
        byte[] OtherInfo = this.createOtherInfo(requiredCekBitLen, AlgorithmID, apu, apv);
        byte[] Z = this.generateZ(request, publicKey, privateKey);
        try {
            SecretKey secretKey = CONCAT_KDF.deriveKey(Z, requiredCekBitLen, OtherInfo);
            return secretKey;
        }
        finally {
            Bytes.clear(Z);
        }
    }

    @Override
    protected String getJcaName(Request<?> request) {
        if (request instanceof SecureRequest) {
            return ((SecureRequest)request).getKey() instanceof ECKey ? super.getJcaName(request) : XDH_JCA_NAME;
        }
        return request.getPayload() instanceof ECKey ? super.getJcaName(request) : XDH_JCA_NAME;
    }

    private static AbstractCurve assertCurve(Key key) {
        Curve curve = StandardCurves.findByKey(key);
        if (curve == null) {
            String type = key instanceof PublicKey ? "encryption " : "decryption ";
            String msg = "Unable to determine JWA-standard Elliptic Curve for " + type + "key [" + KeysBridge.toString(key) + "]";
            throw new InvalidKeyException(msg);
        }
        if (curve instanceof EdwardsCurve && ((EdwardsCurve)curve).isSignatureCurve()) {
            String msg = curve.getId() + " keys may not be used with ECDH-ES key agreement algorithms per " + "https://www.rfc-editor.org/rfc/rfc8037#section-3.1.";
            throw new InvalidKeyException(msg);
        }
        return Assert.isInstanceOf(AbstractCurve.class, curve, "AbstractCurve instance expected.");
    }

    @Override
    public KeyResult getEncryptionKey(KeyRequest<PublicKey> request) throws SecurityException {
        Assert.notNull(request, "Request cannot be null.");
        JweHeader header = Assert.notNull(request.getHeader(), "Request JweHeader cannot be null.");
        PublicKey publicKey = (PublicKey)Assert.notNull(request.getPayload(), "Encryption PublicKey cannot be null.");
        AbstractCurve curve = EcdhKeyAlgorithm.assertCurve(publicKey);
        Assert.stateNotNull(curve, "Internal implementation state: Curve cannot be null.");
        SecureRandom random = EcdhKeyAlgorithm.ensureSecureRandom(request);
        DynamicJwkBuilder jwkBuilder = (DynamicJwkBuilder)Jwks.builder().random(random);
        KeyPair pair = this.generateKeyPair(curve, null, random);
        Assert.stateNotNull(pair, "Internal implementation state: KeyPair cannot be null.");
        PublicJwk jwk = (PublicJwk)jwkBuilder.key(pair.getPublic()).build();
        SecretKey derived = this.deriveKey(request, publicKey, pair.getPrivate());
        DefaultKeyRequest<SecretKey> wrapReq = new DefaultKeyRequest<SecretKey>(derived, request.getProvider(), request.getSecureRandom(), request.getHeader(), request.getEncryptionAlgorithm());
        KeyResult result = this.WRAP_ALG.getEncryptionKey(wrapReq);
        header.put(DefaultJweHeader.EPK.getId(), jwk);
        return result;
    }

    @Override
    public SecretKey getDecryptionKey(DecryptionKeyRequest<PrivateKey> request) throws SecurityException {
        Class epkClass;
        Assert.notNull(request, "Request cannot be null.");
        JweHeader header = Assert.notNull(request.getHeader(), "Request JweHeader cannot be null.");
        PrivateKey privateKey = (PrivateKey)Assert.notNull(request.getKey(), "Decryption PrivateKey cannot be null.");
        RequiredParameterReader reader = new RequiredParameterReader(header);
        PublicJwk<?> epk = reader.get(DefaultJweHeader.EPK);
        AbstractCurve curve = EcdhKeyAlgorithm.assertCurve(privateKey);
        Assert.stateNotNull(curve, "Internal implementation state: Curve cannot be null.");
        Class clazz = epkClass = curve instanceof ECCurve ? EcPublicJwk.class : OctetPublicJwk.class;
        if (!epkClass.isInstance(epk)) {
            String msg = "JWE Header " + DefaultJweHeader.EPK + " value is not an Elliptic Curve " + "Public JWK. Value: " + epk;
            throw new InvalidKeyException(msg);
        }
        if (!curve.contains((Key)epk.toKey())) {
            String msg = "JWE Header " + DefaultJweHeader.EPK + " value does not represent " + "a point on the expected curve. Value: " + epk;
            throw new InvalidKeyException(msg);
        }
        SecretKey derived = this.deriveKey(request, (PublicKey)epk.toKey(), privateKey);
        DefaultDecryptionKeyRequest<SecretKey> unwrapReq = new DefaultDecryptionKeyRequest<SecretKey>((byte[])request.getPayload(), null, request.getSecureRandom(), header, request.getEncryptionAlgorithm(), derived);
        return this.WRAP_ALG.getDecryptionKey(unwrapReq);
    }
}

