Skip to content

Commit eabaf46

Browse files
hoist PEM reading function
Signed-off-by: eternal-flame-AD <[email protected]>
1 parent 12f5aeb commit eabaf46

File tree

3 files changed

+67
-42
lines changed

3 files changed

+67
-42
lines changed

v2/shim_v1.go

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import (
1010
"encoding/json"
1111
"encoding/pem"
1212
"errors"
13-
"io"
1413
"log"
1514
"math"
1615
"net/http"
@@ -120,35 +119,28 @@ func NewCompatV1Rpc(compatV1 *CompatV1, cliArgs []string) (*CompatV1Shim, error)
120119
return nil, err
121120
}
122121

123-
var certBytes []byte
124122
var certificateChain []tls.Certificate
125-
for {
126-
var buf [2048]byte
127-
n, err := cliFlags.KexRespFile.Read(buf[:])
128-
if err != nil {
129-
if err == io.EOF {
130-
break
131-
}
132-
return nil, err
133-
}
134-
certBytes = append(certBytes, buf[:n]...)
135123

136-
for block, rest := pem.Decode(certBytes); block != nil; block, rest = pem.Decode(rest) {
137-
if block.Type == "CERTIFICATE" {
138-
parsedCert, err := x509.ParseCertificate(block.Bytes)
139-
if err != nil {
140-
return nil, err
141-
}
142-
// Server signs with IsCA=false, so we can add all of them to the root CA pool without
143-
// trusting things we shouldn't.
144-
rootCAs.AddCert(parsedCert)
145-
certificateChain = append(certificateChain, tls.Certificate{
146-
Certificate: [][]byte{block.Bytes},
147-
Leaf: parsedCert,
148-
})
124+
if err := transport.IteratePEMFile(cliFlags.KexRespFile, func(block *pem.Block) (continueIterate bool, err error) {
125+
if block.Type == "CERTIFICATE" {
126+
parsedCert, err := x509.ParseCertificate(block.Bytes)
127+
if err != nil {
128+
return false, err
149129
}
150-
certBytes = rest
130+
rootCAs.AddCert(parsedCert)
131+
certificateChain = append(certificateChain, tls.Certificate{
132+
Certificate: [][]byte{block.Bytes},
133+
Leaf: parsedCert,
134+
})
135+
return true, nil
151136
}
137+
return true, nil
138+
}); err != nil {
139+
return nil, err
140+
}
141+
142+
if len(certificateChain) == 0 {
143+
return nil, errors.New("no certificate chain found in kex response file")
152144
}
153145

154146
certificateChain[0].PrivateKey = priv

v2/transport/pem_file.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
package transport
2+
3+
import (
4+
"encoding/pem"
5+
"io"
6+
)
7+
8+
func IteratePEMFile(r io.Reader, callback func(block *pem.Block) (continueIterate bool, err error)) error {
9+
var bufferBytes []byte
10+
for {
11+
var buf [2048]byte
12+
n, err := r.Read(buf[:])
13+
if err != nil {
14+
if err == io.EOF {
15+
break
16+
}
17+
return err
18+
}
19+
bufferBytes = append(bufferBytes, buf[:n]...)
20+
21+
for block, rest := pem.Decode(bufferBytes); block != nil; block, rest = pem.Decode(rest) {
22+
continueIterate, err := callback(block)
23+
if err != nil {
24+
return err
25+
}
26+
if !continueIterate {
27+
return nil
28+
}
29+
}
30+
}
31+
return nil
32+
}

v2/transport/transport_auth.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"crypto/x509/pkix"
99
"encoding/hex"
1010
"encoding/pem"
11+
"errors"
1112
"fmt"
1213
"io"
1314
"slices"
@@ -102,27 +103,25 @@ func (s *EphemeralTLSClient) SignPluginCSR(moduleName string, csr *x509.Certific
102103

103104
func (s *EphemeralTLSClient) Kex(req io.Reader, resp io.Writer) error {
104105
var csr *x509.CertificateRequest
105-
var csrBytes []byte
106-
for csr == nil {
107-
var buf [2048]byte
108-
n, err := req.Read(buf[:])
109-
if err != nil {
110-
return err
111-
}
112-
csrBytes = append(csrBytes, buf[:n]...)
113-
block, _ := pem.Decode(csrBytes)
114-
if block == nil {
115-
continue
116-
}
117106

107+
if err := IteratePEMFile(req, func(block *pem.Block) (continueIterate bool, err error) {
118108
if block.Type == "CERTIFICATE REQUEST" {
119-
csrParsed, err := x509.ParseCertificateRequest(block.Bytes)
109+
csr, err = x509.ParseCertificateRequest(block.Bytes)
120110
if err != nil {
121-
return err
111+
return false, err
122112
}
123-
csr = csrParsed
113+
114+
return false, nil
124115
}
116+
return true, nil
117+
}); err != nil {
118+
return err
119+
}
120+
121+
if csr == nil {
122+
return errors.New("no certificate request found in kex request file")
125123
}
124+
126125
dnsName := csr.Subject.CommonName
127126
certBytes, err := s.SignCSR(dnsName, csr)
128127
if err != nil {
@@ -165,6 +164,9 @@ func NewEphemeralTLSClient() (*EphemeralTLSClient, error) {
165164
IsCA: true,
166165
}
167166
caCertBytes, err := x509.CreateCertificate(rand.Reader, caCertTemplate, caCertTemplate, caPub, caPriv)
167+
if err != nil {
168+
return nil, err
169+
}
168170
caCert, err := x509.ParseCertificate(caCertBytes)
169171
if err != nil {
170172
return nil, err
@@ -203,7 +205,6 @@ func NewEphemeralTLSClient() (*EphemeralTLSClient, error) {
203205
},
204206
{
205207
Certificate: [][]byte{caCertBytes},
206-
PrivateKey: caPriv,
207208
},
208209
},
209210
RootCAs: certPool,

0 commit comments

Comments
 (0)