N과 M(9) next_permutation풀이
2024-02-05
#C++#백준
next_permutaion이란?
순열을 구해주는 c++의 라이브러리
순열은 다들 아시니 넘어가겠습니다.
next_permutaion이 어떻게 작동하는가?
next_permuation은 두개의 인자를 받는데 첫번째 인자는 순열을 구할 배열의 시작 주소(iterator), 두번째 인자는 끝나는 주소(iterator)를 받습니다.
주의할 점
이 함수는 주어진 배열을 기준으로 사전 순서로 다음 순열을 만들기 때문에 처음 주어진 배열이
오름차순으로 정렬되어 있지 않는다면 모든 경우를 보여주지 않는다.
- •
예시
int main() {
vector<int> v = {9, 7, 9, 1};
do {
for (int i = 0; i < v.size(); i++)
cout << v[i] << " ";
cout << endl;
} while (next_permutation(v.begin(), v.end()));
}
위와 같은 코드가 있다면 처음으로 주어진 배열이 9 7 9 1 이므로 모든 순열이 출력되지 않고 나머지 배열인
9 7 9 1
9 9 1 7
9 9 7 1
이 출력된다.
여기서 우리는 모든 순열을 출력하는 것이아니고 길이가 m인 수열을 구해야 한다.
즉, 조합을 구해야 한다.
조합을 구하기 우해서 우리는 temp벡터를 사용하여 원하는만큼의 수를 출력할 것이다.
예시를 보자
int main() {
vector<int> v = {9, 7, 9, 1};
vector<int> temp = {0, 0, 1, 1};
sort(v.begin(), v.end());
do {
for (int i = 0; i < v.size(); i++)
if (!temp[i])
cout << v[i] << " ";
cout << endl;
} while (next_permutation(temp.begin(), temp.end()));
}
위와 같이 temp 배열에 출력하고 싶은 수의 개수만큼 0으로 초기화하고 나머지는 1로 하여 출력할 때,
temp가 0인 위치만 출력하도록 한 것이다.
그럼 출력으로 아래와 같은 값을 얻을 수 있다.
1 7
1 9
1 9
7 9
7 9
9 9
하지만 여기서 우리가 구하고 싶은 것은 조합과 살짝 다른 개념이므로 이를 응용해서 원하는 답을 잘 구해야 한다.
이를 코드로 풀어보면 아래와 같은 결과가 나온다.
#include <iostream>
#include <vector>
#include <algorithm>using namespace std;
int main() {
vector<int> input, temp, temp_ans;
vector<vector<int>> ans;
int n, m, num;
cin >> n >> m;
for (int i = 0; i < n; i++) { //수 입력 받기
cin >> num;
input.push_back(num);
}
for (int i = 0; i < n; i++) { //출력할만큼 수 지정
if (i < m)
temp.push_back(0);
else
temp.push_back(1);
}
sort(input.begin(), input.end()); //오름차순 정렬하기
do {
temp_ans.clear();
for (int i = 0; i < input.size(); i++) { //현재 수 배열에서 m개 만큼의 수 temp_ans에 저장
if (!temp[i])
temp_ans.push_back(input[i]);
}
do { //temp_ans로 만들 수 있는 모든 순열 저장하기
ans.push_back(temp_ans);
} while (next_permutation(temp_ans.begin(), temp_ans.end()));
} while (next_permutation(temp.begin(), temp.end()));
sort(ans.begin(), ans.end()); //ans 오름차순 정렬
for (int i = 0; i < ans[0].size(); i++) //첫번째 답은 무조건 출력
cout << ans[0][i] << " ";
cout << "\n";
for (int i = 1; i < ans.size(); i++) {
bool flag = false;
for (int j = 0; j < ans[i].size(); j++) {
if (ans[i - 1][j] != ans[i][j]) { //ans벡터의 이전값과 하나라도 다르면 그대로 출력
flag = true;
break;
}
}
if (flag) {
for (int j = 0; j < ans[i].size(); j++)
cout << ans[i][j] << " ";
cout << "\n";
}
}
}
참 쉽죠?