Educational Codeforces Round 168 - E. Level Up

Updated:

Categories:

Tags: , ,

Educational Codeforces Round 168 - E. Level Up

요약

1번부터 n번까지 monster와 싸우는데, 현재 lv보다 작은 lv의 몬스터는 도망친다.
이때 유저는 k명의 몬스터와 싸우고 나서 레벨업을 한다.

이때, 총 q개의 쿼리가 들어온다. i,x가 들어온다.
k=x일 때, i번째 몬스터는 싸우는 지 여부를 출력해야 한다.

Solution. 1

풀이

k = 1 -> 최대 n/1 + 1
k = 2 -> 최대 n/2 + 1

k = x -> 최대 n/x + 1

k = n -> 최대 n/n + 1

위처럼 k = x일 때, 최대 n/x+1 레벨까지 레벨업을 할 수 있음을 알 수 있다.

우리는 모든 k마다 레벨업할 때의 위치를 구하고 싶다. 레벨업하는 위치의 갯수는 n/1 + n/2 + .. + n/n으로 n(1/1 + 1/2 + … + n/1) 이 된다.
the sum of harmonic numbers $H_n$은 대충 $log_2(n)$ 쯤이다.
따라서 모든 레벨업의 위치는 $Nlog_2N$개 정도 됨을 알 수 있다.

그렇다면 각 레벨업의 위치를 구하려면 어떻게 해야 할까?

모든 k에 대해서 sloc배열을 저장한다.
$sloc[x][lv]$에는 k=x일 때 레벨이 lv이 될 때의 위치를 저장하고 있다.

모든 lv = 1 부터 증가하면서 탐색한다.
이때, 세그먼트 트리에 lv보다 작은 값들은 제거해준다.
참고로 여기서 세그먼트 트리는 현재 숫자의 갯수를 저장하고 있다.
이렇게 하면 세그먼트 트리에는 항상 lv보다 크거나 같은 수들의 갯수를 저장하고 있게 된다.

각 $sloc[x][lv+1]$을 구하려면 $sloc[x][lv]$ ~ $right$ 의 수의 갯수가 x가 되도록 하는 right를 찾으면 된다. right를 이분탐색으로 찾고 각 세그먼트의 sum을 구하는 데 $logN$이 걸리므로, 각 레벨업 위치마다 $O(log_2N)^2$의 시간이 걸린다.

따라서 최종 시간복잡도는 $O(Nlog_2N^3)$이 된다.

쿼리의 경우 i,x가 들어오면 이분탐색으로 $sloc[x]$ 에 대해 해당 i위치를 몇 레벨일 때 만나는 지 구할 수 있다.
이를 이용해서 $a[i]$와 레벨을 비교해서 $a[i]$가 lv보다 크거나 같으면 YES를, 아니라면 NO를 출력해주면 된다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#include <bits/stdc++.h>

#define endl "\n"
#define all(v) (v).begin(), (v).end()
#define For_IMPL(condition, i, a, b, increment, decrement) \
    for (int i = (a); condition; (a < b ? increment : decrement))
#define For(i, a, b) For_IMPL((a < b ? i < b : i > b), i, a, b, ++i, --i)
#define For_(i, a, b) For_IMPL((a < b ? i <= b : i >= b), i, a, b, ++i, --i)
#define ft first
#define sd second

using namespace std;
using ll = long long;
using lll = __int128_t;
using ulll = __uint128_t;
using ull = unsigned long long;
using ld = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using ti3 = tuple<int, int, int>;
using tl3 = tuple<ll, ll, ll>;

template<class T> bool ckmin(T& a, const T& b) { return b < a ? a = b, 1 : 0; }
template<class T> bool ckmax(T& a, const T& b) { return a < b ? a = b, 1 : 0; }

const int INF = 987654321;
const int INF0 = numeric_limits<int>::max();
const ll LNF = 987654321987654321;
const ll LNF0 = numeric_limits<ll>::max();

class Segment {
public:
    vector<ll> tree; //tree[node] := a[start ~ end] 의 합

    Segment() {}
    Segment(int size) {
        this->resize(size);
    }
    void resize(int size) {
        size = (int) floor(log2(size)) + 2;
        size = pow(2, size);
        tree.resize(size, 0);
    }
    ll init(int node, int start, int end) {
        if(start == end) return tree[node] = 1;
        else return tree[node] = init( 2 * node, start, (start + end) / 2) +
                                 init( 2 * node + 1, (start + end) / 2 + 1, end);
    }
    ll sum(int node, int start, int end, int left, int right) {
        if(right < start || end < left) return 0;
        if(left <= start && end <= right) return tree[node];
        return sum(node * 2, start, (start + end) / 2, left, right) +
               sum(node * 2 + 1, (start + end) / 2 + 1, end, left, right);
    }
    void update(int node, int start, int end, int index, ll diff) {
        if(index < start || end < index) return;
        tree[node] += diff;
        if(start != end) {
            update(node * 2, start, (start + end) / 2, index, diff);
            update(node * 2 + 1, (start + end) / 2 + 1, end, index, diff);
        }
    }
    ll search(int node, int start, int end, int left, int right, ll target) {
        if(start == end) return start;
        ll leftSum = sum(2*node, start, (start+end)/2, left, right);
        if(target <= leftSum) return search(2*node, start, (start+end)/2, left, right, target);
        else return search(2*node+1, (start+end)/2+1, end, left, right, target - leftSum);
    }
};

void solve() {
    int n, q; cin >> n >> q;

    vector<int> need(n); iota(all(need),1);

    vector<ll> a(n+1);
    priority_queue<pll,vector<pll>,greater<>> pq;
    For_(i,1,n) {
        cin >> a[i];
        pq.emplace(a[i],i);
    }

    Segment tree(n);
    tree.init(1,1,n);
    vector<ll> sloc[n+1];
    For_(i,1,n) sloc[i].emplace_back(1);
    for(ll lv=1; !need.empty(); lv++) {
        while(!pq.empty() and pq.top().ft < lv) {
            tree.update(1,1,n,pq.top().sd,-1);
            pq.pop();
        }
        vector<int> new_need;
        for(auto x : need) {
            ll l = sloc[x].back();
            ll r = tree.search(1,1,n,l,n,x);
            ll sum = tree.sum(1,1,n,l,r);
            if(sum == x and r+1 <= n) {
                new_need.emplace_back(x);
                sloc[x].emplace_back(r+1);
            }
        }
        swap(need, new_need);
    }

    while(q--) {
        ll i, x; cin >> i >> x;
        ll idx = upper_bound(all(sloc[x]), i) - sloc[x].begin();
        if(a[i] >= idx) cout << "YES\n";
        else cout << "NO\n";
    }
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int tc = 1;
//    cin >> tc;
    while(tc--) {
        solve();
//        cout << solve() << endl;
    }


    return 0;
}

Solution. 2

풀이

세그먼트 트리 말고, 머지소트 트리로도 풀 수 있다.

머지소트 트리의 쿼리가 $O(log_2N^2)$이므로 전체 시간복잡도는 $O(Nlog_2N^4)$가 되는데 최적화만 잘 시키면 통과된다.

Solution. 1과 동일하지만 $sloc$ 배열을 구하는 방식이 머지소트 트리로 구하는 방식인 해결 방법이다.

코드

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#include <bits/stdc++.h>

#define endl "\n"
#define all(v) (v).begin(), (v).end()
#define For_IMPL(condition, i, a, b, increment, decrement) \
    for (int i = (a); condition; (a < b ? increment : decrement))
#define For(i, a, b) For_IMPL((a < b ? i < b : i > b), i, a, b, ++i, --i)
#define For_(i, a, b) For_IMPL((a < b ? i <= b : i >= b), i, a, b, ++i, --i)
#define ft first
#define sd second

using namespace std;
using ll = long long;
using lll = __int128_t;
using ulll = __uint128_t;
using ull = unsigned long long;
using ld = long double;
using pii = pair<int, int>;
using pll = pair<ll, ll>;
using ti3 = tuple<int, int, int>;
using tl3 = tuple<ll, ll, ll>;

const int INF = 987654321;
const int INF0 = numeric_limits<int>::max();
const ll LNF = 987654321987654321;
const ll LNF0 = numeric_limits<ll>::max();

class mergeTree {
public:
    vector<vector<ll>> tree; //tree[node] := a[start ~ end] 의 합

    mergeTree() {}
    mergeTree(int size) {
        this->resize(size);
    }
    void resize(int size) {
        size = (int) floor(log2(size)) + 2;
        size = pow(2, size);
        tree.resize(size);
    }

    void update(int node, int start, int end, int index, ll value) {
        if(index < start || end < index) return;
        tree[node].emplace_back(value);
        if(start != end) {
            update(node * 2, start, (start + end) / 2, index, value);
            update(node * 2 + 1, (start + end) / 2 + 1, end, index, value);
        }
    }
    ll query(int node, int start, int end, int left, int right, ll value) {
        if(right < start || end < left) return 0;
        if(left <= start && end <= right) return tree[node].end() - lower_bound(all(tree[node]), value);
        return query(node * 2, start, (start + end) / 2, left, right, value) +
        query(node * 2 + 1, (start + end) / 2 + 1, end, left, right, value);
    }
    // left ~ right 에서 value보다 크거나 같으면서 count번째의 index 반환
    ll search(int node, int start, int end, int left, int right, ll value, ll count) {
        if(start == end) {
            if(count == 1 && tree[node].front() >= value) return start;
            else return -1;
        }
        ll leftCnt = query(2*node, start, (start+end)/2, left, right, value);
        if(count <= leftCnt) return search(2*node, start, (start+end)/2, left, right, value, count);
        else return search(2*node+1, (start+end)/2+1, end, left, right, value, count-leftCnt);
    }
};


void solve() {
    ll n, q; cin >> n >> q;
    vector<ll> a(n+1);
    vector<pll> v;
    For_(i,1,n) {
        cin >> a[i];
        v.emplace_back(a[i],i);
    }
    sort(all(v));

    mergeTree tree(n);
    for(auto [val, idx] : v)
        tree.update(1,1,n,idx,val);

    vector<ll> group[n+1];
    For_(x,1,n) {
        ll left=1, lv=1;
        while(true) {
            ll res = tree.search(1,1,n,left,n,lv,x);
            if(res == -1) break;
            group[x].emplace_back(res);
            lv++;
            left = res+1;
        }
    }

    while(q--) {
        ll i, x; cin >> i >> x;
        int idx = lower_bound(all(group[x]), i) - group[x].begin();
        if(a[i] >= idx+1) cout << "YES\n";
        else cout << "NO\n";
    }
}

int main(void) {
    ios_base::sync_with_stdio(false);
    cin.tie(nullptr);
    cout.tie(nullptr);

    int tc = 1;
//    cin >> tc;
    while(tc--) {
        solve();
//        cout << solve() << endl;
    }


    return 0;
}

후기

조화수의 합(sum of the harmonic numbers) $H_n$가 $ln(n)$과 근접하다는 사실을 몰라서 풀지 못 했다.
솔루션1의 세그트리는 바로 생각 못 했지만 머지소트 트리를 이용한 풀이는 생각했다. 하지만 시간초과가 날 것이라 생각하고 시도하지 않았다.

k = x일 때, n/x개의 레벨만 생길 것이라 생각하긴 했으나 전부 합쳤을 때 시간초과가 날 것이라 예상했다.
조화수의 합은 로그 시간임을 기억하자.

Leave a comment