From ca99f635698ecc71fcdb08bee035888ab6e7a4d8 Mon Sep 17 00:00:00 2001 From: InfiniteLoopSpace <35842605+InfiniteLoopSpace@users.noreply.github.com> Date: Fri, 16 Nov 2018 13:19:38 +0100 Subject: [PATCH] Added cms/protocol which is a fork of https://github.com/mastahyeti/cms/tree/master/protocol with support for encryption. --- cms/protocol/asn1.go | 25 ++ cms/protocol/attribute.go | 74 +++++ cms/protocol/authenvdata.go | 145 ++++++++++ cms/protocol/contentinfo.go | 55 ++++ cms/protocol/eci.go | 40 +++ cms/protocol/enci.go | 34 +++ cms/protocol/envelopeddata.go | 121 +++++++++ cms/protocol/error.go | 39 +++ cms/protocol/issuerserialnumber.go | 77 ++++++ cms/protocol/reciepientinfo.go | 212 +++++++++++++++ cms/protocol/signeddata.go | 422 +++++++++++++++++++++++++++++ cms/protocol/signerinfo.go | 153 +++++++++++ 12 files changed, 1397 insertions(+) create mode 100644 cms/protocol/asn1.go create mode 100644 cms/protocol/attribute.go create mode 100644 cms/protocol/authenvdata.go create mode 100644 cms/protocol/contentinfo.go create mode 100644 cms/protocol/eci.go create mode 100644 cms/protocol/enci.go create mode 100644 cms/protocol/envelopeddata.go create mode 100644 cms/protocol/error.go create mode 100644 cms/protocol/issuerserialnumber.go create mode 100644 cms/protocol/reciepientinfo.go create mode 100644 cms/protocol/signeddata.go create mode 100644 cms/protocol/signerinfo.go diff --git a/cms/protocol/asn1.go b/cms/protocol/asn1.go new file mode 100644 index 0000000..018d1d7 --- /dev/null +++ b/cms/protocol/asn1.go @@ -0,0 +1,25 @@ +package protocol + +import ( + "encoding/asn1" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" +) + +// RawValue marshals val and returns the asn1.RawValue +func RawValue(val interface{}, params ...string) (rv asn1.RawValue, err error) { + param := "" + if len(params) > 0 { + param = params[0] + } + + var der []byte + if der, err = asn.MarshalWithParams(val, param); err != nil { + return + } + + if _, err = asn.Unmarshal(der, &rv); err != nil { + return + } + return +} diff --git a/cms/protocol/attribute.go b/cms/protocol/attribute.go new file mode 100644 index 0000000..a0935cd --- /dev/null +++ b/cms/protocol/attribute.go @@ -0,0 +1,74 @@ +package protocol + +import ( + "encoding/asn1" +) + +// Attribute ::= SEQUENCE { +// attrType OBJECT IDENTIFIER, +// attrValues SET OF AttributeValue } +// +// AttributeValue ::= ANY +type Attribute struct { + Type asn1.ObjectIdentifier + + // This should be a SET OF ANY, but Go's asn1 parser can't handle slices of + // RawValues. Use value() to get an AnySet of the value. + RawValue []asn1.RawValue `asn1:"set"` +} + +// NewAttribute creates a single-value Attribute. +func NewAttribute(attrType asn1.ObjectIdentifier, val interface{}) (attr Attribute, err error) { + var rv asn1.RawValue + if rv, err = RawValue(val); err != nil { + return + } + + attr = Attribute{attrType, []asn1.RawValue{rv}} + + return +} + +// Attributes is a common Go type for SignedAttributes and UnsignedAttributes. +// +// SignedAttributes ::= SET SIZE (1..MAX) OF Attribute +// +// UnsignedAttributes ::= SET SIZE (1..MAX) OF Attribute +type Attributes []Attribute + +// GetOnlyAttributeValueBytes gets an attribute value, returning an error if the +// attribute occurs multiple times or has multiple values. +func (attrs Attributes) GetOnlyAttributeValueBytes(oid asn1.ObjectIdentifier) (rv asn1.RawValue, err error) { + var vals [][]asn1.RawValue + if vals, err = attrs.GetValues(oid); err != nil { + return + } + if len(vals) != 1 { + err = ASN1Error{"bad attribute count"} + return + } + if len(vals[0]) != 1 { + err = ASN1Error{"bad attribute element count"} + return + } + + return vals[0][0], nil +} + +// GetValues retreives the attributes with the given OID. A nil value is +// returned if the OPTIONAL SET of Attributes is missing from the SignerInfo. An +// empty slice is returned if the specified attribute isn't in the set. +func (attrs Attributes) GetValues(oid asn1.ObjectIdentifier) ([][]asn1.RawValue, error) { + if attrs == nil { + return nil, nil + } + + vals := [][]asn1.RawValue{} + for _, attr := range attrs { + if attr.Type.Equal(oid) { + vals = append(vals, attr.RawValue) + } + } + + return vals, nil +} diff --git a/cms/protocol/authenvdata.go b/cms/protocol/authenvdata.go new file mode 100644 index 0000000..13549aa --- /dev/null +++ b/cms/protocol/authenvdata.go @@ -0,0 +1,145 @@ +package protocol + +import ( + "crypto/tls" + "encoding/asn1" + "log" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +//AuthEnvelopedData ::= SEQUENCE { +// version CMSVersion, +// originatorInfo [0] IMPLICIT OriginatorInfo OPTIONAL, +// recipientInfos RecipientInfos, +// authEncryptedContentInfo EncryptedContentInfo, +/// authAttrs [1] IMPLICIT AuthAttributes OPTIONAL, +// mac MessageAuthenticationCode, +// unauthAttrs [2] IMPLICIT UnauthAttributes OPTIONAL } +//https://tools.ietf.org/html/rfc5083##section-2.1 +type AuthEnvelopedData struct { + Version int + OriginatorInfo asn1.RawValue `asn1:"optional,tag:0"` + RecipientInfos []RecipientInfo `asn1:"set,choice"` + AECI EncryptedContentInfo + AauthAttrs []Attribute `asn1:"set,optional,tag:1"` + MAC []byte + UnAauthAttrs []Attribute `asn1:"set,optional,tag:2"` +} + +// Decrypt decrypts AuthEnvelopedData and returns the plaintext. +func (ed *AuthEnvelopedData) Decrypt(keyPair []tls.Certificate) (plain []byte, err error) { + + // Find the right key + var key []byte + for i := range keyPair { + key, err = ed.decryptKey(keyPair[i]) + switch err { + case ErrNoKeyFound: + continue + case nil: + break + default: + return + } + } + + encAlg := &oid.EncryptionAlgorithm{ + Key: key, + ContentEncryptionAlgorithmIdentifier: ed.AECI.ContentEncryptionAlgorithm, + } + encAlg.MAC = ed.MAC + + plain, err = encAlg.Decrypt(ed.AECI.EContent) + + return +} + +func (ed *AuthEnvelopedData) decryptKey(keyPair tls.Certificate) (key []byte, err error) { + + for i := range ed.RecipientInfos { + + key, err = ed.RecipientInfos[i].decryptKey(keyPair) + if key != nil { + return + } + } + return nil, ErrNoKeyFound +} + +// NewAuthEnvelopedData creates AuthEnvelopedData from an EncryptedContentInfo with mac and given RecipientInfos. +func NewAuthEnvelopedData(eci *EncryptedContentInfo, reciInfos []RecipientInfo, mac []byte) AuthEnvelopedData { + version := 0 + + ed := AuthEnvelopedData{ + Version: version, + RecipientInfos: reciInfos, + AECI: *eci, + MAC: mac, + } + + return ed +} + +func authcontentInfo(ed AuthEnvelopedData) (ci ContentInfo, err error) { + + der, err := asn.Marshal(ed) + if err != nil { + return + } + + ci = ContentInfo{ + ContentType: oid.AuthEnvelopedData, + Content: asn1.RawValue{ + Class: asn1.ClassContextSpecific, + Tag: 0, + Bytes: der, + IsCompound: true, + }, + } + + return +} + +// ContentInfo marshals AuthEnvelopedData and returns ContentInfo. +func (ed AuthEnvelopedData) ContentInfo() (ContentInfo, error) { + nilCI := *new(ContentInfo) + + der, err := asn.Marshal(ed) + if err != nil { + log.Fatal(err) + } + + if err != nil { + return nilCI, err + } + + return ContentInfo{ + ContentType: oid.AuthEnvelopedData, + Content: asn1.RawValue{ + Class: asn1.ClassContextSpecific, + Tag: 0, + Bytes: der, + IsCompound: true, + }, + }, nil + +} + +// AuthEnvelopedDataContent unmarshals ContentInfo and returns AuthEnvelopedData if +// content type is AuthEnvelopedData. +func (ci ContentInfo) AuthEnvelopedDataContent() (*AuthEnvelopedData, error) { + if !ci.ContentType.Equal(oid.AuthEnvelopedData) { + return nil, ErrWrongType + } + + ed := new(AuthEnvelopedData) + if rest, err := asn.Unmarshal(ci.Content.Bytes, ed); err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, ErrTrailingData + } + + return ed, nil +} diff --git a/cms/protocol/contentinfo.go b/cms/protocol/contentinfo.go new file mode 100644 index 0000000..b7c02fc --- /dev/null +++ b/cms/protocol/contentinfo.go @@ -0,0 +1,55 @@ +// Package protocol implemets parts of cryptographic message syntax RFC 5652. +// This package is mostly for handling of the asn1 sturctures of cms. For +// de/encryption and signing/verfiying use to package cms. +package protocol + +import ( + "encoding/asn1" + "fmt" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" + "github.com/InfiniteLoopSpace/go_S-MIME/b64" +) + +// ContentInfo ::= SEQUENCE { +// contentType ContentType, +// content [0] EXPLICIT ANY DEFINED BY contentType } +// +// ContentType ::= OBJECT IDENTIFIER +type ContentInfo struct { + ContentType asn1.ObjectIdentifier + Content asn1.RawValue `asn1:"explicit,tag:0"` +} + +// ParseContentInfo parses DER-encoded ASN.1 data and returns ContentInfo. +func ParseContentInfo(der []byte) (ci ContentInfo, err error) { + + if err != nil { + return + } + + var rest []byte + if rest, err = asn.Unmarshal(der, &ci); err != nil { + return + } + if len(rest) > 0 { + fmt.Println(ErrTrailingData) + //err = ErrTrailingData + } + + return +} + +// DER returns the DER-encoded ASN.1 data. +func (ci ContentInfo) DER() ([]byte, error) { + return asn.Marshal(ci) +} + +// Base64 encodes the DER-encoded ASN.1 data in base64 for use in S/MIME. +func (ci ContentInfo) Base64() ([]byte, error) { + der, err := ci.DER() + if err != nil { + return nil, err + } + return b64.EncodeBase64(der) +} diff --git a/cms/protocol/eci.go b/cms/protocol/eci.go new file mode 100644 index 0000000..b2b3d57 --- /dev/null +++ b/cms/protocol/eci.go @@ -0,0 +1,40 @@ +package protocol + +import ( + "crypto/x509/pkix" + "encoding/asn1" + + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +//EncryptedContentInfo ::= SEQUENCE { +// contentType ContentType, +// contentEncryptionAlgorithm ContentEncryptionAlgorithmIdentifier, +// encryptedContent [0] IMPLICIT EncryptedContent OPTIONAL } +type EncryptedContentInfo struct { + EContentType asn1.ObjectIdentifier + ContentEncryptionAlgorithm pkix.AlgorithmIdentifier + EContent []byte `asn1:"optional,implicit,tag:0"` +} + +// NewEncryptedContentInfo encrypts the conent with the contentEncryptionAlgorithm and retuns +// the EncryptedContentInfo, the key and the MAC. +func NewEncryptedContentInfo(contentType asn1.ObjectIdentifier, contentEncryptionAlg asn1.ObjectIdentifier, content []byte) (eci EncryptedContentInfo, key, mac []byte, err error) { + + encAlg := &oid.EncryptionAlgorithm{ + EncryptionAlgorithmIdentifier: contentEncryptionAlg, + } + + ciphertext, err := encAlg.Encrypt(content) + if err != nil { + return + } + + eci = EncryptedContentInfo{ + EContentType: contentType, + ContentEncryptionAlgorithm: encAlg.ContentEncryptionAlgorithmIdentifier, + EContent: ciphertext, + } + + return eci, encAlg.Key, encAlg.MAC, nil +} diff --git a/cms/protocol/enci.go b/cms/protocol/enci.go new file mode 100644 index 0000000..7a10a3e --- /dev/null +++ b/cms/protocol/enci.go @@ -0,0 +1,34 @@ +package protocol + +import ( + "encoding/asn1" + + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +// EncapsulatedContentInfo ::= SEQUENCE { +// eContentType ContentType, +// eContent [0] EXPLICIT OCTET STRING OPTIONAL } +type EncapsulatedContentInfo struct { + EContentType asn1.ObjectIdentifier `` // ContentType ::= OBJECT IDENTIFIER + EContent []byte `asn1:"optional,explicit,tag:0"` // +} + +// NewDataEncapsulatedContentInfo creates a new EncapsulatedContentInfo of type +// id-data. +func NewDataEncapsulatedContentInfo(data []byte) (EncapsulatedContentInfo, error) { + return NewEncapsulatedContentInfo(oid.Data, data) +} + +// NewEncapsulatedContentInfo creates a new EncapsulatedContentInfo. +func NewEncapsulatedContentInfo(contentType asn1.ObjectIdentifier, content []byte) (EncapsulatedContentInfo, error) { + return EncapsulatedContentInfo{ + EContentType: contentType, + EContent: content, + }, nil +} + +// IsTypeData checks if the EContentType is id-data. +func (eci EncapsulatedContentInfo) IsTypeData() bool { + return eci.EContentType.Equal(oid.Data) +} diff --git a/cms/protocol/envelopeddata.go b/cms/protocol/envelopeddata.go new file mode 100644 index 0000000..06b635a --- /dev/null +++ b/cms/protocol/envelopeddata.go @@ -0,0 +1,121 @@ +package protocol + +import ( + "crypto/tls" + "encoding/asn1" + "log" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +//EnvelopedData ::= SEQUENCE { +// version CMSVersion, +// originatorInfo [0] IMPLICIT OriginatorInfo OPTIONAL, +// recipientInfos RecipientInfos, +// encryptedContentInfo EncryptedContentInfo, +// unprotectedAttrs [1] IMPLICIT UnprotectedAttributes OPTIONAL } +type EnvelopedData struct { + Version int + OriginatorInfo asn1.RawValue `asn1:"optional,tag:0"` + RecipientInfos []RecipientInfo `asn1:"set,choice"` + ECI EncryptedContentInfo `` + UnprotectedAttrs []Attribute `asn1:"set,optional,tag:1"` +} + +// Decrypt decrypts the EnvelopedData with the given keyPair and retuns the plaintext. +func (ed *EnvelopedData) Decrypt(keyPairs []tls.Certificate) (plain []byte, err error) { + + // Find the right key + var key []byte + for i := range keyPairs { + key, err = ed.decryptKey(keyPairs[i]) + switch err { + case ErrNoKeyFound: + continue + case nil: + break + default: + return + } + } + if key == nil { + return nil, ErrNoKeyFound + } + + encAlg := &oid.EncryptionAlgorithm{ + Key: key, + ContentEncryptionAlgorithmIdentifier: ed.ECI.ContentEncryptionAlgorithm, + } + + plain, err = encAlg.Decrypt(ed.ECI.EContent) + + return +} + +func (ed *EnvelopedData) decryptKey(keyPair tls.Certificate) (key []byte, err error) { + + for i := range ed.RecipientInfos { + + key, err = ed.RecipientInfos[i].decryptKey(keyPair) + if key != nil { + return + } + } + return nil, ErrNoKeyFound +} + +// EnvelopedDataContent returns EnvelopedData if ContentType is EnvelopedData. +func (ci ContentInfo) EnvelopedDataContent() (*EnvelopedData, error) { + if !ci.ContentType.Equal(oid.EnvelopedData) { + return nil, ErrWrongType + } + + //var Ed interface{} + ed := new(EnvelopedData) + if rest, err := asn.Unmarshal(ci.Content.Bytes, ed); err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, ErrTrailingData + } + + return ed, nil +} + +// ContentInfo returns new ContentInfo with ContentType EnvelopedData. +func (ed EnvelopedData) ContentInfo() (ContentInfo, error) { + nilCI := *new(ContentInfo) + + der, err := asn.Marshal(ed) + if err != nil { + log.Fatal(err) + } + + if err != nil { + return nilCI, err + } + + return ContentInfo{ + ContentType: oid.EnvelopedData, + Content: asn1.RawValue{ + Class: asn1.ClassContextSpecific, + Tag: 0, + Bytes: der, + IsCompound: true, + }, + }, nil + +} + +// NewEnvelopedData creates a new EnvelopedData from the given data. +func NewEnvelopedData(eci *EncryptedContentInfo, reciInfos []RecipientInfo) EnvelopedData { + version := 0 + + ed := EnvelopedData{ + Version: version, + RecipientInfos: reciInfos, + ECI: *eci, + } + + return ed +} diff --git a/cms/protocol/error.go b/cms/protocol/error.go new file mode 100644 index 0000000..19b9dea --- /dev/null +++ b/cms/protocol/error.go @@ -0,0 +1,39 @@ +package protocol + +import ( + "errors" + "fmt" +) + +// ASN1Error is an error from parsing ASN.1 structures. +type ASN1Error struct { + Message string +} + +// Error implements the error interface. +func (err ASN1Error) Error() string { + return fmt.Sprintf("cms/protocol: ASN.1 Error — %s", err.Message) +} + +var ( + // ErrWrongType is returned by methods that make assumptions about types. + // Helper methods are defined for accessing CHOICE and ANY feilds. These + // helper methods get the value of the field, assuming it is of a given type. + // This error is returned if that assumption is wrong and the field has a + // different type. + ErrWrongType = errors.New("cms/protocol: wrong choice or any type") + + // ErrNoCertificate is returned when a requested certificate cannot be found. + ErrNoCertificate = errors.New("no certificate found") + + // ErrNoKeyFound is returned when a requested certificate cannot be found. + ErrNoKeyFound = errors.New("no key for decryption found") + + // ErrUnsupported is returned when an unsupported type or version + // is encountered. + ErrUnsupported = ASN1Error{"unsupported type or version"} + + // ErrTrailingData is returned when extra data is found after parsing an ASN.1 + // structure. + ErrTrailingData = ASN1Error{"unexpected trailing data"} +) diff --git a/cms/protocol/issuerserialnumber.go b/cms/protocol/issuerserialnumber.go new file mode 100644 index 0000000..a90fc5e --- /dev/null +++ b/cms/protocol/issuerserialnumber.go @@ -0,0 +1,77 @@ +package protocol + +import ( + "bytes" + "crypto/x509" + "encoding/asn1" + "fmt" + "math/big" +) + +// IssuerAndSerialNumber ::= SEQUENCE { +// issuer Name, +// serialNumber CertificateSerialNumber } +// +// CertificateSerialNumber ::= INTEGER +type IssuerAndSerialNumber struct { + Issuer asn1.RawValue + SerialNumber *big.Int +} + +// NewIssuerAndSerialNumber creates a IssuerAndSerialNumber SID for the given +// cert. +func NewIssuerAndSerialNumber(cert *x509.Certificate) (sid IssuerAndSerialNumber, err error) { + sid = IssuerAndSerialNumber{ + SerialNumber: new(big.Int).Set(cert.SerialNumber), + } + + if _, err = asn1.Unmarshal(cert.RawIssuer, &sid.Issuer); err != nil { + return + } + + return +} + +// RawValue returns the RawValue of the IssuerAndSerialNumber. +func (ias *IssuerAndSerialNumber) RawValue() (rv asn1.RawValue, err error) { + var der []byte + if der, err = asn1.Marshal(*ias); err != nil { + return + } + + if _, err = asn1.Unmarshal(der, &rv); err != nil { + return + } + + return +} + +// Equal returns true if ias and ias2 agree. +func (ias *IssuerAndSerialNumber) Equal(ias2 IssuerAndSerialNumber) bool { + + if bytes.Compare(ias.Issuer.Bytes, ias2.Issuer.Bytes) != 0 { + return false + } + + if ias.SerialNumber.Cmp(ias2.SerialNumber) != 0 { + return false + } + + return true +} + +// IASstring retuns the ias of the cert as hex encoded string. +func IASstring(cert *x509.Certificate) (iasString string, err error) { + ias, err := NewIssuerAndSerialNumber(cert) + if err != nil { + return + } + + rv, err := ias.RawValue() + if err != nil { + return + } + + iasString = fmt.Sprintf("%x", rv.Bytes) + return +} diff --git a/cms/protocol/reciepientinfo.go b/cms/protocol/reciepientinfo.go new file mode 100644 index 0000000..be933bc --- /dev/null +++ b/cms/protocol/reciepientinfo.go @@ -0,0 +1,212 @@ +package protocol + +import ( + "bytes" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "log" + "time" + + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +//RecipientInfo ::= CHOICE { +// ktri KeyTransRecipientInfo, +// kari [1] KeyAgreeRecipientInfo, +// kekri [2] KEKRecipientInfo, +// pwri [3] PasswordRecipientInfo, +// ori [4] OtherRecipientInfo } +type RecipientInfo struct { + KTRI KeyTransRecipientInfo `asn1:"optional"` + KARI KeyAgreeRecipientInfo `asn1:"optional,tag:1"` //KeyAgreeRecipientInfo + KEKRI asn1.RawValue `asn1:"optional,tag:2"` + PWRI asn1.RawValue `asn1:"optional,tag:3"` + ORI asn1.RawValue `asn1:"optional,tag:4"` +} + +func (recInfo *RecipientInfo) decryptKey(keyPair tls.Certificate) (key []byte, err error) { + + return recInfo.KTRI.decryptKey(keyPair) + +} + +//KeyTransRecipientInfo ::= SEQUENCE { +// version CMSVersion, -- always set to 0 or 2 +// rid RecipientIdentifier, +// keyEncryptionAlgorithm KeyEncryptionAlgorithmIdentifier, +// encryptedKey EncryptedKey } +type KeyTransRecipientInfo struct { + Version int + Rid RecipientIdentifier `asn1:"choice"` + KeyEncryptionAlgorithm pkix.AlgorithmIdentifier + EncryptedKey []byte +} + +func (ktri *KeyTransRecipientInfo) decryptKey(keyPair tls.Certificate) (key []byte, err error) { + + ias, err := NewIssuerAndSerialNumber(keyPair.Leaf) + if err != nil { + return + } + + ski := keyPair.Leaf.SubjectKeyId + + //version is the syntax version number. If the SignerIdentifier is + //the CHOICE issuerAndSerialNumber, then the version MUST be 1. If + //the SignerIdentifier is subjectKeyIdentifier, then the version + //MUST be 3. + switch ktri.Version { + case 0: + if ias.Equal(ktri.Rid.IAS) { + alg := oid.PublicKeyAlgorithmToEncrytionAlgorithm[keyPair.Leaf.PublicKeyAlgorithm].Algorithm + if ktri.KeyEncryptionAlgorithm.Algorithm.Equal(alg) { + + decrypter := keyPair.PrivateKey.(crypto.Decrypter) + return decrypter.Decrypt(rand.Reader, ktri.EncryptedKey, nil) + + } + log.Println("Key encrytion algorithm not matching") + } + case 2: + if bytes.Equal(ski, ktri.Rid.SKI) { + alg := oid.PublicKeyAlgorithmToEncrytionAlgorithm[keyPair.Leaf.PublicKeyAlgorithm].Algorithm + if ktri.KeyEncryptionAlgorithm.Algorithm.Equal(alg) { + if alg.Equal(oid.EncryptionAlgorithmRSA) { + return rsa.DecryptPKCS1v15(rand.Reader, keyPair.PrivateKey.(*rsa.PrivateKey), ktri.EncryptedKey) + } + log.Println("Unsupported key encrytion algorithm") + } + log.Println("Key encrytion algorithm not matching") + } + default: + fmt.Println(ktri.Version) + return nil, ErrUnsupported + } + + return nil, nil +} + +//RecipientIdentifier ::= CHOICE { +// issuerAndSerialNumber IssuerAndSerialNumber, +// subjectKeyIdentifier [0] SubjectKeyIdentifier } +type RecipientIdentifier struct { + IAS IssuerAndSerialNumber `asn1:"optional"` + SKI []byte `asn1:"optional,tag:0"` +} + +// NewRecipientInfo creates RecipientInfo for giben recipient and key. +func NewRecipientInfo(recipient *x509.Certificate, key []byte) RecipientInfo { + version := 0 //issuerAndSerialNumber + + rid := RecipientIdentifier{} + + switch version { + case 0: + ias, err := NewIssuerAndSerialNumber(recipient) + if err != nil { + log.Fatal(err) + } + rid.IAS = ias + case 2: + rid.SKI = recipient.SubjectKeyId + } + + kea := oid.PublicKeyAlgorithmToEncrytionAlgorithm[recipient.PublicKeyAlgorithm] + if _, ok := oid.PublicKeyAlgorithmToEncrytionAlgorithm[recipient.PublicKeyAlgorithm]; !ok { + log.Fatal("NewRecipientInfo: PublicKeyAlgorithm not supported") + } + + encrypted, _ := encryptKey(key, recipient) + + info := RecipientInfo{ + KTRI: KeyTransRecipientInfo{ + Version: version, + Rid: rid, + KeyEncryptionAlgorithm: kea, + EncryptedKey: encrypted, + }} + return info +} + +func encryptKey(key []byte, recipient *x509.Certificate) ([]byte, error) { + if pub := recipient.PublicKey.(*rsa.PublicKey); pub != nil { + return rsa.EncryptPKCS1v15(rand.Reader, pub, key) + } + return nil, ErrUnsupportedAlgorithm +} + +// ErrUnsupportedAlgorithm is returned if the algorithm is unsupported. +var ErrUnsupportedAlgorithm = errors.New("cms: cannot decrypt data: unsupported algorithm") + +//KeyAgreeRecipientInfo ::= SEQUENCE { +// version CMSVersion, -- always set to 3 +// originator [0] EXPLICIT OriginatorIdentifierOrKey, +// ukm [1] EXPLICIT UserKeyingMaterial OPTIONAL, +// keyEncryptionAlgorithm KeyEncryptionAlgorithmIdentifier, +// recipientEncryptedKeys RecipientEncryptedKeys } +type KeyAgreeRecipientInfo struct { + Version int + Originator OriginatorIdentifierOrKey `asn1:"explicit,choice,tag:0"` + UKM []byte `asn1:"explicit,optional,tag:1"` + KeyEncryptionAlgorithm pkix.AlgorithmIdentifier `` + RecipientEncryptedKeys []RecipientEncryptedKey `asn1:"sequence"` //RecipientEncryptedKeys ::= SEQUENCE OF RecipientEncryptedKey +} + +//OriginatorIdentifierOrKey ::= CHOICE { +// issuerAndSerialNumber IssuerAndSerialNumber, +// subjectKeyIdentifier [0] SubjectKeyIdentifier, +// originatorKey [1] OriginatorPublicKey } +type OriginatorIdentifierOrKey struct { + IAS IssuerAndSerialNumber `asn1:"optional"` + SKI []byte `asn1:"optional,tag:0"` + OriginatorKey OriginatorPublicKey `asn1:"optional,tag:1"` +} + +//OriginatorPublicKey ::= SEQUENCE { +// algorithm AlgorithmIdentifier, +// publicKey BIT STRING +type OriginatorPublicKey struct { + Algorithm pkix.AlgorithmIdentifier + PublicKey asn1.BitString +} + +//RecipientEncryptedKey ::= SEQUENCE { +// rid KeyAgreeRecipientIdentifier, +// encryptedKey EncryptedKey } +type RecipientEncryptedKey struct { + RID KeyAgreeRecipientIdentifier `asn1:"choice"` + EncryptedKey []byte +} + +//KeyAgreeRecipientIdentifier ::= CHOICE { +// issuerAndSerialNumber IssuerAndSerialNumber, +// rKeyId [0] IMPLICIT RecipientKeyIdentifier } +type KeyAgreeRecipientIdentifier struct { + IAS IssuerAndSerialNumber `asn1:"optional"` + RKeyID RecipientKeyIdentifier `asn1:"optional,tag:0"` +} + +//RecipientKeyIdentifier ::= SEQUENCE { +// subjectKeyIdentifier SubjectKeyIdentifier, +// date GeneralizedTime OPTIONAL, +// other OtherKeyAttribute OPTIONAL } +type RecipientKeyIdentifier struct { + SubjectKeyIdentifier []byte //SubjectKeyIdentifier ::= OCTET STRING + Date time.Time `asn1:"optional"` + Other OtherKeyAttribute `asn1:"optional"` +} + +//OtherKeyAttribute ::= SEQUENCE { +// keyAttrId OBJECT IDENTIFIER, +// keyAttr ANY DEFINED BY keyAttrId OPTIONAL } +type OtherKeyAttribute struct { + KeyAttrID asn1.ObjectIdentifier + KeyAttr asn1.RawValue `asn1:"optional"` +} diff --git a/cms/protocol/signeddata.go b/cms/protocol/signeddata.go new file mode 100644 index 0000000..4481372 --- /dev/null +++ b/cms/protocol/signeddata.go @@ -0,0 +1,422 @@ +package protocol + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "errors" + "fmt" + "io/ioutil" + "net/http" + "time" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +// SignedDataContent returns SignedData if ContentType is SignedData. +func (ci ContentInfo) SignedDataContent() (*SignedData, error) { + if !ci.ContentType.Equal(oid.SignedData) { + return nil, ErrWrongType + } + + sd := new(SignedData) + if rest, err := asn.Unmarshal(ci.Content.Bytes, sd); err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, ErrTrailingData + } + + return sd, nil +} + +// SignedData ::= SEQUENCE { +// version CMSVersion, +// digestAlgorithms DigestAlgorithmIdentifiers, +// encapContentInfo EncapsulatedContentInfo, +// certificates [0] IMPLICIT CertificateSet OPTIONAL, +// crls [1] IMPLICIT RevocationInfoChoices OPTIONAL, +// signerInfos SignerInfos } +type SignedData struct { + Version int `` // CMSVersion ::= INTEGER { v0(0), v1(1), v2(2), v3(3), v4(4), v5(5) } + DigestAlgorithms []pkix.AlgorithmIdentifier `asn1:"set"` //DigestAlgorithmIdentifiers ::= SET OF DigestAlgorithmIdentifier //DigestAlgorithmIdentifier ::= AlgorithmIdentifier + EncapContentInfo EncapsulatedContentInfo `` // + Certificates []asn1.RawValue `asn1:"optional,set,tag:0"` // CertificateSet ::= SET OF CertificateChoices + CRLs []RevocationInfoChoice `asn1:"optional,set,tag:1"` // RevocationInfoChoices ::= SET OF RevocationInfoChoice + SignerInfos []SignerInfo `asn1:"set"` // SignerInfos ::= SET OF SignerInfo +} + +// CertificateChoices ::= CHOICE { +// certificate Certificate, +// extendedCertificate [0] IMPLICIT ExtendedCertificate, -- Obsolete +// v1AttrCert [1] IMPLICIT AttributeCertificateV1, -- Obsolete +// v2AttrCert [2] IMPLICIT AttributeCertificateV2, +// other [3] IMPLICIT OtherCertificateFormat } +type CertificateChoices struct { + Cert x509.Certificate `asn1:"optional"` + V2AttrCert asn1.RawValue `asn1:"optional,tag:2"` + Other OtherCertificateFormat `asn1:"optional,tag:3"` +} + +// OtherCertificateFormat ::= SEQUENCE { +// otherCertFormat OBJECT IDENTIFIER, +// otherCert ANY DEFINED BY otherCertFormat } +type OtherCertificateFormat struct { + OtherCertFormat asn1.ObjectIdentifier + OtherCert asn1.RawValue +} + +// RevocationInfoChoice ::= CHOICE { +// crl CertificateList, +// other [1] IMPLICIT OtherRevocationInfoFormat } +type RevocationInfoChoice struct { + Crl pkix.CertificateList `asn1:"optional"` + Other OtherRevocationInfoFormat `asn1:"optional,tag:1"` +} + +// OtherRevocationInfoFormat ::= SEQUENCE { +// otherRevInfoFormat OBJECT IDENTIFIER, +// otherRevInfo ANY DEFINED BY otherRevInfoFormat } +type OtherRevocationInfoFormat struct { + OtherRevInfoFormat asn1.ObjectIdentifier + OtherRevInfo asn1.RawValue +} + +// NewSignedData creates a new SignedData. +func NewSignedData(eci EncapsulatedContentInfo) (*SignedData, error) { + // The version is picked based on which CMS features are used. We only use + // version 1 features, except for supporting non-data econtent. + version := 1 + if !eci.IsTypeData() { + version = 3 + } + + return &SignedData{ + Version: version, + DigestAlgorithms: []pkix.AlgorithmIdentifier{}, + EncapContentInfo: eci, + SignerInfos: []SignerInfo{}, + }, nil +} + +// AddSignerInfo adds a SignerInfo to the SignedData. +func (sd *SignedData) AddSignerInfo(keypPair tls.Certificate) (err error) { + + for _, cert := range keypPair.Certificate { + if err = sd.AddCertificate(cert); err != nil { + return + } + } + + signer := keypPair.PrivateKey.(crypto.Signer) + + cert := keypPair.Leaf + + ias, err := NewIssuerAndSerialNumber(cert) + if err != nil { + return err + } + + sid := SignerIdentifier{ias, nil} + + digestAlgorithm := digestAlgorithmForPublicKey(cert.PublicKey) + signatureAlgorithm, ok := oid.PublicKeyAlgorithmToSignatureAlgorithm[keypPair.Leaf.PublicKeyAlgorithm] + if !ok { + return errors.New("unsupported certificate public key algorithm") + } + + si := SignerInfo{ + Version: 1, + SID: sid, + DigestAlgorithm: digestAlgorithm, + SignedAttrs: nil, + SignatureAlgorithm: signatureAlgorithm, + Signature: nil, + UnsignedAttrs: nil, + } + + // Get the message + content := sd.EncapContentInfo.EContent + if err != nil { + return err + } + if content == nil { + return errors.New("already detached") + } + + // Digest the message. + hash, err := si.Hash() + if err != nil { + return err + } + md := hash.New() + if _, err = md.Write(content); err != nil { + return err + } + + // Build our SignedAttributes + mdAttr, err := NewAttribute(oid.AttributeMessageDigest, md.Sum(nil)) + if err != nil { + return err + } + ctAttr, err := NewAttribute(oid.AttributeContentType, sd.EncapContentInfo.EContentType) + if err != nil { + return err + } + sTAttr, err := NewAttribute(oid.AttributeSigningTime, time.Now()) + if err != nil { + return err + } + si.SignedAttrs = append(si.SignedAttrs, mdAttr, ctAttr, sTAttr) + + sm, err := asn.MarshalWithParams(si.SignedAttrs, `set`) + if err != nil { + return err + } + + smd := hash.New() + if _, errr := smd.Write(sm); errr != nil { + return errr + } + if si.Signature, err = signer.Sign(rand.Reader, smd.Sum(nil), hash); err != nil { + return err + } + + sd.addDigestAlgorithm(si.DigestAlgorithm) + + sd.SignerInfos = append(sd.SignerInfos, si) + + return nil +} + +// algorithmsForPublicKey takes an opinionated stance on what algorithms to use +// for the given public key. +func digestAlgorithmForPublicKey(pub crypto.PublicKey) pkix.AlgorithmIdentifier { + if ecPub, ok := pub.(*ecdsa.PublicKey); ok { + switch ecPub.Curve { + case elliptic.P384(): + return pkix.AlgorithmIdentifier{Algorithm: oid.DigestAlgorithmSHA384} + case elliptic.P521(): + return pkix.AlgorithmIdentifier{Algorithm: oid.DigestAlgorithmSHA512} + } + } + + return pkix.AlgorithmIdentifier{Algorithm: oid.DigestAlgorithmSHA256} +} + +// ClearCertificates removes all certificates. +func (sd *SignedData) ClearCertificates() { + sd.Certificates = []asn1.RawValue{} +} + +// AddCertificate adds a *x509.Certificate. +func (sd *SignedData) AddCertificate(cert []byte) error { + for _, existing := range sd.Certificates { + if bytes.Equal(existing.Bytes, cert) { + return errors.New("certificate already added") + } + } + + var rv asn1.RawValue + if _, err := asn.Unmarshal(cert, &rv); err != nil { + return err + } + + sd.Certificates = append(sd.Certificates, rv) + + return nil +} + +// addDigestAlgorithm adds a new AlgorithmIdentifier if it doesn't exist yet. +func (sd *SignedData) addDigestAlgorithm(algo pkix.AlgorithmIdentifier) { + for _, existing := range sd.DigestAlgorithms { + if existing.Algorithm.Equal(algo.Algorithm) { + return + } + } + + sd.DigestAlgorithms = append(sd.DigestAlgorithms, algo) +} + +// X509Certificates gets the certificates, assuming that they're X.509 encoded. +func (sd *SignedData) X509Certificates() (map[string]*x509.Certificate, error) { + // Certificates field is optional. Handle missing value. + if sd.Certificates == nil { + return nil, nil + } + + certs := map[string]*x509.Certificate{} + + // Empty set + if len(sd.Certificates) == 0 { + return certs, nil + } + + for _, raw := range sd.Certificates { + if raw.Class != asn1.ClassUniversal || raw.Tag != asn1.TagSequence { + return nil, ErrUnsupported + } + + x509, err := x509.ParseCertificate(raw.FullBytes) + if err != nil { + return nil, err + } + iasString, err := IASstring(x509) + certs[iasString] = x509 + if err != nil { + return nil, err + } + + } + return certs, nil +} + +// ContentInfo returns the SignedData wrapped in a ContentInfo packet. +func (sd *SignedData) ContentInfo() (ContentInfo, error) { + var nilCI ContentInfo + + der, err := asn.Marshal(*sd) + if err != nil { + return nilCI, err + } + + return ContentInfo{ + ContentType: oid.SignedData, + Content: asn1.RawValue{ + Class: asn1.ClassContextSpecific, + Tag: 0, + Bytes: der, + IsCompound: true, + }, + }, nil + +} + +// Verify checks the signature +func (sd *SignedData) Verify(Opts x509.VerifyOptions, detached []byte) (chains [][][]*x509.Certificate, err error) { + certs, _ := sd.X509Certificates() + + opts := Opts + + for _, c := range certs { + opts.Intermediates.AddCert(c) + intermediates, fetchErr := fetchIntermediates(c.IssuingCertificateURL) + for _, e := range fetchErr { + fmt.Printf("Error while fetching intermediates: %s\n", e) + } + + for _, i := range intermediates { + opts.Intermediates.AddCert(i) + } + } + + eContent := detached + if eContent == nil { + eContent = sd.EncapContentInfo.EContent + } + + for _, signer := range sd.SignerInfos { + //Find and check signer Certificate: + sidxxx, _ := signer.SID.IAS.RawValue() + sid := fmt.Sprintf("%x", sidxxx.Bytes) + + cert, exist := certs[sid] + + if !exist { + err = errors.New("Could not find a Certificate for signer with sid : " + sid) + return + } + + var chain [][]*x509.Certificate + chain, err = cert.Verify(Opts) + if err != nil { + return + } + + signedMessage := eContent + if signer.SignedAttrs != nil { + + //Hash message: + var hash crypto.Hash + hash, err = signer.Hash() + if err != nil { + return nil, err + } + md := hash.New() + + _, err = md.Write(eContent) + if err != nil { + return nil, err + } + h := md.Sum(nil) + + var messageDigestAttr []byte + messageDigestAttr, err = signer.GetMessageDigestAttribute() + if err != nil { + return + } + + if !bytes.Equal(messageDigestAttr, h) { + err = errors.New("Signed hash does not match the hash of the message") + return + } + + signedMessage, err = asn.MarshalWithParams(signer.SignedAttrs, `set`) + if err != nil { + return + } + } + + err = cert.CheckSignature(signer.X509SignatureAlgorithm(), signedMessage, signer.Signature) + if err != nil { + return + } + chains = append(chains, chain) + } + + return +} + +func fetchIntermediates(urls []string) (certificates []*x509.Certificate, errs []error) { + for _, url := range urls { + var resp *http.Response + resp, err := http.Get(url) + if err != nil { + errs = append(errs, err) + continue + } + defer resp.Body.Close() + + issuerBytes, err := ioutil.ReadAll(resp.Body) + if err != nil { + errs = append(errs, err) + continue + } + + issuerCert, err := x509.ParseCertificate(issuerBytes) + if err != nil { + errs = append(errs, err) + continue + } + + //Prevent infinite loop + if len(certificates) > 50 { + err = errors.New("To many issuers") + errs = append(errs, err) + return + } + certificates = append(certificates, issuerCert) + + //Recusively fetch issuers + issuers, fetchErrs := fetchIntermediates(issuerCert.IssuingCertificateURL) + certificates = append(certificates, issuers...) + errs = append(errs, fetchErrs...) + } + return +} diff --git a/cms/protocol/signerinfo.go b/cms/protocol/signerinfo.go new file mode 100644 index 0000000..f61e9d1 --- /dev/null +++ b/cms/protocol/signerinfo.go @@ -0,0 +1,153 @@ +package protocol + +import ( + "bytes" + "crypto" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "time" + + asn "github.com/InfiniteLoopSpace/go_S-MIME/asn1" + oid "github.com/InfiniteLoopSpace/go_S-MIME/oid" +) + +// SignerInfo ::= SEQUENCE { +// version CMSVersion, +// sid SignerIdentifier, +// digestAlgorithm DigestAlgorithmIdentifier, +// signedAttrs [0] IMPLICIT SignedAttributes OPTIONAL, +// signatureAlgorithm SignatureAlgorithmIdentifier, +// signature SignatureValue, +// unsignedAttrs [1] IMPLICIT UnsignedAttributes OPTIONAL } +type SignerInfo struct { + Version int `` // CMSVersion ::= INTEGER { v0(0), v1(1), v2(2), v3(3), v4(4), v5(5) } + SID SignerIdentifier `asn1:"choice"` // + DigestAlgorithm pkix.AlgorithmIdentifier `` // DigestAlgorithmIdentifier ::= AlgorithmIdentifier + SignedAttrs []Attribute `asn1:"set,optional,tag:0"` // SignedAttributes ::= SET SIZE (1..MAX) OF Attribute + SignatureAlgorithm pkix.AlgorithmIdentifier `` // SignatureAlgorithmIdentifier ::= AlgorithmIdentifier + Signature []byte `` // SignatureValue ::= OCTET STRING + UnsignedAttrs []Attribute `asn1:"set,optional,tag:1"` // UnsignedAttributes ::= SET SIZE (1..MAX) OF Attribute +} + +//SignerIdentifier ::= CHOICE { +// issuerAndSerialNumber IssuerAndSerialNumber, +// subjectKeyIdentifier [0] SubjectKeyIdentifier } +type SignerIdentifier struct { + IAS IssuerAndSerialNumber `asn1:"optional"` + SKI []byte `asn1:"optional,tag:0"` +} + +// FindCertificate finds this SignerInfo's certificate in a slice of +// certificates. +func (si SignerInfo) FindCertificate(certs []*x509.Certificate) (*x509.Certificate, error) { + switch si.Version { + case 1: // SID is issuer and serial number + isn := si.SID.IAS + + for _, cert := range certs { + if bytes.Equal(cert.RawIssuer, isn.Issuer.FullBytes) && isn.SerialNumber.Cmp(cert.SerialNumber) == 0 { + return cert, nil + } + } + case 3: // SID is SubjectKeyIdentifier + ski := si.SID.SKI + + for _, cert := range certs { + for _, ext := range cert.Extensions { + if oid.SubjectKeyIdentifier.Equal(ext.Id) { + if bytes.Equal(ski, ext.Value) { + return cert, nil + } + } + } + } + default: + return nil, ErrUnsupported + } + + return nil, ErrNoCertificate +} + +// Hash gets the crypto.Hash associated with this SignerInfo's DigestAlgorithm. +// 0 is returned for unrecognized algorithms. +func (si SignerInfo) Hash() (crypto.Hash, error) { + algo := si.DigestAlgorithm.Algorithm.String() + hash := oid.DigestAlgorithmToHash[algo] + if hash == 0 || !hash.Available() { + return 0, ErrUnsupported + } + + return hash, nil +} + +// X509SignatureAlgorithm gets the x509.SignatureAlgorithm that should be used +// for verifying this SignerInfo's signature. +func (si SignerInfo) X509SignatureAlgorithm() x509.SignatureAlgorithm { + var ( + sigOID = si.SignatureAlgorithm.Algorithm.String() + digestOID = si.DigestAlgorithm.Algorithm.String() + ) + + return oid.SignatureAlgorithms[sigOID][digestOID] +} + +// GetContentTypeAttribute gets the signed ContentType attribute from the +// SignerInfo. +func (si SignerInfo) GetContentTypeAttribute() (asn1.ObjectIdentifier, error) { + var sa Attributes + sa = si.SignedAttrs + rv, err := sa.GetOnlyAttributeValueBytes(oid.AttributeContentType) + if err != nil { + return nil, err + } + + var ct asn1.ObjectIdentifier + if rest, err := asn.Unmarshal(rv.FullBytes, &ct); err != nil { + return nil, err + } else if len(rest) > 0 { + return nil, ErrTrailingData + } + + return ct, nil +} + +// GetMessageDigestAttribute gets the signed MessageDigest attribute from the +// SignerInfo. +func (si SignerInfo) GetMessageDigestAttribute() ([]byte, error) { + var sa Attributes + sa = si.SignedAttrs + rv, err := sa.GetOnlyAttributeValueBytes(oid.AttributeMessageDigest) + if err != nil { + return nil, err + } + if rv.Class != asn1.ClassUniversal || rv.Tag != asn1.TagOctetString { + return nil, ASN1Error{"bad class or tag"} + } + + return rv.Bytes, nil +} + +// GetSigningTimeAttribute gets the signed SigningTime attribute from the +// SignerInfo. +func (si SignerInfo) GetSigningTimeAttribute() (time.Time, error) { + var t time.Time + + var sa Attributes + sa = si.SignedAttrs + rv, err := sa.GetOnlyAttributeValueBytes(oid.AttributeSigningTime) + if err != nil { + return t, err + } + if rv.Class != asn1.ClassUniversal || (rv.Tag != asn1.TagUTCTime && rv.Tag != asn1.TagGeneralizedTime) { + return t, ASN1Error{"bad class or tag"} + } + + if rest, err := asn.Unmarshal(rv.FullBytes, &t); err != nil { + return t, err + } else if len(rest) > 0 { + return t, ErrTrailingData + } + + return t, nil +}