Skip to content

Instantly share code, notes, and snippets.

@koosaga
Created January 6, 2024 01:02
Show Gist options
  • Save koosaga/2a71ef1ad3df5d4cc9a3342cd48d937f to your computer and use it in GitHub Desktop.
Save koosaga/2a71ef1ad3df5d4cc9a3342cd48d937f to your computer and use it in GitHub Desktop.
#include <bits/stdc++.h>
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
#define cr(v, n) (v).clear(), (v).resize(n);
using namespace std;
using lint = long long;
using pi = array<lint, 2>;
const int MAXN = 100005;
struct elem {
lint far, diam, twosum;
};
vector<int> a;
namespace AllDirectionTreeDP {
// Need to implement four functions:
// E: identity
// take_vertex: add vertex on top of merged edges
// up_root: update child DP to consider parent edge values
// merge: merge two child edges
// it's good if merges are commutative (its not necessary but be careful of specifics)
elem E() { return elem{0, 0}; }
elem take_vertex(elem DP, int v) { return elem{a[v] + DP.far, max(a[v] + DP.twosum, DP.diam), 0}; }
elem up_root(elem DP, int e) { return elem{DP.far, DP.diam, 0}; }
elem merge(elem a, elem b) { return elem{max(a.far, b.far), max(a.diam, b.diam), max({a.twosum, b.twosum, a.far + b.far})}; }
void dfs(int x, vector<vector<pi>> &gph, vector<int> &ord, vector<int> &pae) {
ord.push_back(x);
for (auto &[i, y] : gph[x]) {
gph[y].erase(find(all(gph[y]), pi{i ^ 1, x}));
pae[y] = (i ^ 1);
dfs(y, gph, ord, pae);
}
}
lint solve(int n, vector<pi> edges) {
vector<vector<pi>> gph(n);
gph.resize(n);
for (int i = 0; i < n - 1; i++) {
gph[edges[i][0]].push_back({2 * i, edges[i][1]});
gph[edges[i][1]].push_back({2 * i + 1, edges[i][0]});
}
vector<int> ord;
vector<int> pae(n, -1);
dfs(0, gph, ord, pae);
vector<elem> dp(n, E());
reverse(all(ord));
for (auto &z : ord) {
for (auto &[i, y] : gph[z]) {
dp[z] = merge(dp[z], up_root(dp[y], i));
}
dp[z] = take_vertex(dp[z], z);
}
vector<elem> rev_dp(n, E());
reverse(all(ord));
for (auto &z : ord) {
vector<elem> pref(sz(gph[z]) + 1, E());
vector<elem> suff(sz(gph[z]) + 1, E());
if (~pae[z])
pref[0] = up_root(rev_dp[z], pae[z]);
for (int i = 0; i < sz(gph[z]); i++) {
pref[i + 1] = suff[i] = up_root(dp[gph[z][i][1]], gph[z][i][0]);
}
for (int i = 1; i <= sz(gph[z]); i++)
pref[i] = merge(pref[i - 1], pref[i]);
for (int i = sz(gph[z]) - 1; i >= 0; i--)
suff[i] = merge(suff[i], suff[i + 1]);
for (int i = 0; i < sz(gph[z]); i++) {
rev_dp[gph[z][i][1]] = take_vertex(merge(pref[i], suff[i + 1]), z);
}
}
lint ans = 0;
for (int i = 1; i < n; i++) {
ans = max(ans, dp[i].diam + rev_dp[i].diam);
}
for (int i = 0; i < n; i++) {
vector<lint> paths(4);
for (auto &[_, j] : gph[i]) {
paths.push_back(dp[j].far);
}
paths.push_back(rev_dp[i].far);
sort(all(paths));
reverse(all(paths));
paths.resize(4);
ans = max(ans, accumulate(all(paths), 0ll));
}
return ans;
}
} // namespace AllDirectionTreeDP
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int n;
cin >> n;
a.resize(n);
for (auto &x : a)
cin >> x;
vector<pi> edges;
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
u--;
v--;
edges.push_back({u, v});
}
cout << AllDirectionTreeDP::solve(n, edges) << "\n";
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment