Tree House CodeChef Solution | Easy Approach | C++ Java Python

Share:

Tree House CodeChef Solution View on Codechef

 Tree House CodeChef Solution

There is a large tree house in an unknown world. It is ruled by the great emperor KZS. It consists of NN nodes numbered from 11 to NN in which the people of that world reside. The nodes are organized in a tree structure rooted at node 11. You need to assign values to the nodes according to the wishes of emperor KZS which are as follows :-

  • The value of node 11 is XX.
  • All immediate children of any node have pairwise distinct values.
  • For every node with at least one immediate child, the gcdgcd of the values of all immediate children is equal to the value of the node.
  • The total sum of the values of all nodes should be minimum.

The greatest common divisor gcd(a,b)gcd(a,b) of two positive integers aa and bb is equal to the largest integer dd such that both integers aa and bb are divisible by dd.

Print the sum of all values, modulo 109+7109+7.

Input

  • The first line contains an integer TT, the number of test cases. TT testcases follow.
  • The first line of each test contains two integers NN and XX.
  • Each of the following N−1N−1 lines contains two integers uu and vv, denoting an edge between nodes uu and vv.

Output

  • For each test case, print the sum of values, modulo 109+7109+7.

Constraints

  • 1≤T≤151≤T≤15
  • 2≤N≤3⋅1052≤N≤3⋅105
  • 1≤X≤1091≤X≤109
  • 1≤u,v≤N1≤u,v≤N and u≠vu≠v
  • The given edges form a tree
  • The sum of NN over all test cases doesn’t exceed 3⋅1053⋅105.

Subtasks

Subtask #1 (100 points): Original Constraints

Sample Input

2
4 1
1 2
1 3
1 4
8 1
1 2
1 3
2 4
2 5
5 6
5 7
7 8

Sample Output

7
11

Explanation

In test case 11, we will give values 11, 22, 33 to the nodes 22, 33 and 44 respectively. So, the total sum will be 1+1+2+3=71+1+2+3=7.

Tree House CodeChef Solution

C++

#include <iostream>
#include <vector>
#include <unordered_set>
#include <set>
#include <unordered_map>
#include <map>
#include <utility>
#include <cmath>
#include <string>
#include <algorithm>
#include <numeric>
#include <chrono>
#include <iomanip>

using namespace std;

#define mod 1000000007
typedef long long ll;
typedef long double ld;

// loops
#define FOR(i,a,b) for (int i = (a); i < (b); ++i)
#define F0R(i,a) FOR(i,0,a)
#define rep(a) F0R(_,a)

vector<vector<int>> tree, tud;
vector<int> vis;
void dir(int n) {
    vis[n] = 1;
    for (int& node: tud[n]) {
        if (vis[node] == 0) {
            tree[n].push_back(node);
            dir(node);
        }
    }
}

ll houses(int n) {
    vector<ll> temp;
    for (int node: tree[n])
        temp.push_back(houses(node));
    sort(temp.begin(), temp.end(), greater<ll> ());
    ll k = 0;
    for (int i = 1; i <= temp.size(); i++) {
        k += temp[i - 1] * i + i;
    }
    return k;
}

ll solve() {
    ll n, x; cin >> n >> x;
    
    tud.assign(n + 1, vector<int> ());
    tree.assign(n + 1, vector<int> ());
    rep(n-1) {
        int a, b; cin >> a >> b;
        tud[a].push_back(b);
        tud[b].push_back(a);
    }
    vis.assign(n + 1, 0);
    dir(1);

    ll ans = houses(1) + 1;
    ans %= mod;
    return (ans * x) % mod;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    
    int T; cin >> T;
//    cout << setprecision(10) << fixed;
    while (T--)
        cout << solve() << '\n';
    
    return 0;
}

Java

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.*;

public class Main {
    static class FastReader {
        BufferedReader br;
        StringTokenizer st;

        public FastReader() {
            br = new BufferedReader(
                    new InputStreamReader(System.in));
        }

        String next() {
            while (st == null || !st.hasMoreElements()) {
                try {
                    st = new StringTokenizer(br.readLine());
                } catch (IOException e) {
                    e.printStackTrace();
                }
            }
            return st.nextToken();
        }

        int nextInt() {
            return Integer.parseInt(next());
        }

        long nextLong() {
            return Long.parseLong(next());
        }

        double nextDouble() {
            return Double.parseDouble(next());
        }

        String nextLine() {
            String str = "";
            try {
                str = br.readLine();
            } catch (IOException e) {
                e.printStackTrace();
            }
            return str;
        }
    }

    static long ans;
    static long MOD = (int) 1e9 + 7;

    public static void main(String[] args) throws Exception {
//        final Scanner cin = new Scanner(System.in);
        final FastReader cin = new FastReader();
        final StringBuilder cout = new StringBuilder();
//        final Main solver = new Main();
//        int[] nums = {8, 10, 2, 5, 9, 6, 3, 8, 2};
//        System.out.println(solver.coutPairs(nums, 6));
        int t = cin.nextInt();
        for (int tc = 1; tc <= t; ++tc) {
            int n = cin.nextInt();
            int x = cin.nextInt();
            List<List<Integer>> graph = new ArrayList<>(n - 1);
            for (int i = 0; i < n; ++i)
                graph.add(new ArrayList<>());
            boolean[] vis = new boolean[n];
            for (int i = 0; i < n - 1; ++i) {
                int u = cin.nextInt();
                int v = cin.nextInt();
                --u;
                --v;
                graph.get(u).add(v);
                graph.get(v).add(u);
            }
            ans = solve(graph, vis, 0) % MOD;
            ans *= x;
            cout.append(ans % MOD).append('\n');
        }
        System.out.println(cout);
    }

    static long solve(List<List<Integer>> g, boolean[] vis, int src) {
        if (!vis[src]) {
            vis[src] = true;
            List<Long> weights = new ArrayList<>();
            for (int child : g.get(src)) {
                if (!vis[child]) {
                    weights.add(solve(g, vis, child));
                }
            }
            Collections.sort(weights, Collections.reverseOrder());
            long curWt = 1;
            for (int i = 0; i < weights.size(); ++i)
                curWt += ((i + 1) * weights.get(i));
            return curWt;
        }
        throw new RuntimeException("ERR");
    }
}

Python

from sys import setrecursionlimit
setrecursionlimit(300000)
MD = 10**9 +7
def getval(p,n):
	if len(A[n]) == 1:
		v = 1
	else:
		L = []
		for x in A[n]:
			if x != p:
				v = getval(n,x)
				L.append(v)
			# endif
		# endfor x
		L.sort()
		m = len(L)
		v = 1
		for x in L:
			v += m*x
			m -= 1
		# endfor x
	# endif
	return v
# end fun
t = int(raw_input())
for i in range(t):
	st = raw_input().split()
	N = int(st[0])
	X = int(st[1])
	A = [[] for x in range(N+1)]
	A[1].append(-1)
	for k in range(N-1):
		st = raw_input().split()
		u = int(st[0])
		v = int(st[1])
		A[u].append(v)
		A[v].append(u)
	# endfor k
	v = getval(-1,1)
	r = v*X%MD
	print r
# endfor i

Tree House CodeChef Solution Tutorial

Check out more such posts here

Leave a Comment

Your email address will not be published. Required fields are marked *

x