Print all nodes at distance k from a given node
Given a binary tree, a target node in the binary tree, and an integer value k, the task is to find all the nodes at a distance k from the given target node. No parent pointers are available.
Note:
- You have to return the list in sorted order.
- The tree will not contain duplicate values.
Examples:
Input: target = 2, k = 2
Output: 3
Explanation: Nodes at a distance 2 from the given target node 2 is 3.Input: target = 3, k = 1
Output: 1 6 7
Explanation: Nodes at a distance 1 from the given target node 3 are 1 , 6 & 7.
Table of Content
[Expected Approach – 1] Using Recursion – O(nlogn) Time and O(h) Space
The idea is to traverse the binary tree using recursion and find the target node. Find all the nodes in the left and right subtree of target node that are at a distance k. Also for all the nodes in the path of target node, find all the nodes in the opposite subtree that are at the distance of (k – distance of target node).
Below is the implementation of the above approach:
// C++ program to find nodes
// at distance k from target.
#include <bits/stdc++.h>
using namespace std;
class Node {
public:
int data;
Node *left, *right;
Node(int x) {
data = x;
left = nullptr;
right = nullptr;
}
};
// Function which finds the nodes at a given
// distance from root node
void findNodes(Node *root, int dis, vector<int> &ans) {
// base case
if (root == nullptr)
return;
if (dis == 0) {
ans.push_back(root->data);
return;
}
findNodes(root->left, dis - 1, ans);
findNodes(root->right, dis - 1, ans);
}
// Function which returns the distance of a node
// target node. Returns -1 if target is not found.
int kDistanceRecur(Node *root, int target, int k, vector<int> &ans) {
// base case
if (root == nullptr)
return -1;
// If current node is target
if (root->data == target) {
// Find nodes at distance k from
// target node in subtree.
findNodes(root, k, ans);
return 1;
}
int left = kDistanceRecur(root->left, target, k, ans);
// If target node is found in left
// subtree, find all nodes at distance
// k-left in right subtree.
if (left != -1) {
if (k - left == 0)
ans.push_back(root->data);
else
findNodes(root->right, k - left - 1, ans);
return left + 1;
}
int right = kDistanceRecur(root->right, target, k, ans);
// If target node is found in right
// subtree, find all nodes at distance
// k-right in left subtree.
if (right != -1) {
if (k - right == 0)
ans.push_back(root->data);
else
findNodes(root->left, k - right - 1, ans);
return right + 1;
}
// If target node is not found
// return -1
return -1;
}
vector<int> KDistanceNodes(Node *root, int target, int k) {
vector<int> ans;
kDistanceRecur(root, target, k, ans);
// sort the result
sort(ans.begin(), ans.end());
return ans;
}
void printList(vector<int> v) {
int n = v.size();
for (int i = 0; i < n; i++) {
cout << v[i] << " ";
}
cout << endl;
}
int main() {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node *root = new Node(20);
root->left = new Node(7);
root->right = new Node(24);
root->left->left = new Node(4);
root->left->right = new Node(3);
root->left->right->left = new Node(1);
int target = 7, k = 2;
vector<int> ans = KDistanceNodes(root, target, k);
printList(ans);
return 0;
}
// Java program to find nodes
// at distance k from target.
import java.util.ArrayList;
import java.util.Collections;
class Node {
int data;
Node left, right;
Node(int x) {
data = x;
left = null;
right = null;
}
}
class GfG {
// Function which finds the nodes at a given
// distance from root node
static void findNodes(Node root, int dis,
ArrayList<Integer> ans) {
// base case
if (root == null)
return;
if (dis == 0) {
ans.add(root.data);
return;
}
findNodes(root.left, dis - 1, ans);
findNodes(root.right, dis - 1, ans);
}
// Function which returns the distance of a node
// target node. Returns -1 if target is not found.
static int kDistanceRecur(Node root, int target, int k,
ArrayList<Integer> ans) {
// base case
if (root == null)
return -1;
// If current node is target
if (root.data == target) {
// Find nodes at distance k from
// target node in subtree.
findNodes(root, k, ans);
return 1;
}
int left
= kDistanceRecur(root.left, target, k, ans);
// If target node is found in left
// subtree, find all nodes at distance
// k-left in right subtree.
if (left != -1) {
if (k - left == 0)
ans.add(root.data);
else
findNodes(root.right, k - left - 1, ans);
return left + 1;
}
int right
= kDistanceRecur(root.right, target, k, ans);
// If target node is found in right
// subtree, find all nodes at distance
// k-right in left subtree.
if (right != -1) {
if (k - right == 0)
ans.add(root.data);
else
findNodes(root.left, k - right - 1, ans);
return right + 1;
}
// If target node is not found
// return -1
return -1;
}
static ArrayList<Integer>
KDistanceNodes(Node root, int target, int k) {
ArrayList<Integer> ans = new ArrayList<>();
kDistanceRecur(root, target, k, ans);
// sort the result
Collections.sort(ans);
return ans;
}
static void printList(ArrayList<Integer> v) {
for (int i : v) {
System.out.print(i + " ");
}
System.out.println();
}
public static void main(String[] args) {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
int target = 7, k = 2;
ArrayList<Integer> ans
= KDistanceNodes(root, target, k);
printList(ans);
}
}
# Python program to find nodes
# at distance k from target.
class Node:
def __init__(self, x):
self.data = x
self.left = None
self.right = None
# Function which finds the nodes at a given
# distance from root node
def findNodes(root, dis, ans):
# base case
if root is None:
return
if dis == 0:
ans.append(root.data)
return
findNodes(root.left, dis - 1, ans)
findNodes(root.right, dis - 1, ans)
# Function which returns the distance of a node
# target node. Returns -1 if target is not found.
def kDistanceRecur(root, target, k, ans):
# base case
if root is None:
return -1
# If current node is target
if root.data == target:
# Find nodes at distance k from
# target node in subtree.
findNodes(root, k, ans)
return 1
left = kDistanceRecur(root.left, target, k, ans)
# If target node is found in left
# subtree, find all nodes at distance
# k-left in right subtree.
if left != -1:
if k - left == 0:
ans.append(root.data)
else:
findNodes(root.right, k - left - 1, ans)
return left + 1
right = kDistanceRecur(root.right, target, k, ans)
# If target node is found in right
# subtree, find all nodes at distance
# k-right in left subtree.
if right != -1:
if k - right == 0:
ans.append(root.data)
else:
findNodes(root.left, k - right - 1, ans)
return right + 1
# If target node is not found
# return -1
return -1
def KDistanceNodes(root, target, k):
ans = []
kDistanceRecur(root, target, k, ans)
# sort the result
ans.sort()
return ans
def printList(v):
print(" ".join(map(str, v)))
if __name__ == "__main__":
# Create a hard coded tree.
# 20
# / \
# 7 24
# / \
# 4 3
# /
# 1
root = Node(20)
root.left = Node(7)
root.right = Node(24)
root.left.left = Node(4)
root.left.right = Node(3)
root.left.right.left = Node(1)
target = 7
k = 2
ans = KDistanceNodes(root, target, k)
printList(ans)
// C# program to find nodes
// at distance k from target.
using System;
using System.Collections.Generic;
class Node {
public int data;
public Node left, right;
public Node(int x) {
data = x;
left = null;
right = null;
}
}
class GfG {
// Function which finds the nodes at a given
// distance from root node
static void findNodes(Node root, int dis, List<int> ans) {
// base case
if (root == null)
return;
if (dis == 0) {
ans.Add(root.data);
return;
}
findNodes(root.left, dis - 1, ans);
findNodes(root.right, dis - 1, ans);
}
// Function which returns the distance of a node
// target node. Returns -1 if target is not found.
static int kDistanceRecur(Node root, int target, int k,
List<int> ans) {
// base case
if (root == null)
return -1;
// If current node is target
if (root.data == target) {
// Find nodes at distance k from
// target node in subtree.
findNodes(root, k, ans);
return 1;
}
int left
= kDistanceRecur(root.left, target, k, ans);
// If target node is found in left
// subtree, find all nodes at distance
// k-left in right subtree.
if (left != -1) {
if (k - left == 0)
ans.Add(root.data);
else
findNodes(root.right, k - left - 1, ans);
return left + 1;
}
int right
= kDistanceRecur(root.right, target, k, ans);
// If target node is found in right
// subtree, find all nodes at distance
// k-right in left subtree.
if (right != -1) {
if (k - right == 0)
ans.Add(root.data);
else
findNodes(root.left, k - right - 1, ans);
return right + 1;
}
// If target node is not found
// return -1
return -1;
}
static List<int> KDistanceNodes(Node root, int target,
int k) {
List<int> ans = new List<int>();
kDistanceRecur(root, target, k, ans);
// sort the result
ans.Sort();
return ans;
}
static void printList(List<int> v) {
foreach(int i in v) { Console.Write(i + " "); }
Console.WriteLine();
}
static void Main() {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
int target = 7, k = 2;
List<int> ans = KDistanceNodes(root, target, k);
printList(ans);
}
}
// JavaScript program to find nodes
// at distance k from target.
class Node {
constructor(x) {
this.key = x;
this.left = null;
this.right = null;
}
}
// Function which finds the nodes at a given
// distance from root node
function findNodes(root, dis, ans) {
// base case
if (root === null)
return;
if (dis === 0) {
ans.push(root.key);
return;
}
findNodes(root.left, dis - 1, ans);
findNodes(root.right, dis - 1, ans);
}
// Function which returns the distance of a node
// target node. Returns -1 if target is not found.
function kDistanceRecur(root, target, k, ans) {
// base case
if (root === null)
return -1;
// If current node is target
if (root.key === target) {
// Find nodes at distance k from
// target node in subtree.
findNodes(root, k, ans);
return 1;
}
let left = kDistanceRecur(root.left, target, k, ans);
// If target node is found in left
// subtree, find all nodes at distance
// k-left in right subtree.
if (left !== -1) {
if (k - left === 0)
ans.push(root.key);
else
findNodes(root.right, k - left - 1, ans);
return left + 1;
}
let right = kDistanceRecur(root.right, target, k, ans);
// If target node is found in right
// subtree, find all nodes at distance
// k-right in left subtree.
if (right !== -1) {
if (k - right === 0)
ans.push(root.key);
else
findNodes(root.left, k - right - 1, ans);
return right + 1;
}
// If target node is not found
// return -1
return -1;
}
function KDistanceNodes(root, target, k) {
let ans = [];
kDistanceRecur(root, target, k, ans);
// sort the result
ans.sort((a, b) => a - b);
return ans;
}
function printList(v) { console.log(v.join(" ")); }
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
let root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
let target = 7, k = 2;
let ans = KDistanceNodes(root, target, k);
printList(ans);
Output
1 24
Time Complexity: O(nlogn), for sorting the result.
Auxiliary Space: O(h), where h is the height of the tree.
[Expected Approach – 2] Using DFS with Parent Pointers – O(nlogn) Time and O(n) Space:
The idea is to recursively find the target node and map each node to its parent node. Then, starting from the target node, apply depth first search (DFS) to find all the nodes at distance k from the target node.
Below is the implementation of the above approach:
// C++ program to find nodes
// at distance k from target.
#include <bits/stdc++.h>
using namespace std;
class Node {
public:
int data;
Node *left, *right;
Node(int x) {
data = x;
left = nullptr;
right = nullptr;
}
};
// Function which maps the nodes to its parent nodes
// and returns the target node.
Node *findTarNode(Node *root, int target, unordered_map<Node*, Node*> &parent) {
Node *left = nullptr, *right = nullptr;
// map the left child to root node
// and search for target node in
// left subtree.
if (root->left != nullptr) {
parent[root->left] = root;
left = findTarNode(root->left, target, parent);
}
// map the right child to root node and search
// for target node in right subtree.
if (root->right != nullptr) {
parent[root->right] = root;
right = findTarNode(root->right, target, parent);
}
// If root node is target, then
// return root node.
if (root->data == target) {
return root;
}
// If target node in found in left
// subtree, then return left.
else if (left != nullptr) {
return left;
}
// return the result from
// right subtree.
return right;
}
// depth first function to find nodes k
// distance away.
void dfs(Node *root, Node *prev, int k,
unordered_map<Node *, Node *> &parent, vector<int> &ans) {
// base case
if (root == nullptr)
return;
// If current node is kth
// distance away.
if (k == 0) {
ans.push_back(root->data);
return;
}
if (root->left != prev)
dfs(root->left, root, k - 1, parent, ans);
if (root->right != prev)
dfs(root->right, root, k - 1, parent, ans);
if (parent[root] != prev)
dfs(parent[root], root, k - 1, parent, ans);
}
vector<int> KDistanceNodes(Node *root, int target, int k) {
vector<int> ans;
if (root == nullptr)
return ans;
// find the target nodes and map the nodes
// to their parent nodes.
unordered_map<Node *, Node *> parent;
Node *tar = findTarNode(root, target, parent);
dfs(tar, nullptr, k, parent, ans);
// sort the result
sort(ans.begin(), ans.end());
return ans;
}
void printList(vector<int> v) {
int n = v.size();
for (int i = 0; i < n; i++) {
cout << v[i] << " ";
}
cout << endl;
}
int main() {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node *root = new Node(20);
root->left = new Node(7);
root->right = new Node(24);
root->left->left = new Node(4);
root->left->right = new Node(3);
root->left->right->left = new Node(1);
int target = 7, k = 2;
vector<int> ans = KDistanceNodes(root, target, k);
printList(ans);
return 0;
}
// Java program to find nodes
// at distance k from target.
import java.util.*;
class Node {
int data;
Node left, right;
Node(int x) {
data = x;
left = null;
right = null;
}
}
class GfG {
// Function which maps the nodes to its parent nodes
// and returns the target node.
static Node findTarNode(Node root, int target,
Map<Node, Node> parent) {
Node left = null, right = null;
// map the left child to root node
// and search for target node in
// left subtree.
if (root.left != null) {
parent.put(root.left, root);
left = findTarNode(root.left, target, parent);
}
// map the right child to root node and search
// for target node in right subtree.
if (root.right != null) {
parent.put(root.right, root);
right = findTarNode(root.right, target, parent);
}
// If root node is target, then
// return root node.
if (root.data == target) {
return root;
}
// If target node in found in left
// subtree, then return left.
else if (left != null) {
return left;
}
// return the result from
// right subtree.
return right;
}
// depth first function to find nodes k
// distance away.
static void dfs(Node root, Node prev, int k,
Map<Node, Node> parent,
ArrayList<Integer> ans) {
// base case
if (root == null)
return;
// If current node is kth
// distance away.
if (k == 0) {
ans.add(root.data);
return;
}
if (root.left != prev)
dfs(root.left, root, k - 1, parent, ans);
if (root.right != prev)
dfs(root.right, root, k - 1, parent, ans);
if (parent.get(root) != prev)
dfs(parent.get(root), root, k - 1, parent, ans);
}
static ArrayList<Integer> KDistanceNodes(Node root, int target, int k) {
ArrayList<Integer> ans = new ArrayList<>();
if (root == null)
return ans;
// find the target nodes and map the nodes
// to their parent nodes.
Map<Node, Node> parent = new HashMap<>();
Node tar = findTarNode(root, target, parent);
dfs(tar, null, k, parent, ans);
// sort the result
Collections.sort(ans);
return ans;
}
static void printList(ArrayList<Integer> v) {
for (int i : v) {
System.out.print(i + " ");
}
System.out.println();
}
public static void main(String[] args) {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
int target = 7, k = 2;
ArrayList<Integer> ans
= KDistanceNodes(root, target, k);
printList(ans);
}
}
# Python program to find nodes
# at distance k from target.
class Node:
def __init__(self, x):
self.data = x
self.left = None
self.right = None
# Function which maps the nodes to its parent nodes
# and returns the target node.
def findTarNode(root, target, parent):
left = right = None
# map the left child to root node
# and search for target node in
# left subtree.
if root.left is not None:
parent[root.left] = root
left = findTarNode(root.left, target, parent)
# map the right child to root node and search
# for target node in right subtree.
if root.right is not None:
parent[root.right] = root
right = findTarNode(root.right, target, parent)
# If root node is target, then
# return root node.
if root.data == target:
return root
# If target node in found in left
# subtree, then return left.
elif left:
return left
# return the result from
# right subtree.
return right
# depth first function to find nodes k
# distance away.
def dfs(root, prev, k, parent, ans):
# base case
if not root:
return
# If current node is kth
# distance away.
if k == 0:
ans.append(root.data)
return
if root.left != prev:
dfs(root.left, root, k - 1, parent, ans)
if root.right != prev:
dfs(root.right, root, k - 1, parent, ans)
if parent.get(root) != prev:
dfs(parent[root], root, k - 1, parent, ans)
def KDistanceNodes(root, target, k):
ans = []
if not root:
return ans
# find the target nodes and map the nodes
# to their parent nodes.
parent = {}
parent[root] = None
tar = findTarNode(root, target, parent)
dfs(tar, None, k, parent, ans)
# sort the result
ans.sort()
return ans
def printList(v):
print(" ".join(map(str, v)))
if __name__ == "__main__":
# Create a hard coded tree.
# 20
# / \
# 7 24
# / \
# 4 3
# /
# 1
root = Node(20)
root.left = Node(7)
root.right = Node(24)
root.left.left = Node(4)
root.left.right = Node(3)
root.left.right.left = Node(1)
target = 7
k = 2
ans = KDistanceNodes(root, target, k)
printList(ans)
// C# program to find nodes
// at distance k from target.
using System;
using System.Collections.Generic;
class Node {
public int data;
public Node left, right;
public Node(int x) {
data = x;
left = null;
right = null;
}
}
class GfG {
// Function which maps the nodes to its parent nodes
// and returns the target node.
static Node findTarNode(Node root, int target,
Dictionary<Node, Node> parent) {
Node left = null, right = null;
// map the left child to root node
// and search for target node in
// left subtree.
if (root.left != null) {
parent[root.left] = root;
left = findTarNode(root.left, target, parent);
}
// map the right child to root node and search
// for target node in right subtree.
if (root.right != null) {
parent[root.right] = root;
right = findTarNode(root.right, target, parent);
}
// If root node is target, then
// return root node.
if (root.data == target) {
return root;
}
// If target node in found in left
// subtree, then return left.
else if (left != null) {
return left;
}
// return the result from
// right subtree.
return right;
}
static void dfs(Node root, Node prev, int k,
Dictionary<Node, Node> parent,
List<int> ans) {
// base case
if (root == null)
return;
// If current node is kth
// distance away.
if (k == 0) {
ans.Add(root.data);
return;
}
if (root.left != prev)
dfs(root.left, root, k - 1, parent, ans);
if (root.right != prev)
dfs(root.right, root, k - 1, parent, ans);
if (parent[root] != prev)
dfs(parent[root], root, k - 1, parent, ans);
}
static List<int> KDistanceNodes(Node root, int target, int k) {
List<int> ans = new List<int>();
if (root == null)
return ans;
Dictionary<Node, Node> parent
= new Dictionary<Node, Node>();
parent[root] = null;
Node tar = findTarNode(root, target, parent);
dfs(tar, null, k, parent, ans);
// sort the result
ans.Sort();
return ans;
}
static void printList(List<int> v) {
foreach(int i in v) { Console.Write(i + " "); }
Console.WriteLine();
}
static void Main(string[] args) {
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
Node root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
int target = 7, k = 2;
List<int> ans = KDistanceNodes(root, target, k);
printList(ans);
}
}
// JavaScript program to find nodes
// at distance k from target.
class Node {
constructor(x) {
this.key = x;
this.left = null;
this.right = null;
}
}
// Function which maps the nodes to its parent nodes
// and returns the target node.
function findTarNode(root, target, parent) {
let left = null, right = null;
// map the left child to root node
// and search for target node in
// left subtree.
if (root.left !== null) {
parent.set(root.left, root);
left = findTarNode(root.left, target, parent);
}
// map the right child to root node and search
// for target node in right subtree.
if (root.right !== null) {
parent.set(root.right, root);
right = findTarNode(root.right, target, parent);
}
// If root node is target, then
// return root node.
if (root.key === target) {
return root;
}
// If target node in found in left
// subtree, then return left.
else if (left !== null) {
return left;
}
// return the result from
// right subtree.
return right;
}
// depth first function to find nodes k
// distance away.
function dfs(root, prev, k, parent, ans) {
// base case
if (root === null)
return;
// If current node is kth
// distance away.
if (k === 0) {
ans.push(root.key);
return;
}
if (root.left !== prev)
dfs(root.left, root, k - 1, parent, ans);
if (root.right !== prev)
dfs(root.right, root, k - 1, parent, ans);
if (parent.get(root) !== prev)
dfs(parent.get(root), root, k - 1, parent, ans);
}
function KDistanceNodes(root, target, k) {
let ans = [];
if (root === null)
return ans;
// find the target nodes and map the nodes
// to their parent nodes.
let parent = new Map();
parent.set(root, null);
let tar = findTarNode(root, target, parent);
dfs(tar, null, k, parent, ans);
// sort the result
ans.sort((a, b) => a - b);
return ans;
}
function printList(v) { console.log(v.join(" ")); }
// Create a hard coded tree.
// 20
// / \
// 7 24
// / \
// 4 3
// /
// 1
let root = new Node(20);
root.left = new Node(7);
root.right = new Node(24);
root.left.left = new Node(4);
root.left.right = new Node(3);
root.left.right.left = new Node(1);
let target = 7, k = 2;
let ans = KDistanceNodes(root, target, k);
printList(ans);
Output
1 24
Time Complexity: O(nlogn), for sorting the result.
Space Complexity: O(h), where h is the height of the tree.
Related article: