11401번 - 이항 계수 3
문제 풀이 시간 : 2시간
문제 요약
- •
자연수 N과 정수 K가 주어졌을 때 이항계수를 1,000,000,007로 나눈 나머지를 구해라
- •
1 ≤ N ≤ 4,000,000
- •
0 ≤ K ≤ N
문제 풀이
이 문제는 알고리즘 시간때 배웠던 dp 알고리즘으로 풀면 되는 아주 간단한 문제 아닌가?
하고 신나게 바로 풀기 시작했다.
b[i][j] = b[i-1][j] + b[i-1][j-1]
초기 코드 (dp)
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
int N = Integer.parseInt(st.nextToken());
int K = Integer.parseInt(st.nextToken());
int[][] bin = new int[N + 1][K + 1];
for (int i = 0; i <= N; i++) {
for (int j = 0; j <= Math.min(i, K); j++) {
if (j == 0 || j == i) {
bin[i][j] = 1;
} else {
bin[i][j] = (bin[i - 1][j] + bin[i - 1][j - 1]) % 1_000_000_007;
}
}
}
System.out.println(bin[N][K]);
}
}
이렇게 아주 간단하게 코드를 작성할 수 있었고,
바로 메모리 초과를 받게 되었다.ㅋ
생각해보면 당연한게 N의 범위가 4,000,000이라서 배열이 최대 4,000,000 * 4,000,000이 될 수 있는 것이다.
그래서 생각한게 배열을 최소한으로 사용하면 메모리 초과를 해결할 수 있지 않을까?
K 길이의 배열을 두개만 사용해서 푸는 방법이다.
중간 코드(메모리 해결)
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
public class Main {
public static int N,K;
public static void calc(int[] a, int[] b, int idx){
for(int i=0;i<=Math.min(idx,K);i++){
if(i==0||i==idx){
a[i] = 1;
}
else{
a[i] = b[i - 1] + b[i];
}
}
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
K = Integer.parseInt(st.nextToken());
int[] bin = new int[K + 1];
int[] temp = new int[K + 1];
for (int i = 0; i <= N; i++) {
if(i%2==0){
calc(bin,temp,i);
}
else{
calc(temp, bin, i);
}
}
int ans = (N % 2 == 0) ? bin[K] : temp[K];
System.out.println(ans);
}
}
이렇게 했더니 메모리 초과는 해결되었지만
시간초과가 뜬다.
그것도 그럴게 생각해보면 우리가 배운 dp 알고리즘은 O(n^2)의 시간복잡도를 가진다.
즉, 이 문제에서는 n 이 4,000,000이라서 엄청나게 느리게 동작한다는 것이다.
그러면 어떻게 해결해야 할까?
우리는 이항계수를 구하는 식을 알고 있다.
위 식을 사용하면 해당 문제의 조건에서 O(n^2)의 시간복잡도를 거치지 않고 더 빠르게 구할 수 있다.
하지만 문제는 위 식에서 1,000,000,007로 나눈 나머지 값을 출력해야 한다.
우리는 보통 나머지 값을 출력할 때, 모듈러 연산을 사용한다.
모듈러 연산은 우리가 알고리즘 시간에 배운 것처럼 분배법칙이 성립한다.
하지만 모듈러 연산은 아래와 같이 나눗셈에서의 분배법칙은 성립되지 않는다.
즉, 우리는 위의 나눗셈을 곱셉으로 바꿔줘야 한다.
어떻게 나눗셈을 곱셈으로 바꿀까?
역원을 이용하는 것이다.
위와 같이 x에 대한 역원을 구하면 손쉽게 나눗셈을 곱셈으로 표현해줄 수 있다.
즉, 우리가 구하고자 하는 식을 곱셈으로 바꾼다면
로 나타낼 수 있다.
그렇다면 우리가 이제 구해야 하는 것은 k!(n-k)!의 역원을 구하는 것이다.
여기서 우리는 ‘페르마의 소정리’를 사용하게 된다.
페르마의 소정리란?
여기서 나뭇가지 같은 기호는 a가 p의 배수가 아닌 정수라는 뜻이다.
위 식의 예를 들어 p = 7, a = 12라고 한다면,
이 된다.
즉, 12^6 을 7로 나눈 나머지가 1이 된다는 뜻이다.
그럼 이 식을 증명해보자.
다음과 같은 p-1 크기의 집합을 살펴보자. (단, p는 소수)
이 집합의 모든 원소에 a를 곱한다. (단, a는 p의 배수가 아닌 정수)
그럼 아래와 같은 식이 성립되게 되는데,
예를 들어, p = 5, a = 2 라고 한다면,
가 된다.
이것을 증명하기 위해 아래 사실을 보인다.
귀류법으로 해당 사실을 증명해보인다.
인 i,j 가 존재한다고 가정하자.
p가 소수이므로 a와 p는 서로소이다. 따라서 양변을 a로 나누면
가 된다.
여기서 i와 j는 같은 나머지를 가진다는 뜻이 된다.
그러나 1 ≤ i,j ≤ p-1 의 범위에서 i와 j의 값이 다르다는 것이 말이 안된다. 즉, i와 j는 동일하다가 된다.
따라서,
위의 두 집합은 모든 원소들이 서로 다른 값을 가지는 같은 집합이라는 뜻이다.
따라서 두 집합의 모든 원소를 곱한 값도 같은 것이다.
여기서 양변을 (p-1)!로 나누면 아래의 식이 성립하게 된다.
이게 페르마의 소정리이다.
사실 나도 잘 이해가 안된다.ㅎㅎ
풀이(계속)
위 정리를 사용해서 풀이를 계속해보자면-
가 된다.
즉, a (mod p) 의 역원은 a^p-2 (mod p)가 되는 것이다.
그렇다면 우리가 구하고자 했던 역원은?
가 된다.
즉, 최종적으로 우리가 구해야하는 식은 아래와 같다.
이 된다.
이 식을 이용해서 코드를 짜면 아래와 같이 쉽게 구할 수 있게 된다.
최종 코드
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;
public class Main {
public static int N, K;
public static long DIV = 1_000_000_007;
public static long pow(long a, long b) {
if (b == 1) {
return a % DIV;
}
long temp = pow(a, b / 2);
if (b % 2 == 0) {
return temp * temp % DIV;
}
return (temp * temp % DIV) * a % DIV;
}
public static long fact(int a) {
long result = 1L;
while (a > 1) {
result = (result * a) % DIV;
a -= 1;
}
return result;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
StringTokenizer st = new StringTokenizer(br.readLine());
N = Integer.parseInt(st.nextToken());
K = Integer.parseInt(st.nextToken());
long up = fact(N) % DIV;
long down = fact(K) * fact(N - K) % DIV;
System.out.println(up * pow(down, DIV - 2) % DIV);
}
}