Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 66 additions & 23 deletions obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ object RateLimitingUtil extends MdcLoggable {

def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false)

private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + RateLimitingPeriod.toString(period)
private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period)

private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = {
if (useConsumerLimits) {
Expand Down Expand Up @@ -173,18 +173,51 @@ object RateLimitingUtil extends MdcLoggable {
}

/**
* This function checks rate limiting for a Consumer.
* It will check rate limiting per minute, hour, day, week and month.
* In case any of the above is hit an error is thrown.
* In case two or more limits are hit rate limit with lower period has precedence regarding the error message.
* @param userAndCallContext is a Tuple (Box[User], Option[CallContext]) provided from getUserAndSessionContextFuture function
* @return a Tuple (Box[User], Option[CallContext]) enriched with rate limiting header or an error.
* Rate limiting guard that enforces API call limits for both authorized and anonymous access.
*
* This is the main rate limiting enforcement function that controls access to OBP API endpoints.
* It operates in two modes depending on whether the caller is authenticated or anonymous.
*
* AUTHORIZED ACCESS (with valid consumer credentials):
* - Enforces limits across 6 time periods: per second, minute, hour, day, week, and month
* - Uses consumer_id as the rate limiting key (simplified for current implementation)
* - Note: api_name, api_version, and bank_id may be added to the key in future versions
* - Limits are defined in CallLimit configuration for each consumer
* - Stores counters in Redis with TTL matching the time period
* - Returns 429 status with appropriate error message when any limit is exceeded
* - Lower period limits take precedence in error messages (e.g., per-second over per-minute)
*
* ANONYMOUS ACCESS (no consumer credentials):
* - Only enforces per-hour limits (configurable via "user_consumer_limit_anonymous_access", default: 1000)
* - Uses client IP address as the rate limiting key
* - Designed to prevent abuse while allowing reasonable anonymous usage
*
* REDIS STORAGE MECHANISM:
* - Keys format: {consumer_id}_{PERIOD} (e.g., "consumer123_PER_MINUTE")
* - Values: current call count within the time window
* - TTL: automatically expires keys when time period ends
* - Atomic operations ensure thread-safe counter increments
*
* RATE LIMIT HEADERS:
* - Sets X-Rate-Limit-Limit: maximum allowed requests for the period
* - Sets X-Rate-Limit-Reset: seconds until the limit resets (TTL)
* - Sets X-Rate-Limit-Remaining: requests remaining in current period
*
* ERROR HANDLING:
* - Redis connectivity issues default to allowing the request (fail-open)
* - Rate limiting can be globally disabled via "use_consumer_limits" property
* - Malformed or missing limits default to unlimited access
*
* @param userAndCallContext Tuple containing (Box[User], Option[CallContext]) from authentication
* @return Same tuple structure, either with updated rate limit headers or rate limit exceeded error
*/
def underCallLimits(userAndCallContext: (Box[User], Option[CallContext])): (Box[User], Option[CallContext]) = {
// Configuration and helper functions
def perHourLimitAnonymous = APIUtil.getPropsAsIntValue("user_consumer_limit_anonymous_access", 1000)
def composeMsgAuthorizedAccess(period: LimitCallPeriod, limit: Long): String = TooManyRequests + s" We only allow $limit requests ${RateLimitingPeriod.humanReadable(period)} for this Consumer."
def composeMsgAnonymousAccess(period: LimitCallPeriod, limit: Long): String = TooManyRequests + s" We only allow $limit requests ${RateLimitingPeriod.humanReadable(period)} for anonymous access."

// Helper function to set rate limit headers in successful responses
def setXRateLimits(c: CallLimit, z: (Long, Long), period: LimitCallPeriod): Option[CallContext] = {
val limit = period match {
case PER_SECOND => c.per_second
Expand All @@ -199,6 +232,7 @@ object RateLimitingUtil extends MdcLoggable {
.map(_.copy(xRateLimitReset = z._1))
.map(_.copy(xRateLimitRemaining = limit - z._2))
}
// Helper function to set rate limit headers for anonymous access
def setXRateLimitsAnonymous(id: String, z: (Long, Long), period: LimitCallPeriod): Option[CallContext] = {
val limit = period match {
case PER_HOUR => perHourLimitAnonymous
Expand All @@ -209,6 +243,7 @@ object RateLimitingUtil extends MdcLoggable {
.map(_.copy(xRateLimitRemaining = limit - z._2))
}

// Helper function to create rate limit exceeded response with remaining TTL for authorized users
def exceededRateLimit(c: CallLimit, period: LimitCallPeriod): Option[CallContextLight] = {
val remain = ttl(c.consumer_id, period)
val limit = period match {
Expand All @@ -225,6 +260,7 @@ object RateLimitingUtil extends MdcLoggable {
.map(_.copy(xRateLimitRemaining = 0)).map(_.toLight)
}

// Helper function to create rate limit exceeded response for anonymous users
def exceededRateLimitAnonymous(id: String, period: LimitCallPeriod): Option[CallContextLight] = {
val remain = ttl(id, period)
val limit = period match {
Expand All @@ -236,15 +272,14 @@ object RateLimitingUtil extends MdcLoggable {
.map(_.copy(xRateLimitRemaining = 0)).map(_.toLight)
}

// Main logic: check if we have a CallContext and determine access type
userAndCallContext._2 match {
case Some(cc) =>
cc.rateLimiting match {
case Some(rl) => // Authorized access
val rateLimitingKey =
rl.consumer_id +
rl.api_name.getOrElse("") +
rl.api_version.getOrElse("") +
rl.bank_id.getOrElse("")
case Some(rl) => // AUTHORIZED ACCESS - consumer has valid credentials and rate limits
// Create rate limiting key for Redis storage using consumer_id
val rateLimitingKey = rl.consumer_id
// Check if current request would exceed any of the 6 rate limits
val checkLimits = List(
underConsumerLimits(rateLimitingKey, PER_SECOND, rl.per_second),
underConsumerLimits(rateLimitingKey, PER_MINUTE, rl.per_minute),
Expand All @@ -253,6 +288,7 @@ object RateLimitingUtil extends MdcLoggable {
underConsumerLimits(rateLimitingKey, PER_WEEK, rl.per_week),
underConsumerLimits(rateLimitingKey, PER_MONTH, rl.per_month)
)
// Return 429 error for first exceeded limit (shorter periods take precedence)
checkLimits match {
case x1 :: x2 :: x3 :: x4 :: x5 :: x6 :: Nil if x1 == false =>
(fullBoxOrException(Empty ~> APIFailureNewStyle(composeMsgAuthorizedAccess(PER_SECOND, rl.per_second), 429, exceededRateLimit(rl, PER_SECOND))), userAndCallContext._2)
Expand All @@ -267,14 +303,16 @@ object RateLimitingUtil extends MdcLoggable {
case x1 :: x2 :: x3 :: x4 :: x5 :: x6 :: Nil if x6 == false =>
(fullBoxOrException(Empty ~> APIFailureNewStyle(composeMsgAuthorizedAccess(PER_MONTH, rl.per_month), 429, exceededRateLimit(rl, PER_MONTH))), userAndCallContext._2)
case _ =>
// All limits passed - increment counters and set rate limit headers
val incrementCounters = List (
incrementConsumerCounters(rateLimitingKey, PER_SECOND, rl.per_second), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_MINUTE, rl.per_minute), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_HOUR, rl.per_hour), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_DAY, rl.per_day), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_WEEK, rl.per_week), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_MONTH, rl.per_month) // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(rateLimitingKey, PER_SECOND, rl.per_second),
incrementConsumerCounters(rateLimitingKey, PER_MINUTE, rl.per_minute),
incrementConsumerCounters(rateLimitingKey, PER_HOUR, rl.per_hour),
incrementConsumerCounters(rateLimitingKey, PER_DAY, rl.per_day),
incrementConsumerCounters(rateLimitingKey, PER_WEEK, rl.per_week),
incrementConsumerCounters(rateLimitingKey, PER_MONTH, rl.per_month)
)
// Set rate limit headers based on the most restrictive active period
incrementCounters match {
case first :: _ :: _ :: _ :: _ :: _ :: Nil if first._1 > 0 =>
(userAndCallContext._1, setXRateLimits(rl, first, PER_SECOND))
Expand All @@ -292,17 +330,21 @@ object RateLimitingUtil extends MdcLoggable {
(userAndCallContext._1, userAndCallContext._2)
}
}
case None => // Anonymous access
case None => // ANONYMOUS ACCESS - no consumer credentials, use IP-based limiting
// Use client IP address as rate limiting key for anonymous access
val consumerId = cc.ipAddress
// Anonymous access only has per-hour limits to prevent abuse
val checkLimits = List(
underConsumerLimits(consumerId, PER_HOUR, perHourLimitAnonymous)
)
checkLimits match {
case x1 :: Nil if x1 == false =>
case x1 :: Nil if !x1 =>
// Return 429 error if anonymous hourly limit exceeded
(fullBoxOrException(Empty ~> APIFailureNewStyle(composeMsgAnonymousAccess(PER_HOUR, perHourLimitAnonymous), 429, exceededRateLimitAnonymous(consumerId, PER_HOUR))), userAndCallContext._2)
case _ =>
// Limit not exceeded - increment counter and set headers
val incrementCounters = List (
incrementConsumerCounters(consumerId, PER_HOUR, perHourLimitAnonymous), // Responses other than the 429 status code MUST be stored by a cache.
incrementConsumerCounters(consumerId, PER_HOUR, perHourLimitAnonymous)
)
incrementCounters match {
case x1 :: Nil if x1._1 > 0 =>
Expand All @@ -312,7 +354,8 @@ object RateLimitingUtil extends MdcLoggable {
}
}
}
case _ => (userAndCallContext._1, userAndCallContext._2)
case _ => // No CallContext available - pass through without rate limiting
(userAndCallContext._1, userAndCallContext._2)
}
}

Expand Down