Created
January 6, 2024 01:02
-
-
Save koosaga/2a71ef1ad3df5d4cc9a3342cd48d937f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <bits/stdc++.h> | |
#define sz(v) ((int)(v).size()) | |
#define all(v) (v).begin(), (v).end() | |
#define cr(v, n) (v).clear(), (v).resize(n); | |
using namespace std; | |
using lint = long long; | |
using pi = array<lint, 2>; | |
const int MAXN = 100005; | |
struct elem { | |
lint far, diam, twosum; | |
}; | |
vector<int> a; | |
namespace AllDirectionTreeDP { | |
// Need to implement four functions: | |
// E: identity | |
// take_vertex: add vertex on top of merged edges | |
// up_root: update child DP to consider parent edge values | |
// merge: merge two child edges | |
// it's good if merges are commutative (its not necessary but be careful of specifics) | |
elem E() { return elem{0, 0}; } | |
elem take_vertex(elem DP, int v) { return elem{a[v] + DP.far, max(a[v] + DP.twosum, DP.diam), 0}; } | |
elem up_root(elem DP, int e) { return elem{DP.far, DP.diam, 0}; } | |
elem merge(elem a, elem b) { return elem{max(a.far, b.far), max(a.diam, b.diam), max({a.twosum, b.twosum, a.far + b.far})}; } | |
void dfs(int x, vector<vector<pi>> &gph, vector<int> &ord, vector<int> &pae) { | |
ord.push_back(x); | |
for (auto &[i, y] : gph[x]) { | |
gph[y].erase(find(all(gph[y]), pi{i ^ 1, x})); | |
pae[y] = (i ^ 1); | |
dfs(y, gph, ord, pae); | |
} | |
} | |
lint solve(int n, vector<pi> edges) { | |
vector<vector<pi>> gph(n); | |
gph.resize(n); | |
for (int i = 0; i < n - 1; i++) { | |
gph[edges[i][0]].push_back({2 * i, edges[i][1]}); | |
gph[edges[i][1]].push_back({2 * i + 1, edges[i][0]}); | |
} | |
vector<int> ord; | |
vector<int> pae(n, -1); | |
dfs(0, gph, ord, pae); | |
vector<elem> dp(n, E()); | |
reverse(all(ord)); | |
for (auto &z : ord) { | |
for (auto &[i, y] : gph[z]) { | |
dp[z] = merge(dp[z], up_root(dp[y], i)); | |
} | |
dp[z] = take_vertex(dp[z], z); | |
} | |
vector<elem> rev_dp(n, E()); | |
reverse(all(ord)); | |
for (auto &z : ord) { | |
vector<elem> pref(sz(gph[z]) + 1, E()); | |
vector<elem> suff(sz(gph[z]) + 1, E()); | |
if (~pae[z]) | |
pref[0] = up_root(rev_dp[z], pae[z]); | |
for (int i = 0; i < sz(gph[z]); i++) { | |
pref[i + 1] = suff[i] = up_root(dp[gph[z][i][1]], gph[z][i][0]); | |
} | |
for (int i = 1; i <= sz(gph[z]); i++) | |
pref[i] = merge(pref[i - 1], pref[i]); | |
for (int i = sz(gph[z]) - 1; i >= 0; i--) | |
suff[i] = merge(suff[i], suff[i + 1]); | |
for (int i = 0; i < sz(gph[z]); i++) { | |
rev_dp[gph[z][i][1]] = take_vertex(merge(pref[i], suff[i + 1]), z); | |
} | |
} | |
lint ans = 0; | |
for (int i = 1; i < n; i++) { | |
ans = max(ans, dp[i].diam + rev_dp[i].diam); | |
} | |
for (int i = 0; i < n; i++) { | |
vector<lint> paths(4); | |
for (auto &[_, j] : gph[i]) { | |
paths.push_back(dp[j].far); | |
} | |
paths.push_back(rev_dp[i].far); | |
sort(all(paths)); | |
reverse(all(paths)); | |
paths.resize(4); | |
ans = max(ans, accumulate(all(paths), 0ll)); | |
} | |
return ans; | |
} | |
} // namespace AllDirectionTreeDP | |
int main() { | |
ios_base::sync_with_stdio(0); | |
cin.tie(0); | |
cout.tie(0); | |
int n; | |
cin >> n; | |
a.resize(n); | |
for (auto &x : a) | |
cin >> x; | |
vector<pi> edges; | |
for (int i = 0; i < n - 1; i++) { | |
int u, v; | |
cin >> u >> v; | |
u--; | |
v--; | |
edges.push_back({u, v}); | |
} | |
cout << AllDirectionTreeDP::solve(n, edges) << "\n"; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment