N과 M(9) next_permutation풀이

2024-02-05
#C++#백준

next_permutaion이란?

순열을 구해주는 c++의 라이브러리

순열은 다들 아시니 넘어가겠습니다.

next_permutaion이 어떻게 작동하는가?

next_permuation은 두개의 인자를 받는데 첫번째 인자는 순열을 구할 배열의 시작 주소(iterator), 두번째 인자는 끝나는 주소(iterator)를 받습니다.

주의할 점

이 함수는 주어진 배열을 기준으로 사전 순서로 다음 순열을 만들기 때문에 처음 주어진 배열이

오름차순으로 정렬되어 있지 않는다면 모든 경우를 보여주지 않는다.

  • 예시

Plain text
int main() { vector<int> v = {9, 7, 9, 1}; do { for (int i = 0; i < v.size(); i++) cout << v[i] << " "; cout << endl; } while (next_permutation(v.begin(), v.end())); }

위와 같은 코드가 있다면 처음으로 주어진 배열이 9 7 9 1 이므로 모든 순열이 출력되지 않고 나머지 배열인

Plain text
9 7 9 1 9 9 1 7 9 9 7 1

이 출력된다.


여기서 우리는 모든 순열을 출력하는 것이아니고 길이가 m인 수열을 구해야 한다.

즉, 조합을 구해야 한다.

조합을 구하기 우해서 우리는 temp벡터를 사용하여 원하는만큼의 수를 출력할 것이다.

예시를 보자

Plain text
int main() { vector<int> v = {9, 7, 9, 1}; vector<int> temp = {0, 0, 1, 1}; sort(v.begin(), v.end()); do { for (int i = 0; i < v.size(); i++) if (!temp[i]) cout << v[i] << " "; cout << endl; } while (next_permutation(temp.begin(), temp.end())); }

위와 같이 temp 배열에 출력하고 싶은 수의 개수만큼 0으로 초기화하고 나머지는 1로 하여 출력할 때,

temp가 0인 위치만 출력하도록 한 것이다.

그럼 출력으로 아래와 같은 값을 얻을 수 있다.

Plain text
1 7 1 9 1 9 7 9 7 9 9 9

하지만 여기서 우리가 구하고 싶은 것은 조합과 살짝 다른 개념이므로 이를 응용해서 원하는 답을 잘 구해야 한다.

이를 코드로 풀어보면 아래와 같은 결과가 나온다.

Plain text
#include <iostream> #include <vector> #include <algorithm>using namespace std; int main() { vector<int> input, temp, temp_ans; vector<vector<int>> ans; int n, m, num; cin >> n >> m; for (int i = 0; i < n; i++) { //수 입력 받기 cin >> num; input.push_back(num); } for (int i = 0; i < n; i++) { //출력할만큼 수 지정 if (i < m) temp.push_back(0); else temp.push_back(1); } sort(input.begin(), input.end()); //오름차순 정렬하기 do { temp_ans.clear(); for (int i = 0; i < input.size(); i++) { //현재 수 배열에서 m개 만큼의 수 temp_ans에 저장 if (!temp[i]) temp_ans.push_back(input[i]); } do { //temp_ans로 만들 수 있는 모든 순열 저장하기 ans.push_back(temp_ans); } while (next_permutation(temp_ans.begin(), temp_ans.end())); } while (next_permutation(temp.begin(), temp.end())); sort(ans.begin(), ans.end()); //ans 오름차순 정렬 for (int i = 0; i < ans[0].size(); i++) //첫번째 답은 무조건 출력 cout << ans[0][i] << " "; cout << "\n"; for (int i = 1; i < ans.size(); i++) { bool flag = false; for (int j = 0; j < ans[i].size(); j++) { if (ans[i - 1][j] != ans[i][j]) { //ans벡터의 이전값과 하나라도 다르면 그대로 출력 flag = true; break; } } if (flag) { for (int j = 0; j < ans[i].size(); j++) cout << ans[i][j] << " "; cout << "\n"; } } }

참 쉽죠?