Skip to main content

Kth Smallest Element in a Sorted Matrix


Kth Smallest Element in a Sorted Matrix: Given an n x n matrix where each of the rows and columns is sorted in ascending order, return the kth smallest element in the matrix.

Note that it is the kth smallest element in the sorted order, not the kth distinct element.

You must find a solution with a memory complexity better than O(n^2).

Example 1:
Input: matrix = [[1,5,9],[10,11,13],[12,13,15]], k = 8
Output: 13
Explanation: The elements in the matrix are [1,5,9,10,11,12,13,13,15],
and the 8th smallest number is 13.

Example 2:
Input: matrix = [[-5]], k = 1
Output: -5

Constraints:
  • n == matrix.length == matrix[i].length
  • 1 <= n <= 300
  • -10^9 <= matrix[i][j] <= 10^9
  • All the rows and columns of matrix are guaranteed to be sorted in non-decreasing order.
  • 1 <= k <= n^2

Try this Problem on your own or check similar problems:

  1. Find K Pairs with Smallest Sums
  2. Kth Smallest Number in Multiplication Table
  3. Find K-th Smallest Pair Distance
Solution:
public int kthSmallest(int[][] matrix, int k) {
int n = matrix.length;

int start = matrix[0][0], end = matrix[n - 1][n - 1];

while(start < end) {
int mid = start + (end - start) / 2;

int cnt = 0, i = 0, j = n - 1;
while (i < n && j >= 0) {
if (matrix[i][j] > mid) --j;
else {
cnt += j + 1;
++i;
}
}

if(cnt < k) start = mid + 1;
else end = mid;
}
return start;
}

Time/Space Complexity:
  • Time Complexity: O(nlog(max-min))
  • Space Complexity: O(1)

Explanation:

First, it's important to note the relative order of elements in the rows and columns, your first hint should be to use some kind of implementation that uses this fact (connection between current element and its successor and predecessor). We can first pick initial set of candidates for our kth smallest number, since the first row contains the smallest numbers (when compared to other rows), we can pick the first row elements as our candidates (note that we pick min(k,n) since we don't really need the whole first row in the queue if k is a lot smaller, we just need k candidates). We use triple to denote picked element (position in the matrix and its value). We create a priority queue that will have the minimum number (minimum by the value property of the triplet) at the top of it. So, if we poll from the queue k-1 times at the top of it will be the kth smallest number, but when we discard the number what’s the next best element to replace it in the queue? Since columns are sorted too, we can just choose the element below it (also taking care of out of boundary elements with current[0] == n-1). Finally, we return the element at the top of the queue as our result. Note that time complexity is O(klog(min(k,n))), with the space complexity of O(min(n,k)) (number of elements in the queue).

public int kthSmallest(int[][] matrix, int k) {
PriorityQueue<int[]> p = new PriorityQueue<>((a, b) -> a[2] - b[2]);
int n = matrix.length;

for(int i = 0; i < Math.min(k, n); ++i){
p.add(new int[]{0, i, matrix[0][i]});
}

for(int i = 0; i < k - 1; ++i){
int[] current = p.poll();
if(current[0] == n-1) continue;
p.add(new int[]{current[0] + 1, current[1], matrix[current[0] + 1][current[1]]});
}

return p.peek()[2];
}

It's a bit tricker to use the binary search to find the element satisfying the requirement "have k-1 elements in matrix smaller than it", but we can build trial and error algorithm. If we define our search space as all elements in the matrix between the first element (matrix[0][0]) and the last element (matrix[n-1][n-1]) inclusive we run a binary search each time picking the mid element (trial) and checking how many number are smaller than it (error). Our metric here is the count function which goes over all rows and for each row checks how many numbers are smaller than current chosen number mid (note that it's enough to find first j column for each row with a value smaller than current mid if we iterate backwards, e.g. first element iterating from right to left over each row is in the column = 3 that means that all elements in colums < 3 are also smaller than mid since the rows are sorted in ascending order so we can just add cnt += (j + 1)). If the total count is smaller than k our mid is to small so we have to move the start of our search space to elements greater than mid, the other scenario is that cnt is greater or equal to k in that case we don't have to search in the right side of search space and we can bring our end boundary to the current mid (note that we cannot discard the current mid from the search space since it could well be the element we are looking since cnt >= k). The time complexity is equal to number of partitions (halving) of the search space (maxElement - minElement) and time complexity of the job we do for each partition (the operation we do is kind a like search element in sorted matrix, we begin at first row and last column, and check if current matrix element is smaller than our target element, if not we check the previous column, if current element is smaller we have to go down in the matrix i+1 to find an element larger or equal to target element, the time complexity for this is O(m+n) or in our case O(n+n)=O(2n)). Check the Solution section for the final solution. Note that time complexity for binary search could be much worse than the time complexity using the priority queue, you should discuss with your interview the nature of the input, because if we're trying to save on space if n or k are large numbers we could use binary search which has constant space, but if we're optimizing for the time complexity and we know that k << n than it makes sense to use the solution with the priority queue.

With binary search how can we guarantee that element we end up with exists in the matrix? Let's say we chose mid such that it doesn't exist in the matrix but has count >= k, so mid is larger than element x which exists in the matrix and where have k-1 smaller or equal than x elements in the matrix. If we haven't picked such mid (e.g. we picked mid which has only k-2 elements smaller or equal to it) than we will move it forward (we will increase the mid), but if we did get a mid such that count >= k we will place our end to it end = mid, now we have start <= x <= mid <= end so as start converges to end at the end of our loop (start == end) we will have mid == x and also (start == end == mid) so we just return our start as our final solution. Imagine you have a range (x1, x2, x3) where only x1 is in the matrix, let's say their counts (numbers smaller or equal to them) are k, k+1, k+1, since there is a k+1 for x2 we know that x1 will exist in the matrix and we will converge to it.