jz051.数组中的逆序对

 

题目链接

看到这个题目的第一反应是顺序扫描数组,每扫描到一个数字就逐个比较该数字和它后面的数字的大小,如果后面数字比他小,则这俩数字就组成一个逆序对,但这种算法的时间复杂度为O(n^2)。我们需要尝试找到更快的算法,以数组[7, 5, 6, 4]为例子,我们考虑先比较两个相邻的数字。我们先把数组分解成两个长度为2的子数组,再将他们拆成长度为1的子数组。接下来一边合并相邻的子数组一边统计逆序对的数目。在第一对长度为1的子数组[7], [5]中,7大于5,因此(7,5)组成一个逆序对。同样,在第二对长度为1的子数组[6], [4]中,也有逆序对(6,4)。由于我们已经统计了这两对子数组内部的逆序对,因此需要把这两对子数组排序,以免在以后的统计过程中再重复统计。接下来我们再统计两个长度为2的子数组之间的逆序对,我们先用两个指针分别指向两个子数组的末尾,并每次比较两个指针指向的数字。如果第一个子数组中的数字大于第二个子数组中的数字,则构成逆序对,并且逆序对数目等于第二个子数组中剩余数字的个数。如果第一个数组中的数字小于或等于第二个数组中的数字,则不构成逆序对。每次比较的时候,我们都把较大的数字从后往前复制到一个辅助数组,确保辅助数组中的数字是递增排序的。在把较大的数字复制到辅助数组之后,把对应的指针向前移动一位,接下来进行下一轮比较。总结一下统计逆序对的过程:先把数组分割成子数组,统计出子数组内部的逆序对的数目,在统计逆序对的过程中,还需要对数组进行排序。如果对排序算法很熟悉,那么我们不难发现这个排序的过程实际上就是归并排序,我们可以基于归并排序写出如下代码。对于第一个,我们每次在mergesort中都开了一个新的tmp,这样空间复杂度大概大概相当于NlogN,实际上我们可以改成像2中的,共用一个tmp。这个算法的时间复杂度是O(NlogN),空间复杂度是O(N)。

class Solution:
    def reversePairs1(self, nums: List[int]) -> int:
        self.ans = 0

        def merge(nums, start, mid, end):
            i, j, tmp = start, mid+1, []
            while i <= mid and j <= end:
                if nums[i] <= nums[j]:
                    tmp.append(nums[i])
                    i += 1
                else:
                    self.ans += (mid - i + 1)
                    tmp.append(nums[j])
                    j += 1
            while i <= mid:
                tmp.append(nums[i])
                i += 1
            while j <= end:
                tmp.append(nums[j])
                j += 1
            
            for i in range(len(tmp)):
                nums[start+i] = tmp[i]

        def mergesort(nums, start, end):
            if start >= end:
                return
            mid = (start + end) // 2
            mergesort(nums, start, mid)
            mergesort(nums, mid+1, end)
            merge(nums, start, mid, end)

        mergesort(nums, 0, len(nums)-1)
        return self.ans

    def reversePairs2(self, nums: List[int]) -> int:
        self.ans = 0

        def merge(nums, start, mid, end, tmp):
            i, j = start, mid+1
            while i <= mid and j <= end:
                if nums[i] <= nums[j]:
                    tmp.append(nums[i])
                    i += 1
                else:
                    self.ans += (mid - i + 1)
                    tmp.append(nums[j])
                    j += 1
            while i <= mid:
                tmp.append(nums[i])
                i += 1
            while j <= end:
                tmp.append(nums[j])
                j += 1

            for i in range(len(tmp)):
                nums[start+i] = tmp[i]
            tmp.clear()

        def mergesort(nums, start, end, tmp):
            if start >= end:
                return
            mid = (start + end) // 2
            mergesort(nums, start, mid, tmp)
            mergesort(nums, mid+1, end, tmp)
            merge(nums, start, mid, end, tmp)

        mergesort(nums, 0, len(nums)-1, [])
        return self.ans