재귀 함수 Ⅰ

시뮬레이션

함수

함수에 대한 기초적인 내용은 다들 아실 것 같아서 생략하고, 잘 모를만한 내용만 설명하겠습니다.

파이썬은 C언어랑 다르게 main 함수가 없습니다. 그래서 함수 내부에서 변수가 설정되는 것이 아닌 이상 모든 변수, 객체가 전역변수로 설정되어서 함수 내부나 클래스 내부에서 이용할 수 있습니다.

def func():
    print(arr)
    print(val)

arr = [1, 2, 3]
val = 1000
func()
print(arr)
class A:
    def __init__(self):
        self.x = 0
        self.y = 0
        print(x) # 50

def func():
    print(val.x) # 0
    val.y += 1
    print(val.y) # 1

x = 50
val = A()
func()

여기서 객체는 객체 고유의 메모리 주소 값을 가지고 있기 때문에 함수 내부에서 값을 수정할 수 있습니다. 하지만 전역 변수로 설정된 숫자나 문자열은 값을 변경하려고 하려면 global이라는 키워드가 필요합니다.

# 정상 출력
def func():
    arr.append(4)
    print(arr) #[1, 2, 3, 4]
    
arr = [1, 2, 3]
func()
print(arr) #[1, 2, 3, 4]
# 오류 출력
def func():
    val += 1
    print(val)

val = 100
func()
# 정상 출력
def func():
    global val
    val += 1
    print(val) # 101

val = 100
string = 'abc'
func()
print(val) # 101
# 정상 출력
def func():
    global string
    string += 'd'
    print(string)

string = 'abc'
func()
print(string)

재귀 함수

재귀 함수는 함수를 정의할 때, 함수 내부에서 함수 자신을 호출하는 함수입니다. 재귀 함수는 알고리즘에서 백트래킹, DFS, 유니온 파인드(분리 집합) 등과 관련된 중요한 내용입니다.

가장 유명한 예시인 팩토리얼로 재귀 함수가 어떤 구조로 돌아가는지 설명하겠습니다.

BOJ 10872

일단 입력값을 받는 변수 n, 답을 출력할 ans 변수를 만들어줍니다.

n = int(input())
ans = 1

단순히 for문으로 만들어보면 아래 코드와 같습니다.

n = int(input())
ans = 1

for i in range(1, n+1):
    ans *= i

print(ans)

이것을 단순히 factorial 이라는 함수로 옮기겠습니다. ans의 값을 수정해야 하므로 global 키워드를 사용해야 합니다.

def factorial():
    global ans
    for i in range(1, n+1):
        ans *= i

n = int(input())
ans = 1
facotrial()
print(ans)

아니면 ans를 함수 내부에 두고, return을 이용하여 답을 가져올 수도 있습니다.

def factorial():
    ans = 1
    for i in range(1, n+1):
        ans *= i
    return ans

n = int(input())
print(factorial())

이제 우리는 재귀 함수로 만들려고 하는데, for문 대신 재귀 함수로 만들어주면 됩니다. for문에서 i는 1부터 시작하므로 factorial(1)로 시작하고 ans *= 1, factorial(2)에서는 ans *= 2를 하고, ... 이러다가 facotrial(n)에서는 ans *= n을 하고 함수를 끝내면 됩니다.

def factorial(x):
    global ans
    ans *= x
    if x == n: return
    factorial(x+1)

n = int(input())
ans = 1
factorial(1)

print(ans)

마지막으로 n = 0 일 때가 존재하므로 조건문과 exit() 함수를 붙여주면 됩니다.

def factorial(x):
    global ans
    ans *= x
    if x == n: return
    factorial(x+1)

n = int(input())
if n == 0:
    print(1)
    exit()

ans = 1
factorial(1)

print(ans)

위 코드를 n = 5일 때, 실행 과정을 그려보면 아래와 같습니다.

BOJ 15649

우리가 7주차 순열과 조합에서 이미 풀어보았던 문제입니다. 이번엔 permutation이 아닌 재귀로 풀어보겠습니다.

일단 입력 값을 받는 코드를 짜봅시다.

n, m = map(int, input().split())

함수 이름은 go라고 정하고, 파라미터(매개 변수)는 리스트에 값을 저장해서 길이가 m이 될 때 출력해야 하므로 리스트를 넣어주면 될 것 같습니다.

배열을 파라미터로 설정하면 len(arr)을 통해 길이를 알 수 있고, 멤버 연산자 in을 이용하여 중복된 값을 넣지 않는지 확인할 수도 있습니다.

우리는 go()함수를 한 번 실행될 때마다 arr에 원소가 한 개씩 들어가는 함수로 만들 예정입니다. 그래서 답을 출력하는 리스트 arr은 go()함수를 m번 실행하고 출력하게 됩니다.

def go(arr):
    blah blah

n, m = map(int,input().split())

출력문을 보니까 사전순서대로 넣어야 합니다. 그러므로 for i in range(1, n+1)를 사용하여 arr에 추가합니다. 이때 이미 arr에 들어간 값은 들어가지 않아야 하므로 if문과 멤버 연산자를 이용하여 중복되지 않게 값을 넣어줍시다.

def go(arr):
    for i in range(1, n+1):
        if i not in arr:
            arr.append(i)

n, m = map(int,input().split())

우리는 go()함수를 한 번 실행할 때마다 arr에 원소를 한 개씩 추가하기로 했으니까 해당 go()함수에서는 원소를 한 개 추가했으므로 go()함수를 호출합니다.

def go(arr):
    for i in range(1, n+1):
        if i not in arr:
            arr.append(i)
            go(arr)

n, m = map(int,input().split())

이때 이렇게 코드를 두면 for문이 계속 실행되어서 arr의 배열이 계속 늘어나므로 arr.pop()을 이용하여 i값을 빼줍시다.

def go(arr):
    for i in range(1, n+1):
        if i not in arr:
            arr.append(i)
            go(arr)
            arr.pop()

n, m = map(int,input().split())

4~6줄을 한 줄로 줄인다면 go(arr+[i])를 작성하면 됩니다.

def go(arr):
    for i in range(1, n+1):
        if i not in arr:
            go(arr+[i])

n, m = map(int,input().split())

이제 arr의 길이가 m개가 되면 arr을 출력하고 끝내야 하므로 함수 첫 번째 줄에 조건문을 추가해줍시다.

def go(arr):
    if len(arr) == m:
        print(' '.join(map(str,arr)))
        return
    for i in range(1, n+1):
        if i not in arr:
            go(arr+[i])

n, m = map(int,input().split())

이제 go()함수를 실행해야 하므로 go([]) 혹은 go(list())를 추가해주면 됩니다.

def go(arr):
    if len(arr) == m:
        print(' '.join(map(str,arr)))
        return
    for i in range(1, n+1):
        if i not in arr:
            go(arr+[i])

n, m = map(int,input().split())
go([])

BOJ 15655

N과 M (1)과 다르게 1부터 N까지의 숫자가 아닌 직접 N개의 수를 입력받고, 여기서 M개를 골라 오름차순으로 출력하라고 합니다. 문제를 확인해보니 조합의 성격을 띠고 있습니다.

일단 입력값을 받는 코드를 작성해봅시다.

n, m = map(int,input().split())
arr = list(map(int,input().split()))

오름차순으로 출력해야 하므로 여기서 바로 정렬을 해줍시다.

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))

이번에도 마찬가지로 go()라는 이름의 함수를 사용합니다. 파라미터(매개 변수)는 N과 M(1)과 비슷하게 배열을 사용하지만, 이번엔 해당 배열의 요소를 사용했는지 True(1)과 False(0)으로 표현하는 check 배열을 사용해보겠습니다.

이번 go()함수도 한 번 실행할 때마다 check의 배열 값 한 개만 선택해서 1로 바꾸고 끝내는 함수입니다.

그러면 처음 함수를 실행할 때는 아무것도 선택하지 않았으니 go([0]*n)으로 실행해야 합니다.

def go(check):
    blah blah

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n)

check 배열은 0과 1로만 있으니 만약 m개의 배열을 선택했다면 sum(check) == m이 되겠습니다.

그래서 조건문을 이용하여 sum(check) = m이 되면 arr을 출력하고, 함수를 끝내면 됩니다.

def go(check):
    if sum(check) == m:
        ans = []
        for i in range(n):
            if check[i]: ans.append(arr[i])
        print(' '.join(map(str,ans)))
        return
    blah blah

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n)

3 ~ 5줄은 한 줄로 표현할 수 있습니다.

def go(check):
    if sum(check) == m:
        ans = [arr[i] for i in range(n) if check[i]]
        print(' '.join(map(str,ans)))
        return
    blah blah

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n)

해당 코드는 위에 3줄짜리 코드와 마찬가지로 for문을 n번 돌릴 때, check[i]의 값이 True(1)라면 ans에 arr[i]를 추가하는 코드입니다.

이제 for문을 이용하여 check에서 0인 값들 중에 하나를 1로 바꿔주면 됩니다.

def go(check):
    if sum(check) == m:
        ans = [arr[i] for i in range(n) if check[i]]
        print(' '.join(map(str,ans)))
        return
    for i in range(n):
        if check[i] == 0:
            check[i] = 1

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n)

이 함수에서는 check값 하나를 1로 바꾸었으므로 go(check)를 실행해줍니다. 그리고 우리는 다른 경우의 수도 확인해봐야 하니까 다시 check[i] = 0으로 돌려줍시다.

def go(check):
    if sum(check) == m:
        ans = [arr[i] for i in range(n) if check[i]]
        print(' '.join(map(str,ans)))
        return
    for i in range(n):
        if check[i] == 0:
            check[i] = 1
            go(check)
            check[i] = 0

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n)

이대로 코드를 실행해보면 답이랑 다르다는 것을 알 수 있습니다.

예제 2로 예를 들자면 arr = [1, 7, 8, 9]가 있는데, 6줄에 for문이 처음부터 끝까지 돌아가기 때문에 (처음 1을 선택하고 두 번째로 9를 선택하는 경우의 수)와 (처음 9를 선택하고 두 번째로 1을 선택하는 경우의 수)가 나와서 1 9가 2번이나 출력하게 됩니다.

그래서 일단 처음 값을 선택하면 두 번째로 선택할 값은 첫 번째 값보다는 큰 값만 선택할 수 있게, 세 번째로 선택할 값은 두 번째 값보다는 큰 값만 선택할 수 있게 파라미터에 start 변수를 추가하여 6줄 for문을 start부터 n-1까지 돌릴 수 있게 만들어주면 됩니다. 우리는 이미 arr 배열을 오름차순으로 정렬했기 때문에 arr[start]는 arr[0] ~ arr[start-1]의 모든 값보다 큰 값이 되는 것은 자명합니다.

go()함수 내부에서 go()함수를 호출할 때는 for문에서 i를 사용하고 있으므로 다음 go()함수에서는 i+1부터 시작할 수 있도록 go(check, i+1)을 적어주시면 됩니다.

def go(check, start):
    if sum(check) == m:
        ans = [arr[i] for i in range(n) if check[i]]
        print(' '.join(map(str,ans)))
        return
    for i in range(start, n):
        if check[i] == 0:
            check[i] = 1
            go(check, i+1)
            check[i] = 0

n, m = map(int,input().split())
arr = sorted(list(map(int,input().split())))
go([0]*n, 0)

이렇게 1, 6, 9, 14줄을 수정하고 코드를 실행해보면 제대로 답이 나오게 됩니다.

풀어볼 문제

N과 M 시리즈 문제에서 난이도가 실버 3인 문제를 전부 가지고 왔습니다. 나머지 문제는 9주차에서 풀을 예정입니다.

Last updated