Skip to content

Commit

Permalink
fix: move semantic tokens to visitor
Browse files Browse the repository at this point in the history
  • Loading branch information
TheNuclearNexus committed Sep 27, 2024
1 parent 14327d4 commit 4fbeced
Showing 1 changed file with 128 additions and 140 deletions.
268 changes: 128 additions & 140 deletions language_server/server/features/semantics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from dataclasses import dataclass, field
import logging
from beet import Context
from beet.core.utils import required_field
from lsprotocol import types as lsp
from mecha import AstCommand, AstItemSlot, AstNode, AstResourceLocation
from mecha import AstCommand, AstItemSlot, AstNode, AstResourceLocation, Visitor, rule, Mecha
from mecha.contrib.nested_location import AstNestedLocation
from bolt import AstAttribute, AstPrelude
from bolt import (
Expand Down Expand Up @@ -67,167 +70,152 @@ def node_to_token(
node: AstNode,
type: int,
modifier: int,
prev_node: tuple[AstNode, tuple[int, ...]] | None,
prev_node: AstNode | None,
) -> tuple[int, ...]:
line_offset = node.location.lineno - 1
col_offset = node.location.colno - 1
length = node.end_location.pos - node.location.pos

if prev_node is not None:
line_offset -= prev_node[0].location.lineno - 1
line_offset -= prev_node.location.lineno - 1

if line_offset == 0:
col_offset -= prev_node[0].location.colno - 1
col_offset -= prev_node.location.colno - 1

token = (line_offset, col_offset, length, type, modifier)
return token

@dataclass
class SemanticTokenCollector(Visitor):
nodes: list[tuple[AstNode, int, int]] = field(default_factory=list)
ctx: Context = required_field()


def parse_command(nodes: list[tuple[AstNode, int, int]], node: AstCommand):
match node.identifier:
case "import:module":
modules: list[AstResourceLocation] = node.arguments

for m in modules:
nodes.append(
(m, TOKEN_TYPES["class" if m.namespace == None else "function"], 0)
)
case "import:module:as:alias":
module: AstResourceLocation = node.arguments[0]
item: AstImportedItem = node.arguments[1]

type = TOKEN_TYPES["class" if module.namespace == None else "function"]

nodes.append((module, type, 0))
nodes.append((item, type, 0))


def walk(root: AstNode):
nodes: list[tuple[AstNode, int, int]] = []

for node in root.walk():
get_token_type(nodes, node)

tokens: list[tuple[int, ...]] = []

for i in range(len(nodes)):
prev_node = None
if i > 0:
prev_node = (nodes[i - 1][0], tokens[i - 1])

node, type, modifier = nodes[i]
tokens.append(node_to_token(node, type, modifier, prev_node))

logging.debug(tokens)
return list(sum(tokens, ()))


def get_token_type(nodes: list[tuple[AstNode, int, int]], node: AstNode):
match node:
case AstFromImport() as from_import:
handle_from_import(nodes, from_import)

case AstCommand() as command:
parse_command(nodes, command)

case AstAssignment() as assignment:
handle_assignment(nodes, assignment)

case AstCall() as call:

if isinstance(call.value, AstAttribute):
attribute = call.value
offset = len(attribute.name)

call_start = SourceLocation(
attribute.end_location.pos - offset,
attribute.end_location.lineno,
attribute.end_location.colno - offset,
)

function = AstNode(
call_start,
attribute.end_location,
)

base = AstNode(
attribute.location,
call_start
)

nodes.append((base, TOKEN_TYPES["variable"], 0))
nodes.append((function, TOKEN_TYPES["function"], 0))
else:
nodes.append((call.value, TOKEN_TYPES["function"], 0))

case AstFunctionSignature() as signature:
handle_function_sig(nodes, signature)

case AstResourceLocation() as nested_location:
nodes.append((nested_location, TOKEN_TYPES["function"], 0))

case AstItemSlot():
nodes.append((node, TOKEN_TYPES["variable"], TOKEN_MODIFIERS["readonly"]))


def handle_function_sig(
nodes: list[tuple[AstNode, int, int]], signature: AstFunctionSignature
):
location: SourceLocation = signature.location
node = AstNode(
location=location,
end_location=SourceLocation(
location.pos + len(signature.name),
location.lineno,
location.colno + len(signature.name),
),
)
nodes.append((node, TOKEN_TYPES["function"], 0))

@rule(AstCommand)
def command(self, node: AstCommand):
match node.identifier:
case "import:module":
modules: list[AstResourceLocation] = node.arguments

for m in modules:
self.nodes.append(
(m, TOKEN_TYPES["class" if m.namespace == None else "function"], 0)
)
case "import:module:as:alias":
module: AstResourceLocation = node.arguments[0]
item: AstImportedItem = node.arguments[1]

type = TOKEN_TYPES["class" if module.namespace == None else "function"]

self.nodes.append((module, type, 0))
self.nodes.append((item, type, 0))


@rule(AstFromImport)
def from_import(self, from_import: AstFromImport):
if isinstance(from_import, AstPrelude):
return

logging.debug(from_import)

location: AstResourceLocation = from_import.arguments[0]
imports: tuple[AstImportedItem] = from_import.arguments[1:]

self.nodes.append(
(
location,
TOKEN_TYPES["class" if location.namespace == None else "function"],
0,
)
)

def handle_assignment(nodes: list[tuple[AstNode, int, int]], assignment: AstAssignment):
operator = assignment.operator
import_offset = len("import")
self.nodes.append((
AstNode(
offset_location(location.end_location, 1),
offset_location(location.end_location, import_offset + 1)
),
TOKEN_TYPES["keyword"],
0
))

for i in imports:
logging.debug(f"{i.name}: {i.location}, {i.end_location}")
self.nodes.append((i, TOKEN_TYPES["variable" if i.identifier else "class"], 0))


@rule(AstCall)
def call(self, call: AstCall):
if isinstance(call.value, AstAttribute):
attribute = call.value
offset = len(attribute.name)

call_start = SourceLocation(
attribute.end_location.pos - offset,
attribute.end_location.lineno,
attribute.end_location.colno - offset,
)

function = AstNode(
call_start,
attribute.end_location,
)

base = AstNode(
attribute.location,
call_start
)

self.nodes.append((base, TOKEN_TYPES["variable"], 0))
self.nodes.append((function, TOKEN_TYPES["function"], 0))
else:
self.nodes.append((call.value, TOKEN_TYPES["function"], 0))


@rule(AstItemSlot)
def item_slot(self, item_slot: AstItemSlot):
self.nodes.append((item_slot, TOKEN_TYPES["variable"], TOKEN_MODIFIERS["readonly"]))

@rule(AstResourceLocation)
def resource_location(self, resource_location: AstResourceLocation):
self.nodes.append((resource_location, TOKEN_TYPES["function"], 0))

@rule(AstFunctionSignature)
def function_signature(
self, signature: AstFunctionSignature
):
location: SourceLocation = signature.location
node = AstNode(
location=location,
end_location=offset_location(signature.location, len(signature.name))
)
self.nodes.append((node, TOKEN_TYPES["function"], 0))

nodes.append((assignment.target, TOKEN_TYPES["variable"], 0))
@rule(AstAssignment)
def assignment(self, assignment: AstAssignment):
operator = assignment.operator

if assignment.type_annotation != None:
nodes.append((assignment.type_annotation, TOKEN_TYPES["class"], 0))
nodes.append((assignment.target, TOKEN_TYPES["variable"], 0))

if assignment.type_annotation != None:
nodes.append((assignment.type_annotation, TOKEN_TYPES["class"], 0))

def handle_from_import(
nodes: list[tuple[AstNode, int, int]], from_import: AstFromImport
):
if isinstance(from_import, AstPrelude):
return

logging.debug(from_import)

location: AstResourceLocation = from_import.arguments[0]
imports: tuple[AstImportedItem] = from_import.arguments[1:]
def walk(self, root: AstNode):
self.nodes = []
self.__call__(root)

nodes.append(
(
location,
TOKEN_TYPES["class" if location.namespace == None else "function"],
0,
)
)
tokens: list[tuple[int, ...]] = []

import_offset = len("import")
nodes.append((
AstNode(
offset_location(location.end_location, 1),
offset_location(location.end_location, import_offset + 1)
),
TOKEN_TYPES["keyword"],
0
))
for i in range(len(self.nodes)):
prev_node = None
if i > 0:
prev_node = self.nodes[i - 1][0]

logging.debug(f"{nodes[-1][0].location}, {nodes[-1][0].end_location}")
node, type, modifier = self.nodes[i]
tokens.append(node_to_token(node, type, modifier, prev_node))

for i in imports:
logging.debug(f"{i.name}: {i.location}, {i.end_location}")
nodes.append((i, TOKEN_TYPES["variable" if i.identifier else "class"], 0))
return list(sum(tokens, ()))


def semantic_tokens(ls: MechaLanguageServer, params: lsp.SemanticTokensParams):
Expand All @@ -240,6 +228,6 @@ def semantic_tokens(ls: MechaLanguageServer, params: lsp.SemanticTokensParams):
compiled_doc = get_compilation_data(ls, ctx, text_doc)
ast = compiled_doc.ast

data = walk(ast) if ast else []
data = SemanticTokenCollector(ctx=ctx).walk(ast) if ast else []

return lsp.SemanticTokens(data=data)

0 comments on commit 4fbeced

Please sign in to comment.