Skip to content

Commit ac003ce

Browse files
corybcode-asher
authored andcommitted
feat: add configuration options to support mtls
adding options to support mtls with the coder server. This supports adding PEM certs and keys to the tls requests, and also supports adding a CA cert to the trust store. Also allowing for an alternate hostname that may appear in the certs which is useful for testing or for non-standard cert usage.
1 parent 5e55049 commit ac003ce

11 files changed

+327
-37
lines changed

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ gradleVersion=7.4
2929
# Opt-out flag for bundling Kotlin standard library.
3030
# See https://plugins.jetbrains.com/docs/intellij/kotlin.html#kotlin-standard-library for details.
3131
# suppress inspection "UnusedProperty"
32-
kotlin.stdlib.default.dependency=false
32+
kotlin.stdlib.default.dependency=true

src/main/kotlin/com/coder/gateway/CoderGatewayConnectionProvider.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class CoderGatewayConnectionProvider : GatewayConnectionProvider {
140140
if (token == null) { // User aborted.
141141
throw IllegalArgumentException("Unable to connect to $deploymentURL, $TOKEN is missing")
142142
}
143-
val client = CoderRestClient(deploymentURL, token.first, settings.headerCommand, null)
143+
val client = CoderRestClient(deploymentURL, token.first,null, settings)
144144
return try {
145145
Pair(client, client.me().username)
146146
} catch (ex: AuthenticationResponseException) {

src/main/kotlin/com/coder/gateway/CoderSettingsConfigurable.kt

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
3939
.comment(
4040
CoderGatewayBundle.message(
4141
"gateway.connector.settings.binary-source.comment",
42-
CoderCLIManager(URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
42+
CoderCLIManager(state, URL("http://localhost"), CoderCLIManager.getDataDir()).remoteBinaryURL.path,
4343
)
4444
)
4545
}.layout(RowLayout.PARENT_GRID)
@@ -73,6 +73,34 @@ class CoderSettingsConfigurable : BoundConfigurable("Coder") {
7373
CoderGatewayBundle.message("gateway.connector.settings.header-command.comment")
7474
)
7575
}.layout(RowLayout.PARENT_GRID)
76+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.title")) {
77+
textField().resizableColumn().align(AlignX.FILL)
78+
.bindText(state::tlsCertPath)
79+
.comment(
80+
CoderGatewayBundle.message("gateway.connector.settings.tls-cert-path.comment")
81+
)
82+
}.layout(RowLayout.PARENT_GRID)
83+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.title")) {
84+
textField().resizableColumn().align(AlignX.FILL)
85+
.bindText(state::tlsKeyPath)
86+
.comment(
87+
CoderGatewayBundle.message("gateway.connector.settings.tls-key-path.comment")
88+
)
89+
}.layout(RowLayout.PARENT_GRID)
90+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.title")) {
91+
textField().resizableColumn().align(AlignX.FILL)
92+
.bindText(state::tlsCAPath)
93+
.comment(
94+
CoderGatewayBundle.message("gateway.connector.settings.tls-ca-path.comment")
95+
)
96+
}.layout(RowLayout.PARENT_GRID)
97+
row(CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.title")) {
98+
textField().resizableColumn().align(AlignX.FILL)
99+
.bindText(state::tlsAlternateHostname)
100+
.comment(
101+
CoderGatewayBundle.message("gateway.connector.settings.tls-alt-name.comment")
102+
)
103+
}.layout(RowLayout.PARENT_GRID)
76104
}
77105
}
78106

src/main/kotlin/com/coder/gateway/sdk/CoderCLIManager.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,15 @@ import java.nio.file.StandardCopyOption
2222
import java.security.DigestInputStream
2323
import java.security.MessageDigest
2424
import java.util.zip.GZIPInputStream
25+
import javax.net.ssl.HttpsURLConnection
2526
import javax.xml.bind.annotation.adapters.HexBinaryAdapter
2627

2728

2829
/**
2930
* Manage the CLI for a single deployment.
3031
*/
3132
class CoderCLIManager @JvmOverloads constructor(
33+
private val settings: CoderSettingsState,
3234
private val deploymentURL: URL,
3335
dataDir: Path,
3436
cliDir: Path? = null,
@@ -104,6 +106,10 @@ class CoderCLIManager @JvmOverloads constructor(
104106
conn.setRequestProperty("If-None-Match", "\"$etag\"")
105107
}
106108
conn.setRequestProperty("Accept-Encoding", "gzip")
109+
if (conn is HttpsURLConnection) {
110+
conn.sslSocketFactory = coderSocketFactory(settings)
111+
conn.hostnameVerifier = CoderHostnameVerifier(settings.tlsAlternateHostname)
112+
}
107113

108114
try {
109115
conn.connect()
@@ -463,7 +469,7 @@ class CoderCLIManager @JvmOverloads constructor(
463469
if (settings.binaryDirectory.isBlank()) null
464470
else Path.of(settings.binaryDirectory).toAbsolutePath()
465471

466-
val cli = CoderCLIManager(deploymentURL, dataDir, binDir, settings.binarySource)
472+
val cli = CoderCLIManager(settings, deploymentURL, dataDir, binDir, settings.binarySource)
467473

468474
// Short-circuit if we already have the expected version. This
469475
// lets us bypass the 304 which is slower and may not be
@@ -490,7 +496,7 @@ class CoderCLIManager @JvmOverloads constructor(
490496
}
491497

492498
// Try falling back to the data directory.
493-
val dataCLI = CoderCLIManager(deploymentURL, dataDir, null, settings.binarySource)
499+
val dataCLI = CoderCLIManager(settings, deploymentURL, dataDir, null, settings.binarySource)
494500
val dataCLIMatches = dataCLI.matchesVersion(buildVersion)
495501
if (dataCLIMatches == true) {
496502
return dataCLI

src/main/kotlin/com/coder/gateway/sdk/CoderRestClientService.kt

Lines changed: 237 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,48 @@ import com.coder.gateway.sdk.v2.models.Workspace
1414
import com.coder.gateway.sdk.v2.models.WorkspaceBuild
1515
import com.coder.gateway.sdk.v2.models.WorkspaceTransition
1616
import com.coder.gateway.sdk.v2.models.toAgentModels
17+
import com.coder.gateway.services.CoderSettingsState
1718
import com.google.gson.Gson
1819
import com.google.gson.GsonBuilder
1920
import com.intellij.ide.plugins.PluginManagerCore
2021
import com.intellij.openapi.components.Service
2122
import com.intellij.openapi.extensions.PluginId
2223
import com.intellij.openapi.util.SystemInfo
2324
import okhttp3.OkHttpClient
25+
import okhttp3.internal.tls.OkHostnameVerifier
2426
import okhttp3.logging.HttpLoggingInterceptor
2527
import org.zeroturnaround.exec.ProcessExecutor
2628
import retrofit2.Retrofit
2729
import retrofit2.converter.gson.GsonConverterFactory
30+
import java.io.File
31+
import java.io.FileInputStream
2832
import java.net.HttpURLConnection.HTTP_CREATED
33+
import java.net.InetAddress
34+
import java.net.Socket
2935
import java.net.URL
36+
import java.nio.file.Path
37+
import java.security.KeyFactory
38+
import java.security.KeyStore
39+
import java.security.PrivateKey
40+
import java.security.cert.CertificateException
41+
import java.security.cert.CertificateFactory
42+
import java.security.cert.X509Certificate
43+
import java.security.spec.InvalidKeySpecException
44+
import java.security.spec.PKCS8EncodedKeySpec
3045
import java.time.Instant
46+
import java.util.Base64
47+
import java.util.Locale
3148
import java.util.UUID
49+
import javax.net.ssl.HostnameVerifier
50+
import javax.net.ssl.KeyManagerFactory
51+
import javax.net.ssl.SNIHostName
52+
import javax.net.ssl.SSLContext
53+
import javax.net.ssl.SSLSession
54+
import javax.net.ssl.SSLSocket
55+
import javax.net.ssl.SSLSocketFactory
56+
import javax.net.ssl.TrustManagerFactory
57+
import javax.net.ssl.TrustManager
58+
import javax.net.ssl.X509TrustManager
3259

3360
@Service(Service.Level.APP)
3461
class CoderRestClientService {
@@ -44,18 +71,19 @@ class CoderRestClientService {
4471
*
4572
* @throws [AuthenticationResponseException] if authentication failed.
4673
*/
47-
fun initClientSession(url: URL, token: String, headerCommand: String?): User {
48-
client = CoderRestClient(url, token, headerCommand, null)
74+
fun initClientSession(url: URL, token: String, settings: CoderSettingsState): User {
75+
client = CoderRestClient(url, token, null, settings)
4976
me = client.me()
5077
buildVersion = client.buildInfo().version
5178
isReady = true
5279
return me
5380
}
5481
}
5582

56-
class CoderRestClient(var url: URL, var token: String,
57-
private var headerCommand: String?,
83+
class CoderRestClient(
84+
var url: URL, var token: String,
5885
private var pluginVersion: String?,
86+
private var settings: CoderSettingsState,
5987
) {
6088
private var httpClient: OkHttpClient
6189
private var retroRestClient: CoderV2RestFacade
@@ -66,12 +94,16 @@ class CoderRestClient(var url: URL, var token: String,
6694
pluginVersion = PluginManagerCore.getPlugin(PluginId.getId("com.coder.gateway"))!!.version // this is the id from the plugin.xml
6795
}
6896

97+
val socketFactory = coderSocketFactory(settings)
98+
val trustManagers = coderTrustManagers(settings.tlsCAPath)
6999
httpClient = OkHttpClient.Builder()
100+
.sslSocketFactory(socketFactory, trustManagers[0] as X509TrustManager)
101+
.hostnameVerifier(CoderHostnameVerifier(settings.tlsAlternateHostname))
70102
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("Coder-Session-Token", token).build()) }
71103
.addInterceptor { it.proceed(it.request().newBuilder().addHeader("User-Agent", "Coder Gateway/${pluginVersion} (${SystemInfo.getOsNameAndVersion()}; ${SystemInfo.OS_ARCH})").build()) }
72104
.addInterceptor {
73105
var request = it.request()
74-
val headers = getHeaders(url, headerCommand)
106+
val headers = getHeaders(url, settings.headerCommand)
75107
if (headers.size > 0) {
76108
val builder = request.newBuilder()
77109
headers.forEach { h -> builder.addHeader(h.key, h.value) }
@@ -218,3 +250,203 @@ class CoderRestClient(var url: URL, var token: String,
218250
}
219251
}
220252
}
253+
254+
fun coderSocketFactory(settings: CoderSettingsState) : SSLSocketFactory {
255+
if (settings.tlsCertPath.isBlank() || settings.tlsKeyPath.isBlank()) {
256+
return SSLSocketFactory.getDefault() as SSLSocketFactory
257+
}
258+
259+
val certificateFactory = CertificateFactory.getInstance("X.509")
260+
val certInputStream = FileInputStream(expandPath(settings.tlsCertPath))
261+
val certChain = certificateFactory.generateCertificates(certInputStream)
262+
certInputStream.close()
263+
264+
// ideally we would use something like PemReader from BouncyCastle, but
265+
// BC is used by the IDE. This makes using BC very impractical since
266+
// type casting will mismatch due to the different class loaders.
267+
val privateKeyPem = File(expandPath(settings.tlsKeyPath)).readText()
268+
val start: Int = privateKeyPem.indexOf("-----BEGIN PRIVATE KEY-----")
269+
val end: Int = privateKeyPem.indexOf("-----END PRIVATE KEY-----", start)
270+
val pemBytes: ByteArray = Base64.getDecoder().decode(
271+
privateKeyPem.substring(start + "-----BEGIN PRIVATE KEY-----".length, end)
272+
.replace("\\s+".toRegex(), "")
273+
)
274+
275+
var privateKey : PrivateKey
276+
try {
277+
val kf = KeyFactory.getInstance("RSA")
278+
val keySpec = PKCS8EncodedKeySpec(pemBytes)
279+
privateKey = kf.generatePrivate(keySpec)
280+
} catch (e: InvalidKeySpecException) {
281+
val kf = KeyFactory.getInstance("EC")
282+
val keySpec = PKCS8EncodedKeySpec(pemBytes)
283+
privateKey = kf.generatePrivate(keySpec)
284+
}
285+
286+
val keyStore = KeyStore.getInstance(KeyStore.getDefaultType())
287+
keyStore.load(null)
288+
certChain.withIndex().forEach {
289+
keyStore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
290+
}
291+
keyStore.setKeyEntry("key", privateKey, null, certChain.toTypedArray())
292+
293+
val keyManagerFactory = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())
294+
keyManagerFactory.init(keyStore, null)
295+
296+
val sslContext = SSLContext.getInstance("TLS")
297+
298+
val trustManagers = coderTrustManagers(settings.tlsCAPath)
299+
sslContext.init(keyManagerFactory.keyManagers, trustManagers, null)
300+
301+
if (settings.tlsAlternateHostname.isBlank()) {
302+
return sslContext.socketFactory
303+
}
304+
305+
return AlternateNameSSLSocketFactory(sslContext.socketFactory, settings.tlsAlternateHostname)
306+
}
307+
308+
fun coderTrustManagers(tlsCAPath: String) : Array<TrustManager> {
309+
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
310+
if (tlsCAPath.isBlank()) {
311+
// return default trust managers
312+
trustManagerFactory.init(null as KeyStore?)
313+
return trustManagerFactory.trustManagers
314+
}
315+
316+
317+
val certificateFactory = CertificateFactory.getInstance("X.509")
318+
val caInputStream = FileInputStream(expandPath(tlsCAPath))
319+
val certChain = certificateFactory.generateCertificates(caInputStream)
320+
321+
val truststore = KeyStore.getInstance(KeyStore.getDefaultType())
322+
truststore.load(null)
323+
certChain.withIndex().forEach {
324+
truststore.setCertificateEntry("cert${it.index}", it.value as X509Certificate)
325+
}
326+
trustManagerFactory.init(truststore)
327+
return trustManagerFactory.trustManagers.map { MergedSystemTrustManger(it as X509TrustManager) }.toTypedArray()
328+
}
329+
330+
fun expandPath(path: String): String {
331+
if (path.startsWith("~/")) {
332+
return Path.of(System.getProperty("user.home"), path.substring(1)).toString()
333+
}
334+
if (path.startsWith("\$HOME/")) {
335+
return Path.of(System.getProperty("user.home"), path.substring(5)).toString()
336+
}
337+
if (path.startsWith("\${user.home}/")) {
338+
return Path.of(System.getProperty("user.home"), path.substring(12)).toString()
339+
}
340+
return path
341+
}
342+
343+
class AlternateNameSSLSocketFactory(private val delegate: SSLSocketFactory, private val alternateName: String) : SSLSocketFactory() {
344+
override fun getDefaultCipherSuites(): Array<String> {
345+
return delegate.defaultCipherSuites
346+
}
347+
348+
override fun getSupportedCipherSuites(): Array<String> {
349+
return delegate.supportedCipherSuites
350+
}
351+
352+
override fun createSocket(): Socket {
353+
val socket = delegate.createSocket() as SSLSocket
354+
customizeSocket(socket)
355+
return socket
356+
}
357+
358+
override fun createSocket(host: String?, port: Int): Socket {
359+
val socket = delegate.createSocket(host, port) as SSLSocket
360+
customizeSocket(socket)
361+
return socket
362+
}
363+
364+
override fun createSocket(host: String?, port: Int, localHost: InetAddress?, localPort: Int): Socket {
365+
val socket = delegate.createSocket(host, port, localHost, localPort) as SSLSocket
366+
customizeSocket(socket)
367+
return socket
368+
}
369+
370+
override fun createSocket(host: InetAddress?, port: Int): Socket {
371+
val socket = delegate.createSocket(host, port) as SSLSocket
372+
customizeSocket(socket)
373+
return socket
374+
}
375+
376+
override fun createSocket(address: InetAddress?, port: Int, localAddress: InetAddress?, localPort: Int): Socket {
377+
val socket = delegate.createSocket(address, port, localAddress, localPort) as SSLSocket
378+
customizeSocket(socket)
379+
return socket
380+
}
381+
382+
override fun createSocket(s: Socket?, host: String?, port: Int, autoClose: Boolean): Socket {
383+
val socket = delegate.createSocket(s, host, port, autoClose) as SSLSocket
384+
customizeSocket(socket)
385+
return socket
386+
}
387+
388+
private fun customizeSocket(socket: SSLSocket) {
389+
val params = socket.sslParameters
390+
params.serverNames = listOf(SNIHostName(alternateName))
391+
socket.sslParameters = params
392+
}
393+
}
394+
395+
class CoderHostnameVerifier(private val alternateName: String) : HostnameVerifier {
396+
override fun verify(host: String, session: SSLSession): Boolean {
397+
if (alternateName.isEmpty()) {
398+
println("using default hostname verifier, alternateName is empty")
399+
return OkHostnameVerifier.verify(host, session)
400+
}
401+
println("Looking for alternate hostname: $alternateName")
402+
val certs = session.peerCertificates ?: return false
403+
for (cert in certs) {
404+
if (cert !is X509Certificate) {
405+
continue
406+
}
407+
val entries = cert.subjectAlternativeNames ?: continue
408+
for (entry in entries) {
409+
val kind = entry[0] as Int
410+
if (kind != 2) { // DNS Name
411+
continue
412+
}
413+
val hostname = entry[1] as String
414+
println("Found cert hostname: $hostname")
415+
if (hostname.lowercase(Locale.getDefault()) == alternateName) {
416+
return true
417+
}
418+
}
419+
}
420+
println("No matching hostname found")
421+
return false
422+
}
423+
}
424+
425+
class MergedSystemTrustManger(private val otherTrustManager: X509TrustManager) : X509TrustManager {
426+
private val systemTrustManager : X509TrustManager
427+
init {
428+
val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())
429+
trustManagerFactory.init(null as KeyStore?)
430+
systemTrustManager = trustManagerFactory.trustManagers.first { it is X509TrustManager } as X509TrustManager
431+
}
432+
433+
override fun checkClientTrusted(chain: Array<out X509Certificate>, authType: String?) {
434+
try {
435+
otherTrustManager.checkClientTrusted(chain, authType)
436+
} catch (e: CertificateException) {
437+
systemTrustManager.checkClientTrusted(chain, authType)
438+
}
439+
}
440+
441+
override fun checkServerTrusted(chain: Array<out X509Certificate>, authType: String?) {
442+
try {
443+
otherTrustManager.checkServerTrusted(chain, authType)
444+
} catch (e: CertificateException) {
445+
systemTrustManager.checkServerTrusted(chain, authType)
446+
}
447+
}
448+
449+
override fun getAcceptedIssuers(): Array<X509Certificate> {
450+
return otherTrustManager.acceptedIssuers + systemTrustManager.acceptedIssuers
451+
}
452+
}

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy