Open In App

Print all nodes at distance k from a given node

Last Updated : 05 Oct, 2024
Summarize
Comments
Improve
Suggest changes
Like Article
Like
Share
Report
News Follow

Given a binary tree, a target node in the binary tree, and an integer value k, the task is to find all the nodes at a distance k from the given target node. No parent pointers are available.

Note:

  • You have to return the list in sorted order.
  • The tree will not contain duplicate values.

Examples:

Input: target = 2, k = 2

Iterative-Postorder-Traversal

Output: 3
Explanation: Nodes at a distance 2 from the given target node 2 is 3.

Input: target = 3, k = 1

Iterative-Postorder-Traversal-3

Output: 1 6 7
Explanation: Nodes at a distance 1 from the given target node 3 are 1 , 6 & 7.

[Expected Approach – 1] Using Recursion – O(nlogn) Time and O(h) Space

The idea is to traverse the binary tree using recursion and find the target node. Find all the nodes in the left and right subtree of target node that are at a distance k. Also for all the nodes in the path of target node, find all the nodes in the opposite subtree that are at the distance of (k – distance of target node).

Below is the implementation of the above approach:

C++
// C++ program to find nodes
// at distance k from target.
#include <bits/stdc++.h>
using namespace std;

class Node {
  public:
    int data;
    Node *left, *right;
    Node(int x) {
        data = x;
        left = nullptr;
        right = nullptr;
    }
};

// Function which finds the nodes at a given
// distance from root node
void findNodes(Node *root, int dis, vector<int> &ans) {

    // base case
    if (root == nullptr)
        return;

    if (dis == 0) {
        ans.push_back(root->data);
        return;
    }

    findNodes(root->left, dis - 1, ans);
    findNodes(root->right, dis - 1, ans);
}

// Function which returns the distance of a node
// target node. Returns -1 if target is not found.
int kDistanceRecur(Node *root, int target, int k, vector<int> &ans) {

    // base case
    if (root == nullptr)
        return -1;

    // If current node is target
    if (root->data == target) {

        // Find nodes at distance k from
        // target node in subtree.
        findNodes(root, k, ans);

        return 1;
    }

    int left = kDistanceRecur(root->left, target, k, ans);

    // If target node is found in left
    // subtree, find all nodes at distance
    // k-left in right subtree.
    if (left != -1) {
        if (k - left == 0)
            ans.push_back(root->data);
        else
            findNodes(root->right, k - left - 1, ans);
        return left + 1;
    }

    int right = kDistanceRecur(root->right, target, k, ans);

    // If target node is found in right
    // subtree, find all nodes at distance
    // k-right in left subtree.
    if (right != -1) {
        if (k - right == 0)
            ans.push_back(root->data);
        else
            findNodes(root->left, k - right - 1, ans);
        return right + 1;
    }

    // If target node is not found
    // return -1
    return -1;
}

vector<int> KDistanceNodes(Node *root, int target, int k) {
    vector<int> ans;

    kDistanceRecur(root, target, k, ans);

    // sort the result
    sort(ans.begin(), ans.end());

    return ans;
}

void printList(vector<int> v) {
    int n = v.size();
    for (int i = 0; i < n; i++) {
        cout << v[i] << " ";
    }
    cout << endl;
}

int main() {

    // Create a hard coded tree.
    //         20
    //       /    \
    //      7      24
    //    /   \
    //   4     3
    //        /
    //       1
    Node *root = new Node(20);
    root->left = new Node(7);
    root->right = new Node(24);
    root->left->left = new Node(4);
    root->left->right = new Node(3);
    root->left->right->left = new Node(1);

    int target = 7, k = 2;
    vector<int> ans = KDistanceNodes(root, target, k);

    printList(ans);
    return 0;
}
Java Python C# JavaScript

Output
1 24 

Time Complexity: O(nlogn), for sorting the result.
Auxiliary Space: O(h), where h is the height of the tree.

[Expected Approach – 2] Using DFS with Parent Pointers – O(nlogn) Time and O(n) Space:

The idea is to recursively find the target node and map each node to its parent node. Then, starting from the target node, apply depth first search (DFS) to find all the nodes at distance k from the target node.

Below is the implementation of the above approach:

C++
// C++ program to find nodes
// at distance k from target.
#include <bits/stdc++.h>
using namespace std;

class Node {
  public:
    int data;
    Node *left, *right;
    Node(int x) {
        data = x;
        left = nullptr;
        right = nullptr;
    }
};

// Function which maps the nodes to its parent nodes
// and returns the target node.
Node *findTarNode(Node *root, int target, unordered_map<Node*, Node*> &parent) {

    Node *left = nullptr, *right = nullptr;
    
    // map the left child to root node 
    // and search for target node in 
    // left subtree.
    if (root->left != nullptr) {
        parent[root->left] = root;
        left = findTarNode(root->left, target, parent);
    }
    
    // map the right child to root node and search 
    // for target node in right subtree.
    if (root->right != nullptr) {
        parent[root->right] = root;
        right = findTarNode(root->right, target, parent);
    }
    
    // If root node is target, then
    // return root node.
    if (root->data == target) {
        return root;
    }
    
    // If target node in found in left
    // subtree, then return left.
    else if (left != nullptr) {
        return left;
    }
    
    // return the result from
    // right subtree.
    return right;
}

// depth first function to find nodes k 
// distance away.
void dfs(Node *root, Node *prev, int k, 
         unordered_map<Node *, Node *> &parent, vector<int> &ans) {

    // base case
    if (root == nullptr)
        return;
    
    // If current node is kth 
    // distance away.
    if (k == 0) {
        ans.push_back(root->data);
        return;
    }

    if (root->left != prev)
        dfs(root->left, root, k - 1, parent, ans);

    if (root->right != prev)
        dfs(root->right, root, k - 1, parent, ans);

    if (parent[root] != prev)
        dfs(parent[root], root, k - 1, parent, ans);
}

vector<int> KDistanceNodes(Node *root, int target, int k) {
    vector<int> ans;

    if (root == nullptr)
        return ans;
    
    // find the target nodes and map the nodes
    // to their parent nodes.
    unordered_map<Node *, Node *> parent;
    Node *tar = findTarNode(root, target, parent);

    dfs(tar, nullptr, k, parent, ans);

    // sort the result
    sort(ans.begin(), ans.end());

    return ans;
}

void printList(vector<int> v) {
    int n = v.size();
    for (int i = 0; i < n; i++) {
        cout << v[i] << " ";
    }
    cout << endl;
}

int main() {

    // Create a hard coded tree.
    //         20
    //       /    \
    //      7      24
    //    /   \
    //   4     3
    //        /
    //       1
    Node *root = new Node(20);
    root->left = new Node(7);
    root->right = new Node(24);
    root->left->left = new Node(4);
    root->left->right = new Node(3);
    root->left->right->left = new Node(1);

    int target = 7, k = 2;
    vector<int> ans = KDistanceNodes(root, target, k);

    printList(ans);
    return 0;
}
Java Python C# JavaScript

Output
1 24 

Time Complexity: O(nlogn), for sorting the result.
Space Complexity: O(h), where h is the height of the tree.

Related article:



Next Article

Similar Reads

three90RightbarBannerImg