백준 7453
문제
정수로 이루어진 크기가 같은 배열 A, B, C, D가 있다.
A[a], B[b], C[c], D[d]의 합이 0인 (a, b, c, d) 쌍의 개수를 구하는 프로그램을 작성하시오.
입력
첫째 줄에 배열의 크기 n (1 ≤ n ≤ 4000)이 주어진다. 다음 n개 줄에는 A, B, C, D에 포함되는 정수가 공백으로 구분되어져서 주어진다. 배열에 들어있는 정수의 절댓값은 최대 228이다.
출력
합이 0이 되는 쌍의 개수를 출력한다.
간단한 오류를 몇 단계로 나누어 수정하느라 오답이 두 번이나 뜬 문제.
답의 최댓값이 $n^4$까지 될 수 있어, $2000^4 \approx 2^{44}$정도까지 가능한데 그냥 int로 변수를 정의해서 오답이 뜨고,
result만 long long으로 정의하여 한번 더 오류가 떴다. num1 * num2도 각 변수가 int이면 곱하는 순간 범위를 넘길 수 있는 것이 그 이유.
문제를 푼 방법은, 네 array입력 중 두 개씩 짝으로 가능한 합을 두 vector (sum_ab, sum_cd)에 저장하고
각각을 sort하되, 하나는 ascending, 하나는 descending하게 정렬
이후에 합이 0이 되는지 확인하여, 합이 0이면 경우의 수를 세어 적절히 result에 가산하고
0보다 크면 작아지는 방향으로, 0보다 작으면 커지는 방향으로 index를 조절,
각 vector의 마지막에 MAX+1과 -MAX를 포함시킨 것은, result에 가산할 때 다음 index 값과 비교하는 부분이 있는데(44줄)
마지막 index에 닿았을 때를 굳이 index로 확인하지 않고 마무리하기 위함이다.
입력의 최댓값을 알고 있으므로, 해당 범위의 입력에 대해 문제가 생기지 않는 값을 MAX로 지정했고,
MAX, -MAX로 해도 문제가 되지 않지만 혹시나 하는 마음에 하나는 MAX+1, 나머지 하나는 -MAX로 입력하여 두 값의 합이 0이 되지 않도록 조작했다.
힌트는 이분탐색을 제안하고 있는데, 지금처럼 두 개씩 나누어 vector를 구성하는 것 까지는 같을 것 같고, 이후에 한 vector만 sort한 후
sort하지 않은 vector의 element를 순차적으로 travel하면서, sort한 vector를 이분탐색하여 합이 0이 되는 숫자의 갯수를 세어 result에 추가하면 되는 것 같다.
그렇게 하면 두 vector의 크기가 $n^2$이므로 $l := n^2$이면 $\mathbb{O}(l \log l)$정도의 복잡도를 가지게 된다.
$n$으로 표기하면, $\mathbb{O}(n^2 \log n)$
하지만 아래 코드대로 진행하면, 두 vector를 모두 sort해야 하지만, 두 vector를 번갈아가며 travel하기만 하면 되므로
$\mathbb{O}(2l +2 \log l) = \mathbb{O}(l) = \mathbb{O}(n^2)$이어서 더 나을 것으로 예상된다.
아래는 코드.
#include<iostream> #include<vector> #include<algorithm> //#include<queue> //#include<cstdlib> const int MAX = 2000000000; using namespace std; int main() { int n, temp; vector<int> input[4]; vector<int> sum_ab, sum_cd; cin >> n; for (int i = 0; i < n; i++) { for (int j = 0; j < 4; j++) { cin >> temp; input[j].push_back(temp); } } for (int i = 0; i < n; i++) { for (int j = 0; j < n; j++) { sum_ab.push_back(input[0][i] + input[1][j]); sum_cd.push_back(input[2][i] + input[3][j]); } } sort(sum_ab.begin(), sum_ab.end()); sort(sum_cd.begin(), sum_cd.end(), greater<int>()); sum_ab.push_back(MAX+1); sum_cd.push_back(-MAX); long long result = 0, num1, num2; int i = 0, j = 0; while (i < n*n && j < n*n) { if (sum_ab[i] + sum_cd[j] == 0) { num1 = 1; num2 = 1; while (sum_ab[i] == sum_ab[i + 1]) { num1++; i++; } while (sum_cd[j] == sum_cd[j + 1]) { num2++; j++; } result += num1 * num2; i++; j++; } else if (sum_ab[i] + sum_cd[j] > 0) { while (sum_ab[i] + sum_cd[j] > 0) { j++; } } else { while (sum_ab[i] + sum_cd[j] < 0) { i++; } } } cout << result; return 0; }