OpenID Connect のメモ
Amazon Cognito の署名検証のコードを書きました。
OpenID Connect (OIDC) では公開鍵方式に RSA を使用する署名検証アルゴリズムの RSA-SHA256 (RS256) がよく使用されています。
公開鍵は、JSON Web Key (JWK) 仕様にあるように JSON で公開鍵の情報が IdP から提供されます。
公開鍵が一般的な PEM 等の形式ではないため、少し特別なことをしないと公開鍵オブジェクト (RSAPublicKey) に復元することができません。この記事では、JSON から RSAPublicKey オブジェクトを復元して署名検証するコードを例示します。
Scala の例
import java.math.BigInteger import java.net.URI import java.security.interfaces.RSAPublicKey import java.security.spec.RSAPublicKeySpec import java.security.{KeyFactory, PublicKey} import java.util.Base64 import com.auth0.jwt.JWT import com.auth0.jwt.algorithms.Algorithm import spray.json._ private case class Key(alg: String, e: String, kid: String, kty: String, n: String, use: String) private case class JwksJson(keys: Array[Key]) private object JwksJsonProtocol extends DefaultJsonProtocol { implicit val keyJsonFormat = jsonFormat6(Key) implicit val inputJsonFormat = jsonFormat(JwksJson, "keys") } trait OpenIDConnectGetKey { val openIdConnectUrl: String private lazy val jwksJsonUrl = new URI(openIdConnectUrl.concat("/.well-known/jwks.json")) private var keys: Map[String, PublicKey] = Map.empty def openIDConnectGetKeys: Map[String, PublicKey] = { import JwksJsonProtocol._ val input = jwksJsonUrl.toURL.openStream() try { val bytes = Stream.continually(input.read).takeWhile(_ != -1).map(_.toByte).toArray val jwksJson = JsonParser(bytes).convertTo[JwksJson] keys ++= jwksJson.keys.map { key => val kty = key.kty val modulus = new BigInteger(1, Base64.getUrlDecoder.decode(key.n)) val publicExponent = new BigInteger(1, Base64.getUrlDecoder.decode(key.e)) val keySpec = new RSAPublicKeySpec(modulus, publicExponent) key.kid -> KeyFactory.getInstance(kty).generatePublic(keySpec) }.toMap keys } finally { input.close() } } def verify(idToken: String): JsValue = { val jwt = JWT.decode(idToken) val maybePublicKey = keys.get(jwt.getKeyId) match { case None => openIDConnectGetKeys.get(jwt.getKeyId) case s => s } maybePublicKey map { publicKey => val algorithm = Algorithm.RSA256(publicKey.asInstanceOf[RSAPublicKey], null) val verification = JWT.require(algorithm) verification.build().verify(jwt) } getOrElse(throw new RuntimeException(s"Not exists ${jwt.getKeyId}")) JsonParser(Base64.getUrlDecoder().decode(jwt.getPayload)) } }
Java の例
import com.auth0.jwt.JWT; import com.auth0.jwt.algorithms.Algorithm; import com.auth0.jwt.interfaces.DecodedJWT; import com.auth0.jwt.interfaces.Verification; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; import java.math.BigInteger; import java.net.URI; import java.net.URISyntaxException; import java.security.KeyFactory; import java.security.NoSuchAlgorithmException; import java.security.interfaces.RSAPublicKey; import java.security.spec.InvalidKeySpecException; import java.security.spec.KeySpec; import java.security.spec.RSAPublicKeySpec; import java.util.Base64; import java.util.HashMap; import java.util.Map; public interface OpenIDConnectGetKeys { Map<String, KeyItem> getKeys(); void saveKeys(Map<String, KeyItem> keys); String getOpenIdConnectUrl(); default URI getJwksJsonUrl() throws URISyntaxException { return new URI(getOpenIdConnectUrl().concat("/.well-known/jwks.json")); } default BigInteger decodeBase64UrlUInt(String value) { byte[] uintBinary = Base64.getUrlDecoder().decode(value); return new BigInteger(1, uintBinary); } default Map<String, KeyItem> openIDConnectGetKeys() throws URISyntaxException, IOException { try (InputStream input = getJwksJsonUrl().toURL().openStream()) { Map<String, KeyItem> map = new HashMap(); ByteArrayOutputStream buffer = new ByteArrayOutputStream(); while (true) { int ch = input.read(); if (ch == -1) { break; } buffer.write(ch); } ObjectMapper mapper = new ObjectMapper(); Keys keys = mapper.readValue(buffer.toByteArray(), Keys.class); for (KeyItem key: keys.keys) { map.put(key.kid, key); } return map; } } default HashMap<String, Object> verify(String idToken) throws TokenVerifyException { try { DecodedJWT jwt = JWT.decode(idToken); Map<String, KeyItem> keys = getKeys(); if (!keys.containsKey(jwt.getKeyId())) { keys = openIDConnectGetKeys(); saveKeys(keys); } KeyItem key = keys.get(jwt.getKeyId()); if (key != null) { BigInteger modulus = decodeBase64UrlUInt(key.n); BigInteger publicExponent = decodeBase64UrlUInt(key.e); KeySpec spec = new RSAPublicKeySpec(modulus, publicExponent); RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(spec); Algorithm algorithm = Algorithm.RSA256(publicKey, null); Verification verification = JWT.require(algorithm); verification.build().verify(jwt); ObjectMapper mapper = new ObjectMapper(); TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>() {}; return mapper.readValue(Base64.getUrlDecoder().decode(jwt.getPayload()), typeRef); } else { throw new TokenVerifyException(String.format("Not exists %s", jwt.getKeyId())); } } catch (IOException e) { throw new TokenVerifyException(e); } catch (NoSuchAlgorithmException e) { throw new TokenVerifyException(e); } catch (URISyntaxException e) { throw new TokenVerifyException(e); } catch (InvalidKeySpecException e) { throw new TokenVerifyException(e); } } }