1. package chroma
    
  2. 
    
  3. import (
    
  4. 	"compress/gzip"
    
  5. 	"encoding/xml"
    
  6. 	"errors"
    
  7. 	"fmt"
    
  8. 	"io"
    
  9. 	"io/fs"
    
  10. 	"math"
    
  11. 	"path/filepath"
    
  12. 	"reflect"
    
  13. 	"regexp"
    
  14. 	"strings"
    
  15. 
    
  16. 	"github.com/dlclark/regexp2"
    
  17. )
    
  18. 
    
  19. // Serialisation of Chroma rules to XML. The format is:
    
  20. //
    
  21. //	<rules>
    
  22. //	  <state name="$STATE">
    
  23. //	    <rule [pattern="$PATTERN"]>
    
  24. //	      [<$EMITTER ...>]
    
  25. //	      [<$MUTATOR ...>]
    
  26. //	    </rule>
    
  27. //	  </state>
    
  28. //	</rules>
    
  29. //
    
  30. // eg. Include("String") would become:
    
  31. //
    
  32. //	<rule>
    
  33. //	  <include state="String" />
    
  34. //	</rule>
    
  35. //
    
  36. //	[null, null, {"kind": "include", "state": "String"}]
    
  37. //
    
  38. // eg. Rule{`\d+`, Text, nil} would become:
    
  39. //
    
  40. //	<rule pattern="\\d+">
    
  41. //	  <token type="Text"/>
    
  42. //	</rule>
    
  43. //
    
  44. // eg. Rule{`"`, String, Push("String")}
    
  45. //
    
  46. //	<rule pattern="\"">
    
  47. //	  <token type="String" />
    
  48. //	  <push state="String" />
    
  49. //	</rule>
    
  50. //
    
  51. // eg. Rule{`(\w+)(\n)`, ByGroups(Keyword, Whitespace), nil},
    
  52. //
    
  53. //	<rule pattern="(\\w+)(\\n)">
    
  54. //	  <bygroups token="Keyword" token="Whitespace" />
    
  55. //	  <push state="String" />
    
  56. //	</rule>
    
  57. var (
    
  58. 	// ErrNotSerialisable is returned if a lexer contains Rules that cannot be serialised.
    
  59. 	ErrNotSerialisable = fmt.Errorf("not serialisable")
    
  60. 	emitterTemplates   = func() map[string]SerialisableEmitter {
    
  61. 		out := map[string]SerialisableEmitter{}
    
  62. 		for _, emitter := range []SerialisableEmitter{
    
  63. 			&byGroupsEmitter{},
    
  64. 			&usingSelfEmitter{},
    
  65. 			TokenType(0),
    
  66. 			&usingEmitter{},
    
  67. 			&usingByGroup{},
    
  68. 		} {
    
  69. 			out[emitter.EmitterKind()] = emitter
    
  70. 		}
    
  71. 		return out
    
  72. 	}()
    
  73. 	mutatorTemplates = func() map[string]SerialisableMutator {
    
  74. 		out := map[string]SerialisableMutator{}
    
  75. 		for _, mutator := range []SerialisableMutator{
    
  76. 			&includeMutator{},
    
  77. 			&combinedMutator{},
    
  78. 			&multiMutator{},
    
  79. 			&pushMutator{},
    
  80. 			&popMutator{},
    
  81. 		} {
    
  82. 			out[mutator.MutatorKind()] = mutator
    
  83. 		}
    
  84. 		return out
    
  85. 	}()
    
  86. )
    
  87. 
    
  88. // fastUnmarshalConfig unmarshals only the Config from a serialised lexer.
    
  89. func fastUnmarshalConfig(from fs.FS, path string) (*Config, error) {
    
  90. 	r, err := from.Open(path)
    
  91. 	if err != nil {
    
  92. 		return nil, err
    
  93. 	}
    
  94. 	defer r.Close()
    
  95. 	dec := xml.NewDecoder(r)
    
  96. 	for {
    
  97. 		token, err := dec.Token()
    
  98. 		if err != nil {
    
  99. 			if errors.Is(err, io.EOF) {
    
  100. 				return nil, fmt.Errorf("could not find <config> element")
    
  101. 			}
    
  102. 			return nil, err
    
  103. 		}
    
  104. 		switch se := token.(type) {
    
  105. 		case xml.StartElement:
    
  106. 			if se.Name.Local != "config" {
    
  107. 				break
    
  108. 			}
    
  109. 
    
  110. 			var config Config
    
  111. 			err = dec.DecodeElement(&config, &se)
    
  112. 			if err != nil {
    
  113. 				return nil, fmt.Errorf("%s: %w", path, err)
    
  114. 			}
    
  115. 			return &config, nil
    
  116. 		}
    
  117. 	}
    
  118. }
    
  119. 
    
  120. // MustNewXMLLexer constructs a new RegexLexer from an XML file or panics.
    
  121. func MustNewXMLLexer(from fs.FS, path string) *RegexLexer {
    
  122. 	lex, err := NewXMLLexer(from, path)
    
  123. 	if err != nil {
    
  124. 		panic(err)
    
  125. 	}
    
  126. 	return lex
    
  127. }
    
  128. 
    
  129. // NewXMLLexer creates a new RegexLexer from a serialised RegexLexer.
    
  130. func NewXMLLexer(from fs.FS, path string) (*RegexLexer, error) {
    
  131. 	config, err := fastUnmarshalConfig(from, path)
    
  132. 	if err != nil {
    
  133. 		return nil, err
    
  134. 	}
    
  135. 
    
  136. 	for _, glob := range append(config.Filenames, config.AliasFilenames...) {
    
  137. 		_, err := filepath.Match(glob, "")
    
  138. 		if err != nil {
    
  139. 			return nil, fmt.Errorf("%s: %q is not a valid glob: %w", config.Name, glob, err)
    
  140. 		}
    
  141. 	}
    
  142. 
    
  143. 	var analyserFn func(string) float32
    
  144. 
    
  145. 	if config.Analyse != nil {
    
  146. 		type regexAnalyse struct {
    
  147. 			re    *regexp2.Regexp
    
  148. 			score float32
    
  149. 		}
    
  150. 
    
  151. 		regexAnalysers := make([]regexAnalyse, 0, len(config.Analyse.Regexes))
    
  152. 
    
  153. 		for _, ra := range config.Analyse.Regexes {
    
  154. 			re, err := regexp2.Compile(ra.Pattern, regexp2.None)
    
  155. 			if err != nil {
    
  156. 				return nil, fmt.Errorf("%s: %q is not a valid analyser regex: %w", config.Name, ra.Pattern, err)
    
  157. 			}
    
  158. 
    
  159. 			regexAnalysers = append(regexAnalysers, regexAnalyse{re, ra.Score})
    
  160. 		}
    
  161. 
    
  162. 		analyserFn = func(text string) float32 {
    
  163. 			var score float32
    
  164. 
    
  165. 			for _, ra := range regexAnalysers {
    
  166. 				ok, err := ra.re.MatchString(text)
    
  167. 				if err != nil {
    
  168. 					return 0
    
  169. 				}
    
  170. 
    
  171. 				if ok && config.Analyse.First {
    
  172. 					return float32(math.Min(float64(ra.score), 1.0))
    
  173. 				}
    
  174. 
    
  175. 				if ok {
    
  176. 					score += ra.score
    
  177. 				}
    
  178. 			}
    
  179. 
    
  180. 			return float32(math.Min(float64(score), 1.0))
    
  181. 		}
    
  182. 	}
    
  183. 
    
  184. 	return &RegexLexer{
    
  185. 		config:   config,
    
  186. 		analyser: analyserFn,
    
  187. 		fetchRulesFunc: func() (Rules, error) {
    
  188. 			var lexer struct {
    
  189. 				Config
    
  190. 				Rules Rules `xml:"rules"`
    
  191. 			}
    
  192. 			// Try to open .xml fallback to .xml.gz
    
  193. 			fr, err := from.Open(path)
    
  194. 			if err != nil {
    
  195. 				if errors.Is(err, fs.ErrNotExist) {
    
  196. 					path += ".gz"
    
  197. 					fr, err = from.Open(path)
    
  198. 					if err != nil {
    
  199. 						return nil, err
    
  200. 					}
    
  201. 				} else {
    
  202. 					return nil, err
    
  203. 				}
    
  204. 			}
    
  205. 			defer fr.Close()
    
  206. 			var r io.Reader = fr
    
  207. 			if strings.HasSuffix(path, ".gz") {
    
  208. 				r, err = gzip.NewReader(r)
    
  209. 				if err != nil {
    
  210. 					return nil, fmt.Errorf("%s: %w", path, err)
    
  211. 				}
    
  212. 			}
    
  213. 			err = xml.NewDecoder(r).Decode(&lexer)
    
  214. 			if err != nil {
    
  215. 				return nil, fmt.Errorf("%s: %w", path, err)
    
  216. 			}
    
  217. 			return lexer.Rules, nil
    
  218. 		},
    
  219. 	}, nil
    
  220. }
    
  221. 
    
  222. // Marshal a RegexLexer to XML.
    
  223. func Marshal(l *RegexLexer) ([]byte, error) {
    
  224. 	type lexer struct {
    
  225. 		Config Config `xml:"config"`
    
  226. 		Rules  Rules  `xml:"rules"`
    
  227. 	}
    
  228. 
    
  229. 	rules, err := l.Rules()
    
  230. 	if err != nil {
    
  231. 		return nil, err
    
  232. 	}
    
  233. 	root := &lexer{
    
  234. 		Config: *l.Config(),
    
  235. 		Rules:  rules,
    
  236. 	}
    
  237. 	data, err := xml.MarshalIndent(root, "", "  ")
    
  238. 	if err != nil {
    
  239. 		return nil, err
    
  240. 	}
    
  241. 	re := regexp.MustCompile(`></[a-zA-Z]+>`)
    
  242. 	data = re.ReplaceAll(data, []byte(`/>`))
    
  243. 	return data, nil
    
  244. }
    
  245. 
    
  246. // Unmarshal a RegexLexer from XML.
    
  247. func Unmarshal(data []byte) (*RegexLexer, error) {
    
  248. 	type lexer struct {
    
  249. 		Config Config `xml:"config"`
    
  250. 		Rules  Rules  `xml:"rules"`
    
  251. 	}
    
  252. 	root := &lexer{}
    
  253. 	err := xml.Unmarshal(data, root)
    
  254. 	if err != nil {
    
  255. 		return nil, fmt.Errorf("invalid Lexer XML: %w", err)
    
  256. 	}
    
  257. 	lex, err := NewLexer(&root.Config, func() Rules { return root.Rules })
    
  258. 	if err != nil {
    
  259. 		return nil, err
    
  260. 	}
    
  261. 	return lex, nil
    
  262. }
    
  263. 
    
  264. func marshalMutator(e *xml.Encoder, mutator Mutator) error {
    
  265. 	if mutator == nil {
    
  266. 		return nil
    
  267. 	}
    
  268. 	smutator, ok := mutator.(SerialisableMutator)
    
  269. 	if !ok {
    
  270. 		return fmt.Errorf("unsupported mutator: %w", ErrNotSerialisable)
    
  271. 	}
    
  272. 	return e.EncodeElement(mutator, xml.StartElement{Name: xml.Name{Local: smutator.MutatorKind()}})
    
  273. }
    
  274. 
    
  275. func unmarshalMutator(d *xml.Decoder, start xml.StartElement) (Mutator, error) {
    
  276. 	kind := start.Name.Local
    
  277. 	mutator, ok := mutatorTemplates[kind]
    
  278. 	if !ok {
    
  279. 		return nil, fmt.Errorf("unknown mutator %q: %w", kind, ErrNotSerialisable)
    
  280. 	}
    
  281. 	value, target := newFromTemplate(mutator)
    
  282. 	if err := d.DecodeElement(target, &start); err != nil {
    
  283. 		return nil, err
    
  284. 	}
    
  285. 	return value().(SerialisableMutator), nil
    
  286. }
    
  287. 
    
  288. func marshalEmitter(e *xml.Encoder, emitter Emitter) error {
    
  289. 	if emitter == nil {
    
  290. 		return nil
    
  291. 	}
    
  292. 	semitter, ok := emitter.(SerialisableEmitter)
    
  293. 	if !ok {
    
  294. 		return fmt.Errorf("unsupported emitter %T: %w", emitter, ErrNotSerialisable)
    
  295. 	}
    
  296. 	return e.EncodeElement(emitter, xml.StartElement{
    
  297. 		Name: xml.Name{Local: semitter.EmitterKind()},
    
  298. 	})
    
  299. }
    
  300. 
    
  301. func unmarshalEmitter(d *xml.Decoder, start xml.StartElement) (Emitter, error) {
    
  302. 	kind := start.Name.Local
    
  303. 	mutator, ok := emitterTemplates[kind]
    
  304. 	if !ok {
    
  305. 		return nil, fmt.Errorf("unknown emitter %q: %w", kind, ErrNotSerialisable)
    
  306. 	}
    
  307. 	value, target := newFromTemplate(mutator)
    
  308. 	if err := d.DecodeElement(target, &start); err != nil {
    
  309. 		return nil, err
    
  310. 	}
    
  311. 	return value().(SerialisableEmitter), nil
    
  312. }
    
  313. 
    
  314. func (r Rule) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
    
  315. 	start := xml.StartElement{
    
  316. 		Name: xml.Name{Local: "rule"},
    
  317. 	}
    
  318. 	if r.Pattern != "" {
    
  319. 		start.Attr = append(start.Attr, xml.Attr{
    
  320. 			Name:  xml.Name{Local: "pattern"},
    
  321. 			Value: r.Pattern,
    
  322. 		})
    
  323. 	}
    
  324. 	if err := e.EncodeToken(start); err != nil {
    
  325. 		return err
    
  326. 	}
    
  327. 	if err := marshalEmitter(e, r.Type); err != nil {
    
  328. 		return err
    
  329. 	}
    
  330. 	if err := marshalMutator(e, r.Mutator); err != nil {
    
  331. 		return err
    
  332. 	}
    
  333. 	return e.EncodeToken(xml.EndElement{Name: start.Name})
    
  334. }
    
  335. 
    
  336. func (r *Rule) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
    
  337. 	for _, attr := range start.Attr {
    
  338. 		if attr.Name.Local == "pattern" {
    
  339. 			r.Pattern = attr.Value
    
  340. 			break
    
  341. 		}
    
  342. 	}
    
  343. 	for {
    
  344. 		token, err := d.Token()
    
  345. 		if err != nil {
    
  346. 			return err
    
  347. 		}
    
  348. 		switch token := token.(type) {
    
  349. 		case xml.StartElement:
    
  350. 			mutator, err := unmarshalMutator(d, token)
    
  351. 			if err != nil && !errors.Is(err, ErrNotSerialisable) {
    
  352. 				return err
    
  353. 			} else if err == nil {
    
  354. 				if r.Mutator != nil {
    
  355. 					return fmt.Errorf("duplicate mutator")
    
  356. 				}
    
  357. 				r.Mutator = mutator
    
  358. 				continue
    
  359. 			}
    
  360. 			emitter, err := unmarshalEmitter(d, token)
    
  361. 			if err != nil && !errors.Is(err, ErrNotSerialisable) { // nolint: gocritic
    
  362. 				return err
    
  363. 			} else if err == nil {
    
  364. 				if r.Type != nil {
    
  365. 					return fmt.Errorf("duplicate emitter")
    
  366. 				}
    
  367. 				r.Type = emitter
    
  368. 				continue
    
  369. 			} else {
    
  370. 				return err
    
  371. 			}
    
  372. 
    
  373. 		case xml.EndElement:
    
  374. 			return nil
    
  375. 		}
    
  376. 	}
    
  377. }
    
  378. 
    
  379. type xmlRuleState struct {
    
  380. 	Name  string `xml:"name,attr"`
    
  381. 	Rules []Rule `xml:"rule"`
    
  382. }
    
  383. 
    
  384. type xmlRules struct {
    
  385. 	States []xmlRuleState `xml:"state"`
    
  386. }
    
  387. 
    
  388. func (r Rules) MarshalXML(e *xml.Encoder, _ xml.StartElement) error {
    
  389. 	xr := xmlRules{}
    
  390. 	for state, rules := range r {
    
  391. 		xr.States = append(xr.States, xmlRuleState{
    
  392. 			Name:  state,
    
  393. 			Rules: rules,
    
  394. 		})
    
  395. 	}
    
  396. 	return e.EncodeElement(xr, xml.StartElement{Name: xml.Name{Local: "rules"}})
    
  397. }
    
  398. 
    
  399. func (r *Rules) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
    
  400. 	xr := xmlRules{}
    
  401. 	if err := d.DecodeElement(&xr, &start); err != nil {
    
  402. 		return err
    
  403. 	}
    
  404. 	if *r == nil {
    
  405. 		*r = Rules{}
    
  406. 	}
    
  407. 	for _, state := range xr.States {
    
  408. 		(*r)[state.Name] = state.Rules
    
  409. 	}
    
  410. 	return nil
    
  411. }
    
  412. 
    
  413. type xmlTokenType struct {
    
  414. 	Type string `xml:"type,attr"`
    
  415. }
    
  416. 
    
  417. func (t *TokenType) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
    
  418. 	el := xmlTokenType{}
    
  419. 	if err := d.DecodeElement(&el, &start); err != nil {
    
  420. 		return err
    
  421. 	}
    
  422. 	tt, err := TokenTypeString(el.Type)
    
  423. 	if err != nil {
    
  424. 		return err
    
  425. 	}
    
  426. 	*t = tt
    
  427. 	return nil
    
  428. }
    
  429. 
    
  430. func (t TokenType) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
    
  431. 	start.Attr = append(start.Attr, xml.Attr{Name: xml.Name{Local: "type"}, Value: t.String()})
    
  432. 	if err := e.EncodeToken(start); err != nil {
    
  433. 		return err
    
  434. 	}
    
  435. 	return e.EncodeToken(xml.EndElement{Name: start.Name})
    
  436. }
    
  437. 
    
  438. // This hijinks is a bit unfortunate but without it we can't deserialise into TokenType.
    
  439. func newFromTemplate(template interface{}) (value func() interface{}, target interface{}) {
    
  440. 	t := reflect.TypeOf(template)
    
  441. 	if t.Kind() == reflect.Ptr {
    
  442. 		v := reflect.New(t.Elem())
    
  443. 		return v.Interface, v.Interface()
    
  444. 	}
    
  445. 	v := reflect.New(t)
    
  446. 	return func() interface{} { return v.Elem().Interface() }, v.Interface()
    
  447. }
    
  448. 
    
  449. func (b *Emitters) UnmarshalXML(d *xml.Decoder, start xml.StartElement) error {
    
  450. 	for {
    
  451. 		token, err := d.Token()
    
  452. 		if err != nil {
    
  453. 			return err
    
  454. 		}
    
  455. 		switch token := token.(type) {
    
  456. 		case xml.StartElement:
    
  457. 			emitter, err := unmarshalEmitter(d, token)
    
  458. 			if err != nil {
    
  459. 				return err
    
  460. 			}
    
  461. 			*b = append(*b, emitter)
    
  462. 
    
  463. 		case xml.EndElement:
    
  464. 			return nil
    
  465. 		}
    
  466. 	}
    
  467. }
    
  468. 
    
  469. func (b Emitters) MarshalXML(e *xml.Encoder, start xml.StartElement) error {
    
  470. 	if err := e.EncodeToken(start); err != nil {
    
  471. 		return err
    
  472. 	}
    
  473. 	for _, m := range b {
    
  474. 		if err := marshalEmitter(e, m); err != nil {
    
  475. 			return err
    
  476. 		}
    
  477. 	}
    
  478. 	return e.EncodeToken(xml.EndElement{Name: start.Name})
    
  479. }