OpenID Connect のメモ

Amazon Cognito の署名検証のコードを書きました。

OpenID Connect (OIDC) では公開鍵方式に RSA を使用する署名検証アルゴリズムRSA-SHA256 (RS256) がよく使用されています。

公開鍵は、JSON Web Key (JWK) 仕様にあるように JSON で公開鍵の情報が IdP から提供されます。

公開鍵が一般的な PEM 等の形式ではないため、少し特別なことをしないと公開鍵オブジェクト (RSAPublicKey) に復元することができません。この記事では、JSON から RSAPublicKey オブジェクトを復元して署名検証するコードを例示します。

Scala の例

import java.math.BigInteger
import java.net.URI
import java.security.interfaces.RSAPublicKey
import java.security.spec.RSAPublicKeySpec
import java.security.{KeyFactory, PublicKey}
import java.util.Base64

import com.auth0.jwt.JWT
import com.auth0.jwt.algorithms.Algorithm
import spray.json._

private case class Key(alg: String,
                       e: String,
                       kid: String,
                       kty: String,
                       n: String,
                       use: String)
private case class JwksJson(keys: Array[Key])

private object JwksJsonProtocol extends DefaultJsonProtocol {
  implicit val keyJsonFormat = jsonFormat6(Key)
  implicit val inputJsonFormat = jsonFormat(JwksJson, "keys")
}

trait OpenIDConnectGetKey {
  val openIdConnectUrl: String
  private lazy val jwksJsonUrl =
    new URI(openIdConnectUrl.concat("/.well-known/jwks.json"))
  private var keys: Map[String, PublicKey] = Map.empty

  def openIDConnectGetKeys: Map[String, PublicKey] = {
    import JwksJsonProtocol._

    val input = jwksJsonUrl.toURL.openStream()
    try {
      val bytes = Stream.continually(input.read).takeWhile(_ != -1).map(_.toByte).toArray
      val jwksJson = JsonParser(bytes).convertTo[JwksJson]
      keys ++= jwksJson.keys.map { key =>
        val kty = key.kty
        val modulus = new BigInteger(1, Base64.getUrlDecoder.decode(key.n))
        val publicExponent = new BigInteger(1, Base64.getUrlDecoder.decode(key.e))
        val keySpec = new RSAPublicKeySpec(modulus, publicExponent)
        key.kid -> KeyFactory.getInstance(kty).generatePublic(keySpec)
      }.toMap
      keys
    } finally {
      input.close()
    }
  }

  def verify(idToken: String): JsValue = {
    val jwt = JWT.decode(idToken)
    val maybePublicKey = keys.get(jwt.getKeyId) match {
      case None => openIDConnectGetKeys.get(jwt.getKeyId)
      case s => s
    }
    maybePublicKey map { publicKey =>
      val algorithm = Algorithm.RSA256(publicKey.asInstanceOf[RSAPublicKey], null)
      val verification = JWT.require(algorithm)
      verification.build().verify(jwt)
    } getOrElse(throw new RuntimeException(s"Not exists ${jwt.getKeyId}"))
    JsonParser(Base64.getUrlDecoder().decode(jwt.getPayload))
  }
}

Java の例

import com.auth0.jwt.JWT;
import com.auth0.jwt.algorithms.Algorithm;
import com.auth0.jwt.interfaces.DecodedJWT;
import com.auth0.jwt.interfaces.Verification;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.net.URI;
import java.net.URISyntaxException;
import java.security.KeyFactory;
import java.security.NoSuchAlgorithmException;
import java.security.interfaces.RSAPublicKey;
import java.security.spec.InvalidKeySpecException;
import java.security.spec.KeySpec;
import java.security.spec.RSAPublicKeySpec;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;

public interface OpenIDConnectGetKeys {

    Map<String, KeyItem> getKeys();
    void saveKeys(Map<String, KeyItem> keys);
    String getOpenIdConnectUrl();
    default URI getJwksJsonUrl() throws URISyntaxException {
        return new URI(getOpenIdConnectUrl().concat("/.well-known/jwks.json"));
    }
    default BigInteger decodeBase64UrlUInt(String value) {
        byte[] uintBinary = Base64.getUrlDecoder().decode(value);
        return new BigInteger(1, uintBinary);
    }
    default Map<String, KeyItem> openIDConnectGetKeys() throws URISyntaxException, IOException {
        try (InputStream input = getJwksJsonUrl().toURL().openStream()) {
            Map<String, KeyItem> map = new HashMap();
            ByteArrayOutputStream buffer = new ByteArrayOutputStream();
            while (true) {
                int ch = input.read();
                if (ch == -1) {
                    break;
                }
                buffer.write(ch);
            }
            ObjectMapper mapper = new ObjectMapper();
            Keys keys = mapper.readValue(buffer.toByteArray(), Keys.class);
            for (KeyItem key: keys.keys) {
                map.put(key.kid, key);
            }
            return map;
        }
    }
    default HashMap<String, Object> verify(String idToken) throws TokenVerifyException {
        try {
            DecodedJWT jwt = JWT.decode(idToken);
            Map<String, KeyItem> keys = getKeys();
            if (!keys.containsKey(jwt.getKeyId())) {
                keys = openIDConnectGetKeys();
                saveKeys(keys);
            }
            KeyItem key = keys.get(jwt.getKeyId());
            if (key != null) {
                BigInteger modulus = decodeBase64UrlUInt(key.n);
                BigInteger publicExponent = decodeBase64UrlUInt(key.e);
                KeySpec spec = new RSAPublicKeySpec(modulus, publicExponent);
                RSAPublicKey publicKey = (RSAPublicKey) KeyFactory.getInstance("RSA").generatePublic(spec);
                Algorithm algorithm = Algorithm.RSA256(publicKey, null);
                Verification verification = JWT.require(algorithm);
                verification.build().verify(jwt);
                ObjectMapper mapper = new ObjectMapper();
                TypeReference<HashMap<String, Object>> typeRef = new TypeReference<HashMap<String, Object>>() {};
                return mapper.readValue(Base64.getUrlDecoder().decode(jwt.getPayload()), typeRef);
            } else {
                throw new TokenVerifyException(String.format("Not exists %s", jwt.getKeyId()));
            }
        } catch (IOException e) {
            throw new TokenVerifyException(e);
        } catch (NoSuchAlgorithmException e) {
            throw new TokenVerifyException(e);
        } catch (URISyntaxException e) {
            throw new TokenVerifyException(e);
        } catch (InvalidKeySpecException e) {
            throw new TokenVerifyException(e);
        }
    }
}