백준 7578 - 공장
https://www.acmicpc.net/problem/7578
문제
어떤 공장에는 2N개의 기계가 2열에 걸쳐 N개씩 배치되어 있다. 이 2개의 열을 각각 A열과 B 열이라고 부른다. A열에 있는 N개의 기계는 각각이 B열에 있는 N개의 기계와 하나씩 짝을 이루어 케이블로 연결되어 있다. 즉, A열의 임의의 기계는 B열의 유일한 기계와 케이블로 연결되어 있고, B열의 임의의 기계는 A열의 유일한 기계와 케이블로 연결되어 있다
또한, 각 기계에는 식별번호가 붙어있으며, 짝이 맺어진 기계끼리는 같은 식별번호가 붙어있다. 즉, 각 열에 있는 N개의 기계끼리는 서로 다른 식별번호를 가지고 있으며, 반대쪽 열에 있는 같은 식별번호를 가진 기계와 케이블로 이어져 있다.
공장 작업의 효율성을 위해 기계들은 짝을 맺은 순서대로 배치되지 않으며, 필요에 따라 각 열의 기계들의 순서를 바꾼 바람에 케이블은 마구 엉켜있는 상태이다. 이렇게 엉켜버린 케이블은 잦은 고장의 원인이 되기 때문에, 기계의 위치를 바꾸지 않은 상태에서 케이블을 두 기계를 잇는 직선의 형태로 만들기로 했다.
예를 들어, 위의 그림과 같이 N = 5이고, A열에 위치한 기계의 식별번호가 순서대로 132, 392, 311, 351, 231이고 B열에 위치한 기계의 식별번호가 순서대로 392, 351, 132, 311, 231이라면 케이블들의 교차 횟수 혹은 서로 교차하는 케이블 쌍의 개수는 3이 된다.
정수 N과 A열에 위치한 기계, B열에 위치한 기계의 식별번호가 각각 순서대로 주어질 때에 서로 교차하는 케이블 쌍의 개수를 정확하게 세어 출력하는 프로그램을 작성하시오.
입력
입력은 세 줄로 이루어져 있다. 첫 줄에는 정수 N이 주어지며, 두 번째 줄에는 A열에 위치한 N개 기계의 서로 다른 식별번호가 순서대로 공백문자로 구분되어 주어진다. 세 번째 줄에는 B열에 위치한 N개의 기계의 식별번호가 순서대로 공백문자로 구분되어 주어진다.
단, 1 ≤ N ≤ 500,000이며, 기계의 식별번호는 모두 0 이상 1,000,000 이하의 정수로 주어진다.
출력
여러분은 읽어 들인 2N개의 기계의 배치로부터 서로 교차하는 케이블 쌍의 개수를 정수 형태로 한 줄에 출력해야 한다.
풀이
먼저, 기계의 식별번호 값 자체는 별로 중요하지 않다. 각 열의 몇 번째 기계인지를 중심으로 풀자. 식별번호가 1,000,000 이하이므로, 식별번호가 i인 기계의 위치를 배열로 담아 사용할 수 있다. 식별번호가 132인 기계가 A열의 세 번째에 위치하면 d[132] = 2와 같이 저장하는 것이다. (첫 번째를 0으로 지정하자.)
for문으로 A열 또는 B열의 기계를 순회하며 그보다 앞 순서의 기계 중 해당 기계와 교차하는 것의 수를 전부 더하면 답을 구할 수 있다. A열의 정보를 먼저 입력받으므로, A열을 입력받으며 순서를 저장하고, 이후 B열을 입력받으면서 교차하는 기계 수를 세는 것이 시간을 줄이는데 도움이 될 것이다.
B열의 $x$ 번째 기계의 A열에서의 순서를 $p(x)$라고 하자. 예컨대 예제의 식별번호가 132인 기계는 A열의 첫 번째, B열의 세 번째에 위치하므로 $p(2) = 0$이다.
두 기계의 순서가 A열과 B열에서 반대이면 케이블이 교차하므로, B열의 $x$ 번째 기계에 대해 $y < x$이면서 $p(y) > p(x)$인 $y$ 번째 기계의 수를 세면 된다.
Naive하게 처리하면 $O(n^2)$이므로 탐색 횟수를 줄여야 한다. $p(x)$는 바뀌지 않고, $i$ 번째에 교차 여부를 확인해야 하는 공장은 $i - 1$ 번째에 확인해야 하는 공장을 완전히 포함하므로 이전 탐색 결과를 참조하여 계산할 수 있을 것이다.
지금까지 탐색한 공장의 종점 정보를 담아두는 배열 tr[]를 사용하자. $i$ 번째 공장을 처리했으면 tr[p(i)]를 1로 바꾸는 것이다. 이렇게 할 경우 $y < x$인 공장 $y$에 대해 $p(y) > p(x)$인 것은 tr[p(x):]의 합으로 구할 수 있다. 계속 변화하는 배열의 부분합은 세그먼트 트리를 이용하면 계산과 변경을 모두 $O(\log n)$에 처리할 수 있으므로 문제를 $O(n \log n)$에 처리할 수 있다.
코드
#include <bits/stdc++.h>
#define ll long long
using namespace std;
int n, d[1000005], tr[2000020];
void update(int s, int e, int idx, int tidx){
if (idx < s || idx > e) return;
if (s == e) { tr[tidx]++; return; }
update(s, (s + e) / 2, idx, tidx * 2);
update((s + e) / 2 + 1, e, idx, tidx * 2 + 1);
tr[tidx] = tr[tidx * 2] + tr[tidx * 2 + 1];
}
ll sum(int s, int e, int l, int r, int idx){
if (r < s || e < l) return 0;
if (l <= s && e <= r) return tr[idx];
return sum(s, (s + e) / 2, l, r, idx * 2) + sum((s + e) / 2 + 1, e, l, r, idx * 2 + 1);
}
int main(){
cin.tie(0); ios::sync_with_stdio(false);
cin >> n;
int t, p;
for (int i = 0; i < n; i++) { cin >> t; d[t] = i; }
ll result = 0;
for (int i = 0; i < n; i++){
cin >> t; p = d[t];
result += sum(0, 1000004, p, 1000004, 1);
update(0, 1000004, p, 1);
}
cout << result;
}