Commit 8b7fb485 authored by Mathieu's avatar Mathieu
Browse files

Encode private IP in JWTs

parent 8a5ea175
......@@ -50,18 +50,18 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
def headers(request: HttpServletRequest) = request.getHeaderNames.map { hn => hn -> request.getHeader(hn) }.toSeq
def proxyRequest(uuid: UUID) = {
withForwardRequest(uuid){ forwardRequest=>
val req = forwardRequest.withHeaders((headers(request)): _*).withHeaders(allowHeaders: _*)
val fR = waitForGet(req)
Ok(fR.body, fR.headers)
}.getOrElse(NotFound())
def proxyRequest(hostIP: Option[String]) = {
hostIP.map { hip =>
withForwardRequest(hip) { forwardRequest =>
val req = forwardRequest.withHeaders((headers(request)): _*).withHeaders(allowHeaders: _*)
val fR = waitForGet(req)
Ok(fR.body, fR.headers)
}
}
}
def withForwardRequest(uuid: UUID)(action: HttpRequest=> ActionResult): Option[ActionResult] = {
K8sService.podIP(uuid).map { podIP =>
action(baseForwardRequest.withHost(podIP).withPort(80).withPath(""))
}
def withForwardRequest(hostIP: String)(action: HttpRequest => ActionResult): ActionResult = {
action(baseForwardRequest.withHost(hostIP).withPort(80).withPath(""))
}
def withAccesToken(action: TokenData => ActionResult): Serializable = {
......@@ -71,7 +71,7 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
Authentication.isValid(request, TokenType.refreshToken) match {
case true =>
withRefreshToken { refreshToken =>
val tokenData = TokenData.accessToken(refreshToken.uuid, refreshToken.login)
val tokenData = TokenData.accessToken(refreshToken.host, refreshToken.login)
buildAndAddCookieToHeader(tokenData)
action(tokenData)
}
......@@ -89,7 +89,7 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
def connectionAppRedirection = {
withAccesToken { tokenData =>
proxyRequest(tokenData.uuid)
proxyRequest(tokenData.host.hostIP).getOrElse(NotFound())
}
}
......@@ -105,21 +105,23 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
// OM instance requests
post("/*") {
withAccesToken { tokenData =>
withForwardRequest(tokenData.uuid) { forwardRequest =>
multiParams("splat").headOption match {
case Some(path) =>
val is = request.getInputStream
val bytes: Array[Byte] = Iterator.continually(is.read()).takeWhile(_ != -1).map(_.asInstanceOf[Byte]).toArray[Byte]
val bb = ByteBuffer.wrap(bytes)
val req = waitForPost(
forwardRequest.withPath(s"/$path").withHeader("Content-Type", "application/octet-stream").withBody(ByteBufferBody(bb))
)
if (req.statusCode < 400) Ok(req.body)
else NotFound()
case None => NotFound()
tokenData.host.hostIP.map { hip =>
withForwardRequest(hip) { forwardRequest =>
multiParams("splat").headOption match {
case Some(path) =>
val is = request.getInputStream
val bytes: Array[Byte] = Iterator.continually(is.read()).takeWhile(_ != -1).map(_.asInstanceOf[Byte]).toArray[Byte]
val bb = ByteBuffer.wrap(bytes)
val req = waitForPost(
forwardRequest.withPath(s"/$path").withHeader("Content-Type", "application/octet-stream").withBody(ByteBufferBody(bb))
)
if (req.statusCode < 400) Ok(req.body)
else NotFound()
case None => NotFound()
}
}
}.getOrElse(NotFound())
}
......@@ -138,8 +140,9 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
else {
DB.uuid(DB.User(DB.Login(login), DB.Password(password))) match {
case Some(uuid) =>
buildAndAddCookieToHeader(TokenData.accessToken(uuid, DB.Login(login)))
buildAndAddCookieToHeader(TokenData.refreshToken(uuid, DB.Login(login)))
val host = Host(uuid, K8sService.hostIP(uuid))
buildAndAddCookieToHeader(TokenData.accessToken(host, DB.Login(login)))
buildAndAddCookieToHeader(TokenData.refreshToken(host, DB.Login(login)))
redirect("/")
case _ => connectionHtml
}
......@@ -169,12 +172,14 @@ class ConnectServlet(arguments: ConnectServer.ServletArguments) extends Scalatra
localPath
} else {
withAccesToken { tokenData =>
withForwardRequest(tokenData.uuid) { forwardRequest =>
Ok(
waitForGet(
forwardRequest.withHeader("Content-Type", requestContentType).withPath(s"$path")
).body
)
tokenData.host.hostIP.map { hip =>
withForwardRequest(hip) { forwardRequest =>
Ok(
waitForGet(
forwardRequest.withHeader("Content-Type", requestContentType).withPath(s"$path")
).body
)
}
}.getOrElse(NotFound())
}
}
......
......@@ -12,6 +12,7 @@ object JWT {
type Secret = String
type Token = String
case class Host(uuid: UUID, hostIP: Option[String])
trait TokenType {
def cookieKey: String
......@@ -34,24 +35,34 @@ object JWT {
def fromTokenContent(content: String, tokenType: TokenType)(implicit secret: Secret) = {
Jwt.decode(content, secret, Seq(JwtAlgorithm.HS256)).map { jwtClaim =>
val login: Login = Login(Json.fromJson(jwtClaim.content, Json.key.login))
val uuid: UUID = UUID(Json.fromJson(jwtClaim.content, Json.key.uuid))
TokenData(login, uuid, jwtClaim.issuedAt.get, jwtClaim.expiration.get, tokenType)
val host = {
val uuid: UUID = UUID(Json.fromJson(jwtClaim.content, Json.key.uuid))
val hostIP: String = Json.fromJson(jwtClaim.content, Json.key.hostIP)
val hip = {
if(hostIP.isEmpty) None
else Some(hostIP)
}
Host(uuid, hip)
}
TokenData(login, host, jwtClaim.issuedAt.get, jwtClaim.expiration.get, tokenType)
}.toOption.filter {
hasExpired(_)
}
}
def accessToken(uuid: UUID, login: Login) = TokenData(login, uuid, now, inFiveMinutes, TokenType.accessToken)
def accessToken(host: Host, login: Login) = TokenData(login, host, now, inFiveMinutes, TokenType.accessToken)
def refreshToken(uuid: UUID, login: Login) = TokenData(login, uuid, now, inOneMonth, TokenType.refreshToken)
def refreshToken(host: Host, login: Login) = TokenData(login, host, now, inOneMonth, TokenType.refreshToken)
}
case class TokenData(login: Login, uuid: UUID, issued: Long, expirationTime: Long, tokenType: TokenType) {
case class TokenData(login: Login, host: Host, issued: Long, expirationTime: Long, tokenType: TokenType) {
def toContent(implicit secret: Secret) = {
implicit val clock = Clock.systemUTC()
val claims = Seq((Json.key.uuid, uuid.value), (Json.key.login, login.value))
val claims = Seq((Json.key.uuid, host.uuid.value), (Json.key.hostIP, host.hostIP.getOrElse("")), (Json.key.login, login.value))
val expandedClaims = claims.map { case (k, v) =>
s"""
......
......@@ -10,6 +10,7 @@ object Json {
object key {
val login = "login"
val uuid = "uuid"
val hostIP = "hostIP"
}
def fromJson(json: String, jsonKey: String): String = {
......
......@@ -79,5 +79,5 @@ object K8sService {
// }
}
def podIP(uuid:UUID) = pod(uuid).map{_.podIP}
def hostIP(uuid:UUID) = pod(uuid).map{_.podIP}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment